mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
修复无法预测nospeech标签的问题 (#1604)
Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com>
This commit is contained in:
parent
c6574bf4f4
commit
112c8e6eb7
@ -58,18 +58,20 @@ def detect_language(
|
|||||||
# x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
# x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
||||||
if x is None:
|
if x is None:
|
||||||
x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device) # [n_audio, 1]
|
x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device) # [n_audio, 1]
|
||||||
logits = model.logits(x, mel)[:, 0]
|
logits = model.logits(x[:,:-1], mel)[:, -1]
|
||||||
|
|
||||||
# collect detected languages; suppress all non-language tokens
|
# collect detected languages; suppress all non-language tokens
|
||||||
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||||
mask[list(tokenizer.all_language_tokens)] = False
|
mask[list(tokenizer.all_language_tokens)] = False
|
||||||
|
mask[tokenizer.no_speech] = False
|
||||||
|
|
||||||
logits[:, mask] = -np.inf
|
logits[:, mask] = -np.inf
|
||||||
language_tokens = logits.argmax(dim=-1)
|
language_tokens = logits.argmax(dim=-1)
|
||||||
language_token_probs = logits.softmax(dim=-1).cpu()
|
language_token_probs = logits.softmax(dim=-1).cpu()
|
||||||
|
|
||||||
language_probs = [
|
language_probs = [
|
||||||
{
|
{
|
||||||
c: language_token_probs[i, j].item()
|
c: language_token_probs[i, j].item()
|
||||||
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
|
for j, c in zip(list(tokenizer.all_language_tokens) + [tokenizer.no_speech], list(tokenizer.all_language_codes) + ["nospeech"])
|
||||||
}
|
}
|
||||||
for i in range(n_audio)
|
for i in range(n_audio)
|
||||||
]
|
]
|
||||||
|
|||||||
@ -179,7 +179,12 @@ class Tokenizer:
|
|||||||
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
||||||
sot_sequence = [sot]
|
sot_sequence = [sot]
|
||||||
if self.language is not None:
|
if self.language is not None:
|
||||||
sot_sequence.append(sot + 1 + langs.index(self.language))
|
if self.language == 'nospeech':
|
||||||
|
sot_sequence.append(self.no_speech)
|
||||||
|
else:
|
||||||
|
sot_sequence.append(sot + 1 + langs.index(self.language))
|
||||||
|
# if self.language is not None:
|
||||||
|
# sot_sequence.append(sot + 1 + langs.index(self.language))
|
||||||
if self.task is not None:
|
if self.task is not None:
|
||||||
task_token: int = transcribe if self.task == "transcribe" else translate
|
task_token: int = transcribe if self.task == "transcribe" else translate
|
||||||
sot_sequence.append(task_token)
|
sot_sequence.append(task_token)
|
||||||
@ -432,6 +437,8 @@ def get_tokenizer(
|
|||||||
if language not in LANGUAGES:
|
if language not in LANGUAGES:
|
||||||
if language in TO_LANGUAGE_CODE:
|
if language in TO_LANGUAGE_CODE:
|
||||||
language = TO_LANGUAGE_CODE[language]
|
language = TO_LANGUAGE_CODE[language]
|
||||||
|
elif language == 'nospeech':
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported language: {language}")
|
raise ValueError(f"Unsupported language: {language}")
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user