8.4. Multi-Branch Networks (GoogLeNet)¶ Open the notebook in SageMaker Studio Lab
In 2014, GoogLeNet won the ImageNet Challenge (Szegedy et al., 2015), using a structure that combined the strengths of NiN (Lin et al., 2013), repeated blocks (Simonyan and Zisserman, 2014), and a cocktail of convolution kernels. It was arguably also the first network that exhibited a clear distinction among the stem (data ingest), body (data processing), and head (prediction) in a CNN. This design pattern has persisted ever since in the design of deep networks: the stem is given by the first two or three convolutions that operate on the image. They extract low-level features from the underlying images. This is followed by a body of convolutional blocks. Finally, the head maps the features obtained so far to the required classification, segmentation, detection, or tracking problem at hand.
The key contribution in GoogLeNet was the design of the network body. It solved the problem of selecting convolution kernels in an ingenious way. While other works tried to identify which convolution, ranging from \(1 \times 1\) to \(11 \times 11\) would be best, it simply concatenated multi-branch convolutions. In what follows we introduce a slightly simplified version of GoogLeNet: the original design included a number of tricks for stabilizing training through intermediate loss functions, applied to multiple layers of the network. They are no longer necessary due to the availability of improved training algorithms.
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
from mxnet import init, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
import tensorflow as tf
from d2l import tensorflow as d2l
8.4.1. Inception Blocks¶
The basic convolutional block in GoogLeNet is called an Inception block, stemming from the meme “we need to go deeper” from the movie Inception.
As depicted in Fig. 8.4.1, the inception block consists of four parallel branches. The first three branches use convolutional layers with window sizes of \(1\times 1\), \(3\times 3\), and \(5\times 5\) to extract information from different spatial sizes. The middle two branches also add a \(1\times 1\) convolution of the input to reduce the number of channels, reducing the model’s complexity. The fourth branch uses a \(3\times 3\) max-pooling layer, followed by a \(1\times 1\) convolutional layer to change the number of channels. The four branches all use appropriate padding to give the input and output the same height and width. Finally, the outputs along each branch are concatenated along the channel dimension and comprise the block’s output. The commonly-tuned hyperparameters of the Inception block are the number of output channels per layer, i.e., how to allocate capacity among convolutions of different size.
class Inception(nn.Module):
# c1--c4 are the number of output channels for each branch
def __init__(self, c1, c2, c3, c4, **kwargs):
super(Inception, self).__init__(**kwargs)
# Branch 1
self.b1_1 = nn.LazyConv2d(c1, kernel_size=1)
# Branch 2
self.b2_1 = nn.LazyConv2d(c2[0], kernel_size=1)
self.b2_2 = nn.LazyConv2d(c2[1], kernel_size=3, padding=1)
# Branch 3
self.b3_1 = nn.LazyConv2d(c3[0], kernel_size=1)
self.b3_2 = nn.LazyConv2d(c3[1], kernel_size=5, padding=2)
# Branch 4
self.b4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
self.b4_2 = nn.LazyConv2d(c4, kernel_size=1)
def forward(self, x):
b1 = F.relu(self.b1_1(x))
b2 = F.relu(self.b2_2(F.relu(self.b2_1(x))))
b3 = F.relu(self.b3_2(F.relu(self.b3_1(x))))
b4 = F.relu(self.b4_2(self.b4_1(x)))
return torch.cat((b1, b2, b3, b4), dim=1)
class Inception(nn.Block):
# c1--c4 are the number of output channels for each branch
def __init__(self, c1, c2, c3, c4, **kwargs):
super(Inception, self).__init__(**kwargs)
# Branch 1
self.b1_1 = nn.Conv2D(c1, kernel_size=1, activation='relu')
# Branch 2
self.b2_1 = nn.Conv2D(c2[0], kernel_size=1, activation='relu')
self.b2_2 = nn.Conv2D(c2[1], kernel_size=3, padding=1,
activation='relu')
# Branch 3
self.b3_1 = nn.Conv2D(c3[0], kernel_size=1, activation='relu')
self.b3_2 = nn.Conv2D(c3[1], kernel_size=5, padding=2,
activation='relu')
# Branch 4
self.b4_1 = nn.MaxPool2D(pool_size=3, strides=1, padding=1)
self.b4_2 = nn.Conv2D(c4, kernel_size=1, activation='relu')
def forward(self, x):
b1 = self.b1_1(x)
b2 = self.b2_2(self.b2_1(x))
b3 = self.b3_2(self.b3_1(x))
b4 = self.b4_2(self.b4_1(x))
return np.concatenate((b1, b2, b3, b4), axis=1)
class Inception(nn.Module):
# `c1`--`c4` are the number of output channels for each branch
c1: int
c2: tuple
c3: tuple
c4: int
def setup(self):
# Branch 1
self.b1_1 = nn.Conv(self.c1, kernel_size=(1, 1))
# Branch 2
self.b2_1 = nn.Conv(self.c2[0], kernel_size=(1, 1))
self.b2_2 = nn.Conv(self.c2[1], kernel_size=(3, 3), padding='same')
# Branch 3
self.b3_1 = nn.Conv(self.c3[0], kernel_size=(1, 1))
self.b3_2 = nn.Conv(self.c3[1], kernel_size=(5, 5), padding='same')
# Branch 4
self.b4_1 = lambda x: nn.max_pool(x, window_shape=(3, 3),
strides=(1, 1), padding='same')
self.b4_2 = nn.Conv(self.c4, kernel_size=(1, 1))
def __call__(self, x):
b1 = nn.relu(self.b1_1(x))
b2 = nn.relu(self.b2_2(nn.relu(self.b2_1(x))))
b3 = nn.relu(self.b3_2(nn.relu(self.b3_1(x))))
b4 = nn.relu(self.b4_2(self.b4_1(x)))
return jnp.concatenate((b1, b2, b3, b4), axis=-1)
class Inception(tf.keras.Model):
# c1--c4 are the number of output channels for each branch
def __init__(self, c1, c2, c3, c4):
super().__init__()
self.b1_1 = tf.keras.layers.Conv2D(c1, 1, activation='relu')
self.b2_1 = tf.keras.layers.Conv2D(c2[0], 1, activation='relu')
self.b2_2 = tf.keras.layers.Conv2D(c2[1], 3, padding='same',
activation='relu')
self.b3_1 = tf.keras.layers.Conv2D(c3[0], 1, activation='relu')
self.b3_2 = tf.keras.layers.Conv2D(c3[1], 5, padding='same',
activation='relu')
self.b4_1 = tf.keras.layers.MaxPool2D(3, 1, padding='same')
self.b4_2 = tf.keras.layers.Conv2D(c4, 1, activation='relu')
def call(self, x):
b1 = self.b1_1(x)
b2 = self.b2_2(self.b2_1(x))
b3 = self.b3_2(self.b3_1(x))
b4 = self.b4_2(self.b4_1(x))
return tf.keras.layers.Concatenate()([b1, b2, b3, b4])
To gain some intuition for why this network works so well, consider the combination of the filters. They explore the image in a variety of filter sizes. This means that details at different extents can be recognized efficiently by filters of different sizes. At the same time, we can allocate different amounts of parameters for different filters.
8.4.2. GoogLeNet Model¶
As shown in Fig. 8.4.2, GoogLeNet uses a stack of a total of 9 inception blocks, arranged into three groups with max-pooling in between, and global average pooling in its head to generate its estimates. Max-pooling between inception blocks reduces the dimensionality. At its stem, the first module is similar to AlexNet and LeNet.
We can now implement GoogLeNet piece by piece. Let’s begin with the stem. The first module uses a 64-channel \(7\times 7\) convolutional layer.
class GoogleNet(d2l.Classifier):
def b1(self):
return nn.Sequential(
nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3),
nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
class GoogleNet(d2l.Classifier):
def b1(self):
net = nn.Sequential()
net.add(nn.Conv2D(64, kernel_size=7, strides=2, padding=3,
activation='relu'),
nn.MaxPool2D(pool_size=3, strides=2, padding=1))
return net
class GoogleNet(d2l.Classifier):
lr: float = 0.1
num_classes: int = 10
def setup(self):
self.net = nn.Sequential([self.b1(), self.b2(), self.b3(), self.b4(),
self.b5(), nn.Dense(self.num_classes)])
def b1(self):
return nn.Sequential([
nn.Conv(64, kernel_size=(7, 7), strides=(2, 2), padding='same'),
nn.relu,
lambda x: nn.max_pool(x, window_shape=(3, 3), strides=(2, 2),
padding='same')])
class GoogleNet(d2l.Classifier):
def b1(self):
return tf.keras.models.Sequential([
tf.keras.layers.Conv2D(64, 7, strides=2, padding='same',
activation='relu'),
tf.keras.layers.MaxPool2D(pool_size=3, strides=2,
padding='same')])
The second module uses two convolutional layers: first, a 64-channel \(1\times 1\) convolutional layer, followed by a \(3\times 3\) convolutional layer that triples the number of channels. This corresponds to the second branch in the Inception block and concludes the design of the body. At this point we have 192 channels.
@d2l.add_to_class(GoogleNet)
def b2(self):
return nn.Sequential(
nn.LazyConv2d(64, kernel_size=1), nn.ReLU(),
nn.LazyConv2d(192, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
@d2l.add_to_class(GoogleNet)
def b2(self):
net = nn.Sequential()
net.add(nn.Conv2D(64, kernel_size=1, activation='relu'),
nn.Conv2D(192, kernel_size=3, padding=1, activation='relu'),
nn.MaxPool2D(pool_size=3, strides=2, padding=1))
return net
@d2l.add_to_class(GoogleNet)
def b2(self):
return nn.Sequential([nn.Conv(64, kernel_size=(1, 1)),
nn.relu,
nn.Conv(192, kernel_size=(3, 3), padding='same'),
nn.relu,
lambda x: nn.max_pool(x, window_shape=(3, 3),
strides=(2, 2),
padding='same')])
@d2l.add_to_class(GoogleNet)
def b2(self):
return tf.keras.Sequential([
tf.keras.layers.Conv2D(64, 1, activation='relu'),
tf.keras.layers.Conv2D(192, 3, padding='same', activation='relu'),
tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')])
The third module connects two complete Inception blocks in series. The number of output channels of the first Inception block is \(64+128+32+32=256\). This amounts to a ratio of the number of output channels among the four branches of \(2:4:1:1\). To achieve this, we first reduce the input dimensions by \(\frac{1}{2}\) and by \(\frac{1}{12}\) in the second and third branch respectively to arrive at \(96 = 192/2\) and \(16 = 192/12\) channels respectively.
The number of output channels of the second Inception block is increased
to \(128+192+96+64=480\), yielding a ratio of
\(128:192:96:64 = 4:6:3:2\). As before, we need to reduce the number
of intermediate dimensions in the second and third channel. A scale of
\(\frac{1}{2}\) and \(\frac{1}{8}\) respectively suffices,
yielding \(128\) and \(32\) channels respectively. This is
captured by the arguments of the following Inception
block
constructors.
@d2l.add_to_class(GoogleNet)
def b3(self):
return nn.Sequential(Inception(64, (96, 128), (16, 32), 32),
Inception(128, (128, 192), (32, 96), 64),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
@d2l.add_to_class(GoogleNet)
def b3(self):
net = nn.Sequential()
net.add(Inception(64, (96, 128), (16, 32), 32),
Inception(128, (128, 192), (32, 96), 64),
nn.MaxPool2D(pool_size=3, strides=2, padding=1))
return net
@d2l.add_to_class(GoogleNet)
def b3(self):
return nn.Sequential([Inception(64, (96, 128), (16, 32), 32),
Inception(128, (128, 192), (32, 96), 64),
lambda x: nn.max_pool(x, window_shape=(3, 3),
strides=(2, 2),
padding='same')])
@d2l.add_to_class(GoogleNet)
def b3(self):
return tf.keras.models.Sequential([
Inception(64, (96, 128), (16, 32), 32),
Inception(128, (128, 192), (32, 96), 64),
tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')])
The fourth module is more complicated. It connects five Inception blocks in series, and they have \(192+208+48+64=512\), \(160+224+64+64=512\), \(128+256+64+64=512\), \(112+288+64+64=528\), and \(256+320+128+128=832\) output channels, respectively. The number of channels assigned to these branches is similar to that in the third module: the second branch with the \(3\times 3\) convolutional layer outputs the largest number of channels, followed by the first branch with only the \(1\times 1\) convolutional layer, the third branch with the \(5\times 5\) convolutional layer, and the fourth branch with the \(3\times 3\) max-pooling layer. The second and third branches will first reduce the number of channels according to the ratio. These ratios are slightly different in different Inception blocks.
@d2l.add_to_class(GoogleNet)
def b4(self):
return nn.Sequential(Inception(192, (96, 208), (16, 48), 64),
Inception(160, (112, 224), (24, 64), 64),
Inception(128, (128, 256), (24, 64), 64),
Inception(112, (144, 288), (32, 64), 64),
Inception(256, (160, 320), (32, 128), 128),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
@d2l.add_to_class(GoogleNet)
def b4(self):
net = nn.Sequential()
net.add(Inception(192, (96, 208), (16, 48), 64),
Inception(160, (112, 224), (24, 64), 64),
Inception(128, (128, 256), (24, 64), 64),
Inception(112, (144, 288), (32, 64), 64),
Inception(256, (160, 320), (32, 128), 128),
nn.MaxPool2D(pool_size=3, strides=2, padding=1))
return net
@d2l.add_to_class(GoogleNet)
def b4(self):
return nn.Sequential([Inception(192, (96, 208), (16, 48), 64),
Inception(160, (112, 224), (24, 64), 64),
Inception(128, (128, 256), (24, 64), 64),
Inception(112, (144, 288), (32, 64), 64),
Inception(256, (160, 320), (32, 128), 128),
lambda x: nn.max_pool(x, window_shape=(3, 3),
strides=(2, 2),
padding='same')])
@d2l.add_to_class(GoogleNet)
def b4(self):
return tf.keras.Sequential([
Inception(192, (96, 208), (16, 48), 64),
Inception(160, (112, 224), (24, 64), 64),
Inception(128, (128, 256), (24, 64), 64),
Inception(112, (144, 288), (32, 64), 64),
Inception(256, (160, 320), (32, 128), 128),
tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')])
The fifth module has two Inception blocks with \(256+320+128+128=832\) and \(384+384+128+128=1024\) output channels. The number of channels assigned to each branch is the same as that in the third and fourth modules, but differs in specific values. It should be noted that the fifth block is followed by the output layer. This block uses the global average pooling layer to change the height and width of each channel to 1, just as in NiN. Finally, we turn the output into a two-dimensional array followed by a fully connected layer whose number of outputs is the number of label classes.
@d2l.add_to_class(GoogleNet)
def b5(self):
return nn.Sequential(Inception(256, (160, 320), (32, 128), 128),
Inception(384, (192, 384), (48, 128), 128),
nn.AdaptiveAvgPool2d((1,1)), nn.Flatten())
@d2l.add_to_class(GoogleNet)
def b5(self):
net = nn.Sequential()
net.add(Inception(256, (160, 320), (32, 128), 128),
Inception(384, (192, 384), (48, 128), 128),
nn.GlobalAvgPool2D())
return net
@d2l.add_to_class(GoogleNet)
def b5(self):
return nn.Sequential([Inception(256, (160, 320), (32, 128), 128),
Inception(384, (192, 384), (48, 128), 128),
# Flax does not provide a GlobalAvgPool2D layer
lambda x: nn.avg_pool(x,
window_shape=x.shape[1:3],
strides=x.shape[1:3],
padding='valid'),
lambda x: x.reshape((x.shape[0], -1))])
@d2l.add_to_class(GoogleNet)
def b5(self):
return tf.keras.Sequential([
Inception(256, (160, 320), (32, 128), 128),
Inception(384, (192, 384), (48, 128), 128),
tf.keras.layers.GlobalAvgPool2D(),
tf.keras.layers.Flatten()])
Now that we defined all blocks b1
through b5
, it is just a
matter of assembling them all into a full network.
@d2l.add_to_class(GoogleNet)
def __init__(self, lr=0.1, num_classes=10):
super(GoogleNet, self).__init__()
self.save_hyperparameters()
self.net = nn.Sequential(self.b1(), self.b2(), self.b3(), self.b4(),
self.b5(), nn.LazyLinear(num_classes))
self.net.apply(d2l.init_cnn)
@d2l.add_to_class(GoogleNet)
def __init__(self, lr=0.1, num_classes=10):
super(GoogleNet, self).__init__()
self.save_hyperparameters()
self.net = nn.Sequential()
self.net.add(self.b1(), self.b2(), self.b3(), self.b4(), self.b5(),
nn.Dense(num_classes))
self.net.initialize(init.Xavier())
@d2l.add_to_class(GoogleNet)
def __init__(self, lr=0.1, num_classes=10):
super(GoogleNet, self).__init__()
self.save_hyperparameters()
self.net = tf.keras.Sequential([
self.b1(), self.b2(), self.b3(), self.b4(), self.b5(),
tf.keras.layers.Dense(num_classes)])
The GoogLeNet model is computationally complex. Note the large number of relatively arbitrary hyperparameters in terms of the number of channels chosen, the number of blocks prior to dimensionality reduction, the relative partitioning of capacity across channels, etc. Much of it is due to the fact that at the time when GoogLeNet was introduced, automatic tools for network definition or design exploration were not yet available. For instance, by now we take it for granted that a competent deep learning framework is capable of inferring dimensionalities of input tensors automatically. At the time, many such configurations had to be specified explicitly by the experimenter, thus often slowing down active experimentation. Moreover, the tools needed for automatic exploration were still in flux and initial experiments largely amounted to costly brute-force exploration, genetic algorithms, and similar strategies.
For now the only modification we will carry out is to reduce the input height and width from 224 to 96 to have a reasonable training time on Fashion-MNIST. This simplifies the computation. Let’s have a look at the changes in the shape of the output between the various modules.
model = GoogleNet().layer_summary((1, 1, 96, 96))
Sequential output shape: torch.Size([1, 64, 24, 24])
Sequential output shape: torch.Size([1, 192, 12, 12])
Sequential output shape: torch.Size([1, 480, 6, 6])
Sequential output shape: torch.Size([1, 832, 3, 3])
Sequential output shape: torch.Size([1, 1024])
Linear output shape: torch.Size([1, 10])
model = GoogleNet().layer_summary((1, 1, 96, 96))
[22:26:25] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
Sequential output shape: (1, 64, 24, 24)
Sequential output shape: (1, 192, 12, 12)
Sequential output shape: (1, 480, 6, 6)
Sequential output shape: (1, 832, 3, 3)
Sequential output shape: (1, 1024, 1, 1)
Dense output shape: (1, 10)
model = GoogleNet().layer_summary((1, 96, 96, 1))
Sequential output shape: (1, 24, 24, 64)
Sequential output shape: (1, 12, 12, 192)
Sequential output shape: (1, 6, 6, 480)
Sequential output shape: (1, 3, 3, 832)
Sequential output shape: (1, 1024)
Dense output shape: (1, 10)
model = GoogleNet().layer_summary((1, 96, 96, 1))
Sequential output shape: (1, 24, 24, 64)
Sequential output shape: (1, 12, 12, 192)
Sequential output shape: (1, 6, 6, 480)
Sequential output shape: (1, 3, 3, 832)
Sequential output shape: (1, 1024)
Dense output shape: (1, 10)
8.4.3. Training¶
As before, we train our model using the Fashion-MNIST dataset. We transform it to \(96 \times 96\) pixel resolution before invoking the training procedure.
model = GoogleNet(lr=0.01)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(96, 96))
model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)
trainer.fit(model, data)
model = GoogleNet(lr=0.01)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(96, 96))
trainer.fit(model, data)
model = GoogleNet(lr=0.01)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(96, 96))
trainer.fit(model, data)
trainer = d2l.Trainer(max_epochs=10)
data = d2l.FashionMNIST(batch_size=128, resize=(96, 96))
with d2l.try_gpu():
model = GoogleNet(lr=0.01)
trainer.fit(model, data)
8.4.4. Discussion¶
A key feature of GoogLeNet is that it is actually cheaper to compute than its predecessors while simultaneously providing improved accuracy. This marks the beginning of a much more deliberate network design that trades off the cost of evaluating a network with a reduction in errors. It also marks the beginning of experimentation at a block level with network design hyperparameters, even though it was entirely manual at the time. We will revisit this topic in Section 8.8 when discussing strategies for network structure exploration.
Over the following sections we will encounter a number of design choices (e.g., batch normalization, residual connections, and channel grouping) that allow us to improve networks significantly. For now, you can be proud to have implemented what is arguably the first truly modern CNN.
8.4.5. Exercises¶
GoogLeNet was so successful that it went through a number of iterations, progressively improving speed and accuracy. Try to implement and run some of them. They include the following:
Add a batch normalization layer (Ioffe and Szegedy, 2015), as described later in Section 8.5.
Make adjustments to the Inception block (width, choice and order of convolutions), as described in Szegedy et al. (2016).
Use label smoothing for model regularization, as described in Szegedy et al. (2016).
Make further adjustments to the Inception block by adding residual connection (Szegedy et al., 2017), as described later in Section 8.6.
What is the minimum image size needed for GoogLeNet to work?
Can you design a variant of GoogLeNet that works on Fashion-MNIST’s native resolution of \(28 \times 28\) pixels? How would you need to change the stem, the body, and the head of the network, if anything at all?
Compare the model parameter sizes of AlexNet, VGG, NiN, and GoogLeNet. How do the latter two network architectures significantly reduce the model parameter size?
Compare the amount of computation needed in GoogLeNet and AlexNet. How does this affect the design of an accelerator chip, e.g., in terms of memory size, memory bandwidth, cache size, the amount of computation, and the benefit of specialized operations?