diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py index e3ad56a5a..c47d96d06 100644 --- a/funasr/modules/attention.py +++ b/funasr/modules/attention.py @@ -622,4 +622,108 @@ class MultiHeadedAttentionCrossAtt(nn.Module): q_h, k_h, v_h = self.forward_qkv(x, memory) q_h = q_h * self.d_k ** (-0.5) scores = torch.matmul(q_h, k_h.transpose(-2, -1)) - return self.forward_attention(v_h, scores, memory_mask) \ No newline at end of file + return self.forward_attention(v_h, scores, memory_mask) + + +class MultiHeadSelfAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, n_head, in_feat, n_feat, dropout_rate): + """Construct an MultiHeadedAttention object.""" + super(MultiHeadSelfAttention, self).__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_out = nn.Linear(n_feat, n_feat) + self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3) + self.attn = None + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv(self, x): + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). + + """ + b, t, d = x.size() + q_k_v = self.linear_q_k_v(x) + q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1) + q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k) + k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k) + v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k) + + return q_h, k_h, v_h, v + + def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None): + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + if mask is not None: + if mask_att_chunk_encoder is not None: + mask = mask * mask_att_chunk_encoder + + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + + min_value = float( + numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min + ) + scores = scores.masked_fill(mask, min_value) + self.attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0 + ) # (batch, head, time1, time2) + else: + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, x, mask, mask_att_chunk_encoder=None): + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + q_h, k_h, v_h, v = self.forward_qkv(x) + q_h = q_h * self.d_k ** (-0.5) + scores = torch.matmul(q_h, k_h.transpose(-2, -1)) + att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) + return att_outs