* 修复无法预测nospeech标签的问题

* 修复prompt存储的设备的问题

---------

Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com>
Co-authored-by: zhifu gao <zhifu.gzf@alibaba-inc.com>
This commit is contained in:
gaochangfeng 2024-04-10 14:37:35 +08:00 committed by GitHub
parent b8bf792ce7
commit 851e3e3ef8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,6 +10,8 @@ from torch.distributions import Categorical
from .audio import CHUNK_LENGTH
from .tokenizer import Tokenizer, get_tokenizer
from .utils import compression_ratio
from funasr.models.transformer.utils.nets_utils import to_device
if TYPE_CHECKING:
from .model import Whisper
@ -58,6 +60,10 @@ def detect_language(
# x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
if x is None:
x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device) # [n_audio, 1]
else:
x = x.to(mel.device)
logits = model.logits(x[:,:-1], mel)[:, -1]
# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool)