LSTM参数更新推导

本文转自:https://zybuluo.com/hanbingtao/note/581764,对其进行一定的整理。

LSTM前向计算

Understanding LSTM Networks一文中,介绍了LSTM的基本原理。LSTM网络使用了门(gete)的概念,门实际上就是一层全连接,它的输入是一个向量,输出是一个0到1之间的实数向量。假设W是门的权重向量,b是偏置项,那么门可以表示为:$$
g(x) =\sigma(Wx+b)
$$
门的使用,就是用门的输出向量按元素乘以要控制的向量。因为门的输出是0到1之间的实数向量。所以,当门输出为0时,任何向量与之相乘都会得到0向量,这就相当于不能通过;当输出为1时,任何向量与之相乘都不会有任何改变,相当都通过。因为$\sigma$函数的值域是(0,1),所以门的状态都是半开半闭的。

典型的LSTM的网络架构图如下图所示:


mark

相比RNN网络,LSTM新增加状态C,称为单元状态(cell state),如下图所示:


mark

图引自:https://zybuluo.com/hanbingtao/note/581764


从图中可以看出,LSTM的输入有三个:当前时刻网络的输入$x_t$、上一时刻LSTM的输出值$h_{t-1}$以及上一时刻的单元状态$c_{t-1}$。LSTM的输出有两个:当前时刻LSTM输出值$h_t$和当前时刻的单元状态$c_t$。在这里x、h、c都是向量。

LSTM中引入了三个门:遗忘门(forget gate)、输入门(input gate)和输出门(output gate)

  • 遗忘门:它决定了上一时刻的单元状态$c_{t-1}$有多少保留到当前时刻$c_t$;
  • 输入门:它决定了当前时刻网络的输入$x_t$有多少保留到单元状态$c_t$。
  • 输出门:来控制单元状态$c_t$有多少输出到LSTM的当前输出值$h_t$。

遗忘门计算
$$f_t = \sigma(W_f\cdot[h_{t-1},x_t]+b_f) (公式1)$$
其中,$W_f$是遗忘门的权重矩阵,$[h_{t-1},x_t]$是表示把两个向量连接成一个更长的向量,$b_f$是遗忘门的偏置项,$\sigma$是sigmoid函数。事实上权重矩阵$W_f$是由两个矩阵拼接而成的,一个是$W_{fh}$,它对应着输入项$h_{t-1}$,一个是$W_{fx}$。$W_f$可以写为:$$
[W_f]\begin{bmatrix}h_{t-1}\\\\x_t\end{bmatrix}=\begin{bmatrix}
W_{fh}&W_{fx}
\end{bmatrix}\begin{bmatrix}h_{t-1}\\\\x_t\end{bmatrix}=W_{fh}h_{t-1}+W_{fx}x_t
$$

输出门计算
$$i_t=\sigma(W_i\cdot[h_{t-1},x_t]+b_i) (公式2)$$

单元状态$\hat{c_t}$
$$\hat{c_t}=tanh(W_c\cdot[h_{t-1},x_t]+b_c) (公式3)$$
单元状态$c_t$,它是有上一时刻的单元状态$c_{t-1}$按元素乘以遗忘门$f_t$,再用当前输入的单元状态$\hat{c_t}$按元素乘以输入门$i_t$,再将两个积加和产生的:
$$c_t = f_t \bigodot c_{t-1}+i_t \bigodot \hat{c_t} (公式4)$$
其中$\bigodot$表示按位相乘。这样就把LSTM关于当前的记忆$\hat{c_t}$和长期的记忆$c_{t-1}$组合在一起,形成了新的状态单元$c_t$。由于遗忘门的控制,它可以保存很久很久之前的信息,由于输入门的控制,又可以避免当前无关紧要的内容进入记忆。下面看一下输出门,它控制了长期记忆对当前输出的影响:
输出门
$$o_t=\sigma(W_o\cdot[h_{t-1},x_t]+b_o) (公式5)$$

