.. _sec_rnn-scratch:
Recurrent Neural Network Implementation from Scratch
====================================================
We are now ready to implement an RNN from scratch. In particular, we
will train this RNN to function as a character-level language model (see
:numref:`sec_rnn`) and train it on a corpus consisting of the entire
text of H. G. Wells’ *The Time Machine*, following the data processing
steps outlined in :numref:`sec_text-sequence`. We start by loading the
dataset.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
%matplotlib inline
import math
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
%matplotlib inline
import math
from mxnet import autograd, gluon, np, npx
from d2l import mxnet as d2l
npx.set_np()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
%matplotlib inline
import math
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
%matplotlib inline
import math
import tensorflow as tf
from d2l import tensorflow as d2l
.. raw:: html
.. raw:: html
RNN Model
---------
We begin by defining a class to implement the RNN model
(:numref:`subsec_rnn_w_hidden_states`). Note that the number of hidden
units ``num_hiddens`` is a tunable hyperparameter.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNNScratch(d2l.Module): #@save
"""The RNN model implemented from scratch."""
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
self.W_xh = nn.Parameter(
torch.randn(num_inputs, num_hiddens) * sigma)
self.W_hh = nn.Parameter(
torch.randn(num_hiddens, num_hiddens) * sigma)
self.b_h = nn.Parameter(torch.zeros(num_hiddens))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNNScratch(d2l.Module): #@save
"""The RNN model implemented from scratch."""
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
self.W_xh = np.random.randn(num_inputs, num_hiddens) * sigma
self.W_hh = np.random.randn(
num_hiddens, num_hiddens) * sigma
self.b_h = np.zeros(num_hiddens)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNNScratch(nn.Module): #@save
"""The RNN model implemented from scratch."""
num_inputs: int
num_hiddens: int
sigma: float = 0.01
def setup(self):
self.W_xh = self.param('W_xh', nn.initializers.normal(self.sigma),
(self.num_inputs, self.num_hiddens))
self.W_hh = self.param('W_hh', nn.initializers.normal(self.sigma),
(self.num_hiddens, self.num_hiddens))
self.b_h = self.param('b_h', nn.initializers.zeros, (self.num_hiddens))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNNScratch(d2l.Module): #@save
"""The RNN model implemented from scratch."""
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
self.W_xh = tf.Variable(tf.random.normal(
(num_inputs, num_hiddens)) * sigma)
self.W_hh = tf.Variable(tf.random.normal(
(num_hiddens, num_hiddens)) * sigma)
self.b_h = tf.Variable(tf.zeros(num_hiddens))
.. raw:: html
.. raw:: html
The ``forward`` method below defines how to compute the output and
hidden state at any time step, given the current input and the state of
the model at the previous time step. Note that the RNN model loops
through the outermost dimension of ``inputs``, updating the hidden state
one time step at a time. The model here uses a :math:`\tanh` activation
function (:numref:`subsec_tanh`).
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(RNNScratch) #@save
def forward(self, inputs, state=None):
if state is None:
# Initial state with shape: (batch_size, num_hiddens)
state = torch.zeros((inputs.shape[1], self.num_hiddens),
device=inputs.device)
else:
state, = state
outputs = []
for X in inputs: # Shape of inputs: (num_steps, batch_size, num_inputs)
state = torch.tanh(torch.matmul(X, self.W_xh) +
torch.matmul(state, self.W_hh) + self.b_h)
outputs.append(state)
return outputs, state
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(RNNScratch) #@save
def forward(self, inputs, state=None):
if state is None:
# Initial state with shape: (batch_size, num_hiddens)
state = np.zeros((inputs.shape[1], self.num_hiddens),
ctx=inputs.ctx)
else:
state, = state
outputs = []
for X in inputs: # Shape of inputs: (num_steps, batch_size, num_inputs)
state = np.tanh(np.dot(X, self.W_xh) +
np.dot(state, self.W_hh) + self.b_h)
outputs.append(state)
return outputs, state
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(RNNScratch) #@save
def __call__(self, inputs, state=None):
if state is not None:
state, = state
outputs = []
for X in inputs: # Shape of inputs: (num_steps, batch_size, num_inputs)
state = jnp.tanh(jnp.matmul(X, self.W_xh) + (
jnp.matmul(state, self.W_hh) if state is not None else 0)
+ self.b_h)
outputs.append(state)
return outputs, state
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(RNNScratch) #@save
def forward(self, inputs, state=None):
if state is None:
# Initial state with shape: (batch_size, num_hiddens)
state = tf.zeros((inputs.shape[1], self.num_hiddens))
else:
state, = state
state = tf.reshape(state, (-1, self.num_hiddens))
outputs = []
for X in inputs: # Shape of inputs: (num_steps, batch_size, num_inputs)
state = tf.tanh(tf.matmul(X, self.W_xh) +
tf.matmul(state, self.W_hh) + self.b_h)
outputs.append(state)
return outputs, state
.. raw:: html
.. raw:: html
We can feed a minibatch of input sequences into an RNN model as follows.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
batch_size, num_inputs, num_hiddens, num_steps = 2, 16, 32, 100
rnn = RNNScratch(num_inputs, num_hiddens)
X = torch.ones((num_steps, batch_size, num_inputs))
outputs, state = rnn(X)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
batch_size, num_inputs, num_hiddens, num_steps = 2, 16, 32, 100
rnn = RNNScratch(num_inputs, num_hiddens)
X = np.ones((num_steps, batch_size, num_inputs))
outputs, state = rnn(X)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[22:31:16] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
batch_size, num_inputs, num_hiddens, num_steps = 2, 16, 32, 100
rnn = RNNScratch(num_inputs, num_hiddens)
X = jnp.ones((num_steps, batch_size, num_inputs))
(outputs, state), _ = rnn.init_with_output(d2l.get_key(), X)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
batch_size, num_inputs, num_hiddens, num_steps = 2, 16, 32, 100
rnn = RNNScratch(num_inputs, num_hiddens)
X = tf.ones((num_steps, batch_size, num_inputs))
outputs, state = rnn(X)
.. raw:: html
.. raw:: html
Let’s check whether the RNN model produces results of the correct shapes
to ensure that the dimensionality of the hidden state remains unchanged.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def check_len(a, n): #@save
"""Check the length of a list."""
assert len(a) == n, f'list\'s length {len(a)} != expected length {n}'
def check_shape(a, shape): #@save
"""Check the shape of a tensor."""
assert a.shape == shape, \
f'tensor\'s shape {a.shape} != expected shape {shape}'
check_len(outputs, num_steps)
check_shape(outputs[0], (batch_size, num_hiddens))
check_shape(state, (batch_size, num_hiddens))
RNN-Based Language Model
------------------------
The following ``RNNLMScratch`` class defines an RNN-based language
model, where we pass in our RNN via the ``rnn`` argument of the
``__init__`` method. When training language models, the inputs and
outputs are from the same vocabulary. Hence, they have the same
dimension, which is equal to the vocabulary size. Note that we use
perplexity to evaluate the model. As discussed in
:numref:`subsec_perplexity`, this ensures that sequences of different
length are comparable.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNNLMScratch(d2l.Classifier): #@save
"""The RNN-based language model implemented from scratch."""
def __init__(self, rnn, vocab_size, lr=0.01):
super().__init__()
self.save_hyperparameters()
self.init_params()
def init_params(self):
self.W_hq = nn.Parameter(
torch.randn(
self.rnn.num_hiddens, self.vocab_size) * self.rnn.sigma)
self.b_q = nn.Parameter(torch.zeros(self.vocab_size))
def training_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('ppl', torch.exp(l), train=True)
return l
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('ppl', torch.exp(l), train=False)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNNLMScratch(d2l.Classifier): #@save
"""The RNN-based language model implemented from scratch."""
def __init__(self, rnn, vocab_size, lr=0.01):
super().__init__()
self.save_hyperparameters()
self.init_params()
def init_params(self):
self.W_hq = np.random.randn(
self.rnn.num_hiddens, self.vocab_size) * self.rnn.sigma
self.b_q = np.zeros(self.vocab_size)
for param in self.get_scratch_params():
param.attach_grad()
def training_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('ppl', np.exp(l), train=True)
return l
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('ppl', np.exp(l), train=False)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNNLMScratch(d2l.Classifier): #@save
"""The RNN-based language model implemented from scratch."""
rnn: nn.Module
vocab_size: int
lr: float = 0.01
def setup(self):
self.W_hq = self.param('W_hq', nn.initializers.normal(self.rnn.sigma),
(self.rnn.num_hiddens, self.vocab_size))
self.b_q = self.param('b_q', nn.initializers.zeros, (self.vocab_size))
def training_step(self, params, batch, state):
value, grads = jax.value_and_grad(
self.loss, has_aux=True)(params, batch[:-1], batch[-1], state)
l, _ = value
self.plot('ppl', jnp.exp(l), train=True)
return value, grads
def validation_step(self, params, batch, state):
l, _ = self.loss(params, batch[:-1], batch[-1], state)
self.plot('ppl', jnp.exp(l), train=False)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNNLMScratch(d2l.Classifier): #@save
"""The RNN-based language model implemented from scratch."""
def __init__(self, rnn, vocab_size, lr=0.01):
super().__init__()
self.save_hyperparameters()
self.init_params()
def init_params(self):
self.W_hq = tf.Variable(tf.random.normal(
(self.rnn.num_hiddens, self.vocab_size)) * self.rnn.sigma)
self.b_q = tf.Variable(tf.zeros(self.vocab_size))
def training_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('ppl', tf.exp(l), train=True)
return l
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('ppl', tf.exp(l), train=False)
.. raw:: html
.. raw:: html
One-Hot Encoding
~~~~~~~~~~~~~~~~
Recall that each token is represented by a numerical index indicating
the position in the vocabulary of the corresponding word/character/word
piece. You might be tempted to build a neural network with a single
input node (at each time step), where the index could be fed in as a
scalar value. This works when we are dealing with numerical inputs like
price or temperature, where any two values sufficiently close together
should be treated similarly. But this does not quite make sense. The
:math:`45^{\textrm{th}}` and :math:`46^{\textrm{th}}` words in our
vocabulary happen to be “their” and “said”, whose meanings are not
remotely similar.
When dealing with such categorical data, the most common strategy is to
represent each item by a *one-hot encoding* (recall from
:numref:`subsec_classification-problem`). A one-hot encoding is a
vector whose length is given by the size of the vocabulary :math:`N`,
where all entries are set to :math:`0`, except for the entry
corresponding to our token, which is set to :math:`1`. For example, if
the vocabulary had five elements, then the one-hot vectors corresponding
to indices 0 and 2 would be the following.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
F.one_hot(torch.tensor([0, 2]), 5)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
npx.one_hot(np.array([0, 2]), 5)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[1., 0., 0., 0., 0.],
[0., 0., 1., 0., 0.]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
jax.nn.one_hot(jnp.array([0, 2]), 5)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([[1., 0., 0., 0., 0.],
[0., 0., 1., 0., 0.]], dtype=float32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
tf.one_hot(tf.constant([0, 2]), 5)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
The minibatches that we sample at each iteration will take the shape
(batch size, number of time steps). Once representing each input as a
one-hot vector, we can think of each minibatch as a three-dimensional
tensor, where the length along the third axis is given by the vocabulary
size (``len(vocab)``). We often transpose the input so that we will
obtain an output of shape (number of time steps, batch size, vocabulary
size). This will allow us to loop more conveniently through the
outermost dimension for updating hidden states of a minibatch, time step
by time step (e.g., in the above ``forward`` method).
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(RNNLMScratch) #@save
def one_hot(self, X):
# Output shape: (num_steps, batch_size, vocab_size)
return F.one_hot(X.T, self.vocab_size).type(torch.float32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(RNNLMScratch) #@save
def one_hot(self, X):
# Output shape: (num_steps, batch_size, vocab_size)
return npx.one_hot(X.T, self.vocab_size)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(RNNLMScratch) #@save
def one_hot(self, X):
# Output shape: (num_steps, batch_size, vocab_size)
return jax.nn.one_hot(X.T, self.vocab_size)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(RNNLMScratch) #@save
def one_hot(self, X):
# Output shape: (num_steps, batch_size, vocab_size)
return tf.one_hot(tf.transpose(X), self.vocab_size)
.. raw:: html
.. raw:: html
Transforming RNN Outputs
~~~~~~~~~~~~~~~~~~~~~~~~
The language model uses a fully connected output layer to transform RNN
outputs into token predictions at each time step.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(RNNLMScratch) #@save
def output_layer(self, rnn_outputs):
outputs = [torch.matmul(H, self.W_hq) + self.b_q for H in rnn_outputs]
return torch.stack(outputs, 1)
@d2l.add_to_class(RNNLMScratch) #@save
def forward(self, X, state=None):
embs = self.one_hot(X)
rnn_outputs, _ = self.rnn(embs, state)
return self.output_layer(rnn_outputs)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(RNNLMScratch) #@save
def output_layer(self, rnn_outputs):
outputs = [np.dot(H, self.W_hq) + self.b_q for H in rnn_outputs]
return np.stack(outputs, 1)
@d2l.add_to_class(RNNLMScratch) #@save
def forward(self, X, state=None):
embs = self.one_hot(X)
rnn_outputs, _ = self.rnn(embs, state)
return self.output_layer(rnn_outputs)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(RNNLMScratch) #@save
def output_layer(self, rnn_outputs):
outputs = [jnp.matmul(H, self.W_hq) + self.b_q for H in rnn_outputs]
return jnp.stack(outputs, 1)
@d2l.add_to_class(RNNLMScratch) #@save
def forward(self, X, state=None):
embs = self.one_hot(X)
rnn_outputs, _ = self.rnn(embs, state)
return self.output_layer(rnn_outputs)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(RNNLMScratch) #@save
def output_layer(self, rnn_outputs):
outputs = [tf.matmul(H, self.W_hq) + self.b_q for H in rnn_outputs]
return tf.stack(outputs, 1)
@d2l.add_to_class(RNNLMScratch) #@save
def forward(self, X, state=None):
embs = self.one_hot(X)
rnn_outputs, _ = self.rnn(embs, state)
return self.output_layer(rnn_outputs)
.. raw:: html
.. raw:: html
Let’s check whether the forward computation produces outputs with the
correct shape.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model = RNNLMScratch(rnn, num_inputs)
outputs = model(torch.ones((batch_size, num_steps), dtype=torch.int64))
check_shape(outputs, (batch_size, num_steps, num_inputs))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model = RNNLMScratch(rnn, num_inputs)
outputs = model(np.ones((batch_size, num_steps), dtype=np.int64))
check_shape(outputs, (batch_size, num_steps, num_inputs))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model = RNNLMScratch(rnn, num_inputs)
outputs, _ = model.init_with_output(d2l.get_key(),
jnp.ones((batch_size, num_steps),
dtype=jnp.int32))
check_shape(outputs, (batch_size, num_steps, num_inputs))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model = RNNLMScratch(rnn, num_inputs)
outputs = model(tf.ones((batch_size, num_steps), dtype=tf.int64))
check_shape(outputs, (batch_size, num_steps, num_inputs))
.. raw:: html
.. raw:: html
Gradient Clipping
-----------------
While you are already used to thinking of neural networks as “deep” in
the sense that many layers separate the input and output even within a
single time step, the length of the sequence introduces a new notion of
depth. In addition to the passing through the network in the
input-to-output direction, inputs at the first time step must pass
through a chain of :math:`T` layers along the time steps in order to
influence the output of the model at the final time step. Taking the
backwards view, in each iteration, we backpropagate gradients through
time, resulting in a chain of matrix-products of length
:math:`\mathcal{O}(T)`. As mentioned in
:numref:`sec_numerical_stability`, this can result in numerical
instability, causing the gradients either to explode or vanish,
depending on the properties of the weight matrices.
Dealing with vanishing and exploding gradients is a fundamental problem
when designing RNNs and has inspired some of the biggest advances in
modern neural network architectures. In the next chapter, we will talk
about specialized architectures that were designed in hopes of
mitigating the vanishing gradient problem. However, even modern RNNs
often suffer from exploding gradients. One inelegant but ubiquitous
solution is to simply clip the gradients forcing the resulting “clipped”
gradients to take smaller values.
Generally speaking, when optimizing some objective by gradient descent,
we iteratively update the parameter of interest, say a vector
:math:`\mathbf{x}`, but pushing it in the direction of the negative
gradient :math:`\mathbf{g}` (in stochastic gradient descent, we
calculate this gradient on a randomly sampled minibatch). For example,
with learning rate :math:`\eta > 0`, each update takes the form
:math:`\mathbf{x} \gets \mathbf{x} - \eta \mathbf{g}`. Let’s further
assume that the objective function :math:`f` is sufficiently smooth.
Formally, we say that the objective is *Lipschitz continuous* with
constant :math:`L`, meaning that for any :math:`\mathbf{x}` and
:math:`\mathbf{y}`, we have
.. math:: |f(\mathbf{x}) - f(\mathbf{y})| \leq L \|\mathbf{x} - \mathbf{y}\|.
As you can see, when we update the parameter vector by subtracting
:math:`\eta \mathbf{g}`, the change in the value of the objective
depends on the learning rate, the norm of the gradient and :math:`L` as
follows:
.. math:: |f(\mathbf{x}) - f(\mathbf{x} - \eta\mathbf{g})| \leq L \eta\|\mathbf{g}\|.
In other words, the objective cannot change by more than
:math:`L \eta \|\mathbf{g}\|`. Having a small value for this upper bound
might be viewed as good or bad. On the downside, we are limiting the
speed at which we can reduce the value of the objective. On the bright
side, this limits by just how much we can go wrong in any one gradient
step.
When we say that gradients explode, we mean that :math:`\|\mathbf{g}\|`
becomes excessively large. In this worst case, we might do so much
damage in a single gradient step that we could undo all of the progress
made over the course of thousands of training iterations. When gradients
can be so large, neural network training often diverges, failing to
reduce the value of the objective. At other times, training eventually
converges but is unstable owing to massive spikes in the loss.
One way to limit the size of :math:`L \eta \|\mathbf{g}\|` is to shrink
the learning rate :math:`\eta` to tiny values. This has the advantage
that we do not bias the updates. But what if we only *rarely* get large
gradients? This drastic move slows down our progress at all steps, just
to deal with the rare exploding gradient events. A popular alternative
is to adopt a *gradient clipping* heuristic projecting the gradients
:math:`\mathbf{g}` onto a ball of some given radius :math:`\theta` as
follows:
.. math:: \mathbf{g} \leftarrow \min\left(1, \frac{\theta}{\|\mathbf{g}\|}\right) \mathbf{g}.
This ensures that the gradient norm never exceeds :math:`\theta` and
that the updated gradient is entirely aligned with the original
direction of :math:`\mathbf{g}`. It also has the desirable side-effect
of limiting the influence any given minibatch (and within it any given
sample) can exert on the parameter vector. This bestows a certain degree
of robustness to the model. To be clear, it is a hack. Gradient clipping
means that we are not always following the true gradient and it is hard
to reason analytically about the possible side effects. However, it is a
very useful hack, and is widely adopted in RNN implementations in most
deep learning frameworks.
Below we define a method to clip gradients, which is invoked by the
``fit_epoch`` method of the ``d2l.Trainer`` class (see
:numref:`sec_linear_scratch`). Note that when computing the gradient
norm, we are concatenating all model parameters, treating them as a
single giant parameter vector.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(d2l.Trainer) #@save
def clip_gradients(self, grad_clip_val, model):
params = [p for p in model.parameters() if p.requires_grad]
norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
if norm > grad_clip_val:
for param in params:
param.grad[:] *= grad_clip_val / norm
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(d2l.Trainer) #@save
def clip_gradients(self, grad_clip_val, model):
params = model.parameters()
if not isinstance(params, list):
params = [p.data() for p in params.values()]
norm = math.sqrt(sum((p.grad ** 2).sum() for p in params))
if norm > grad_clip_val:
for param in params:
param.grad[:] *= grad_clip_val / norm
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(d2l.Trainer) #@save
def clip_gradients(self, grad_clip_val, grads):
grad_leaves, _ = jax.tree_util.tree_flatten(grads)
norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in grad_leaves))
clip = lambda grad: jnp.where(norm < grad_clip_val,
grad, grad * (grad_clip_val / norm))
return jax.tree_util.tree_map(clip, grads)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(d2l.Trainer) #@save
def clip_gradients(self, grad_clip_val, grads):
grad_clip_val = tf.constant(grad_clip_val, dtype=tf.float32)
new_grads = [tf.convert_to_tensor(grad) if isinstance(
grad, tf.IndexedSlices) else grad for grad in grads]
norm = tf.math.sqrt(sum((tf.reduce_sum(grad ** 2)) for grad in new_grads))
if tf.greater(norm, grad_clip_val):
for i, grad in enumerate(new_grads):
new_grads[i] = grad * grad_clip_val / norm
return new_grads
return grads
.. raw:: html
.. raw:: html
Training
--------
Using *The Time Machine* dataset (``data``), we train a character-level
language model (``model``) based on the RNN (``rnn``) implemented from
scratch. Note that we first calculate the gradients, then clip them, and
finally update the model parameters using the clipped gradients.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
rnn = RNNScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = RNNLMScratch(rnn, vocab_size=len(data.vocab), lr=1)
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
.. figure:: output_rnn-scratch_546c4d_155_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
rnn = RNNScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = RNNLMScratch(rnn, vocab_size=len(data.vocab), lr=1)
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
.. figure:: output_rnn-scratch_546c4d_158_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
rnn = RNNScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = RNNLMScratch(rnn, vocab_size=len(data.vocab), lr=1)
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
.. figure:: output_rnn-scratch_546c4d_161_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
with d2l.try_gpu():
rnn = RNNScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = RNNLMScratch(rnn, vocab_size=len(data.vocab), lr=1)
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1)
trainer.fit(model, data)
.. figure:: output_rnn-scratch_546c4d_164_0.svg
.. raw:: html
.. raw:: html
Decoding
--------
Once a language model has been learned, we can use it not only to
predict the next token but to continue predicting each subsequent one,
treating the previously predicted token as though it were the next in
the input. Sometimes we will just want to generate text as though we
were starting at the beginning of a document. However, it is often
useful to condition the language model on a user-supplied prefix. For
example, if we were developing an autocomplete feature for a search
engine or to assist users in writing emails, we would want to feed in
what they had written so far (the prefix), and then generate a likely
continuation.
The following ``predict`` method generates a continuation, one character
at a time, after ingesting a user-provided ``prefix``. When looping
through the characters in ``prefix``, we keep passing the hidden state
to the next time step but do not generate any output. This is called the
*warm-up* period. After ingesting the prefix, we are now ready to begin
emitting the subsequent characters, each of which will be fed back into
the model as the input at the next time step.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(RNNLMScratch) #@save
def predict(self, prefix, num_preds, vocab, device=None):
state, outputs = None, [vocab[prefix[0]]]
for i in range(len(prefix) + num_preds - 1):
X = torch.tensor([[outputs[-1]]], device=device)
embs = self.one_hot(X)
rnn_outputs, state = self.rnn(embs, state)
if i < len(prefix) - 1: # Warm-up period
outputs.append(vocab[prefix[i + 1]])
else: # Predict num_preds steps
Y = self.output_layer(rnn_outputs)
outputs.append(int(Y.argmax(axis=2).reshape(1)))
return ''.join([vocab.idx_to_token[i] for i in outputs])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(RNNLMScratch) #@save
def predict(self, prefix, num_preds, vocab, device=None):
state, outputs = None, [vocab[prefix[0]]]
for i in range(len(prefix) + num_preds - 1):
X = np.array([[outputs[-1]]], ctx=device)
embs = self.one_hot(X)
rnn_outputs, state = self.rnn(embs, state)
if i < len(prefix) - 1: # Warm-up period
outputs.append(vocab[prefix[i + 1]])
else: # Predict num_preds steps
Y = self.output_layer(rnn_outputs)
outputs.append(int(Y.argmax(axis=2).reshape(1)))
return ''.join([vocab.idx_to_token[i] for i in outputs])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(RNNLMScratch) #@save
def predict(self, prefix, num_preds, vocab, params):
state, outputs = None, [vocab[prefix[0]]]
for i in range(len(prefix) + num_preds - 1):
X = jnp.array([[outputs[-1]]])
embs = self.one_hot(X)
rnn_outputs, state = self.rnn.apply({'params': params['rnn']},
embs, state)
if i < len(prefix) - 1: # Warm-up period
outputs.append(vocab[prefix[i + 1]])
else: # Predict num_preds steps
Y = self.apply({'params': params}, rnn_outputs,
method=self.output_layer)
outputs.append(int(Y.argmax(axis=2).reshape(1)))
return ''.join([vocab.idx_to_token[i] for i in outputs])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(RNNLMScratch) #@save
def predict(self, prefix, num_preds, vocab, device=None):
state, outputs = None, [vocab[prefix[0]]]
for i in range(len(prefix) + num_preds - 1):
X = tf.constant([[outputs[-1]]])
embs = self.one_hot(X)
rnn_outputs, state = self.rnn(embs, state)
if i < len(prefix) - 1: # Warm-up period
outputs.append(vocab[prefix[i + 1]])
else: # Predict num_preds steps
Y = self.output_layer(rnn_outputs)
outputs.append(int(tf.reshape(tf.argmax(Y, axis=2), 1)))
return ''.join([vocab.idx_to_token[i] for i in outputs])
.. raw:: html
.. raw:: html
In the following, we specify the prefix and have it generate 20
additional characters.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model.predict('it has', 20, data.vocab, d2l.try_gpu())
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'it has in the the the the '
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model.predict('it has', 20, data.vocab, d2l.try_gpu())
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'it has in the the prace th'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model.predict('it has', 20, data.vocab, trainer.state.params)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'it has and the the the the'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model.predict('it has', 20, data.vocab)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'it has it and the the the '
.. raw:: html
.. raw:: html
While implementing the above RNN model from scratch is instructive, it
is not convenient. In the next section, we will see how to leverage deep
learning frameworks to whip up RNNs using standard architectures, and to
reap performance gains by relying on highly optimized library functions.
Summary
-------
We can train RNN-based language models to generate text following the
user-provided text prefix. A simple RNN language model consists of input
encoding, RNN modeling, and output generation. During training, gradient
clipping can mitigate the problem of exploding gradients but does not
address the problem of vanishing gradients. In the experiment, we
implemented a simple RNN language model and trained it with gradient
clipping on sequences of text, tokenized at the character level. By
conditioning on a prefix, we can use a language model to generate likely
continuations, which proves useful in many applications, e.g.,
autocomplete features.
Exercises
---------
1. Does the implemented language model predict the next token based on
all the past tokens up to the very first token in *The Time
Machine*?
2. Which hyperparameter controls the length of history used for
prediction?
3. Show that one-hot encoding is equivalent to picking a different
embedding for each object.
4. Adjust the hyperparameters (e.g., number of epochs, number of hidden
units, number of time steps in a minibatch, and learning rate) to
improve the perplexity. How low can you go while sticking with this
simple architecture?
5. Replace one-hot encoding with learnable embeddings. Does this lead
to better performance?
6. Conduct an experiment to determine how well this language model
trained on *The Time Machine* works on other books by H. G. Wells,
e.g., *The War of the Worlds*.
7. Conduct another experiment to evaluate the perplexity of this model
on books written by other authors.
8. Modify the prediction method so as to use sampling rather than
picking the most likely next character.
- What happens?
- Bias the model towards more likely outputs, e.g., by sampling
from
:math:`q(x_t \mid x_{t-1}, \ldots, x_1) \propto P(x_t \mid x_{t-1}, \ldots, x_1)^\alpha`
for :math:`\alpha > 1`.
9. Run the code in this section without clipping the gradient. What
happens?
10. Replace the activation function used in this section with ReLU and
repeat the experiments in this section. Do we still need gradient
clipping? Why?
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html