mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
sdpa bugfix
This commit is contained in:
parent
dfc52059c0
commit
318d81be4a
@ -71,7 +71,7 @@ def sense_voice_decode_forward(
|
||||
x = tgt.to(memory.dtype)
|
||||
|
||||
if use_padmask and hlens is not None:
|
||||
memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
|
||||
memory_mask = (~make_pad_mask(hlens)[:, :, None]).to(memory.device)
|
||||
else:
|
||||
memory_mask = None
|
||||
|
||||
|
||||
@ -127,7 +127,7 @@ class MultiHeadAttention(nn.Module):
|
||||
if not is_pad_mask:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
else:
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, t, 1)
|
||||
min_value = -float(
|
||||
"inf"
|
||||
) # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
|
||||
@ -172,9 +172,7 @@ class MultiHeadAttentionSdpa(nn.Module):
|
||||
k = kv_cache[self.key]
|
||||
v = kv_cache[self.value]
|
||||
|
||||
wv, qk = self.qkv_attention(
|
||||
q, k, v, mask, is_pad_mask=is_pad_mask, is_causal=True if xa is not None else False
|
||||
)
|
||||
wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask, is_causal=False)
|
||||
return self.out(wv), qk
|
||||
|
||||
def qkv_attention(
|
||||
@ -186,6 +184,7 @@ class MultiHeadAttentionSdpa(nn.Module):
|
||||
**kwargs,
|
||||
):
|
||||
is_pad_mask = kwargs.get("is_pad_mask", False)
|
||||
is_causal = kwargs.get("is_causal", False)
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.5
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
@ -193,7 +192,11 @@ class MultiHeadAttentionSdpa(nn.Module):
|
||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).eq(0)
|
||||
if not is_pad_mask:
|
||||
mask = None
|
||||
is_causal = True
|
||||
else:
|
||||
mask = mask.unsqueeze(1).to(torch.bool) # (batch, 1, t, 1)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
q,
|
||||
@ -201,8 +204,10 @@ class MultiHeadAttentionSdpa(nn.Module):
|
||||
v,
|
||||
attn_mask=mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=kwargs.get("is_causal", True),
|
||||
is_causal=True,
|
||||
scale=scale,
|
||||
)
|
||||
attn_output = attn_output.masked_fill(mask.logical_not(), 0.0)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.flatten(start_dim=2)
|
||||
return attn_output, None
|
||||
@ -241,12 +246,17 @@ class ResidualAttentionBlock(nn.Module):
|
||||
):
|
||||
is_pad_mask = kwargs.get("is_pad_mask", False)
|
||||
is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
|
||||
memory_mask = kwargs.get("memory_mask", None)
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
|
||||
if self.cross_attn:
|
||||
x = (
|
||||
x
|
||||
+ self.cross_attn(
|
||||
self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask
|
||||
self.cross_attn_ln(x),
|
||||
xa,
|
||||
mask=memory_mask,
|
||||
kv_cache=kv_cache,
|
||||
is_pad_mask=is_pad_memory_mask,
|
||||
)[0]
|
||||
)
|
||||
x = x + self.mlp(self.mlp_ln(x))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user