Backpropagation Formula: An Optimal Control View
Consider the following optimal control problem (or equivalently, a constrained optimization problem),
$$ \begin{aligned} \min_{u_0,u_1,\ldots,u_{N-1}}\ &\phi(x_N) + \sum_{k=0}^{N-1}g(k,x_k,u_k)\\ \operatorname{s.t.}\quad&x_{k+1} = f(k,x_k,u_k),\qquad k\in ⟦0, N-1⟧. \end{aligned} $$With appropriate choices of \(f,g,\phi\) and initial state \(x_0\), this optimal control problem can be seen as a training step a neural network for a singe sample point \((x^{(i)},y^{(i)})\).
- Set \(x_0\) as the input of the neural network, i.e. \(x^{(i)}\);
- Set \(u_k\) as the paramters of the \(k\)-th layer;
- Set \(f(k,\cdot,u_k)\) as the operation of the \(k\)-th layer.
Thus, \(x_k\) becomes the output of \(k\)-th layer. Then we need to specify the cost function.
- Set \(\phi\) as the loss between \(x_N\) and the target \(y^{(i)}\).
- Set \(g(k,x_k,\cdot)\) as the regularization loss of the \(k\)-th layer.
For example, for the widely used MSE loss with \(L_2\) regularization, the loss function is \[ L(x^{(i)},y^{(i)}) = \|x_N-y^{(i)}\|^2 + \sum_{k=0}^{N-1}\|u_k\|^2, \] where \(\phi(x_N)=\|x_N-y^{(i)}\|^2\) and \(g(k,x_k,u_k)=\|u_k\|^2\).
Back to the genral form. We need to calculate the derivatives of the cost functional (or objective function in optimization, loss function in machine learning) w.r.t \(u_0,u_1,\ldots,u_{N-1}\). Introduce the \(k\)-th tail cost as \[ J_k := \phi(x_N) + \sum_{i=k}^{N-1}g(i,x_i,u_i),\qquad\forall k\in ⟦0,N⟧, \] which can be seen as the function of the input \(x_k\) and hyperparameters \(u_k,u_{k+1},\ldots,u_{N-1}\) and \(J_0\) is the original cost. By induction, it is not hard to show that \(\partial J_k/\partial x_k\) satisfies the following adjoint equation
$$ \begin{aligned} \frac{\partial}{\partial x_N}J_N &= \phi'(x_N)\\ \frac{\partial}{\partial x_k}J_k &= \frac{\partial g}{\partial x}(k,x_k,u_k) + \langle\frac{\partial f}{\partial x}(k,x_k,u_k), \frac{\partial}{\partial x_{k+1}}J_{k+1}\rangle,\qquad k\in⟦0,N-1⟧. \end{aligned} $$This means the costate \(\partial J_k/\partial x_k\) can be calculated backwardly, i.e., form the last layer \(\phi'(x_N)\) to the very first layer \(\partial J_0/\partial x_0\). With the help of the costate, the rest part is straightforward
$$ \begin{aligned} \frac{\partial}{\partial u_k}J_0 &= \frac{\partial}{\partial u_k}J_k\\ &= \frac{\partial}{\partial u_k}\bigl(g(k,x_k,u_k) + J_{k+1}\bigr)\\ &= \frac{\partial g}{\partial u}(k,x_k,u_k) + \frac{\partial x_{k+1}}{\partial u_k}\cdot \frac{\partial}{\partial x_{k+1}}J_{k+1}\\ &= \frac{\partial g}{\partial u}(k,x_k,u_k) + \frac{\partial f}{\partial u}(k,x_k,u_k)\cdot \frac{\partial}{\partial x_{k+1}}J_{k+1},\qquad k\in ⟦0,N-1⟧. \end{aligned} $$To conclude, we introduce the Hamiltonian \[ H(t,x,u,p) = g(t,x,u) + \langle p, f(t,x,u)\rangle. \] In the calculation of the gradient of the loss function at point \((x^{(i)},y^{(i)})\), the forward phase is firstly executed to obtain the state series
$$ \begin{aligned} x_0 &= x^{(i)}\\ x_{k+1} &= \nabla_pH(k,x_k,u_k,p_{k+1}),\qquad k\in ⟦0,N-1⟧. \end{aligned} $$Then the costate series is obtained via the backward phase
$$ \begin{aligned} p_{N} &= \phi'(x_N)\\ p_{k} &= \nabla_xH(k,x_k,u_k,p_{k+1}),\qquad k\in ⟦0,N-1⟧. \end{aligned} $$At last, the gradient is \[ \frac{\partial}{\partial u_k}J_0 = \nabla_uH(k,x_k,u_k,p_{k+1}),\qquad k\in ⟦0,N-1⟧.\]
Further Readings
- Li, Qianxiao, Long Chen, and Cheng Tai. "Maximum Principle Based Algorithms for Deep Learning." Journal of Machine Learning Research 18 (2018): 1-29.
- Li, Qianxiao, and Shuji Hao. "An optimal control approach to deep learning and applications to discrete-weight neural networks". International Conference on Machine Learning. PMLR, 2018.