Cyan's Blog

Search

Search IconIcon to open search

D2L-69-Scaled Dot-Product Attention

Last updated Apr 21, 2022 Edit Source

# 缩放的点积注意力

2022-04-21

Tags: #Attention #DeepLearning

1

# 通用形式

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class DotProductAttention(nn.Module):
    """缩放点积注意力"""
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # queries的形状:(batch_size,查询的个数,d)
    # keys的形状:(batch_size,“键-值”对的个数,d)
    # values的形状:(batch_size,“键-值”对的个数,值的维度)
    # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)
    def forward(self, queries, keys, values, valid_lens=None):
        # 获取query的长度(key的长度)
        d = queries.shape[-1]
        # 设置transpose_b=True为了交换keys的最后两个维度
        scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

  1. 不是图源, 但是我是在这里看到的 Attention机制详解(二)——Self-Attention与Transformer - 知乎 ↩︎