# 8.5. Batch Normalization¶ Open the notebook in SageMaker Studio Lab

Training deep neural networks is difficult. Getting them to converge in
a reasonable amount of time can be tricky. In this section, we describe
*batch normalization*, a popular and effective technique that
consistently accelerates the convergence of deep networks
(Ioffe and Szegedy, 2015). Together with residual blocks—covered
later in Section 8.6—batch normalization has made it
possible for practitioners to routinely train networks with over 100
layers. A secondary (serendipitous) benefit of batch normalization lies
in its inherent regularization.

## 8.5.1. Training Deep Networks¶

When working with data, we often preprocess before training. Choices
regarding data preprocessing often make an enormous difference in the
final results. Recall our application of MLPs to predicting house prices
(Section 5.7). Our first step when working with real
data was to standardize our input features to have zero mean
\(\boldsymbol{\mu} = 0\) and unit variance
\(\boldsymbol{\Sigma} = \boldsymbol{1}\) across multiple
observations (Friedman, 1987). At a minimum, one
frequently rescales it such that the diagonal is unity, i.e.,
\(\Sigma_{ii} = 1\). Yet another strategy is to rescale vectors to
unit length, possibly zero mean *per observation*. This can work well,
e.g., for spatial sensor data. These preprocessing techniques and many
more are beneficial to keep the estimation problem well controlled. See
e.g., the articles by Guyon *et al.* (2008) for a review of
feature selection and extraction techniques. Standardizing vectors also
has the nice side-effect of constraining the function complexity of
functions that act upon it. For instance, the celebrated radius-margin
bound (Vapnik, 1995) in support vector machines and the Perceptron
Convergence Theorem (Novikoff, 1962) rely on inputs of bounded norm.

Intuitively, this standardization plays nicely with our optimizers since
it puts the parameters *a priori* at a similar scale. As such, it is
only natural to ask whether a corresponding normalization step *inside*
a deep network might not be beneficial. While this isn’t quite the
reasoning that led to the invention of batch normalization
(Ioffe and Szegedy, 2015), it is a useful way of understanding it and
its cousin, layer normalization (Ba *et al.*, 2016) within a
unified framework.

Second, for a typical MLP or CNN, as we train, the variables in
intermediate layers (e.g., affine transformation outputs in MLP) may
take values with widely varying magnitudes: both along the layers from
input to output, across units in the same layer, and over time due to
our updates to the model parameters. The inventors of batch
normalization postulated informally that this drift in the distribution
of such variables could hamper the convergence of the network.
Intuitively, we might conjecture that if one layer has variable
activations that are 100 times that of another layer, this might
necessitate compensatory adjustments in the learning rates. Adaptive
solvers such as AdaGrad (Duchi *et al.*, 2011), Adam
(Kingma and Ba, 2014), Yogi (Zaheer *et al.*, 2018), or
Distributed Shampoo (Anil *et al.*, 2020) aim to address this from
the viewpoint of optimization, e.g., by adding aspects of second-order
methods. The alternative is to prevent the problem from occurring,
simply by adaptive normalization.

Third, deeper networks are complex and tend to be more easily capable of overfitting. This means that regularization becomes more critical. A common technique for regularization is noise injection. This has been known for a long time, e.g., with regard to noise injection for the inputs (Bishop, 1995). It also forms the basis of dropout in Section 5.6. As it turns out, quite serendipitously, batch normalization conveys all three benefits: preprocessing, numerical stability, and regularization.

Batch normalization is applied to individual layers, or optionally, to
all of them: In each training iteration, we first normalize the inputs
(of batch normalization) by subtracting their mean and dividing by their
standard deviation, where both are estimated based on the statistics of
the current minibatch. Next, we apply a scale coefficient and an offset
to recover the lost degrees of freedom. It is precisely due to this
*normalization* based on *batch* statistics that *batch normalization*
derives its name.

Note that if we tried to apply batch normalization with minibatches of size 1, we would not be able to learn anything. That is because after subtracting the means, each hidden unit would take value 0. As you might guess, since we are devoting a whole section to batch normalization, with large enough minibatches, the approach proves effective and stable. One takeaway here is that when applying batch normalization, the choice of batch size is even more significant than without batch normalization, or at least, suitable calibration is needed as we might adjust it.

Denote by \(\mathcal{B}\) a minibatch and let \(\mathbf{x} \in \mathcal{B}\) be an input to batch normalization (\(\mathrm{BN}\)). In this case the batch normalization is defined as follows:

