.. _sec_vgg:
Networks Using Blocks (VGG)
===========================
While AlexNet offered empirical evidence that deep CNNs can achieve good
results, it did not provide a general template to guide subsequent
researchers in designing new networks. In the following sections, we
will introduce several heuristic concepts commonly used to design deep
networks.
Progress in this field mirrors that of VLSI (very large scale
integration) in chip design where engineers moved from placing
transistors to logical elements to logic blocks :cite:`Mead.1980`.
Similarly, the design of neural network architectures has grown
progressively more abstract, with researchers moving from thinking in
terms of individual neurons to whole layers, and now to blocks,
repeating patterns of layers. A decade later, this has now progressed to
researchers using entire trained models to repurpose them for different,
albeit related, tasks. Such large pretrained models are typically called
*foundation models* :cite:`bommasani2021opportunities`.
Back to network design. The idea of using blocks first emerged from the
Visual Geometry Group (VGG) at Oxford University, in their
eponymously-named *VGG* network :cite:`Simonyan.Zisserman.2014`. It is
easy to implement these repeated structures in code with any modern deep
learning framework by using loops and subroutines.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import torch
from torch import nn
from d2l import torch as d2l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from mxnet import init, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import jax
from flax import linen as nn
from d2l import jax as d2l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import tensorflow as tf
from d2l import tensorflow as d2l
.. raw:: html
.. raw:: html
.. _subsec_vgg-blocks:
VGG Blocks
----------
The basic building block of CNNs is a sequence of the following: (i) a
convolutional layer with padding to maintain the resolution, (ii) a
nonlinearity such as a ReLU, (iii) a pooling layer such as max-pooling
to reduce the resolution. One of the problems with this approach is that
the spatial resolution decreases quite rapidly. In particular, this
imposes a hard limit of :math:`\log_2 d` convolutional layers on the
network before all dimensions (:math:`d`) are used up. For instance, in
the case of ImageNet, it would be impossible to have more than 8
convolutional layers in this way.
The key idea of :cite:t:`Simonyan.Zisserman.2014` was to use *multiple*
convolutions in between downsampling via max-pooling in the form of a
block. They were primarily interested in whether deep or wide networks
perform better. For instance, the successive application of two
:math:`3 \times 3` convolutions touches the same pixels as a single
:math:`5 \times 5` convolution does. At the same time, the latter uses
approximately as many parameters (:math:`25 \cdot c^2`) as three
:math:`3 \times 3` convolutions do (:math:`3 \cdot 9 \cdot c^2`). In a
rather detailed analysis they showed that deep and narrow networks
significantly outperform their shallow counterparts. This set deep
learning on a quest for ever deeper networks with over 100 layers for
typical applications. Stacking :math:`3 \times 3` convolutions has
become a gold standard in later deep networks (a design decision only to
be revisited recently by :cite:t:`liu2022convnet`). Consequently, fast
implementations for small convolutions have become a staple on GPUs
:cite:`lavin2016fast`.
Back to VGG: a VGG block consists of a *sequence* of convolutions with
:math:`3\times3` kernels with padding of 1 (keeping height and width)
followed by a :math:`2 \times 2` max-pooling layer with stride of 2
(halving height and width after each block). In the code below, we
define a function called ``vgg_block`` to implement one VGG block.
The function below takes two arguments, corresponding to the number of
convolutional layers ``num_convs`` and the number of output channels
``num_channels``.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def vgg_block(num_convs, out_channels):
layers = []
for _ in range(num_convs):
layers.append(nn.LazyConv2d(out_channels, kernel_size=3, padding=1))
layers.append(nn.ReLU())
layers.append(nn.MaxPool2d(kernel_size=2,stride=2))
return nn.Sequential(*layers)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def vgg_block(num_convs, num_channels):
blk = nn.Sequential()
for _ in range(num_convs):
blk.add(nn.Conv2D(num_channels, kernel_size=3,
padding=1, activation='relu'))
blk.add(nn.MaxPool2D(pool_size=2, strides=2))
return blk
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def vgg_block(num_convs, out_channels):
layers = []
for _ in range(num_convs):
layers.append(nn.Conv(out_channels, kernel_size=(3, 3), padding=(1, 1)))
layers.append(nn.relu)
layers.append(lambda x: nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)))
return nn.Sequential(layers)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def vgg_block(num_convs, num_channels):
blk = tf.keras.models.Sequential()
for _ in range(num_convs):
blk.add(
tf.keras.layers.Conv2D(num_channels, kernel_size=3,
padding='same', activation='relu'))
blk.add(tf.keras.layers.MaxPool2D(pool_size=2, strides=2))
return blk
.. raw:: html
.. raw:: html
.. _subsec_vgg-network:
VGG Network
-----------
Like AlexNet and LeNet, the VGG Network can be partitioned into two
parts: the first consisting mostly of convolutional and pooling layers
and the second consisting of fully connected layers that are identical
to those in AlexNet. The key difference is that the convolutional layers
are grouped in nonlinear transformations that leave the dimensonality
unchanged, followed by a resolution-reduction step, as depicted in
:numref:`fig_vgg`.
.. _fig_vgg:
.. figure:: ../img/vgg.svg
:width: 400px
From AlexNet to VGG. The key difference is that VGG consists of
blocks of layers, whereas AlexNet’s layers are all designed
individually.
The convolutional part of the network connects several VGG blocks from
:numref:`fig_vgg` (also defined in the ``vgg_block`` function) in
succession. This grouping of convolutions is a pattern that has remained
almost unchanged over the past decade, although the specific choice of
operations has undergone considerable modifications. The variable
``arch`` consists of a list of tuples (one per block), where each
contains two values: the number of convolutional layers and the number
of output channels, which are precisely the arguments required to call
the ``vgg_block`` function. As such, VGG defines a *family* of networks
rather than just a specific manifestation. To build a specific network
we simply iterate over ``arch`` to compose the blocks.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class VGG(d2l.Classifier):
def __init__(self, arch, lr=0.1, num_classes=10):
super().__init__()
self.save_hyperparameters()
conv_blks = []
for (num_convs, out_channels) in arch:
conv_blks.append(vgg_block(num_convs, out_channels))
self.net = nn.Sequential(
*conv_blks, nn.Flatten(),
nn.LazyLinear(4096), nn.ReLU(), nn.Dropout(0.5),
nn.LazyLinear(4096), nn.ReLU(), nn.Dropout(0.5),
nn.LazyLinear(num_classes))
self.net.apply(d2l.init_cnn)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class VGG(d2l.Classifier):
def __init__(self, arch, lr=0.1, num_classes=10):
super().__init__()
self.save_hyperparameters()
self.net = nn.Sequential()
for (num_convs, num_channels) in arch:
self.net.add(vgg_block(num_convs, num_channels))
self.net.add(nn.Dense(4096, activation='relu'), nn.Dropout(0.5),
nn.Dense(4096, activation='relu'), nn.Dropout(0.5),
nn.Dense(num_classes))
self.net.initialize(init.Xavier())
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class VGG(d2l.Classifier):
arch: list
lr: float = 0.1
num_classes: int = 10
training: bool = True
def setup(self):
conv_blks = []
for (num_convs, out_channels) in self.arch:
conv_blks.append(vgg_block(num_convs, out_channels))
self.net = nn.Sequential([
*conv_blks,
lambda x: x.reshape((x.shape[0], -1)), # flatten
nn.Dense(4096), nn.relu,
nn.Dropout(0.5, deterministic=not self.training),
nn.Dense(4096), nn.relu,
nn.Dropout(0.5, deterministic=not self.training),
nn.Dense(self.num_classes)])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class VGG(d2l.Classifier):
def __init__(self, arch, lr=0.1, num_classes=10):
super().__init__()
self.save_hyperparameters()
self.net = tf.keras.models.Sequential()
for (num_convs, num_channels) in arch:
self.net.add(vgg_block(num_convs, num_channels))
self.net.add(
tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(4096, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(4096, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(num_classes)]))
.. raw:: html
.. raw:: html
The original VGG network had five convolutional blocks, among which the
first two have one convolutional layer each and the latter three contain
two convolutional layers each. The first block has 64 output channels
and each subsequent block doubles the number of output channels, until
that number reaches 512. Since this network uses eight convolutional
layers and three fully connected layers, it is often called VGG-11.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
VGG(arch=((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))).layer_summary(
(1, 1, 224, 224))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Sequential output shape: torch.Size([1, 64, 112, 112])
Sequential output shape: torch.Size([1, 128, 56, 56])
Sequential output shape: torch.Size([1, 256, 28, 28])
Sequential output shape: torch.Size([1, 512, 14, 14])
Sequential output shape: torch.Size([1, 512, 7, 7])
Flatten output shape: torch.Size([1, 25088])
Linear output shape: torch.Size([1, 4096])
ReLU output shape: torch.Size([1, 4096])
Dropout output shape: torch.Size([1, 4096])
Linear output shape: torch.Size([1, 4096])
ReLU output shape: torch.Size([1, 4096])
Dropout output shape: torch.Size([1, 4096])
Linear output shape: torch.Size([1, 10])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
VGG(arch=((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))).layer_summary(
(1, 1, 224, 224))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Sequential output shape: (1, 64, 112, 112)
Sequential output shape: (1, 128, 56, 56)
Sequential output shape: (1, 256, 28, 28)
Sequential output shape: (1, 512, 14, 14)
Sequential output shape: (1, 512, 7, 7)
Dense output shape: (1, 4096)
Dropout output shape: (1, 4096)
Dense output shape: (1, 4096)
Dropout output shape: (1, 4096)
Dense output shape: (1, 10)
[22:40:53] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
VGG(arch=((1, 64), (1, 128), (2, 256), (2, 512), (2, 512)),
training=False).layer_summary((1, 224, 224, 1))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Sequential output shape: (1, 112, 112, 64)
Sequential output shape: (1, 56, 56, 128)
Sequential output shape: (1, 28, 28, 256)
Sequential output shape: (1, 14, 14, 512)
Sequential output shape: (1, 7, 7, 512)
function output shape: (1, 25088)
Dense output shape: (1, 4096)
custom_jvp output shape: (1, 4096)
Dropout output shape: (1, 4096)
Dense output shape: (1, 4096)
custom_jvp output shape: (1, 4096)
Dropout output shape: (1, 4096)
Dense output shape: (1, 10)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
VGG(arch=((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))).layer_summary(
(1, 224, 224, 1))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Sequential output shape: (1, 112, 112, 64)
Sequential output shape: (1, 56, 56, 128)
Sequential output shape: (1, 28, 28, 256)
Sequential output shape: (1, 14, 14, 512)
Sequential output shape: (1, 7, 7, 512)
Sequential output shape: (1, 10)
.. raw:: html
.. raw:: html
As you can see, we halve height and width at each block, finally
reaching a height and width of 7 before flattening the representations
for processing by the fully connected part of the network.
:cite:t:`Simonyan.Zisserman.2014` described several other variants of
VGG. In fact, it has become the norm to propose *families* of networks
with different speed–accuracy trade-off when introducing a new
architecture.
Training
--------
Since VGG-11 is computationally more demanding than AlexNet we construct
a network with a smaller number of channels. This is more than
sufficient for training on Fashion-MNIST. The model training process is
similar to that of AlexNet in :numref:`sec_alexnet`. Again observe the
close match between validation and training loss, suggesting only a
small amount of overfitting.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model = VGG(arch=((1, 16), (1, 32), (2, 64), (2, 128), (2, 128)), lr=0.01)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(224, 224))
model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)
trainer.fit(model, data)
.. figure:: output_vgg_4a7574_63_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model = VGG(arch=((1, 16), (1, 32), (2, 64), (2, 128), (2, 128)), lr=0.01)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(224, 224))
trainer.fit(model, data)
.. figure:: output_vgg_4a7574_66_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model = VGG(arch=((1, 16), (1, 32), (2, 64), (2, 128), (2, 128)), lr=0.01)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(224, 224))
trainer.fit(model, data)
.. figure:: output_vgg_4a7574_69_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
trainer = d2l.Trainer(max_epochs=10)
data = d2l.FashionMNIST(batch_size=128, resize=(224, 224))
with d2l.try_gpu():
model = VGG(arch=((1, 16), (1, 32), (2, 64), (2, 128), (2, 128)), lr=0.01)
trainer.fit(model, data)
.. figure:: output_vgg_4a7574_72_0.svg
.. raw:: html
.. raw:: html
Summary
-------
One might argue that VGG is the first truly modern convolutional neural
network. While AlexNet introduced many of the components of what make
deep learning effective at scale, it is VGG that arguably introduced key
properties such as blocks of multiple convolutions and a preference for
deep and narrow networks. It is also the first network that is actually
an entire family of similarly parametrized models, giving the
practitioner ample trade-off between complexity and speed. This is also
the place where modern deep learning frameworks shine. It is no longer
necessary to generate XML configuration files to specify a network but
rather, to assemble said networks through simple Python code.
More recently ParNet :cite:`Goyal.Bochkovskiy.Deng.ea.2021`
demonstrated that it is possible to achieve competitive performance
using a much more shallow architecture through a large number of
parallel computations. This is an exciting development and there is hope
that it will influence architecture designs in the future. For the
remainder of the chapter, though, we will follow the path of scientific
progress over the past decade.
Exercises
---------
1. Compared with AlexNet, VGG is much slower in terms of computation,
and it also needs more GPU memory.
1. Compare the number of parameters needed for AlexNet and VGG.
2. Compare the number of floating point operations used in the
convolutional layers and in the fully connected layers.
3. How could you reduce the computational cost created by the fully
connected layers?
2. When displaying the dimensions associated with the various layers of
the network, we only see the information associated with eight blocks
(plus some auxiliary transforms), even though the network has 11
layers. Where did the remaining three layers go?
3. Use Table 1 in the VGG paper :cite:`Simonyan.Zisserman.2014` to
construct other common models, such as VGG-16 or VGG-19.
4. Upsampling the resolution in Fashion-MNIST eight-fold from
:math:`28 \times 28` to :math:`224 \times 224` dimensions is very
wasteful. Try modifying the network architecture and resolution
conversion, e.g., to 56 or to 84 dimensions for its input instead.
Can you do so without reducing the accuracy of the network? Consult
the VGG paper :cite:`Simonyan.Zisserman.2014` for ideas on adding
more nonlinearities prior to downsampling.
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html