.. _sec_queries-keys-values: Queries, Keys, and Values ========================= So far all the networks we have reviewed crucially relied on the input being of a well-defined size. For instance, the images in ImageNet are of size :math:`224 \times 224` pixels and CNNs are specifically tuned to this size. Even in natural language processing the input size for RNNs is well defined and fixed. Variable size is addressed by sequentially processing one token at a time, or by specially designed convolution kernels :cite:`Kalchbrenner.Grefenstette.Blunsom.2014`. This approach can lead to significant problems when the input is truly of varying size with varying information content, such as in :numref:`sec_seq2seq` in the transformation of text :cite:`Sutskever.Vinyals.Le.2014`. In particular, for long sequences it becomes quite difficult to keep track of everything that has already been generated or even viewed by the network. Even explicit tracking heuristics such as proposed by :cite:t:`yang2016neural` only offer limited benefit. Compare this to databases. In their simplest form they are collections of keys (:math:`k`) and values (:math:`v`). For instance, our database :math:`\mathcal{D}` might consist of tuples {(“Zhang”, “Aston”), (“Lipton”, “Zachary”), (“Li”, “Mu”), (“Smola”, “Alex”), (“Hu”, “Rachel”), (“Werness”, “Brent”)} with the last name being the key and the first name being the value. We can operate on :math:`\mathcal{D}`, for instance with the exact query (:math:`q`) for “Li” which would return the value “Mu”. If (“Li”, “Mu”) was not a record in :math:`\mathcal{D}`, there would be no valid answer. If we also allowed for approximate matches, we would retrieve (“Lipton”, “Zachary”) instead. This quite simple and trivial example nonetheless teaches us a number of useful things: - We can design queries :math:`q` that operate on (:math:`k`,\ :math:`v`) pairs in such a manner as to be valid regardless of the database size. - The same query can receive different answers, according to the contents of the database. - The “code” being executed for operating on a large state space (the database) can be quite simple (e.g., exact match, approximate match, top-:math:`k`). - There is no need to compress or simplify the database to make the operations effective. Clearly we would not have introduced a simple database here if it wasn’t for the purpose of explaining deep learning. Indeed, this leads to one of the most exciting concepts introduced in deep learning in the past decade: the *attention mechanism* :cite:`Bahdanau.Cho.Bengio.2014`. We will cover the specifics of its application to machine translation later. For now, simply consider the following: denote by :math:`\mathcal{D} \stackrel{\textrm{def}}{=} \{(\mathbf{k}_1, \mathbf{v}_1), \ldots (\mathbf{k}_m, \mathbf{v}_m)\}` a database of :math:`m` tuples of *keys* and *values*. Moreover, denote by :math:`\mathbf{q}` a *query*. Then we can define the *attention* over :math:`\mathcal{D}` as .. math:: \textrm{Attention}(\mathbf{q}, \mathcal{D}) \stackrel{\textrm{def}}{=} \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i, :label: eq_attention_pooling where :math:`\alpha(\mathbf{q}, \mathbf{k}_i) \in \mathbb{R}` (:math:`i = 1, \ldots, m`) are scalar attention weights. The operation itself is typically referred to as *attention pooling*. The name *attention* derives from the fact that the operation pays particular attention to the terms for which the weight :math:`\alpha` is significant (i.e., large). As such, the attention over :math:`\mathcal{D}` generates a linear combination of values contained in the database. In fact, this contains the above example as a special case where all but one weight is zero. We have a number of special cases: - The weights :math:`\alpha(\mathbf{q}, \mathbf{k}_i)` are nonnegative. In this case the output of the attention mechanism is contained in the convex cone spanned by the values :math:`\mathbf{v}_i`. - The weights :math:`\alpha(\mathbf{q}, \mathbf{k}_i)` form a convex combination, i.e., :math:`\sum_i \alpha(\mathbf{q}, \mathbf{k}_i) = 1` and :math:`\alpha(\mathbf{q}, \mathbf{k}_i) \geq 0` for all :math:`i`. This is the most common setting in deep learning. - Exactly one of the weights :math:`\alpha(\mathbf{q}, \mathbf{k}_i)` is :math:`1`, while all others are :math:`0`. This is akin to a traditional database query. - All weights are equal, i.e., :math:`\alpha(\mathbf{q}, \mathbf{k}_i) = \frac{1}{m}` for all :math:`i`. This amounts to averaging across the entire database, also called average pooling in deep learning. A common strategy for ensuring that the weights sum up to :math:`1` is to normalize them via .. math:: \alpha(\mathbf{q}, \mathbf{k}_i) = \frac{\alpha(\mathbf{q}, \mathbf{k}_i)}{{\sum_j} \alpha(\mathbf{q}, \mathbf{k}_j)}. In particular, to ensure that the weights are also nonnegative, one can resort to exponentiation. This means that we can now pick *any* function :math:`a(\mathbf{q}, \mathbf{k})` and then apply the softmax operation used for multinomial models to it via .. math:: \alpha(\mathbf{q}, \mathbf{k}_i) = \frac{\exp(a(\mathbf{q}, \mathbf{k}_i))}{\sum_j \exp(a(\mathbf{q}, \mathbf{k}_j))}. :label: eq_softmax_attention This operation is readily available in all deep learning frameworks. It is differentiable and its gradient never vanishes, all of which are desirable properties in a model. Note though, the attention mechanism introduced above is not the only option. For instance, we can design a non-differentiable attention model that can be trained using reinforcement learning methods :cite:`Mnih.Heess.Graves.ea.2014`. As one would expect, training such a model is quite complex. Consequently the bulk of modern attention research follows the framework outlined in :numref:`fig_qkv`. We thus focus our exposition on this family of differentiable mechanisms. .. _fig_qkv: .. figure:: ../img/qkv.svg The attention mechanism computes a linear combination over values :math:`\mathbf{v}_\mathit{i}` via attention pooling, where weights are derived according to the compatibility between a query :math:`\mathbf{q}` and keys :math:`\mathbf{k}_\mathit{i}`. What is quite remarkable is that the actual “code” for executing on the set of keys and values, namely the query, can be quite concise, even though the space to operate on is significant. This is a desirable property for a network layer as it does not require too many parameters to learn. Just as convenient is the fact that attention can operate on arbitrarily large databases without the need to change the way the attention pooling operation is performed. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import torch from d2l import torch as d2l .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python from mxnet import np, npx from d2l import mxnet as d2l npx.set_np() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python from jax import numpy as jnp from d2l import jax as d2l .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import tensorflow as tf from d2l import tensorflow as d2l .. raw:: html
.. raw:: html
Visualization ------------- One of the benefits of the attention mechanism is that it can be quite intuitive, particularly when the weights are nonnegative and sum to :math:`1`. In this case we might *interpret* large weights as a way for the model to select components of relevance. While this is a good intuition, it is important to remember that it is just that, an *intuition*. Regardless, we may want to visualize its effect on the given set of keys when applying a variety of different queries. This function will come in handy later. We thus define the ``show_heatmaps`` function. Note that it does not take a matrix (of attention weights) as its input but rather a tensor with four axes, allowing for an array of different queries and weights. Consequently the input ``matrices`` has the shape (number of rows for display, number of columns for display, number of queries, number of keys). This will come in handy later on when we want to visualize the workings that are to design Transformers. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5), cmap='Reds'): """Show heatmaps of matrices.""" d2l.use_svg_display() num_rows, num_cols, _, _ = matrices.shape fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize, sharex=True, sharey=True, squeeze=False) for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)): for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)): pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap) if i == num_rows - 1: ax.set_xlabel(xlabel) if j == 0: ax.set_ylabel(ylabel) if titles: ax.set_title(titles[j]) fig.colorbar(pcm, ax=axes, shrink=0.6); .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5), cmap='Reds'): """Show heatmaps of matrices.""" d2l.use_svg_display() num_rows, num_cols, _, _ = matrices.shape fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize, sharex=True, sharey=True, squeeze=False) for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)): for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)): pcm = ax.imshow(matrix.asnumpy(), cmap=cmap) if i == num_rows - 1: ax.set_xlabel(xlabel) if j == 0: ax.set_ylabel(ylabel) if titles: ax.set_title(titles[j]) fig.colorbar(pcm, ax=axes, shrink=0.6); .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5), cmap='Reds'): """Show heatmaps of matrices.""" d2l.use_svg_display() num_rows, num_cols, _, _ = matrices.shape fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize, sharex=True, sharey=True, squeeze=False) for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)): for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)): pcm = ax.imshow(matrix, cmap=cmap) if i == num_rows - 1: ax.set_xlabel(xlabel) if j == 0: ax.set_ylabel(ylabel) if titles: ax.set_title(titles[j]) fig.colorbar(pcm, ax=axes, shrink=0.6); .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5), cmap='Reds'): """Show heatmaps of matrices.""" d2l.use_svg_display() num_rows, num_cols, _, _ = matrices.shape fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize, sharex=True, sharey=True, squeeze=False) for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)): for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)): pcm = ax.imshow(matrix.numpy(), cmap=cmap) if i == num_rows - 1: ax.set_xlabel(xlabel) if j == 0: ax.set_ylabel(ylabel) if titles: ax.set_title(titles[j]) fig.colorbar(pcm, ax=axes, shrink=0.6); .. raw:: html
.. raw:: html
As a quick sanity check let’s visualize the identity matrix, representing a case where the attention weight is :math:`1` only when the query and the key are the same. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python attention_weights = torch.eye(10).reshape((1, 1, 10, 10)) show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries') .. figure:: output_queries-keys-values_7fe0e8_33_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python attention_weights = np.eye(10).reshape((1, 1, 10, 10)) show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries') .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [21:50:08] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU .. figure:: output_queries-keys-values_7fe0e8_36_1.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python attention_weights = jnp.eye(10).reshape((1, 1, 10, 10)) show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries') .. figure:: output_queries-keys-values_7fe0e8_39_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python attention_weights = tf.reshape(tf.eye(10), (1, 1, 10, 10)) show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries') .. figure:: output_queries-keys-values_7fe0e8_42_0.svg .. raw:: html
.. raw:: html
Summary ------- The attention mechanism allows us to aggregate data from many (key, value) pairs. So far our discussion was quite abstract, simply describing a way to pool data. We have not explained yet where those mysterious queries, keys, and values might arise from. Some intuition might help here: for instance, in a regression setting, the query might correspond to the location where the regression should be carried out. The keys are the locations where past data was observed and the values are the (regression) values themselves. This is the so-called Nadaraya–Watson estimator :cite:`Nadaraya.1964,Watson.1964` that we will be studying in the next section. By design, the attention mechanism provides a *differentiable* means of control by which a neural network can select elements from a set and to construct an associated weighted sum over representations. Exercises --------- 1. Suppose that you wanted to reimplement approximate (key, query) matches as used in classical databases, which attention function would you pick? 2. Suppose that the attention function is given by :math:`a(\mathbf{q}, \mathbf{k}_i) = \mathbf{q}^\top \mathbf{k}_i` and that :math:`\mathbf{k}_i = \mathbf{v}_i` for :math:`i = 1, \ldots, m`. Denote by :math:`p(\mathbf{k}_i; \mathbf{q})` the probability distribution over keys when using the softmax normalization in :eq:`eq_softmax_attention`. Prove that :math:`\nabla_{\mathbf{q}} \mathop{\textrm{Attention}}(\mathbf{q}, \mathcal{D}) = \textrm{Cov}_{p(\mathbf{k}_i; \mathbf{q})}[\mathbf{k}_i]`. 3. Design a differentiable search engine using the attention mechanism. 4. Review the design of the Squeeze and Excitation Networks :cite:`Hu.Shen.Sun.2018` and interpret them through the lens of the attention mechanism. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html