In (8.5.1), \(\hat{\boldsymbol{\mu}}_\mathcal{B}\)
is the sample mean and \(\hat{\boldsymbol{\sigma}}_\mathcal{B}\) is
the sample standard deviation of the minibatch \(\mathcal{B}\).
After applying standardization, the resulting minibatch has zero mean
and unit variance. The choice of unit variance (vs. some other magic
number) is an arbitrary choice. We recover this degree of freedom by
including an elementwise *scale parameter* \(\boldsymbol{\gamma}\)
and *shift parameter* \(\boldsymbol{\beta}\) that have the same
shape as \(\mathbf{x}\). Both are parameters that need to be learned
as part of model training.

The variable magnitudes for intermediate layers cannot diverge during training since batch normalization actively centers and rescales them back to a given mean and size (via \(\hat{\boldsymbol{\mu}}_\mathcal{B}\) and \({\hat{\boldsymbol{\sigma}}_\mathcal{B}}\)). Practical experience confirms that, as alluded to when discussing feature rescaling, batch normalization seems to allow for more aggressive learning rates. We calculate \(\hat{\boldsymbol{\mu}}_\mathcal{B}\) and \({\hat{\boldsymbol{\sigma}}_\mathcal{B}}\) in (8.5.1) as follows:

Note that we add a small constant \(\epsilon > 0\) to the variance estimate to ensure that we never attempt division by zero, even in cases where the empirical variance estimate might be very small or even vanish. The estimates \(\hat{\boldsymbol{\mu}}_\mathcal{B}\) and \({\hat{\boldsymbol{\sigma}}_\mathcal{B}}\) counteract the scaling issue by using noisy estimates of mean and variance. You might think that this noisiness should be a problem. Quite to the contrary, this is actually beneficial.

This turns out to be a recurring theme in deep learning. For reasons
that are not yet well-characterized theoretically, various sources of
noise in optimization often lead to faster training and less
overfitting: this variation appears to act as a form of regularization.
(Teye *et al.*, 2018) and (Luo *et al.*, 2018)
relate the properties of batch normalization to Bayesian priors and
penalties respectively. In particular, this sheds some light on the
puzzle of why batch normalization works best for moderate minibatches
sizes in the \(50 \sim 100\) range. This particular size of
minibatch seems to inject just the “right amount” of noise per layer,
both in terms of scale via \(\hat{\boldsymbol{\sigma}}\), and in
terms of offset via \(\hat{\boldsymbol{\mu}}\): a larger minibatch
regularizes less due to the more stable estimates, whereas tiny
minibatches destroy useful signal due to high variance. Exploring this
direction further, considering alternative types of preprocessing and
filtering may yet lead to other effective types of regularization.

Fixing a trained model, you might think that we would prefer using the
entire dataset to estimate the mean and variance. Once training is
complete, why would we want the same image to be classified differently,
depending on the batch in which it happens to reside? During training,
such exact calculation is infeasible because the intermediate variables
for all data examples change every time we update our model. However,
once the model is trained, we can calculate the means and variances of
each layer’s variables based on the entire dataset. Indeed this is
standard practice for models employing batch normalization and thus
batch normalization layers function differently in *training mode*
(normalizing by minibatch statistics) and in *prediction mode*
(normalizing by dataset statistics). In this form they closely resemble
the behavior of dropout regularization of Section 5.6, where
noise is only injected during training.

## 8.5.2. Batch Normalization Layers¶

Batch normalization implementations for fully connected layers and convolutional layers are slightly different. One key difference between batch normalization and other layers is that because batch normalization operates on a full minibatch at a time, we cannot just ignore the batch dimension as we did before when introducing other layers.

### 8.5.2.1. Fully Connected Layers¶

When applying batch normalization to fully connected layers, the
original paper inserted batch normalization after the affine
transformation and *before* the nonlinear activation function. Later
applications experimented with inserting batch normalization right
*after* activation functions (Ioffe and Szegedy, 2015). Denoting the
input to the fully connected layer by \(\mathbf{x}\), the affine
transformation by \(\mathbf{W}\mathbf{x} + \mathbf{b}\) (with the
weight parameter \(\mathbf{W}\) and the bias parameter
\(\mathbf{b}\)), and the activation function by \(\phi\), we can
express the computation of a batch-normalization-enabled, fully
connected layer output \(\mathbf{h}\) as follows:

Recall that mean and variance are computed on the *same* minibatch on
which the transformation is applied.

### 8.5.2.2. Convolutional Layers¶

Similarly, with convolutional layers, we can apply batch normalization
after the convolution and before the nonlinear activation function. The
key difference from batch normalization in fully connected layers is
that we apply the operation on a per-channel basis *across all
locations*. This is compatible with our assumption of translation
invariance that led to convolutions: we assumed that the specific
location of a pattern within an image was not critical for the purpose
of understanding.

