24 Sep 2024

An In-Depth Introduction to Backpropagation and Automatic Differentiation

Backpropagation and automatic differentiation (AD) are fundamental components of modern deep learning frameworks. However, many practitioners pay little attention to their implementations and may regard them as some sort of "black magic". It indeed looks like magic that PyTorch can virtually calculate derivatives of an arbitrary function defined by the user, and even accommodate flow control elements like conditional execution, which is mathematically not differentiable. Although we understand that mathematically they primarily employ the chain rule, it remains unclear how they efficiently apply it to a function whose form is entirely unknown and will be determined by the user.

In this post, I will introduce the underlying mechanisms of AD and give a simple implementatin for demonstration. Moreover, I want to clarify the following facts of AD.

  1. Backpropagation is a special case of AD that runs in reverse mode. Backpropagation is a specialized algorithm for calculating gradients of neural networks. On the other hand, AD refers to a general technique that generates numerical derivative evaluations rather than derivative expressions. Besides reverse-mode AD, there are forward-mode AD and even reverse-on-forward AD.
  2. AD is neither traditional numerical differentiation nor symbolic differentiation. However, AD does in fact provide numerical values of derivatives and it does so by using symbolic rules of differentiation, giving it a two-sided nature that is partly symbolic and partly numerical. To some extent, AD should be regarded as an efficient way to evaluate the derivative expressions without the need to explicitly construct their mathematical expressions.
  3. AD is guaranteed to take no more than 6 times the computational cost of a single function evaluation. In practice, the overhead is typically closer to a factor of 2 or 3 (Bishop & Bishop, 2024, p. 250; see also Baydin et al., 2018, p. 16).
  4. Although evaluating the Hessian matrix \(H\) of some scalar function \(f: \mathbb{R}^n \to \mathbb{R}\) by AD has \(\mathcal{O}(n^2)\) complexity, evaluating the Hessian-vector product \(H v\) by reverse-on-forward AD has only \(\mathcal{O}(n)\) complexity. The same difference exists between evaluating the Jacobian matrix \(J\) of some vector-valued function \(\mathbb{R}^n \to \mathbb{R}^m\) and evaluating the Jacobian-vector product \(Jv\).

The original motivation of this post is the first homework assignment of Data C182, which requires to implement the forward pass and backward pass of common neural network layers, including affine layer, dropout layer, batchnorm layer and so on. While scanning through the recommended textbook (Bishop & Bishop, 2024), I notice that there are not only reverse-mode AD, which is a general version of backpropagation, and also forward-mode AD, which can be used in combine with reverse-mode AD to efficiently compute the Hessian-vector product. After reading the survey paper (Baydin et al., 2018), I finally open up the "balck-box" of automatic differentiation hidden in PyTorch and have a clearer understanding of AD.

Four ways to calculate gradients

Adapted from the textbook (Bishop & Bishop, 2024) and the survey paper (Baydin et al., 2018).

Methods for the computation of derivatives in computer programs can be classified into four categories:

  1. manually working out derivatives and coding them;
  2. numerical differentiation using finite difference approximations;
  3. symbolic differentiation using expression manipulation in computer algebra systems such as Mathematica, Maxima, and Maple;
  4. automatic differentiation, also called algorithmic differentiation.

Numerical approximations of derivatives are inherently ill-conditioned and unstable. It suffers truncation and round-off errors inflicted by the limited precision of computations and the chosen value of the step size \(h\). Truncation error tends to zero as \(h \to 0\). However, as \(h\) is decreased, round-off error increases and becomes dominant. Moreover, numerical differentiation requires \(\mathcal{O}(n)\) evaluations (for gradients in \(n\) dimensions) of the original function, which is the main obstacle to its usefulness in machine learning where \(n\) can be as large as millions or billions in state-of-the-art deep learning models.

Symbolic differentiation, on the other hand, faces the problem called expression well, i.e., the resulting expressions for derivatives can become exponentially longer than the original function. A further major drawback with symbolic differentiation is that it requires that the expression to be differentiated is expressed in closed form. It therefore excludes important control flow operations such as loops, recursions, conditional execution, and procedure calls, which are valuable constructs that we might wish to use when defining the network function.

The first approach, which formed the mainstay of neural networks for many years, is to derive the backpropagation equations by hand and then to implement them explicitly in software. If this is done carefully it results in efficient code that gives precise results that are accurate to numerical precision. However, the process of deriving the equations as well as the process of coding them both take time and are prone to errors. If the model is altered, both the forward and backward implementations need to be changed in unison. This effort can easily become a limitation on how quickly and effectively different architectures can be explored empirically.

Modern deep learning frameworks use automatic differentiation to evaluate gradients of neural networks. Unlike symbolic differentiation, the goal of automatic differentiation is not to find a mathematical expression for the derivatives but to have the computer automatically generate the code that implements the gradient calculations given only the code for the forward propagation equations. It is accurate to machine precision, just as with symbolic differentiation, but is more efficient because it is able to exploit intermediate variables used in the definition of the forward propagation equations and thereby avoid redundant evaluations.

Automatic differentiation

Automatic differentiation (AD) can be thought of as performing a non-standard interpretation of a computer program where this interpretation involves augmenting the standard computation with the calculation of various derivatives. All numerical computations are ultimately compositions of a finite set of elementary operations for which derivatives are known, and combining the derivatives of the constituent operations through the chain rule gives the derivative of the overall composition. Usually these elementary operations include the binary arithmetic operations, the unary sign switch, and transcendental functions such as the exponential, the logarithm, and the trigonometric functions. Besides that, users can define their own operations, as long as they provide their derivative functions.

For a given function \(y=f(x)\) for differentiation, assume it can be described by the following computational graph

$$ \begin{aligned} z_0 &= x, \\ z_1 &= \phi_1(z_0), \\ z_2 &= \phi_2(z_0, z_1), \\ &\cdots \\ z_k &= \phi_k(z_0, z_1, z_2, \ldots, z_{k-1}), \\ &\cdots \\ z_N &= \phi_N(z_0, z_1, z_2, \ldots, z_{k-1}, \ldots, z_{N-1}), \\ y &= z_N, \end{aligned} $$

where \(\phi_1, \phi_2, \ldots, \phi_N\) are elementary operations whose derivatives \(\partial \phi_1, \partial \phi_2, \ldots, \partial \phi_N\) are known. Obviously, any computational graph is a directed acyclic graph and always falls within the considered case. In practice, the initial variable \(z_0\) contains all leaf nodes of the computational graph, and \(\phi_k\) may only depends on one or two variables in \(\{z_\alpha\}_{\alpha=0}^{k-1}\).

Notations. For a vector \(x\), we denote by \(x^{(i)}\) the \(i\)-th component of \(x\), i.e., \(x = (x^{(0)}, x^{(1)}, x^{(2)} \ldots)^\intercal\). For a vector \(y\) that depends on \(x\), the Jacobian matrix \(\partial y/\partial x\) is defined by \[ \biggl(\frac{\partial y}{\partial x}\biggr)_{ij} = \frac{\partial y^{(i)}}{\partial x^{(j)}}. \] For a function \(y=f(x)\), we also use \(\partial f\) to represent the Jacobian matrix \(\partial y / \partial x\). With this notation, the chain rule can be written as \(\partial (f \circ g ) = [\partial f] [\partial g]\). It should be noted that the Jacobian matrix \(\partial f\) of a scalar function is a row vector in this notation.

Forward-mode AD

Forward-mode AD is the conceptually most simple type: apply symbolic differentiation at the elementary operation level and keep intermediate numerical results, in lockstep with the evaluation of the main function. Viewing \(z_k\) (\(k=1, 2, \ldots, N\)) as functions of \(x\), we differentiate the function \[ z_k(x) = \phi_k(z_0(x), z_1(x), z_2(x), \ldots, z_{k-1}(x)) \] w.r.t. \(x\) and obtain

\begin{equation}\tag{chain-rule: FAD} \frac{\partial z_k}{\partial x} = \sum_{\alpha=0}^{k-1} \frac{\partial \phi_k}{\partial z_\alpha} \frac{\partial z_\alpha}{\partial x}, \quad k=1, 2, \ldots, N. \end{equation}

Starting from \(\partial z_0 / \partial x = I\), we evaluate \(\partial z_k / \partial x\) in accordance with the evaluation of \(z_k\). When finally we obtain \(y=z_N\), the derivative \(\partial y/\partial x=\partial z_N/ \partial x\) is also obtained.

$$ \begin{aligned} \frac{\partial z_1}{\partial x} &= \frac{\partial \phi_1}{\partial z_0} \frac{\partial z_0}{\partial x}, \\ \frac{\partial z_2}{\partial x} &= \frac{\partial \phi_2}{\partial z_0} \frac{\partial z_0}{\partial x} + \frac{\partial \phi_2}{\partial z_1} \frac{\partial z_1}{\partial x}, \\ &\cdots \\ \frac{\partial z_N}{\partial x} &= \sum_{\alpha=0}^{N-1} \frac{\partial \phi_k}{\partial z_\alpha} \frac{\partial z_\alpha}{\partial x}. \end{aligned} $$

In other words, these equations are evaluated from top to the bottom, row by row.

Reverse-mode AD

Reverse-mode AD differes from forward-mode AD by using different auxiliary variables. In forward-mode AD, the auxiliary variables are \(\partial z_k / \partial x\) (called tangent variables), and are evaluated from \(k=1\) to \(k=N\). In reverse-mode AD, the auxiliary variables are \(\partial y / \partial z_k\) (called adjoint variables), and are evaluated from \(k=N-1\) to \(k=0\). For any \(0 \leq k \leq N-1\), we regard \(\mathcal{Z}_k := \{z_\alpha\}_{\alpha=0}^{k}\) as independent variables and view \(\{z_\beta\}_{\beta=k+1}^N\) as functions of \(\mathcal{Z}_k\). Thus, we can differentiate \[ y = y(z_0, z_1, \ldots, z_k) \] w.r.t. \(z_k\) and obtain (see the appendix for an inductive proof)

\begin{equation}\tag{chain-rule: RAD} \frac{\partial y}{\partial z_{k}} = \sum_{\beta=k+1}^N \frac{\partial y}{\partial z_\beta} \frac{\partial \phi_\beta}{\partial z_k}, \quad k=N-1, N-2, \ldots, 0. \end{equation}

The partial derivative \(\partial y/ \partial z_k\) should be understood as the sensitivity of \(y\) w.r.t. \(z_k\). An intuitive explanation of this formula is: when \(z_k\) changes, the change of the final output is determined by the cumulative effect of changes in downstream variables \(\{z_\beta\}_{\beta=k+1}^N\).

Starting from \(\partial y / \partial z_N = I\), we evaluate \(\partial y / \partial z_k\) reversely.

$$ \begin{aligned} \frac{\partial y}{\partial z_{N-1}} &= {\color{blue}\frac{\partial y}{\partial z_{N}} \frac{\partial \phi_N}{\partial z_{N-1}}}, \\ \frac{\partial y}{\partial z_{N-2}} &= {\color{blue} \frac{\partial y}{\partial z_{N}} \frac{\partial \phi_N}{\partial z_{N-2}}} + {\color{green} \frac{\partial y}{\partial z_{N-1}} \frac{\partial \phi_{N-1}}{\partial z_{N-2}}}, \\ \frac{\partial y}{\partial z_{N-3}} &= {\color{blue} \frac{\partial y}{\partial z_{N}} \frac{\partial \phi_N}{\partial z_{N-3}}} + {\color{green} \frac{\partial y}{\partial z_{N-1}} \frac{\partial \phi_{N-1}}{\partial z_{N-3}}} + \frac{\partial y}{\partial z_{N-2}} \frac{\partial \phi_{N-2}}{\partial z_{N-3}}, \\ &\cdots \\ \frac{\partial y}{\partial z_{0}} &= \sum_{\beta=1}^N \frac{\partial y}{\partial z_\beta} \frac{\partial \phi_\beta}{\partial z_k}. \end{aligned} $$

Initially, all adjoint variables are set to 0 except \(\partial y/\partial z_N\), which is set to \(I\). First, we compute the blue terms \(\frac{\partial y}{\partial z_N}\frac{\partial \phi_N}{\partial z_k}\) and add them to the corresponding adjoint variable \(\partial y/\partial z_k\). In the next step, we compute the green terms \(\frac{\partial y}{\partial z_{N-1}}\frac{\partial \phi_{N-1}}{\partial z_k}\) and add them to the corresponding adjoint variable \(\partial y/\partial z_k\). Repeat this process until all terms listed above have been calculated. Then, the accumulated values in adjoint variables are theire true values.

In other words, these equations are evaluated from left to right, column by column. After the evaluation of column \(\beta\), all adjoint variables \(\partial y/ \partial z_k\) for \(k \geq \beta-1\) has been obtained.

It should be noted that the reverse calculations of \(\partial y/ \partial z_k\) actually happens in the second phase of a two-phase process, while intermediate variables \(z_k\) are calculated in the first phase. This is different from forward-mode AD, where \(\partial z_k / \partial x\) and \(z_k\) are calculated simultaneously and in a forward manner.

Examples: reverse-mode AD for scalar functions

Here we apply reverse-mode AD to scalar functions and demonstrate how it works. We will discuss the case of vector functions in the following sections.

Reverse-mode AD is inherently suitable for scalar output \(y\). Examining the formula (chain-rule: RAD), the Jacobian \(\partial y / \partial z_\beta\) now simplifies to row vectors as \(y\) is one dimensional. Consequently, all matrix multiplications \([\partial y/\partial z_\beta][\partial \phi_\beta /\partial z_k]\) reduces to vector-matrix products.

For a variable \(z\), we denote by a column vector \(\dot{z} = [\partial y / \partial z]^\intercal\). With this notation, the chain-rule of reverse-mode AD can be written as \[ \dot{z}_k = \sum_{\beta=k+1}^N \biggl[ \frac{\partial \phi_\beta}{\partial z_k} \biggr]^\intercal \dot{z}_\beta, \quad k=N-1, N-2, \ldots, 0.\] Note that the true gradient \(\dot{z}_k\) is a summation accumulated from \(\beta=N\) to \(\beta=k+1\). At each stage \(\beta\), we can only compute a single term \([\partial \phi_\beta /\partial z_k]^\intercal \dot{z}_\beta\) for \(\dot{z}_k\).

In order to apply the chain rule, we record all the operations \(\{\phi_k\}\) along with their inputs and outputs on a tape (alternatively known as a Wengert list or an evaluation trace).

Example. Consider the function \(f(a, b) = \langle a, a+ b\rangle\). The computational graph is

\(k\) input \(\phi_k\) output
1 \(a, b\) \(z_1 = a+b\) \(z_1\)
2 \(a, z_1\) \(z_2 = a^\intercal z_1\) \(z_2\)

Initially, set \(\dot{a}=0, \dot{b}=0\) and \(\dot{z}_k = 0\). Starting at \(\dot{z}_2 = [\partial y / \partial z_2]^\intercal = 1\), we then apply the chain rule to propagate gradient reversely and accumulate the obtained value.

$$ \begin{aligned} (\text{Initialize}) &\qquad & \dot{z}_2 &\leftarrow 1 \\ (k=2) & \qquad & \dot{a} &\leftarrow \dot{a} + \biggl[\frac{\partial \phi_k}{\partial a}\biggr]^\intercal \dot{z}_2 \\ & & \dot{z}_1 &\leftarrow \dot{z}_1 + \biggl[\frac{\partial \phi_k}{\partial z_1}\biggr]^\intercal \dot{z}_2 \\ (k=1) & \qquad & \dot{a} &\leftarrow \dot{a} + \biggl[ \frac{\partial \phi_k}{\partial a} \biggr]^\intercal \dot{z}_1 \\ & & \dot{b} &\leftarrow \dot{b} + \biggl[\frac{\partial \phi_k}{\partial b}\biggr]^\intercal \dot{z}_1\\ \end{aligned} $$

We can explicitly write down these values and verify.

$$ \begin{aligned} \dot{z}_2 &= 1 ,\\ \dot{z}_1 &= a ,\\ \dot{a} &= z_1 \dot{z}_2 + \dot{z}_1 = 2a + b ,\\ \dot{b} &= \dot{z}_1 = a. \end{aligned} $$

Example. Consider the function \(f(a, b) = \|a\|^2 + a^\intercal b - \sin (a^\intercal b)\). The computational graph is

\(k\) input \(\phi_k\) output
1 \(a, b\) \(z_1 = a^\intercal b\) \(z_1\)
2 \(a\) \(z_2 = a^\intercal a\) \(z_2\)
3 \(z_1,z_2\) \(z_3 = z_1 + z_2\) \(z_3\)
4 \(z_1\) \(z_4 = \sin z_1\) \(z_4\)
5 \(z_3,z_4\) \(z_5=z_3-z_4\) \(z_5\)

Initially, set \(\dot{a}=0, \dot{b}=0\) and \(\dot{z}_k = 0\). Starting at \(\dot{z}_5 = [\partial y / \partial z_5]^\intercal = 1\), we then apply the chain rule to propagate gradient reversely and accumulate the obtained value.

$$ \begin{aligned} (\text{Initialize}) &\qquad & \dot{z}_5 &\leftarrow 1 \\ (k=5) & \qquad & \dot{z}_4 &\leftarrow \dot{z}_4 + \biggl[ \frac{\partial \phi_k}{\partial z_4}\biggr]^\intercal \dot{z}_5 \\ & & \dot{z}_3 &\leftarrow \dot{z}_3 + \biggl[ \frac{\partial \phi_k}{\partial z_3}\biggr]^\intercal \dot{z}_5 \\ (k=4) & \qquad & \dot{z}_1 &\leftarrow \dot{z}_1 + \biggl[ \frac{\partial \phi_k}{\partial z_1}\biggr]^\intercal \dot{z}_4\\ (k=3) & \qquad & \dot{z}_2 &\leftarrow \dot{z}_2 + \biggl[ \frac{\partial \phi_k}{\partial z_2}\biggr]^\intercal \dot{z}_3\\ & & \dot{z}_1 &\leftarrow \dot{z}_1 + \biggl[ \frac{\partial \phi_k}{\partial z_2}\biggr]^\intercal \dot{z}_3 \\ (k=2) & \qquad & \dot{a} &\leftarrow \dot{a} + \biggl[ \frac{\partial \phi_k}{\partial a}\biggr]^\intercal \dot{z}_2\\ (k=1) & \qquad & \dot{a} &\leftarrow \dot{a} + \biggl[ \frac{\partial \phi_k}{\partial a}\biggr]^\intercal \dot{z}_1\\ & & \dot{b} &\leftarrow \dot{b} + \biggl[ \frac{\partial \phi_k}{\partial b}\biggr]^\intercal \dot{z}_1\\ \end{aligned} $$

We can explicitly write down these values and verify.

$$ \begin{aligned} \dot{z}_5 &= 1 ,\\ \dot{z}_4 &= -\dot{z}_5 = -1 ,\\ \dot{z}_3 &= \dot{z}_5 = 1 ,\\ \dot{z}_2 &= \dot{z}_3 = 1 ,\\ \dot{z}_1 &= (\cos z_1) \dot{z}_4 + \dot{z}_3 = -\cos a^\intercal b + 1 ,\\ \dot{a} &= 2a \dot{z}_2 + b\dot{z}_1 = 2a + b(1 - \cos a^\intercal b) ,\\ \dot{b} &= a\dot{z}_1 = a(1 - \cos a^\intercal b). \end{aligned} $$

Jacobian-vector product

Examing the chain-rule formulae used in forward-mode and reverse-mode AD, we note that it suffices to calculate the Jacobian \(\partial \phi_k\). In practice, however, we don't calculate Jacobian directly. Instead, we calculate the so-called Jacobian-vector product (or vector-Jacobian product).

Consider the function \(y=f(x)\) and assume \(x\in\mathbb{R}^n, y\in\mathbb{R}^m\).

For a vector \(v\in\mathbb{R}^n\), forward-mode AD can efficiently calculate the Jacobian-vector product \([\partial f]v\) by

\begin{equation}\tag{chain-rule: JVP} \dot{z}_k = \sum_{\alpha=0}^{N-1} \biggl[\frac{\partial \phi_k}{\partial z_\alpha}\biggr] \dot{z}_\alpha, \quad \text{ where }\dot{z}_\alpha := \biggl[\frac{\partial z_\alpha}{\partial x}\biggr] v. \end{equation}

For a vector \(v\in\mathbb{R}^m\), reverse-mode AD can efficiently calculate the vector-Jacobian product \([\partial f]^\intercal v\) by

\begin{equation}\tag{chain-rule: VJP} \dot{z}_k = \sum_{\beta=k+1}^N \biggl[\frac{\partial \phi_\beta}{\partial z_k}\biggr]^\intercal \dot{z}_\beta, \quad \text{ where }\dot{z}_\beta := \biggl[ \frac{\partial y}{\partial z_\beta} \biggr]^\intercal v. \end{equation}

There are two reasons why using Jacobian-vector product is always better than using the Jacobian.

  1. It requires less memory. Indeed, all calculations in (chain-rule: VJP) are matrix-vector product, while all calculations in (chain-rule: RAD) are matrix multiplications.
  2. It remains efficient when parallized. If we want the full Jacobian matrix, we can run the algorithm with different \(v_j=e_j\) concurrently and then stack the result.

In deep learning, we use reverse-mode AD because the loss function is a scalar function \[ \ell(\theta) = L(y, \hat{y}), \quad \text{ where } y:=f(x;\theta). \] The gradient is indeed a vector-Jacobian product \[ \frac{\partial \ell(\theta)}{\partial \theta} = \frac{\partial L}{\partial y} \frac{\partial y}{\partial \theta}. \]

