.. _sec_queries-keys-values:
Queries, Keys, and Values
=========================
So far all the networks we have reviewed crucially relied on the input
being of a well-defined size. For instance, the images in ImageNet are
of size :math:`224 \times 224` pixels and CNNs are specifically tuned to
this size. Even in natural language processing the input size for RNNs
is well defined and fixed. Variable size is addressed by sequentially
processing one token at a time, or by specially designed convolution
kernels :cite:`Kalchbrenner.Grefenstette.Blunsom.2014`. This approach
can lead to significant problems when the input is truly of varying size
with varying information content, such as in :numref:`sec_seq2seq` in
the transformation of text :cite:`Sutskever.Vinyals.Le.2014`. In
particular, for long sequences it becomes quite difficult to keep track
of everything that has already been generated or even viewed by the
network. Even explicit tracking heuristics such as proposed by
:cite:t:`yang2016neural` only offer limited benefit.
Compare this to databases. In their simplest form they are collections
of keys (:math:`k`) and values (:math:`v`). For instance, our database
:math:`\mathcal{D}` might consist of tuples {(“Zhang”, “Aston”),
(“Lipton”, “Zachary”), (“Li”, “Mu”), (“Smola”, “Alex”), (“Hu”,
“Rachel”), (“Werness”, “Brent”)} with the last name being the key and
the first name being the value. We can operate on :math:`\mathcal{D}`,
for instance with the exact query (:math:`q`) for “Li” which would
return the value “Mu”. If (“Li”, “Mu”) was not a record in
:math:`\mathcal{D}`, there would be no valid answer. If we also allowed
for approximate matches, we would retrieve (“Lipton”, “Zachary”)
instead. This quite simple and trivial example nonetheless teaches us a
number of useful things:
- We can design queries :math:`q` that operate on
(:math:`k`,\ :math:`v`) pairs in such a manner as to be valid
regardless of the database size.
- The same query can receive different answers, according to the
contents of the database.
- The “code” being executed for operating on a large state space (the
database) can be quite simple (e.g., exact match, approximate match,
top-:math:`k`).
- There is no need to compress or simplify the database to make the
operations effective.
Clearly we would not have introduced a simple database here if it wasn’t
for the purpose of explaining deep learning. Indeed, this leads to one
of the most exciting concepts introduced in deep learning in the past
decade: the *attention mechanism* :cite:`Bahdanau.Cho.Bengio.2014`. We
will cover the specifics of its application to machine translation
later. For now, simply consider the following: denote by
:math:`\mathcal{D} \stackrel{\textrm{def}}{=} \{(\mathbf{k}_1, \mathbf{v}_1), \ldots (\mathbf{k}_m, \mathbf{v}_m)\}`
a database of :math:`m` tuples of *keys* and *values*. Moreover, denote
by :math:`\mathbf{q}` a *query*. Then we can define the *attention* over
:math:`\mathcal{D}` as
.. math:: \textrm{Attention}(\mathbf{q}, \mathcal{D}) \stackrel{\textrm{def}}{=} \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i,
:label: eq_attention_pooling
where :math:`\alpha(\mathbf{q}, \mathbf{k}_i) \in \mathbb{R}`
(:math:`i = 1, \ldots, m`) are scalar attention weights. The operation
itself is typically referred to as *attention pooling*. The name
*attention* derives from the fact that the operation pays particular
attention to the terms for which the weight :math:`\alpha` is
significant (i.e., large). As such, the attention over
:math:`\mathcal{D}` generates a linear combination of values contained
in the database. In fact, this contains the above example as a special
case where all but one weight is zero. We have a number of special
cases:
- The weights :math:`\alpha(\mathbf{q}, \mathbf{k}_i)` are nonnegative.
In this case the output of the attention mechanism is contained in
the convex cone spanned by the values :math:`\mathbf{v}_i`.
- The weights :math:`\alpha(\mathbf{q}, \mathbf{k}_i)` form a convex
combination, i.e.,
:math:`\sum_i \alpha(\mathbf{q}, \mathbf{k}_i) = 1` and
:math:`\alpha(\mathbf{q}, \mathbf{k}_i) \geq 0` for all :math:`i`.
This is the most common setting in deep learning.
- Exactly one of the weights :math:`\alpha(\mathbf{q}, \mathbf{k}_i)`
is :math:`1`, while all others are :math:`0`. This is akin to a
traditional database query.
- All weights are equal, i.e.,
:math:`\alpha(\mathbf{q}, \mathbf{k}_i) = \frac{1}{m}` for all
:math:`i`. This amounts to averaging across the entire database, also
called average pooling in deep learning.
A common strategy for ensuring that the weights sum up to :math:`1` is
to normalize them via
.. math:: \alpha(\mathbf{q}, \mathbf{k}_i) = \frac{\alpha(\mathbf{q}, \mathbf{k}_i)}{{\sum_j} \alpha(\mathbf{q}, \mathbf{k}_j)}.
In particular, to ensure that the weights are also nonnegative, one can
resort to exponentiation. This means that we can now pick *any* function
:math:`a(\mathbf{q}, \mathbf{k})` and then apply the softmax operation
used for multinomial models to it via
.. math:: \alpha(\mathbf{q}, \mathbf{k}_i) = \frac{\exp(a(\mathbf{q}, \mathbf{k}_i))}{\sum_j \exp(a(\mathbf{q}, \mathbf{k}_j))}.
:label: eq_softmax_attention
This operation is readily available in all deep learning frameworks. It
is differentiable and its gradient never vanishes, all of which are
desirable properties in a model. Note though, the attention mechanism
introduced above is not the only option. For instance, we can design a
non-differentiable attention model that can be trained using
reinforcement learning methods :cite:`Mnih.Heess.Graves.ea.2014`. As
one would expect, training such a model is quite complex. Consequently
the bulk of modern attention research follows the framework outlined in
:numref:`fig_qkv`. We thus focus our exposition on this family of
differentiable mechanisms.
.. _fig_qkv:
.. figure:: ../img/qkv.svg
The attention mechanism computes a linear combination over values
:math:`\mathbf{v}_\mathit{i}` via attention pooling, where weights
are derived according to the compatibility between a query
:math:`\mathbf{q}` and keys :math:`\mathbf{k}_\mathit{i}`.
What is quite remarkable is that the actual “code” for executing on the
set of keys and values, namely the query, can be quite concise, even
though the space to operate on is significant. This is a desirable
property for a network layer as it does not require too many parameters
to learn. Just as convenient is the fact that attention can operate on
arbitrarily large databases without the need to change the way the
attention pooling operation is performed.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import torch
from d2l import torch as d2l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from mxnet import np, npx
from d2l import mxnet as d2l
npx.set_np()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from jax import numpy as jnp
from d2l import jax as d2l
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import tensorflow as tf
from d2l import tensorflow as d2l
.. raw:: html
.. raw:: html
Visualization
-------------
One of the benefits of the attention mechanism is that it can be quite
intuitive, particularly when the weights are nonnegative and sum to
:math:`1`. In this case we might *interpret* large weights as a way for
the model to select components of relevance. While this is a good
intuition, it is important to remember that it is just that, an
*intuition*. Regardless, we may want to visualize its effect on the
given set of keys when applying a variety of different queries. This
function will come in handy later.
We thus define the ``show_heatmaps`` function. Note that it does not
take a matrix (of attention weights) as its input but rather a tensor
with four axes, allowing for an array of different queries and weights.
Consequently the input ``matrices`` has the shape (number of rows for
display, number of columns for display, number of queries, number of
keys). This will come in handy later on when we want to visualize the
workings that are to design Transformers.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
cmap='Reds'):
"""Show heatmaps of matrices."""
d2l.use_svg_display()
num_rows, num_cols, _, _ = matrices.shape
fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
sharex=True, sharey=True, squeeze=False)
for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)
if i == num_rows - 1:
ax.set_xlabel(xlabel)
if j == 0:
ax.set_ylabel(ylabel)
if titles:
ax.set_title(titles[j])
fig.colorbar(pcm, ax=axes, shrink=0.6);
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
cmap='Reds'):
"""Show heatmaps of matrices."""
d2l.use_svg_display()
num_rows, num_cols, _, _ = matrices.shape
fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
sharex=True, sharey=True, squeeze=False)
for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
pcm = ax.imshow(matrix.asnumpy(), cmap=cmap)
if i == num_rows - 1:
ax.set_xlabel(xlabel)
if j == 0:
ax.set_ylabel(ylabel)
if titles:
ax.set_title(titles[j])
fig.colorbar(pcm, ax=axes, shrink=0.6);
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
cmap='Reds'):
"""Show heatmaps of matrices."""
d2l.use_svg_display()
num_rows, num_cols, _, _ = matrices.shape
fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
sharex=True, sharey=True, squeeze=False)
for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
pcm = ax.imshow(matrix, cmap=cmap)
if i == num_rows - 1:
ax.set_xlabel(xlabel)
if j == 0:
ax.set_ylabel(ylabel)
if titles:
ax.set_title(titles[j])
fig.colorbar(pcm, ax=axes, shrink=0.6);
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
cmap='Reds'):
"""Show heatmaps of matrices."""
d2l.use_svg_display()
num_rows, num_cols, _, _ = matrices.shape
fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
sharex=True, sharey=True, squeeze=False)
for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
pcm = ax.imshow(matrix.numpy(), cmap=cmap)
if i == num_rows - 1:
ax.set_xlabel(xlabel)
if j == 0:
ax.set_ylabel(ylabel)
if titles:
ax.set_title(titles[j])
fig.colorbar(pcm, ax=axes, shrink=0.6);
.. raw:: html
.. raw:: html
As a quick sanity check let’s visualize the identity matrix,
representing a case where the attention weight is :math:`1` only when
the query and the key are the same.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
attention_weights = torch.eye(10).reshape((1, 1, 10, 10))
show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')
.. figure:: output_queries-keys-values_7fe0e8_33_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
attention_weights = np.eye(10).reshape((1, 1, 10, 10))
show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[21:50:08] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
.. figure:: output_queries-keys-values_7fe0e8_36_1.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
attention_weights = jnp.eye(10).reshape((1, 1, 10, 10))
show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')
.. figure:: output_queries-keys-values_7fe0e8_39_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
attention_weights = tf.reshape(tf.eye(10), (1, 1, 10, 10))
show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')
.. figure:: output_queries-keys-values_7fe0e8_42_0.svg
.. raw:: html
.. raw:: html
Summary
-------
The attention mechanism allows us to aggregate data from many (key,
value) pairs. So far our discussion was quite abstract, simply
describing a way to pool data. We have not explained yet where those
mysterious queries, keys, and values might arise from. Some intuition
might help here: for instance, in a regression setting, the query might
correspond to the location where the regression should be carried out.
The keys are the locations where past data was observed and the values
are the (regression) values themselves. This is the so-called
Nadaraya–Watson estimator :cite:`Nadaraya.1964,Watson.1964` that we
will be studying in the next section.
By design, the attention mechanism provides a *differentiable* means of
control by which a neural network can select elements from a set and to
construct an associated weighted sum over representations.
Exercises
---------
1. Suppose that you wanted to reimplement approximate (key, query)
matches as used in classical databases, which attention function
would you pick?
2. Suppose that the attention function is given by
:math:`a(\mathbf{q}, \mathbf{k}_i) = \mathbf{q}^\top \mathbf{k}_i`
and that :math:`\mathbf{k}_i = \mathbf{v}_i` for
:math:`i = 1, \ldots, m`. Denote by
:math:`p(\mathbf{k}_i; \mathbf{q})` the probability distribution over
keys when using the softmax normalization in
:eq:`eq_softmax_attention`. Prove that
:math:`\nabla_{\mathbf{q}} \mathop{\textrm{Attention}}(\mathbf{q}, \mathcal{D}) = \textrm{Cov}_{p(\mathbf{k}_i; \mathbf{q})}[\mathbf{k}_i]`.
3. Design a differentiable search engine using the attention mechanism.
4. Review the design of the Squeeze and Excitation Networks
:cite:`Hu.Shen.Sun.2018` and interpret them through the lens of the
attention mechanism.
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html