mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
ONNX and torchscript export for sensevoice
This commit is contained in:
parent
340c55838b
commit
2ae59b6ce0
15
examples/industrial_data_pretraining/sense_voice/export.py
Normal file
15
examples/industrial_data_pretraining/sense_voice/export.py
Normal file
@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
|
||||
model_dir = "iic/SenseVoiceSmall"
|
||||
model = AutoModel(
|
||||
model=model_dir,
|
||||
device="cuda:0",
|
||||
)
|
||||
|
||||
res = model.export(type="onnx", quantize=False)
|
||||
@ -5,31 +5,20 @@
|
||||
|
||||
import types
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from funasr.register import tables
|
||||
from funasr.utils.torch_function import sequence_mask
|
||||
|
||||
|
||||
def export_rebuild_model(model, **kwargs):
|
||||
model.device = kwargs.get("device")
|
||||
is_onnx = kwargs.get("type", "onnx") == "onnx"
|
||||
# encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
|
||||
# model.encoder = encoder_class(model.encoder, onnx=is_onnx)
|
||||
|
||||
from funasr.utils.torch_function import sequence_mask
|
||||
|
||||
model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
|
||||
|
||||
model.forward = types.MethodType(export_forward, model)
|
||||
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
|
||||
model.export_input_names = types.MethodType(export_input_names, model)
|
||||
model.export_output_names = types.MethodType(export_output_names, model)
|
||||
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
|
||||
model.export_name = types.MethodType(export_name, model)
|
||||
|
||||
model.export_name = "model"
|
||||
return model
|
||||
|
||||
|
||||
def export_forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
@ -38,33 +27,29 @@ def export_forward(
|
||||
textnorm: torch.Tensor,
|
||||
**kwargs,
|
||||
):
|
||||
speech = speech.to(device=kwargs["device"])
|
||||
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
||||
|
||||
language_query = self.embed(language).to(speech.device)
|
||||
|
||||
textnorm_query = self.embed(textnorm).to(speech.device)
|
||||
# speech = speech.to(device="cuda")
|
||||
# speech_lengths = speech_lengths.to(device="cuda")
|
||||
language_query = self.embed(language.to(speech.device)).unsqueeze(1)
|
||||
textnorm_query = self.embed(textnorm.to(speech.device)).unsqueeze(1)
|
||||
|
||||
speech = torch.cat((textnorm_query, speech), dim=1)
|
||||
speech_lengths += 1
|
||||
|
||||
|
||||
event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(
|
||||
speech.size(0), 1, 1
|
||||
)
|
||||
input_query = torch.cat((language_query, event_emo_query), dim=1)
|
||||
speech = torch.cat((input_query, speech), dim=1)
|
||||
speech_lengths += 3
|
||||
|
||||
# Encoder
|
||||
encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
|
||||
|
||||
speech_lengths_new = speech_lengths + 4
|
||||
encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths_new)
|
||||
|
||||
if isinstance(encoder_out, tuple):
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
# c. Passed the encoder result and the beam search
|
||||
ctc_logits = self.ctc.log_softmax(encoder_out)
|
||||
|
||||
ctc_logits = self.ctc.ctc_lo(encoder_out)
|
||||
|
||||
return ctc_logits, encoder_out_lens
|
||||
|
||||
|
||||
def export_dummy_inputs(self):
|
||||
speech = torch.randn(2, 30, 560)
|
||||
speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
|
||||
@ -72,26 +57,21 @@ def export_dummy_inputs(self):
|
||||
textnorm = torch.tensor([15, 15], dtype=torch.int32)
|
||||
return (speech, speech_lengths, language, textnorm)
|
||||
|
||||
|
||||
def export_input_names(self):
|
||||
return ["speech", "speech_lengths", "language", "textnorm"]
|
||||
|
||||
|
||||
def export_output_names(self):
|
||||
return ["ctc_logits", "encoder_out_lens"]
|
||||
|
||||
|
||||
def export_dynamic_axes(self):
|
||||
return {
|
||||
"speech": {0: "batch_size", 1: "feats_length"},
|
||||
"speech_lengths": {
|
||||
0: "batch_size",
|
||||
},
|
||||
"logits": {0: "batch_size", 1: "logits_length"},
|
||||
"speech_lengths": {0: "batch_size"},
|
||||
"language": {0: "batch_size"},
|
||||
"textnorm": {0: "batch_size"},
|
||||
"ctc_logits": {0: "batch_size", 1: "logits_length"},
|
||||
"encoder_out_lens": {0: "batch_size"},
|
||||
}
|
||||
|
||||
|
||||
def export_name(
|
||||
self,
|
||||
):
|
||||
def export_name(self):
|
||||
return "model.onnx"
|
||||
|
||||
@ -54,7 +54,10 @@ def _onnx(
|
||||
|
||||
verbose = kwargs.get("verbose", False)
|
||||
|
||||
export_name = model.export_name + ".onnx"
|
||||
if isinstance(model.export_name, str):
|
||||
export_name = model.export_name + ".onnx"
|
||||
else:
|
||||
export_name = model.export_name()
|
||||
model_path = os.path.join(export_dir, export_name)
|
||||
torch.onnx.export(
|
||||
model,
|
||||
@ -72,35 +75,38 @@ def _onnx(
|
||||
import onnx
|
||||
|
||||
quant_model_path = model_path.replace(".onnx", "_quant.onnx")
|
||||
if not os.path.exists(quant_model_path):
|
||||
onnx_model = onnx.load(model_path)
|
||||
nodes = [n.name for n in onnx_model.graph.node]
|
||||
nodes_to_exclude = [
|
||||
m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m
|
||||
]
|
||||
quantize_dynamic(
|
||||
model_input=model_path,
|
||||
model_output=quant_model_path,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
per_channel=True,
|
||||
reduce_range=False,
|
||||
weight_type=QuantType.QUInt8,
|
||||
nodes_to_exclude=nodes_to_exclude,
|
||||
)
|
||||
onnx_model = onnx.load(model_path)
|
||||
nodes = [n.name for n in onnx_model.graph.node]
|
||||
nodes_to_exclude = [
|
||||
m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m
|
||||
]
|
||||
print("Quantizing model from {} to {}".format(model_path, quant_model_path))
|
||||
quantize_dynamic(
|
||||
model_input=model_path,
|
||||
model_output=quant_model_path,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
per_channel=True,
|
||||
reduce_range=False,
|
||||
weight_type=QuantType.QUInt8,
|
||||
nodes_to_exclude=nodes_to_exclude,
|
||||
)
|
||||
|
||||
|
||||
def _torchscripts(model, path, device="cuda"):
|
||||
dummy_input = model.export_dummy_inputs()
|
||||
|
||||
|
||||
if device == "cuda":
|
||||
model = model.cuda()
|
||||
if isinstance(dummy_input, torch.Tensor):
|
||||
dummy_input = dummy_input.cuda()
|
||||
else:
|
||||
dummy_input = tuple([i.cuda() for i in dummy_input])
|
||||
|
||||
|
||||
model_script = torch.jit.trace(model, dummy_input)
|
||||
model_script.save(os.path.join(path, f"{model.export_name}.torchscript"))
|
||||
if isinstance(model.export_name, str):
|
||||
model_script.save(os.path.join(path, f"{model.export_name}".replace("onnx", "torchscript")))
|
||||
else:
|
||||
model_script.save(os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript")))
|
||||
|
||||
|
||||
def _bladedisc_opt(model, model_inputs, enable_fp16=True):
|
||||
|
||||
38
runtime/python/libtorch/demo_sensevoicesmall.py
Normal file
38
runtime/python/libtorch/demo_sensevoicesmall.py
Normal file
@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import os
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from funasr import AutoModel
|
||||
from funasr_torch import SenseVoiceSmallTorchScript as SenseVoiceSmall
|
||||
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
||||
|
||||
|
||||
model_dir = "iic/SenseVoiceSmall"
|
||||
model = AutoModel(
|
||||
model=model_dir,
|
||||
device="cuda:0",
|
||||
)
|
||||
|
||||
# res = model.export(type="torchscript", quantize=False)
|
||||
|
||||
# export model init
|
||||
model_path = "{}/.cache/modelscope/hub/{}".format(Path.home(), model_dir)
|
||||
model_bin = SenseVoiceSmall(model_path)
|
||||
|
||||
# build tokenizer
|
||||
try:
|
||||
from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
|
||||
tokenizer = SentencepiecesTokenizer(bpemodel=os.path.join(model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model"))
|
||||
except:
|
||||
tokenizer = None
|
||||
|
||||
# inference
|
||||
wav_or_scp = "/Users/shixian/Downloads/asr_example_hotword.wav"
|
||||
language_list = [0]
|
||||
textnorm_list = [15]
|
||||
res = model_bin(wav_or_scp, language_list, textnorm_list, tokenizer=tokenizer)
|
||||
print([rich_transcription_postprocess(i) for i in res])
|
||||
@ -1,2 +1,3 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
from .paraformer_bin import Paraformer
|
||||
from .sensevoice_bin import SenseVoiceSmallTorchScript
|
||||
|
||||
130
runtime/python/libtorch/funasr_torch/sensevoice_bin.py
Normal file
130
runtime/python/libtorch/funasr_torch/sensevoice_bin.py
Normal file
@ -0,0 +1,130 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
|
||||
import torch
|
||||
import os.path
|
||||
import librosa
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Tuple
|
||||
|
||||
from .utils.utils import (
|
||||
CharTokenizer,
|
||||
get_logger,
|
||||
read_yaml,
|
||||
)
|
||||
from .utils.frontend import WavFrontend
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
|
||||
class SenseVoiceSmallTorchScript:
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2206.08317
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: Union[str, Path] = None,
|
||||
batch_size: int = 1,
|
||||
device_id: Union[str, int] = "-1",
|
||||
plot_timestamp_to: str = "",
|
||||
quantize: bool = False,
|
||||
intra_op_num_threads: int = 4,
|
||||
cache_dir: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
if quantize:
|
||||
model_file = os.path.join(model_dir, "model_quant.torchscript")
|
||||
else:
|
||||
model_file = os.path.join(model_dir, "model.torchscript")
|
||||
|
||||
config_file = os.path.join(model_dir, "config.yaml")
|
||||
cmvn_file = os.path.join(model_dir, "am.mvn")
|
||||
config = read_yaml(config_file)
|
||||
# token_list = os.path.join(model_dir, "tokens.json")
|
||||
# with open(token_list, "r", encoding="utf-8") as f:
|
||||
# token_list = json.load(f)
|
||||
|
||||
# self.converter = TokenIDConverter(token_list)
|
||||
self.tokenizer = CharTokenizer()
|
||||
config["frontend_conf"]['cmvn_file'] = cmvn_file
|
||||
self.frontend = WavFrontend(**config["frontend_conf"])
|
||||
self.ort_infer = torch.jit.load(model_file)
|
||||
self.batch_size = batch_size
|
||||
self.blank_id = 0
|
||||
|
||||
def __call__(self,
|
||||
wav_content: Union[str, np.ndarray, List[str]],
|
||||
language: List,
|
||||
textnorm: List,
|
||||
tokenizer=None,
|
||||
**kwargs) -> List:
|
||||
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
|
||||
waveform_nums = len(waveform_list)
|
||||
asr_res = []
|
||||
for beg_idx in range(0, waveform_nums, self.batch_size):
|
||||
end_idx = min(waveform_nums, beg_idx + self.batch_size)
|
||||
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
|
||||
ctc_logits, encoder_out_lens = self.ort_infer(torch.Tensor(feats),
|
||||
torch.Tensor(feats_len),
|
||||
torch.tensor(language),
|
||||
torch.tensor(textnorm)
|
||||
)
|
||||
# support batch_size=1 only currently
|
||||
x = ctc_logits[0, : encoder_out_lens[0].item(), :]
|
||||
yseq = x.argmax(dim=-1)
|
||||
yseq = torch.unique_consecutive(yseq, dim=-1)
|
||||
|
||||
mask = yseq != self.blank_id
|
||||
token_int = yseq[mask].tolist()
|
||||
|
||||
if tokenizer is not None:
|
||||
asr_res.append(tokenizer.tokens2text(token_int))
|
||||
else:
|
||||
asr_res.append(token_int)
|
||||
return asr_res
|
||||
|
||||
def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
|
||||
def load_wav(path: str) -> np.ndarray:
|
||||
waveform, _ = librosa.load(path, sr=fs)
|
||||
return waveform
|
||||
|
||||
if isinstance(wav_content, np.ndarray):
|
||||
return [wav_content]
|
||||
|
||||
if isinstance(wav_content, str):
|
||||
return [load_wav(wav_content)]
|
||||
|
||||
if isinstance(wav_content, list):
|
||||
return [load_wav(path) for path in wav_content]
|
||||
|
||||
raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
|
||||
|
||||
def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
feats, feats_len = [], []
|
||||
for waveform in waveform_list:
|
||||
speech, _ = self.frontend.fbank(waveform)
|
||||
feat, feat_len = self.frontend.lfr_cmvn(speech)
|
||||
feats.append(feat)
|
||||
feats_len.append(feat_len)
|
||||
|
||||
feats = self.pad_feats(feats, np.max(feats_len))
|
||||
feats_len = np.array(feats_len).astype(np.int32)
|
||||
return feats, feats_len
|
||||
|
||||
@staticmethod
|
||||
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
|
||||
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
|
||||
pad_width = ((0, max_feat_len - cur_len), (0, 0))
|
||||
return np.pad(feat, pad_width, "constant", constant_values=0)
|
||||
|
||||
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
|
||||
feats = np.array(feat_res).astype(np.float32)
|
||||
return feats
|
||||
|
||||
38
runtime/python/onnxruntime/demo_sencevoicesmall.py
Normal file
38
runtime/python/onnxruntime/demo_sencevoicesmall.py
Normal file
@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import os
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from funasr import AutoModel
|
||||
from funasr_onnx import SenseVoiceSmallONNX as SenseVoiceSmall
|
||||
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
||||
|
||||
|
||||
model_dir = "iic/SenseVoiceSmall"
|
||||
model = AutoModel(
|
||||
model=model_dir,
|
||||
device="cuda:0",
|
||||
)
|
||||
|
||||
res = model.export(type="onnx", quantize=False)
|
||||
|
||||
# export model init
|
||||
model_path = "{}/.cache/modelscope/hub/{}".format(Path.home(), model_dir)
|
||||
model_bin = SenseVoiceSmall(model_path)
|
||||
|
||||
# build tokenizer
|
||||
try:
|
||||
from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
|
||||
tokenizer = SentencepiecesTokenizer(bpemodel=os.path.join(model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model"))
|
||||
except:
|
||||
tokenizer = None
|
||||
|
||||
# inference
|
||||
wav_or_scp = "/Users/shixian/Downloads/asr_example_hotword.wav"
|
||||
language_list = [0]
|
||||
textnorm_list = [15]
|
||||
res = model_bin(wav_or_scp, language_list, textnorm_list, tokenizer=tokenizer)
|
||||
print([rich_transcription_postprocess(i) for i in res])
|
||||
@ -4,3 +4,4 @@ from .vad_bin import Fsmn_vad
|
||||
from .vad_bin import Fsmn_vad_online
|
||||
from .punc_bin import CT_Transformer
|
||||
from .punc_bin import CT_Transformer_VadRealtime
|
||||
from .sensevoice_bin import SenseVoiceSmallONNX
|
||||
|
||||
145
runtime/python/onnxruntime/funasr_onnx/sensevoice_bin.py
Normal file
145
runtime/python/onnxruntime/funasr_onnx/sensevoice_bin.py
Normal file
@ -0,0 +1,145 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
|
||||
import torch
|
||||
import os.path
|
||||
import librosa
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Tuple
|
||||
|
||||
from .utils.utils import (
|
||||
CharTokenizer,
|
||||
Hypothesis,
|
||||
ONNXRuntimeError,
|
||||
OrtInferSession,
|
||||
TokenIDConverter,
|
||||
get_logger,
|
||||
read_yaml,
|
||||
)
|
||||
from .utils.frontend import WavFrontend
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
|
||||
class SenseVoiceSmallONNX:
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2206.08317
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: Union[str, Path] = None,
|
||||
batch_size: int = 1,
|
||||
device_id: Union[str, int] = "-1",
|
||||
plot_timestamp_to: str = "",
|
||||
quantize: bool = False,
|
||||
intra_op_num_threads: int = 4,
|
||||
cache_dir: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
if quantize:
|
||||
model_file = os.path.join(model_dir, "model_quant.onnx")
|
||||
else:
|
||||
model_file = os.path.join(model_dir, "model.onnx")
|
||||
|
||||
config_file = os.path.join(model_dir, "config.yaml")
|
||||
cmvn_file = os.path.join(model_dir, "am.mvn")
|
||||
config = read_yaml(config_file)
|
||||
# token_list = os.path.join(model_dir, "tokens.json")
|
||||
# with open(token_list, "r", encoding="utf-8") as f:
|
||||
# token_list = json.load(f)
|
||||
|
||||
# self.converter = TokenIDConverter(token_list)
|
||||
self.tokenizer = CharTokenizer()
|
||||
config["frontend_conf"]['cmvn_file'] = cmvn_file
|
||||
self.frontend = WavFrontend(**config["frontend_conf"])
|
||||
self.ort_infer = OrtInferSession(
|
||||
model_file, device_id, intra_op_num_threads=intra_op_num_threads
|
||||
)
|
||||
self.batch_size = batch_size
|
||||
self.blank_id = 0
|
||||
|
||||
def __call__(self,
|
||||
wav_content: Union[str, np.ndarray, List[str]],
|
||||
language: List,
|
||||
textnorm: List,
|
||||
tokenizer=None,
|
||||
**kwargs) -> List:
|
||||
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
|
||||
waveform_nums = len(waveform_list)
|
||||
asr_res = []
|
||||
for beg_idx in range(0, waveform_nums, self.batch_size):
|
||||
end_idx = min(waveform_nums, beg_idx + self.batch_size)
|
||||
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
|
||||
ctc_logits, encoder_out_lens = self.infer(feats,
|
||||
feats_len,
|
||||
np.array(language, dtype=np.int32),
|
||||
np.array(textnorm, dtype=np.int32)
|
||||
)
|
||||
# back to torch.Tensor
|
||||
ctc_logits = torch.from_numpy(ctc_logits).float()
|
||||
# support batch_size=1 only currently
|
||||
x = ctc_logits[0, : encoder_out_lens[0].item(), :]
|
||||
yseq = x.argmax(dim=-1)
|
||||
yseq = torch.unique_consecutive(yseq, dim=-1)
|
||||
|
||||
mask = yseq != self.blank_id
|
||||
token_int = yseq[mask].tolist()
|
||||
|
||||
if tokenizer is not None:
|
||||
asr_res.append(tokenizer.tokens2text(token_int))
|
||||
else:
|
||||
asr_res.append(token_int)
|
||||
return asr_res
|
||||
|
||||
def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
|
||||
def load_wav(path: str) -> np.ndarray:
|
||||
waveform, _ = librosa.load(path, sr=fs)
|
||||
return waveform
|
||||
|
||||
if isinstance(wav_content, np.ndarray):
|
||||
return [wav_content]
|
||||
|
||||
if isinstance(wav_content, str):
|
||||
return [load_wav(wav_content)]
|
||||
|
||||
if isinstance(wav_content, list):
|
||||
return [load_wav(path) for path in wav_content]
|
||||
|
||||
raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
|
||||
|
||||
def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
feats, feats_len = [], []
|
||||
for waveform in waveform_list:
|
||||
speech, _ = self.frontend.fbank(waveform)
|
||||
feat, feat_len = self.frontend.lfr_cmvn(speech)
|
||||
feats.append(feat)
|
||||
feats_len.append(feat_len)
|
||||
|
||||
feats = self.pad_feats(feats, np.max(feats_len))
|
||||
feats_len = np.array(feats_len).astype(np.int32)
|
||||
return feats, feats_len
|
||||
|
||||
@staticmethod
|
||||
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
|
||||
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
|
||||
pad_width = ((0, max_feat_len - cur_len), (0, 0))
|
||||
return np.pad(feat, pad_width, "constant", constant_values=0)
|
||||
|
||||
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
|
||||
feats = np.array(feat_res).astype(np.float32)
|
||||
return feats
|
||||
|
||||
def infer(self,
|
||||
feats: np.ndarray,
|
||||
feats_len: np.ndarray,
|
||||
language: np.ndarray,
|
||||
textnorm: np.ndarray,) -> Tuple[np.ndarray, np.ndarray]:
|
||||
outputs = self.ort_infer([feats, feats_len, language, textnorm])
|
||||
return outputs
|
||||
Loading…
Reference in New Issue
Block a user