Assume that our minibatches contain \(m\) examples and that for each channel, the output of the convolution has height \(p\) and width \(q\). For convolutional layers, we carry out each batch normalization over the \(m \cdot p \cdot q\) elements per output channel simultaneously. Thus, we collect the values over all spatial locations when computing the mean and variance and consequently apply the same mean and variance within a given channel to normalize the value at each spatial location. Each channel has its own scale and shift parameters, both of which are scalars.

### 8.5.2.3. Layer Normalization¶

Note that in the context of convolutions the batch normalization is
well-defined even for minibatches of size 1: after all, we have all the
locations across an image to average. Consequently, mean and variance
are well defined, even if it’s just within a single observation. This
consideration led Ba *et al.* (2016) to introduce the
notion of *layer normalization*. It works just like a batch norm, only
that it is applied to one observation at a time. Consequently both the
offset and the scaling factor are scalars. Given an
\(n\)-dimensional vector \(\mathbf{x}\) layer norms are given by

where scaling and offset are applied coefficient-wise and given by

As before we add a small offset \(\epsilon > 0\) to prevent division by zero. One of the major benefits of using layer normalization is that it prevents divergence. After all, ignoring \(\epsilon\), the output of the layer normalization is scale independent. That is, we have \(\mathrm{LN}(\mathbf{x}) \approx \mathrm{LN}(\alpha \mathbf{x})\) for any choice of \(\alpha \neq 0\). This becomes an equality for \(|\alpha| \to \infty\) (the approximate equality is due to the offset \(\epsilon\) for the variance).

Another advantage of the layer normalization is that it doesn’t depend on the minibatch size. It is also independent of whether we are in training or test regime. In other words, it is simply a deterministic transformation that standardizes the activations to a given scale. This can be very beneficial in preventing divergence in optimization. We skip further details and recommend the interested reader to consult the original paper.

### 8.5.2.4. Batch Normalization During Prediction¶

As we mentioned earlier, batch normalization typically behaves differently in training mode and prediction mode. First, the noise in the sample mean and the sample variance arising from estimating each on minibatches are no longer desirable once we have trained the model. Second, we might not have the luxury of computing per-batch normalization statistics. For example, we might need to apply our model to make one prediction at a time.

Typically, after training, we use the entire dataset to compute stable estimates of the variable statistics and then fix them at prediction time. Consequently, batch normalization behaves differently during training and at test time. Recall that dropout also exhibits this characteristic.

## 8.5.3. Implementation from Scratch¶

To see how batch normalization works in practice, we implement one from scratch below.

```
import torch
from torch import nn
from d2l import torch as d2l
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
# Use `is_grad_enabled` to determine whether we are in training mode
if not torch.is_grad_enabled():
# In prediction mode, use mean and variance obtained by moving average
X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
# When using a fully connected layer, calculate the mean and
# variance on the feature dimension
mean = X.mean(dim=0)
var = ((X - mean) ** 2).mean(dim=0)
else:
# When using a two-dimensional convolutional layer, calculate the
# mean and variance on the channel dimension (axis=1). Here we
# need to maintain the shape of `X`, so that the broadcasting
# operation can be carried out later
mean = X.mean(dim=(0, 2, 3), keepdim=True)
var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
# In training mode, the current mean and variance are used
X_hat = (X - mean) / torch.sqrt(var + eps)
# Update the mean and variance using moving average
moving_mean = (1.0 - momentum) * moving_mean + momentum * mean
moving_var = (1.0 - momentum) * moving_var + momentum * var
Y = gamma * X_hat + beta # Scale and shift
return Y, moving_mean.data, moving_var.data
```

```
from mxnet import autograd, init, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
# Use `autograd` to determine whether we are in training mode
if not autograd.is_training():
# In prediction mode, use mean and variance obtained by moving average
X_hat = (X - moving_mean) / np.sqrt(moving_var + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
# When using a fully connected layer, calculate the mean and
# variance on the feature dimension
mean = X.mean(axis=0)
var = ((X - mean) ** 2).mean(axis=0)
else:
# When using a two-dimensional convolutional layer, calculate the
# mean and variance on the channel dimension (axis=1). Here we
# need to maintain the shape of `X`, so that the broadcasting
# operation can be carried out later
mean = X.mean(axis=(0, 2, 3), keepdims=True)
var = ((X - mean) ** 2).mean(axis=(0, 2, 3), keepdims=True)
# In training mode, the current mean and variance are used
X_hat = (X - mean) / np.sqrt(var + eps)
# Update the mean and variance using moving average
moving_mean = (1.0 - momentum) * moving_mean + momentum * mean
moving_var = (1.0 - momentum) * moving_var + momentum * var
Y = gamma * X_hat + beta # Scale and shift
return Y, moving_mean, moving_var
```

