sdpa bugfix

This commit is contained in:
游雁 2024-07-24 00:57:10 +08:00
parent 609c0e7e0d
commit 54e630159d
3 changed files with 4 additions and 4 deletions

View File

@ -71,7 +71,7 @@ def sense_voice_decode_forward(
x = tgt.to(memory.dtype)
if use_padmask and hlens is not None:
memory_mask = (~make_pad_mask(hlens)[:, :, None]).to(memory.device)
memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
else:
memory_mask = None

View File

@ -42,7 +42,7 @@ def sense_voice_encode_forward(
olens = None
if use_padmask and olens is not None:
padding_mask = (~make_pad_mask(olens)[:, :, None]).to(torch.bool).to(x.device)
padding_mask = (~make_pad_mask(olens)[:, None, :]).to(torch.bool).to(x.device)
else:
padding_mask = None

View File

@ -196,7 +196,7 @@ class MultiHeadAttentionSdpa(nn.Module):
mask = None
is_causal = True
else:
mask = mask.unsqueeze(1).to(torch.bool) # (batch, 1, t, 1)
mask = mask.unsqueeze(1).to(torch.bool) # (batch, 1, 1, t)
attn_output = torch.nn.functional.scaled_dot_product_attention(
q,
@ -208,7 +208,7 @@ class MultiHeadAttentionSdpa(nn.Module):
scale=scale,
)
if mask is not None:
attn_output = attn_output.masked_fill(mask.logical_not(), 0.0)
attn_output = attn_output.masked_fill(mask.transpose(2, 3).logical_not(), 0.0)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.flatten(start_dim=2)
return attn_output, None