June 30, 2020
Batch Norm Folding: An easy way to improve your network speed
scroll

Introduction
At Scortex, we automate visual inspection. The defects that we are looking for can be quite small (1 mm) and we sometimes have to handle rather large parts (Up to several meters length. For more information, check our use cases). Moreover, complex part geometry means that we may have to use several cameras at the same time to inspect the full surface of the part. Finally, real time inference is necessary to ensure that we do not slow down production.
As a result, we are often in situations where we have to treat many high resolution images per second. When this happens, we need to maintain the highest throughput possible by optimising network inference speed.
As a matter of fact, Scortex joined the COGNITWIN consortium a few months ago in order to work on improving deep learning speed performances at the edge.

Along with our COGNITWIN partners, we are aiming at deploying fast deep networks in the factories. To do this, we investigated several leads. One of them is Batch Norm Folding, the topic of this blog post.
Normalization Layers
Batch Normalization
Batch Normalization (or BatchNorm) is a widely used technique to better train deep learning models. Batch Normalization is defined as follow:

Basically:
- Moments (mean and standard deviation) are computed for each feature across the mini-batch during training.
- The feature are normalized using these moments
- Two parameters for scale and shift are learned during training to allow more expressivity
- At test time, it is not possible to normalize using batch statistics. As a result, rolling averages of the moments are calculated while training and are used at test time.
The original paper explains that the technique reduces internal covariate shift by normalizing the inputs of each layer. However, a more recent paper disproves this hypothesis and explains that the success of Batch Normalization is due to the fact that it “makes the optimization landscape significantly smoother”.
Batch Normalization has supposedly several advantages:
- It provides faster convergence, hence faster training of the model as well as better performances.
- It enables training with a higher learning rate and makes training globally more robust to untuned learning rate (this is a very nice property; as a matter of fact, we recently had trouble fine-tuning a VGG architecture, which does not use Batch Norm, with a poorly chosen learning rate).
- It makes the training more robust to the activation choice. In the paper, the authors manage to achieve decent performances with a model trained with sigmoid activations!
- Removes the need for Dropout by providing some kind of regularization.
Other normalization schemes
The dependency of Batch Normalization to the mini-batches size is an issue. It has been proven that models using Batch Normalization performances degrade quickly when the batch size gets small. This is typically the case when training segmentation or detection architectures, where people usually use smaller batch sizes (4 to 32, when classification is usually 32 to 256).
To remedy this issue, several methods were proposed such as:
- Layer Norm, that normalize across channels
- Instance Norm, that normalizes only across height and width of the feature maps
- Group Norm, which defines group of features to replace the batch aggregation by a channel aggregation. This can be seen as a relaxation of LayerNorm. Bellow is an illustration of normalisation schemes from the Group Norm paper.

Though Group Norm is proven to be not as good as Batch Normalization, notably in high batch size regimes, combining it with Weight Standardization enables it to reach BatchNorm performances.
In 2020, two new papers came out to propose alternative standardization. In “Filter Response Normalization layer”, the authors propose a new normalization that leads to better performances than GroupNorm and BatchNorm for all batch sizes. In “Evolving Normalization-Activation Layers”, architecture search is performed to obtain the best couple of Normalization and Activation. The result is two Activation/Normalization: EvoNorm-B0 (uses batch statistics) and EvoNorm-S0 (independent of batch).
However, Batch Normalization has an advantage over Group Normalization and other methods: it can be easily folded in the convolution layers (NB: weight standardization can be folded too!).
Batch Norm Folding
NB: we know this technique as “folding” but it has also been called “fusing”, for example in this blog post or the tf.keras implementation.
Folding in Convolution layer
One of the advantages of Batch Normalization is that it can be folded in a convolution layer. This means that we can replace the Convolution followed by Batch Normalization operation by just one convolution with different weights.
To prove this, we only need a few equations. We keep the same notations as algorithm 1 above. Below, in (1) we explicit the batch norm output as a function of its input. (2) Locally, we can define the input of BatchNorm as a product between the convolution weights and the previous activations, with an added bias. We can thus express in (3) the BatchNorm output as a function of the convolution input which we can factor as equation (4) with new weights W’ and b’ described in (5) and (6).