LSTM最终的输出,是由输出门和单元状态共同确定的:
$$h_t = o_t \bigodot tanh(c_t) (公式6)$$
从公式1到公式6就是LSTM的前向计算的全部公式。
LSTM前向传播的更新过程如下:

  1. 更新遗忘门的输出:$$f_t = \sigma(W_f\cdot[h_{t-1},x_t]+b_f)$$
  2. 更新输入门的两部分输出:$$i_t=\sigma(W_i\cdot[h_{t-1},x_t]+b_i)\\\\\hat{c_t}=tanh(W_c\cdot[h_{t-1},x_t]+b_c)$$
  3. 更新细胞状态:$$c_t = f_t \bigodot c_{t-1}+i_t \bigodot \hat{c_t}$$
  4. 更新输出门状态:$$o_t=\sigma(W_o\cdot[h_{t-1},x_t]+b_o)\\\\h_t = o_t \bigodot tanh(c_t)$$
  5. 更新当前序列输出:$$\hat {y_t}=\sigma(Vh_t+b_y)$$

LSTM反向传播

LSTM的训练算法,仍然是反向传播算法,主要有三个步骤:

  1. 前向计算每个神经元的输出值,对于LSTM来说,即$f_t、i_t、c_t、o_t、h_t$ 5组向量。
  2. 反向计算每个神经元的误差项$\delta$。与循环神网络一样,LSTM误差项的反向传播也是包括两个方向:一个是沿时间的反向传播,即从当前时刻t开始,计算每个时刻的误差项;一个是将误差项向上一层传播。
  3. 根据相应的误差项,计算每个权重的梯度。

LSTM需要学习的参数共有8组,分别是:遗忘门的权重矩阵$W_f$和偏置项$b_f$、输入门的权重矩阵$W_i$和偏置项$b_i$、输出门的权重矩阵$W_o$和偏置项$b_o$以及计算单元状态的权重矩阵$W_c$和偏置项$b_c$。因为权重矩阵的两部分在反向传播中使用不同的公式,因此,权重矩阵$W_f$、$W_i$、$W_c$、$W_o$都将被写为分开的两个矩阵:$W_{fh}$、$W_{fx}$、$W_{ih}$、$W_{ix}$、$W_{ch}$、$W_{cx}$、$W_{oh}$、$W_{ox}$。

在t时刻,LSTM的输出值为$h_t$,定义t时刻的误差项为$\delta_t$为:$$
\delta_t = \frac {\partial E}{\partial h_t}$$

因为LSTM有四个加权输入,分别为$f_t、i_t、c_t、o_t、h_t$,定义这四个加权输入,以及他们对应的误差项。$$
net_{f,t} = W_f[h_{t-1},x_t]+b_f=W_{fh}h_{t-1}+W_{fx}x_t+b_f\\\\
net_{i,t} = W_i[h_{t-1},x_t]+b_i=W_{ih}h_{t-1}+W_{ix}x_t+b_i\\\\
net_{\hat{c},t} = W_c[h_{t-1},x_t]+b_f=W_{ch}h_{t-1}+W_{cx}x_t+b_c\\\\
net_{o,t} = W_o[h_{t-1},x_t]+b_o=W_{oh}h_{t-1}+W_{ox}x_t+b_o\\\\
\delta_{f,t}=\frac {\partial E}{\partial net_{f,t}}\\\\
\delta_{i,t}=\frac {\partial E}{\partial net_{i,t}}\\\\
\delta_{\hat{c},t}=\frac {\partial E}{\partial net_{\hat{c},t}}\\\\
\delta_{o,t}=\frac {\partial E}{\partial net_{o,t}}
$$

误差项沿时间反向传递

