decoding key

This commit is contained in:
游雁 2024-05-07 15:53:26 +08:00
parent 78ff06a45c
commit fb0da9f849
3 changed files with 8 additions and 5 deletions

View File

@ -472,7 +472,7 @@ class ResidualAttentionBlockFSMN(nn.Module):
is_pad_mask = kwargs.get("is_pad_mask", False)
is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
fsmn_cache = cache[layer]["fsmn_cache"] if len(cache) > 0 or cache is None else None
fsmn_cache = cache[layer]["fsmn_cache"] if cache is not None and len(cache) > 0 else None
# if fsmn_cache is not None:
# x = x[:, -1:]
att_res, fsmn_cache = self.attn(self.attn_ln(x), mask=None, cache=fsmn_cache)

View File

@ -806,7 +806,6 @@ class SenseVoiceFSMN(nn.Module):
if len(kwargs.get("data_type", [])) > 1:
audio_sample_list, text_token_int_list = audio_sample_list
text_token_int = text_token_int_list[0]
text_token_int = tokenizer.encode(text_token_int)
else:
text_token_int = None
@ -846,7 +845,7 @@ class SenseVoiceFSMN(nn.Module):
)
if text_token_int is not None:
i = 1
i = 0
results = []
ibest_writer = None
if kwargs.get("output_dir") is not None:
@ -855,7 +854,9 @@ class SenseVoiceFSMN(nn.Module):
ibest_writer = self.writer[f"1best_recog"]
# 1. Forward decoder
ys_pad = torch.tensor(text_token_int, dtype=torch.int64).to(kwargs["device"])[None, :]
ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[
None, :
]
ys_pad_lens = torch.tensor([len(text_token_int)], dtype=torch.int64).to(
kwargs["device"]
)[None, :]

View File

@ -62,8 +62,10 @@ def detect_language(
else:
x = x.to(mel.device)
# FIX(funasr): sense vocie
# logits = model.logits(x[:, :-1], mel)[:, -1]
logits = model.logits(x[:, :], mel)[:, -1]
logits = model.logits(x[:, :-1], mel)[:, -1]
# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(tokenizer.all_language_tokens)] = False