矩阵求导总结

矩阵求导总结

​ 这篇文章是为了总结矩阵求导和反向传播推导的,

求导布局

​ 求导布局包括:分子布局或分母布局。

  • 分子布局:求导结果的维度以分子为主。分子是列向量形式,分母是行向量形式,例如:

f2×1(x)x3×1T=[f1x1f1x2f1x3f2x1f1x2f2x3]2×3\frac{\partial {\boldsymbol{f}}_{2 \times 1}( \boldsymbol{x}) }{\partial \boldsymbol{x}_{3 \times 1}^T} = {\left\lbrack \begin{array}{lll} \frac{\partial f_1}{\partial x_1} & \frac{\partial f_1}{\partial x_2} & \frac{\partial f_1}{\partial x_3} \\ \frac{\partial f_2}{\partial x_1} & \frac{\partial f_1}{\partial x_2} & \frac{\partial f_2}{\partial x_3} \end{array}\right\rbrack }_{2 \times 3}

  • 分母布局:求导结果的维度以分母为主。分子是行向量形式,分母是列向量形式

f2×1T(x)x3×1=[f1x1f2x1f1x2f2x2f1x3f2x3]3×2\frac{\partial {\boldsymbol{f}}_{2 \times 1}^{T}\left( \boldsymbol{x}\right) }{\partial {\boldsymbol{x}}_{3 \times 1}} = {\left\lbrack \begin{array}{ll} \frac{\partial f_1}{\partial {x}_1} & \frac{\partial f_2}{\partial {x}_1}\\ \frac{\partial f_1}{\partial {x}_2} & \frac{\partial f_2}{\partial {x}_2}\\ \frac{\partial f_1}{\partial {x}_3} & \frac{\partial f_2}{\partial {x}_3} \end{array}\right\rbrack }_{3 \times 2}

​ 对于一个 nn-维输出张量 Y\boldsymbol{Y} 对一个 mm-维输入张量 X\boldsymbol{X} 的导数,雅可比张量将是一个 n+mn+m 维的张量。这是因为每个输出元素 yiy_i 对所有输入元素 xjx_j 的偏导数 yixj\frac{\partial y_i}{x_j} 都会形成一个新的维度,我们研究几个在反向传播的特例情况,

由 Loss 函数到向量对矩阵求导:

​ 从严格的数学定义来看, y\boldsymbol{y}W\boldsymbol{W} 的雅可比矩阵确实是一个三维结构。因为 y\boldsymbol{y} 是一个向量(假设大小为 mm ),W\boldsymbol{W} 是一个矩阵(假设大小为 m×nm \times n )。因此雅可比矩阵形成一个三维张量,其维度为 m×m×nm \times m \times n
但是,在神经网络的反向传播中,我们并不直接使用这个完整的三维雅可比张量。我们利用了链式法则和梯度计算
的特殊性质来简化计算。我们关心的是损失函数 LL 对权重矩阵 W\boldsymbol{W} 的梯度 LW\frac{\partial L}{\partial \boldsymbol{W}}

LW=i=1mLyiyiW\frac{\partial L}{\partial \boldsymbol{W}} = \sum_{i = 1}^m\frac{\partial L}{\partial y_i}\frac{\partial y_i}{\partial \boldsymbol{W}}

其中 Lyi\frac{\partial L}{\partial y_i} 是标量,而 yiW\frac{\partial y_i}{\partial \boldsymbol{W}} 是一个 m×nm \times n 矩阵,对应于单个输出 yiy_i 对所有权重 WijW_{ij} 的偏导数。然而,由于 y=Wx+b\boldsymbol{y} = \boldsymbol{W}\boldsymbol{x} + \boldsymbol{b},我们知道

