This is an example implementation for calculating gradients for different neural network operations.
The tensor.py contains the details on how the back-propagation on the whole graph happens (based on the chain-rule).
The ops.py contains the operator implementations and the corresponding derivatives.
Below is a set of operations with backward implementations.
In the calculations L means the loss, calculated at the end of the graph. Now, it is assumed to be a scalar.
This is the original paper for gelu. There are two formulas:
- the original one
- and approximation for better inference speed.
Here, we will only deal with the original. The formula for forward inference:
The error function can be defined by the following integral:
The derivative of the error function is therefore the function inside the integral by definition:
The derivative function for the gelu:
The matrix multiplication with tensors:
To derive the derivative of the matrix multiplication, let's apply tensor notation. The above formula is equivalent with:
(Summation convention was omitted here for simplicity.) The loss will be a function of Y, therefore the following derivative needs to be calculated with the chain rule (for A):
Taking into account that the derivative term is fixed,
So
The term
Therefore:
As a summary for matrix A:
A similar calculation can show the derivative according to B:
The convolution is assumed to be 2-dimensional, the extension for other dimensions is straightforward. There are three variables in convolution: x (image), weight and bias. The derivative can be calculated by any of them. Let's start with the bias as it is the most simplest one.
First, the formula for the convolution with bias (B):
The calculation of K is a convolution without bias. We first need the derivative according to B:
Looking the formula for
Where g is just an abbreviation. The formula means a sum reduce over 3 axes: b, i, j. (b: batch, i: height index for images, j: width index for images)
The implementation is using 2d spatial dimension but here, for simpler indexing in the formulas, we will assume 1d spatial indices (1d convolution). Assume non-trivial dilation (d) and stride (s).
The forward pass for convolution can be expressed with tensor notation as:
We need the derivative according to the weight:
Now, the calculation for the second derivative:
Introducing G for the derivative by Y, it simplifies to:
Do some rearrangment (in indices and order), so we can compare this formula accurately to the first one (to forward conv.):
Then, the main point is to express X', W' and s', d':
The last two formula indicates, the role of the stride and dilation is interchanged.
Summary:
- the derivative of the convolution according to the kernel weights (W) is also a convolution
- this effective convolution swaps the batch and channel axis in the input tensor (image)
- the weight for the effective convolution is expressed by the gradient calculated after the convolution (or it can also be filled with ones, but the tensor shape needs to match the shape of Y)
- the original stride becomes the dilation in the effective convolution
- the original dilation becomes the stride in the effective convolution
- a cropping can be necessary to adjust the output size of the effective convolution, the unused outputs should be cropped because the convolution can leave some of the X elements untouched when the stride and kernel size parameters are chosen a specific way
The derivative according to the image (X) is a bit more involved. But it can be also calculated with a convolution (or with transposed convolution).
Let's start with the forward convolution in tensor notation:
Ignore the bias and introduce a new index:
The derivative we are looking for:
The second term after taking the derivative (for fix i' and i, r becomes concrete too):
In order to arrive a formula wich can be similar to a convolution, we have to get rid of the condition (
This is like padding the G and W matrices internally with zeros (similarly to dilation). Then the derivative becomes:
Where
This formula looks exactly like a convolution but due to the negative sign of
Summary:
- the derivative of the convolution in terms of X is also a convolution
- the weight of the effective convolution requires a swap of the batch and channel axis; the spatial axis needs to be reverted (or rotated when pictured in 2d)
- the G matrix requires padding with zeros among its elements with s-1 elements; (new shape will be (gh - 1)(s-1) + gh); this results in G'
- the G' matrix also requires padding outside with K - 1 zeros on both sides (K is the size of W')
- both the stride and dilation of the effective convolution are 1
- the output shape should be adjusted to the shape of X, with zero pedding on the right side
Alternative with transposed convolution:
- the steps outlined in the summary is very similar to the transposed convolution; in fact it is a transposed convolution
- where outpad is the adjustment to the shape of X.