GRU神经网络
GRU(Gated Recurrent Unit)是LSTM的一种变体,它对LSTM做了很多简化,同时却保持着和LSTM几乎相同的效果。因此,GRU最近变得非常流行。下图是GRU的网络架构图。
GRU对LSTM做了两个大得改动:
- 将输入门、遗忘门和输出门改变为两个门:更新门(Update Gate)$z_t$和重置门(Reset Gate)$r_t$。
- 将单元状态与输出合并为一个状态:h。
根据上图的架构图可以得出GRU的前向计算公式:
$$\begin{aligned}
&r_t = \sigma(W_r \cdot [h_{t-1},x_t])\\\\
&z_t =\sigma(W_z \cdot [h_{t-1},x_t])\\\\
&\hat {h_t}=tanh(W_{\hat {h}} \cdot [r_t \bigodot h_{t-1},x_t])\\\\
&h_t =(1-z_t)\bigodot h_{t-1}+z_t \bigodot \hat {h_t}\\\\
&y_t=\sigma(W_o \cdot h_t)
\end{aligned}$$
GRU的反向传播梯度计算
GRU的参数更新方式同样是基于沿时间反向传播的算法(BPTT),为了为了更清晰的推导GRU反向传播梯度计算,对上文的GRU前向计算公式进行一定的改写,实质上还是一样的,只不过是将参数分开写而已。具体如下:假设,对于t时刻,GRU的输出为$\hat {y_t}$,输入为$x_t$,前一时刻的状态为$s_{t-1}$,则可以得出如下的前向计算的公式:
$$\begin{aligned}
&z_t = \sigma(U_zx_t+W_zs_{t-1}+b_z)\\\\
&r_t = \sigma(U_r x_t+W_rs_{t-1}+b_r)\\\\
&h_t = tanh(U_hx_t+W_h(s_{t-1}\bigodot r_t)+b_h)\\\\
&s_t = (1-z_t)\bigodot h_t + z_t \bigodot s_{t-1}\\\\
&\hat {y_t}=softmax(Vs_t+b_V)
\end{aligned}
$$
其中,$\bigodot$表示向量的点乘;$z_t$表示更新门;$r_t$表示重置门;$\hat {y_t}$表示t时刻的输出。
如果采用交叉熵损失函数,那么在t时刻的损失为$L_t$:$$
L_t =sumOfAllElements(-y_t\bigodot log(\hat {y_t}))
$$
为了训练GRU,需要把所有时刻的损失加在一起,并最小化损失$L=\sum_{t=1}^T L_t$:
$$argmin_{\Theta}L$$
其中,$\Theta={U_z,U_r,U_c,W_z,W_r,W_c,b_z,b_r,b_c,V,b_V}$。
这是一个非凸优化问题,通常采用随机梯度下降法去解决问题。因此,需要计算$\partial L/ \partial U_z,\partial L/ \partial U_r,\partial L/ \partial U_h,\partial L/ \partial W_z,\partial L/ \partial W_r,\partial L/ \partial W_h,\partial L/ \partial b_z,\partial L/ \partial b_r,\partial L/ \partial b_h,\partial L/ \partial V,\partial L/ \partial b_v$。计算上面的梯度,最好的方式是利用链式法则从输出到输入一步一步去计算,为了更好得看清输入、中间值以及输出之间的关系,画了一张GRU的计算图,如下图所示:
根据计算图,利用链式法则计算梯度,需要从上至下沿着边进行计算。如果节点X有多条出边和目标节点T相连,如果要计算$\partial T / \partial X$,需要分别计算每条边对X的梯度,并将梯度进行相加。
以计算$\frac {\partial L}{\partial U_z}$为例,其他的计算方式和其相似。因为$L=\sum_{t=1}^T L_t$,所以,$\frac {\partial L}{\partial U_z}=\sum_{t=1}^T \frac {\partial L_t}{\partial U_z}$,因此,可以先计算$\frac {\partial L_t}{\partial U_z}$,然后将不同时刻结果加起来就可以。
根据链式法则:$$\frac {\partial L_t}{\partial U_z} = \frac {\partial L_t}{\partial s_t} \frac{\partial s_t}{\partial U_z} (公式1)
$$
公式1右边的第一个式子的计算如下:$$\frac {\partial L_t}{\partial s_t}=V(\hat {y_t}-y_t) (公式2)
$$
对于$\frac {\partial z}{\partial U_z}$,一些人可能会直接进行如下的求导计算
$$
\frac {\overline{\partial s_t}}{\partial U_z}=((s_{t-1}-h_t)\bigodot z_t \bigodot (1-z_t))x_t^T (公式3)
$$
在$s_t$的计算公式里有$1-z$和$z\bigodot s_{t-1}$两个公式都会影响到$\frac {\partial s_t}{\partial U_z}$。正确的方法是分别计算每条边的偏导数,并将它们相加,因此,需要引入$\frac {\partial s_t}{\partial s_{t-1}}$。但是,公式3只计算了部分的梯度,因此用$\frac {\overline{\partial s_t}}{\partial U_z}$表示。
因为$s_{t-1}$同样依赖于$U_z$,因此我们不能把$s_{t-1}$作为一个常量处理。$s_{t-1}$同样会受到$s_i,i=1,…,t-2$的影响,因此,需要将公式1进行扩展,如下:
$$\begin{aligned}
\frac {\partial L_t}{\partial U_z} = &\frac {\partial L_t}{\partial s_t} \frac{\partial s_t}{\partial U_z}\\\\
=&\frac {\partial L_t}{\partial s_t}\sum_{i=1}^t(\frac{\partial s_t}{\partial s_i}\frac{\overline {\partial s_i}}{\partial U_z})\\\\
=&\frac {\partial L_t}{\partial s_t}\sum_{i=1}^t ((\prod_{j=i}^{t-1} \frac {\partial s_{j+1}}{\partial s_j})\frac{\overline {\partial s_i}}{\partial U_z})
\end{aligned} (公式4)
$$
其中,$\frac{\overline {\partial s_i}}{\partial U_z}$是$s_i$对$U_z$的梯度,其计算公式如公式3所示。
$\frac {\partial s_t}{\partial s_{t-1}}$的计算和$\frac {\partial s_t}{\partial z}$的计算相似。因为从$s_{t-1}$到$s_t$有四条边,直接或间接相连,通过$z_t,r_t和h_t$,因此,需要计算这四条边上的梯度,然后进行相加,计算公式如下:$$\begin{aligned}
\frac {\partial s_t}{\partial s_{t-1}}=&\frac {\partial s_t}{\partial h_t} \frac {\partial h_t}{\partial s_{t-1}}+\frac{\partial s_t}{\partial z_t}\frac{\partial z_t}{\partial s_{t-1}} + \frac {\overline {\partial s_t}}{\partial s_{t-1}}\\\\
=&\frac {\partial s_t}{\partial h_t}(\frac {\partial h_t}{\partial r_t}\frac {\partial r_t}{\partial s_{t-1}}+\frac {\overline {\partial h_t}}{\partial s_{t-1}}) + \frac {\partial s_t}{\partial z_t}\frac{\partial z_t}{\partial s_{t-1}}+\frac {\overline {\partial s_t}}{\partial s_{t-1}}
\end{aligned} (公式5)$$
其中,$\frac {\overline {\partial s_t}}{\partial s_{t-1}}$是对$s_t$关于$s_{t-1}$的导数,并将$h_t,z_t$看做常量。同样,$\frac {\overline {\partial h_t}}{\partial s_{t-1}}$是$h_t$关于$s_{t-1}$的导数,将$r_t$看做常量。最终,可以得到:
$$ \frac {\partial s_t}{\partial s_{t-1}}=(1-z_t)(W_r^T((W_h^T(1-h\bigodot h))\bigodot s_{t-1}\bigodot r \bigodot (1-r))+((W_h^T(1-h \bigodot h))\bigodot r_t))+\\\\W_z^T((s_{t-1}-h_t)\bigodot z_t \bigodot (1-z_t))+z (公式6)$$
到此为止,$\frac {\partial L}{\partial U_z}$的计算已经完成,而其余的参数的计算和它的计算方式类似,沿着计算图一步一步计算,这里就不一一计算了。
参考
- A Tutorial On Backward Propagation Through Time (BPTT) In The Gated Recurrent Unit (GRU) RNN