沿时间反向传递误差项,就是要计算出t-1时刻的误差项$\delta_{t-1}$。
$$
\delta_{t-1}^T=\frac {\partial E}{\partial h_{t-1}}=\frac {\partial E}{\partial h_{t}}\frac {\partial h_t}{\partial h_{t-1}}=\delta_t^T\frac {\partial h_t}{\partial h_{t-1}}
$$
其中,$\frac {\partial h_t}{\partial h_{t-1}}$是一个Jacobian矩阵。如果隐藏层h的维度是N的话,那么它就是一个N*N的矩阵。为了求出它,先列出$h_t$的计算公式,即公式6和公式4:
$$
h_t=o_t\bigodot tanh(c_t)\\\\c_t=f_t \bigodot c_{t-1}+i_t \bigodot \hat {c_t}
$$
可以看出,$o_t、f_t、i_t、\hat{c_t}$都是$h_{t-1}$的函数,那么利用全导数公式可得:$$
\delta_t^T\frac {\partial h_t}{\partial h_{t-1}}=\delta_t^T\frac {\partial h_t}{\partial o_t}\frac {\partial o_t}{\partial net_{o,t}}\frac {\partial net_{o,t}}{\partial h_{t-1}}+\delta_t^T\frac {\partial h_t}{\partial c_t}\frac {\partial c_t}{\partial f_t}\frac{\partial f_t}{\partial net_{f,t}}\frac {\partial net_{f,t}}{\partial h_{t-1}}+\delta_t^T\frac{\partial h_t}{\partial c_t}\frac {\partial c_t}{\partial i_t}\frac {\partial i_t}{\partial net_{i,t}}\frac {\partial net_{i,j}}{\partial h_{t-1}}\\\\=\delta_{o,t}^T\frac {\partial net_{o,t}}{\partial h_{t-1}}+\delta_{f,t}^T\frac {\partial net_{f,t}}{\partial h_{t-1}}+\delta_{i,t}^T\frac {\partial net_{i,t}}{\partial h_{t-1}}+\delta_{\hat{c_t},t}^T\frac{\partial net_{\hat{c_t},t}}{\partial h_{t-1}}\\\(公式7)
$$
下面要把公式7中的每个偏导数都求出来,根据公式6,我们可以求出:$$
\frac {\partial h_t}{\partial o_t}= diag[tanh(c_t)]\\\\
\frac {\partial h_t}{\partial c_t}= diag[o_t\bigodot(1-tanh(c_t)^2)]
$$
根据公式4,可以求出:$$
\frac {\partial c_t}{\partial f_t} = diag[c_{t-1}]\\\\
\frac {\partial c_t}{\partial i_t} = diag[\hat{c_t}]\\\\
\frac {\partial c_t}{\partial \hat{c_t}} =diag[i_t]
$$
因为:$$\begin{aligned}
&o_t = \sigma(net_{o,t})\\\\
&net_{o,t} = W_{oh}h_{t-1}+W_{ox}x_t+b_o\\\
&f_t = \sigma(net_{f,t})\\\
&net_{f,t} = W_{ft}h_{t-1}+W_{fx}x_t+b_f\\\
&i_t = \sigma(net_{i,t})\\\\
&net_{i,t}=W_{ih}h_{t-1}+W_{ix}x_t+b_i\\\
&\hat{c_t}=tanh(net_{\hat{c},t})\\\\
&net_{\hat{c},t}=W_{ch}h_{t-1}+W_{cx}x_t+b_c
\end{aligned}
$$
很容易得出:
$$\begin{aligned}
&\frac {\partial o_t}{\partial net_{o,t}}= diag[o_t\bigodot(1-o_t)]\\\\
&\frac {\partial_{o,t}}{\partial h_{t-1}}=W_{oh}\\\\
&\frac {\partial f_t}{\partial net_{f,t}}=diag[f_t\bigodot(1-f_t)]\\\\
&\frac {\partial net_{f,t}}{\partial h_{t-1}}=W_{fh}\\\\
&\frac {\partial i_t}{\partial net_{i,j}}=diag[i_t\bigodot(1-i_t)]\\\\
&\frac {\partial net_{i,t}}{\partial h_{t-1}}=W_{ih}\\\\
&\frac {\partial \hat{c_t}}{\partial net_{\hat{c},t}}=diag[1-\hat{c_t}^2]\\\\
&\frac {\partial net_{\hat{c},t}}{\partial h_{t-1}}=W_{ch}
\end{aligned}
$$
将上述偏导数带入公式7,可以得到:$$
\delta_{t-1} = \delta_{o,t}^T\frac {\partial net_{o,t}}{\partial h_{t-1}}+\delta_{f,t}^T\frac {\partial net_{f,t}}{\partial h_{t-1}}+\delta_{i,t}^T\frac {\partial net_{i,t}}{\partial h_{t-1}}+\delta_{\hat{c},t}^T\frac {\partial net_{\hat{c},t}}{\partial h_{t-1}}=\delta_{o,t}^TW_{oh}+\delta_{f,t}^TW_{fh}+\delta_{i,t}^TW_{i,h}+\delta_{\hat{c},t}^TW_{ch} (公式8)$$