Finally, we should mention that in our formulation all vectors can generalized to tensors. Indeed, for a tensor \(x[\eta]\) of order \(n\), where \(\eta \in \mathcal{I}(x) \subset \mathbb{N}^{n}\) is a multi-index, we can treat it as a vector by iterating on \(\mathcal{I}(x)\). Moreover, for a function \(y=f(x)\) with \(y\) a tensor of order \(m\), we can define the Jacobian-vector product by \[ \operatorname{JVP}_f(x, v)[\zeta] := \sum_{\eta \in \mathcal{I}(x)} \frac{\partial (f[\zeta])}{ \partial (x[\eta])} v[\eta], \quad \zeta\in\mathcal{I}(y)\subset\mathbb{N}^m. \] Similarly, the vector-Jacobian product is defined by \[ \operatorname{VJP}_f(x, v)[\eta] := \sum_{\zeta \in \mathcal{I}(y)} \frac{\partial (f[\zeta])}{ \partial (x[\eta])} v[\zeta], \quad \eta\in\mathcal{I}(x)\subset\mathbb{N}^n. \] See the appendix for an example illustrating how we apply this definition to matrix multiplication \(f(A, B) = AB\).

Implement forward-mode and reverse-mode AD

In this section, we give the formal algorithm for differentiating a function \(y=f(x)\).

Let's begin with forward-mode AD, which evaluates the Jacobian-vector product \([\partial f(x)]v\) at a given point \(x\). Assume we have access to the tape which records the sequence of elementary operations and their inputs/output during the computation of the target function \(y=f(x)\). The tables in section Examples: reverse-mode AD for scalar functions are examples of such tapes.

Forward-mode AD relies on the mathematical formula (chain-rule: JVP). Starting with the gradient \(v\) of the initial input \(x\), we traverse the tape in a forward direction, propagating gradients from the inputs of \(\phi_k\) to its output via its JVP. The pseudocode of forward-mode AD can be summarized as follows.

def forwardAD_along_tape(inputs, call_tape, inputs_v, *, gradkey):
    """Forward propagate gradient starting at inputs. Initially the grad of
    inputs is set to inputs_v.  `gradkey` is a string used for the dict key. For
    a given tensor `a`, the grad is stored in `a.buffer[gradkey]`"""
    for x, v in zip(inputs, inputs_v):
        x.buffer[gradkey] = v
    for k_inputs, k_outputs, k_phi, k_kwargs in call_tape:
        grad_inputs = [x.buffer[gradkey] for x in k_inputs]
        k_outputs.buffer[gradkey] = k_phi.jvp(
            k_inputs, k_outputs, grad_inputs, **k_kwargs
        )

Reverse-mode AD relies on the mathematical formula (chain-rule: VJP). Starting with the gradient \(v\) of the final output \(y\), we traverse the tape in a reverse direction, propagating gradients from the output of \(\phi_k\) to its inputs via its VJP. The pseudocode of reverse-mode AD can be summarized as follows.

def reverseAD_along_tape(y, call_tape, v, *, gradkey):
    """Backpropagate gradient starting at y. Initially the grad of y is set to
    v.  `gradkey` is a string used for the dict key. For a given tensor `a`, the
    grad is stored in `a.buffer[gradkey]`"""
    y.buffer[gradkey] = v
    for k_inputs, k_outputs, k_phi, k_kwargs in reversed(call_tape):
        grad_inputs = k_phi.vjp(
            k_inputs, k_outputs, k_outputs.buffer[gradkey], **k_kwargs
        )
        # accumulate grad
        for x, grad in zip(k_inputs, grad_inputs):
            x.buffer[gradkey] += grad

Comparing forward-mode AD and reverse-mode AD, a subtle difference is the gradient accumulation process in the latter. In the formula (chain-rule: JVP), the gradient \(\dot{z}_k\) is obtained by a single call of the JVP of \(\phi_k\). Thus, in the algorithm, the value of \(\dot{z}_k\) is computed in a single step of the iteration. In the formula (chain-rule: VJP), however, the gradient \(\dot{z}_k\) is obtained by successive calls of the VJP of \(\{\phi_\beta\}_{\beta=k+1}^N\). Consequently, in the algorithm, the value of \(\dot{z}_k\) accumulates in several steps of the iteration.

See the appendix for an overview of my implementation in Python.

Hessian-vector product

The Hessian matrix of a scalar function \(f\) can be defined by the Jacobian matrix of \(\partial f\), \[ (\partial^2 f)_{i,j} := (\partial(\partial f))_{i,j} = \frac{\partial}{\partial x_j}(\partial f)_i = \frac{\partial}{\partial x_j} \frac{\partial}{\partial x_i} f. \] Let \(g(x)=\partial f(x)\). Calculating the Jacobian-vector product \([\partial g]v\) is essentially calculating the Hessian-vector product \([\partial^2 f]v\). This can be achieved by a combination of forward-mode AD and reverse-mode AD.

Given the input \(x\), first calculate \(y=f(x)\) and record the operations during this procedure in a tape \(T_1\). Then, take a new tape \(T_2\). Use \(T_2\) to record the operations of applying forward-mode AD on \(T_1\) to calculate the Jacobian-vector product \(L=[\partial f]v\), which is a scalar in this case. Finally, apply reverse-mode AD on \(T_1 \cup T_2\) to obtain gradient \(\partial L = \partial([\partial f]v) = v^\intercal [\partial^2 f]\). This is exactly the Hessian-vector product \([\partial^2f]v\) if the Hessian matrix is symmetric.

