前言
在笔者学习机器学习与深度的过程中,发现存在一些数学知识是会被经常用到,但是在理工科数学三大件(即高数,线代,概统)中未被提及的,例如本文的主题矩阵微分,或者叫矩阵求导。在数学分析中,我们会接触到的求导运算通常是如dx/dy一类标量对标量的求导,而矩阵微分,顾名思义就是涉及到向量与矩阵的求导。最典型的例子就是梯度下降,在梯度下降中,关键步骤就是损失函数对参数求导∂L/∂w,损失函数值是一个标量,但是参数是一个向量,因此这是一个标量对向量的求导,这就涉及到矩阵微分了。
本文将从机器学习的角度,介绍矩阵微分中机器学习中最常用的部分,并不会详细的介绍整个矩阵微分理论,且本文会着重于数学推导,而不是简单的堆砌矩阵微分的计算公式。
注1:对梯度下降不了解的可以移步本站的梯度下降blog
注2:分享一个在线计算矩阵微分的网站Matrix Calculus,可以用于验证计算结果
定义
首先,我们要了解一下矩阵微分的定义,微分运算的自变量与因变量均有可能是标量,向量以及矩阵,因此有九种可能的微分运算,如下表
自变量\因变量 |
标量 y |
向量 y |
矩阵 Y |
标量 x |
∂x∂y |
∂x∂y |
∂x∂Y |
向量 x |
∂x∂y |
∂x∂y |
∂x∂Y |
矩阵 X |
∂X∂y |
∂X∂y |
∂X∂Y |
在上述的九种微分运算中,本文会关注标量对向量,矩阵的微分运算,以及向量,矩阵对标量的微分运算,也就是上表的第一行与第一列,因为这四种微分运算在机器学习与深度学习中最常被使用。
现在,让我们进入正题,讲讲我们关注的四种矩阵微分的定义,其实非常简单,首先是标量对向量,矩阵的微分运算,以向量为例,记
x=(x1,x2,⋯,xm)T
则标量y对向量x求导为
∂x∂y=(∂x1∂y,∂x2∂y,⋯,∂xm∂y)T
其实就是标量y对向量x中的每一个分量分别求导,然后再拼回一个向量,对矩阵求导则同理,即标量y对矩阵X中的每一个分量分别求导,然后再拼回一个矩阵。
不仅如此,向量,矩阵对标量的微分运算也是类似的,还是以向量为例,记
y=(y1,y2,⋯,yn)T
则向量y对标量x求导为
∂x∂y=(∂x∂y1,∂x∂y2,⋯,∂x∂yn)T
就是向量y的每一个分量分别对标量x求导,然后再拼回一个向量,矩阵同理。
分子布局与分母布局
说完了定义,我们还需要补充一个很重要的内容,矩阵微分的布局,其用于规定矩阵微分计算结果的形状,防止出现混乱。矩阵微分的布局一共有两种,分别为分子布局与分母布局,其定义也很简单,分子布局代表计算结果的形状以分子的形状为主,而分母布局代表计算结果的形状以分母的形状为主,且分子布局与分母布局之间互为转置。以标量y对向量x求导为例,向量x的形状为m×1,且向量x的分母,因此在分母布局下,∂y/∂x的形状为m×1,而在分子布局下,只需要进行转置即可,即∂y/∂x的形状为1×m。
对于本文关注的四种矩阵微分,其两种布局下的形状如下表
自变量\因变量 |
标量 y |
向量 yn×1 |
矩阵 Yn×m |
标量 x |
|
∂x∂y 分子布局:n×1 分母布局:1×n |
∂x∂Y 分子布局:n×m 分母布局:m×n |
向量 xm×1 |
∂x∂y 分子布局:1×m 分母布局:m×1 |
|
|
矩阵 Xm×n |
∂X∂y 分子布局:n×m 分母布局:m×n |
|
|
注:布局相关的内容主要还是针对本文不涉及的四种矩阵微分运算的,因为其分子分母均为向量及矩阵,更加需要布局的规范
现在,我们了解了布局,那么在阅读机器学习资料的过程中,我们只要留意资料中的说明,明白资料中使用的是何种布局,就不会产生混乱了。但是,事实上有很多的资料都不会标注啦,这大概是因为大家多数时候都会约定俗成的使用一种“默认布局”(其实更可能是因为除了学数学的之外,大家都没有这种严谨的习惯)。
那么所谓的“默认布局”是分子布局还是分母布局呢?答案并没有这么简单啦,对于本文未提及的四种矩阵微分运算来说,“默认布局”就是分子布局,但是对于本文关注的四种矩阵微分运算来说,“默认布局”是一种被称为混合布局的布局方式。混合布局其实非常简单,让我们忘掉上面说的东西,现在有一个形状为m×1的向量x,那么∂y/∂x的形状是什么?大多数人的直觉肯定会觉得,既然向量x的形状为m×1,y又是一个标量,那∂y/∂x的形状就按照向量x的形状来,是m×1。同样的,有一个形状为n×1的向量y,那么∂y/∂x的形状是什么?按照上面的逻辑,应该是n×1。这时候我们会发现,在这种按直觉进行的布局中,对于∂y/∂x,我们使用了分母布局,而对于∂y/∂x,我们使用了分子布局,这就是所谓的混合布局,也是所谓的“默认布局”。
定义法
接着让我们回归正题,既然讲完了定义,那么随着而来的自然是矩阵微分的定义法,从定义中我们可以知道,矩阵微分可以被拆分为多个标量求导运算,标量求导的运算我们都会算,因此把矩阵微分按照定义,拆分成多个标量求导运算,就是定义法。当然,这种方法非常的简单粗暴,因此其主要的用途是用于推导一些矩阵微分的计算公式,例如,y=aTx,其中a与x的形状都为m×1,计算∂y/∂x,根据定义法,将该矩阵微分拆解为y对x每一个分量分别求导,将y拆开为
y=a1x1+a2x2+⋯+amxm=i=1∑maixi
y对xi求导为
∂xi∂y=ai
故
∂x∂y=(∂x1∂y,∂x2∂y,⋯,∂xm∂y)T=(a1,a2,⋯,am)T=a
因此我们得到了一个矩阵微分的求导公式
∂x∂aTx=a
多数矩阵微分的求导公式都可以通过定义法进行推导。
微分法
介绍完矩阵微分的定义以及定义法,对于一些简单的矩阵微分,大家已经可以进行计算了,但是,这显然不够优雅,定义法本质上还是在算标量微分,我们应该需要一种在矩阵层面进行操作的方法,也就是本章节要介绍的微分法,这也是本文的绝对核心方法。
矩阵形式
在正式开始介绍微分法之前,我想先做一些相关补充,以尽可能降低读者的理解门槛,由于接下来我们将从向量与矩阵的角度来进行操作,因此,会涉及到大量的矩阵形式,这个名字是我自己随便取的,大概意思就是会将标量内容写作向量或矩阵的形式,例如一个多元函数
y(x1,x2)=a1x1+a2x2
写作向量形式就是
y(x1,x2)=a1x1+a2x2=(a1,a2)(x1x2)=aTx
再例如
y(x1,x2,⋯,xm)=x12+x22+⋯+xm2=(x1,x2,⋯,xm)⎝⎜⎜⎜⎛x1x2⋯xm⎠⎟⎟⎟⎞=xTx
这都是常见的矩阵形式的例子。矩阵形式并不是什么复杂的内容,且在后文中会大量使用,但是笔者发现鲜有教程提及,因此在此做为补充,希望对读者有所帮助。
注:如果读者对于矩阵形式不甚熟悉,我能给予的建议就是自己动手进行推导验证,包括后文的一些公式推导也是如此,主要的目的是让自己熟悉矩阵形式。其实学数学就是这样,对于一个陌生的概念与定理,最好的理解方法就是动手去算一些例子,去自己推导一遍定理,这一点数学系出身的笔者深有体会。
微分法的数学推导
微分
现在,让我们进入正题,所谓微分法,自然跟微分有关系,我知道,微分这个概念在数学分析中存在跟很低,大部分人都只记得导数了,所以这里我们简单复习一下,从一元函数开始,定义如下
df=f′(x)dx
到了多元函数,情况会更复杂一点,其定义为
df=∂x1∂fdx1+∂x2∂fdx2+⋯+∂xm∂fdxm
上面的定义也被称为全微分,如果修过高数或者数分的读者,全微分期末考试应该是必考的。
标量对向量及矩阵的微分运算
现在,让我们结合一下上一小节提到的矩阵形式,将全微分写为矩阵形式,则为
df=∂x1∂fdx1+∂x2∂fdx2+⋯+∂xm∂fdxm=(∂x1∂f,∂x2∂f,⋯,∂xm∂f)⎝⎜⎜⎜⎛dx1dx2⋯dxm⎠⎟⎟⎟⎞=(∂x∂f)Tdx
至此,我们就得到了微分法的核心依赖公式,即
df=(∂x∂f)Tdx
这个公式说明,只要求出因变量的微分df,就可以得到矩阵微分∂f/∂x,例如我们计算得到df=Adx,则∂f/∂x=AT。
再进一步,如果自变量为矩阵X,则
df=∂x11∂fdx11+∂x12∂fdx12+⋯+∂xnm∂fdxnm=tr⎣⎢⎢⎢⎢⎡⎝⎜⎜⎜⎜⎛∂x11∂f∂x12∂f⋯∂x1m∂f∂x21∂f∂x22∂f⋯∂x2m∂f⋯⋯⋯⋯∂xn1∂f∂xn2∂f⋯∂xnm∂f⎠⎟⎟⎟⎟⎞⎝⎜⎜⎜⎛dx11dx21⋯dxn1dx12dx22⋯dxn2⋯⋯⋯⋯dx1mdx2m⋯dxnm⎠⎟⎟⎟⎞⎦⎥⎥⎥⎥⎤=tr[(∂X∂f)TdX]
至此,我们也得到了矩阵情况下的微分法核心依赖公式,即
df=tr[(∂X∂f)TdX]
注1:上述公式中的tr代表矩阵的迹(trace),其定义为矩阵对角线元素的和
注2:上述矩阵情况下的公式读者可以自行验证,也是一种常见的矩阵形式写法
向量及矩阵函数对标量的微分运算
而对于向量及矩阵函数对标量的微分运算,情况则更加简单,因为自变量是一个标量,因此其等价于一元函数的微分,即
df=∂x∂fdx,dF=∂x∂Fdx
微分法
综上,我们可以得到使用微分法计算矩阵微分的步骤
- 计算因变量f,f或F的微分
- 根据微分与导数的关系,得到导数
四种矩阵微分运算与其导数的关系式如下表
自变量\因变量 |
标量 y |
向量 y |
矩阵 Y |
标量 x |
|
dy=∂x∂ydx |
dY=∂x∂Ydx |
向量 x |
dy=(∂x∂y)Tdx |
|
|
矩阵 X |
dy=tr[(∂X∂y)TdX] |
|
|
微分运算法则
既然微分法的核心是计算微分,那我们就需要先学习一下矩阵微分的运算法则
- 加减法:d(X±Y)=dX±dY
- 矩阵乘法:d(XY)=dXY+XdY
- 转置:d(XT)=(dX)T
- 逆:d(X−1)=−X−1dXX−1
- 行列式:d∣X∣=tr(X∗dX),X∗表示X的伴随矩阵,若X可逆,则d∣X∣=tr(X−1dX)
- 迹:d(tr(X))=tr(dX)
- Hadamard积:d(X⊙Y)=dX⊙Y+X⊙dY,Hadamard表示两个形状相同的向量或矩阵逐元素相乘
- 逐元素函数:dσ(X)=σ′(X)⊙dX
这里单独说一下逐元素函数,对于大多数标量函数来说,例如指数函数ex,将其输入改为向量或矩阵,意为使用该函数对向量或矩阵的每一个分量进行函数运算,即σ(X)=[σ(Xij)],对这种求微分,就可以使用上述公式,以指数函数ex为例
deX=eX⊙dX
迹的运算法则与迹技巧
然后,由于在标量对矩阵求导的情况中,会涉及到矩阵的迹运算,因此,我们还需要复习一下矩阵的迹的运算法则以及一些常用的公式,一般称之为迹技巧(trace trick)
- 加减法:tr(A±B)=tr(A)±tr(B)
- 转置:tr(AT)=tr(A)
- 矩阵乘法交换:tr(AB)=tr(BA),A与BT的形状相同
- 矩阵乘法复合Hadamard积:tr(AT(B⊙C))=tr((A⊙B)TC)
并且,由于标量套上迹之后没有任何变化,因此,即使在不涉及迹的情况下,也经常会出现给标量套上迹,然后使用迹技巧的做法。
实例
经过了无聊的堆公式环节,让我们正式进入实例环节
例1:y=aTXb,a的形状为n×1,X的形状为n×m,b的形状为m×1
dy=aTdXb=tr(aTdXb)=tr(baTdX)
故
∂X∂y=(baT)T=abT
注:上述计算过程使用了tr(AB)=tr(BA),把aTdX看作A,b看作B,这是该公式的常见用法
例2:f=aTe(Xb),a的形状为n×1,X的形状为n×m,b的形状为m×1
dy=aTd(e(Xb))=aT(e(Xb)⊙d(Xb))=tr[aT(e(Xb)⊙d(Xb))]=tr[(a⊙e(Xb))Td(Xb)]=tr[(a⊙e(Xb))TdXb]=tr[b(a⊙e(Xb))TdX]
故
∂X∂y=(b(a⊙e(Xb))T)T=(a⊙e(Xb))bT
链式法则
接着,我们来介绍矩阵微分中的链式法则,这也是矩阵微分计算中非常重要的技巧。链式法则用于计算复合函数的导数,简单来说,假设有两个函数y=f(x)以及x=h(t),可以通过计算∂y/∂x以及∂x/∂t来算得∂y/∂t。对于标量的情况,只需要简单进行相乘即可,即
∂t∂y=∂x∂y∂t∂x
不过矩阵微分的情况要稍微复杂一点,这里我们直接通过一个例子来说明
例:将上文例1进行简单的修改,依旧是y=aTXb,再定义X=tcT,其中c的形状为m×1,t的形状为n×1,计算∂y/∂t
首先,我们需要计算∂y/∂X,将上文的计算结果拿过来,即
dy=tr(baTdX)
保留微分形式是因为后续我们还需要在微分形式上做运算,接下来我们来计算dX
dX=dtcT
将其代入上述dy的计算式中
dy=tr(baTdX)=tr(baTdtcT)=tr(cTbaTdt)
故
∂t∂y=(cTbaT)T=abTc
如果有多层嵌套也是一样的,假设上述t之后还有一层复合,那么就继续计算dt,然后代入。
矩阵微分在机器学习中的应用
最后,我们来讲一些机器学习中的经典例子
例1(线性回归):损失函数L=∣∣y−Xw∣∣2,计算∂L/∂w
计算损失函数L的微分
dL=d∣∣y−Xw∣∣2=d(y−Xw)T(y−Xw)=d(y−Xw)T(y−Xw)+(y−Xw)Td(y−Xw)=(−Xdw)T(y−Xw)+(y−Xw)T(−Xdw)=−tr[(Xdw)T(y−Xw)]−tr[(y−Xw)TXdw]=−2tr[(y−Xw)TXdw]=−2(y−Xw)TXdw=2(Xw−y)TXdw
故
∂w∂L=2XT(Xw−y)
例2(logistic回归):损失函数L(w)=−yTlny^−(1−y)Tln(1−y^),其中y^=σ(Xw),sigmoid函数σ(x)=1/1+e−x,计算∂L/∂w
由于损失函数比较复杂,使用链式法则进行逐层运算,首先计算∂L/∂y^
dL=−yTdlny^−(1−y)Tdln(1−y^)=−yT(y^1⊙dy^)−(1−y)T(1−y^1⊙dy^)=(1−y)T(1−y^1⊙dy^)−yT(y^1⊙dy^)=tr[(1−y)T(1−y^1⊙dy^)]−tr[yT(y^1⊙dy^)]
tr[(1−y)T(1−y^1⊙dy^)]=tr[((1−y)⊙1−y^1)Tdy^]
tr[yT(y^1⊙dy^)]=tr[(y⊙y^1)Tdy^]
dL=tr[((1−y)⊙1−y^1)Tdy^]−tr[(y⊙y^1)Tdy^]=[((1−y)⊙1−y^1)T−(y⊙y^1)T]dy^
接着计算dy^
dy^=dσ(Xw)=σ′(Xw)⊙d(Xw)=σ′(Xw)⊙(Xdw)
其中
σ′(Xw)=σ(Xw)(1−σ(Xw))=y^(1−y^)
将其代入上式
dL=[((1−y)⊙1−y^1)T−(y⊙y^1)T]dy^=[((1−y)⊙1−y^1)T−(y⊙y^1)T]σ′(Xw)⊙(Xdw)=tr[[((1−y)⊙1−y^1)−(y⊙y^1)]Tσ′(Xw)⊙(Xdw)]=tr[[[((1−y)⊙1−y^1)−(y⊙y^1)]⊙σ′(Xw)]TXdw]=tr[[((1−y)⊙y^)−(y⊙(1−y^))]TXdw]=[((1−y)⊙y^)−(y⊙(1−y^))]TXdw
故
∂w∂L=XT[((1−y)⊙y^)−(y⊙(1−y^))]
例3(softmax回归):损失函数为L(W)=−tr(YTlnY^),其中Y^=softmax(XW),计算∂L/∂w
由于在数学上,多样本的softmax函数写起来比较复杂,并且也不利于后续操作,因此,这里我们可以选择曲线救国,先计算单样本情况下的导数,再使用定义法扩展至多样本,会简单不少。在单样本的情况下,损失函数为L(w)=−yTlny^,其中y^=softmax(Xw),softmax函数为softmax(x)=ex/1Tex。
定义z=Xw,并将y^代入,因为softmax函数中包含指数运算,代入之后配合对数运算会好处理很多
L(w)=−yTlny^=−yTln1Tezez=−yTz+yT1ln(1Tez)=−yTz+ln(1Tez)
注意到ln(a/b)=lna−1lnb,a为向量,b为标量,因此需要使用全一向量进行维度扩展,且yT1=1,因为y是只有一个分量为1,剩余分量均为0的向量。
再计算微分
dL=−yTdz+dln(1Tez)=−yTdz+(1Tez1⊙d1Tez)=−yTdz+[1Tez1⊙(1T(ez⊙dz))]=−yTdz+1Tez(ez)Tdz=(1Tez(ez)T−yT)dz=(softmax(z)T−yT)dz
上述运算中,有两个需要提及的点,第一,由于1Tez是标量,所以逐元素相乘相当于直接相乘,第二,上述计算中使用了公式
1T(a⊙b)=aTb⇒1T(ez⊙dz)=(ez)Tdz
综上
∂z∂L=softmax(z)−y
扩展到多样本为
∂Z∂L=softmax(Z)−Y
其中Z=XW,计算dZ=XdW,故
dL=(softmax(Z)T−YT)XdW
∂W∂L=XT(softmax(Z)−Y)=XT(Y^−Y)