mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
python runtime
This commit is contained in:
parent
ea8e5bb726
commit
0cf7d386c6
@ -33,14 +33,13 @@ class SenseVoiceSmall:
|
|||||||
self,
|
self,
|
||||||
model_dir: Union[str, Path] = None,
|
model_dir: Union[str, Path] = None,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
device_id: Union[str, int] = "-1",
|
|
||||||
plot_timestamp_to: str = "",
|
plot_timestamp_to: str = "",
|
||||||
quantize: bool = False,
|
quantize: bool = False,
|
||||||
intra_op_num_threads: int = 4,
|
intra_op_num_threads: int = 4,
|
||||||
cache_dir: str = None,
|
cache_dir: str = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
self.device = kwargs.get("device", "cpu")
|
||||||
if not Path(model_dir).exists():
|
if not Path(model_dir).exists():
|
||||||
try:
|
try:
|
||||||
from modelscope.hub.snapshot_download import snapshot_download
|
from modelscope.hub.snapshot_download import snapshot_download
|
||||||
@ -99,10 +98,10 @@ class SenseVoiceSmall:
|
|||||||
end_idx = min(waveform_nums, beg_idx + self.batch_size)
|
end_idx = min(waveform_nums, beg_idx + self.batch_size)
|
||||||
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
|
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
|
||||||
ctc_logits, encoder_out_lens = self.ort_infer(
|
ctc_logits, encoder_out_lens = self.ort_infer(
|
||||||
torch.Tensor(feats),
|
torch.Tensor(feats).to(self.device),
|
||||||
torch.Tensor(feats_len),
|
torch.Tensor(feats_len).to(self.device),
|
||||||
torch.tensor([language]),
|
torch.tensor([language]).to(self.device),
|
||||||
torch.tensor([textnorm]),
|
torch.tensor([textnorm]).to(self.device),
|
||||||
)
|
)
|
||||||
# support batch_size=1 only currently
|
# support batch_size=1 only currently
|
||||||
x = ctc_logits[0, : encoder_out_lens[0].item(), :]
|
x = ctc_logits[0, : encoder_out_lens[0].item(), :]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user