def hvp_by_reverse_forwardAD(f, inputs, v_vars, *, inputs_vars):
    """Calculate the Hessian-vector product of function `f` using
    reverse-on-forward mode automatic differentiation.

    `inputs_vars` is an subset of `inputs` specifying independent
    variables in the Hessian matrix.  `inputs_vars` aligns with the
    number of tensors in `v_vars`.
    """
    tape1 = []
    with my_func_tracker.track_func(True, tape=tape1):
        y = f(*inputs)  # do computations and track in tape1

    tape2 = []
    with my_func_tracker.track_func(True, tape=tape2):
        forwardAD_along_tape(inputs_vars, tape1, v_vars, gradkey="rfgrad1")
        yy = y.buffer["rfgrad1"]

    # apply reverse-mode AD to yy
    # ATTENTION: we have to use a different gradkey to avoid modifying inputs
    #            recorded in tape 2
    reverseAD_along_tape(yy, tape1 + tape2, MyTensor(1.0), gradkey="rfgrad2")
    return [x.buffer["rfgrad2"] for x in inputs_vars]

The above procedure is the reverse-on-forward mode AD for calculating the Hessian-vector product. There are, of course, other procedures to obtain the same result.

  1. Forward-on-reverse mode AD. After recording the operations of \(y=f(x)\) in a tape \(T_1\), use a new tape \(T_2\) to record the operations of applying reverse-mode AD on \(T_1\) to calculate the gradient \(\partial f(x)\). Then, apply forward-mode AD on \(T_1 \cup T_2\) to calculate the Jacobian-vector product \([\partial (\partial f)]v\), which is exactly the Hessian-vector product \([\partial^2f]v\).
  2. Reverse-on-reverse mode AD. After recording the operations of \(y=f(x)\) in a tape \(T_1\), use a new tape \(T_2\) to record the operations of 1) applying reverse-mode AD on \(T_1\) to calculate the gradient \(\partial f(x)\) (which is a row vector in our notation); 2) computing the Jacobian-vector product of \(L=[\partial f(x)]v\). Then, apply reverse-mode AD on \(T_1 \cup T_2\) to calculate the gradient \(\partial L = \partial([\partial f]v)=v^\intercal [\partial^2f]\). This is exactly the Hessian-vector product \([\partial^2f]v\) if the Hessian matrix is symmetric.

See the appendix for an overview of my implementation in Python.

Conclusion

AD refers to a general technique that generates numerical derivative evaluations rather than derivative expressions. Backpropagation is a special case of AD that runs in reverse mode. Besides reverse-mode AD, there are forward-mode AD and even reverse-on-forward AD.

For a function \(y=f(x)\) with \(x\in\mathbb{R}^n\) and \(y\in\mathbb{R}^m\), forward-mode AD can efficiently compute the Jacobian-vector product \([\partial f]v\) for any given initial gradient \(v\in\mathbb{R}^n\), while reverse-mode AD can efficiently compute the vector-Jacobian product \(v^\intercal [\partial f]\) for any given terminal gradient \(v\in\mathbb{R}^m\). Deep learning uses reverse-mode AD because loss functions are scalar functions, which means \(m=1\) and setting \(v=1\) yields the gradient \(\partial f\) directly.

For a scalar function \(y=f(x)\), combining forward-mode AD and reverse-mode AD can efficiently evaluate the Hessian-vector product \([\partial^2 f]v\) for any vector \(v\).

References

Books and Papers

  • Bishop, C. M., & Bishop, H. (2024). Deep learning: Foundations and concepts. Springer.
  • Zhang, A., Lipton, Z. C., Li, M., & Smola, A. J. (2023). Dive into Deep Learning. Cambridge University Press. https://d2l.ai
  • Baydin, A. G., Pearlmutter, B. A., Radul, A. A., & Siskind, J. M. (2018). Automatic differentiation in machine learning: A survey. Journal of Machine Learning Research, 18(153), 1–43.
  • Paszke, A., Gross, S., Chintala, S., Chanan, G., Yang, E., DeVito, Z., Lin, Z., Desmaison, A., Antiga, L., & Lerer, A. (2017). Automatic differentiation in PyTorch. Neural Information Processing Systems. https://openreview.net/forum?id=BJJsrmfCZ

Online resources

Appendix: An inductive proof of reverse-mode AD formula

Consider the following computational graph

$$ \begin{aligned} z_0 &= x, \\ z_k &= \phi_k(z_0, z_1, \ldots, z_{k-1}), \quad k=1, 2, \ldots, N, \\ y &= z_N. \end{aligned} $$

Define functions \(\{f_k\}\) inductively from \(k=N\) to \(k=1\).

  1. For \(k=N\), \(f_N:= \phi_N\).
  2. Assume \(f_{k+1}\) has been defined. Then, \(f_k\) is defined by the map \[ (z_0, z_1, \ldots, z_{k-1}) \mapsto f_{k+1}(z_0, z_1, \ldots, z_{k-1}, \phi_k(z_0, z_1, \ldots, z_{k-1})). \]

