.. _sec_basic_gan:
Generative Adversarial Networks
===============================
Throughout most of this book, we have talked about how to make
predictions. In some form or another, we used deep neural networks
learned mappings from data points to labels. This kind of learning is
called discriminative learning, as in, we’d like to be able to
discriminate between photos cats and photos of dogs. Classifiers and
regressors are both examples of discriminative learning. And neural
networks trained by backpropagation have upended everything we thought
we knew about discriminative learning on large complicated datasets.
Classification accuracies on high-res images has gone from useless to
human-level (with some caveats) in just 5-6 years. We will spare you
another spiel about all the other discriminative tasks where deep neural
networks do astoundingly well.
But there is more to machine learning than just solving discriminative
tasks. For example, given a large dataset, without any labels, we might
want to learn a model that concisely captures the characteristics of
this data. Given such a model, we could sample synthetic data points
that resemble the distribution of the training data. For example, given
a large corpus of photographs of faces, we might want to be able to
generate a new photorealistic image that looks like it might plausibly
have come from the same dataset. This kind of learning is called
generative modeling.
Until recently, we had no method that could synthesize novel
photorealistic images. But the success of deep neural networks for
discriminative learning opened up new possibilities. One big trend over
the last three years has been the application of discriminative deep
nets to overcome challenges in problems that we do not generally think
of as supervised learning problems. The recurrent neural network
language models are one example of using a discriminative network
(trained to predict the next character) that once trained can act as a
generative model.
In 2014, a breakthrough paper introduced Generative adversarial networks
(GANs) :cite:`Goodfellow.Pouget-Abadie.Mirza.ea.2014`, a clever new
way to leverage the power of discriminative models to get good
generative models. At their heart, GANs rely on the idea that a data
generator is good if we cannot tell fake data apart from real data. In
statistics, this is called a two-sample test - a test to answer the
question whether datasets :math:`X=\{x_1,\ldots,x_n\}` and
:math:`X'=\{x'_1,\ldots,x'_n\}` were drawn from the same distribution.
The main difference between most statistics papers and GANs is that the
latter use this idea in a constructive way. In other words, rather than
just training a model to say “hey, these two datasets do not look like
they came from the same distribution”, they use the `two-sample
test `__ to
provide training signals to a generative model. This allows us to
improve the data generator until it generates something that resembles
the real data. At the very least, it needs to fool the classifier. Even
if our classifier is a state of the art deep neural network.
.. _fig_gan:
.. figure:: ../img/gan.svg
Generative Adversarial Networks
The GAN architecture is illustrated in :numref:`fig_gan`. As you can
see, there are two pieces in GAN architecture - first off, we need a
device (say, a deep network but it really could be anything, such as a
game rendering engine) that might potentially be able to generate data
that looks just like the real thing. If we are dealing with images, this
needs to generate images. If we are dealing with speech, it needs to
generate audio sequences, and so on. We call this the generator network.
The second component is the discriminator network. It attempts to
distinguish fake and real data from each other. Both networks are in
competition with each other. The generator network attempts to fool the
discriminator network. At that point, the discriminator network adapts
to the new fake data. This information, in turn is used to improve the
generator network, and so on.
The discriminator is a binary classifier to distinguish if the input
:math:`x` is real (from real data) or fake (from the generator).
Typically, the discriminator outputs a scalar prediction
:math:`o\in\mathbb R` for input :math:`\mathbf x`, such as using a dense
layer with hidden size 1, and then applies sigmoid function to obtain
the predicted probability :math:`D(\mathbf x) = 1/(1+e^{-o})`. Assume
the label :math:`y` for the true data is :math:`1` and :math:`0` for the
fake data. We train the discriminator to minimize the cross-entropy
loss, *i.e.*,
.. math:: \min_D \{ - y \log D(\mathbf x) - (1-y)\log(1-D(\mathbf x)) \},
For the generator, it first draws some parameter
:math:`\mathbf z\in\mathbb R^d` from a source of randomness, *e.g.*, a
normal distribution :math:`\mathbf z \sim \mathcal{N} (0,1)`. We often
call :math:`\mathbf z` as the latent variable. It then applies a
function to generate :math:`\mathbf x'=G(\mathbf z)`. The goal of the
generator is to fool the discriminator to classify
:math:`\mathbf x'=G(\mathbf z)` as true data, *i.e.*, we want
:math:`D( G(\mathbf z)) \approx 1`. In other words, for a given
discriminator :math:`D`, we update the parameters of the generator
:math:`G` to maximize the cross-entropy loss when :math:`y=0`, *i.e.*,
.. math:: \max_G \{ - (1-y) \log(1-D(G(\mathbf z))) \} = \max_G \{ - \log(1-D(G(\mathbf z))) \}.
If the generator does a perfect job, then :math:`D(\mathbf x')\approx 1`
so the above loss near 0, which results the gradients are too small to
make a good progress for the discriminator. So commonly we minimize the
following loss:
.. math:: \min_G \{ - y \log(D(G(\mathbf z))) \} = \min_G \{ - \log(D(G(\mathbf z))) \},
which is just feed :math:`\mathbf x'=G(\mathbf z)` into the
discriminator but giving label :math:`y=1`.
To sum up, :math:`D` and :math:`G` are playing a “minimax” game with the
comprehensive objective function:
.. math:: min_D max_G \{ -E_{x \sim \text{Data}} log D(\mathbf x) - E_{z \sim \text{Noise}} log(1 - D(G(\mathbf z))) \}.
Many of the GANs applications are in the context of images. As a
demonstration purpose, we are going to content ourselves with fitting a
much simpler distribution first. We will illustrate what happens if we
use GANs to build the world’s most inefficient estimator of parameters
for a Gaussian. Let us get started.
.. code:: python
%matplotlib inline
import d2l
from mxnet import np, npx, gluon, autograd, init
from mxnet.gluon import nn
npx.set_np()
Generate some “real” data
-------------------------
Since this is going to be the world’s lamest example, we simply generate
data drawn from a Gaussian.
.. code:: python
X = np.random.normal(size=(1000, 2))
A = np.array([[1, 2], [-0.1, 0.5]])
b = np.array([1, 2])
data = X.dot(A) + b
Let us see what we got. This should be a Gaussian shifted in some rather
arbitrary way with mean :math:`b` and covariance matrix :math:`A^TA`.
.. code:: python
d2l.set_figsize((3.5, 2.5))
d2l.plt.scatter(data[:100,0].asnumpy(), data[:100,1].asnumpy());
print("The covariance matrix is\n%s" % np.dot(A.T, A))
.. parsed-literal::
:class: output
The covariance matrix is
[[1.01 1.95]
[1.95 4.25]]
.. figure:: output_gan_4e4dd7_5_1.svg
.. code:: python
batch_size = 8
data_iter = d2l.load_array((data,), batch_size)
Generator
---------
Our generator network will be the simplest network possible - a single
layer linear model. This is since we will be driving that linear network
with a Gaussian data generator. Hence, it literally only needs to learn
the parameters to fake things perfectly.
.. code:: python
net_G = nn.Sequential()
net_G.add(nn.Dense(2))
Discriminator
-------------
For the discriminator we will be a bit more discriminating: we will use
an MLP with 3 layers to make things a bit more interesting.
.. code:: python
net_D = nn.Sequential()
net_D.add(nn.Dense(5, activation='tanh'),
nn.Dense(3, activation='tanh'),
nn.Dense(1))
Training
--------
First we define a function to update the discriminator.
.. code:: python
# Saved in the d2l package for later use
def update_D(X, Z, net_D, net_G, loss, trainer_D):
"""Update discriminator"""
batch_size = X.shape[0]
ones = np.ones((batch_size,), ctx=X.context)
zeros = np.zeros((batch_size,), ctx=X.context)
with autograd.record():
real_Y = net_D(X)
fake_X = net_G(Z)
# Do not need to compute gradient for net_G, detach it from
# computing gradients.
fake_Y = net_D(fake_X.detach())
loss_D = (loss(real_Y, ones) + loss(fake_Y, zeros)) / 2
loss_D.backward()
trainer_D.step(batch_size)
return float(loss_D.sum())
The generator is updated similarly. Here we reuse the cross-entropy loss
but change the label of the fake data from :math:`0` to :math:`1`.
.. code:: python
# Saved in the d2l package for later use
def update_G(Z, net_D, net_G, loss, trainer_G): # saved in d2l
"""Update generator"""
batch_size = Z.shape[0]
ones = np.ones((batch_size,), ctx=Z.context)
with autograd.record():
# We could reuse fake_X from update_D to save computation.
fake_X = net_G(Z)
# Recomputing fake_Y is needed since net_D is changed.
fake_Y = net_D(fake_X)
loss_G = loss(fake_Y, ones)
loss_G.backward()
trainer_G.step(batch_size)
return float(loss_G.sum())
Both the discriminator and the generator performs a binary logistic
regression with the cross-entropy loss. We use Adam to smooth the
training process. In each iteration, we first update the discriminator
and then the generator. We visualize both losses and generated examples.
.. code:: python
def train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G, latent_dim, data):
loss = gluon.loss.SigmoidBCELoss()
net_D.initialize(init=init.Normal(0.02), force_reinit=True)
net_G.initialize(init=init.Normal(0.02), force_reinit=True)
trainer_D = gluon.Trainer(net_D.collect_params(),
'adam', {'learning_rate': lr_D})
trainer_G = gluon.Trainer(net_G.collect_params(),
'adam', {'learning_rate': lr_G})
animator = d2l.Animator(xlabel='epoch', ylabel='loss',
xlim=[1, num_epochs], nrows=2, figsize=(5,5),
legend=['generator', 'discriminator'])
animator.fig.subplots_adjust(hspace=0.3)
for epoch in range(1, num_epochs+1):
# Train one epoch
timer = d2l.Timer()
metric = d2l.Accumulator(3) # loss_D, loss_G, num_examples
for X in data_iter:
batch_size = X.shape[0]
Z = np.random.normal(0, 1, size=(batch_size, latent_dim))
metric.add(update_D(X, Z, net_D, net_G, loss, trainer_D),
update_G(Z, net_D, net_G, loss, trainer_G),
batch_size)
# Visualize generated examples
Z = np.random.normal(0, 1, size=(100, latent_dim))
fake_X = net_G(Z).asnumpy()
animator.axes[1].cla()
animator.axes[1].scatter(data[:,0], data[:,1])
animator.axes[1].scatter(fake_X[:,0], fake_X[:,1])
animator.axes[1].legend(['real', 'generated'])
# Show the losses
loss_D, loss_G = metric[0]/metric[2], metric[1]/metric[2]
animator.add(epoch, (loss_D, loss_G))
print('loss_D %.3f, loss_G %.3f, %d examples/sec' % (
loss_D, loss_G, metric[2]/timer.stop()))
Now we specify the hyper-parameters to fit the Gaussian distribution.
.. code:: python
lr_D, lr_G, latent_dim, num_epochs = 0.05, 0.005, 2, 20
train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G,
latent_dim, data[:100].asnumpy())
.. parsed-literal::
:class: output
loss_D 0.693, loss_G 0.693, 647 examples/sec
.. figure:: output_gan_4e4dd7_18_1.svg
Summary
-------
- Generative adversarial networks (GANs) composes of two deep networks,
the generator and the discriminator.
- The generator generates the image as much closer to the true image as
possible to fool the discriminator, via maximizing the cross-entropy
loss, *i.e.*, :math:`\max \log(D(\mathbf{x'}))`.
- The discriminator tries to distinguish the generated images from the
true images, via minimizing the cross-entropy loss, *i.e.*,
:math:`\min - y \log D(\mathbf{x}) - (1-y)\log(1-D(\mathbf{x}))`.
Exercises
---------
- Does an equilibrium exist where the generator wins, *i.e.* the
discriminator ends up unable to distinguish the two distributions on
finite samples?
Scan the QR Code to `Discuss `__
-----------------------------------------------------------------
|image0|
.. |image0| image:: ../img/qr_gan.svg