Parameter Management ==================== Once we have chosen an architecture and set our hyperparameters, we proceed to the training loop, where our goal is to find parameter values that minimize our loss function. After training, we will need these parameters in order to make future predictions. Additionally, we will sometimes wish to extract the parameters perhaps to reuse them in some other context, to save our model to disk so that it may be executed in other software, or for examination in the hope of gaining scientific understanding. Most of the time, we will be able to ignore the nitty-gritty details of how parameters are declared and manipulated, relying on deep learning frameworks to do the heavy lifting. However, when we move away from stacked architectures with standard layers, we will sometimes need to get into the weeds of declaring and manipulating parameters. In this section, we cover the following: - Accessing parameters for debugging, diagnostics, and visualizations. - Sharing parameters across different model components. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import torch from torch import nn .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python from mxnet import init, np, npx from mxnet.gluon import nn npx.set_np() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import jax from flax import linen as nn from jax import numpy as jnp from d2l import jax as d2l .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import tensorflow as tf .. raw:: html
.. raw:: html
We start by focusing on an MLP with one hidden layer. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net = nn.Sequential(nn.LazyLinear(8), nn.ReLU(), nn.LazyLinear(1)) X = torch.rand(size=(2, 4)) net(X).shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output torch.Size([2, 1]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net = nn.Sequential() net.add(nn.Dense(8, activation='relu')) net.add(nn.Dense(1)) net.initialize() # Use the default initialization method X = np.random.uniform(size=(2, 4)) net(X).shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [21:49:32] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (2, 1) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net = nn.Sequential([nn.Dense(8), nn.relu, nn.Dense(1)]) X = jax.random.uniform(d2l.get_key(), (2, 4)) params = net.init(d2l.get_key(), X) net.apply(params, X).shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (2, 1) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net = tf.keras.models.Sequential([ tf.keras.layers.Flatten(), tf.keras.layers.Dense(4, activation=tf.nn.relu), tf.keras.layers.Dense(1), ]) X = tf.random.uniform((2, 4)) net(X).shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output TensorShape([2, 1]) .. raw:: html
.. raw:: html
.. _subsec_param-access: Parameter Access ---------------- Let’s start with how to access parameters from the models that you already know. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
When a model is defined via the ``Sequential`` class, we can first access any layer by indexing into the model as though it were a list. Each layer’s parameters are conveniently located in its attribute. .. raw:: html
.. raw:: html
When a model is defined via the ``Sequential`` class, we can first access any layer by indexing into the model as though it were a list. Each layer’s parameters are conveniently located in its attribute. .. raw:: html
.. raw:: html
Flax and JAX decouple the model and the parameters as you might have observed in the models defined previously. When a model is defined via the ``Sequential`` class, we first need to initialize the network to generate the parameters dictionary. We can access any layer’s parameters through the keys of this dictionary. .. raw:: html
.. raw:: html
When a model is defined via the ``Sequential`` class, we can first access any layer by indexing into the model as though it were a list. Each layer’s parameters are conveniently located in its attribute. .. raw:: html
.. raw:: html
We can inspect the parameters of the second fully connected layer as follows. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net[2].state_dict() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output OrderedDict([('weight', tensor([[-0.1649, 0.0605, 0.1694, -0.2524, 0.3526, -0.3414, -0.2322, 0.0822]])), ('bias', tensor([0.0709]))]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net[1].params .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output dense1_ ( Parameter dense1_weight (shape=(1, 8), dtype=float32) Parameter dense1_bias (shape=(1,), dtype=float32) ) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python params['params']['layers_2'] .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output FrozenDict({ kernel: Array([[ 0.2758769 ], [ 0.45259333], [ 0.28696904], [ 0.24622999], [-0.29272735], [ 0.07597765], [ 0.14919828], [ 0.18445292]], dtype=float32), bias: Array([0.], dtype=float32), }) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net.layers[2].weights .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [, ] .. raw:: html
.. raw:: html
We can see that this fully connected layer contains two parameters, corresponding to that layer’s weights and biases, respectively. Targeted Parameters ~~~~~~~~~~~~~~~~~~~ Note that each parameter is represented as an instance of the parameter class. To do anything useful with the parameters, we first need to access the underlying numerical values. There are several ways to do this. Some are simpler while others are more general. The following code extracts the bias from the second neural network layer, which returns a parameter class instance, and further accesses that parameter’s value. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python type(net[2].bias), net[2].bias.data .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (torch.nn.parameter.Parameter, tensor([0.0709])) Parameters are complex objects, containing values, gradients, and additional information. That is why we need to request the value explicitly. In addition to the value, each parameter also allows us to access the gradient. Because we have not invoked backpropagation for this network yet, it is in its initial state. .. raw:: latex \diilbookstyleinputcell .. code:: python net[2].weight.grad == None .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output True .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python type(net[1].bias), net[1].bias.data() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (mxnet.gluon.parameter.Parameter, array([0.])) Parameters are complex objects, containing values, gradients, and additional information. That is why we need to request the value explicitly. In addition to the value, each parameter also allows us to access the gradient. Because we have not invoked backpropagation for this network yet, it is in its initial state. .. raw:: latex \diilbookstyleinputcell .. code:: python net[1].weight.grad() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([[0., 0., 0., 0., 0., 0., 0., 0.]]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python bias = params['params']['layers_2']['bias'] type(bias), bias .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (jaxlib.xla_extension.ArrayImpl, Array([0.], dtype=float32)) Unlike the other frameworks, JAX does not keep a track of the gradients over the neural network parameters, instead the parameters and the network are decoupled. It allows the user to express their computation as a Python function, and use the ``grad`` transformation for the same purpose. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python type(net.layers[2].weights[1]), tf.convert_to_tensor(net.layers[2].weights[1]) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (tensorflow.python.ops.resource_variable_ops.ResourceVariable, ) .. raw:: html
.. raw:: html
All Parameters at Once ~~~~~~~~~~~~~~~~~~~~~~ When we need to perform operations on all parameters, accessing them one-by-one can grow tedious. The situation can grow especially unwieldy when we work with more complex, e.g., nested, modules, since we would need to recurse through the entire tree to extract each sub-module’s parameters. Below we demonstrate accessing the parameters of all layers. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python [(name, param.shape) for name, param in net.named_parameters()] .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [('0.weight', torch.Size([8, 4])), ('0.bias', torch.Size([8])), ('2.weight', torch.Size([1, 8])), ('2.bias', torch.Size([1]))] .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net.collect_params() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output sequential0_ ( Parameter dense0_weight (shape=(8, 4), dtype=float32) Parameter dense0_bias (shape=(8,), dtype=float32) Parameter dense1_weight (shape=(1, 8), dtype=float32) Parameter dense1_bias (shape=(1,), dtype=float32) ) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python jax.tree_util.tree_map(lambda x: x.shape, params) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output FrozenDict({ params: { layers_0: { bias: (8,), kernel: (4, 8), }, layers_2: { bias: (1,), kernel: (8, 1), }, }, }) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net.get_weights() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [array([[-0.06085569, -0.8411268 , -0.28591037, 0.31637532], [ 0.8330259 , 0.4529298 , 0.14709991, -0.18423098], [ 0.835087 , 0.23927861, -0.7909084 , -0.49229068], [ 0.76430553, 0.40979892, 0.09074789, 0.4237972 ]], dtype=float32), array([0., 0., 0., 0.], dtype=float32), array([[-0.892862 ], [ 0.7337135 ], [-0.05061114], [-0.97688395]], dtype=float32), array([0.], dtype=float32)] .. raw:: html
.. raw:: html
Tied Parameters --------------- Often, we want to share parameters across multiple layers. Let’s see how to do this elegantly. In the following we allocate a fully connected layer and then use its parameters specifically to set those of another layer. Here we need to run the forward propagation ``net(X)`` before accessing the parameters. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python # We need to give the shared layer a name so that we can refer to its # parameters shared = nn.LazyLinear(8) net = nn.Sequential(nn.LazyLinear(8), nn.ReLU(), shared, nn.ReLU(), shared, nn.ReLU(), nn.LazyLinear(1)) net(X) # Check whether the parameters are the same print(net[2].weight.data[0] == net[4].weight.data[0]) net[2].weight.data[0, 0] = 100 # Make sure that they are actually the same object rather than just having the # same value print(net[2].weight.data[0] == net[4].weight.data[0]) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([True, True, True, True, True, True, True, True]) tensor([True, True, True, True, True, True, True, True]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net = nn.Sequential() # We need to give the shared layer a name so that we can refer to its # parameters shared = nn.Dense(8, activation='relu') net.add(nn.Dense(8, activation='relu'), shared, nn.Dense(8, activation='relu', params=shared.params), nn.Dense(10)) net.initialize() X = np.random.uniform(size=(2, 20)) net(X) # Check whether the parameters are the same print(net[1].weight.data()[0] == net[2].weight.data()[0]) net[1].weight.data()[0, 0] = 100 # Make sure that they are actually the same object rather than just having the # same value print(net[1].weight.data()[0] == net[2].weight.data()[0]) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [ True True True True True True True True] [ True True True True True True True True] .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python # We need to give the shared layer a name so that we can refer to its # parameters shared = nn.Dense(8) net = nn.Sequential([nn.Dense(8), nn.relu, shared, nn.relu, shared, nn.relu, nn.Dense(1)]) params = net.init(jax.random.PRNGKey(d2l.get_seed()), X) # Check whether the parameters are different print(len(params['params']) == 3) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output True .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python # tf.keras behaves a bit differently. It removes the duplicate layer # automatically shared = tf.keras.layers.Dense(4, activation=tf.nn.relu) net = tf.keras.models.Sequential([ tf.keras.layers.Flatten(), shared, shared, tf.keras.layers.Dense(1), ]) net(X) # Check whether the parameters are different print(len(net.layers) == 3) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output True .. raw:: html
.. raw:: html
This example shows that the parameters of the second and third layer are tied. They are not just equal, they are represented by the same exact tensor. Thus, if we change one of the parameters, the other one changes, too. .. raw:: html
pytorchmxnettensorflow
.. raw:: html
You might wonder, when parameters are tied what happens to the gradients? Since the model parameters contain gradients, the gradients of the second hidden layer and the third hidden layer are added together during backpropagation. .. raw:: html
.. raw:: html
You might wonder, when parameters are tied what happens to the gradients? Since the model parameters contain gradients, the gradients of the second hidden layer and the third hidden layer are added together during backpropagation. .. raw:: html
.. raw:: html
You might wonder, when parameters are tied what happens to the gradients? Since the model parameters contain gradients, the gradients of the second hidden layer and the third hidden layer are added together during backpropagation. .. raw:: html
.. raw:: html
Summary ------- We have several ways of accessing and tying model parameters. Exercises --------- 1. Use the ``NestMLP`` model defined in :numref:`sec_model_construction` and access the parameters of the various layers. 2. Construct an MLP containing a shared parameter layer and train it. During the training process, observe the model parameters and gradients of each layer. 3. Why is sharing parameters a good idea? .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html