mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
sdpa bugfix
This commit is contained in:
parent
609c0e7e0d
commit
54e630159d
@ -71,7 +71,7 @@ def sense_voice_decode_forward(
|
||||
x = tgt.to(memory.dtype)
|
||||
|
||||
if use_padmask and hlens is not None:
|
||||
memory_mask = (~make_pad_mask(hlens)[:, :, None]).to(memory.device)
|
||||
memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
|
||||
else:
|
||||
memory_mask = None
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ def sense_voice_encode_forward(
|
||||
olens = None
|
||||
|
||||
if use_padmask and olens is not None:
|
||||
padding_mask = (~make_pad_mask(olens)[:, :, None]).to(torch.bool).to(x.device)
|
||||
padding_mask = (~make_pad_mask(olens)[:, None, :]).to(torch.bool).to(x.device)
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
|
||||
@ -196,7 +196,7 @@ class MultiHeadAttentionSdpa(nn.Module):
|
||||
mask = None
|
||||
is_causal = True
|
||||
else:
|
||||
mask = mask.unsqueeze(1).to(torch.bool) # (batch, 1, t, 1)
|
||||
mask = mask.unsqueeze(1).to(torch.bool) # (batch, 1, 1, t)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
q,
|
||||
@ -208,7 +208,7 @@ class MultiHeadAttentionSdpa(nn.Module):
|
||||
scale=scale,
|
||||
)
|
||||
if mask is not None:
|
||||
attn_output = attn_output.masked_fill(mask.logical_not(), 0.0)
|
||||
attn_output = attn_output.masked_fill(mask.transpose(2, 3).logical_not(), 0.0)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.flatten(start_dim=2)
|
||||
return attn_output, None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user