I present a derivation of efficient backpropagation equations for batch-normalization layers.

## Introduction

A batch normalization layer is given a batch of $N$ examples, each of which is a $D$-dimensional vector. We can represent the inputs as a matrix $X \in \R^{N \times D}$ where each row $x_i$ is a single example. Each example $x_i$ is normalized by

where $\mu, \sigma^2 \in \R^{1 \times D}$ are the mean and variance, respectively, of each input dimension across the batch. $\epsilon$ is some small constant that prevents division by 0. The mean and variance are computed by

An affine transform is then applied to the normalized rows to produce the final output

where $\gamma, \beta \in \R^{1 \times D}$ are learnable scale parameters for each input dimension. For notational simplicity, we can express the entire layer as

Notation: $\odot$ denotes the Hadamard (element-wise) product. In the case of $\gamma \odot \hat{X}$, where $\gamma$ is a row vector and $\hat{X}$ is a matrix, each row of $\hat{X}$ is multiplied element-wise by $\gamma$.

Gradient Notes: Several times throughout this post, I mention my “gradient notes” which refers to this document.

## Backpropagation Basics

Let $J$ be the training loss. We are given $\frac{\partial J}{\partial Y} \in \R^{N \times D}$, the gradient signal with respect to $Y$. Our goal is to calculate three gradients:

1. $\frac{\partial J}{\partial \gamma} \in \R^{1 \times D}$, to perform a gradient descent update on $\gamma$
2. $\frac{\partial J}{\partial \beta} \in \R^{1 \times D}$, to perform a gradient descent update on $\beta$
3. $\frac{\partial J}{\partial X} \in \R^{N \times D}$, to pass on the gradient signal to lower layers

Both $\frac{\partial J}{\partial \gamma}$ and $\frac{\partial J}{\partial \beta}$ are straightforward. Let $y_i$ be the $i$-th row of $Y$. We refer to our gradient notes to get

Deriving $\frac{\partial J}{\partial X}$ requires backpropagation through $Y = \gamma \odot \hat{X} + \beta$, which yields

Next we have to backpropagate through $\hat{X} = \frac{X - \mu}{\sqrt{\sigma^2 + \epsilon}} = (\sigma^2 + \epsilon)^{-1/2}(X - \mu)$. Because both $\sigma^2$ and $\mu$ are functions of $X$, finding the gradient of $J$ with respect to $X$ is tricky. There are two approaches to break this down:

1. Take the gradient of $J$ with respect to each row (example) in $X$. This approach is complicated by the fact that the values of each row in $X$ influence the values of all rows in $\hat{X}$ (i.e. $\partial \hat{x}_{j \neq i} / \partial x_i \neq 0$). By properly considering how changes in $x_i$ influence $\mu$ and $\sigma^2$, this is doable, as explained here.
2. Take the gradient of $J$ with respect to each column (input dimension) in $X$. I find this more intuitive because batch normalization operates independently for each column - $\mu$, $\sigma^2$, $\lambda$, and $\beta$ are all calculated per column. This method is explained below.

Since we are taking the gradient of $J$ with respect to each column in $X$, we can start by considering the case where $X$ is just a single column vector. Thus, each example $x_i$ is a single number, and $\mu$ and $\sigma$ are scalar real numbers. This makes the math much easier. Later on, we generalize to $D$-dimensional input examples.

### Lemma

Let $a(B) \in \R$ be a real-valued function of vector $B \in \R^n$. Suppose $\frac{\partial a}{\partial B} \in \R^n$ is known. If $B = c(D) \cdot D$ where $c(D) \in \R$ and $D \in \R^n$, then

Proof

First we compute the gradient of $B$ for a single element in $D$.

We apply the chain rule to obtain the gradient of $a$ for a single element in $D$.

Now we can write the gradient for all elements in $D$, where $I$ is the $n \times n$ identity matrix.

This result is a generalization of the “product rule” in the completely scalar case. For a function $a(b)$ where $b=c(d) \cdot d$, we have

### Getting a single expression for $\frac{\partial J}{\partial X}$

We want a single expression for $\frac{\partial J}{\partial X}$, which we will derive in two steps.

1. Rewrite $\hat{X}$ in the form $\hat{X} = c(R) \cdot R$ for some choice of $R$ and $c(R)$. This enables us to use the lemma above to obtain $\frac{\partial J}{\partial R}$.
2. Rewrite $R$ in the form $R = A \cdot X$ for some choice of $A$. This enables us to use our gradient notes to obtain $\frac{\partial J}{\partial X}$.

We choose $R = X-\mu$. Then $\hat{X}$ can be expressed in terms of $R$ as follows:

where $c(R) = (\sigma^2 + \epsilon)^{-1/2} = (\frac{1}{N} R^T R + \epsilon)^{-1/2}$. Now we apply our lemma above.

$R$ can be written as a matrix multiplication with $X$, where $\mathbf{1}$ is a $N \times N$ matrix of all ones.

Using our gradient rules, we get

### Simplifying the expression

First, we calculate

We plug this into our equation for $\frac{\partial J}{\partial X}$ and rewrite $R$ and $c(R)$ in terms of $\mu$ and $\sigma$:

The last step above is because $\mathbf{1} \hat{X}$ is the 0-vector:

Note that when the inputs are scalars, $\frac{\partial J}{\partial \hat{X}} = \gamma \cdot \frac{\partial J}{\partial Y}$ where $\gamma$ is a scalar and $\frac{\partial J}{\partial Y}$ is a column vector. Thus,

where $\mathbf{1}_N$ is a $N$-dimensional column vector of ones. The last line uses the fact that when the input examples are scalars, the derivatives simplify to

Finally, we generalize to the case when the input examples are $D$-dimensional vectors: