Cyan's Blog

Search

Search IconIcon to open search

D2L-55-在时间上反向传播

Last updated Apr 2, 2022 Edit Source

# Backpropagation Through Time

2022-04-02

Tags: #Backpropagation #RNN

# 在时间上反向传播/RNN的反向传播

# 模型

简化版RNN

# 损失函数

# 求梯度: $\frac{\partial L}{\partial w_o}$

# 求梯度: $\frac{\partial L}{\partial w_h}$

# RNN缓解梯度问题的一些策略

# Truncating Time Steps

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):

...

	for X, Y in train_iter:
        if state is None or use_random_iter:
            # 在第一次迭代或使用随机抽样时初始化state
            state = net.begin_state(batch_size=X.shape[0], device=device)
        else:
            if isinstance(net, nn.Module) and not isinstance(state, tuple):
                # state对于nn.GRU是个张量
                state.detach_()
            else:
            # state对于nn.LSTM或对于我们从零开始实现的模型是个张量
                for s in state:
                    s.detach_()
 
 ...

    return math.exp(metric[0] / metric[1]), metric[1] / timer.stop()

# Randomized Truncation

# 截断方式可视化

不同的截断方式代表了梯度不同的传播距离, 上面的图表示了每一个位置的隐状态可能的影响范围.

# RNN梯度传播的细节问题

其中圆圈代表运算, 方框代表变量或参数

# Step 1: $\frac{\partial L}{\partial \mathbf{o}_t}$

$$L = \frac{1}{T} \sum_{t=1}^T l(\mathbf{o}_t, y_t).$$ $$\frac{\partial L}{\partial \mathbf{o}_t} =\frac{1}{T}\frac{\partial l (\mathbf{o}_t, y_t)}{\partial \mathbf{o}_t} \in \mathbb{R}^q$$

# Step 2: $\frac{\partial L}{\partial \mathbf{W}_{qh}}$

根据计算图, 损失函数对 $\mathbf{W}_{qh}$ 的梯度依赖于 $\mathbf{o}1, \ldots, \mathbf{o}T$, 利用链式法则有: $$\frac{\partial L}{\partial \mathbf{W}{qh}} = \sum{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}t}, \frac{\partial \mathbf{o}t}{\partial \mathbf{W}{qh}}\right) = \sum{t=1}^T \frac{\partial L}{\partial \mathbf{o}_t} \mathbf{h}_t^\top$$

# Step 3: $\frac{\partial L}{\partial \mathbf{h}_t}$

我们先来看看对于最后一个时间步 $T$ 来说, 梯度 $\frac{\partial L}{\partial \mathbf{h}T}$ 的计算: $$\frac{\partial L}{\partial \mathbf{h}T} = \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}T}, \frac{\partial \mathbf{o}T}{\partial \mathbf{h}T} \right) = \mathbf{W}{qh}^\top \frac{\partial L}{\partial \mathbf{o}T}$$ 在 $t<T$ 的时候计算变得复杂起来, 因为 $h_t$ 的梯度同时依赖于 $o_t$ 和 $h{t+1}$ 根据链式法则有: $$\frac{\partial L}{\partial \mathbf{h}t} = \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}{t+1}}, \frac{\partial \mathbf{h}{t+1}}{\partial \mathbf{h}t} \right) + \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}t}, \frac{\partial \mathbf{o}t}{\partial \mathbf{h}t} \right) = \mathbf{W}{hh}^\top \frac{\partial L}{\partial \mathbf{h}{t+1}} + \mathbf{W}{qh}^\top \frac{\partial L}{\partial \mathbf{o}t}$$ 转化为通项公式: $$\frac{\partial L}{\partial \mathbf{h}t}= \sum{i=t}^T {\left(\mathbf{W}{hh}^\top\right)}^{T-i} \mathbf{W}{qh}^\top \frac{\partial L}{\partial \mathbf{o}{T+t-i}}.$$ 即使我们省略了激活函数, 从中我们已经能够看到一些问题: 表达式里面 $\mathbf{W}{hh}^\top$ 的指数部分可能会很大, 在 $\mathbf{W}{hh}^\top$ 里面特征值大于 $1$ 的部分会梯度爆炸, 而特征值小于 $1$ 的部分会梯度消失. 在多次矩阵连乘以后, 一个向量会越来越靠近特征值最大的特征向量的方向. EigenvalueMatrixPower5

# Step 4: $\partial L / \partial \mathbf{W}{hx}$ and $\partial L / \partial \mathbf{W}{hh}$,

最后我们基于$\frac{\partial L}{\partial \mathbf{h}t}$计算隐藏层参数的梯度: $\partial L / \partial \mathbf{W}{hx} \in \mathbb{R}^{h \times d}$ 和 $\partial L / \partial \mathbf{W}{hh} \in \mathbb{R}^{h \times h}$, $$ \begin{aligned} \frac{\partial L}{\partial \mathbf{W}{hx}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}t}, \frac{\partial \mathbf{h}t}{\partial \mathbf{W}{hx}}\right) = \sum{t=1}^T \frac{\partial L}{\partial \mathbf{h}t} \mathbf{x}t^\top,\\ \frac{\partial L}{\partial \mathbf{W}{hh}} &= \sum{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}t}, \frac{\partial \mathbf{h}t}{\partial \mathbf{W}{hh}}\right) = \sum{t=1}^T \frac{\partial L}{\partial \mathbf{h}t} \mathbf{h}{t-1}^\top, \end{aligned} $$


  1. 递推公式 $a_{t}=b_{t}+c_{t}a_{t-1}$ 转通项公式 ↩︎

  2. 梯度归一化 ↩︎

  3. 8.7. Backpropagation Through Time — Dive into Deep Learning 0.17.5 documentation ↩︎

  4. 具体看看这一节: D2L-5-拓展链式法则 利用抽象的符号可以省略掉很多繁杂的细节 ↩︎

  5. 如何理解矩阵特征值? - 知乎 ↩︎