From 4ac582341c5f88fe30bc47225cf9811cc1233983 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Mon, 15 May 2023 00:32:33 +0800 Subject: [PATCH] inference --- funasr/bin/asr_infer.py | 1270 ++++++++++++++++++++++++++++ funasr/bin/asr_inference.py | 65 +- funasr/bin/asr_inference_launch.py | 966 ++++++++++++++++++++- 3 files changed, 2196 insertions(+), 105 deletions(-) create mode 100644 funasr/bin/asr_infer.py diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py new file mode 100644 index 000000000..dce9ee009 --- /dev/null +++ b/funasr/bin/asr_infer.py @@ -0,0 +1,1270 @@ +#!/usr/bin/env python3 +import argparse +import logging +import sys +import time +import copy +import os +import codecs +import tempfile +import requests +from pathlib import Path +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union +from typing import Dict +from typing import Any +from typing import List + +import numpy as np +import torch +from typeguard import check_argument_types +from typeguard import check_return_type +from funasr.fileio.datadir_writer import DatadirWriter +from funasr.modules.beam_search.beam_search import BeamSearch +# from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch + +from funasr.modules.beam_search.beam_search import Hypothesis +from funasr.modules.scorers.ctc import CTCPrefixScorer +from funasr.modules.scorers.length_bonus import LengthBonus +from funasr.modules.subsampling import TooShortUttError +from funasr.tasks.asr import ASRTask +from funasr.tasks.lm import LMTask +from funasr.text.build_tokenizer import build_tokenizer +from funasr.text.token_id_converter import TokenIDConverter +from funasr.torch_utils.device_funcs import to_device +from funasr.torch_utils.set_all_random_seed import set_all_random_seed +from funasr.utils import config_argparse +from funasr.utils.cli_utils import get_commandline_args +from funasr.utils.types import str2bool +from funasr.utils.types import str2triple_str +from funasr.utils.types import str_or_none +from funasr.utils import asr_utils, wav_utils, postprocess_utils +from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline +from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer +from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer +from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export +from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard +from funasr.bin.tp_inference import SpeechText2Timestamp +from funasr.bin.vad_inference import Speech2VadSegment +from funasr.bin.punctuation_infer import Text2Punc +from funasr.utils.vad_utils import slice_padding_fbank +from funasr.tasks.vad import VADTask +from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard + + +class Speech2Text: + """Speech2Text class + + Examples: + >>> import soundfile + >>> speech2text = Speech2Text("asr_config.yml", "asr.pb") + >>> audio, rate = soundfile.read("speech.wav") + >>> speech2text(audio) + [(text, token, token_int, hypothesis object), ...] + + """ + + def __init__( + self, + asr_train_config: Union[Path, str] = None, + asr_model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, + lm_train_config: Union[Path, str] = None, + lm_file: Union[Path, str] = None, + token_type: str = None, + bpemodel: str = None, + device: str = "cpu", + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + batch_size: int = 1, + dtype: str = "float32", + beam_size: int = 20, + ctc_weight: float = 0.5, + lm_weight: float = 1.0, + ngram_weight: float = 0.9, + penalty: float = 0.0, + nbest: int = 1, + streaming: bool = False, + frontend_conf: dict = None, + **kwargs, + ): + assert check_argument_types() + + # 1. Build ASR model + scorers = {} + asr_model, asr_train_args = ASRTask.build_model_from_file( + asr_train_config, asr_model_file, cmvn_file, device + ) + frontend = None + if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: + if asr_train_args.frontend == 'wav_frontend': + frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) + else: + from funasr.tasks.asr import frontend_choices + frontend_class = frontend_choices.get_class(asr_train_args.frontend) + frontend = frontend_class(**asr_train_args.frontend_conf).eval() + + logging.info("asr_model: {}".format(asr_model)) + logging.info("asr_train_args: {}".format(asr_train_args)) + asr_model.to(dtype=getattr(torch, dtype)).eval() + + decoder = asr_model.decoder + + ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + token_list = asr_model.token_list + scorers.update( + decoder=decoder, + ctc=ctc, + length_bonus=LengthBonus(len(token_list)), + ) + + # 2. Build Language model + if lm_train_config is not None: + lm, lm_train_args = LMTask.build_model_from_file( + lm_train_config, lm_file, None, device + ) + scorers["lm"] = lm.lm + + # 3. Build ngram model + # ngram is not supported now + ngram = None + scorers["ngram"] = ngram + + # 4. Build BeamSearch object + # transducer is not supported now + beam_search_transducer = None + from funasr.modules.beam_search.beam_search import BeamSearch + + weights = dict( + decoder=1.0 - ctc_weight, + ctc=ctc_weight, + lm=lm_weight, + ngram=ngram_weight, + length_bonus=penalty, + ) + beam_search = BeamSearch( + beam_size=beam_size, + weights=weights, + scorers=scorers, + sos=asr_model.sos, + eos=asr_model.eos, + vocab_size=len(token_list), + token_list=token_list, + pre_beam_score_key=None if ctc_weight == 1.0 else "full", + ) + + # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text + if token_type is None: + token_type = asr_train_args.token_type + if bpemodel is None: + bpemodel = asr_train_args.bpemodel + + if token_type is None: + tokenizer = None + elif token_type == "bpe": + if bpemodel is not None: + tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) + else: + tokenizer = None + else: + tokenizer = build_tokenizer(token_type=token_type) + converter = TokenIDConverter(token_list=token_list) + logging.info(f"Text tokenizer: {tokenizer}") + + self.asr_model = asr_model + self.asr_train_args = asr_train_args + self.converter = converter + self.tokenizer = tokenizer + self.beam_search = beam_search + self.beam_search_transducer = beam_search_transducer + self.maxlenratio = maxlenratio + self.minlenratio = minlenratio + self.device = device + self.dtype = dtype + self.nbest = nbest + self.frontend = frontend + + @torch.no_grad() + def __call__( + self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None + ) -> List[ + Tuple[ + Optional[str], + List[str], + List[int], + Union[Hypothesis], + ] + ]: + """Inference + + Args: + speech: Input speech data + Returns: + text, token, token_int, hyp + + """ + assert check_argument_types() + + # Input as audio signal + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + + if self.frontend is not None: + feats, feats_len = self.frontend.forward(speech, speech_lengths) + feats = to_device(feats, device=self.device) + feats_len = feats_len.int() + self.asr_model.frontend = None + else: + feats = speech + feats_len = speech_lengths + lfr_factor = max(1, (feats.size()[-1] // 80) - 1) + batch = {"speech": feats, "speech_lengths": feats_len} + + # a. To device + batch = to_device(batch, device=self.device) + + # b. Forward Encoder + enc, _ = self.asr_model.encode(**batch) + if isinstance(enc, tuple): + enc = enc[0] + assert len(enc) == 1, len(enc) + + # c. Passed the encoder result and the beam search + nbest_hyps = self.beam_search( + x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio + ) + + nbest_hyps = nbest_hyps[: self.nbest] + + results = [] + for hyp in nbest_hyps: + assert isinstance(hyp, (Hypothesis)), type(hyp) + + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != 0, token_int)) + + # Change integer-ids to tokens + token = self.converter.ids2tokens(token_int) + + if self.tokenizer is not None: + text = self.tokenizer.tokens2text(token) + else: + text = None + results.append((text, token, token_int, hyp)) + + assert check_return_type(results) + return results + + +class Speech2TextParaformer: + """Speech2Text class + + Examples: + >>> import soundfile + >>> speech2text = Speech2TextParaformer("asr_config.yml", "asr.pb") + >>> audio, rate = soundfile.read("speech.wav") + >>> speech2text(audio) + [(text, token, token_int, hypothesis object), ...] + + """ + + def __init__( + self, + asr_train_config: Union[Path, str] = None, + asr_model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, + lm_train_config: Union[Path, str] = None, + lm_file: Union[Path, str] = None, + token_type: str = None, + bpemodel: str = None, + device: str = "cpu", + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + dtype: str = "float32", + beam_size: int = 20, + ctc_weight: float = 0.5, + lm_weight: float = 1.0, + ngram_weight: float = 0.9, + penalty: float = 0.0, + nbest: int = 1, + frontend_conf: dict = None, + hotword_list_or_file: str = None, + **kwargs, + ): + assert check_argument_types() + + # 1. Build ASR model + scorers = {} + from funasr.tasks.asr import ASRTaskParaformer as ASRTask + asr_model, asr_train_args = ASRTask.build_model_from_file( + asr_train_config, asr_model_file, cmvn_file, device + ) + frontend = None + if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: + frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) + + logging.info("asr_model: {}".format(asr_model)) + logging.info("asr_train_args: {}".format(asr_train_args)) + asr_model.to(dtype=getattr(torch, dtype)).eval() + + if asr_model.ctc != None: + ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + scorers.update( + ctc=ctc + ) + token_list = asr_model.token_list + scorers.update( + length_bonus=LengthBonus(len(token_list)), + ) + + # 2. Build Language model + if lm_train_config is not None: + lm, lm_train_args = LMTask.build_model_from_file( + lm_train_config, lm_file, device + ) + scorers["lm"] = lm.lm + + # 3. Build ngram model + # ngram is not supported now + ngram = None + scorers["ngram"] = ngram + + # 4. Build BeamSearch object + # transducer is not supported now + beam_search_transducer = None + from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch + + weights = dict( + decoder=1.0 - ctc_weight, + ctc=ctc_weight, + lm=lm_weight, + ngram=ngram_weight, + length_bonus=penalty, + ) + beam_search = BeamSearch( + beam_size=beam_size, + weights=weights, + scorers=scorers, + sos=asr_model.sos, + eos=asr_model.eos, + vocab_size=len(token_list), + token_list=token_list, + pre_beam_score_key=None if ctc_weight == 1.0 else "full", + ) + + beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() + for scorer in scorers.values(): + if isinstance(scorer, torch.nn.Module): + scorer.to(device=device, dtype=getattr(torch, dtype)).eval() + + logging.info(f"Decoding device={device}, dtype={dtype}") + + # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text + if token_type is None: + token_type = asr_train_args.token_type + if bpemodel is None: + bpemodel = asr_train_args.bpemodel + + if token_type is None: + tokenizer = None + elif token_type == "bpe": + if bpemodel is not None: + tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) + else: + tokenizer = None + else: + tokenizer = build_tokenizer(token_type=token_type) + converter = TokenIDConverter(token_list=token_list) + logging.info(f"Text tokenizer: {tokenizer}") + + self.asr_model = asr_model + self.asr_train_args = asr_train_args + self.converter = converter + self.tokenizer = tokenizer + + # 6. [Optional] Build hotword list from str, local file or url + self.hotword_list = None + self.hotword_list = self.generate_hotwords_list(hotword_list_or_file) + + is_use_lm = lm_weight != 0.0 and lm_file is not None + if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm: + beam_search = None + self.beam_search = beam_search + logging.info(f"Beam_search: {self.beam_search}") + self.beam_search_transducer = beam_search_transducer + self.maxlenratio = maxlenratio + self.minlenratio = minlenratio + self.device = device + self.dtype = dtype + self.nbest = nbest + self.frontend = frontend + self.encoder_downsampling_factor = 1 + if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d": + self.encoder_downsampling_factor = 4 + + @torch.no_grad() + def __call__( + self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None, + begin_time: int = 0, end_time: int = None, + ): + """Inference + + Args: + speech: Input speech data + Returns: + text, token, token_int, hyp + + """ + assert check_argument_types() + + # Input as audio signal + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + + if self.frontend is not None: + feats, feats_len = self.frontend.forward(speech, speech_lengths) + feats = to_device(feats, device=self.device) + feats_len = feats_len.int() + self.asr_model.frontend = None + else: + feats = speech + feats_len = speech_lengths + lfr_factor = max(1, (feats.size()[-1] // 80) - 1) + batch = {"speech": feats, "speech_lengths": feats_len} + + # a. To device + batch = to_device(batch, device=self.device) + + # b. Forward Encoder + enc, enc_len = self.asr_model.encode(**batch) + if isinstance(enc, tuple): + enc = enc[0] + # assert len(enc) == 1, len(enc) + enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor + + predictor_outs = self.asr_model.calc_predictor(enc, enc_len) + pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \ + predictor_outs[2], predictor_outs[3] + pre_token_length = pre_token_length.round().long() + if torch.max(pre_token_length) < 1: + return [] + if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model, NeatContextualParaformer): + if self.hotword_list: + logging.warning("Hotword is given but asr model is not a ContextualParaformer.") + decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) + decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] + else: + decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list) + decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] + + if isinstance(self.asr_model, BiCifParaformer): + _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len, + pre_token_length) # test no bias cif2 + + results = [] + b, n, d = decoder_out.size() + for i in range(b): + x = enc[i, :enc_len[i], :] + am_scores = decoder_out[i, :pre_token_length[i], :] + if self.beam_search is not None: + nbest_hyps = self.beam_search( + x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio + ) + + nbest_hyps = nbest_hyps[: self.nbest] + else: + yseq = am_scores.argmax(dim=-1) + score = am_scores.max(dim=-1)[0] + score = torch.sum(score, dim=-1) + # pad with mask tokens to ensure compatibility with sos/eos tokens + yseq = torch.tensor( + [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device + ) + nbest_hyps = [Hypothesis(yseq=yseq, score=score)] + + for hyp in nbest_hyps: + assert isinstance(hyp, (Hypothesis)), type(hyp) + + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != 0 and x != 2, token_int)) + + # Change integer-ids to tokens + token = self.converter.ids2tokens(token_int) + + if self.tokenizer is not None: + text = self.tokenizer.tokens2text(token) + else: + text = None + timestamp = [] + if isinstance(self.asr_model, BiCifParaformer): + _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:enc_len[i]*3], + us_peaks[i][:enc_len[i]*3], + copy.copy(token), + vad_offset=begin_time) + results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor)) + + + # assert check_return_type(results) + return results + + def generate_hotwords_list(self, hotword_list_or_file): + # for None + if hotword_list_or_file is None: + hotword_list = None + # for local txt inputs + elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'): + logging.info("Attempting to parse hotwords from local txt...") + hotword_list = [] + hotword_str_list = [] + with codecs.open(hotword_list_or_file, 'r') as fin: + for line in fin.readlines(): + hw = line.strip() + hotword_str_list.append(hw) + hotword_list.append(self.converter.tokens2ids([i for i in hw])) + hotword_list.append([self.asr_model.sos]) + hotword_str_list.append('') + logging.info("Initialized hotword list from file: {}, hotword list: {}." + .format(hotword_list_or_file, hotword_str_list)) + # for url, download and generate txt + elif hotword_list_or_file.startswith('http'): + logging.info("Attempting to parse hotwords from url...") + work_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(work_dir): + os.makedirs(work_dir) + text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file)) + local_file = requests.get(hotword_list_or_file) + open(text_file_path, "wb").write(local_file.content) + hotword_list_or_file = text_file_path + hotword_list = [] + hotword_str_list = [] + with codecs.open(hotword_list_or_file, 'r') as fin: + for line in fin.readlines(): + hw = line.strip() + hotword_str_list.append(hw) + hotword_list.append(self.converter.tokens2ids([i for i in hw])) + hotword_list.append([self.asr_model.sos]) + hotword_str_list.append('') + logging.info("Initialized hotword list from file: {}, hotword list: {}." + .format(hotword_list_or_file, hotword_str_list)) + # for text str input + elif not hotword_list_or_file.endswith('.txt'): + logging.info("Attempting to parse hotwords as str...") + hotword_list = [] + hotword_str_list = [] + for hw in hotword_list_or_file.strip().split(): + hotword_str_list.append(hw) + hotword_list.append(self.converter.tokens2ids([i for i in hw])) + hotword_list.append([self.asr_model.sos]) + hotword_str_list.append('') + logging.info("Hotword list: {}.".format(hotword_str_list)) + else: + hotword_list = None + return hotword_list + +class Speech2TextParaformerOnline: + """Speech2Text class + + Examples: + >>> import soundfile + >>> speech2text = Speech2TextParaformerOnline("asr_config.yml", "asr.pth") + >>> audio, rate = soundfile.read("speech.wav") + >>> speech2text(audio) + [(text, token, token_int, hypothesis object), ...] + + """ + + def __init__( + self, + asr_train_config: Union[Path, str] = None, + asr_model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, + lm_train_config: Union[Path, str] = None, + lm_file: Union[Path, str] = None, + token_type: str = None, + bpemodel: str = None, + device: str = "cpu", + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + dtype: str = "float32", + beam_size: int = 20, + ctc_weight: float = 0.5, + lm_weight: float = 1.0, + ngram_weight: float = 0.9, + penalty: float = 0.0, + nbest: int = 1, + frontend_conf: dict = None, + hotword_list_or_file: str = None, + **kwargs, + ): + assert check_argument_types() + + # 1. Build ASR model + scorers = {} + asr_model, asr_train_args = ASRTask.build_model_from_file( + asr_train_config, asr_model_file, cmvn_file, device + ) + frontend = None + if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: + frontend = WavFrontendOnline(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) + + logging.info("asr_model: {}".format(asr_model)) + logging.info("asr_train_args: {}".format(asr_train_args)) + asr_model.to(dtype=getattr(torch, dtype)).eval() + + if asr_model.ctc != None: + ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + scorers.update( + ctc=ctc + ) + token_list = asr_model.token_list + scorers.update( + length_bonus=LengthBonus(len(token_list)), + ) + + # 2. Build Language model + if lm_train_config is not None: + lm, lm_train_args = LMTask.build_model_from_file( + lm_train_config, lm_file, device + ) + scorers["lm"] = lm.lm + + # 3. Build ngram model + # ngram is not supported now + ngram = None + scorers["ngram"] = ngram + + # 4. Build BeamSearch object + # transducer is not supported now + beam_search_transducer = None + from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch + + weights = dict( + decoder=1.0 - ctc_weight, + ctc=ctc_weight, + lm=lm_weight, + ngram=ngram_weight, + length_bonus=penalty, + ) + beam_search = BeamSearch( + beam_size=beam_size, + weights=weights, + scorers=scorers, + sos=asr_model.sos, + eos=asr_model.eos, + vocab_size=len(token_list), + token_list=token_list, + pre_beam_score_key=None if ctc_weight == 1.0 else "full", + ) + + beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() + for scorer in scorers.values(): + if isinstance(scorer, torch.nn.Module): + scorer.to(device=device, dtype=getattr(torch, dtype)).eval() + + logging.info(f"Decoding device={device}, dtype={dtype}") + + # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text + if token_type is None: + token_type = asr_train_args.token_type + if bpemodel is None: + bpemodel = asr_train_args.bpemodel + + if token_type is None: + tokenizer = None + elif token_type == "bpe": + if bpemodel is not None: + tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) + else: + tokenizer = None + else: + tokenizer = build_tokenizer(token_type=token_type) + converter = TokenIDConverter(token_list=token_list) + logging.info(f"Text tokenizer: {tokenizer}") + + self.asr_model = asr_model + self.asr_train_args = asr_train_args + self.converter = converter + self.tokenizer = tokenizer + + # 6. [Optional] Build hotword list from str, local file or url + + is_use_lm = lm_weight != 0.0 and lm_file is not None + if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm: + beam_search = None + self.beam_search = beam_search + logging.info(f"Beam_search: {self.beam_search}") + self.beam_search_transducer = beam_search_transducer + self.maxlenratio = maxlenratio + self.minlenratio = minlenratio + self.device = device + self.dtype = dtype + self.nbest = nbest + self.frontend = frontend + self.encoder_downsampling_factor = 1 + if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d": + self.encoder_downsampling_factor = 4 + + @torch.no_grad() + def __call__( + self, cache: dict, speech: Union[torch.Tensor], speech_lengths: Union[torch.Tensor] = None + ): + """Inference + + Args: + speech: Input speech data + Returns: + text, token, token_int, hyp + + """ + assert check_argument_types() + results = [] + cache_en = cache["encoder"] + if speech.shape[1] < 16 * 60 and cache_en["is_final"]: + if cache_en["start_idx"] == 0: + return [] + cache_en["tail_chunk"] = True + feats = cache_en["feats"] + feats_len = torch.tensor([feats.shape[1]]) + self.asr_model.frontend = None + results = self.infer(feats, feats_len, cache) + return results + else: + if self.frontend is not None: + feats, feats_len = self.frontend.forward(speech, speech_lengths, cache_en["is_final"]) + feats = to_device(feats, device=self.device) + feats_len = feats_len.int() + self.asr_model.frontend = None + else: + feats = speech + feats_len = speech_lengths + + if feats.shape[1] != 0: + if cache_en["is_final"]: + if feats.shape[1] + cache_en["chunk_size"][2] < cache_en["chunk_size"][1]: + cache_en["last_chunk"] = True + else: + # first chunk + feats_chunk1 = feats[:, :cache_en["chunk_size"][1], :] + feats_len = torch.tensor([feats_chunk1.shape[1]]) + results_chunk1 = self.infer(feats_chunk1, feats_len, cache) + + # last chunk + cache_en["last_chunk"] = True + feats_chunk2 = feats[:, -(feats.shape[1] + cache_en["chunk_size"][2] - cache_en["chunk_size"][1]):, :] + feats_len = torch.tensor([feats_chunk2.shape[1]]) + results_chunk2 = self.infer(feats_chunk2, feats_len, cache) + + return [" ".join(results_chunk1 + results_chunk2)] + + results = self.infer(feats, feats_len, cache) + + return results + + @torch.no_grad() + def infer(self, feats: Union[torch.Tensor], feats_len: Union[torch.Tensor], cache: List = None): + batch = {"speech": feats, "speech_lengths": feats_len} + batch = to_device(batch, device=self.device) + # b. Forward Encoder + enc, enc_len = self.asr_model.encode_chunk(feats, feats_len, cache=cache) + if isinstance(enc, tuple): + enc = enc[0] + # assert len(enc) == 1, len(enc) + enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor + + predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache) + pre_acoustic_embeds, pre_token_length= predictor_outs[0], predictor_outs[1] + if torch.max(pre_token_length) < 1: + return [] + decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache) + decoder_out = decoder_outs + + results = [] + b, n, d = decoder_out.size() + for i in range(b): + x = enc[i, :enc_len[i], :] + am_scores = decoder_out[i, :pre_token_length[i], :] + if self.beam_search is not None: + nbest_hyps = self.beam_search( + x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio + ) + + nbest_hyps = nbest_hyps[: self.nbest] + else: + yseq = am_scores.argmax(dim=-1) + score = am_scores.max(dim=-1)[0] + score = torch.sum(score, dim=-1) + # pad with mask tokens to ensure compatibility with sos/eos tokens + yseq = torch.tensor( + [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device + ) + nbest_hyps = [Hypothesis(yseq=yseq, score=score)] + + for hyp in nbest_hyps: + assert isinstance(hyp, (Hypothesis)), type(hyp) + + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != 0 and x != 2, token_int)) + + # Change integer-ids to tokens + token = self.converter.ids2tokens(token_int) + token = " ".join(token) + + results.append(token) + + # assert check_return_type(results) + return results + + +class Speech2TextUniASR: + """Speech2Text class + + Examples: + >>> import soundfile + >>> speech2text = Speech2TextUniASR("asr_config.yml", "asr.pb") + >>> audio, rate = soundfile.read("speech.wav") + >>> speech2text(audio) + [(text, token, token_int, hypothesis object), ...] + + """ + + def __init__( + self, + asr_train_config: Union[Path, str] = None, + asr_model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, + lm_train_config: Union[Path, str] = None, + lm_file: Union[Path, str] = None, + token_type: str = None, + bpemodel: str = None, + device: str = "cpu", + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + dtype: str = "float32", + beam_size: int = 20, + ctc_weight: float = 0.5, + lm_weight: float = 1.0, + ngram_weight: float = 0.9, + penalty: float = 0.0, + nbest: int = 1, + token_num_relax: int = 1, + decoding_ind: int = 0, + decoding_mode: str = "model1", + frontend_conf: dict = None, + **kwargs, + ): + assert check_argument_types() + + # 1. Build ASR model + scorers = {} + from funasr.tasks.asr import ASRTaskUniASR as ASRTask + asr_model, asr_train_args = ASRTask.build_model_from_file( + asr_train_config, asr_model_file, cmvn_file, device + ) + frontend = None + if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: + frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) + + logging.info("asr_train_args: {}".format(asr_train_args)) + asr_model.to(dtype=getattr(torch, dtype)).eval() + if decoding_mode == "model1": + decoder = asr_model.decoder + else: + decoder = asr_model.decoder2 + + if asr_model.ctc != None: + ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + scorers.update( + ctc=ctc + ) + token_list = asr_model.token_list + scorers.update( + decoder=decoder, + length_bonus=LengthBonus(len(token_list)), + ) + + # 2. Build Language model + if lm_train_config is not None: + lm, lm_train_args = LMTask.build_model_from_file( + lm_train_config, lm_file, device + ) + scorers["lm"] = lm.lm + + # 3. Build ngram model + # ngram is not supported now + ngram = None + scorers["ngram"] = ngram + + # 4. Build BeamSearch object + # transducer is not supported now + beam_search_transducer = None + from funasr.modules.beam_search.beam_search import BeamSearchScama as BeamSearch + + weights = dict( + decoder=1.0 - ctc_weight, + ctc=ctc_weight, + lm=lm_weight, + ngram=ngram_weight, + length_bonus=penalty, + ) + beam_search = BeamSearch( + beam_size=beam_size, + weights=weights, + scorers=scorers, + sos=asr_model.sos, + eos=asr_model.eos, + vocab_size=len(token_list), + token_list=token_list, + pre_beam_score_key=None if ctc_weight == 1.0 else "full", + ) + + beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() + for scorer in scorers.values(): + if isinstance(scorer, torch.nn.Module): + scorer.to(device=device, dtype=getattr(torch, dtype)).eval() + # logging.info(f"Beam_search: {beam_search}") + logging.info(f"Decoding device={device}, dtype={dtype}") + + # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text + if token_type is None: + token_type = asr_train_args.token_type + if bpemodel is None: + bpemodel = asr_train_args.bpemodel + + if token_type is None: + tokenizer = None + elif token_type == "bpe": + if bpemodel is not None: + tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) + else: + tokenizer = None + else: + tokenizer = build_tokenizer(token_type=token_type) + converter = TokenIDConverter(token_list=token_list) + logging.info(f"Text tokenizer: {tokenizer}") + + self.asr_model = asr_model + self.asr_train_args = asr_train_args + self.converter = converter + self.tokenizer = tokenizer + self.beam_search = beam_search + self.beam_search_transducer = beam_search_transducer + self.maxlenratio = maxlenratio + self.minlenratio = minlenratio + self.device = device + self.dtype = dtype + self.nbest = nbest + self.token_num_relax = token_num_relax + self.decoding_ind = decoding_ind + self.decoding_mode = decoding_mode + self.frontend = frontend + + @torch.no_grad() + def __call__( + self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None + ) -> List[ + Tuple[ + Optional[str], + List[str], + List[int], + Union[Hypothesis], + ] + ]: + """Inference + + Args: + speech: Input speech data + Returns: + text, token, token_int, hyp + + """ + assert check_argument_types() + + # Input as audio signal + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + + if self.frontend is not None: + feats, feats_len = self.frontend.forward(speech, speech_lengths) + feats = to_device(feats, device=self.device) + feats_len = feats_len.int() + self.asr_model.frontend = None + else: + feats = speech + feats_len = speech_lengths + lfr_factor = max(1, (feats.size()[-1] // 80) - 1) + feats_raw = feats.clone().to(self.device) + batch = {"speech": feats, "speech_lengths": feats_len} + + # a. To device + batch = to_device(batch, device=self.device) + # b. Forward Encoder + _, enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind) + if isinstance(enc, tuple): + enc = enc[0] + assert len(enc) == 1, len(enc) + if self.decoding_mode == "model1": + predictor_outs = self.asr_model.calc_predictor_mask(enc, enc_len) + else: + enc, enc_len = self.asr_model.encode2(enc, enc_len, feats_raw, feats_len, ind=self.decoding_ind) + predictor_outs = self.asr_model.calc_predictor_mask2(enc, enc_len) + + scama_mask = predictor_outs[4] + pre_token_length = predictor_outs[1] + pre_acoustic_embeds = predictor_outs[0] + maxlen = pre_token_length.sum().item() + self.token_num_relax + minlen = max(0, pre_token_length.sum().item() - self.token_num_relax) + # c. Passed the encoder result and the beam search + nbest_hyps = self.beam_search( + x=enc[0], scama_mask=scama_mask, pre_acoustic_embeds=pre_acoustic_embeds, maxlenratio=self.maxlenratio, + minlenratio=self.minlenratio, maxlen=int(maxlen), minlen=int(minlen), + ) + + nbest_hyps = nbest_hyps[: self.nbest] + + results = [] + for hyp in nbest_hyps: + assert isinstance(hyp, (Hypothesis)), type(hyp) + + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != 0, token_int)) + + # Change integer-ids to tokens + token = self.converter.ids2tokens(token_int) + token = list(filter(lambda x: x != "", token)) + + if self.tokenizer is not None: + text = self.tokenizer.tokens2text(token) + else: + text = None + results.append((text, token, token_int, hyp)) + + assert check_return_type(results) + return results + + + + +class Speech2TextMFCCA: + """Speech2Text class + + Examples: + >>> import soundfile + >>> speech2text = Speech2TextMFCCA("asr_config.yml", "asr.pb") + >>> audio, rate = soundfile.read("speech.wav") + >>> speech2text(audio) + [(text, token, token_int, hypothesis object), ...] + + """ + + def __init__( + self, + asr_train_config: Union[Path, str] = None, + asr_model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, + lm_train_config: Union[Path, str] = None, + lm_file: Union[Path, str] = None, + token_type: str = None, + bpemodel: str = None, + device: str = "cpu", + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + batch_size: int = 1, + dtype: str = "float32", + beam_size: int = 20, + ctc_weight: float = 0.5, + lm_weight: float = 1.0, + ngram_weight: float = 0.9, + penalty: float = 0.0, + nbest: int = 1, + streaming: bool = False, + **kwargs, + ): + assert check_argument_types() + + # 1. Build ASR model + scorers = {} + asr_model, asr_train_args = ASRTask.build_model_from_file( + asr_train_config, asr_model_file, cmvn_file, device + ) + + logging.info("asr_model: {}".format(asr_model)) + logging.info("asr_train_args: {}".format(asr_train_args)) + asr_model.to(dtype=getattr(torch, dtype)).eval() + + decoder = asr_model.decoder + + ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + token_list = asr_model.token_list + scorers.update( + decoder=decoder, + ctc=ctc, + length_bonus=LengthBonus(len(token_list)), + ) + + # 2. Build Language model + if lm_train_config is not None: + lm, lm_train_args = LMTask.build_model_from_file( + lm_train_config, lm_file, device + ) + lm.to(device) + scorers["lm"] = lm.lm + # 3. Build ngram model + # ngram is not supported now + ngram = None + scorers["ngram"] = ngram + + # 4. Build BeamSearch object + # transducer is not supported now + beam_search_transducer = None + + weights = dict( + decoder=1.0 - ctc_weight, + ctc=ctc_weight, + lm=lm_weight, + ngram=ngram_weight, + length_bonus=penalty, + ) + beam_search = BeamSearch( + beam_size=beam_size, + weights=weights, + scorers=scorers, + sos=asr_model.sos, + eos=asr_model.eos, + vocab_size=len(token_list), + token_list=token_list, + pre_beam_score_key=None if ctc_weight == 1.0 else "full", + ) + # beam_search.__class__ = BatchBeamSearch + # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text + if token_type is None: + token_type = asr_train_args.token_type + if bpemodel is None: + bpemodel = asr_train_args.bpemodel + + if token_type is None: + tokenizer = None + elif token_type == "bpe": + if bpemodel is not None: + tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) + else: + tokenizer = None + else: + tokenizer = build_tokenizer(token_type=token_type) + converter = TokenIDConverter(token_list=token_list) + logging.info(f"Text tokenizer: {tokenizer}") + + self.asr_model = asr_model + self.asr_train_args = asr_train_args + self.converter = converter + self.tokenizer = tokenizer + self.beam_search = beam_search + self.beam_search_transducer = beam_search_transducer + self.maxlenratio = maxlenratio + self.minlenratio = minlenratio + self.device = device + self.dtype = dtype + self.nbest = nbest + + @torch.no_grad() + def __call__( + self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None + ) -> List[ + Tuple[ + Optional[str], + List[str], + List[int], + Union[Hypothesis], + ] + ]: + """Inference + + Args: + speech: Input speech data + Returns: + text, token, token_int, hyp + + """ + assert check_argument_types() + # Input as audio signal + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + if (speech.dim() == 3): + speech = torch.squeeze(speech, 2) + # speech = speech.unsqueeze(0).to(getattr(torch, self.dtype)) + speech = speech.to(getattr(torch, self.dtype)) + # lenghts: (1,) + lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1)) + batch = {"speech": speech, "speech_lengths": lengths} + + # a. To device + batch = to_device(batch, device=self.device) + + # b. Forward Encoder + enc, _ = self.asr_model.encode(**batch) + + assert len(enc) == 1, len(enc) + + # c. Passed the encoder result and the beam search + nbest_hyps = self.beam_search( + x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio + ) + + nbest_hyps = nbest_hyps[: self.nbest] + + results = [] + for hyp in nbest_hyps: + assert isinstance(hyp, (Hypothesis)), type(hyp) + + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != 0, token_int)) + + # Change integer-ids to tokens + token = self.converter.ids2tokens(token_int) + + if self.tokenizer is not None: + text = self.tokenizer.tokens2text(token) + else: + text = None + results.append((text, token, token_int, hyp)) + + assert check_return_type(results) + return results + + diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py index a52e94a7f..f70382bf1 100644 --- a/funasr/bin/asr_inference.py +++ b/funasr/bin/asr_inference.py @@ -256,70 +256,7 @@ class Speech2Text: assert check_return_type(results) return results -def inference( - maxlenratio: float, - minlenratio: float, - batch_size: int, - beam_size: int, - ngpu: int, - ctc_weight: float, - lm_weight: float, - penalty: float, - log_level: Union[int, str], - data_path_and_name_and_type, - asr_train_config: Optional[str], - asr_model_file: Optional[str], - cmvn_file: Optional[str] = None, - raw_inputs: Union[np.ndarray, torch.Tensor] = None, - lm_train_config: Optional[str] = None, - lm_file: Optional[str] = None, - token_type: Optional[str] = None, - key_file: Optional[str] = None, - word_lm_train_config: Optional[str] = None, - bpemodel: Optional[str] = None, - allow_variable_data_keys: bool = False, - streaming: bool = False, - output_dir: Optional[str] = None, - dtype: str = "float32", - seed: int = 0, - ngram_weight: float = 0.9, - nbest: int = 1, - num_workers: int = 1, - mc: bool = False, - **kwargs, -): - inference_pipeline = inference_modelscope( - maxlenratio=maxlenratio, - minlenratio=minlenratio, - batch_size=batch_size, - beam_size=beam_size, - ngpu=ngpu, - ctc_weight=ctc_weight, - lm_weight=lm_weight, - penalty=penalty, - log_level=log_level, - asr_train_config=asr_train_config, - asr_model_file=asr_model_file, - cmvn_file=cmvn_file, - raw_inputs=raw_inputs, - lm_train_config=lm_train_config, - lm_file=lm_file, - token_type=token_type, - key_file=key_file, - word_lm_train_config=word_lm_train_config, - bpemodel=bpemodel, - allow_variable_data_keys=allow_variable_data_keys, - streaming=streaming, - output_dir=output_dir, - dtype=dtype, - seed=seed, - ngram_weight=ngram_weight, - nbest=nbest, - num_workers=num_workers, - mc=mc, - **kwargs, - ) - return inference_pipeline(data_path_and_name_and_type, raw_inputs) + def inference_modelscope( maxlenratio: float, diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index 7b04a9e31..6ad17f0c6 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -12,6 +12,924 @@ from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none +#!/usr/bin/env python3 +import argparse +import logging +import sys +import time +import copy +import os +import codecs +import tempfile +import requests +from pathlib import Path +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union +from typing import Dict +from typing import Any +from typing import List +import yaml +import numpy as np +import torch +import torchaudio +from typeguard import check_argument_types +from typeguard import check_return_type +from funasr.fileio.datadir_writer import DatadirWriter +from funasr.modules.beam_search.beam_search import BeamSearch +# from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch + +from funasr.modules.beam_search.beam_search import Hypothesis +from funasr.modules.scorers.ctc import CTCPrefixScorer +from funasr.modules.scorers.length_bonus import LengthBonus +from funasr.modules.subsampling import TooShortUttError +from funasr.tasks.asr import ASRTask +from funasr.tasks.lm import LMTask +from funasr.text.build_tokenizer import build_tokenizer +from funasr.text.token_id_converter import TokenIDConverter +from funasr.torch_utils.device_funcs import to_device +from funasr.torch_utils.set_all_random_seed import set_all_random_seed +from funasr.utils import config_argparse +from funasr.utils.cli_utils import get_commandline_args +from funasr.utils.types import str2bool +from funasr.utils.types import str2triple_str +from funasr.utils.types import str_or_none +from funasr.utils import asr_utils, wav_utils, postprocess_utils +from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline +from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer +from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer +from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export +from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard +from funasr.bin.tp_inference import SpeechText2Timestamp +from funasr.bin.vad_inference import Speech2VadSegment +from funasr.bin.punctuation_infer import Text2Punc +from funasr.utils.vad_utils import slice_padding_fbank +from funasr.tasks.vad import VADTask +from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard +from funasr.bin.asr_infer import Speech2Text +from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline +from funasr.bin.asr_infer import Speech2TextUniASR + + +def inference_paraformer( + maxlenratio: float, + minlenratio: float, + batch_size: int, + beam_size: int, + ngpu: int, + ctc_weight: float, + lm_weight: float, + penalty: float, + log_level: Union[int, str], + # data_path_and_name_and_type, + asr_train_config: Optional[str], + asr_model_file: Optional[str], + cmvn_file: Optional[str] = None, + lm_train_config: Optional[str] = None, + lm_file: Optional[str] = None, + token_type: Optional[str] = None, + key_file: Optional[str] = None, + word_lm_train_config: Optional[str] = None, + bpemodel: Optional[str] = None, + allow_variable_data_keys: bool = False, + dtype: str = "float32", + seed: int = 0, + ngram_weight: float = 0.9, + nbest: int = 1, + num_workers: int = 1, + output_dir: Optional[str] = None, + timestamp_infer_config: Union[Path, str] = None, + timestamp_model_file: Union[Path, str] = None, + param_dict: dict = None, + **kwargs, +): + assert check_argument_types() + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) + + if word_lm_train_config is not None: + raise NotImplementedError("Word LM is not implemented") + if ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + + export_mode = False + if param_dict is not None: + hotword_list_or_file = param_dict.get('hotword') + export_mode = param_dict.get("export_mode", False) + else: + hotword_list_or_file = None + + if kwargs.get("device", None) == "cpu": + ngpu = 0 + if ngpu >= 1 and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + batch_size = 1 + + # 1. Set random-seed + set_all_random_seed(seed) + + # 2. Build speech2text + speech2text_kwargs = dict( + asr_train_config=asr_train_config, + asr_model_file=asr_model_file, + cmvn_file=cmvn_file, + lm_train_config=lm_train_config, + lm_file=lm_file, + token_type=token_type, + bpemodel=bpemodel, + device=device, + maxlenratio=maxlenratio, + minlenratio=minlenratio, + dtype=dtype, + beam_size=beam_size, + ctc_weight=ctc_weight, + lm_weight=lm_weight, + ngram_weight=ngram_weight, + penalty=penalty, + nbest=nbest, + hotword_list_or_file=hotword_list_or_file, + ) + + speech2text = Speech2TextParaformer(**speech2text_kwargs) + + if timestamp_model_file is not None: + speechtext2timestamp = SpeechText2Timestamp( + timestamp_cmvn_file=cmvn_file, + timestamp_model_file=timestamp_model_file, + timestamp_infer_config=timestamp_infer_config, + ) + else: + speechtext2timestamp = None + + def _forward( + data_path_and_name_and_type, + raw_inputs: Union[np.ndarray, torch.Tensor] = None, + output_dir_v2: Optional[str] = None, + fs: dict = None, + param_dict: dict = None, + **kwargs, + ): + + hotword_list_or_file = None + if param_dict is not None: + hotword_list_or_file = param_dict.get('hotword') + if 'hotword' in kwargs and kwargs['hotword'] is not None: + hotword_list_or_file = kwargs['hotword'] + if hotword_list_or_file is not None or 'hotword' in kwargs: + speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file) + + # 3. Build data-iterator + if data_path_and_name_and_type is None and raw_inputs is not None: + if isinstance(raw_inputs, torch.Tensor): + raw_inputs = raw_inputs.numpy() + data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] + loader = ASRTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + fs=fs, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), + collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + + if param_dict is not None: + use_timestamp = param_dict.get('use_timestamp', True) + else: + use_timestamp = True + + forward_time_total = 0.0 + length_total = 0.0 + finish_count = 0 + file_count = 1 + # 7 .Start for-loop + # FIXME(kamo): The output format should be discussed about + asr_result_list = [] + output_path = output_dir_v2 if output_dir_v2 is not None else output_dir + if output_path is not None: + writer = DatadirWriter(output_path) + else: + writer = None + + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f"{len(keys)} != {_bs}" + # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")} + + logging.info("decoding, utt_id: {}".format(keys)) + # N-best list of (text, token, token_int, hyp_object) + + time_beg = time.time() + results = speech2text(**batch) + if len(results) < 1: + hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) + results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest + time_end = time.time() + forward_time = time_end - time_beg + lfr_factor = results[0][-1] + length = results[0][-2] + forward_time_total += forward_time + length_total += length + rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time, + 100 * forward_time / ( + length * lfr_factor)) + logging.info(rtf_cur) + + for batch_id in range(_bs): + result = [results[batch_id][:-2]] + + key = keys[batch_id] + for n, result in zip(range(1, nbest + 1), result): + text, token, token_int, hyp = result[0], result[1], result[2], result[3] + timestamp = result[4] if len(result[4]) > 0 else None + # conduct timestamp prediction here + # timestamp inference requires token length + # thus following inference cannot be conducted in batch + if timestamp is None and speechtext2timestamp: + ts_batch = {} + ts_batch['speech'] = batch['speech'][batch_id].unsqueeze(0) + ts_batch['speech_lengths'] = torch.tensor([batch['speech_lengths'][batch_id]]) + ts_batch['text_lengths'] = torch.tensor([len(token)]) + us_alphas, us_peaks = speechtext2timestamp(**ts_batch) + ts_str, timestamp = ts_prediction_lfr6_standard(us_alphas[0], us_peaks[0], token, + force_time_shift=-3.0) + # Create a directory: outdir/{n}best_recog + if writer is not None: + ibest_writer = writer[f"{n}best_recog"] + + # Write the result to each file + ibest_writer["token"][key] = " ".join(token) + # ibest_writer["token_int"][key] = " ".join(map(str, token_int)) + ibest_writer["score"][key] = str(hyp.score) + ibest_writer["rtf"][key] = rtf_cur + + if text is not None: + if use_timestamp and timestamp is not None: + postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp) + else: + postprocessed_result = postprocess_utils.sentence_postprocess(token) + timestamp_postprocessed = "" + if len(postprocessed_result) == 3: + text_postprocessed, timestamp_postprocessed, word_lists = postprocessed_result[0], \ + postprocessed_result[1], \ + postprocessed_result[2] + else: + text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1] + item = {'key': key, 'value': text_postprocessed} + if timestamp_postprocessed != "": + item['timestamp'] = timestamp_postprocessed + asr_result_list.append(item) + finish_count += 1 + # asr_utils.print_progress(finish_count / file_count) + if writer is not None: + ibest_writer["text"][key] = " ".join(word_lists) + + logging.info("decoding, utt: {}, predictions: {}".format(key, text)) + rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, + forward_time_total, + 100 * forward_time_total / ( + length_total * lfr_factor)) + logging.info(rtf_avg) + if writer is not None: + ibest_writer["rtf"]["rtf_avf"] = rtf_avg + return asr_result_list + + return _forward + + +def inference_paraformer_vad_punc( + maxlenratio: float, + minlenratio: float, + batch_size: int, + beam_size: int, + ngpu: int, + ctc_weight: float, + lm_weight: float, + penalty: float, + log_level: Union[int, str], + # data_path_and_name_and_type, + asr_train_config: Optional[str], + asr_model_file: Optional[str], + cmvn_file: Optional[str] = None, + lm_train_config: Optional[str] = None, + lm_file: Optional[str] = None, + token_type: Optional[str] = None, + key_file: Optional[str] = None, + word_lm_train_config: Optional[str] = None, + bpemodel: Optional[str] = None, + allow_variable_data_keys: bool = False, + output_dir: Optional[str] = None, + dtype: str = "float32", + seed: int = 0, + ngram_weight: float = 0.9, + nbest: int = 1, + num_workers: int = 1, + vad_infer_config: Optional[str] = None, + vad_model_file: Optional[str] = None, + vad_cmvn_file: Optional[str] = None, + time_stamp_writer: bool = True, + punc_infer_config: Optional[str] = None, + punc_model_file: Optional[str] = None, + outputs_dict: Optional[bool] = True, + param_dict: dict = None, + **kwargs, +): + assert check_argument_types() + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) + + if word_lm_train_config is not None: + raise NotImplementedError("Word LM is not implemented") + if ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + + if param_dict is not None: + hotword_list_or_file = param_dict.get('hotword') + else: + hotword_list_or_file = None + + if ngpu >= 1 and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + # 1. Set random-seed + set_all_random_seed(seed) + + # 2. Build speech2vadsegment + speech2vadsegment_kwargs = dict( + vad_infer_config=vad_infer_config, + vad_model_file=vad_model_file, + vad_cmvn_file=vad_cmvn_file, + device=device, + dtype=dtype, + ) + # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs)) + speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs) + + # 3. Build speech2text + speech2text_kwargs = dict( + asr_train_config=asr_train_config, + asr_model_file=asr_model_file, + cmvn_file=cmvn_file, + lm_train_config=lm_train_config, + lm_file=lm_file, + token_type=token_type, + bpemodel=bpemodel, + device=device, + maxlenratio=maxlenratio, + minlenratio=minlenratio, + dtype=dtype, + beam_size=beam_size, + ctc_weight=ctc_weight, + lm_weight=lm_weight, + ngram_weight=ngram_weight, + penalty=penalty, + nbest=nbest, + hotword_list_or_file=hotword_list_or_file, + ) + speech2text = Speech2TextParaformer(**speech2text_kwargs) + text2punc = None + if punc_model_file is not None: + text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype) + + if output_dir is not None: + writer = DatadirWriter(output_dir) + ibest_writer = writer[f"1best_recog"] + ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list) + + def _forward(data_path_and_name_and_type, + raw_inputs: Union[np.ndarray, torch.Tensor] = None, + output_dir_v2: Optional[str] = None, + fs: dict = None, + param_dict: dict = None, + **kwargs, + ): + + hotword_list_or_file = None + if param_dict is not None: + hotword_list_or_file = param_dict.get('hotword') + + if 'hotword' in kwargs: + hotword_list_or_file = kwargs['hotword'] + + if speech2text.hotword_list is None: + speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file) + + # 3. Build data-iterator + if data_path_and_name_and_type is None and raw_inputs is not None: + if isinstance(raw_inputs, torch.Tensor): + raw_inputs = raw_inputs.numpy() + data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] + loader = ASRTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + fs=fs, + batch_size=1, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False), + collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + + if param_dict is not None: + use_timestamp = param_dict.get('use_timestamp', True) + else: + use_timestamp = True + + finish_count = 0 + file_count = 1 + lfr_factor = 6 + # 7 .Start for-loop + asr_result_list = [] + output_path = output_dir_v2 if output_dir_v2 is not None else output_dir + writer = None + if output_path is not None: + writer = DatadirWriter(output_path) + ibest_writer = writer[f"1best_recog"] + + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f"{len(keys)} != {_bs}" + + vad_results = speech2vadsegment(**batch) + _, vadsegments = vad_results[0], vad_results[1][0] + + speech, speech_lengths = batch["speech"], batch["speech_lengths"] + + n = len(vadsegments) + data_with_index = [(vadsegments[i], i) for i in range(n)] + sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0]) + results_sorted = [] + for j, beg_idx in enumerate(range(0, n, batch_size)): + end_idx = min(n, beg_idx + batch_size) + speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx]) + + batch = {"speech": speech_j, "speech_lengths": speech_lengths_j} + batch = to_device(batch, device=device) + results = speech2text(**batch) + + if len(results) < 1: + results = [["", [], [], [], [], [], []]] + results_sorted.extend(results) + restored_data = [0] * n + for j in range(n): + index = sorted_data[j][1] + restored_data[index] = results_sorted[j] + result = ["", [], [], [], [], [], []] + for j in range(n): + result[0] += restored_data[j][0] + result[1] += restored_data[j][1] + result[2] += restored_data[j][2] + if len(restored_data[j][4]) > 0: + for t in restored_data[j][4]: + t[0] += vadsegments[j][0] + t[1] += vadsegments[j][0] + result[4] += restored_data[j][4] + # result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))] + + key = keys[0] + # result = result_segments[0] + text, token, token_int = result[0], result[1], result[2] + time_stamp = result[4] if len(result[4]) > 0 else None + + if use_timestamp and time_stamp is not None: + postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp) + else: + postprocessed_result = postprocess_utils.sentence_postprocess(token) + text_postprocessed = "" + time_stamp_postprocessed = "" + text_postprocessed_punc = postprocessed_result + if len(postprocessed_result) == 3: + text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \ + postprocessed_result[1], \ + postprocessed_result[2] + else: + text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1] + + text_postprocessed_punc = text_postprocessed + punc_id_list = [] + if len(word_lists) > 0 and text2punc is not None: + text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20) + + item = {'key': key, 'value': text_postprocessed_punc} + if text_postprocessed != "": + item['text_postprocessed'] = text_postprocessed + if time_stamp_postprocessed != "": + item['time_stamp'] = time_stamp_postprocessed + + item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed) + + asr_result_list.append(item) + finish_count += 1 + # asr_utils.print_progress(finish_count / file_count) + if writer is not None: + # Write the result to each file + ibest_writer["token"][key] = " ".join(token) + ibest_writer["token_int"][key] = " ".join(map(str, token_int)) + ibest_writer["vad"][key] = "{}".format(vadsegments) + ibest_writer["text"][key] = " ".join(word_lists) + ibest_writer["text_with_punc"][key] = text_postprocessed_punc + if time_stamp_postprocessed is not None: + ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed) + + logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc)) + return asr_result_list + + return _forward + +def inference_paraformer_online( + maxlenratio: float, + minlenratio: float, + batch_size: int, + beam_size: int, + ngpu: int, + ctc_weight: float, + lm_weight: float, + penalty: float, + log_level: Union[int, str], + # data_path_and_name_and_type, + asr_train_config: Optional[str], + asr_model_file: Optional[str], + cmvn_file: Optional[str] = None, + lm_train_config: Optional[str] = None, + lm_file: Optional[str] = None, + token_type: Optional[str] = None, + key_file: Optional[str] = None, + word_lm_train_config: Optional[str] = None, + bpemodel: Optional[str] = None, + allow_variable_data_keys: bool = False, + dtype: str = "float32", + seed: int = 0, + ngram_weight: float = 0.9, + nbest: int = 1, + num_workers: int = 1, + output_dir: Optional[str] = None, + param_dict: dict = None, + **kwargs, +): + assert check_argument_types() + + if word_lm_train_config is not None: + raise NotImplementedError("Word LM is not implemented") + if ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + + export_mode = False + + if ngpu >= 1 and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + batch_size = 1 + + # 1. Set random-seed + set_all_random_seed(seed) + + # 2. Build speech2text + speech2text_kwargs = dict( + asr_train_config=asr_train_config, + asr_model_file=asr_model_file, + cmvn_file=cmvn_file, + lm_train_config=lm_train_config, + lm_file=lm_file, + token_type=token_type, + bpemodel=bpemodel, + device=device, + maxlenratio=maxlenratio, + minlenratio=minlenratio, + dtype=dtype, + beam_size=beam_size, + ctc_weight=ctc_weight, + lm_weight=lm_weight, + ngram_weight=ngram_weight, + penalty=penalty, + nbest=nbest, + ) + + speech2text = Speech2TextParaformerOnline(**speech2text_kwargs) + + def _load_bytes(input): + middle_data = np.frombuffer(input, dtype=np.int16) + middle_data = np.asarray(middle_data) + if middle_data.dtype.kind not in 'iu': + raise TypeError("'middle_data' must be an array of integers") + dtype = np.dtype('float32') + if dtype.kind != 'f': + raise TypeError("'dtype' must be a floating point type") + + i = np.iinfo(middle_data.dtype) + abs_max = 2 ** (i.bits - 1) + offset = i.min + abs_max + array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) + return array + + def _read_yaml(yaml_path: Union[str, Path]) -> Dict: + if not Path(yaml_path).exists(): + raise FileExistsError(f'The {yaml_path} does not exist.') + + with open(str(yaml_path), 'rb') as f: + data = yaml.load(f, Loader=yaml.Loader) + return data + + def _prepare_cache(cache: dict = {}, chunk_size=[5,10,5], batch_size=1): + if len(cache) > 0: + return cache + config = _read_yaml(asr_train_config) + enc_output_size = config["encoder_conf"]["output_size"] + feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"] + cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)), + "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False, + "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False} + cache["encoder"] = cache_en + + cache_de = {"decode_fsmn": None} + cache["decoder"] = cache_de + + return cache + + def _cache_reset(cache: dict = {}, chunk_size=[5,10,5], batch_size=1): + if len(cache) > 0: + config = _read_yaml(asr_train_config) + enc_output_size = config["encoder_conf"]["output_size"] + feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"] + cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)), + "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False, + "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False} + cache["encoder"] = cache_en + + cache_de = {"decode_fsmn": None} + cache["decoder"] = cache_de + + return cache + + def _forward( + data_path_and_name_and_type, + raw_inputs: Union[np.ndarray, torch.Tensor] = None, + output_dir_v2: Optional[str] = None, + fs: dict = None, + param_dict: dict = None, + **kwargs, + ): + + # 3. Build data-iterator + if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes": + raw_inputs = _load_bytes(data_path_and_name_and_type[0]) + raw_inputs = torch.tensor(raw_inputs) + if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound": + raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0] + if data_path_and_name_and_type is None and raw_inputs is not None: + if isinstance(raw_inputs, np.ndarray): + raw_inputs = torch.tensor(raw_inputs) + is_final = False + cache = {} + chunk_size = [5, 10, 5] + if param_dict is not None and "cache" in param_dict: + cache = param_dict["cache"] + if param_dict is not None and "is_final" in param_dict: + is_final = param_dict["is_final"] + if param_dict is not None and "chunk_size" in param_dict: + chunk_size = param_dict["chunk_size"] + + # 7 .Start for-loop + # FIXME(kamo): The output format should be discussed about + raw_inputs = torch.unsqueeze(raw_inputs, axis=0) + asr_result_list = [] + cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1) + item = {} + if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound": + sample_offset = 0 + speech_length = raw_inputs.shape[1] + stride_size = chunk_size[1] * 960 + cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1) + final_result = "" + for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)): + if sample_offset + stride_size >= speech_length - 1: + stride_size = speech_length - sample_offset + cache["encoder"]["is_final"] = True + else: + cache["encoder"]["is_final"] = False + input_lens = torch.tensor([stride_size]) + asr_result = speech2text(cache, raw_inputs[:, sample_offset: sample_offset + stride_size], input_lens) + if len(asr_result) != 0: + final_result += " ".join(asr_result) + " " + item = {'key': "utt", 'value': final_result.strip()} + else: + input_lens = torch.tensor([raw_inputs.shape[1]]) + cache["encoder"]["is_final"] = is_final + asr_result = speech2text(cache, raw_inputs, input_lens) + item = {'key': "utt", 'value': " ".join(asr_result)} + + asr_result_list.append(item) + if is_final: + cache = _cache_reset(cache, chunk_size=chunk_size, batch_size=1) + return asr_result_list + + return _forward + + +def inference_uniasr( + maxlenratio: float, + minlenratio: float, + batch_size: int, + beam_size: int, + ngpu: int, + ctc_weight: float, + lm_weight: float, + penalty: float, + log_level: Union[int, str], + # data_path_and_name_and_type, + asr_train_config: Optional[str], + asr_model_file: Optional[str], + ngram_file: Optional[str] = None, + cmvn_file: Optional[str] = None, + # raw_inputs: Union[np.ndarray, torch.Tensor] = None, + lm_train_config: Optional[str] = None, + lm_file: Optional[str] = None, + token_type: Optional[str] = None, + key_file: Optional[str] = None, + word_lm_train_config: Optional[str] = None, + bpemodel: Optional[str] = None, + allow_variable_data_keys: bool = False, + streaming: bool = False, + output_dir: Optional[str] = None, + dtype: str = "float32", + seed: int = 0, + ngram_weight: float = 0.9, + nbest: int = 1, + num_workers: int = 1, + token_num_relax: int = 1, + decoding_ind: int = 0, + decoding_mode: str = "model1", + param_dict: dict = None, + **kwargs, +): + assert check_argument_types() + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) + if batch_size > 1: + raise NotImplementedError("batch decoding is not implemented") + if word_lm_train_config is not None: + raise NotImplementedError("Word LM is not implemented") + if ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + + if ngpu >= 1 and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + if param_dict is not None and "decoding_model" in param_dict: + if param_dict["decoding_model"] == "fast": + decoding_ind = 0 + decoding_mode = "model1" + elif param_dict["decoding_model"] == "normal": + decoding_ind = 0 + decoding_mode = "model2" + elif param_dict["decoding_model"] == "offline": + decoding_ind = 1 + decoding_mode = "model2" + else: + raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"])) + + # 1. Set random-seed + set_all_random_seed(seed) + + # 2. Build speech2text + speech2text_kwargs = dict( + asr_train_config=asr_train_config, + asr_model_file=asr_model_file, + cmvn_file=cmvn_file, + lm_train_config=lm_train_config, + lm_file=lm_file, + ngram_file=ngram_file, + token_type=token_type, + bpemodel=bpemodel, + device=device, + maxlenratio=maxlenratio, + minlenratio=minlenratio, + dtype=dtype, + beam_size=beam_size, + ctc_weight=ctc_weight, + lm_weight=lm_weight, + ngram_weight=ngram_weight, + penalty=penalty, + nbest=nbest, + streaming=streaming, + token_num_relax=token_num_relax, + decoding_ind=decoding_ind, + decoding_mode=decoding_mode, + ) + speech2text = Speech2Text(**speech2text_kwargs) + + def _forward(data_path_and_name_and_type, + raw_inputs: Union[np.ndarray, torch.Tensor] = None, + output_dir_v2: Optional[str] = None, + fs: dict = None, + param_dict: dict = None, + **kwargs, + ): + # 3. Build data-iterator + if data_path_and_name_and_type is None and raw_inputs is not None: + if isinstance(raw_inputs, torch.Tensor): + raw_inputs = raw_inputs.numpy() + data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] + loader = ASRTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + fs=fs, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), + collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + + finish_count = 0 + file_count = 1 + # 7 .Start for-loop + # FIXME(kamo): The output format should be discussed about + asr_result_list = [] + output_path = output_dir_v2 if output_dir_v2 is not None else output_dir + if output_path is not None: + writer = DatadirWriter(output_path) + else: + writer = None + + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f"{len(keys)} != {_bs}" + # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} + + # N-best list of (text, token, token_int, hyp_object) + try: + results = speech2text(**batch) + except TooShortUttError as e: + logging.warning(f"Utterance {keys} {e}") + hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) + results = [[" ", ["sil"], [2], hyp]] * nbest + + # Only supporting batch_size==1 + key = keys[0] + logging.info(f"Utterance: {key}") + for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results): + # Create a directory: outdir/{n}best_recog + if writer is not None: + ibest_writer = writer[f"{n}best_recog"] + + # Write the result to each file + ibest_writer["token"][key] = " ".join(token) + # ibest_writer["token_int"][key] = " ".join(map(str, token_int)) + ibest_writer["score"][key] = str(hyp.score) + + if text is not None: + text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token) + item = {'key': key, 'value': text_postprocessed} + asr_result_list.append(item) + finish_count += 1 + asr_utils.print_progress(finish_count / file_count) + if writer is not None: + ibest_writer["text"][key] = " ".join(word_lists) + return asr_result_list + + return _forward + def get_parser(): parser = config_argparse.ArgumentParser( @@ -252,17 +1170,13 @@ def inference_launch(**kwargs): from funasr.bin.asr_inference import inference_modelscope return inference_modelscope(**kwargs) elif mode == "uniasr": - from funasr.bin.asr_inference_uniasr import inference_modelscope - return inference_modelscope(**kwargs) + return inference_uniasr(**kwargs) elif mode == "paraformer": - from funasr.bin.asr_inference_paraformer import inference_modelscope - return inference_modelscope(**kwargs) + return inference_paraformer(**kwargs) elif mode == "paraformer_streaming": - from funasr.bin.asr_inference_paraformer_streaming import inference_modelscope - return inference_modelscope(**kwargs) + return inference_paraformer_online(**kwargs) elif mode.startswith("paraformer_vad"): - from funasr.bin.asr_inference_paraformer import inference_modelscope_vad_punc - return inference_modelscope_vad_punc(**kwargs) + return inference_paraformer_vad_punc(**kwargs) elif mode == "mfcca": from funasr.bin.asr_inference_mfcca import inference_modelscope return inference_modelscope(**kwargs) @@ -273,38 +1187,6 @@ def inference_launch(**kwargs): logging.info("Unknown decoding mode: {}".format(mode)) return None -def inference_launch_funasr(**kwargs): - if 'mode' in kwargs: - mode = kwargs['mode'] - else: - logging.info("Unknown decoding mode.") - return None - if mode == "asr": - from funasr.bin.asr_inference import inference - return inference(**kwargs) - elif mode == "sa_asr": - from funasr.bin.sa_asr_inference import inference - return inference(**kwargs) - elif mode == "uniasr": - from funasr.bin.asr_inference_uniasr import inference - return inference(**kwargs) - elif mode == "paraformer": - from funasr.bin.asr_inference_paraformer import inference_modelscope - inference_pipeline = inference_modelscope(**kwargs) - return inference_pipeline(kwargs["data_path_and_name_and_type"], hotword=kwargs.get("hotword", None)) - elif mode.startswith("paraformer_vad"): - from funasr.bin.asr_inference_paraformer import inference_modelscope_vad_punc - inference_pipeline = inference_modelscope_vad_punc(**kwargs) - return inference_pipeline(kwargs["data_path_and_name_and_type"], hotword=kwargs.get("hotword", None)) - elif mode == "mfcca": - from funasr.bin.asr_inference_mfcca import inference_modelscope - return inference_modelscope(**kwargs) - elif mode == "rnnt": - from funasr.bin.asr_inference_rnnt import inference - return inference(**kwargs) - else: - logging.info("Unknown decoding mode: {}".format(mode)) - return None def main(cmd=None): @@ -334,7 +1216,9 @@ def main(cmd=None): os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = gpuid - inference_launch_funasr(**kwargs) + inference_pipeline = inference_launch(**kwargs) + return inference_pipeline(kwargs["data_path_and_name_and_type"], hotword=kwargs.get("hotword", None)) + if __name__ == "__main__":