amp pipeline

This commit is contained in:
游雁 2023-02-16 18:54:13 +08:00
parent 544b798b32
commit 6023271be6
2 changed files with 163 additions and 12 deletions

View File

@ -41,16 +41,7 @@ from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
header_colors = '\033[95m'
end_colors = '\033[0m'
global_asr_language: str = 'zh-cn'
global_sample_rate: Union[int, Dict[Any, int]] = {
'audio_fs': 16000,
'model_fs': 16000
}
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
class Speech2Text:
@ -346,6 +337,160 @@ class Speech2Text:
# assert check_return_type(results)
return results
class Speech2TextExport(torch.nn.Module):
"""Speech2TextExport class
"""
def __init__(
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
token_type: str = None,
bpemodel: str = None,
device: str = "cpu",
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
dtype: str = "float32",
beam_size: int = 20,
ctc_weight: float = 0.5,
lm_weight: float = 1.0,
ngram_weight: float = 0.9,
penalty: float = 0.0,
nbest: int = 1,
frontend_conf: dict = None,
hotword_list_or_file: str = None,
**kwargs,
):
# 1. Build ASR model
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
token_list = asr_model.token_list
logging.info(f"Decoding device={device}, dtype={dtype}")
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
if token_type is None:
tokenizer = None
elif token_type == "bpe":
if bpemodel is not None:
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
else:
tokenizer = None
else:
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
# self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
self.device = device
self.dtype = dtype
self.nbest = nbest
self.frontend = frontend
model = Paraformer_export(asr_model, onnx=False)
self.asr_model = model
@torch.no_grad()
def forward(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
):
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
self.asr_model.frontend = None
else:
feats = speech
feats_len = speech_lengths
enc_len_batch_total = feats_len.sum()
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
batch = {"speech": feats, "speech_lengths": feats_len}
# a. To device
batch = to_device(batch, device=self.device)
decoder_outs = self.asr_model(**batch)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
results = []
b, n, d = decoder_out.size()
for i in range(b):
am_scores = decoder_out[i, :ys_pad_lens[i], :]
yseq = am_scores.argmax(dim=-1)
score = am_scores.max(dim=-1)[0]
score = torch.sum(score, dim=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
yseq = torch.tensor(
yseq.tolist(), device=yseq.device
)
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
return results
def inference(
maxlenratio: float,
@ -454,9 +599,11 @@ def inference_modelscope(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
export_mode = False
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
export_mode = param_dict.get("export_mode", False)
else:
hotword_list_or_file = None
@ -490,7 +637,10 @@ def inference_modelscope(
nbest=nbest,
hotword_list_or_file=hotword_list_or_file,
)
speech2text = Speech2Text(**speech2text_kwargs)
if export_mode:
speech2text = Speech2TextExport(**speech2text_kwargs)
else:
speech2text = Speech2Text(**speech2text_kwargs)
def _forward(
data_path_and_name_and_type,

View File

@ -44,6 +44,7 @@ class ASRModelExportParaformer:
model,
self.export_config,
)
model.eval()
# self._export_onnx(model, verbose, export_dir)
if self.onnx:
self._export_onnx(model, verbose, export_dir)