In practice it is very easy to implement this in keras. We denote by I the dimension of the input, O the dimension of the output, and k the kernel size. The output of the keras get_weights() method of the convolution layer is a tuple of two elements: convolution weights of shape (k,k,I,O), and the biases of shape (O,). Meanwhile the BatchNorm weights are a tuple of four arrays of shape (O,).

In the code above, the reshaping is necessary to prevent a mistake if the dimension of the output O was the same as the dimension of the input I.
Setting the same epsilon as the one you use in Batch Norm (default is 1e-3) is absolutely necessary as small differences in each activation can quickly create butterfly effects in the output.
Folding in other layers
Now, BatchNorm is used after many different layers so we need to ask ourselves if we can easily fold into these.
Dense Layer
Dense layer, in a fully connected network or after a “Flatten operation” can be seen as a 1×1 convolution so it is very easy to see that folding works.
Depthwise convolution
Depthwise convolution applies a kernel independently on all input feature maps. As a result, its weights are a tuple of convolution weights (k,k,I,1) and biases (I,). Note that the elements of a batchnorm layer following a depthwise convolution will all have the shape (I,).

The only difference with the previous example is the reshaping that is different as the BatchNorm weights have a shape of I instead of O.
Separable convolution
Separable convolution (building block of MobileNet or Xception) is simply the concatenation of a Depthwise convolution and a 1×1 convolution. As a result, Batch Norm can easily be folded.

Transpose convolution
Batch Normalization cannot be as easily folded in the transpose convolution operation.
To see why, let’s consider the following schemas from this awesome github.

The left schema represents convolution with “same” padding (so that output size equals input size). The input map is blue, and the output map is cyan. On the right is displayed the transpose convolution with stride 2. This time, the input map is cyan, and the output map is blue.
With convolution, each output pixel is calculated using only one kernel calculation. For transpose convolution with stride 2 however, each output is the sum of kernels application. The amount of terms in the sum is not the same for the top left pixel (1) as it is for the middle one (9). As a result, folding cannot be obtained by simple matrix multiplication as above (though maybe some closed-form calculation enables it).
Folding other normalizations
One of the advantages of weight normalization is that it can easily be folded. You simply have to normalise the weights of your trained models and voilà, it is done.
Other normalizations such as GroupNorm, cannot necessarily be folded. In fact, the reason why BatchNorm can be folded comes from one of its potential weaknesses: at test time, estimators (rolling average) of moments are used. The fact that this parameter is fixed during inference enables folding. On the contrary, most of the other normalizations will instead require re-computing the moments at test time and thus cannot be folded. We could imagine replacing these moments with rolling average estimators but it would probably come at the expense of accuracy.
Conv-BN-ReLU or Conv-ReLU-BN?
Folding order
There are two possible ways of ordering batch norm and activation (in our case ReLU): Conv-BatchNorm-ReLU and Conv-ReLU-BatchNorm.
In the original batch normalization paper, the batch normalization operation is used between the convolution and the activation.
But the order of normalization and activation has been debated in several discussions such as this one or this one. Some internet benchmarks like this one seem to prove that Batch Norm after ReLU provides better performances.
Above we showed folding to be possible for the Conv-BatchNorm-ReLU case. Can we do the same when doing Conv-ReLU-BatchNorm? Because of the ReLU non linearity, we cannot fold the Batch Normalization on the left convolution, but we could imagine folding it in the next one as we would have something like Conv-ReLU-BatchNorm-Conv.
We can follow the same approach. Below, z is the output of the convolution, y the one of BatchNorm and x the input of BatchNorm. We can express locally z as a function of y in (7) and expanding it in (8). As previously, we can derive new formulas to fold the BatchNorm operations in the convolution.
Note that notations are abusive and that here, gamma and beta have the shape of the input. Hence, the convolution kernel will be multiplied by different values over its input dimension.

