This commit is contained in:
游雁 2024-07-18 16:31:49 +08:00
parent 340b6efef2
commit 6aacee8f9e

View File

@ -189,7 +189,7 @@ class MultiHeadAttentionSdpa(nn.Module):
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.5
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1)
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
if mask is not None:
@ -204,7 +204,7 @@ class MultiHeadAttentionSdpa(nn.Module):
is_causal=kwargs.get("is_causal", True),
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = attn_output.flatten(start_dim=2)
return attn_output, None
@ -301,7 +301,7 @@ class TextDecoder(nn.Module):
n_state,
n_head,
cross_attention=True,
att_type=kwargs.get("att_type", "default"),
att_type="default",
)
for _ in range(n_layer)
]