mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Gcf (#1605)
* 修复无法预测nospeech标签的问题 * 修复prompt存储的设备的问题 --------- Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com> Co-authored-by: zhifu gao <zhifu.gzf@alibaba-inc.com>
This commit is contained in:
parent
b8bf792ce7
commit
851e3e3ef8
@ -10,6 +10,8 @@ from torch.distributions import Categorical
|
||||
from .audio import CHUNK_LENGTH
|
||||
from .tokenizer import Tokenizer, get_tokenizer
|
||||
from .utils import compression_ratio
|
||||
from funasr.models.transformer.utils.nets_utils import to_device
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
@ -58,6 +60,10 @@ def detect_language(
|
||||
# x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
||||
if x is None:
|
||||
x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device) # [n_audio, 1]
|
||||
|
||||
else:
|
||||
x = x.to(mel.device)
|
||||
|
||||
logits = model.logits(x[:,:-1], mel)[:, -1]
|
||||
# collect detected languages; suppress all non-language tokens
|
||||
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user