.. _sec_classification: The Base Classification Model ============================= You may have noticed that the implementations from scratch and the concise implementation using framework functionality were quite similar in the case of regression. The same is true for classification. Since many models in this book deal with classification, it is worth adding functionalities to support this setting specifically. This section provides a base class for classification models to simplify future code. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import torch from d2l import torch as d2l .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python from mxnet import autograd, gluon, np, npx from d2l import mxnet as d2l npx.set_np() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python from functools import partial import jax import optax from jax import numpy as jnp from d2l import jax as d2l .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import tensorflow as tf from d2l import tensorflow as d2l .. raw:: html
.. raw:: html
The ``Classifier`` Class ------------------------ .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
We define the ``Classifier`` class below. In the ``validation_step`` we report both the loss value and the classification accuracy on a validation batch. We draw an update for every ``num_val_batches`` batches. This has the benefit of generating the averaged loss and accuracy on the whole validation data. These average numbers are not exactly correct if the final batch contains fewer examples, but we ignore this minor difference to keep the code simple. .. raw:: latex \diilbookstyleinputcell .. code:: python class Classifier(d2l.Module): #@save """The base class of classification models.""" def validation_step(self, batch): Y_hat = self(*batch[:-1]) self.plot('loss', self.loss(Y_hat, batch[-1]), train=False) self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False) .. raw:: html
.. raw:: html
We define the ``Classifier`` class below. In the ``validation_step`` we report both the loss value and the classification accuracy on a validation batch. We draw an update for every ``num_val_batches`` batches. This has the benefit of generating the averaged loss and accuracy on the whole validation data. These average numbers are not exactly correct if the final batch contains fewer examples, but we ignore this minor difference to keep the code simple. .. raw:: latex \diilbookstyleinputcell .. code:: python class Classifier(d2l.Module): #@save """The base class of classification models.""" def validation_step(self, batch): Y_hat = self(*batch[:-1]) self.plot('loss', self.loss(Y_hat, batch[-1]), train=False) self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False) .. raw:: html
.. raw:: html
We define the ``Classifier`` class below. In the ``validation_step`` we report both the loss value and the classification accuracy on a validation batch. We draw an update for every ``num_val_batches`` batches. This has the benefit of generating the averaged loss and accuracy on the whole validation data. These average numbers are not exactly correct if the last batch contains fewer examples, but we ignore this minor difference to keep the code simple. We also redefine the ``training_step`` method for JAX since all models that will subclass ``Classifier`` later will have a loss that returns auxiliary data. This auxiliary data can be used for models with batch normalization (to be explained in :numref:`sec_batch_norm`), while in all other cases we will make the loss also return a placeholder (empty dictionary) to represent the auxiliary data. .. raw:: latex \diilbookstyleinputcell .. code:: python class Classifier(d2l.Module): #@save """The base class of classification models.""" def training_step(self, params, batch, state): # Here value is a tuple since models with BatchNorm layers require # the loss to return auxiliary data value, grads = jax.value_and_grad( self.loss, has_aux=True)(params, batch[:-1], batch[-1], state) l, _ = value self.plot("loss", l, train=True) return value, grads def validation_step(self, params, batch, state): # Discard the second returned value. It is used for training models # with BatchNorm layers since loss also returns auxiliary data l, _ = self.loss(params, batch[:-1], batch[-1], state) self.plot('loss', l, train=False) self.plot('acc', self.accuracy(params, batch[:-1], batch[-1], state), train=False) .. raw:: html
.. raw:: html
We define the ``Classifier`` class below. In the ``validation_step`` we report both the loss value and the classification accuracy on a validation batch. We draw an update for every ``num_val_batches`` batches. This has the benefit of generating the averaged loss and accuracy on the whole validation data. These average numbers are not exactly correct if the final batch contains fewer examples, but we ignore this minor difference to keep the code simple. .. raw:: latex \diilbookstyleinputcell .. code:: python class Classifier(d2l.Module): #@save """The base class of classification models.""" def validation_step(self, batch): Y_hat = self(*batch[:-1]) self.plot('loss', self.loss(Y_hat, batch[-1]), train=False) self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False) .. raw:: html
.. raw:: html
By default we use a stochastic gradient descent optimizer, operating on minibatches, just as we did in the context of linear regression. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python @d2l.add_to_class(d2l.Module) #@save def configure_optimizers(self): return torch.optim.SGD(self.parameters(), lr=self.lr) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python @d2l.add_to_class(d2l.Module) #@save def configure_optimizers(self): params = self.parameters() if isinstance(params, list): return d2l.SGD(params, self.lr) return gluon.Trainer(params, 'sgd', {'learning_rate': self.lr}) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python @d2l.add_to_class(d2l.Module) #@save def configure_optimizers(self): return optax.sgd(self.lr) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python @d2l.add_to_class(d2l.Module) #@save def configure_optimizers(self): return tf.keras.optimizers.SGD(self.lr) .. raw:: html
.. raw:: html
Accuracy -------- Given the predicted probability distribution ``y_hat``, we typically choose the class with the highest predicted probability whenever we must output a hard prediction. Indeed, many applications require that we make a choice. For instance, Gmail must categorize an email into “Primary”, “Social”, “Updates”, “Forums”, or “Spam”. It might estimate probabilities internally, but at the end of the day it has to choose one among the classes. When predictions are consistent with the label class ``y``, they are correct. The classification accuracy is the fraction of all predictions that are correct. Although it can be difficult to optimize accuracy directly (it is not differentiable), it is often the performance measure that we care about the most. It is often *the* relevant quantity in benchmarks. As such, we will nearly always report it when training classifiers. Accuracy is computed as follows. First, if ``y_hat`` is a matrix, we assume that the second dimension stores prediction scores for each class. We use ``argmax`` to obtain the predicted class by the index for the largest entry in each row. Then we compare the predicted class with the ground truth ``y`` elementwise. Since the equality operator ``==`` is sensitive to data types, we convert ``y_hat``\ ’s data type to match that of ``y``. The result is a tensor containing entries of 0 (false) and 1 (true). Taking the sum yields the number of correct predictions. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python @d2l.add_to_class(Classifier) #@save def accuracy(self, Y_hat, Y, averaged=True): """Compute the number of correct predictions.""" Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1])) preds = Y_hat.argmax(axis=1).type(Y.dtype) compare = (preds == Y.reshape(-1)).type(torch.float32) return compare.mean() if averaged else compare .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python @d2l.add_to_class(Classifier) #@save def accuracy(self, Y_hat, Y, averaged=True): """Compute the number of correct predictions.""" Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1])) preds = Y_hat.argmax(axis=1).astype(Y.dtype) compare = (preds == Y.reshape(-1)).astype(np.float32) return compare.mean() if averaged else compare @d2l.add_to_class(d2l.Module) #@save def get_scratch_params(self): params = [] for attr in dir(self): a = getattr(self, attr) if isinstance(a, np.ndarray): params.append(a) if isinstance(a, d2l.Module): params.extend(a.get_scratch_params()) return params @d2l.add_to_class(d2l.Module) #@save def parameters(self): params = self.collect_params() return params if isinstance(params, gluon.parameter.ParameterDict) and len( params.keys()) else self.get_scratch_params() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python @d2l.add_to_class(Classifier) #@save @partial(jax.jit, static_argnums=(0, 5)) def accuracy(self, params, X, Y, state, averaged=True): """Compute the number of correct predictions.""" Y_hat = state.apply_fn({'params': params, 'batch_stats': state.batch_stats}, # BatchNorm Only *X) Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1])) preds = Y_hat.argmax(axis=1).astype(Y.dtype) compare = (preds == Y.reshape(-1)).astype(jnp.float32) return compare.mean() if averaged else compare .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python @d2l.add_to_class(Classifier) #@save def accuracy(self, Y_hat, Y, averaged=True): """Compute the number of correct predictions.""" Y_hat = tf.reshape(Y_hat, (-1, Y_hat.shape[-1])) preds = tf.cast(tf.argmax(Y_hat, axis=1), Y.dtype) compare = tf.cast(preds == tf.reshape(Y, -1), tf.float32) return tf.reduce_mean(compare) if averaged else compare .. raw:: html
.. raw:: html
Summary ------- Classification is a sufficiently common problem that it warrants its own convenience functions. Of central importance in classification is the *accuracy* of the classifier. Note that while we often care primarily about accuracy, we train classifiers to optimize a variety of other objectives for statistical and computational reasons. However, regardless of which loss function was minimized during training, it is useful to have a convenience method for assessing the accuracy of our classifier empirically. Exercises --------- 1. Denote by :math:`L_\textrm{v}` the validation loss, and let :math:`L_\textrm{v}^\textrm{q}` be its quick and dirty estimate computed by the loss function averaging in this section. Lastly, denote by :math:`l_\textrm{v}^\textrm{b}` the loss on the last minibatch. Express :math:`L_\textrm{v}` in terms of :math:`L_\textrm{v}^\textrm{q}`, :math:`l_\textrm{v}^\textrm{b}`, and the sample and minibatch sizes. 2. Show that the quick and dirty estimate :math:`L_\textrm{v}^\textrm{q}` is unbiased. That is, show that :math:`E[L_\textrm{v}] = E[L_\textrm{v}^\textrm{q}]`. Why would you still want to use :math:`L_\textrm{v}` instead? 3. Given a multiclass classification loss, denoting by :math:`l(y,y')` the penalty of estimating :math:`y'` when we see :math:`y` and given a probabilty :math:`p(y \mid x)`, formulate the rule for an optimal selection of :math:`y'`. Hint: express the expected loss, using :math:`l` and :math:`p(y \mid x)`. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html