mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
340b6efef2
commit
6aacee8f9e
@ -189,7 +189,7 @@ class MultiHeadAttentionSdpa(nn.Module):
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.5
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1)
|
||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
|
||||
if mask is not None:
|
||||
@ -204,7 +204,7 @@ class MultiHeadAttentionSdpa(nn.Module):
|
||||
is_causal=kwargs.get("is_causal", True),
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
attn_output = attn_output.flatten(start_dim=2)
|
||||
return attn_output, None
|
||||
|
||||
|
||||
@ -301,7 +301,7 @@ class TextDecoder(nn.Module):
|
||||
n_state,
|
||||
n_head,
|
||||
cross_attention=True,
|
||||
att_type=kwargs.get("att_type", "default"),
|
||||
att_type="default",
|
||||
)
|
||||
for _ in range(n_layer)
|
||||
]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user