I present a derivation of efficient backpropagation equations for batch-normalization layers.

## Table of Contents

## Introduction

A batch normalization layer is given a batch of examples, each of which is a -dimensional vector. We can represent the inputs as a matrix where each row is a single example. Each example is normalized by

where are the mean and variance, respectively, of each input dimension across the batch. is some small constant that prevents division by 0. The mean and variance are computed by

An affine transform is then applied to the normalized rows to produce the final output

where are learnable scale parameters for each input dimension. For notational simplicity, we can express the entire layer as

**Notation**: denotes the Hadamard (element-wise) product. In the case of , where is a row vector and is a matrix, each row of is multiplied element-wise by .

**Gradient Notes**: Several times throughout this post, I mention my “gradient notes” which refers to this document.

## Backpropagation Basics

Let be the training loss. We are given , the gradient signal with respect to . Our goal is to calculate three gradients:

- , to perform a gradient descent update on
- , to perform a gradient descent update on
- , to pass on the gradient signal to lower layers

Both and are straightforward. Let be the -th row of . We refer to our gradient notes to get

Deriving requires backpropagation through , which yields

Next we have to backpropagate through . Because both and are functions of , finding the gradient of with respect to is tricky. There are two approaches to break this down:

- Take the gradient of with respect to each
**row**(example) in . This approach is complicated by the fact that the values of each row in influence the values of**all**rows in (i.e. ). By properly considering how changes in influence and , this is doable, as explained here. - Take the gradient of with respect to each
**column**(input dimension) in . I find this more intuitive because batch normalization operates independently for each column - , , , and are all calculated per column. This method is explained below.

## Column-wise Gradient

Since we are taking the gradient of with respect to each **column** in , we can start by considering the case where is just a single column vector. Thus, each example is a single number, and and are scalar real numbers. This makes the math much easier. Later on, we generalize to -dimensional input examples.

### Lemma

Let be a real-valued function of vector . Suppose is known. If where and , then

*Proof*

First we compute the gradient of for a single element in .

We apply the chain rule to obtain the gradient of for a single element in .

Now we can write the gradient for all elements in , where is the identity matrix.

This result is a generalization of the “product rule” in the completely scalar case. For a function where , we have

### Getting a single expression for

We want a single expression for , which we will derive in two steps.

- Rewrite in the form for some choice of and . This enables us to use the lemma above to obtain .
- Rewrite in the form for some choice of . This enables us to use our gradient notes to obtain .

We choose . Then can be expressed in terms of as follows:

where . Now we apply our lemma above.

can be written as a matrix multiplication with , where is a matrix of all ones.

Using our gradient rules, we get

### Simplifying the expression

First, we calculate

We plug this into our equation for and rewrite and in terms of and :

The last step above is because is the 0-vector:

Note that when the inputs are scalars, where is a scalar and is a column vector. Thus,

where is a -dimensional column vector of ones. The last line uses the fact that when the input examples are scalars, the derivatives simplify to

Finally, we generalize to the case when the input examples are -dimensional vectors:

## References

- Batch Normalization
- the original paper by Sergey Ioffe and Christian Szegedy

- Efficient Batch Normalization
- row-wise derivation of

- Deriving the Gradient for the Backward Pass of Batch Normalization
- another take on row-wise derivation of

- Understanding the backward pass through Batch Normalization Layer
- (slow) step-by-step backpropagation through the batch normalization layer

- Batch Normalization - What the Hey?
- explains some intuition behind batch normalization
- clarifies the difference between using batch statistics during training and sample statistics during inference