sdpa bugfix

This commit is contained in:
游雁 2024-07-24 00:33:04 +08:00
parent 318d81be4a
commit b9bc982e4f

View File

@ -204,10 +204,11 @@ class MultiHeadAttentionSdpa(nn.Module):
v,
attn_mask=mask,
dropout_p=0.0,
is_causal=True,
is_causal=is_causal,
scale=scale,
)
attn_output = attn_output.masked_fill(mask.logical_not(), 0.0)
if mask is not None:
attn_output = attn_output.masked_fill(mask.logical_not(), 0.0)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.flatten(start_dim=2)
return attn_output, None