From 3d70934e7fed7c0d3179fec340761466205cb3e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Wed, 14 Jun 2023 15:09:56 +0800 Subject: [PATCH] update repo --- funasr/bin/asr_infer.py | 543 ++++++++++++++++++++-------------------- 1 file changed, 265 insertions(+), 278 deletions(-) diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py index 47ce0ee99..288034c50 100644 --- a/funasr/bin/asr_infer.py +++ b/funasr/bin/asr_infer.py @@ -1,66 +1,48 @@ -# -*- encoding: utf-8 -*- #!/usr/bin/env python3 +# -*- encoding: utf-8 -*- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) -import argparse -import logging -import sys -import time + +import codecs import copy +import logging import os import re -import codecs import tempfile -import requests from pathlib import Path +from typing import Any +from typing import Dict +from typing import List from typing import Optional -from typing import Sequence from typing import Tuple from typing import Union -from typing import Dict -from typing import Any -from typing import List import numpy as np +import requests import torch from packaging.version import parse as V from typeguard import check_argument_types from typeguard import check_return_type -from funasr.fileio.datadir_writer import DatadirWriter + +from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer +from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer +from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline from funasr.modules.beam_search.beam_search import BeamSearch -# from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch from funasr.modules.beam_search.beam_search import Hypothesis +from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer -from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR from funasr.modules.scorers.ctc import CTCPrefixScorer from funasr.modules.scorers.length_bonus import LengthBonus -from funasr.modules.subsampling import TooShortUttError from funasr.tasks.asr import ASRTask +from funasr.tasks.asr import frontend_choices from funasr.tasks.lm import LMTask from funasr.text.build_tokenizer import build_tokenizer from funasr.text.token_id_converter import TokenIDConverter from funasr.torch_utils.device_funcs import to_device -from funasr.torch_utils.set_all_random_seed import set_all_random_seed -from funasr.utils import config_argparse -from funasr.utils.cli_utils import get_commandline_args -from funasr.utils.types import str2bool -from funasr.utils.types import str2triple_str -from funasr.utils.types import str_or_none -from funasr.utils import asr_utils, wav_utils, postprocess_utils -from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline -from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer -from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer -from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard -from funasr.bin.tp_infer import Speech2Timestamp -from funasr.bin.vad_infer import Speech2VadSegment -from funasr.bin.punc_infer import Text2Punc -from funasr.utils.vad_utils import slice_padding_fbank -from funasr.tasks.vad import VADTask -from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard -from funasr.tasks.asr import frontend_choices + class Speech2Text: """Speech2Text class @@ -73,33 +55,33 @@ class Speech2Text: [(text, token, token_int, hypothesis object), ...] """ - + def __init__( - self, - asr_train_config: Union[Path, str] = None, - asr_model_file: Union[Path, str] = None, - cmvn_file: Union[Path, str] = None, - lm_train_config: Union[Path, str] = None, - lm_file: Union[Path, str] = None, - token_type: str = None, - bpemodel: str = None, - device: str = "cpu", - maxlenratio: float = 0.0, - minlenratio: float = 0.0, - batch_size: int = 1, - dtype: str = "float32", - beam_size: int = 20, - ctc_weight: float = 0.5, - lm_weight: float = 1.0, - ngram_weight: float = 0.9, - penalty: float = 0.0, - nbest: int = 1, - streaming: bool = False, - frontend_conf: dict = None, - **kwargs, + self, + asr_train_config: Union[Path, str] = None, + asr_model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, + lm_train_config: Union[Path, str] = None, + lm_file: Union[Path, str] = None, + token_type: str = None, + bpemodel: str = None, + device: str = "cpu", + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + batch_size: int = 1, + dtype: str = "float32", + beam_size: int = 20, + ctc_weight: float = 0.5, + lm_weight: float = 1.0, + ngram_weight: float = 0.9, + penalty: float = 0.0, + nbest: int = 1, + streaming: bool = False, + frontend_conf: dict = None, + **kwargs, ): assert check_argument_types() - + # 1. Build ASR model scorers = {} asr_model, asr_train_args = ASRTask.build_model_from_file( @@ -113,13 +95,13 @@ class Speech2Text: from funasr.tasks.asr import frontend_choices frontend_class = frontend_choices.get_class(asr_train_args.frontend) frontend = frontend_class(**asr_train_args.frontend_conf).eval() - + logging.info("asr_model: {}".format(asr_model)) logging.info("asr_train_args: {}".format(asr_train_args)) asr_model.to(dtype=getattr(torch, dtype)).eval() - + decoder = asr_model.decoder - + ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) token_list = asr_model.token_list scorers.update( @@ -127,24 +109,24 @@ class Speech2Text: ctc=ctc, length_bonus=LengthBonus(len(token_list)), ) - + # 2. Build Language model if lm_train_config is not None: lm, lm_train_args = LMTask.build_model_from_file( lm_train_config, lm_file, None, device ) scorers["lm"] = lm.lm - + # 3. Build ngram model # ngram is not supported now ngram = None scorers["ngram"] = ngram - + # 4. Build BeamSearch object # transducer is not supported now beam_search_transducer = None from funasr.modules.beam_search.beam_search import BeamSearch - + weights = dict( decoder=1.0 - ctc_weight, ctc=ctc_weight, @@ -162,13 +144,13 @@ class Speech2Text: token_list=token_list, pre_beam_score_key=None if ctc_weight == 1.0 else "full", ) - + # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text if token_type is None: token_type = asr_train_args.token_type if bpemodel is None: bpemodel = asr_train_args.bpemodel - + if token_type is None: tokenizer = None elif token_type == "bpe": @@ -180,7 +162,7 @@ class Speech2Text: tokenizer = build_tokenizer(token_type=token_type) converter = TokenIDConverter(token_list=token_list) logging.info(f"Text tokenizer: {tokenizer}") - + self.asr_model = asr_model self.asr_train_args = asr_train_args self.converter = converter @@ -193,10 +175,10 @@ class Speech2Text: self.dtype = dtype self.nbest = nbest self.frontend = frontend - + @torch.no_grad() def __call__( - self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None + self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None ) -> List[ Tuple[ Optional[str], @@ -214,11 +196,11 @@ class Speech2Text: """ assert check_argument_types() - + # Input as audio signal if isinstance(speech, np.ndarray): speech = torch.tensor(speech) - + if self.frontend is not None: feats, feats_len = self.frontend.forward(speech, speech_lengths) feats = to_device(feats, device=self.device) @@ -229,49 +211,50 @@ class Speech2Text: feats_len = speech_lengths lfr_factor = max(1, (feats.size()[-1] // 80) - 1) batch = {"speech": feats, "speech_lengths": feats_len} - + # a. To device batch = to_device(batch, device=self.device) - + # b. Forward Encoder enc, _ = self.asr_model.encode(**batch) if isinstance(enc, tuple): enc = enc[0] assert len(enc) == 1, len(enc) - + # c. Passed the encoder result and the beam search nbest_hyps = self.beam_search( x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio ) - + nbest_hyps = nbest_hyps[: self.nbest] - + results = [] for hyp in nbest_hyps: assert isinstance(hyp, (Hypothesis)), type(hyp) - + # remove sos/eos and get results last_pos = -1 if isinstance(hyp.yseq, list): token_int = hyp.yseq[1:last_pos] else: token_int = hyp.yseq[1:last_pos].tolist() - + # remove blank symbol id, which is assumed to be 0 token_int = list(filter(lambda x: x != 0, token_int)) - + # Change integer-ids to tokens token = self.converter.ids2tokens(token_int) - + if self.tokenizer is not None: text = self.tokenizer.tokens2text(token) else: text = None results.append((text, token, token_int, hyp)) - + assert check_return_type(results) return results + class Speech2TextParaformer: """Speech2Text class @@ -466,18 +449,21 @@ class Speech2TextParaformer: pre_token_length = pre_token_length.round().long() if torch.max(pre_token_length) < 1: return [] - if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model, NeatContextualParaformer): + if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model, + NeatContextualParaformer): if self.hotword_list: logging.warning("Hotword is given but asr model is not a ContextualParaformer.") - decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) + decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, + pre_token_length) decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] else: - decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list) + decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, + pre_token_length, hw_list=self.hotword_list) decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] if isinstance(self.asr_model, BiCifParaformer): _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len, - pre_token_length) # test no bias cif2 + pre_token_length) # test no bias cif2 results = [] b, n, d = decoder_out.size() @@ -527,13 +513,12 @@ class Speech2TextParaformer: text = None timestamp = [] if isinstance(self.asr_model, BiCifParaformer): - _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:enc_len[i]*3], - us_peaks[i][:enc_len[i]*3], - copy.copy(token), - vad_offset=begin_time) + _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:enc_len[i] * 3], + us_peaks[i][:enc_len[i] * 3], + copy.copy(token), + vad_offset=begin_time) results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor)) - # assert check_return_type(results) return results @@ -591,6 +576,7 @@ class Speech2TextParaformer: hotword_list = None return hotword_list + class Speech2TextParaformerOnline: """Speech2Text class @@ -789,7 +775,7 @@ class Speech2TextParaformerOnline: enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache) - pre_acoustic_embeds, pre_token_length= predictor_outs[0], predictor_outs[1] + pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1] if torch.max(pre_token_length) < 1: return [] decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache) @@ -839,12 +825,13 @@ class Speech2TextParaformerOnline: postprocessed_result += item + " " else: postprocessed_result += item - + results.append(postprocessed_result) # assert check_return_type(results) return results + class Speech2TextUniASR: """Speech2Text class @@ -1077,7 +1064,7 @@ class Speech2TextUniASR: assert check_return_type(results) return results - + class Speech2TextMFCCA: """Speech2Text class @@ -1090,45 +1077,45 @@ class Speech2TextMFCCA: [(text, token, token_int, hypothesis object), ...] """ - + def __init__( - self, - asr_train_config: Union[Path, str] = None, - asr_model_file: Union[Path, str] = None, - cmvn_file: Union[Path, str] = None, - lm_train_config: Union[Path, str] = None, - lm_file: Union[Path, str] = None, - token_type: str = None, - bpemodel: str = None, - device: str = "cpu", - maxlenratio: float = 0.0, - minlenratio: float = 0.0, - batch_size: int = 1, - dtype: str = "float32", - beam_size: int = 20, - ctc_weight: float = 0.5, - lm_weight: float = 1.0, - ngram_weight: float = 0.9, - penalty: float = 0.0, - nbest: int = 1, - streaming: bool = False, - **kwargs, + self, + asr_train_config: Union[Path, str] = None, + asr_model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, + lm_train_config: Union[Path, str] = None, + lm_file: Union[Path, str] = None, + token_type: str = None, + bpemodel: str = None, + device: str = "cpu", + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + batch_size: int = 1, + dtype: str = "float32", + beam_size: int = 20, + ctc_weight: float = 0.5, + lm_weight: float = 1.0, + ngram_weight: float = 0.9, + penalty: float = 0.0, + nbest: int = 1, + streaming: bool = False, + **kwargs, ): assert check_argument_types() - + # 1. Build ASR model from funasr.tasks.asr import ASRTaskMFCCA as ASRTask scorers = {} asr_model, asr_train_args = ASRTask.build_model_from_file( asr_train_config, asr_model_file, cmvn_file, device ) - + logging.info("asr_model: {}".format(asr_model)) logging.info("asr_train_args: {}".format(asr_train_args)) asr_model.to(dtype=getattr(torch, dtype)).eval() - + decoder = asr_model.decoder - + ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) token_list = asr_model.token_list scorers.update( @@ -1136,7 +1123,7 @@ class Speech2TextMFCCA: ctc=ctc, length_bonus=LengthBonus(len(token_list)), ) - + # 2. Build Language model if lm_train_config is not None: lm, lm_train_args = LMTask.build_model_from_file( @@ -1148,11 +1135,11 @@ class Speech2TextMFCCA: # ngram is not supported now ngram = None scorers["ngram"] = ngram - + # 4. Build BeamSearch object # transducer is not supported now beam_search_transducer = None - + weights = dict( decoder=1.0 - ctc_weight, ctc=ctc_weight, @@ -1176,7 +1163,7 @@ class Speech2TextMFCCA: token_type = asr_train_args.token_type if bpemodel is None: bpemodel = asr_train_args.bpemodel - + if token_type is None: tokenizer = None elif token_type == "bpe": @@ -1188,7 +1175,7 @@ class Speech2TextMFCCA: tokenizer = build_tokenizer(token_type=token_type) converter = TokenIDConverter(token_list=token_list) logging.info(f"Text tokenizer: {tokenizer}") - + self.asr_model = asr_model self.asr_train_args = asr_train_args self.converter = converter @@ -1200,10 +1187,10 @@ class Speech2TextMFCCA: self.device = device self.dtype = dtype self.nbest = nbest - + @torch.no_grad() def __call__( - self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None + self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None ) -> List[ Tuple[ Optional[str], @@ -1231,45 +1218,45 @@ class Speech2TextMFCCA: # lenghts: (1,) lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1)) batch = {"speech": speech, "speech_lengths": lengths} - + # a. To device batch = to_device(batch, device=self.device) - + # b. Forward Encoder enc, _ = self.asr_model.encode(**batch) - + assert len(enc) == 1, len(enc) - + # c. Passed the encoder result and the beam search nbest_hyps = self.beam_search( x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio ) - + nbest_hyps = nbest_hyps[: self.nbest] - + results = [] for hyp in nbest_hyps: assert isinstance(hyp, (Hypothesis)), type(hyp) - + # remove sos/eos and get results last_pos = -1 if isinstance(hyp.yseq, list): token_int = hyp.yseq[1:last_pos] else: token_int = hyp.yseq[1:last_pos].tolist() - + # remove blank symbol id, which is assumed to be 0 token_int = list(filter(lambda x: x != 0, token_int)) - + # Change integer-ids to tokens token = self.converter.ids2tokens(token_int) - + if self.tokenizer is not None: text = self.tokenizer.tokens2text(token) else: text = None results.append((text, token, token_int, hyp)) - + assert check_return_type(results) return results @@ -1298,45 +1285,45 @@ class Speech2TextTransducer: right_context: Number of frames in right context AFTER subsampling. display_partial_hypotheses: Whether to display partial hypotheses. """ - + def __init__( - self, - asr_train_config: Union[Path, str] = None, - asr_model_file: Union[Path, str] = None, - cmvn_file: Union[Path, str] = None, - beam_search_config: Dict[str, Any] = None, - lm_train_config: Union[Path, str] = None, - lm_file: Union[Path, str] = None, - token_type: str = None, - bpemodel: str = None, - device: str = "cpu", - beam_size: int = 5, - dtype: str = "float32", - lm_weight: float = 1.0, - quantize_asr_model: bool = False, - quantize_modules: List[str] = None, - quantize_dtype: str = "qint8", - nbest: int = 1, - streaming: bool = False, - simu_streaming: bool = False, - chunk_size: int = 16, - left_context: int = 32, - right_context: int = 0, - display_partial_hypotheses: bool = False, + self, + asr_train_config: Union[Path, str] = None, + asr_model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, + beam_search_config: Dict[str, Any] = None, + lm_train_config: Union[Path, str] = None, + lm_file: Union[Path, str] = None, + token_type: str = None, + bpemodel: str = None, + device: str = "cpu", + beam_size: int = 5, + dtype: str = "float32", + lm_weight: float = 1.0, + quantize_asr_model: bool = False, + quantize_modules: List[str] = None, + quantize_dtype: str = "qint8", + nbest: int = 1, + streaming: bool = False, + simu_streaming: bool = False, + chunk_size: int = 16, + left_context: int = 32, + right_context: int = 0, + display_partial_hypotheses: bool = False, ) -> None: """Construct a Speech2Text object.""" super().__init__() - + assert check_argument_types() from funasr.tasks.asr import ASRTransducerTask asr_model, asr_train_args = ASRTransducerTask.build_model_from_file( asr_train_config, asr_model_file, cmvn_file, device ) - + frontend = None if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) - + if quantize_asr_model: if quantize_modules is not None: if not all([q in ["LSTM", "Linear"] for q in quantize_modules]): @@ -1344,24 +1331,24 @@ class Speech2TextTransducer: "Only 'Linear' and 'LSTM' modules are currently supported" " by PyTorch and in --quantize_modules" ) - + q_config = set([getattr(torch.nn, q) for q in quantize_modules]) else: q_config = {torch.nn.Linear} - + if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")): raise ValueError( "float16 dtype for dynamic quantization is not supported with torch" " version < 1.5.0. Switching to qint8 dtype instead." ) q_dtype = getattr(torch, quantize_dtype) - + asr_model = torch.quantization.quantize_dynamic( asr_model, q_config, dtype=q_dtype ).eval() else: asr_model.to(dtype=getattr(torch, dtype)).eval() - + if lm_train_config is not None: lm, lm_train_args = LMTask.build_model_from_file( lm_train_config, lm_file, device @@ -1369,11 +1356,11 @@ class Speech2TextTransducer: lm_scorer = lm.lm else: lm_scorer = None - + # 4. Build BeamSearch object if beam_search_config is None: beam_search_config = {} - + beam_search = BeamSearchTransducer( asr_model.decoder, asr_model.joint_network, @@ -1383,14 +1370,14 @@ class Speech2TextTransducer: nbest=nbest, **beam_search_config, ) - + token_list = asr_model.token_list - + if token_type is None: token_type = asr_train_args.token_type if bpemodel is None: bpemodel = asr_train_args.bpemodel - + if token_type is None: tokenizer = None elif token_type == "bpe": @@ -1402,60 +1389,60 @@ class Speech2TextTransducer: tokenizer = build_tokenizer(token_type=token_type) converter = TokenIDConverter(token_list=token_list) logging.info(f"Text tokenizer: {tokenizer}") - + self.asr_model = asr_model self.asr_train_args = asr_train_args self.device = device self.dtype = dtype self.nbest = nbest - + self.converter = converter self.tokenizer = tokenizer - + self.beam_search = beam_search self.streaming = streaming self.simu_streaming = simu_streaming self.chunk_size = max(chunk_size, 0) self.left_context = left_context self.right_context = max(right_context, 0) - + if not streaming or chunk_size == 0: self.streaming = False self.asr_model.encoder.dynamic_chunk_training = False - + if not simu_streaming or chunk_size == 0: self.simu_streaming = False self.asr_model.encoder.dynamic_chunk_training = False - + self.frontend = frontend self.window_size = self.chunk_size + self.right_context - + if self.streaming: self._ctx = self.asr_model.encoder.get_encoder_input_size( self.window_size ) - + self.last_chunk_length = ( - self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 + self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 ) self.reset_inference_cache() - + def reset_inference_cache(self) -> None: """Reset Speech2Text parameters.""" self.frontend_cache = None - + self.asr_model.encoder.reset_streaming_cache( self.left_context, device=self.device ) self.beam_search.reset_inference_cache() - + self.num_processed_frames = torch.tensor([[0]], device=self.device) - + @torch.no_grad() def streaming_decode( - self, - speech: Union[torch.Tensor, np.ndarray], - is_final: bool = True, + self, + speech: Union[torch.Tensor, np.ndarray], + is_final: bool = True, ) -> List[HypothesisTransducer]: """Speech2Text streaming call. Args: @@ -1473,13 +1460,13 @@ class Speech2TextTransducer: ) speech = torch.cat([speech, pad], dim=0) # feats, feats_length = self.apply_frontend(speech, is_final=is_final) - + feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) - + if self.asr_model.normalize is not None: feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) - + feats = to_device(feats, device=self.device) feats_lengths = to_device(feats_lengths, device=self.device) enc_out = self.asr_model.encoder.chunk_forward( @@ -1491,14 +1478,14 @@ class Speech2TextTransducer: right_context=self.right_context, ) nbest_hyps = self.beam_search(enc_out[0], is_final=is_final) - + self.num_processed_frames += self.chunk_size - + if is_final: self.reset_inference_cache() - + return nbest_hyps - + @torch.no_grad() def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]: """Speech2Text call. @@ -1508,29 +1495,29 @@ class Speech2TextTransducer: nbest_hypothesis: N-best hypothesis. """ assert check_argument_types() - + if isinstance(speech, np.ndarray): speech = torch.tensor(speech) - + if self.frontend is not None: speech = torch.unsqueeze(speech, axis=0) speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1)) feats, feats_lengths = self.frontend(speech, speech_lengths) - else: + else: feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) - + if self.asr_model.normalize is not None: feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) - + feats = to_device(feats, device=self.device) feats_lengths = to_device(feats_lengths, device=self.device) enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context, self.right_context) nbest_hyps = self.beam_search(enc_out[0]) - + return nbest_hyps - + @torch.no_grad() def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]: """Speech2Text call. @@ -1540,7 +1527,7 @@ class Speech2TextTransducer: nbest_hypothesis: N-best hypothesis. """ assert check_argument_types() - + if isinstance(speech, np.ndarray): speech = torch.tensor(speech) @@ -1548,19 +1535,19 @@ class Speech2TextTransducer: speech = torch.unsqueeze(speech, axis=0) speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1)) feats, feats_lengths = self.frontend(speech, speech_lengths) - else: + else: feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) - + feats = to_device(feats, device=self.device) feats_lengths = to_device(feats_lengths, device=self.device) - + enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths) - + nbest_hyps = self.beam_search(enc_out[0]) - + return nbest_hyps - + def hypotheses_to_results(self, nbest_hyps: List[HypothesisTransducer]) -> List[Any]: """Build partial or final results from the hypotheses. Args: @@ -1569,26 +1556,26 @@ class Speech2TextTransducer: results: Results containing different representation for the hypothesis. """ results = [] - + for hyp in nbest_hyps: token_int = list(filter(lambda x: x != 0, hyp.yseq)) - + token = self.converter.ids2tokens(token_int) - + if self.tokenizer is not None: text = self.tokenizer.tokens2text(token) else: text = None results.append((text, token, token_int, hyp)) - + assert check_return_type(results) - + return results - + @staticmethod def from_pretrained( - model_tag: Optional[str] = None, - **kwargs: Optional[Any], + model_tag: Optional[str] = None, + **kwargs: Optional[Any], ) -> Speech2Text: """Build Speech2Text instance from the pretrained model. Args: @@ -1599,7 +1586,7 @@ class Speech2TextTransducer: if model_tag is not None: try: from espnet_model_zoo.downloader import ModelDownloader - + except ImportError: logging.error( "`espnet_model_zoo` is not installed. " @@ -1608,7 +1595,7 @@ class Speech2TextTransducer: raise d = ModelDownloader() kwargs.update(**d.download_and_unpack(model_tag)) - + return Speech2TextTransducer(**kwargs) @@ -1623,33 +1610,33 @@ class Speech2TextSAASR: [(text, token, token_int, hypothesis object), ...] """ - + def __init__( - self, - asr_train_config: Union[Path, str] = None, - asr_model_file: Union[Path, str] = None, - cmvn_file: Union[Path, str] = None, - lm_train_config: Union[Path, str] = None, - lm_file: Union[Path, str] = None, - token_type: str = None, - bpemodel: str = None, - device: str = "cpu", - maxlenratio: float = 0.0, - minlenratio: float = 0.0, - batch_size: int = 1, - dtype: str = "float32", - beam_size: int = 20, - ctc_weight: float = 0.5, - lm_weight: float = 1.0, - ngram_weight: float = 0.9, - penalty: float = 0.0, - nbest: int = 1, - streaming: bool = False, - frontend_conf: dict = None, - **kwargs, + self, + asr_train_config: Union[Path, str] = None, + asr_model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, + lm_train_config: Union[Path, str] = None, + lm_file: Union[Path, str] = None, + token_type: str = None, + bpemodel: str = None, + device: str = "cpu", + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + batch_size: int = 1, + dtype: str = "float32", + beam_size: int = 20, + ctc_weight: float = 0.5, + lm_weight: float = 1.0, + ngram_weight: float = 0.9, + penalty: float = 0.0, + nbest: int = 1, + streaming: bool = False, + frontend_conf: dict = None, + **kwargs, ): assert check_argument_types() - + # 1. Build ASR model from funasr.tasks.sa_asr import ASRTask scorers = {} @@ -1663,13 +1650,13 @@ class Speech2TextSAASR: else: frontend_class = frontend_choices.get_class(asr_train_args.frontend) frontend = frontend_class(**asr_train_args.frontend_conf).eval() - + logging.info("asr_model: {}".format(asr_model)) logging.info("asr_train_args: {}".format(asr_train_args)) asr_model.to(dtype=getattr(torch, dtype)).eval() - + decoder = asr_model.decoder - + ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) token_list = asr_model.token_list scorers.update( @@ -1677,24 +1664,24 @@ class Speech2TextSAASR: ctc=ctc, length_bonus=LengthBonus(len(token_list)), ) - + # 2. Build Language model if lm_train_config is not None: lm, lm_train_args = LMTask.build_model_from_file( lm_train_config, lm_file, None, device ) scorers["lm"] = lm.lm - + # 3. Build ngram model # ngram is not supported now ngram = None scorers["ngram"] = ngram - + # 4. Build BeamSearch object # transducer is not supported now beam_search_transducer = None from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch - + weights = dict( decoder=1.0 - ctc_weight, ctc=ctc_weight, @@ -1712,13 +1699,13 @@ class Speech2TextSAASR: token_list=token_list, pre_beam_score_key=None if ctc_weight == 1.0 else "full", ) - + # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text if token_type is None: token_type = asr_train_args.token_type if bpemodel is None: bpemodel = asr_train_args.bpemodel - + if token_type is None: tokenizer = None elif token_type == "bpe": @@ -1730,7 +1717,7 @@ class Speech2TextSAASR: tokenizer = build_tokenizer(token_type=token_type) converter = TokenIDConverter(token_list=token_list) logging.info(f"Text tokenizer: {tokenizer}") - + self.asr_model = asr_model self.asr_train_args = asr_train_args self.converter = converter @@ -1743,11 +1730,11 @@ class Speech2TextSAASR: self.dtype = dtype self.nbest = nbest self.frontend = frontend - + @torch.no_grad() def __call__( - self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray], - profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray] + self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray], + profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray] ) -> List[ Tuple[ Optional[str], @@ -1766,14 +1753,14 @@ class Speech2TextSAASR: """ assert check_argument_types() - + # Input as audio signal if isinstance(speech, np.ndarray): speech = torch.tensor(speech) - + if isinstance(profile, np.ndarray): profile = torch.tensor(profile) - + if self.frontend is not None: feats, feats_len = self.frontend.forward(speech, speech_lengths) feats = to_device(feats, device=self.device) @@ -1784,10 +1771,10 @@ class Speech2TextSAASR: feats_len = speech_lengths lfr_factor = max(1, (feats.size()[-1] // 80) - 1) batch = {"speech": feats, "speech_lengths": feats_len} - + # a. To device batch = to_device(batch, device=self.device) - + # b. Forward Encoder asr_enc, _, spk_enc = self.asr_model.encode(**batch) if isinstance(asr_enc, tuple): @@ -1796,30 +1783,30 @@ class Speech2TextSAASR: spk_enc = spk_enc[0] assert len(asr_enc) == 1, len(asr_enc) assert len(spk_enc) == 1, len(spk_enc) - + # c. Passed the encoder result and the beam search nbest_hyps = self.beam_search( asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio ) - + nbest_hyps = nbest_hyps[: self.nbest] - + results = [] for hyp in nbest_hyps: assert isinstance(hyp, (HypothesisSAASR)), type(hyp) - + # remove sos/eos and get results last_pos = -1 if isinstance(hyp.yseq, list): token_int = hyp.yseq[1: last_pos] else: token_int = hyp.yseq[1: last_pos].tolist() - + spk_weigths = torch.stack(hyp.spk_weigths, dim=0) - + token_ori = self.converter.ids2tokens(token_int) text_ori = self.tokenizer.tokens2text(token_ori) - + text_ori_spklist = text_ori.split('$') cur_index = 0 spk_choose = [] @@ -1831,32 +1818,32 @@ class Speech2TextSAASR: spk_weights_local = spk_weights_local.mean(dim=0) spk_choose_local = spk_weights_local.argmax(-1) spk_choose.append(spk_choose_local.item() + 1) - + # remove blank symbol id, which is assumed to be 0 token_int = list(filter(lambda x: x != 0, token_int)) - + # Change integer-ids to tokens token = self.converter.ids2tokens(token_int) - + if self.tokenizer is not None: text = self.tokenizer.tokens2text(token) else: text = None - + text_spklist = text.split('$') assert len(spk_choose) == len(text_spklist) - + spk_list = [] for i in range(len(text_spklist)): text_split = text_spklist[i] n = len(text_split) spk_list.append(str(spk_choose[i]) * n) - + text_id = '$'.join(spk_list) - + assert len(text) == len(text_id) - + results.append((text, text_id, token, token_int, hyp)) - + assert check_return_type(results) return results