```
from functools import partial
import jax
import optax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
def batch_norm(X, deterministic, gamma, beta, moving_mean, moving_var, eps,
momentum):
# Use `deterministic` to determine whether the current mode is training
# mode or prediction mode
if deterministic:
# In prediction mode, use mean and variance obtained by moving average
# `linen.Module.variables` have a `value` attribute containing the array
X_hat = (X - moving_mean.value) / jnp.sqrt(moving_var.value + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
# When using a fully connected layer, calculate the mean and
# variance on the feature dimension
mean = X.mean(axis=0)
var = ((X - mean) ** 2).mean(axis=0)
else:
# When using a two-dimensional convolutional layer, calculate the
# mean and variance on the channel dimension (axis=1). Here we
# need to maintain the shape of `X`, so that the broadcasting
# operation can be carried out later
mean = X.mean(axis=(0, 2, 3), keepdims=True)
var = ((X - mean) ** 2).mean(axis=(0, 2, 3), keepdims=True)
# In training mode, the current mean and variance are used
X_hat = (X - mean) / jnp.sqrt(var + eps)
# Update the mean and variance using moving average
moving_mean.value = momentum * moving_mean.value + (1.0 - momentum) * mean
moving_var.value = momentum * moving_var.value + (1.0 - momentum) * var
Y = gamma * X_hat + beta # Scale and shift
return Y
```

```
import tensorflow as tf
from d2l import tensorflow as d2l
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps):
# Compute reciprocal of square root of the moving variance elementwise
inv = tf.cast(tf.math.rsqrt(moving_var + eps), X.dtype)
# Scale and shift
inv *= gamma
Y = X * inv + (beta - moving_mean * inv)
return Y
```

We can now create a proper `BatchNorm`

layer. Our layer will maintain
proper parameters for scale `gamma`

and shift `beta`

, both of which
will be updated in the course of training. Additionally, our layer will
maintain moving averages of the means and variances for subsequent use
during model prediction.

Putting aside the algorithmic details, note the design pattern
underlying our implementation of the layer. Typically, we define the
mathematics in a separate function, say `batch_norm`

. We then
integrate this functionality into a custom layer, whose code mostly
addresses bookkeeping matters, such as moving data to the right device
context, allocating and initializing any required variables, keeping
track of moving averages (here for mean and variance), and so on. This
pattern enables a clean separation of mathematics from boilerplate code.
Also note that for the sake of convenience we did not worry about
automatically inferring the input shape here, thus we need to specify
the number of features throughout. By now all modern deep learning
frameworks offer automatic detection of size and shape in the high-level
batch normalization APIs (in practice we will use this instead).

```
class BatchNorm(nn.Module):
# `num_features`: the number of outputs for a fully connected layer
# or the number of output channels for a convolutional layer. `num_dims`:
# 2 for a fully connected layer and 4 for a convolutional layer
def __init__(self, num_features, num_dims):
super().__init__()
if num_dims == 2:
shape = (1, num_features)
else:
shape = (1, num_features, 1, 1)
# The scale parameter and the shift parameter (model parameters) are
# initialized to 1 and 0, respectively
self.gamma = nn.Parameter(torch.ones(shape))
self.beta = nn.Parameter(torch.zeros(shape))
# The variables that are not model parameters are initialized to 0 and
# 1
self.moving_mean = torch.zeros(shape)
self.moving_var = torch.ones(shape)
def forward(self, X):
# If `X` is not on the main memory, copy `moving_mean` and
# `moving_var` to the device where `X` is located
if self.moving_mean.device != X.device:
self.moving_mean = self.moving_mean.to(X.device)
self.moving_var = self.moving_var.to(X.device)
# Save the updated `moving_mean` and `moving_var`
Y, self.moving_mean, self.moving_var = batch_norm(
X, self.gamma, self.beta, self.moving_mean,
self.moving_var, eps=1e-5, momentum=0.1)
return Y
```

```
class BatchNorm(nn.Block):
# `num_features`: the number of outputs for a fully connected layer
# or the number of output channels for a convolutional layer. `num_dims`:
# 2 for a fully connected layer and 4 for a convolutional layer
def __init__(self, num_features, num_dims, **kwargs):
super().__init__(**kwargs)
if num_dims == 2:
shape = (1, num_features)
else:
shape = (1, num_features, 1, 1)
# The scale parameter and the shift parameter (model parameters) are
# initialized to 1 and 0, respectively
self.gamma = self.params.get('gamma', shape=shape, init=init.One())
self.beta = self.params.get('beta', shape=shape, init=init.Zero())
# The variables that are not model parameters are initialized to 0 and
# 1
self.moving_mean = np.zeros(shape)
self.moving_var = np.ones(shape)
def forward(self, X):
# If `X` is not on the main memory, copy `moving_mean` and
# `moving_var` to the device where `X` is located
if self.moving_mean.ctx != X.ctx:
self.moving_mean = self.moving_mean.copyto(X.ctx)
self.moving_var = self.moving_var.copyto(X.ctx)
# Save the updated `moving_mean` and `moving_var`
Y, self.moving_mean, self.moving_var = batch_norm(
X, self.gamma.data(), self.beta.data(), self.moving_mean,
self.moving_var, eps=1e-12, momentum=0.1)
return Y
```

