You can run this notebook using a free Colab GPU instance (Tesla T4) but some computations may be slow and/or run out of memory. If you have a GCP account and want to use a faster V100 GPU you can follow the instructions here to use that as an alternative Colab backend. I recommend giving it a try if you have a GCP account and haven't used this feature before. Don't forget to shut down the GCP instance (not just the Colab notebook) once you've finished!
FWIW, I use the following (preemptible) instance type:
1 2 3 4 5 6 7 8 9 10 11 12 13
export IMAGE_FAMILY="pytorch-latest-cu100" export ZONE="europe-west4-a" export INSTANCE_NAME="pytorch-colab-backend" gcloud compute instances create $INSTANCE_NAME \ --zone $ZONE \ --machine-type n1-standard-4 \ --accelerator type=nvidia-tesla-v100,count=1 \ --image-family $IMAGE_FAMILY \ --image-project=deeplearning-platform-release \ --metadata install-nvidia-driver=True \ --maintenance-policy TERMINATE \ --preemptible
On GCP you will need to install altair for plotting. This is pre-installed on Colab instances.
You can check the details of your setup (free Colab GPU or V100 on GCP) with this nice utility from fastai:
The original idea of this project was to make an efficient ResNet training implementation go faster by replacing the batch norm layers. Things have not gone to plan (!) - there have been ample opportunities for optimisation, but batch norm has remained stubbornly in place. Why? Today, we're going to explore the detailed mechanism by which batch norm enables the high learning rates crucial for rapid training. Let's start with a high-level view.
Deep networks provide a way to represent a rich class of functions parameterised by learnable weights. The job of training is to select parameters which achieve a desired functionality. Within the class of functions parameterised by the network there are also many degenerate functions, such as those that ignore their inputs and produce constant outputs. We will show that such functions are not anomolies but rather that they exist throughout parameter space - in particular, near to the trajectory that the network follows during training. The training process must observe certain constraints on the model parameters to avoid moving to a degenerate configuration and suffering a catastrophic loss in accuracy.
First-order optimisers such as SGD do not like constraints. The sharp increase in loss in directions perpendicular to the space of 'good' configurations is a source of high curvature and consequent training instability. Batch norm works by reparameterising the function space such that these constraints are easier to enforce, curvature of the loss landscape is diminished and training can proceed at a high rate.
During the course of our investigations we will fill in this picture in detail. We shall learn some surprising behaviours of networks at initialisation and resolve a mystery about spikes in the spectrum of the Hessian...
Let's get under way by reviewing what batch norm is, what it's good for and the drawbacks that have encouraged many people to look for alternatives.
Batch norm acts by removing the mean and normalising the standard deviation of a channel of activations
Statistics @@1@@ are computed over pixels in an image and examples in a batch. They are frozen at test time.
A learnable output mean @@2@@ and standard deviation @@3@@ are usually applied, potentially undoing everything:
This is not as silly as it might seem! Reparametrising the network can radically change the optimisation landscape as we saw in the previous post for the example of weight rescaling.
For today's post, we shall omit @@5@@ and @@6@@ (or equivalently freeze them to 0 and 1) since they add complexity and are largely irrelevant to the issues under study. In order to achieve the highest training accuracies, learnable biases @@7@@ are recommended whilst learnable scales @@8@@ are sometimes actively unhelpful.
Empirically batch norm has been extremely successful especially for training conv nets. Many proposed alternatives have failed to replace it.
The first point - optimisation stability - is the key one and will be the focus of today's post. There are alternative methods to improve generalisation; with a little care, good weight initialisations can be found without batch norm; and learning rate dynamics are perhaps better controlled explicitly (with an optimiser such as LARS) than left to the implicit weight scale dynamics that we described last time.
Here is an experiment to demonstrate the effect of batch norm on optimisation stability. We train a simple, 8-layer, unbranched conv net, with and without batch norm, on CIFAR10. We shall be using variants of this network throughout the post. The reason for choosing an unbranched architecture is that we will be studying effects that grow with network depth and shortcut connections reduce the effective depth, meaning that we would need a deeper architecture to see similar effect sizes.
The learning rate, plotted on the left, is increased exponentially over time, as a stress test. Training accuracy is plotted on the right.
It can be seen that the network with batch norm is stable over a much larger range of learning rates (note the log scale on the x-axis in the second plot. ) The ability to use high learning rates allows training to proceed much more rapidly for the model with batch norm.
So why has so much effort been devoted to finding replacements? Batch norm has several drawbacks:
We won't have much to say on these today except to remark that once one has accepted the necessity of something like batch norm, addressing these issues seems a lot more palatable. The speed issue in particular is far from insurmountable. A good compiler could fuse the computation of statistics into the previous layer and application into the next, avoiding unneccessary round trips to memory and removing almost all overhead.
Hopefully today's explorations can also provide some guidance on the essential features of batch norm which would be required in any replacement.
The aim of this section is to understand the typical behaviour of deep networks at initialisation. We shall see hints of the problems to come when we start training. In particular care needs to be taken at initialisation if the network is to avoid computing a constant function, independent of the inputs. We shall review the surprising and under-appreciated fact that the standard He-et-al initialisation produces effectively constant functions for sufficiently deep ReLU networks.
We come to the first key point. Batch norm acts on histograms of per channel activations (by shifting means and rescaling variances), which means that these are a really good thing to monitor. This seems to be rarely done, even in papers studying batch norm. A notable exception is the recent paper of Luther and Seung, which we shall discuss shortly.
The diagram below shows histograms of activation values, across pixels and examples in a batch, before and after a batch norm layer. Different channels are represented by different colours; the per channel mean is shown underneath; and minimum/maximum values per channel are indicated by the vertical ticks (the height of which is held constant between plots in units of probability density so that small ticks indicate a more peaked distribution.) Strictly, these are kernel density estimates rather than histograms. We display a maximum of 10 arbitraily selected channels per layer to avoid over-crowding.
We are going to be looking at lots of plots like these during the remainder of the post, so please take a moment to understand the details...
Following Luther and Seung, let's consider a very boring network consisting of alternating fully-connected and ReLU layers, all the same size, and without batch norm. Here are the first 10 layers:
We initialise the weights randomly using the standard He-et-al initialisation, which is aimed at preserving the mean and variance of layer outputs throughout the network if we pool together the channels in a layer. Actually, this is not exactly what the He-et-al initialisation does - it preserves mean and variance of outputs if we marginalise over the weights distribution, but this is close enough to pooling channels that we will ignore the difference here.
Let's look at histograms of channel activations at different depths in the network with independent @@0@@ inputs. First we look at pooled channel histograms, to check that the initialisation is doing what it is supposed to:
Things get a bit wobbly by layer 50, but generally look how they should.
Now let's look at the same histograms, split by channel:
These tell a rather different story!
By layer 50, each channel has effectively chosen a constant value to take, independently of the input. Suppose that layer 50 was the final layer before a softmax classifier (and only consisted of the 10 channels displayed.) Most inputs would be classified as belonging to the 'orange' class with a small number of 'greens' and nothing else. In this sense, a randomly initialised network likes to compute (almost) constant functions!
If we remove the ReLUs and rescale the weights appropriately so that we have a purely linear network, the problem goes away. At deeper layers, we eventually observe some variation between the variances of different channels but the effect is rather mild.
In fact, the per channel histograms are not telling the full story and there are significant correlations between channels in deep layers. We see this if we compute the singular value decomposition of the linear map represented by the linear network to different depths:
By layer 50, certain combinations of channels have been stretched by a factor of ~10, a fact which is not visible in the per channel histograms. There is an elegant analysis based on random matrix theory and reviewed in Pennington-et-al '17 that explains what is going on here. We shall have nothing more to say on the matter here, since the effect is small compared to that of adding ReLU layers; since this effect will be greatly ameliorated with the addition of residual connections; and since the effects of batch norm, seeing only per channel statistics, are largely orthogonal to this.
So how does the inclusion of ReLUs lead to the problems we observed above? It turns out that non-zero channel means introduced by the ReLU layers are at the root of the problem.
Since ReLUs only return positive values, the first time we pass a centred input distribution through a ReLU layer, we get an output distribution in which each channel has positive mean. After the following linear layer, channels still have non-zero means, although these can now be positive or negative depending on the weights for the particular channel. This is because the output mean of a linear layer equals the linear layer applied to the input mean, which is non-zero as a result of the preceding ReLU. The key point is that for fixed weights, per channel means are now non-zero, whilst if we average over the weights the effect disappears (by symmetry of the weights distribution around zero.)
It is less obvious that this effect should continue to grow through the network, with means increasing in magnitude with depth. This is shown by a direct calculation in Luther and Seung '19 and also follows from straightforward properties of ReLU and the symmetry about zero of the weights distribution.
The corresponding loss of variance within each channel follows from the Law of total variance which states that the total variance, integrating over the weights distribution - which is preserved by design under the He-et-al initialisation - decomposes into a sum of the variance of the per channel means and the average variance within each channel. (This decomposition is visible in the per channel histograms above for the ReLU network.) Since the per channel means increase in variance at each layer and the total variance is fixed, the variance within channels must decrease at each layer. At sufficient depth a fixed point is reached where there is no more variance within channels and the full variance is realised in the distribution of the per channel means. At this point the network is computing a constant function.
All of this suggests a rather simple fix: we should subtract the mean after each ReLU layer, and rescale to restore the total variance. In this simple network with @@0@@ input data, we can compute the relevant mean and scale factor analytically, but more generally, this corresponds to initialising with 'frozen batch norm' layers, which compute a static bias and scale to standardise the per channel means and variances at initialisation.
With this fix in place, order is restored and the network no longer computes a constant function:
What happens in a more realistic setting with convolutions, pooling and real data?
Let's consider the 8-layer unbranched conv net from earlier, without batch norm layers. At initialisation with independent @@0@@ input data we get the following per channel output distributions:
The collapse to a constant function happens even more rapidly with this network! This is partly because of the inclusion of max-pooling layers which also act to increase channel-wise means. With CIFAR10 inputs things are slightly improved but the effect clearly persists:
If we reintroduce batch norm layers (or frozen batch norm layers) then, by design, intermediate layers revert to mean zero, variance one outputs per channel. This prevents any growth in per channel means throughout the network. However, the final classifier is not acted on by a batch norm layer (which would presumably restrict expressivity and hurt training perfomance.) Since the inputs to the classifier have passed through both a ReLU and a max-pooling layer, these inputs have substantial positive means already, leading to variance in per channel means at the classifer output. The effect is somewhat milder than for the network without batch norm, but it is by no means absent in this case:
What have we learned so far?
In the absence of batch norm, the standard initialisation scheme for deep networks leads to 'bad' configurations in which the network effectively computes a constant function. By design, batch norm goes a long way towards fixing this. As far as initialisation is concerned - 'frozen batch norm', based on activation statistics at initialisation - works just as well.
A more interesting question, which we turn to next, is what happens when we start training? We shall find that the ubiquity of bad network configurations at initialisation reflects a deeper problem: in the absence of batch norm, there are bad configurations in a small neighbourhood of the good configurations traversed during a training run. The proximity of these bad configurations - with high training loss - indicates the existence of highly curved directions in the loss landscape and places a limit on the achievable learning rates.
The remaining sections are structured as follows.
We shall continue to work with the 8-layer unbranched network from above, with and without batch norm, with one minor alteration. Since we shall shortly be computing second derivatives using autograd, we are going to replace the ReLU non-linearity - whose second derivative vanishes almost everywhere and is therefore invisible to autograd - with a smoothed-out version. We choose to use the so-called 'softplus' non-linearity but the details are unimportant. This may be an unnecessary precaution but at least shouldn't hurt. We continue to refer to the (smoothed) non-linearity as ReLU below.
We want to understand the typical state of networks during training. We shall study the networks after 10 epochs of training - half-way through a typical training run - using the linear warm-up and decay schedule from previous posts. The maximal learning rate will have to be tailored to the network since the one without batch norm is unstable at higher learning rates. These choices are rather arbitrary but the conclusions appear to be robust. In particular experiments performed after a full 20 epochs of training produce similar results and the exact maximal learning rates and schedules don't appear to make a large difference.
We shall consider a third network, identical to the batch norm network, with the batch norm layers frozen after the 10 epochs of training. This allows us to separate issues of initialisation and training trajectory from the ongoing stabilising effects of batch norm. We shall find that the two networks without active batch norm layers behave similarly to one another and entirely differently from the batch norm network.
Statements about the size of learning rates, gradients, Hessian eigenvalues and perturbations in parameter space only really make sense once we fix the scale of the weights. As we discussed in the previous post, there are many alternative ways to parameterise the same network in which these weight scales differ, especially in the presence of batch norm.
For comparisons between the 'batch norm' and 'frozen batch norm' networks, this is not an issue since they have identical parameters. For the 'no batch norm' network we have been careful to make sure that weight scales are the same at initialisation and similar after training to the other two. In particular we have added a 'label smoothing' term to the loss to help stabilise the scale of the networks during training. A more involved treatment, in which we tried to make scale invariant statements only, would produce similar conclusions.
First we setup the network with batch norm and train for 10 epochs:
We reset training=True for this network so that batch norm is active in the experiments that follow:
We create a 'frozen batch norm' model by making a clone of the batch norm model, refreshing the batch norm stats on the training set (since things are changing fast at this point in training and we want to make sure that the frozen stats are up-to-date) and setting training_mode=False:
Next let's setup the network without batch norm from the beginning (note the lower learning rate for this model):
Finally, we convert the models to float32 format. For the following experiments - some of which involve computing second derivatives and doing numerical linear algebra on the results - we are going to work in float32 to reduce the risk of numerical issues. This has the disadvantage of making things rather slow, but efficiency is not the priority here. Were we working with larger models/datasets, it might be worth doing the work to make these computations work correctly in mixed precision.
We also need a float32 version of the dataset and since we are not training further, we are going to fix the shuffled order and random choices of data augmentations to reduce noise when comparing between experiments. We use a smaller batch size here since computing Hessian vector products is rather memory intensive.To reproduce the experiments on a GPU smaller than a V100, a further reduction in batch size may be necessary and things are likely to be slow.
We set ourselves the following task: for each of the three, partially trained networks, can we find nearby configurations in parameter space that compute nearly constant functions with correspondingly high training loss? In fact, let's sharpen the question as follows. Can we find nearby configurations in parameter space such that the networks (mostly) output a fixed, chosen class? This second question is amenable to a solution using backprop - we can compute the gradient, with respect to model parameters, of the mean classifier output for the chosen class. We can then perturb in that direction.
Let's put this into practice for our three networks. First, we compute gradients of the means of each of the 10 classifier outputs for each network:
Next we want to perturb by a fixed length vector in each of the gradient directions computed above. The vector of learnable parameters of the three partially trained networks have approximately the same norm - indeed the 'batch norm' and 'frozen batch norm' networks have identical learnable parameters.
Let's see what happens if we perturb by fixed length vectors of 1% of the length of the underlying parameter vector. As a baseline, we first inspect the unperturbed, channel distributions and the effect of perturbing parameters by a random vector of 1% length:
A few things to note about the plots above. (Note that we are looking at a scaled version of the outputs of the final layer, whereas at initialisation we were monitoring the unscaled version - the scale of the outputs has grown through the training process.)
First, the three networks produce similar output distributions - indeed the 'batch norm' and 'frozen batch norm' outputs are almost indistinguishable, as one might hope. The distributions consist of a main mode and a second, smaller mode at some distance to the right which is characteristic of the network starting to make confident predictions of a single class per input. The overall distributions of different classes are similar, with no single class dominating, consistent with the ~85% training accuracy observed at this stage of training. Finally, the distributions after perturbation by a 1% length random vector are visually indistinguishable from the baseline.
Next, let's see what happens when we perturb by the gradients of channel means computed above, also normalised to 1% of the length of the underlying parameter vector.
First for the 'batch norm' model:
Unlike for a random vector of the same length, the effect of the targetted perturbations is clearly visible. However, this length of perturbation, is not sufficient to produce constant outputs and catastrophic loss in training accuracy.
Next for the 'frozen batch norm' model:
Here the effect is much stronger! The perturbed networks predict the chosen class on a majority of inputs, leading to a dramatic increase in training loss.
The 'no batch norm' model behaves similarly to the 'frozen batch norm' model:
What has happened here? To get further insight, let's have a look at the gradients of the mean channel outputs that we computed above. In particular, let's look at the squared norm of these gradients and how this is distributed between layers for each network. Since the effect is similar for all 10 output channels, we show the gradient for the first channel only:
A couple of things to note. First, the gradient norms are much larger for the two networks without active batch norm. This is consistent with the results above: for a fixed length perturbation in the direction of the gradient, the effect on channel outputs is larger when the gradients are larger. Importantly, the gradient is concentrated in the final layer for the 'batch norm' network, but distributed throughout layers for the other two. We can understand this as follows.
For the networks without active batch norm, changes to the distribution of outputs in earlier layers, can propagate to changes in distribution at later layers. In other words, internal covariate shift is able to propagate to external shift at the output layer. This is closely related to the issue at initialisation, where non-zero means at earlier layers propagate to non-zero means at later layers. Since the earlier layers receive inputs that have passed through a ReLU and have postive mean (even if the mean before ReLU was zero), the weights can easily be adapted to produce a change in output mean. For the 'batch norm' network, changes to the mean and variance at earlier layers get removed by subsequent batch normalisation and so there is little opportunity for early layers to affect the output distribution.
We can visualise what is going on in the case of the 'frozen batch norm' network by looking at channel distributions throughout the network after perturbation by the normalised gradient for the first output channel:
The earlier layers - which had zero mean prior to perturbation, on account of the frozen batch norms - have increasingly large means as we progress through the layers, contributing to the large change for the targetted channel in the output. Apart from the shift in mean, little else changes from baseline in the per channel histograms, as the reader can verify by removing the perturbation.
An important point, in the discussion above, was that incoming activations at early layers have non-zero mean on account of the ReLUs, even if pre-ReLU activations have zero mean. This in turn allows perturbations of the weights to affect the output means. If we apply batch normalisation after the ReLU layers and then freeze after partial training, we might expect a different story in which the effect of weight perturbations in early layers is greatly diminished. (This is analagous to the situation at initialisation where we applied an analytic shift and rescaling after ReLU layers in the simple, fully connected net and found that the problem with propagating means and constant outputs disappeared.)
Let's check if this is the case by computing the gradients of output channel means for such a network. First we setup the network:
Next we compute channel mean gradients:
...and plot the effect of perturbing by these:
The impact of the perturbations is much smaller than for either of the other networks without active batch norm and and similar to that of the batch norm network! Similarly, if we inspect the squared norm of gradient of a single output channel then we find that the norm is much smaller for this network and mostly concentrated in the final layer.
Is this a solution to the stability problem without requiring active batch norm layers? Not quite.
First of all, means can re-appear at early layers after freezing batch norms, because of random drift during subsequent training and once these are back, the instability reappears. Also, as we will see in the next section, there are other issues beyond output means. As an example, early layers can still control their output variance by rescaling weights, even in the absence of non-zero input means and since ReLU is a convex function, increasing (resp. decreasing) the variance of the distribution passed to ReLU is another way to increase (resp. decrease) the post-ReLU mean.
Having said this, the network with frozen batch norms after ReLUs is considerably more stable than the usual one and this is an intertesting direction for further investigation.
It's easy to fall into the trap of anthropomorphising algorithms, but gradient descent is a rather primitive creature. It possesses an ant's-eye view of the surrounding landscape and no memory. Finding itself in a steep enough valley, the poorly-named algorithm observes ever-growing gradients as it oscillates its way up the sides.
For the two networks without active batch norm, we have discovered nearby configurations with mostly constant output and high training loss, indicating that we find ourselves in such a steep valley. If we can show that these are the most curved directions in the loss landscape and that similarly curved directions do not exist in the presence of batch norm, then we shall have understood, rather precisely, the mechanism by which batch norm stabilises optimisation and enables higher learning rates. In order to make this further claim we need to have a way to identify the most curved directions in the loss landscape - those which place the tightest limit on achievable learning rates.
In one dimension, for a loss function @@0@@, the gradient after taking a step @@1@@ is:
In the case of gradient descent @@3@@ and the update in terms of @@4@@ is:
Assuming that @@6@@, the highest stable learning rate is:
beyond which gradients and losses grow exponentially with time as in the picture above.
In higher dimensions, the gradient after taking a step in direction @@0@@ is:
where @@2@@ is
Assuming that all the eigenvalues of @@4@@:
Adding momentum brings a little extra stability, as explained nicely here, but the maximal learning rate is still limited by a constant times the reciprocal of @@6@@. This continues to hold in the presence of mild, localised non-convexity and well-behaved stochastic gradients.
If SGD goes suddenly wrong with 'nan's appearing in losses and parameter values, large eigenvalues of the Hessian are usually (always?) the culprit. It doesn't matter if nearly all of the millions of directions in parameter space are nicely behaved - as soon as there is one direction corresponding to a large eigenvalue of the Hessian, SGD will find itself oscillating out of control.
In the next setion we shall show how to compute the leading eigenvalues/vectors of @@7@@ and see how these relate to the nearby configurations with constant outputs that we identified earlier.
Our networks have over 4 million trainable parameters each, so direct computation of the Hessian is going to be painful. Fortunately, if we're only interested in the leading eigenvalues and eigenvectors we don't need to do this.
The same repeated application of the Hessian that leads to instability of gradient descent, will eventually produce a vector proportional to the maximal eigenvector, when initialised on a random vector. The maximal eigenvalue can be easily computed from this. Orthogonalising to this vector and repeating the process allows one to isolate sub-leading eigenvectors and eigenvalues. This is the well-known Power method. We shall use the slightly more sophisticated Lancsoz algorithm, based on the same principles, via an implementation in scipy.
We also need a way to compute Hessian vector products, which can be done using autograd via the well-known Pearlmutter trick. This repo provided a useful reference in developing the implementation below.
Warning: computing the eigenvalues/vectors for the three networks takes a while.
It's quicker, at the cost of some accuracy, if you reduce the number of batches to average over.
Let's plot the leading eigenvalues for the three networks:
The batch norm model has eigenvalues which are 1-2 orders of magnitude smaller than the leading eigenvalues of the other two networks, consistent with the much higher stable learning rates for that model.
The frozen batch norm model has 9 or 10 outlying eigenvalues, agreeing with results found elsewhere (Sagun et al '16, Sagun et al '17, Ghorbani et al '19')_which empirically found the number of outliers to be roughly the number of classes. We shall relate these to the 10 output channel mean perturbations in the next section.
Finally the situation for the 'no batch norm' network appears to be more complicated at this stage of training, with a larger set of outlying eigenvalues.
Now it's time to tie things together by relating the outlying eigenvectors of the Hessian - which correspond to the most curved directions in the loss landscape, responsible for limiting the achievable learning rates - to the nearby 'bad' configurations that we found earlier.
As a first step, let's observe the action of perturbing the network parameters by these leading eigenvectors - scaled as before to 1% of the length of the underlying network parameters. We are going to observe the effect on the per channel histograms of the output layer and display results for the leading 10 eigenvalues v1-v10.
For the 'batch norm' model , the means of different channels are somewhat affected but the effects are rather mild as expected. For the 'frozen batch norm' model, a clear picture emerges in which the means of typically a small number of channels are affected by each eigenvector from v1-v9, whilst v10 appears to correspond primarily to a rescaling of the outputs. For the 'no batch norm' model, the 10 eigenvectors seem to correspond to a mixture of mean perturbations and rescalings of the different channels.
How can we make this analysis more precise?
We have in hand the 10 vectors corresponding to gradients of the mean of each output channel. With these, we can compute the 10 dimensional subspace that they span and measure how much (of the norm squared) of a given eigenvector lies in this subspace. Note that a random vector of the same dimension (> 4 million) would have, with high probability, a norm squared component in a given 10 dimensional subspace of less than @@0@@.
Actually, shifting all of the output means in parallel has no effect on the loss, since the class proabilities are computed by applying a softmax function which is invariant under such an overall shift. As a result the relevant subspace of channel mean gradients which affect the loss is 9, not 10, dimensional. The loss has an exactly flat direction with coresponding zero eigenvalue of the Hessian in the direction of the overall shift. From now on we will describe the subspace of channel mean gradients as 9 dimensional, although the code below works with the full 10 dimensional space for simplicity.
Let's compute the norm squared overlaps between the eigenvectors and our 9 dimensional subspace:
Wow!! These are not random chance!
For the 'frozen batch norm' model we arrive immediately at a clear understanding of the 9 outlying eigenvectors. These lie almost entirely in the 9 dimensional subspace of gradients of output channel means. We conclude that for this model, the leading instability of SGD is caused simply by a failure to control the means of the outputs.
Can we do better and explain an even greater fraction of the leading eigenvectors for the other models?
A clue is in the previous plots showing that these eigenvectors also affect ouput variances. A reasonable idea would be to extend our 9 dimensional subspace to 19 dimensions by including gradients of the 10 output channel variances, or even to 29 dimensions to include 3rd moments or 'skews' of the output channels. Rescaling an output channel is another way to effectively damage the loss across all training examples and since the output distributions are highly skewed - to allow them to confidently predict a given class - changing the skews is also likely to have an effect.
Let's use the machinery developed above to test this out. First we compute gradients of the variance and skew of the output channels of each model:
Then we compute the norm squared overlaps between eigenvectors and subspaces:
The results are even more striking!
We have shown that the leading 10 eigenvectors of the 'frozen batch norm' model lie almost entirely inside an interpretable (spanned by gradients of the first three moments of the per class output distributions) 29 dimensional subspace of the full > 4 million dimensional parameter space. The results are almost as striking for the 'no batch norm' model and even for the 'batch norm' model.
This last point may seem suprising, but recall that the leading eigenvalues are much smaller for the batch norm model. Below we rescale the bars by the corresponding eigenvalues. We see that even if we fixed the instability from the first three moments of the output, the two models without active batch norm would remain substantially less stable than the batch norm model.
A more precise way to see this is to recompute the leading eigenvalues of the Hessian restricted to the subspace orthogonal to the 29 dimensional subspace of gradients of the three output moments:
We see that the 'no batch norm' and 'frozen batch norm' models are substantially more stable than before, but there is still work to do before they reach the stability of the 'batch norm' model. We discsuss this below.
So what have we learned?
First we reviewed the result that, in the absence of batch norm, deep networks with standard initialisations tend to produce 'bad', almost constant output distributions in which the inputs are ignored. We discussed how batch norm prevents this and that this can also be fixed at initialisation by using 'frozen' batch norm.
Next we turned to training and showed that 'bad' configurations exist near to the parameter configurations traversed during a typical training run. We explicitly found such nearby configurations by computing gradients of the means of the per channel output distributions. Later we extended this by computing gradients of the variance and skew of the per channel output distributions, arguing that changing these higher order statistics would also lead to a large increase in loss. We explained how batch norm, by preventing the propagation of changes to the statistics of internal layer distributions, greatly reduces the gradients in these directions.
Finally, we investigated the leading eigenvalues and eigenvectors of the Hessian of the loss, which account for the instability of SGD, and showed that the leading eigenvectors lie primarily in the low dimensional subspaces of gradients of output statistics that we computed before. The interpretation of this fact is that the cause of instability is indeed the highly curved loss landscape that is produced by failing to enforce appropriate constraints on the moments of the output distribution.
Taking a step back, the main lesson is that deep nets provide a convenient parameterisation of a large class of functions, capable of expressing many computations of interest, but that within this parameterisation there are certain constraints that need to be observed to produce useful functions - rather than say constants. During training, it is necessary to maintain these constraints and this places a heavy burden on a first-order optimiser such as SGD, whose learning rate is limited by the steep curvature in the directions orthogonal to the 'good' subspace.
A possible solution would be to use a second-order optimiser to handle the problematic curvature. (Indeed, our results cast some light on the success of second-order methods using the so-called 'generalised Gauss-Newton' approximation to the Hessian, since the outlying eigenvalues that we identify originate within this term.) However, the results above might convince one that a method such as batch normalisation - which directly tackles the issue underlying the problematic curvature by improving the parameterisation of function space - coupled with a first-order optimiser, is going to be hard to beat.
Despite our focus on the distributions of channels in the output layer, it is not sufficient to fix these. If it were, we could place a single batch norm layer towards the end of the network and this could be expected to work as well as batch norms distributed throughout. The final plot above - which shows the scale of residual instabilities after fixing the leading moments of the output distributions - indicates that doing so is unlikely to approach the stability of the network with batch norm layers throughout.
A reasonable expectation is that the internal layer distributions are also important in order to maintain the expressivity of the function computed by the network. For example, if at some intermediate layer, a small number of channels become dominant, this would introduce a bottleneck, greatly reducing the set of functions that the network is capable of representing. We would expect that directly tackling this kind of internal distributional shift is a further role of batch norm and this is an interesting direction for future research.
Another lesson is that batch norm seems almost ideally designed to tackle the key optimisation issue, which is indeed one of propagation of changes to channel-wise distributions. It would be very interesting to investigate the various alternatives to batch norm that have been proposed and understand if and how they manage to achieve a similar thing. These insights should also provide guidance in developing new stabilisation mechanisms, potentially providing a sharper toolkit than the alternative run-it-and-see approach.
Before closing, there are a couple of 'meta'-lessons I'd like to draw. The first is that visualisation is important! This project was stuck for longer than I care to recall until I finally bit the bullet and developed the channel-wise histogram plots that appear throughout the post. With these in hand, it became possible to visualise eigenvectors of the Hessian and develop an intuition for what was really going on.
Secondly, this is a story about optimisers and parameterisations of function spaces - it has almost nothing to do with datasets! Running this on ImageNet would not only have slowed things down massively, but the higher number of classes would have obscured the key issues under a mountain of numerical noise. I predict that progress on foundational issues will continue to be made using small models and datasets.