From 54e630159d9b3eb87379063cb814766cdc2b67aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Wed, 24 Jul 2024 00:57:10 +0800 Subject: [PATCH] sdpa bugfix --- funasr/models/sense_voice/decoder.py | 2 +- funasr/models/sense_voice/encoder.py | 2 +- funasr/models/sense_voice/whisper_lib/model.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py index 82bdc4a77..ff933d77b 100644 --- a/funasr/models/sense_voice/decoder.py +++ b/funasr/models/sense_voice/decoder.py @@ -71,7 +71,7 @@ def sense_voice_decode_forward( x = tgt.to(memory.dtype) if use_padmask and hlens is not None: - memory_mask = (~make_pad_mask(hlens)[:, :, None]).to(memory.device) + memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device) else: memory_mask = None diff --git a/funasr/models/sense_voice/encoder.py b/funasr/models/sense_voice/encoder.py index 64156e9dd..6c6d15600 100644 --- a/funasr/models/sense_voice/encoder.py +++ b/funasr/models/sense_voice/encoder.py @@ -42,7 +42,7 @@ def sense_voice_encode_forward( olens = None if use_padmask and olens is not None: - padding_mask = (~make_pad_mask(olens)[:, :, None]).to(torch.bool).to(x.device) + padding_mask = (~make_pad_mask(olens)[:, None, :]).to(torch.bool).to(x.device) else: padding_mask = None diff --git a/funasr/models/sense_voice/whisper_lib/model.py b/funasr/models/sense_voice/whisper_lib/model.py index f2b82a5d1..e712e22bf 100644 --- a/funasr/models/sense_voice/whisper_lib/model.py +++ b/funasr/models/sense_voice/whisper_lib/model.py @@ -196,7 +196,7 @@ class MultiHeadAttentionSdpa(nn.Module): mask = None is_causal = True else: - mask = mask.unsqueeze(1).to(torch.bool) # (batch, 1, t, 1) + mask = mask.unsqueeze(1).to(torch.bool) # (batch, 1, 1, t) attn_output = torch.nn.functional.scaled_dot_product_attention( q, @@ -208,7 +208,7 @@ class MultiHeadAttentionSdpa(nn.Module): scale=scale, ) if mask is not None: - attn_output = attn_output.masked_fill(mask.logical_not(), 0.0) + attn_output = attn_output.masked_fill(mask.transpose(2, 3).logical_not(), 0.0) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.flatten(start_dim=2) return attn_output, None