```
class BatchNorm(nn.Module):
# `num_features`: the number of outputs for a fully connected layer
# or the number of output channels for a convolutional layer.
# `num_dims`: 2 for a fully connected layer and 4 for a convolutional layer
# Use `deterministic` to determine whether the current mode is training
# mode or prediction mode
num_features: int
num_dims: int
deterministic: bool = False
@nn.compact
def __call__(self, X):
if self.num_dims == 2:
shape = (1, self.num_features)
else:
shape = (1, 1, 1, self.num_features)
# The scale parameter and the shift parameter (model parameters) are
# initialized to 1 and 0, respectively
gamma = self.param('gamma', jax.nn.initializers.ones, shape)
beta = self.param('beta', jax.nn.initializers.zeros, shape)
# The variables that are not model parameters are initialized to 0 and
# 1. Save them to the 'batch_stats' collection
moving_mean = self.variable('batch_stats', 'moving_mean', jnp.zeros, shape)
moving_var = self.variable('batch_stats', 'moving_var', jnp.ones, shape)
Y = batch_norm(X, self.deterministic, gamma, beta,
moving_mean, moving_var, eps=1e-5, momentum=0.9)
return Y
```

```
class BatchNorm(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(BatchNorm, self).__init__(**kwargs)
def build(self, input_shape):
weight_shape = [input_shape[-1], ]
# The scale parameter and the shift parameter (model parameters) are
# initialized to 1 and 0, respectively
self.gamma = self.add_weight(name='gamma', shape=weight_shape,
initializer=tf.initializers.ones, trainable=True)
self.beta = self.add_weight(name='beta', shape=weight_shape,
initializer=tf.initializers.zeros, trainable=True)
# The variables that are not model parameters are initialized to 0
self.moving_mean = self.add_weight(name='moving_mean',
shape=weight_shape, initializer=tf.initializers.zeros,
trainable=False)
self.moving_variance = self.add_weight(name='moving_variance',
shape=weight_shape, initializer=tf.initializers.ones,
trainable=False)
super(BatchNorm, self).build(input_shape)
def assign_moving_average(self, variable, value):
momentum = 0.1
delta = (1.0 - momentum) * variable + momentum * value
return variable.assign(delta)
@tf.function
def call(self, inputs, training):
if training:
axes = list(range(len(inputs.shape) - 1))
batch_mean = tf.reduce_mean(inputs, axes, keepdims=True)
batch_variance = tf.reduce_mean(tf.math.squared_difference(
inputs, tf.stop_gradient(batch_mean)), axes, keepdims=True)
batch_mean = tf.squeeze(batch_mean, axes)
batch_variance = tf.squeeze(batch_variance, axes)
mean_update = self.assign_moving_average(
self.moving_mean, batch_mean)
variance_update = self.assign_moving_average(
self.moving_variance, batch_variance)
self.add_update(mean_update)
self.add_update(variance_update)
mean, variance = batch_mean, batch_variance
else:
mean, variance = self.moving_mean, self.moving_variance
output = batch_norm(inputs, moving_mean=mean, moving_var=variance,
beta=self.beta, gamma=self.gamma, eps=1e-5)
return output
```

We used `momentum`

to govern the aggregation over past mean and
variance estimates. This is somewhat of a misnomer as it has nothing
whatsoever to do with the *momentum* term of optimization in
Section 12.6. Nonetheless, it is the commonly adopted name
for this term and in deference to API naming convention we use the same
variable name in our code, too.

## 8.5.4. LeNet with Batch Normalization¶

To see how to apply `BatchNorm`

in context, below we apply it to a
traditional LeNet model (Section 7.6). Recall that batch
normalization is applied after the convolutional layers or fully
connected layers but before the corresponding activation functions.

```
class BNLeNetScratch(d2l.Classifier):
def __init__(self, lr=0.1, num_classes=10):
super().__init__()
self.save_hyperparameters()
self.net = nn.Sequential(
nn.LazyConv2d(6, kernel_size=5), BatchNorm(6, num_dims=4),
nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
nn.LazyConv2d(16, kernel_size=5), BatchNorm(16, num_dims=4),
nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
nn.Flatten(), nn.LazyLinear(120),
BatchNorm(120, num_dims=2), nn.Sigmoid(), nn.LazyLinear(84),
BatchNorm(84, num_dims=2), nn.Sigmoid(),
nn.LazyLinear(num_classes))
```