yi=k=1nWikxk+biyiWjk=xkδij={xkif i=j0elsey_i = \sum_{k=1}^n W_{ik} x_k + b_i \\ \Rightarrow \frac{\partial y_i}{\partial W_{jk}} = x_k\delta_{ij} = \left\{ \begin{array}{l} x_k \quad \text{if } i=j\\ 0 \quad \text{else}\\ \end{array} \right.

这意味着在实际计算中,yiW\frac{\partial y_i}{\partial \boldsymbol{W}} 虽然是二维的结构,但是他内部只有某一行元素不为零,且对于每一个 ii 这个行向量都是一样的(因为 ii 维度没有参与计算),意味着我们可以在计算的时候给它降维为一维:

yiWjk=xkδijLWjk=i=1mLyiyiWjk=Lyixk\frac{\partial y_i}{\partial W_{jk}} = x_k\delta_{ij} \\ \Rightarrow \frac{\partial L}{\partial W_{jk}} = \sum_{i = 1}^m \frac{\partial L}{\partial y_i}\frac{\partial y_i}{\partial W_{jk}} = \frac{\partial L}{\partial y_i}x_k

观察 LWjk\frac{\partial L}{\partial W_{jk}} 下标,jj 为行索引,kk 为列索引,Ly,x\frac{\partial L}{\partial \boldsymbol{y}}, \boldsymbol{x} 都为列向量,因此梯度矩阵可以化简为一个列向量乘一个行向量

LW=LyxT\frac{\partial L}{\partial \boldsymbol{W}} = \frac{\partial L}{\partial \boldsymbol{y}}\boldsymbol{x}^T

从此出可以看出梯度矩阵的秩是很低的,意味着在某些情况下,模型中可能存在冗余参数。同时低秩梯度矩阵可能会导致优化过程中的平坦区域(plateaus),使得梯度下降算法收敛速度变慢。这是因为当梯度矩阵接近低秩时,Hessian矩阵(二阶导数矩阵)可能会变得病态(ill-conditioned)

由于低秩矩阵的特殊结构,可以在不损失太多精度的情况下对它们进行近似处理,比如通过奇异值分解(SVD)保留主要成分。可以用来加速训练过程

记上面的输入层的输入为 x\boldsymbol{x},hidden layer 之后带有一个非线性函数 ff,hidden layer 的输出记为 b\boldsymbol{b},输入,输出及中间隐藏层的关系式为:

b=f(Vx+b0)y=Wb+b1\boldsymbol{b} = f(\boldsymbol{Vx} + \boldsymbol{b}_0) \\ \boldsymbol{y} = \boldsymbol{Wb} + \boldsymbol{b}_1 \\

其中 b0,b1\boldsymbol{b}_0, \boldsymbol{b}_1 分别为对应层的偏置项

对矩阵乘法与 einsum

​ 我们思考为什么上面的标量 yiy_iW\boldsymbol{W} 求导会使雅可比 tensor yiW\frac{\partial y_i}{\partial \boldsymbol{W}} 出现稀疏性,原因是在于矩阵乘法是对于一个维度求和进行的,因此另一个维度的元素并没有涉及,导致雅可比矩阵是二维的结构但是完全可以压缩到一维上面去。称雅可比矩阵有一个维度冗余

​ 下面我们去推导更加一般的结论,由于矩阵乘法是 einsum 的一种特例,我们直接对 einsum 总结一般性规律。我们定义基础算子为:对某一维度进行对应元素相乘并求和,那么矩阵乘法就是进行广播-基础算子-转置的结果,那么如果 einsum22 个维度进行这个基础算子运算,例如:

1
2
# A.shape:[b, i, j ,k], B.shape[j, k], C.shape:[b, i]
C = torch.einsum('b i j k, j k -> b i', A, B)

那么虽然 C\boldsymbol{C}A,B\boldsymbol{A,B} 的雅可比矩阵的维度数分别是 6,46,4,但是同理上面的矩阵乘法,A\boldsymbol{A} 的前两个维度并没有参与计算,因此 A\boldsymbol{A} 的雅可比矩阵会出现两个维度的冗余,因此计算 A\boldsymbol{A} 的时候只需要 22 维的 tensor 就行了:

LAbijk=bi(LCbiCbiAbijk)Cbi=jkAbijkBjkLAbijk= LCbiBjk \frac{\partial L}{\partial \boldsymbol{A}_{bijk}} = \sum_{bi}\left( \frac{\partial L}{\partial \boldsymbol{C}_{bi}}\frac{\partial \boldsymbol{C}_{bi}}{\partial \boldsymbol{A}_{bijk}}\right) \\ \boldsymbol{C}_{bi} = \sum_{jk} \boldsymbol{A}_{bijk}\boldsymbol{B}_{jk} \\ \Rightarrow \frac{\partial L}{\partial \boldsymbol{A}_{bijk}} = \ \frac{\partial L}{\partial \boldsymbol{C}_{bi}} \cdot \boldsymbol{B}_{jk}

因此可以推出 A\boldsymbol{A} 梯度的表达式为:

1
dL_dA = torch.einsum('b i, j k', dL_dC, B)

image-20241231233020026

记法说明:记上面的输入层的输入为 x\boldsymbol{x},hidden layer 之后带有一个非线性函数 ff,hidden layer 的输出记为 b\boldsymbol{b},输入,输出及中间隐藏层的关系式为:

b=f(Vx+b0)y=Wb+b1\boldsymbol{b} = f(\boldsymbol{Vx} + \boldsymbol{b}_0) \\ \boldsymbol{y} = \boldsymbol{Wb} + \boldsymbol{b}_1 \\

其中 b0,b1\boldsymbol{b}_0, \boldsymbol{b}_1 分别为对应层的偏置项

损失函数对 W\boldsymbol{W} 的偏导数:

LW=LyyW=LybT\frac{\partial L}{\partial \boldsymbol{W}} = \frac{\partial L}{\partial \boldsymbol{y}}\frac{\partial \boldsymbol{y}}{\partial \boldsymbol{W}} = \frac{\partial L}{\partial \boldsymbol{y}}\boldsymbol{b}^T

损失函数对 b1{\boldsymbol{b}}_1 的偏导数:

Lb1=Lyyb1=Ly\frac{\partial L}{\partial \boldsymbol{b}_1} = \frac{\partial L}{\partial \boldsymbol{y}}\frac{\partial \boldsymbol{y}}{\partial \boldsymbol{b}_1} = \frac{\partial L}{\partial \boldsymbol{y}}

  1. 对于输入层到隐藏层的连接 (V 和 b0)\left( {\boldsymbol{V}\text{ 和 }{\boldsymbol{b}}_0}\right) :
    ○ 损失函数对 b\boldsymbol{b} 的偏导数(这一步是将输出层的误差传递回隐藏层):

Lb=Lyyb=WTLy\frac{\partial L}{\partial \boldsymbol{b}} = \frac{\partial L}{\partial \boldsymbol{y}}\frac{\partial \boldsymbol{y}}{\partial \boldsymbol{b}} = {\boldsymbol{W}}^T\frac{\partial L}{\partial \boldsymbol{y}}

  • 现在我们可以计算 V\boldsymbol{V}b0{\boldsymbol{b}}_0 的偏导数了,但是首先我们需要应用链式法则来考虑激活函数 ff 的影响:

L(Vx+b0)=Lbf(Vx+b0)\frac{\partial L}{\partial \left( {\boldsymbol{V}\boldsymbol{x} + {\boldsymbol{b}}_0}\right) } = \frac{\partial L}{\partial \boldsymbol{b}} \odot {f}^{\prime }( {\boldsymbol{V}\boldsymbol{x} + {\boldsymbol{b}}_0})

  • ○ 损失函数对 V\boldsymbol{V} 的偏导数:

LV=(Lbf(Vx+b0))xT\frac{\partial L}{\partial \boldsymbol{V}} = \left( {\frac{\partial L}{\partial \boldsymbol{b}} \odot {f}^{\prime }\left( {\boldsymbol{V}\boldsymbol{x} + {\boldsymbol{b}}_0}\right) }\right) {\boldsymbol{x}}^T

损失函数对 b0{\boldsymbol{b}}_0 的偏导数:

Lb0=Lbf(Vx+b0)\frac{\partial L}{\partial {\boldsymbol{b}}_0} = \frac{\partial L}{\partial \boldsymbol{b}} \odot {f}^{\prime }\left( {\boldsymbol{V}\boldsymbol{x} + {\boldsymbol{b}}_0}\right)