.. _sec_attention-scoring-functions: Attention Scoring Functions =========================== In :numref:`sec_attention-pooling`, we used a number of different distance-based kernels, including a Gaussian kernel to model interactions between queries and keys. As it turns out, distance functions are slightly more expensive to compute than dot products. As such, with the softmax operation to ensure nonnegative attention weights, much of the work has gone into *attention scoring functions* :math:`a` in :eq:`eq_softmax_attention` and :numref:`fig_attention_output` that are simpler to compute. .. _fig_attention_output: .. figure:: ../img/attention-output.svg Computing the output of attention pooling as a weighted average of values, where weights are computed with the attention scoring function :math:`\mathit{a}` and the softmax operation. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import math import torch from torch import nn from d2l import torch as d2l .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import math from mxnet import np, npx from mxnet.gluon import nn from d2l import mxnet as d2l npx.set_np() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import math import jax from flax import linen as nn from jax import numpy as jnp from d2l import jax as d2l .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import tensorflow as tf from d2l import tensorflow as d2l .. raw:: html
.. raw:: html
Dot Product Attention --------------------- Let’s review the attention function (without exponentiation) from the Gaussian kernel for a moment: .. math:: a(\mathbf{q}, \mathbf{k}_i) = -\frac{1}{2} \|\mathbf{q} - \mathbf{k}_i\|^2 = \mathbf{q}^\top \mathbf{k}_i -\frac{1}{2} \|\mathbf{k}_i\|^2 -\frac{1}{2} \|\mathbf{q}\|^2. First, note that the final term depends on :math:`\mathbf{q}` only. As such it is identical for all :math:`(\mathbf{q}, \mathbf{k}_i)` pairs. Normalizing the attention weights to :math:`1`, as is done in :eq:`eq_softmax_attention`, ensures that this term disappears entirely. Second, note that both batch and layer normalization (to be discussed later) lead to activations that have well-bounded, and often constant, norms :math:`\|\mathbf{k}_i\|`. This is the case, for instance, whenever the keys :math:`\mathbf{k}_i` were generated by a layer norm. As such, we can drop it from the definition of :math:`a` without any major change in the outcome. Last, we need to keep the order of magnitude of the arguments in the exponential function under control. Assume that all the elements of the query :math:`\mathbf{q} \in \mathbb{R}^d` and the key :math:`\mathbf{k}_i \in \mathbb{R}^d` are independent and identically drawn random variables with zero mean and unit variance. The dot product between both vectors has zero mean and a variance of :math:`d`. To ensure that the variance of the dot product still remains :math:`1` regardless of vector length, we use the *scaled dot product attention* scoring function. That is, we rescale the dot product by :math:`1/\sqrt{d}`. We thus arrive at the first commonly used attention function that is used, e.g., in Transformers :cite:`Vaswani.Shazeer.Parmar.ea.2017`: .. math:: a(\mathbf{q}, \mathbf{k}_i) = \mathbf{q}^\top \mathbf{k}_i / \sqrt{d}. :label: eq_dot_product_attention Note that attention weights :math:`\alpha` still need normalizing. We can simplify this further via :eq:`eq_softmax_attention` by using the softmax operation: .. math:: \alpha(\mathbf{q}, \mathbf{k}_i) = \mathrm{softmax}(a(\mathbf{q}, \mathbf{k}_i)) = \frac{\exp(\mathbf{q}^\top \mathbf{k}_i / \sqrt{d})}{\sum_{j=1} \exp(\mathbf{q}^\top \mathbf{k}_j / \sqrt{d})}. :label: eq_attn-scoring-alpha As it turns out, all popular attention mechanisms use the softmax, hence we will limit ourselves to that in the remainder of this chapter. Convenience Functions --------------------- We need a few functions to make the attention mechanism efficient to deploy. This includes tools for dealing with strings of variable lengths (common for natural language processing) and tools for efficient evaluation on minibatches (batch matrix multiplication). Masked Softmax Operation ~~~~~~~~~~~~~~~~~~~~~~~~ One of the most popular applications of the attention mechanism is to sequence models. Hence we need to be able to deal with sequences of different lengths. In some cases, such sequences may end up in the same minibatch, necessitating padding with dummy tokens for shorter sequences (see :numref:`sec_machine_translation` for an example). These special tokens do not carry meaning. For instance, assume that we have the following three sentences: :: Dive into Deep Learning Learn to code Hello world Since we do not want blanks in our attention model we simply need to limit :math:`\sum_{i=1}^n \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i` to :math:`\sum_{i=1}^l \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i` for however long, :math:`l \leq n`, the actual sentence is. Since it is such a common problem, it has a name: the *masked softmax operation*. Let’s implement it. Actually, the implementation cheats ever so slightly by setting the values of :math:`\mathbf{v}_i`, for :math:`i > l`, to zero. Moreover, it sets the attention weights to a large negative number, such as :math:`-10^{6}`, in order to make their contribution to gradients and values vanish in practice. This is done since linear algebra kernels and operators are heavily optimized for GPUs and it is faster to be slightly wasteful in computation rather than to have code with conditional (if then else) statements. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def masked_softmax(X, valid_lens): #@save """Perform softmax operation by masking elements on the last axis.""" # X: 3D tensor, valid_lens: 1D or 2D tensor def _sequence_mask(X, valid_len, value=0): maxlen = X.size(1) mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None] X[~mask] = value return X if valid_lens is None: return nn.functional.softmax(X, dim=-1) else: shape = X.shape if valid_lens.dim() == 1: valid_lens = torch.repeat_interleave(valid_lens, shape[1]) else: valid_lens = valid_lens.reshape(-1) # On the last axis, replace masked elements with a very large negative # value, whose exponentiation outputs 0 X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6) return nn.functional.softmax(X.reshape(shape), dim=-1) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def masked_softmax(X, valid_lens): #@save """Perform softmax operation by masking elements on the last axis.""" # X: 3D tensor, valid_lens: 1D or 2D tensor if valid_lens is None: return npx.softmax(X) else: shape = X.shape if valid_lens.ndim == 1: valid_lens = valid_lens.repeat(shape[1]) else: valid_lens = valid_lens.reshape(-1) # On the last axis, replace masked elements with a very large negative # value, whose exponentiation outputs 0 X = npx.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, True, value=-1e6, axis=1) return npx.softmax(X).reshape(shape) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def masked_softmax(X, valid_lens): #@save """Perform softmax operation by masking elements on the last axis.""" # X: 3D tensor, valid_lens: 1D or 2D tensor def _sequence_mask(X, valid_len, value=0): maxlen = X.shape[1] mask = jnp.arange((maxlen), dtype=jnp.float32)[None, :] < valid_len[:, None] return jnp.where(mask, X, value) if valid_lens is None: return nn.softmax(X, axis=-1) else: shape = X.shape if valid_lens.ndim == 1: valid_lens = jnp.repeat(valid_lens, shape[1]) else: valid_lens = valid_lens.reshape(-1) # On the last axis, replace masked elements with a very large negative # value, whose exponentiation outputs 0 X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6) return nn.softmax(X.reshape(shape), axis=-1) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def masked_softmax(X, valid_lens): #@save """Perform softmax operation by masking elements on the last axis.""" # X: 3D tensor, valid_lens: 1D or 2D tensor def _sequence_mask(X, valid_len, value=0): maxlen = X.shape[1] mask = tf.range(start=0, limit=maxlen, dtype=tf.float32)[ None, :] < tf.cast(valid_len[:, None], dtype=tf.float32) if len(X.shape) == 3: return tf.where(tf.expand_dims(mask, axis=-1), X, value) else: return tf.where(mask, X, value) if valid_lens is None: return tf.nn.softmax(X, axis=-1) else: shape = X.shape if len(valid_lens.shape) == 1: valid_lens = tf.repeat(valid_lens, repeats=shape[1]) else: valid_lens = tf.reshape(valid_lens, shape=-1) # On the last axis, replace masked elements with a very large negative # value, whose exponentiation outputs 0 X = _sequence_mask(tf.reshape(X, shape=(-1, shape[-1])), valid_lens, value=-1e6) return tf.nn.softmax(tf.reshape(X, shape=shape), axis=-1) .. raw:: html
.. raw:: html
To illustrate how this function works, consider a minibatch of two examples of size :math:`2 \times 4`, where their valid lengths are :math:`2` and :math:`3`, respectively. As a result of the masked softmax operation, values beyond the valid lengths for each pair of vectors are all masked as zero. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3])) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([[[0.4448, 0.5552, 0.0000, 0.0000], [0.4032, 0.5968, 0.0000, 0.0000]], [[0.2795, 0.2805, 0.4400, 0.0000], [0.2798, 0.3092, 0.4110, 0.0000]]]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python masked_softmax(np.random.uniform(size=(2, 2, 4)), np.array([2, 3])) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [22:05:24] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([[[0.488994 , 0.511006 , 0. , 0. ], [0.43654838, 0.56345165, 0. , 0. ]], [[0.28817102, 0.3519408 , 0.3598882 , 0. ], [0.29034293, 0.25239873, 0.45725834, 0. ]]]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python masked_softmax(jax.random.uniform(d2l.get_key(), (2, 2, 4)), jnp.array([2, 3])) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Array([[[0.2914798 , 0.7085202 , 0. , 0. ], [0.5130609 , 0.48693904, 0. , 0. ]], [[0.17453432, 0.4599773 , 0.36548832, 0. ], [0.3574293 , 0.3150612 , 0.32750952, 0. ]]], dtype=float32) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python masked_softmax(tf.random.uniform(shape=(2, 2, 4)), tf.constant([2, 3])) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
If we need more fine-grained control to specify the valid length for each of the two vectors of every example, we simply use a two-dimensional tensor of valid lengths. This yields: .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]])) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([[[1.0000, 0.0000, 0.0000, 0.0000], [0.4109, 0.2794, 0.3097, 0.0000]], [[0.3960, 0.6040, 0.0000, 0.0000], [0.2557, 0.1833, 0.2420, 0.3190]]]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python masked_softmax(np.random.uniform(size=(2, 2, 4)), np.array([[1, 3], [2, 4]])) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([[[1. , 0. , 0. , 0. ], [0.35848376, 0.36588794, 0.2756283 , 0. ]], [[0.54370314, 0.45629686, 0. , 0. ], [0.19598779, 0.25580424, 0.19916737, 0.34904057]]]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python masked_softmax(jax.random.uniform(d2l.get_key(), (2, 2, 4)), jnp.array([[1, 3], [2, 4]])) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Array([[[1. , 0. , 0. , 0. ], [0.31556115, 0.28214547, 0.40229338, 0. ]], [[0.5613054 , 0.43869466, 0. , 0. ], [0.29578257, 0.20095006, 0.2151548 , 0.28811258]]], dtype=float32) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python masked_softmax(tf.random.uniform((2, 2, 4)), tf.constant([[1, 3], [2, 4]])) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
.. _subsec_batch_dot: Batch Matrix Multiplication ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Another commonly used operation is to multiply batches of matrices by one another. This comes in handy when we have minibatches of queries, keys, and values. More specifically, assume that .. math:: \mathbf{Q} = [\mathbf{Q}_1, \mathbf{Q}_2, \ldots, \mathbf{Q}_n] \in \mathbb{R}^{n \times a \times b}, \\ \mathbf{K} = [\mathbf{K}_1, \mathbf{K}_2, \ldots, \mathbf{K}_n] \in \mathbb{R}^{n \times b \times c}. Then the batch matrix multiplication (BMM) computes the elementwise product .. math:: \textrm{BMM}(\mathbf{Q}, \mathbf{K}) = [\mathbf{Q}_1 \mathbf{K}_1, \mathbf{Q}_2 \mathbf{K}_2, \ldots, \mathbf{Q}_n \mathbf{K}_n] \in \mathbb{R}^{n \times a \times c}. :label: eq_batch-matrix-mul Let’s see this in action in a deep learning framework. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python Q = torch.ones((2, 3, 4)) K = torch.ones((2, 4, 6)) d2l.check_shape(torch.bmm(Q, K), (2, 3, 6)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python Q = np.ones((2, 3, 4)) K = np.ones((2, 4, 6)) d2l.check_shape(npx.batch_dot(Q, K), (2, 3, 6)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python Q = jnp.ones((2, 3, 4)) K = jnp.ones((2, 4, 6)) d2l.check_shape(jax.lax.batch_matmul(Q, K), (2, 3, 6)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python Q = tf.ones((2, 3, 4)) K = tf.ones((2, 4, 6)) d2l.check_shape(tf.matmul(Q, K).numpy(), (2, 3, 6)) .. raw:: html
.. raw:: html
Scaled Dot Product Attention ---------------------------- Let’s return to the dot product attention introduced in :eq:`eq_dot_product_attention`. In general, it requires that both the query and the key have the same vector length, say :math:`d`, even though this can be addressed easily by replacing :math:`\mathbf{q}^\top \mathbf{k}` with :math:`\mathbf{q}^\top \mathbf{M} \mathbf{k}` where :math:`\mathbf{M}` is a matrix suitably chosen for translating between both spaces. For now assume that the dimensions match. In practice, we often think of minibatches for efficiency, such as computing attention for :math:`n` queries and :math:`m` key-value pairs, where queries and keys are of length :math:`d` and values are of length :math:`v`. The scaled dot product attention of queries :math:`\mathbf Q\in\mathbb R^{n\times d}`, keys :math:`\mathbf K\in\mathbb R^{m\times d}`, and values :math:`\mathbf V\in\mathbb R^{m\times v}` thus can be written as .. math:: \mathrm{softmax}\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \in \mathbb{R}^{n\times v}. :label: eq_softmax_QK_V Note that when applying this to a minibatch, we need the batch matrix multiplication introduced in :eq:`eq_batch-matrix-mul`. In the following implementation of the scaled dot product attention, we use dropout for model regularization. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class DotProductAttention(nn.Module): #@save """Scaled dot product attention.""" def __init__(self, dropout): super().__init__() self.dropout = nn.Dropout(dropout) # Shape of queries: (batch_size, no. of queries, d) # Shape of keys: (batch_size, no. of key-value pairs, d) # Shape of values: (batch_size, no. of key-value pairs, value dimension) # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries) def forward(self, queries, keys, values, valid_lens=None): d = queries.shape[-1] # Swap the last two dimensions of keys with keys.transpose(1, 2) scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d) self.attention_weights = masked_softmax(scores, valid_lens) return torch.bmm(self.dropout(self.attention_weights), values) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class DotProductAttention(nn.Block): #@save """Scaled dot product attention.""" def __init__(self, dropout): super().__init__() self.dropout = nn.Dropout(dropout) # Shape of queries: (batch_size, no. of queries, d) # Shape of keys: (batch_size, no. of key-value pairs, d) # Shape of values: (batch_size, no. of key-value pairs, value dimension) # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries) def forward(self, queries, keys, values, valid_lens=None): d = queries.shape[-1] # Set transpose_b=True to swap the last two dimensions of keys scores = npx.batch_dot(queries, keys, transpose_b=True) / math.sqrt(d) self.attention_weights = masked_softmax(scores, valid_lens) return npx.batch_dot(self.dropout(self.attention_weights), values) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class DotProductAttention(nn.Module): #@save """Scaled dot product attention.""" dropout: float # Shape of queries: (batch_size, no. of queries, d) # Shape of keys: (batch_size, no. of key-value pairs, d) # Shape of values: (batch_size, no. of key-value pairs, value dimension) # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries) @nn.compact def __call__(self, queries, keys, values, valid_lens=None, training=False): d = queries.shape[-1] # Swap the last two dimensions of keys with keys.swapaxes(1, 2) scores = queries@(keys.swapaxes(1, 2)) / math.sqrt(d) attention_weights = masked_softmax(scores, valid_lens) dropout_layer = nn.Dropout(self.dropout, deterministic=not training) return dropout_layer(attention_weights)@values, attention_weights .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class DotProductAttention(tf.keras.layers.Layer): #@save """Scaled dot product attention.""" def __init__(self, dropout): super().__init__() self.dropout = tf.keras.layers.Dropout(dropout) # Shape of queries: (batch_size, no. of queries, d) # Shape of keys: (batch_size, no. of key-value pairs, d) # Shape of values: (batch_size, no. of key-value pairs, value dimension) # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries) def call(self, queries, keys, values, valid_lens=None, **kwargs): d = queries.shape[-1] scores = tf.matmul(queries, keys, transpose_b=True)/tf.math.sqrt( tf.cast(d, dtype=tf.float32)) self.attention_weights = masked_softmax(scores, valid_lens) return tf.matmul(self.dropout(self.attention_weights, **kwargs), values) .. raw:: html
.. raw:: html
To illustrate how the ``DotProductAttention`` class works, we use the same keys, values, and valid lengths from the earlier toy example for additive attention. For the purpose of our example we assume that we have a minibatch size of :math:`2`, a total of :math:`10` keys and values, and that the dimensionality of the values is :math:`4`. Lastly, we assume that the valid length per observation is :math:`2` and :math:`6` respectively. Given that, we expect the output to be a :math:`2 \times 1 \times 4` tensor, i.e., one row per example of the minibatch. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python queries = torch.normal(0, 1, (2, 1, 2)) keys = torch.normal(0, 1, (2, 10, 2)) values = torch.normal(0, 1, (2, 10, 4)) valid_lens = torch.tensor([2, 6]) attention = DotProductAttention(dropout=0.5) attention.eval() d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python queries = np.random.normal(0, 1, (2, 1, 2)) keys = np.random.normal(0, 1, (2, 10, 2)) values = np.random.normal(0, 1, (2, 10, 4)) valid_lens = np.array([2, 6]) attention = DotProductAttention(dropout=0.5) attention.initialize() d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python queries = jax.random.normal(d2l.get_key(), (2, 1, 2)) keys = jax.random.normal(d2l.get_key(), (2, 10, 2)) values = jax.random.normal(d2l.get_key(), (2, 10, 4)) valid_lens = jnp.array([2, 6]) attention = DotProductAttention(dropout=0.5) (output, attention_weights), params = attention.init_with_output( d2l.get_key(), queries, keys, values, valid_lens) print(output) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [[[ 0.75924027 -0.4776329 0.19306126 0.15036084]] [[-0.07728005 1.1064801 -0.839485 -0.36051023]]] .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python queries = tf.random.normal(shape=(2, 1, 2)) keys = tf.random.normal(shape=(2, 10, 2)) values = tf.random.normal(shape=(2, 10, 4)) valid_lens = tf.constant([2, 6]) attention = DotProductAttention(dropout=0.5) d2l.check_shape(attention(queries, keys, values, valid_lens, training=False), (2, 1, 4)) .. raw:: html
.. raw:: html
Let’s check whether the attention weights actually vanish for anything beyond the second and sixth column respectively (because of setting the valid length to :math:`2` and :math:`6`). .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)), xlabel='Keys', ylabel='Queries') .. figure:: output_attention-scoring-functions_722781_108_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)), xlabel='Keys', ylabel='Queries') .. figure:: output_attention-scoring-functions_722781_111_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python d2l.show_heatmaps(attention_weights.reshape((1, 1, 2, 10)), xlabel='Keys', ylabel='Queries') .. figure:: output_attention-scoring-functions_722781_114_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python d2l.show_heatmaps(tf.reshape(attention.attention_weights, (1, 1, 2, 10)), xlabel='Keys', ylabel='Queries') .. figure:: output_attention-scoring-functions_722781_117_0.svg .. raw:: html
.. raw:: html
.. _subsec_additive-attention: Additive Attention ------------------ When queries :math:`\mathbf{q}` and keys :math:`\mathbf{k}` are vectors of different dimension, we can either use a matrix to address the mismatch via :math:`\mathbf{q}^\top \mathbf{M} \mathbf{k}`, or we can use additive attention as the scoring function. Another benefit is that, as its name indicates, the attention is additive. This can lead to some minor computational savings. Given a query :math:`\mathbf{q} \in \mathbb{R}^q` and a key :math:`\mathbf{k} \in \mathbb{R}^k`, the *additive attention* scoring function :cite:`Bahdanau.Cho.Bengio.2014` is given by .. math:: a(\mathbf q, \mathbf k) = \mathbf w_v^\top \textrm{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R}, :label: eq_additive-attn where :math:`\mathbf W_q\in\mathbb R^{h\times q}`, :math:`\mathbf W_k\in\mathbb R^{h\times k}`, and :math:`\mathbf w_v\in\mathbb R^{h}` are the learnable parameters. This term is then fed into a softmax to ensure both nonnegativity and normalization. An equivalent interpretation of :eq:`eq_additive-attn` is that the query and key are concatenated and fed into an MLP with a single hidden layer. Using :math:`\tanh` as the activation function and disabling bias terms, we implement additive attention as follows: .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class AdditiveAttention(nn.Module): #@save """Additive attention.""" def __init__(self, num_hiddens, dropout, **kwargs): super(AdditiveAttention, self).__init__(**kwargs) self.W_k = nn.LazyLinear(num_hiddens, bias=False) self.W_q = nn.LazyLinear(num_hiddens, bias=False) self.w_v = nn.LazyLinear(1, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, queries, keys, values, valid_lens): queries, keys = self.W_q(queries), self.W_k(keys) # After dimension expansion, shape of queries: (batch_size, no. of # queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of # key-value pairs, num_hiddens). Sum them up with broadcasting features = queries.unsqueeze(2) + keys.unsqueeze(1) features = torch.tanh(features) # There is only one output of self.w_v, so we remove the last # one-dimensional entry from the shape. Shape of scores: (batch_size, # no. of queries, no. of key-value pairs) scores = self.w_v(features).squeeze(-1) self.attention_weights = masked_softmax(scores, valid_lens) # Shape of values: (batch_size, no. of key-value pairs, value # dimension) return torch.bmm(self.dropout(self.attention_weights), values) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class AdditiveAttention(nn.Block): #@save """Additive attention.""" def __init__(self, num_hiddens, dropout, **kwargs): super(AdditiveAttention, self).__init__(**kwargs) # Use flatten=False to only transform the last axis so that the # shapes for the other axes are kept the same self.W_k = nn.Dense(num_hiddens, use_bias=False, flatten=False) self.W_q = nn.Dense(num_hiddens, use_bias=False, flatten=False) self.w_v = nn.Dense(1, use_bias=False, flatten=False) self.dropout = nn.Dropout(dropout) def forward(self, queries, keys, values, valid_lens): queries, keys = self.W_q(queries), self.W_k(keys) # After dimension expansion, shape of queries: (batch_size, no. of # queries, 1, num_hiddens) and shape of keys: (batch_size, 1, # no. of key-value pairs, num_hiddens). Sum them up with # broadcasting features = np.expand_dims(queries, axis=2) + np.expand_dims( keys, axis=1) features = np.tanh(features) # There is only one output of self.w_v, so we remove the last # one-dimensional entry from the shape. Shape of scores: # (batch_size, no. of queries, no. of key-value pairs) scores = np.squeeze(self.w_v(features), axis=-1) self.attention_weights = masked_softmax(scores, valid_lens) # Shape of values: (batch_size, no. of key-value pairs, value # dimension) return npx.batch_dot(self.dropout(self.attention_weights), values) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class AdditiveAttention(nn.Module): #@save num_hiddens: int dropout: float def setup(self): self.W_k = nn.Dense(self.num_hiddens, use_bias=False) self.W_q = nn.Dense(self.num_hiddens, use_bias=False) self.w_v = nn.Dense(1, use_bias=False) @nn.compact def __call__(self, queries, keys, values, valid_lens, training=False): queries, keys = self.W_q(queries), self.W_k(keys) # After dimension expansion, shape of queries: (batch_size, no. of # queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of # key-value pairs, num_hiddens). Sum them up with broadcasting features = jnp.expand_dims(queries, axis=2) + jnp.expand_dims(keys, axis=1) features = nn.tanh(features) # There is only one output of self.w_v, so we remove the last # one-dimensional entry from the shape. Shape of scores: (batch_size, # no. of queries, no. of key-value pairs) scores = self.w_v(features).squeeze(-1) attention_weights = masked_softmax(scores, valid_lens) dropout_layer = nn.Dropout(self.dropout, deterministic=not training) # Shape of values: (batch_size, no. of key-value pairs, value # dimension) return dropout_layer(attention_weights)@values, attention_weights .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class AdditiveAttention(tf.keras.layers.Layer): #@save """Additive attention.""" def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs): super().__init__(**kwargs) self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=False) self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=False) self.w_v = tf.keras.layers.Dense(1, use_bias=False) self.dropout = tf.keras.layers.Dropout(dropout) def call(self, queries, keys, values, valid_lens, **kwargs): queries, keys = self.W_q(queries), self.W_k(keys) # After dimension expansion, shape of queries: (batch_size, no. of # queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of # key-value pairs, num_hiddens). Sum them up with broadcasting features = tf.expand_dims(queries, axis=2) + tf.expand_dims( keys, axis=1) features = tf.nn.tanh(features) # There is only one output of self.w_v, so we remove the last # one-dimensional entry from the shape. Shape of scores: (batch_size, # no. of queries, no. of key-value pairs) scores = tf.squeeze(self.w_v(features), axis=-1) self.attention_weights = masked_softmax(scores, valid_lens) # Shape of values: (batch_size, no. of key-value pairs, value # dimension) return tf.matmul(self.dropout( self.attention_weights, **kwargs), values) .. raw:: html
.. raw:: html
Let’s see how ``AdditiveAttention`` works. In our toy example we pick queries, keys and values of size :math:`(2, 1, 20)`, :math:`(2, 10, 2)` and :math:`(2, 10, 4)`, respectively. This is identical to our choice for ``DotProductAttention``, except that now the queries are :math:`20`-dimensional. Likewise, we pick :math:`(2, 6)` as the valid lengths for the sequences in the minibatch. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python queries = torch.normal(0, 1, (2, 1, 20)) attention = AdditiveAttention(num_hiddens=8, dropout=0.1) attention.eval() d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python queries = np.random.normal(0, 1, (2, 1, 20)) attention = AdditiveAttention(num_hiddens=8, dropout=0.1) attention.initialize() d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python queries = jax.random.normal(d2l.get_key(), (2, 1, 20)) attention = AdditiveAttention(num_hiddens=8, dropout=0.1) (output, attention_weights), params = attention.init_with_output( d2l.get_key(), queries, keys, values, valid_lens) print(output) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [[[ 0.8057054 -0.45312855 0.233752 0.32691044]] [[-0.23993565 0.23599407 0.04756263 0.13463953]]] .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python queries = tf.random.normal(shape=(2, 1, 20)) attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8, dropout=0.1) d2l.check_shape(attention(queries, keys, values, valid_lens, training=False), (2, 1, 4)) .. raw:: html
.. raw:: html
When reviewing the attention function we see a behavior that is qualitatively quite similar to that of ``DotProductAttention``. That is, only terms within the chosen valid length :math:`(2, 6)` are nonzero. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)), xlabel='Keys', ylabel='Queries') .. figure:: output_attention-scoring-functions_722781_153_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)), xlabel='Keys', ylabel='Queries') .. figure:: output_attention-scoring-functions_722781_156_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python d2l.show_heatmaps(attention_weights.reshape((1, 1, 2, 10)), xlabel='Keys', ylabel='Queries') .. figure:: output_attention-scoring-functions_722781_159_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python d2l.show_heatmaps(tf.reshape(attention.attention_weights, (1, 1, 2, 10)), xlabel='Keys', ylabel='Queries') .. figure:: output_attention-scoring-functions_722781_162_0.svg .. raw:: html
.. raw:: html
Summary ------- In this section we introduced the two key attention scoring functions: dot product and additive attention. They are effective tools for aggregating across sequences of variable length. In particular, the dot product attention is the mainstay of modern Transformer architectures. When queries and keys are vectors of different lengths, we can use the additive attention scoring function instead. Optimizing these layers is one of the key areas of advance in recent years. For instance, `NVIDIA’s Transformer Library `__ and Megatron :cite:`shoeybi2019megatron` crucially rely on efficient variants of the attention mechanism. We will dive into this in quite a bit more detail as we review Transformers in later sections. Exercises --------- 1. Implement distance-based attention by modifying the ``DotProductAttention`` code. Note that you only need the squared norms of the keys :math:`\|\mathbf{k}_i\|^2` for an efficient implementation. 2. Modify the dot product attention to allow for queries and keys of different dimensionalities by employing a matrix to adjust dimensions. 3. How does the computational cost scale with the dimensionality of the keys, queries, values, and their number? What about the memory bandwidth requirements? .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html