mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Gcf (#1605)
* 修复无法预测nospeech标签的问题 * 修复prompt存储的设备的问题 --------- Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com> Co-authored-by: zhifu gao <zhifu.gzf@alibaba-inc.com>
This commit is contained in:
parent
b8bf792ce7
commit
851e3e3ef8
@ -10,6 +10,8 @@ from torch.distributions import Categorical
|
|||||||
from .audio import CHUNK_LENGTH
|
from .audio import CHUNK_LENGTH
|
||||||
from .tokenizer import Tokenizer, get_tokenizer
|
from .tokenizer import Tokenizer, get_tokenizer
|
||||||
from .utils import compression_ratio
|
from .utils import compression_ratio
|
||||||
|
from funasr.models.transformer.utils.nets_utils import to_device
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .model import Whisper
|
from .model import Whisper
|
||||||
@ -58,6 +60,10 @@ def detect_language(
|
|||||||
# x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
# x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
||||||
if x is None:
|
if x is None:
|
||||||
x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device) # [n_audio, 1]
|
x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device) # [n_audio, 1]
|
||||||
|
|
||||||
|
else:
|
||||||
|
x = x.to(mel.device)
|
||||||
|
|
||||||
logits = model.logits(x[:,:-1], mel)[:, -1]
|
logits = model.logits(x[:,:-1], mel)[:, -1]
|
||||||
# collect detected languages; suppress all non-language tokens
|
# collect detected languages; suppress all non-language tokens
|
||||||
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user