So now, it should be easy to reapply the exact same ideas as previously and come up with a keras implementation right? Except for one thing that was not taken into account here: padding.
Indeed if we use zero padding in convolutions, we will have 0 values entering the convolution, after the BatchNorm. When folded, it will still be the case: to fold properly, we would need to modify the default 0 values by -(b’ – b), that is apply the BatchNorm transformation to the padding so that we can apply its inverse later.
For this reason, we have only tried to fold BatchNorm in the Conv-BatchNorm-ReLU order.
Batch Normalization order experiments
We trained two kinds of architectures on our demonstrator data. A light one and a heavy one.
Our results are in line with the internet benchmark and seem to show that putting BatchNormalization after the ReLU provides better convergence properties. The effect seems larger on small architectures though.

Folding Experiments
Difference in output
We implemented folding using the functions described above. The first thing to check is that the outputs are the same between a folded model and its original version. Computing the segmentation probabilities of 500 images on 10 classes, the average per pixel difference is around 2.e-7 and the maximum observed on 500 images is around 6e-6. These values are to be compared with simply rerunning the prediction as there is stochasticity in CUDA computations. When re-running predictions, we get an average of 1e-9 and a maximum of 6e-6. Though folding leads to superior differences, it is still quite reasonable and does not impact the accuracy of the network.
Also note that these differences are way below what one can expect from very small input perturbations, as we discussed in a previous blog post.
Speed Results
Keras
First, we started by comparing inference times, using Keras 2.3.0 (with tensorflow 1.13.1 as back-end, this was done 8 months ago). We aimed to compare, for two different architectures (“shallow” and “deep”), the impact of manual batch-normalization folding.
We ran inference on 500 images from our demonstrator. Below are the results.


In this experiment, we obtain a x1.69 speed-up for inferring with the shallow architecture, while obtaining a whopping x2.7 for inferring with the deeper architecture.
It seems that when adding more layers, batch normalization layers become bigger bottlenecks in the inference process, resulting in a better speed-up when folding them.
tf.Keras
We then perform a similar experiment, using tensorflow 1.13.1 with tf.keras (tf.keras version is 2.2.4-tf). This Keras version benefits from the presence of a “fused” parameter in the BatchNormalization layer, whose role is to accelerate batch normalization by fusing (or folding, it seems terms can be used interchangeably) its weights into convolutional kernels when possible. We compare this built-in Keras fused batch normalization with a baseline and our manual folding.
This time, we run the experiment 20 times inferring 128 images from our demonstrator.


On average, with manual folding, we get a speed-up of 1.15 for the small network, and of 1.39 for the bigger network. Setting the “fused” batch normalization layer parameter to “True” also allows to speed-up inference time, although lower than a direct, manual folding. The results are relatively stable over runs, as the following box-plots show.

Like previously, the speed-up is more significant for the bigger network. It is however less impressive than in the previous experiment results. This could be explained by other optimizations brought with the usage of tf.keras (keras and tf.keras versions are different).
Finally, we also did similar experiments using plaidml, gaining around 30% speed and improving memory usage (which is not the case with tensorflow!).
Conclusion and Next steps
Folding/fusing provides a nice speedup at inference time. Doing so manually still slightly outperforms the keras “fuse” parameter. As a result, we have been using this at Scortex for several months in order to improve our inference time, on the production lines.
Note that speedup gains are lower using tf.keras because the network are actually faster. As a result, a lot speedup gains come from using tf.keras instead of keras. Using the latest version of tensorflow can also help.
Finally, in order to improve performances even further, or to run the model on different hardware (jetson, fpga, …), quantization can be considered. The implementation of both folding and quantization at the same time is far from trivial and has been studied in this paper.
If you are interested in the subject or in visual inspection in general, do not hesitate to contact us / Pierre / Antoine.
cite this blogpost: Gutierrez Pierre, Cordier Antoine, “Batch Norm Folding: An easy way to improve your network speed [Blog post]” June 30, 2020, https://scortex.io/batch-norm-folding-an-easy-way-to-improve-your-network-speed
