mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
decoding key
This commit is contained in:
parent
78ff06a45c
commit
fb0da9f849
@ -472,7 +472,7 @@ class ResidualAttentionBlockFSMN(nn.Module):
|
||||
is_pad_mask = kwargs.get("is_pad_mask", False)
|
||||
is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
|
||||
|
||||
fsmn_cache = cache[layer]["fsmn_cache"] if len(cache) > 0 or cache is None else None
|
||||
fsmn_cache = cache[layer]["fsmn_cache"] if cache is not None and len(cache) > 0 else None
|
||||
# if fsmn_cache is not None:
|
||||
# x = x[:, -1:]
|
||||
att_res, fsmn_cache = self.attn(self.attn_ln(x), mask=None, cache=fsmn_cache)
|
||||
|
||||
@ -806,7 +806,6 @@ class SenseVoiceFSMN(nn.Module):
|
||||
if len(kwargs.get("data_type", [])) > 1:
|
||||
audio_sample_list, text_token_int_list = audio_sample_list
|
||||
text_token_int = text_token_int_list[0]
|
||||
text_token_int = tokenizer.encode(text_token_int)
|
||||
else:
|
||||
text_token_int = None
|
||||
|
||||
@ -846,7 +845,7 @@ class SenseVoiceFSMN(nn.Module):
|
||||
)
|
||||
|
||||
if text_token_int is not None:
|
||||
i = 1
|
||||
i = 0
|
||||
results = []
|
||||
ibest_writer = None
|
||||
if kwargs.get("output_dir") is not None:
|
||||
@ -855,7 +854,9 @@ class SenseVoiceFSMN(nn.Module):
|
||||
ibest_writer = self.writer[f"1best_recog"]
|
||||
|
||||
# 1. Forward decoder
|
||||
ys_pad = torch.tensor(text_token_int, dtype=torch.int64).to(kwargs["device"])[None, :]
|
||||
ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[
|
||||
None, :
|
||||
]
|
||||
ys_pad_lens = torch.tensor([len(text_token_int)], dtype=torch.int64).to(
|
||||
kwargs["device"]
|
||||
)[None, :]
|
||||
|
||||
@ -62,8 +62,10 @@ def detect_language(
|
||||
|
||||
else:
|
||||
x = x.to(mel.device)
|
||||
# FIX(funasr): sense vocie
|
||||
# logits = model.logits(x[:, :-1], mel)[:, -1]
|
||||
logits = model.logits(x[:, :], mel)[:, -1]
|
||||
|
||||
logits = model.logits(x[:, :-1], mel)[:, -1]
|
||||
# collect detected languages; suppress all non-language tokens
|
||||
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||
mask[list(tokenizer.all_language_tokens)] = False
|
||||
|
||||
Loading…
Reference in New Issue
Block a user