Scaled Dot-Product Attention(transformer)-LMLPHP

Scaled Dot-Product Attention(transformer)-LMLPHP

Scaled Dot-Product Attention是transformer的encoder的multi-head attention的组成部分。

由于Scaled Dot-Product Attention是multi-head的构成部分,因此Scaled Dot-Product Attention的数据的输入q,k,v的shape通常我们会变化为如下:

(batch, n_head, seqLen, dim)  其中n_head表示multi-head的个数,且n_head*dim = embedSize

整个输入到输出,数据的维度保持不变。

temperature表示Scaled,即dim**0.5

mask表示每个batch对应样本中如果sequence为pad,则对应的mask为False,因此mask的初始维度为(batchSize, seqLen),为了计算,mask的维度会扩充为(batchSize, 1, 1, seqLen)。

class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None): # q/k/v.shape: (batchSize, n_head, seqLen, dim)

        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))  # attn.shape: (batchSize, n_head, q_seqLen, k_seqLen)

        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)

        attn = self.dropout(F.softmax(attn, dim=-1)) # attn.shape: (batchSize, n_head, q_seqLen, k_seqLen)
        output = torch.matmul(attn, v) # output.shape: (batchSize, n_head, q_seqLen, dim)

        return output, attn

注意:

当QKV来自同一个向量的矩阵变换时称作self-attention;

当Q和KV来自不同的向量的矩阵变换时叫soft-attention;

09-09 06:52