We prove by induction that for any \(1 \leq k \leq N-1\), \[ \frac{\partial f_{k+1}}{\partial z_s} = \frac{\partial \phi_N}{\partial z_s} + \sum_{\beta=k+1}^{N-1} \frac{\partial f_{\beta+1}}{\partial z_\beta}\frac{\partial \phi_\beta}{\partial z_s}, \quad 0 \leq s \leq k-1. \]

  1. For \(k=N-1\), \(f_{k+1} = f_N = \phi_N\). The statement holds.
  2. Assume the statement holds for function \(f_{k+1}\). We want to prove that it holds for \(f_k\) too. According to the definition, \[ f_k(z_0, z_1, \ldots, z_{k-1}) = f_{k+1}(z_0, z_1, \ldots, z_{k-1}, \phi_k(z_0, z_1, \ldots, z_{k-1})). \] For any \(0 \leq s \leq k-1\), we differentiate both sides of the equation w.r.t. \(z_s\) and obtain \[ \frac{\partial f_k}{\partial z_s} = \frac{\partial f_{k+1}}{\partial z_s} + \frac{\partial f_{k+1}}{\partial z_k} \frac{\partial \phi_k}{z_s}. \] Substituting the hypothesis yields \[ \frac{\partial f_k}{\partial z_s} = \frac{\partial \phi_N}{\partial z_s} + \sum_{\beta=k+1}^{N-1} \frac{\partial f_{\beta+1}}{\partial z_\beta}\frac{\partial \phi_\beta}{\partial z_s} + \frac{\partial f_{k+1}}{\partial z_k} \frac{\partial \phi_k}{z_s}. \] This shows that the statement holds for \(f_k\).
  3. By induction, the statement holds for any \(1 \leq k \leq N-1\).

The adjoint variables \(\partial y / \partial z_k\) is then formally defined by \(\partial f_{k+1} / \partial z_k\). Moreover, \(\partial y / \partial z_N\) is defined by the identity matrix. Setting \(s=k-1\) in the statement yields \[ \frac{\partial y}{\partial z_k} = \frac{\partial \phi_N}{\partial z_k} + \sum_{\beta=k+1}^{N-1} \frac{\partial y}{\partial z_\beta}\frac{\partial \phi_\beta}{\partial z_k}. \] This is exactly the formula uesd in reverse-mode AD.

Appendix: JVP and VJP for linear maps

Let \(y=f(A, B) = AB\) where \(A\) and \(B\) are two matrices. The output \(y\) is also a matrix and \(y[i,k] = \sum_j A[i,j]B[j,k]\).

