.. _sec_rnn-concise: Concise Implementation of Recurrent Neural Networks =================================================== Like most of our from-scratch implementations, :numref:`sec_rnn-scratch` was designed to provide insight into how each component works. But when you are using RNNs every day or writing production code, you will want to rely more on libraries that cut down on both implementation time (by supplying library code for common models and functions) and computation time (by optimizing the heck out of these library implementations). This section will show you how to implement the same language model more efficiently using the high-level API provided by your deep learning framework. We begin, as before, by loading *The Time Machine* dataset. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import torch from torch import nn from torch.nn import functional as F from d2l import torch as d2l .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python from mxnet import np, npx from mxnet.gluon import nn, rnn from d2l import mxnet as d2l npx.set_np() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python from flax import linen as nn from jax import numpy as jnp from d2l import jax as d2l .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import tensorflow as tf from d2l import tensorflow as d2l .. raw:: html
.. raw:: html
Defining the Model ------------------ We define the following class using the RNN implemented by high-level APIs. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class RNN(d2l.Module): #@save """The RNN model implemented with high-level APIs.""" def __init__(self, num_inputs, num_hiddens): super().__init__() self.save_hyperparameters() self.rnn = nn.RNN(num_inputs, num_hiddens) def forward(self, inputs, H=None): return self.rnn(inputs, H) .. raw:: html
.. raw:: html
Specifically, to initialize the hidden state, we invoke the member method ``begin_state``. This returns a list that contains an initial hidden state for each example in the minibatch, whose shape is (number of hidden layers, batch size, number of hidden units). For some models to be introduced later (e.g., long short-term memory), this list will also contain other information. .. raw:: latex \diilbookstyleinputcell .. code:: python class RNN(d2l.Module): #@save """The RNN model implemented with high-level APIs.""" def __init__(self, num_hiddens): super().__init__() self.save_hyperparameters() self.rnn = rnn.RNN(num_hiddens) def forward(self, inputs, H=None): if H is None: H, = self.rnn.begin_state(inputs.shape[1], ctx=inputs.ctx) outputs, (H, ) = self.rnn(inputs, (H, )) return outputs, H .. raw:: html
.. raw:: html
Flax does not provide an RNNCell for concise implementation of Vanilla RNNs as of today. There are more advanced variants of RNNs like LSTMs and GRUs which are available in the Flax ``linen`` API. .. raw:: latex \diilbookstyleinputcell .. code:: python class RNN(nn.Module): #@save """The RNN model implemented with high-level APIs.""" num_hiddens: int @nn.compact def __call__(self, inputs, H=None): raise NotImplementedError .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class RNN(d2l.Module): #@save """The RNN model implemented with high-level APIs.""" def __init__(self, num_hiddens): super().__init__() self.save_hyperparameters() self.rnn = tf.keras.layers.SimpleRNN( num_hiddens, return_sequences=True, return_state=True, time_major=True) def forward(self, inputs, H=None): outputs, H = self.rnn(inputs, H) return outputs, H .. raw:: html
.. raw:: html
Inheriting from the ``RNNLMScratch`` class in :numref:`sec_rnn-scratch`, the following ``RNNLM`` class defines a complete RNN-based language model. Note that we need to create a separate fully connected output layer. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class RNNLM(d2l.RNNLMScratch): #@save """The RNN-based language model implemented with high-level APIs.""" def init_params(self): self.linear = nn.LazyLinear(self.vocab_size) def output_layer(self, hiddens): return self.linear(hiddens).swapaxes(0, 1) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class RNNLM(d2l.RNNLMScratch): #@save """The RNN-based language model implemented with high-level APIs.""" def init_params(self): self.linear = nn.Dense(self.vocab_size, flatten=False) self.initialize() def output_layer(self, hiddens): return self.linear(hiddens).swapaxes(0, 1) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class RNNLM(d2l.RNNLMScratch): #@save """The RNN-based language model implemented with high-level APIs.""" training: bool = True def setup(self): self.linear = nn.Dense(self.vocab_size) def output_layer(self, hiddens): return self.linear(hiddens).swapaxes(0, 1) def forward(self, X, state=None): embs = self.one_hot(X) rnn_outputs, _ = self.rnn(embs, state, self.training) return self.output_layer(rnn_outputs) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class RNNLM(d2l.RNNLMScratch): #@save """The RNN-based language model implemented with high-level APIs.""" def init_params(self): self.linear = tf.keras.layers.Dense(self.vocab_size) def output_layer(self, hiddens): return tf.transpose(self.linear(hiddens), (1, 0, 2)) .. raw:: html
.. raw:: html
Training and Predicting ----------------------- Before training the model, let’s make a prediction with a model initialized with random weights. Given that we have not trained the network, it will generate nonsensical predictions. .. raw:: html
pytorchmxnettensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python data = d2l.TimeMachine(batch_size=1024, num_steps=32) rnn = RNN(num_inputs=len(data.vocab), num_hiddens=32) model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1) model.predict('it has', 20, data.vocab) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output 'it hasoadd dd dd dd dd dd ' .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python data = d2l.TimeMachine(batch_size=1024, num_steps=32) rnn = RNN(num_hiddens=32) model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1) model.predict('it has', 20, data.vocab) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [22:52:51] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output 'it hasxlxlxlxlxlxlxlxlxlxl' .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python data = d2l.TimeMachine(batch_size=1024, num_steps=32) rnn = RNN(num_hiddens=32) model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1) model.predict('it has', 20, data.vocab) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output 'it hasretsnrnrxnrnrgczntgq' .. raw:: html
.. raw:: html
Next, we train our model, leveraging the high-level API. .. raw:: html
pytorchmxnettensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1) trainer.fit(model, data) .. figure:: output_rnn-concise_eff2f4_62_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1) trainer.fit(model, data) .. figure:: output_rnn-concise_eff2f4_65_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python with d2l.try_gpu(): trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1) trainer.fit(model, data) .. figure:: output_rnn-concise_eff2f4_68_0.svg .. raw:: html
.. raw:: html
Compared with :numref:`sec_rnn-scratch`, this model achieves comparable perplexity, but runs faster due to the optimized implementations. As before, we can generate predicted tokens following the specified prefix string. .. raw:: html
pytorchmxnettensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python model.predict('it has', 20, data.vocab, d2l.try_gpu()) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output 'it has and the trave the t' .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python model.predict('it has', 20, data.vocab, d2l.try_gpu()) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output 'it has and the time the ti' .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python model.predict('it has', 20, data.vocab) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output 'it has and the pas an and ' .. raw:: html
.. raw:: html
Summary ------- High-level APIs in deep learning frameworks provide implementations of standard RNNs. These libraries help you to avoid wasting time reimplementing standard models. Moreover, framework implementations are often highly optimized, leading to significant (computational) performance gains when compared with implementations from scratch. Exercises --------- 1. Can you make the RNN model overfit using the high-level APIs? 2. Implement the autoregressive model of :numref:`sec_sequence` using an RNN. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html