```
class BNLeNetScratch(d2l.Classifier):
def __init__(self, lr=0.1, num_classes=10):
super().__init__()
self.save_hyperparameters()
self.net = nn.Sequential()
self.net.add(
nn.Conv2D(6, kernel_size=5), BatchNorm(6, num_dims=4),
nn.Activation('sigmoid'),
nn.AvgPool2D(pool_size=2, strides=2),
nn.Conv2D(16, kernel_size=5), BatchNorm(16, num_dims=4),
nn.Activation('sigmoid'),
nn.AvgPool2D(pool_size=2, strides=2), nn.Dense(120),
BatchNorm(120, num_dims=2), nn.Activation('sigmoid'),
nn.Dense(84), BatchNorm(84, num_dims=2),
nn.Activation('sigmoid'), nn.Dense(num_classes))
self.initialize()
```

```
class BNLeNetScratch(d2l.Classifier):
lr: float = 0.1
num_classes: int = 10
training: bool = True
def setup(self):
self.net = nn.Sequential([
nn.Conv(6, kernel_size=(5, 5)),
BatchNorm(6, num_dims=4, deterministic=not self.training),
nn.sigmoid,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
nn.Conv(16, kernel_size=(5, 5)),
BatchNorm(16, num_dims=4, deterministic=not self.training),
nn.sigmoid,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
lambda x: x.reshape((x.shape[0], -1)),
nn.Dense(120),
BatchNorm(120, num_dims=2, deterministic=not self.training),
nn.sigmoid,
nn.Dense(84),
BatchNorm(84, num_dims=2, deterministic=not self.training),
nn.sigmoid,
nn.Dense(self.num_classes)])
```

Since `BatchNorm`

layers need to calculate the batch statistics (mean
and variance), Flax keeps track of the `batch_stats`

dictionary,
updating them with every minibatch. Collections like `batch_stats`

can
be stored in the `TrainState`

object (in the `d2l.Trainer`

class
defined in Section 3.2.4) as an attribute and during
the model’s forward pass, these should be passed to the `mutable`

argument, so that Flax returns the mutated variables.

```
@d2l.add_to_class(d2l.Classifier) #@save
@partial(jax.jit, static_argnums=(0, 5))
def loss(self, params, X, Y, state, averaged=True):
Y_hat, updates = state.apply_fn({'params': params,
'batch_stats': state.batch_stats},
*X, mutable=['batch_stats'],
rngs={'dropout': state.dropout_rng})
Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
Y = Y.reshape((-1,))
fn = optax.softmax_cross_entropy_with_integer_labels
return (fn(Y_hat, Y).mean(), updates) if averaged else (fn(Y_hat, Y), updates)
```

```
class BNLeNetScratch(d2l.Classifier):
def __init__(self, lr=0.1, num_classes=10):
super().__init__()
self.save_hyperparameters()
self.net = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=6, kernel_size=5,
input_shape=(28, 28, 1)),
BatchNorm(), tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.AvgPool2D(pool_size=2, strides=2),
tf.keras.layers.Conv2D(filters=16, kernel_size=5),
BatchNorm(), tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.AvgPool2D(pool_size=2, strides=2),
tf.keras.layers.Flatten(), tf.keras.layers.Dense(120),
BatchNorm(), tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.Dense(84), BatchNorm(),
tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.Dense(num_classes)])
```

As before, we will train our network on the Fashion-MNIST dataset. This code is virtually identical to that when we first trained LeNet.

```
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128)
model = BNLeNetScratch(lr=0.1)
model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)
trainer.fit(model, data)
```

```
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128)
model = BNLeNetScratch(lr=0.1)
trainer.fit(model, data)
```

```
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128)
model = BNLeNetScratch(lr=0.1)
trainer.fit(model, data)
```

```
trainer = d2l.Trainer(max_epochs=10)
data = d2l.FashionMNIST(batch_size=128)
with d2l.try_gpu():
model = BNLeNetScratch(lr=0.5)
trainer.fit(model, data)
```

Let’s have a look at the scale parameter `gamma`

and the shift
parameter `beta`

learned from the first batch normalization layer.

```
model.net[1].gamma.reshape((-1,)), model.net[1].beta.reshape((-1,))
```

```
(tensor([2.4753, 2.1055, 1.9300, 1.5846, 1.5855, 1.1205], device='cuda:0',
grad_fn=<ReshapeAliasBackward0>),
tensor([-1.3341, -0.7705, -1.9549, 0.4815, 0.7987, -1.0323], device='cuda:0',
grad_fn=<ReshapeAliasBackward0>))
```