For a matrix \(v\) with the same shape as \(A\), \[ \operatorname{JVP}_f(A, v)[i,k] = \sum_{i',j} \frac{\partial y[i,k]}{\partial A[i',j]}v[i',j] = \sum_j B[j, k]v[i, j] = vB. \]

For a matrix \(v\) with the same shape as \(B\), \[ \operatorname{JVP}_f(B, v)[i,k] = \sum_{j, k'} \frac{\partial y[i,k]}{\partial B[j, k']}v[j, k'] = \sum_j A[i, j]v[j, k] = Av. \]

For a matrix \(v\) with the same shape as \(y\), \[ \operatorname{VJP}_f(A, v)[i,j] = \sum_{i', k} \frac{\partial y[i',k]}{\partial A[i, j]}v[i', k] = \sum_k B[j, k]v[i, k] = vB^\intercal. \]

For a matrix \(v\) with the same shape as \(y\), \[ \operatorname{VJP}_f(B, v)[j, k] = \sum_{i, k'} \frac{\partial y[i,k']}{\partial B[j, k]}v[i, k'] = \sum_i A[i, j]v[i, k] = A^\intercal v. \]

Appendix: Overview of the Python Implementation

The complete code is available at this github repo. The overall framework follows this colab notebook. Basically, I implement the following data structures.

  1. MyFunction. This models the elementary operations, such as addition, multiplication, division, subtraction, matrix multiplication and so on. Instances of this class have attributes vjp and jvp, which are responsible for computing vector-Jacobian products and Jacobian-vector products.

    class MyFunction(object):
        """Functions with vjp and jvp as attributes."""
    
        def __init__(self, name, func, func_vjp=None, func_jvp=None):
            self.name = name
            self.func = func
            self.vjp = func_vjp
            self.jvp = func_jvp
    
        def __call__(self, *args, **kws):
            return self.func(*args, **kws)
    
  2. MyFunctionTracker. This is a decorator class to track inputs and outputs of elementary operations, designed for the class MyFunction. The global instance my_func_tracker decorates the MyFunction.__call__ method. By doing so, each call of the elementary operations will be recorded in my_func_tracker.call_tape. Moreover, the method track_func returns a context manager for conveniently toggling traking functionality.

    class MyFuncTracker(object):
        """A decorator class to track function inputs and outputs.
        Designed for MyFunction class.
    
        Store recorded calls in attribute `call_tape`, a list of tuples
        representing (inputs_k, outputs_k, func_k).
    
        Args:
            do_track (bool): A boolean flag to determine whether tracking is enabled.
        """
    
        def __init__(self, do_track: bool):
            self.do_track = do_track
            self.call_tape = []
    
        def __call__(self, func):
            """Wrap the function to track inputs and outputs in `self.call_tape`.
            Expect func receive self as its first argument."""
            pass
    
        @contextmanager
        def track_func(self, do_track: bool, tape: Union[None, list] = None):
            """Context manager to enable or disable tracking within a block.  If
            tape is not None, store records in it. Otherwise, store records in
            `self.call_tape`."""
            pass
    
    # initialize a global function tracker
    my_func_tracker = MyFuncTracker(do_track=True)
    # apply it to MyFunction class
    MyFunction.__call__ = my_func_tracker(MyFunction.__call__)
    
  3. MyTensor. This models the variables for inputs and outputs of elementary operations. Instances of this class comprise, value for a NumPy array and buffer for a dictionary. Arithmetic operators of this class are overloaded by MyFunction instances following the instructions mentioned in the Python documentation Emulating numeric types. By doing so, calculations involving MyTensor instances will be recorded by my_func_tracker automatically.

    class MyTensor(object):
        def __init__(self, value=0):
            self.value = np.asarray(value)
            self.buffer = defaultdict(self.default_grad)
    
        def default_grad(self):
            return MyTensor(np.zeros_like(self.value))
    
        def __add__(self, other):
            pass
    

There are two notable things when implementing these classes.

  1. The attributes vjp and jvp of MyFunction instances are functions. These functions accepts MyTensor inputs and return MyTensor outputs, utilizing operations exclusively modelded by MyFunction for tensor manipulation. Therefore, calculations performed by vjp and jvp can be tracked by my_func_tracker, a crucial aspect for computing high-order derivatives. Below is the definition of addition along with its respective JVP/VJP functionalities. Notably, the JVP propagates gradients of inputs to gradient of outputs, while the VJP propagates gradient of outputs to gradient of inputs.

    def _add(a: MyTensor, b: MyTensor) -> MyTensor:
        return MyTensor(a.value + b.value)
    
    
    def _add_vjp(
        inputs: list[MyTensor], outputs: MyTensor, grad_outputs: MyTensor
    ) -> list[MyTensor]:
        return (grad_outputs for _ in inputs)
    
    
    def _add_jvp(
        inputs: list[MyTensor], outputs: MyTensor, grad_inputs: list[MyTensor]
    ) -> MyTensor:
        return grad_inputs[0] + grad_inputs[1]
    
    
    add = MyFunction("Add", _add, func_vjp=_add_vjp, func_jvp=_add_jvp)
    
  2. Broadcast operations should be tracked too. As NumPy arrays can broadcast their shape automatically, the function _add mentioned earlier may receive a vector \(x\in\mathbb{R}^3\) and a scalar \(k\in\mathbb{R}\) and return a new vector \(y=x+k\mathbb{1}\). Indeed, there is a broadcast operation \(k \mapsto k\mathbb{1}\in\mathbb{R}^3\) besides addition. Failing to record these broadcast operations the gradient can lead to incorrect gradient calculations. Below are the definitions of broadcast and addition operations in MyTensor.

    class MyTensor(object):
    
        def __add__(self, other):
            if not isinstance(other, MyTensor):
                other = MyTensor(other)
            a, b = MyTensor.broadcast(self, other)
            return add(a, b)
    
        @staticmethod
        def broadcast(*tensors):
            shape = np.broadcast_shapes(*[t.shape for t in tensors])
            return tuple(t.expand(shape=shape) for t in tensors)
    
    def _expand(a: MyTensor, *, shape: list[int]) -> MyTensor:
        return MyTensor(np.broadcast_to(a.value, shape))
    
    expand = MyFunction("Expand", _expand, func_vjp=_expand_vjp, func_jvp=_expand_jvp)
    

Finally, forward-mode AD and reverse-mode AD can be implemented as follows.

def forwardAD(
    f: Callable[[list[MyTensor]], MyTensor],
    inputs: list[MyTensor],
    inputs_v: list[MyTensor],
    *,
    gradkey: str = "grad",
) -> MyTensor:
    """Use forward-mode AD to compute the Jacobian-vector product of f.
    Return the gradient of f(dot(v, x)) evaluated at inputs.

    Args
    ----
    - `f`: The function to be differentiated.

    - `inputs`: Inputs of `f`.

    - `inputs_v`: A list of tensor matches `inputs`.

    - `gradkey`: A string used for the dict key. For a given tensor `a`,
           the grad is stored in `a.buffer[gradkey]`.
    """
    tape = []
    with my_func_tracker.track_func(True, tape=tape):
        # do computations and track in tape
        y = f(*inputs)
    # forward propagate gradient starting at inputs
    forwardAD_along_tape(inputs, tape, inputs_v, gradkey=gradkey)
    return y.buffer[gradkey]

def reverseAD(
    f: Callable[[list[MyTensor]], MyTensor],
    inputs: list[MyTensor],
    v: MyTensor,
    *,
    gradkey: str = "grad",
) -> list[MyTensor]:
    """Use reverse-mode AD to compute the vector-Jacobian product of f.
    Return the gradient of dot(f, v) evaluated at inputs.

    Args
    ----
    - `f`: The function to be differentiated.

    - `inputs`: Inputs of `f`.

    - `v`: Any tensor matches the dim of `f`.

    - `gradkey`: A string used for the dict key. For a given tensor `a`,
           the grad is stored in `a.buffer[gradkey]`.

    Note
    ----
    The gradient of tensor `a` is accumulated in `a.bffer[gradkey]`, which is
    zero by default. However, this function does not check whether it is zero or
    not. It simply accumulates all gradient in it.
    """
    tape = []
    with my_func_tracker.track_func(True, tape=tape):
        # do computations and track in tape
        y = f(*inputs)
    # backpropagate gradient starting at y
    reverseAD_along_tape(y, tape, v, gradkey=gradkey)
    return [x.buffer[gradkey] for x in inputs]

There is also a Test class, which contains various functions and their derivative functions derived by hand. For example, \[ f(w; X, y) = L(s(Xw), y), \] where \[s(x) = 1/(1 + e^{-x}), \quad L(p, y) = -\sum[ y_i \log p_i + (1-y_i)\log (1-p_i)]\] are sigmoid function and binary cross entropy loss. \(X\in\mathbb{R}^{n\times m}\) is the predictor matrix and \(y\in\mathbb{R}^n\) is the response. It is not hard to show that \[\partial_w f = (p-y)^\intercal X, \quad \partial^2_w f = X^\intercal \Omega X,\] where \(p=s(Xw)\) and \(\Omega = \operatorname{diag}(p(1-p))\).

Tags: ai
Created by Org Static Blog