mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
sdpa bugfix
This commit is contained in:
parent
609c0e7e0d
commit
54e630159d
@ -71,7 +71,7 @@ def sense_voice_decode_forward(
|
|||||||
x = tgt.to(memory.dtype)
|
x = tgt.to(memory.dtype)
|
||||||
|
|
||||||
if use_padmask and hlens is not None:
|
if use_padmask and hlens is not None:
|
||||||
memory_mask = (~make_pad_mask(hlens)[:, :, None]).to(memory.device)
|
memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
|
||||||
else:
|
else:
|
||||||
memory_mask = None
|
memory_mask = None
|
||||||
|
|
||||||
|
|||||||
@ -42,7 +42,7 @@ def sense_voice_encode_forward(
|
|||||||
olens = None
|
olens = None
|
||||||
|
|
||||||
if use_padmask and olens is not None:
|
if use_padmask and olens is not None:
|
||||||
padding_mask = (~make_pad_mask(olens)[:, :, None]).to(torch.bool).to(x.device)
|
padding_mask = (~make_pad_mask(olens)[:, None, :]).to(torch.bool).to(x.device)
|
||||||
else:
|
else:
|
||||||
padding_mask = None
|
padding_mask = None
|
||||||
|
|
||||||
|
|||||||
@ -196,7 +196,7 @@ class MultiHeadAttentionSdpa(nn.Module):
|
|||||||
mask = None
|
mask = None
|
||||||
is_causal = True
|
is_causal = True
|
||||||
else:
|
else:
|
||||||
mask = mask.unsqueeze(1).to(torch.bool) # (batch, 1, t, 1)
|
mask = mask.unsqueeze(1).to(torch.bool) # (batch, 1, 1, t)
|
||||||
|
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q,
|
q,
|
||||||
@ -208,7 +208,7 @@ class MultiHeadAttentionSdpa(nn.Module):
|
|||||||
scale=scale,
|
scale=scale,
|
||||||
)
|
)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
attn_output = attn_output.masked_fill(mask.logical_not(), 0.0)
|
attn_output = attn_output.masked_fill(mask.transpose(2, 3).logical_not(), 0.0)
|
||||||
attn_output = attn_output.transpose(1, 2)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
attn_output = attn_output.flatten(start_dim=2)
|
attn_output = attn_output.flatten(start_dim=2)
|
||||||
return attn_output, None
|
return attn_output, None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user