```
model.net[1].gamma.data().reshape(-1,), model.net[1].beta.data().reshape(-1,)
```

```
(array([1.9818538, 1.3939314, 1.5982249, 2.0519156, 2.2772565, 1.6727468], ctx=gpu(0)),
array([ 1.1422609 , 1.1089567 , -1.442979 , 0.63024527, 0.7528393 ,
0.95314455], ctx=gpu(0)))
```

```
trainer.state.params['net']['layers_1']['gamma'].reshape((-1,)), \
trainer.state.params['net']['layers_1']['beta'].reshape((-1,))
```

```
(Array([1.3267217, 2.4688377, 1.7215432, 2.7555544, 1.2978013, 1.9420222], dtype=float32),
Array([-0.0358162 , -0.56116295, 0.11314265, 0.9476022 , 0.15608062,
-0.48156735], dtype=float32))
```

```
tf.reshape(model.net.layers[1].gamma, (-1,)), tf.reshape(
model.net.layers[1].beta, (-1,))
```

```
(<tf.Tensor: shape=(6,), dtype=float32, numpy=
array([2.1537447, 1.3743037, 3.4600253, 3.494863 , 4.2359805, 2.8958514],
dtype=float32)>,
<tf.Tensor: shape=(6,), dtype=float32, numpy=
array([-0.7886159 , 0.08862703, -0.07236256, -0.3829588 , 0.26527518,
0.2775884 ], dtype=float32)>)
```

## 8.5.5. Concise Implementation¶

Compared with the `BatchNorm`

class, which we just defined ourselves,
we can use the `BatchNorm`

class defined in high-level APIs from the
deep learning framework directly. The code looks virtually identical to
our implementation above, except that we no longer need to provide
additional arguments for it to get the dimensions right.

```
class BNLeNet(d2l.Classifier):
def __init__(self, lr=0.1, num_classes=10):
super().__init__()
self.save_hyperparameters()
self.net = nn.Sequential(
nn.LazyConv2d(6, kernel_size=5), nn.LazyBatchNorm2d(),
nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
nn.LazyConv2d(16, kernel_size=5), nn.LazyBatchNorm2d(),
nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
nn.Flatten(), nn.LazyLinear(120), nn.LazyBatchNorm1d(),
nn.Sigmoid(), nn.LazyLinear(84), nn.LazyBatchNorm1d(),
nn.Sigmoid(), nn.LazyLinear(num_classes))
```

```
class BNLeNet(d2l.Classifier):
def __init__(self, lr=0.1, num_classes=10):
super().__init__()
self.save_hyperparameters()
self.net = nn.Sequential()
self.net.add(
nn.Conv2D(6, kernel_size=5), nn.BatchNorm(),
nn.Activation('sigmoid'),
nn.AvgPool2D(pool_size=2, strides=2),
nn.Conv2D(16, kernel_size=5), nn.BatchNorm(),
nn.Activation('sigmoid'),
nn.AvgPool2D(pool_size=2, strides=2),
nn.Dense(120), nn.BatchNorm(), nn.Activation('sigmoid'),
nn.Dense(84), nn.BatchNorm(), nn.Activation('sigmoid'),
nn.Dense(num_classes))
self.initialize()
```

```
class BNLeNet(d2l.Classifier):
lr: float = 0.1
num_classes: int = 10
training: bool = True
def setup(self):
self.net = nn.Sequential([
nn.Conv(6, kernel_size=(5, 5)),
nn.BatchNorm(not self.training),
nn.sigmoid,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
nn.Conv(16, kernel_size=(5, 5)),
nn.BatchNorm(not self.training),
nn.sigmoid,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
lambda x: x.reshape((x.shape[0], -1)),
nn.Dense(120),
nn.BatchNorm(not self.training),
nn.sigmoid,
nn.Dense(84),
nn.BatchNorm(not self.training),
nn.sigmoid,
nn.Dense(self.num_classes)])
```

```
class BNLeNet(d2l.Classifier):
def __init__(self, lr=0.1, num_classes=10):
super().__init__()
self.save_hyperparameters()
self.net = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=6, kernel_size=5,
input_shape=(28, 28, 1)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.AvgPool2D(pool_size=2, strides=2),
tf.keras.layers.Conv2D(filters=16, kernel_size=5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.AvgPool2D(pool_size=2, strides=2),
tf.keras.layers.Flatten(), tf.keras.layers.Dense(120),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.Dense(84),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.Dense(num_classes)])
```

Below, we use the same hyperparameters to train our model. Note that as usual, the high-level API variant runs much faster because its code has been compiled to C++ or CUDA while our custom implementation must be interpreted by Python.