根据$\delta_{o,t}、\delta_{f,t}、\delta_{i,t}、\delta_{\hat{c},t}$的定义,可知:
$$
\begin{aligned}
&\delta_{o,t}^T = \delta_t^T\bigodot tanh(c_t)\bigodot o_t\bigodot(1-o_t)(公式9)\\\\
&\delta_{f,t}^T =\delta_t^T \bigodot o_t\bigodot(1-tanh(c_t)^2)\bigodot c_{t-1}\bigodot f_t \bigodot (1-f_t)(公式10)\\\\
&\delta_{i,t}^T = \delta_t^T \bigodot o_t \bigodot(1-tanh(c_t)^2)\bigodot \hat{c_t}\bigodot i_t \bigodot (1-i_t)(公式11)\\\\
&\delta_{\hat{c},t}^T=\delta_t^T\bigodot o_t\bigodot(1-tanh(c_t)^2)\bigodot i_t\bigodot (1-\hat{c}^2)(公式12)
\end{aligned}
$$
公式8到公式12就是将误差沿时间反向传播的一个时刻的公式。有了它,我们可以写出将误差向前传递到任意时刻k的公式:$$
\delta_k^T = \prod_{j=k}^t-1\delta_{o,j}^TW_{oh}+\delta_{f,j}^TW_{fh}+\delta_{i,j}^TW_{ih}+\delta_{\hat{c},j}^TW_{ch}
$$

将误差传递到上一层

假设当前层为第l层,定义第l-1层的误差项是误差函数l-1层加权输入的导数,即:$$
\delta_t^{l-1}=\frac {\partial E}{\partial net_t^{l-1}}
$$
LSTM的输入$x_t$由下面的公式计算:
$$
x_t^l = f^{l-1}(net_t^{l-1})
$$
上式中,$f^{l-1}$表示第l-1层的激活函数。
因为$net_{f,t}^l、net_{i,t}^l、net_{\hat{c},t}^l、net_{o,t}^l$都是$x_t$的函数,$x_t$又是$net_t^{l-1}$的函数,因此,要求出E对$net_t^{l-1}$的导数,就需要使用全导数公式:
$$\begin{aligned}
\frac {\partial E}{\partial net_t^{l-1}}
&=\frac {\partial E}{\partial net_{f,t}^l}\frac {\partial net_{f,t}^l}{\partial x_t^l}\frac {\partial x_t^l}{\partial net_t^{l-1}}+\frac {\partial E}{\partial net_{i,t}^l}\frac{\partial net_{i,t}^l}{\partial x_t^l}\frac {\partial x_t^l}{\partial net_{i,t}^{l-1}}+\frac {\partial E}{\partial net_{\hat{c},t}^l}\frac{\partial net_{\hat{c},t}^l}{\partial x_t^l}\frac {\partial x_t^l}{\partial net_t^{l-1}}+\frac{\partial E}{\partial net_{o,t}^l} \frac{\partial net_{o,t}^l}{\partial x_t^l}\frac{\partial x_t^l}{\partial net_{o,t}^{l-1}}\\\\
&=\delta_{f,t}^TW_{fx}\bigodot f^{\prime}(net_t^{l-1})+\delta_{i,t}^TW_{ix}\bigodot f^{\prime}(net_t^{l-1})+\delta_{\hat{c},t}^TW_{cx}\bigodot f^{\prime}(net_t^{l-1})+\delta_{o,t}^TW_{ox}\bigodot f^{\prime}(net_t^{l-1})\\\\&=
(\delta_{f,t}^TW_{fx}+\delta_{i,t}^TW_{ix}+\delta_{\hat{c},t}^TW_{cx}+\delta_{o,t}^TW_{ox})\bigodot f^{\prime}(net_l^{l-1})
\end{aligned}(公式14)$$

公式14就是将误差传递到上一层的公式。

权重梯度计算

