Cyan's Blog

Search

Search IconIcon to open search

D2L-54-Gradient Clipping-梯度剪裁

Last updated Apr 2, 2022 Edit Source

# Gradient Clipping

2022-04-02

Tags: #GradientClipping

$$\mathbf{g} \leftarrow \min \left(1, \frac{\theta}{|\mathbf{g}|}\right) \mathbf{g}$$

400

# PyTorch

torch.nn.utils.clip_grad_norm_ — PyTorch 1.11.0 documentation

1
nn.utils.clip_grad_norm_(model.parameters(), max_norm=CLIP_GRAD)

# D2l 里面的简易实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
def grad_clipping(net, theta):  #@save
    """裁剪梯度"""
    if isinstance(net, nn.Module):
        params = [p for p in net.parameters() if p.requires_grad]
    else:
        params = net.params
    norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
    if norm > theta:
        for param in params:
            param.grad[:] *= theta / norm #注意这里是*=