.. _sec_bert-pretraining:
Pretraining BERT
================
With the BERT model implemented in :numref:`sec_bert` and the
pretraining examples generated from the WikiText-2 dataset in
:numref:`sec_bert-dataset`, we will pretrain BERT on the WikiText-2
dataset in this section.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import torch
from torch import nn
from d2l import torch as d2l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from mxnet import autograd, gluon, init, np, npx
from d2l import mxnet as d2l
npx.set_np()
.. raw:: html
.. raw:: html
To start, we load the WikiText-2 dataset as minibatches of pretraining
examples for masked language modeling and next sentence prediction. The
batch size is 512 and the maximum length of a BERT input sequence is 64.
Note that in the original BERT model, the maximum length is 512.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
batch_size, max_len = 512, 64
train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
batch_size, max_len = 512, 64
train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[22:11:29] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
.. raw:: html
.. raw:: html
Pretraining BERT
----------------
The original BERT has two versions of different model sizes
:cite:`Devlin.Chang.Lee.ea.2018`. The base model
(:math:`\textrm{BERT}_{\textrm{BASE}}`) uses 12 layers (Transformer
encoder blocks) with 768 hidden units (hidden size) and 12
self-attention heads. The large model
(:math:`\textrm{BERT}_{\textrm{LARGE}}`) uses 24 layers with 1024 hidden
units and 16 self-attention heads. Notably, the former has 110 million
parameters while the latter has 340 million parameters. For
demonstration with ease, we define a small BERT, using 2 layers, 128
hidden units, and 2 self-attention heads.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = d2l.BERTModel(len(vocab), num_hiddens=128,
ffn_num_hiddens=256, num_heads=2, num_blks=2, dropout=0.2)
devices = d2l.try_all_gpus()
loss = nn.CrossEntropyLoss()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = d2l.BERTModel(len(vocab), num_hiddens=128, ffn_num_hiddens=256,
num_heads=2, num_blks=2, dropout=0.2)
devices = d2l.try_all_gpus()
net.initialize(init.Xavier(), ctx=devices)
loss = gluon.loss.SoftmaxCELoss()
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[22:12:33] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for GPU
[22:12:34] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for GPU
.. raw:: html
.. raw:: html
Before defining the training loop, we define a helper function
``_get_batch_loss_bert``. Given the shard of training examples, this
function computes the loss for both the masked language modeling and
next sentence prediction tasks. Note that the final loss of BERT
pretraining is just the sum of both the masked language modeling loss
and the next sentence prediction loss.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,
segments_X, valid_lens_x,
pred_positions_X, mlm_weights_X,
mlm_Y, nsp_y):
# Forward pass
_, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,
valid_lens_x.reshape(-1),
pred_positions_X)
# Compute masked language model loss
mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\
mlm_weights_X.reshape(-1, 1)
mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)
# Compute next sentence prediction loss
nsp_l = loss(nsp_Y_hat, nsp_y)
l = mlm_l + nsp_l
return mlm_l, nsp_l, l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X_shards,
segments_X_shards, valid_lens_x_shards,
pred_positions_X_shards, mlm_weights_X_shards,
mlm_Y_shards, nsp_y_shards):
mlm_ls, nsp_ls, ls = [], [], []
for (tokens_X_shard, segments_X_shard, valid_lens_x_shard,
pred_positions_X_shard, mlm_weights_X_shard, mlm_Y_shard,
nsp_y_shard) in zip(
tokens_X_shards, segments_X_shards, valid_lens_x_shards,
pred_positions_X_shards, mlm_weights_X_shards, mlm_Y_shards,
nsp_y_shards):
# Forward pass
_, mlm_Y_hat, nsp_Y_hat = net(
tokens_X_shard, segments_X_shard, valid_lens_x_shard.reshape(-1),
pred_positions_X_shard)
# Compute masked language model loss
mlm_l = loss(
mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y_shard.reshape(-1),
mlm_weights_X_shard.reshape((-1, 1)))
mlm_l = mlm_l.sum() / (mlm_weights_X_shard.sum() + 1e-8)
# Compute next sentence prediction loss
nsp_l = loss(nsp_Y_hat, nsp_y_shard)
nsp_l = nsp_l.mean()
mlm_ls.append(mlm_l)
nsp_ls.append(nsp_l)
ls.append(mlm_l + nsp_l)
npx.waitall()
return mlm_ls, nsp_ls, ls
.. raw:: html
.. raw:: html
Invoking the two aforementioned helper functions, the following
``train_bert`` function defines the procedure to pretrain BERT (``net``)
on the WikiText-2 (``train_iter``) dataset. Training BERT can take very
long. Instead of specifying the number of epochs for training as in the
``train_ch13`` function (see :numref:`sec_image_augmentation`), the
input ``num_steps`` of the following function specifies the number of
iteration steps for training.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
net(*next(iter(train_iter))[:4])
net = nn.DataParallel(net, device_ids=devices).to(devices[0])
trainer = torch.optim.Adam(net.parameters(), lr=0.01)
step, timer = 0, d2l.Timer()
animator = d2l.Animator(xlabel='step', ylabel='loss',
xlim=[1, num_steps], legend=['mlm', 'nsp'])
# Sum of masked language modeling losses, sum of next sentence prediction
# losses, no. of sentence pairs, count
metric = d2l.Accumulator(4)
num_steps_reached = False
while step < num_steps and not num_steps_reached:
for tokens_X, segments_X, valid_lens_x, pred_positions_X,\
mlm_weights_X, mlm_Y, nsp_y in train_iter:
tokens_X = tokens_X.to(devices[0])
segments_X = segments_X.to(devices[0])
valid_lens_x = valid_lens_x.to(devices[0])
pred_positions_X = pred_positions_X.to(devices[0])
mlm_weights_X = mlm_weights_X.to(devices[0])
mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])
trainer.zero_grad()
timer.start()
mlm_l, nsp_l, l = _get_batch_loss_bert(
net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,
pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)
l.backward()
trainer.step()
metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)
timer.stop()
animator.add(step + 1,
(metric[0] / metric[3], metric[1] / metric[3]))
step += 1
if step == num_steps:
num_steps_reached = True
break
print(f'MLM loss {metric[0] / metric[3]:.3f}, '
f'NSP loss {metric[1] / metric[3]:.3f}')
print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '
f'{str(devices)}')
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
trainer = gluon.Trainer(net.collect_params(), 'adam',
{'learning_rate': 0.01})
step, timer = 0, d2l.Timer()
animator = d2l.Animator(xlabel='step', ylabel='loss',
xlim=[1, num_steps], legend=['mlm', 'nsp'])
# Sum of masked language modeling losses, sum of next sentence prediction
# losses, no. of sentence pairs, count
metric = d2l.Accumulator(4)
num_steps_reached = False
while step < num_steps and not num_steps_reached:
for batch in train_iter:
(tokens_X_shards, segments_X_shards, valid_lens_x_shards,
pred_positions_X_shards, mlm_weights_X_shards,
mlm_Y_shards, nsp_y_shards) = [gluon.utils.split_and_load(
elem, devices, even_split=False) for elem in batch]
timer.start()
with autograd.record():
mlm_ls, nsp_ls, ls = _get_batch_loss_bert(
net, loss, vocab_size, tokens_X_shards, segments_X_shards,
valid_lens_x_shards, pred_positions_X_shards,
mlm_weights_X_shards, mlm_Y_shards, nsp_y_shards)
for l in ls:
l.backward()
trainer.step(1)
mlm_l_mean = sum([float(l) for l in mlm_ls]) / len(mlm_ls)
nsp_l_mean = sum([float(l) for l in nsp_ls]) / len(nsp_ls)
metric.add(mlm_l_mean, nsp_l_mean, batch[0].shape[0], 1)
timer.stop()
animator.add(step + 1,
(metric[0] / metric[3], metric[1] / metric[3]))
step += 1
if step == num_steps:
num_steps_reached = True
break
print(f'MLM loss {metric[0] / metric[3]:.3f}, '
f'NSP loss {metric[1] / metric[3]:.3f}')
print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '
f'{str(devices)}')
.. raw:: html
.. raw:: html
We can plot both the masked language modeling loss and the next sentence
prediction loss during BERT pretraining.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
train_bert(train_iter, net, loss, len(vocab), devices, 50)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
MLM loss 5.885, NSP loss 0.760
4413.2 sentence pairs/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]
.. figure:: output_bert-pretraining_41429c_48_1.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
train_bert(train_iter, net, loss, len(vocab), devices, 50)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
MLM loss 7.292, NSP loss 0.822
2417.3 sentence pairs/sec on [gpu(0), gpu(1)]
.. figure:: output_bert-pretraining_41429c_51_1.svg
.. raw:: html
.. raw:: html
Representing Text with BERT
---------------------------
After pretraining BERT, we can use it to represent single text, text
pairs, or any token in them. The following function returns the BERT
(``net``) representations for all tokens in ``tokens_a`` and
``tokens_b``.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def get_bert_encoding(net, tokens_a, tokens_b=None):
tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)
token_ids = torch.tensor(vocab[tokens], device=devices[0]).unsqueeze(0)
segments = torch.tensor(segments, device=devices[0]).unsqueeze(0)
valid_len = torch.tensor(len(tokens), device=devices[0]).unsqueeze(0)
encoded_X, _, _ = net(token_ids, segments, valid_len)
return encoded_X
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def get_bert_encoding(net, tokens_a, tokens_b=None):
tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)
token_ids = np.expand_dims(np.array(vocab[tokens], ctx=devices[0]),
axis=0)
segments = np.expand_dims(np.array(segments, ctx=devices[0]), axis=0)
valid_len = np.expand_dims(np.array(len(tokens), ctx=devices[0]), axis=0)
encoded_X, _, _ = net(token_ids, segments, valid_len)
return encoded_X
.. raw:: html
.. raw:: html
Consider the sentence “a crane is flying”. Recall the input
representation of BERT as discussed in
:numref:`subsec_bert_input_rep`. After inserting special tokens
“” (used for classification) and “” (used for separation), the
BERT input sequence has a length of six. Since zero is the index of the
“” token, ``encoded_text[:, 0, :]`` is the BERT representation of
the entire input sentence. To evaluate the polysemy token “crane”, we
also print out the first three elements of the BERT representation of
the token.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
tokens_a = ['a', 'crane', 'is', 'flying']
encoded_text = get_bert_encoding(net, tokens_a)
# Tokens: '', 'a', 'crane', 'is', 'flying', ''
encoded_text_cls = encoded_text[:, 0, :]
encoded_text_crane = encoded_text[:, 2, :]
encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(torch.Size([1, 6, 128]),
torch.Size([1, 128]),
tensor([0.8414, 1.4830, 0.8226], device='cuda:0', grad_fn=))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
tokens_a = ['a', 'crane', 'is', 'flying']
encoded_text = get_bert_encoding(net, tokens_a)
# Tokens: '', 'a', 'crane', 'is', 'flying', ''
encoded_text_cls = encoded_text[:, 0, :]
encoded_text_crane = encoded_text[:, 2, :]
encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
((1, 6, 128),
(1, 128),
array([-1.2760178, -0.79205 , -1.0534445], ctx=gpu(0)))
.. raw:: html
.. raw:: html
Now consider a sentence pair “a crane driver came” and “he just left”.
Similarly, ``encoded_pair[:, 0, :]`` is the encoded result of the entire
sentence pair from the pretrained BERT. Note that the first three
elements of the polysemy token “crane” are different from those when the
context is different. This supports that BERT representations are
context-sensitive.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']
encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)
# Tokens: '', 'a', 'crane', 'driver', 'came', '', 'he', 'just',
# 'left', ''
encoded_pair_cls = encoded_pair[:, 0, :]
encoded_pair_crane = encoded_pair[:, 2, :]
encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(torch.Size([1, 10, 128]),
torch.Size([1, 128]),
tensor([0.0430, 1.6132, 0.0437], device='cuda:0', grad_fn=))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']
encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)
# Tokens: '', 'a', 'crane', 'driver', 'came', '', 'he', 'just',
# 'left', ''
encoded_pair_cls = encoded_pair[:, 0, :]
encoded_pair_crane = encoded_pair[:, 2, :]
encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
((1, 10, 128),
(1, 128),
array([-1.2759778 , -0.79211384, -1.0534613 ], ctx=gpu(0)))
.. raw:: html
.. raw:: html
In :numref:`chap_nlp_app`, we will fine-tune a pretrained BERT model
for downstream natural language processing applications.
Summary
-------
- The original BERT has two versions, where the base model has 110
million parameters and the large model has 340 million parameters.
- After pretraining BERT, we can use it to represent single text, text
pairs, or any token in them.
- In the experiment, the same token has different BERT representation
when their contexts are different. This supports that BERT
representations are context-sensitive.
Exercises
---------
1. In the experiment, we can see that the masked language modeling loss
is significantly higher than the next sentence prediction loss. Why?
2. Set the maximum length of a BERT input sequence to be 512 (same as
the original BERT model). Use the configurations of the original BERT
model such as :math:`\textrm{BERT}_{\textrm{LARGE}}`. Do you
encounter any error when running this section? Why?
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html