对度$W_{fh}、W_{ih}、W_{ch}、W_{oh}$的权重梯度,我们知道我们知道它的梯度是各个时刻梯度之和,我们首先求出它们在t时刻的梯度,然后再求出他们最终的梯度。我们已经求得误差项$\delta_{o,t}、\delta_{f,t}、\delta_{i,t}、\delta_{\hat{c},t}$很容易求出t时刻$W_{fh}、W_{ih}、W_{ch}、W_{oh}$的梯度。$$
\begin{aligned}
&\frac{\partial E}{\partial W_{oh,t}}=\frac{\partial E}{\partial net_{o,t}}\frac{\partial net_{o,t}}{\partial W_{oh,t}}=\delta_{o,t}h_{t-1}^T\\\\
&\frac{\partial E}{\partial W_{fh,t}}=\frac{\partial E}{\partial net_{f,t}}\frac{\partial net_{f,t}}{\partial W_{fh,t}}=\delta_{f,t}h_{t-1}^T\\\\
&\frac{\partial E}{\partial W_{ih,t}}=\frac{\partial E}{\partial net_{i,t}}\frac{\partial net_{i,t}}{\partial W_{ih,t}}=\delta_{i,t}h_{t-1}^T\\\\
&\frac{\partial E}{\partial W_{ch,t}}=\frac{\partial E}{\partial net_{\hat{c},t}}\frac{\partial net_{\hat{c},t}}{\partial W_{ch,t}}=\delta_{\hat{c},t}h_{t-1}^T
\end{aligned}
$$
将各个时刻的梯度加在一起,就能得到最终的梯度:$$
\begin{aligned}
&\frac{\partial E}{\partial W_{oh}}=\sum_{j=1}^t\delta_{o,j}h_{j-1}^T\\\\
&\frac{\partial E}{\partial W_{fh}}=\sum_{j=1}^t\delta_{f,j}h_{j-1}^T\\\\
&\frac{\partial E}{\partial W_{ih}}=\sum_{j=1}^t\delta_{i,j}h_{j-1}^T\\\\
&\frac{\partial E}{\partial W_{ch}}=\sum_{j=1}^t\delta_{\hat{c},j}h_{j-1}^T\\\\
\end{aligned}
$$
对于偏置项$b_f、b_i、b_c、b_o$的梯度,也是将各个时刻的梯度加在一起,下面是各个时刻的偏置梯度:
$$\begin{aligned}
&\frac{\partial E}{\partial b_{o,t}}=\frac {\partial E}{\partial net_{o,t}}\frac{\partial net_{o,t}}{\partial b_{o,t}}=\delta_{o,t}\\\\
&\frac{\partial E}{\partial b_{f,t}}=\frac {\partial E}{\partial net_{f,t}}\frac{\partial net_{f,t}}{\partial b_{f,t}}=\delta_{f,t}\\\\
&\frac{\partial E}{\partial b_{i,t}}=\frac {\partial E}{\partial net_{i,t}}\frac{\partial net_{i,t}}{\partial b_{i,t}}=\delta_{i,t}\\\\
&\frac{\partial E}{\partial b_{c,t}}=\frac {\partial E}{\partial net_{\hat{c},t}}\frac{\partial net_{\hat{c},t}}{\partial b_{c,t}}=\delta_{\hat{c},t}
\end{aligned}
$$
下面是最终的偏置项的梯度,即将各个时刻的偏置项梯度加在一起:
$$
\begin{aligned}
&\frac {\partial E}{\partial b_o}=\sum_{j=1}^t\delta_{o,j}\\\\
&\frac {\partial E}{\partial b_i}=\sum_{j=1}^t\delta_{i,j}\\\\
&\frac {\partial E}{\partial b_f}=\sum_{j=1}^t\delta_{f,j}\\\\
&\frac {\partial E}{\partial b_c}=\sum_{j=1}^t\delta_{\hat{c},j}
\end{aligned}
$$
对于$W_{fx}、W_{ix}、W_{cx}、W_{ox}$的权重梯度,只需要根据相应的误差项直接计算即可:
$$\begin{aligned}
&\frac {\partial E}{\partial W_{ox}}=\frac{\partial E}{\partial net_{o,t}}\frac{\partial net_{o,t}}{\partial W_{ox}}=\delta_{o,t}x_t^T\\\\
&\frac {\partial E}{\partial W_{fx}}=\frac{\partial E}{\partial net_{f,t}}\frac{\partial net_{f,t}}{\partial W_{fx}}=\delta_{f,t}x_t^T\\\\
&\frac {\partial E}{\partial W_{ix}}=\frac{\partial E}{\partial net_{i,t}}\frac{\partial net_{i,t}}{\partial W_{ix}}=\delta_{i,t}x_t^T\\\\
&\frac {\partial E}{\partial W_{cx}}=\frac{\partial E}{\partial net_{\hat{c},t}}\frac{\partial net_{\hat{c},t}}{\partial W_{cx}}=\delta_{\hat{c},t}x_t^T\\\\
\end{aligned}
$$

<–end–>

坚持原创技术分享,您的支持将鼓励我继续创作!