.. _sec_bi_rnn: Bidirectional Recurrent Neural Networks ======================================= So far, our working example of a sequence learning task has been language modeling, where we aim to predict the next token given all previous tokens in a sequence. In this scenario, we wish only to condition upon the leftward context, and thus the unidirectional chaining of a standard RNN seems appropriate. However, there are many other sequence learning tasks contexts where it is perfectly fine to condition the prediction at every time step on both the leftward and the rightward context. Consider, for example, part of speech detection. Why shouldn’t we take the context in both directions into account when assessing the part of speech associated with a given word? Another common task—often useful as a pretraining exercise prior to fine-tuning a model on an actual task of interest—is to mask out random tokens in a text document and then to train a sequence model to predict the values of the missing tokens. Note that depending on what comes after the blank, the likely value of the missing token changes dramatically: - I am ``___``. - I am ``___`` hungry. - I am ``___`` hungry, and I can eat half a pig. In the first sentence “happy” seems to be a likely candidate. The words “not” and “very” seem plausible in the second sentence, but “not” seems incompatible with the third sentences. Fortunately, a simple technique transforms any unidirectional RNN into a bidirectional RNN :cite:`Schuster.Paliwal.1997`. We simply implement two unidirectional RNN layers chained together in opposite directions and acting on the same input (:numref:`fig_birnn`). For the first RNN layer, the first input is :math:`\mathbf{x}_1` and the last input is :math:`\mathbf{x}_T`, but for the second RNN layer, the first input is :math:`\mathbf{x}_T` and the last input is :math:`\mathbf{x}_1`. To produce the output of this bidirectional RNN layer, we simply concatenate together the corresponding outputs of the two underlying unidirectional RNN layers. .. _fig_birnn: .. figure:: ../img/birnn.svg Architecture of a bidirectional RNN. Formally for any time step :math:`t`, we consider a minibatch input :math:`\mathbf{X}_t \in \mathbb{R}^{n \times d}` (number of examples :math:`=n`; number of inputs in each example :math:`=d`) and let the hidden layer activation function be :math:`\phi`. In the bidirectional architecture, the forward and backward hidden states for this time step are :math:`\overrightarrow{\mathbf{H}}_t \in \mathbb{R}^{n \times h}` and :math:`\overleftarrow{\mathbf{H}}_t \in \mathbb{R}^{n \times h}`, respectively, where :math:`h` is the number of hidden units. The forward and backward hidden state updates are as follows: .. math:: \begin{aligned} \overrightarrow{\mathbf{H}}_t &= \phi(\mathbf{X}_t \mathbf{W}_{\textrm{xh}}^{(f)} + \overrightarrow{\mathbf{H}}_{t-1} \mathbf{W}_{\textrm{hh}}^{(f)} + \mathbf{b}_\textrm{h}^{(f)}),\\ \overleftarrow{\mathbf{H}}_t &= \phi(\mathbf{X}_t \mathbf{W}_{\textrm{xh}}^{(b)} + \overleftarrow{\mathbf{H}}_{t+1} \mathbf{W}_{\textrm{hh}}^{(b)} + \mathbf{b}_\textrm{h}^{(b)}), \end{aligned} where the weights :math:`\mathbf{W}_{\textrm{xh}}^{(f)} \in \mathbb{R}^{d \times h}, \mathbf{W}_{\textrm{hh}}^{(f)} \in \mathbb{R}^{h \times h}, \mathbf{W}_{\textrm{xh}}^{(b)} \in \mathbb{R}^{d \times h}, \textrm{ and } \mathbf{W}_{\textrm{hh}}^{(b)} \in \mathbb{R}^{h \times h}`, and the biases :math:`\mathbf{b}_\textrm{h}^{(f)} \in \mathbb{R}^{1 \times h}` and :math:`\mathbf{b}_\textrm{h}^{(b)} \in \mathbb{R}^{1 \times h}` are all the model parameters. Next, we concatenate the forward and backward hidden states :math:`\overrightarrow{\mathbf{H}}_t` and :math:`\overleftarrow{\mathbf{H}}_t` to obtain the hidden state :math:`\mathbf{H}_t \in \mathbb{R}^{n \times 2h}` for feeding into the output layer. In deep bidirectional RNNs with multiple hidden layers, such information is passed on as *input* to the next bidirectional layer. Last, the output layer computes the output :math:`\mathbf{O}_t \in \mathbb{R}^{n \times q}` (number of outputs :math:`=q`): .. math:: \mathbf{O}_t = \mathbf{H}_t \mathbf{W}_{\textrm{hq}} + \mathbf{b}_\textrm{q}. Here, the weight matrix :math:`\mathbf{W}_{\textrm{hq}} \in \mathbb{R}^{2h \times q}` and the bias :math:`\mathbf{b}_\textrm{q} \in \mathbb{R}^{1 \times q}` are the model parameters of the output layer. While technically, the two directions can have different numbers of hidden units, this design choice is seldom made in practice. We now demonstrate a simple implementation of a bidirectional RNN. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import torch from torch import nn from d2l import torch as d2l .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python from mxnet import np, npx from mxnet.gluon import rnn from d2l import mxnet as d2l npx.set_np() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python from jax import numpy as jnp from d2l import jax as d2l .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import tensorflow as tf from d2l import tensorflow as d2l .. raw:: html
.. raw:: html
Implementation from Scratch --------------------------- To implement a bidirectional RNN from scratch, we can include two unidirectional ``RNNScratch`` instances with separate learnable parameters. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class BiRNNScratch(d2l.Module): def __init__(self, num_inputs, num_hiddens, sigma=0.01): super().__init__() self.save_hyperparameters() self.f_rnn = d2l.RNNScratch(num_inputs, num_hiddens, sigma) self.b_rnn = d2l.RNNScratch(num_inputs, num_hiddens, sigma) self.num_hiddens *= 2 # The output dimension will be doubled .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class BiRNNScratch(d2l.Module): def __init__(self, num_inputs, num_hiddens, sigma=0.01): super().__init__() self.save_hyperparameters() self.f_rnn = d2l.RNNScratch(num_inputs, num_hiddens, sigma) self.b_rnn = d2l.RNNScratch(num_inputs, num_hiddens, sigma) self.num_hiddens *= 2 # The output dimension will be doubled .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class BiRNNScratch(d2l.Module): num_inputs: int num_hiddens: int sigma: float = 0.01 def setup(self): self.f_rnn = d2l.RNNScratch(num_inputs, num_hiddens, sigma) self.b_rnn = d2l.RNNScratch(num_inputs, num_hiddens, sigma) self.num_hiddens *= 2 # The output dimension will be doubled .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class BiRNNScratch(d2l.Module): def __init__(self, num_inputs, num_hiddens, sigma=0.01): super().__init__() self.save_hyperparameters() self.f_rnn = d2l.RNNScratch(num_inputs, num_hiddens, sigma) self.b_rnn = d2l.RNNScratch(num_inputs, num_hiddens, sigma) self.num_hiddens *= 2 # The output dimension will be doubled .. raw:: html
.. raw:: html
States of forward and backward RNNs are updated separately, while outputs of these two RNNs are concatenated. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python @d2l.add_to_class(BiRNNScratch) def forward(self, inputs, Hs=None): f_H, b_H = Hs if Hs is not None else (None, None) f_outputs, f_H = self.f_rnn(inputs, f_H) b_outputs, b_H = self.b_rnn(reversed(inputs), b_H) outputs = [torch.cat((f, b), -1) for f, b in zip( f_outputs, reversed(b_outputs))] return outputs, (f_H, b_H) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python @d2l.add_to_class(BiRNNScratch) def forward(self, inputs, Hs=None): f_H, b_H = Hs if Hs is not None else (None, None) f_outputs, f_H = self.f_rnn(inputs, f_H) b_outputs, b_H = self.b_rnn(reversed(inputs), b_H) outputs = [np.concatenate((f, b), -1) for f, b in zip( f_outputs, reversed(b_outputs))] return outputs, (f_H, b_H) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python @d2l.add_to_class(BiRNNScratch) def forward(self, inputs, Hs=None): f_H, b_H = Hs if Hs is not None else (None, None) f_outputs, f_H = self.f_rnn(inputs, f_H) b_outputs, b_H = self.b_rnn(reversed(inputs), b_H) outputs = [jnp.concatenate((f, b), -1) for f, b in zip( f_outputs, reversed(b_outputs))] return outputs, (f_H, b_H) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python @d2l.add_to_class(BiRNNScratch) def forward(self, inputs, Hs=None): f_H, b_H = Hs if Hs is not None else (None, None) f_outputs, f_H = self.f_rnn(inputs, f_H) b_outputs, b_H = self.b_rnn(reversed(inputs), b_H) outputs = [tf.concat((f, b), -1) for f, b in zip( f_outputs, reversed(b_outputs))] return outputs, (f_H, b_H) .. raw:: html
.. raw:: html
Concise Implementation ---------------------- .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
Using the high-level APIs, we can implement bidirectional RNNs more concisely. Here we take a GRU model as an example. .. raw:: latex \diilbookstyleinputcell .. code:: python class BiGRU(d2l.RNN): def __init__(self, num_inputs, num_hiddens): d2l.Module.__init__(self) self.save_hyperparameters() self.rnn = nn.GRU(num_inputs, num_hiddens, bidirectional=True) self.num_hiddens *= 2 .. raw:: html
.. raw:: html
Using the high-level APIs, we can implement bidirectional RNNs more concisely. Here we take a GRU model as an example. .. raw:: latex \diilbookstyleinputcell .. code:: python class BiGRU(d2l.RNN): def __init__(self, num_inputs, num_hiddens): d2l.Module.__init__(self) self.save_hyperparameters() self.rnn = rnn.GRU(num_hiddens, bidirectional=True) self.num_hiddens *= 2 .. raw:: html
.. raw:: html
Flax API does not offer RNN layers and hence there is no notion of any ``bidirectional`` argument. One needs to manually reverse the inputs as shown in the scratch implementation, if a bidirectional layer is needed. .. raw:: html
.. raw:: html
Using the high-level APIs, we can implement bidirectional RNNs more concisely. Here we take a GRU model as an example. .. raw:: html
.. raw:: html
Summary ------- In bidirectional RNNs, the hidden state for each time step is simultaneously determined by the data prior to and after the current time step. Bidirectional RNNs are mostly useful for sequence encoding and the estimation of observations given bidirectional context. Bidirectional RNNs are very costly to train due to long gradient chains. Exercises --------- 1. If the different directions use a different number of hidden units, how will the shape of :math:`\mathbf{H}_t` change? 2. Design a bidirectional RNN with multiple hidden layers. 3. Polysemy is common in natural languages. For example, the word “bank” has different meanings in contexts “i went to the bank to deposit cash” and “i went to the bank to sit down”. How can we design a neural network model such that given a context sequence and a word, a vector representation of the word in the correct context will be returned? What type of neural architectures is preferred for handling polysemy? .. raw:: html
pytorchmxnetjax
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html