9.6. Concise Implementation of Recurrent Neural Networks
Open the notebook in Colab
Open the notebook in Colab
Open the notebook in Colab
Open the notebook in Colab
Open the notebook in SageMaker Studio Lab

Like most of our from-scratch implementations, Section 9.5 was designed to provide insight into how each component works. But when you are using RNNs every day or writing production code, you will want to rely more on libraries that cut down on both implementation time (by supplying library code for common models and functions) and computation time (by optimizing the heck out of these library implementations). This section will show you how to implement the same language model more efficiently using the high-level API provided by your deep learning framework. We begin, as before, by loading The Time Machine dataset.

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
from mxnet import np, npx
from mxnet.gluon import nn, rnn
from d2l import mxnet as d2l

npx.set_np()
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf
from d2l import tensorflow as d2l

9.6.1. Defining the Model

We define the following class using the RNN implemented by high-level APIs.

class RNN(d2l.Module):  #@save
    """The RNN model implemented with high-level APIs."""
    def __init__(self, num_inputs, num_hiddens):
        super().__init__()
        self.save_hyperparameters()
        self.rnn = nn.RNN(num_inputs, num_hiddens)

    def forward(self, inputs, H=None):
        return self.rnn(inputs, H)

Specifically, to initialize the hidden state, we invoke the member method begin_state. This returns a list that contains an initial hidden state for each example in the minibatch, whose shape is (number of hidden layers, batch size, number of hidden units). For some models to be introduced later (e.g., long short-term memory), this list will also contain other information.

class RNN(d2l.Module):  #@save
    """The RNN model implemented with high-level APIs."""
    def __init__(self, num_hiddens):
        super().__init__()
        self.save_hyperparameters()
        self.rnn = rnn.RNN(num_hiddens)

    def forward(self, inputs, H=None):
        if H is None:
            H, = self.rnn.begin_state(inputs.shape[1], ctx=inputs.ctx)
        outputs, (H, ) = self.rnn(inputs, (H, ))
        return outputs, H

Flax does not provide an RNNCell for concise implementation of Vanilla RNNs as of today. There are more advanced variants of RNNs like LSTMs and GRUs which are available in the Flax linen API.

class RNN(nn.Module):  #@save
    """The RNN model implemented with high-level APIs."""
    num_hiddens: int

    @nn.compact
    def __call__(self, inputs, H=None):
        raise NotImplementedError
class RNN(d2l.Module):  #@save
    """The RNN model implemented with high-level APIs."""
    def __init__(self, num_hiddens):
        super().__init__()
        self.save_hyperparameters()
        self.rnn = tf.keras.layers.SimpleRNN(
            num_hiddens, return_sequences=True, return_state=True,
            time_major=True)

    def forward(self, inputs, H=None):
        outputs, H = self.rnn(inputs, H)
        return outputs, H

Inheriting from the RNNLMScratch class in Section 9.5, the following RNNLM class defines a complete RNN-based language model. Note that we need to create a separate fully connected output layer.

class RNNLM(d2l.RNNLMScratch):  #@save
    """The RNN-based language model implemented with high-level APIs."""
    def init_params(self):
        self.linear = nn.LazyLinear(self.vocab_size)

    def output_layer(self, hiddens):
        return self.linear(hiddens).swapaxes(0, 1)
class RNNLM(d2l.RNNLMScratch):  #@save
    """The RNN-based language model implemented with high-level APIs."""
    def init_params(self):
        self.linear = nn.Dense(self.vocab_size, flatten=False)
        self.initialize()
    def output_layer(self, hiddens):
        return self.linear(hiddens).swapaxes(0, 1)
class RNNLM(d2l.RNNLMScratch):  #@save
    """The RNN-based language model implemented with high-level APIs."""
    training: bool = True

    def setup(self):
        self.linear = nn.Dense(self.vocab_size)

    def output_layer(self, hiddens):
        return self.linear(hiddens).swapaxes(0, 1)

    def forward(self, X, state=None):
        embs = self.one_hot(X)
        rnn_outputs, _ = self.rnn(embs, state, self.training)
        return self.output_layer(rnn_outputs)
class RNNLM(d2l.RNNLMScratch):  #@save
    """The RNN-based language model implemented with high-level APIs."""
    def init_params(self):
        self.linear = tf.keras.layers.Dense(self.vocab_size)

    def output_layer(self, hiddens):
        return tf.transpose(self.linear(hiddens), (1, 0, 2))

9.6.2. Training and Predicting

Before training the model, let’s make a prediction with a model initialized with random weights. Given that we have not trained the network, it will generate nonsensical predictions.

data = d2l.TimeMachine(batch_size=1024, num_steps=32)
rnn = RNN(num_inputs=len(data.vocab), num_hiddens=32)
model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1)
model.predict('it has', 20, data.vocab)
'it hasoadd dd dd dd dd dd '
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
rnn = RNN(num_hiddens=32)
model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1)
model.predict('it has', 20, data.vocab)
[22:52:51] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
'it hasxlxlxlxlxlxlxlxlxlxl'
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
rnn = RNN(num_hiddens=32)
model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1)
model.predict('it has', 20, data.vocab)
'it hasretsnrnrxnrnrgczntgq'

Next, we train our model, leveraging the high-level API.

trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
../_images/output_rnn-concise_eff2f4_62_0.svg
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
../_images/output_rnn-concise_eff2f4_65_0.svg
with d2l.try_gpu():
    trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1)
trainer.fit(model, data)
../_images/output_rnn-concise_eff2f4_68_0.svg

Compared with Section 9.5, this model achieves comparable perplexity, but runs faster due to the optimized implementations. As before, we can generate predicted tokens following the specified prefix string.

model.predict('it has', 20, data.vocab, d2l.try_gpu())
'it has and the trave the t'
model.predict('it has', 20, data.vocab, d2l.try_gpu())
'it has and the time the ti'
model.predict('it has', 20, data.vocab)
'it has and the pas an and '

9.6.3. Summary

High-level APIs in deep learning frameworks provide implementations of standard RNNs. These libraries help you to avoid wasting time reimplementing standard models. Moreover, framework implementations are often highly optimized, leading to significant (computational) performance gains when compared with implementations from scratch.

9.6.4. Exercises

  1. Can you make the RNN model overfit using the high-level APIs?

  2. Implement the autoregressive model of Section 9.1 using an RNN.