sdpa bugfix

This commit is contained in:
游雁 2024-07-24 00:28:56 +08:00
parent dfc52059c0
commit 318d81be4a
2 changed files with 18 additions and 8 deletions

View File

@ -71,7 +71,7 @@ def sense_voice_decode_forward(
x = tgt.to(memory.dtype)
if use_padmask and hlens is not None:
memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
memory_mask = (~make_pad_mask(hlens)[:, :, None]).to(memory.device)
else:
memory_mask = None

View File

@ -127,7 +127,7 @@ class MultiHeadAttention(nn.Module):
if not is_pad_mask:
qk = qk + mask[:n_ctx, :n_ctx]
else:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
mask = mask.unsqueeze(1).eq(0) # (batch, 1, t, 1)
min_value = -float(
"inf"
) # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
@ -172,9 +172,7 @@ class MultiHeadAttentionSdpa(nn.Module):
k = kv_cache[self.key]
v = kv_cache[self.value]
wv, qk = self.qkv_attention(
q, k, v, mask, is_pad_mask=is_pad_mask, is_causal=True if xa is not None else False
)
wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask, is_causal=False)
return self.out(wv), qk
def qkv_attention(
@ -186,6 +184,7 @@ class MultiHeadAttentionSdpa(nn.Module):
**kwargs,
):
is_pad_mask = kwargs.get("is_pad_mask", False)
is_causal = kwargs.get("is_causal", False)
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.5
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
@ -193,7 +192,11 @@ class MultiHeadAttentionSdpa(nn.Module):
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
if mask is not None:
mask = mask.unsqueeze(1).eq(0)
if not is_pad_mask:
mask = None
is_causal = True
else:
mask = mask.unsqueeze(1).to(torch.bool) # (batch, 1, t, 1)
attn_output = torch.nn.functional.scaled_dot_product_attention(
q,
@ -201,8 +204,10 @@ class MultiHeadAttentionSdpa(nn.Module):
v,
attn_mask=mask,
dropout_p=0.0,
is_causal=kwargs.get("is_causal", True),
is_causal=True,
scale=scale,
)
attn_output = attn_output.masked_fill(mask.logical_not(), 0.0)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.flatten(start_dim=2)
return attn_output, None
@ -241,12 +246,17 @@ class ResidualAttentionBlock(nn.Module):
):
is_pad_mask = kwargs.get("is_pad_mask", False)
is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
memory_mask = kwargs.get("memory_mask", None)
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
if self.cross_attn:
x = (
x
+ self.cross_attn(
self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask
self.cross_attn_ln(x),
xa,
mask=memory_mask,
kv_cache=kv_cache,
is_pad_mask=is_pad_memory_mask,
)[0]
)
x = x + self.mlp(self.mlp_ln(x))