```
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128)
model = BNLeNet(lr=0.1)
model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)
trainer.fit(model, data)
```

```
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128)
model = BNLeNet(lr=0.1)
trainer.fit(model, data)
```

```
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128)
model = BNLeNet(lr=0.1)
trainer.fit(model, data)
```

```
trainer = d2l.Trainer(max_epochs=10)
data = d2l.FashionMNIST(batch_size=128)
with d2l.try_gpu():
model = BNLeNet(lr=0.5)
trainer.fit(model, data)
```

## 8.5.6. Discussion¶

Intuitively, batch normalization is thought to make the optimization landscape smoother. However, we must be careful to distinguish between speculative intuitions and true explanations for the phenomena that we observe when training deep models. Recall that we do not even know why simpler deep neural networks (MLPs and conventional CNNs) generalize well in the first place. Even with dropout and weight decay, they remain so flexible that their ability to generalize to unseen data likely needs significantly more refined learning-theoretic generalization guarantees.

In the original paper proposing batch normalization
(Ioffe and Szegedy, 2015), in addition to introducing a powerful and
useful tool, offered an explanation for why it works: by reducing
*internal covariate shift*. Presumably by *internal covariate shift* the
authors meant something like the intuition expressed above—the notion
that the distribution of variable values changes over the course of
training. However, there were two problems with this explanation: i)
This drift is very different from *covariate shift*, rendering the name
a misnomer. If anything, it’s closer to concept drift. ii) The
explanation offers an under-specified intuition but leaves the question
of *why precisely this technique works* an open question wanting for a
rigorous explanation. Throughout this book, we aim to convey the
intuitions that practitioners use to guide their development of deep
neural networks. However, we believe that it is important to separate
these guiding intuitions from established scientific fact. Eventually,
when you master this material and start writing your own research papers
you will want to be clear to delineate between technical claims and
hunches.

Following the success of batch normalization, its explanation in terms
of *internal covariate shift* has repeatedly surfaced in debates in the
technical literature and broader discourse about how to present machine
learning research. In a memorable speech given while accepting a Test of
Time Award at the 2017 NeurIPS conference, Ali Rahimi used *internal
covariate shift* as a focal point in an argument likening the modern
practice of deep learning to alchemy. Subsequently, the example was
revisited in detail in a position paper outlining troubling trends in
machine learning (Lipton and Steinhardt, 2018). Other authors have
proposed alternative explanations for the success of batch
normalization, some claiming that batch normalization’s success comes
despite exhibiting behavior that is in some ways opposite to those
claimed in the original paper (Santurkar *et al.*, 2018).

We note that the *internal covariate shift* is no more worthy of
criticism than any of thousands of similarly vague claims made every
year in the technical machine learning literature. Likely, its resonance
as a focal point of these debates owes to its broad recognizability to
the target audience. Batch normalization has proven an indispensable
method, applied in nearly all deployed image classifiers, earning the
paper that introduced the technique tens of thousands of citations. We
conjecture, though, that the guiding principles of regularization
through noise injection, acceleration through rescaling and lastly
preprocessing may well lead to further inventions of layers and
techniques in the future.

On a more practical note, there are a number of aspects worth remembering about batch normalization:

During model training, batch normalization continuously adjusts the intermediate output of the network by utilizing the mean and standard deviation of the minibatch, so that the values of the intermediate output in each layer throughout the neural network are more stable.

Batch normalization for fully connected layers and convolutional layers are slightly different. In fact, for convolutional layers, layer normalization can sometimes be used as an alternative.

Like a dropout layer, batch normalization layers have different behaviors in training mode and prediction mode.

Batch normalization is useful for regularization and improving convergence in optimization. On the other hand, the original motivation of reducing internal covariate shift seems not to be a valid explanation.

## 8.5.7. Exercises¶

Should we remove the bias parameter from the fully connected layer or the convolutional layer before the batch normalization? Why?

Compare the learning rates for LeNet with and without batch normalization.

Plot the increase in validation accuracy.

How large can you make the learning rate before the optimization fails in both cases?

Do we need batch normalization in every layer? Experiment with it?

Implement a “lite” version of batch normalization that only removes the mean, or alternatively one that only removes the variance. How does it behave?

Fix the parameters

`beta`

and`gamma`

. Observe and analyze the results.Can you replace dropout by batch normalization? How does the behavior change?

Research ideas: think of other normalization transforms that you can apply:

Can you apply the probability integral transform?

Can you use a full rank covariance estimate? Why should you probably not do that?

Can you use other compact matrix variants (block-diagonal, low-displacement rank, Monarch, etc.)?

Does a sparsification compression act as a regularizer?

Are there other projections (e.g., convex cone, symmetry group-specific transforms) that you can use?