mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
sdpa bugfix
This commit is contained in:
parent
318d81be4a
commit
b9bc982e4f
@ -204,10 +204,11 @@ class MultiHeadAttentionSdpa(nn.Module):
|
||||
v,
|
||||
attn_mask=mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=True,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
)
|
||||
attn_output = attn_output.masked_fill(mask.logical_not(), 0.0)
|
||||
if mask is not None:
|
||||
attn_output = attn_output.masked_fill(mask.logical_not(), 0.0)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.flatten(start_dim=2)
|
||||
return attn_output, None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user