sdpa bugfix

This commit is contained in:
游雁 2024-07-24 01:08:52 +08:00
parent 54e630159d
commit 20d32f68e8

View File

@ -312,7 +312,7 @@ class TextDecoder(nn.Module):
n_state,
n_head,
cross_attention=True,
att_type="default",
att_type=kwargs.get("att_type", "default"),
)
for _ in range(n_layer)
]