Scaled Dot-Product Attention是transformer的encoder的multi-head attention的组成部分。
由于Scaled Dot-Product Attention是multi-head的构成部分,因此Scaled Dot-Product Attention的数据的输入q,k,v的shape通常我们会变化为如下:
(batch, n_head, seqLen, dim) 其中n_head表示multi-head的个数,且n_head*dim = embedSize
整个输入到输出,数据的维度保持不变。
temperature表示Scaled,即dim**0.5
mask表示每个batch对应样本中如果sequence为pad,则对应的mask为False,因此mask的初始维度为(batchSize, seqLen),为了计算,mask的维度会扩充为(batchSize, 1, 1, seqLen)。
class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention '''
def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
def forward(self, q, k, v, mask=None): # q/k/v.shape: (batchSize, n_head, seqLen, dim)
attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) # attn.shape: (batchSize, n_head, q_seqLen, k_seqLen)
if mask is not None:
attn = attn.masked_fill(mask == 0, -1e9)
attn = self.dropout(F.softmax(attn, dim=-1)) # attn.shape: (batchSize, n_head, q_seqLen, k_seqLen)
output = torch.matmul(attn, v) # output.shape: (batchSize, n_head, q_seqLen, dim)
return output, attn
注意:
当QKV来自同一个向量的矩阵变换时称作self-attention;
当Q和KV来自不同的向量的矩阵变换时叫soft-attention;