From e4035edb4620e483279923be3de9af9c35b9ce67 Mon Sep 17 00:00:00 2001 From: zhifu gao Date: Wed, 24 Jan 2024 16:59:26 +0800 Subject: [PATCH] Funasr1.0 (#1297) * fix add_file bug (#1296) Co-authored-by: shixian.shi * funasr1.0 uniasr * funasr1.0 uniasr --------- Co-authored-by: shixian.shi --- .../uniasr/demo.py | 29 + .../uniasr/infer.sh | 11 + funasr/models/transformer/model.py | 2 +- funasr/models/uniasr/beam_search.py | 496 +++++++++++++ funasr/models/uniasr/model.py | 656 ++++++++---------- 5 files changed, 809 insertions(+), 385 deletions(-) create mode 100644 examples/industrial_data_pretraining/uniasr/demo.py create mode 100644 examples/industrial_data_pretraining/uniasr/infer.sh create mode 100644 funasr/models/uniasr/beam_search.py diff --git a/examples/industrial_data_pretraining/uniasr/demo.py b/examples/industrial_data_pretraining/uniasr/demo.py new file mode 100644 index 000000000..125902174 --- /dev/null +++ b/examples/industrial_data_pretraining/uniasr/demo.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +from funasr import AutoModel + +model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online", model_revision="v2.0.4", + # vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", + # vad_model_revision="v2.0.4", + # punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", + # punc_model_revision="v2.0.4", + ) + +res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav") +print(res) + + +''' can not use currently +from funasr import AutoFrontend + +frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4") + +fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2) + +for batch_idx, fbank_dict in enumerate(fbanks): + res = model.generate(**fbank_dict) + print(res) +''' \ No newline at end of file diff --git a/examples/industrial_data_pretraining/uniasr/infer.sh b/examples/industrial_data_pretraining/uniasr/infer.sh new file mode 100644 index 000000000..7491e98b5 --- /dev/null +++ b/examples/industrial_data_pretraining/uniasr/infer.sh @@ -0,0 +1,11 @@ + +model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" +model_revision="v2.0.4" + +python funasr/bin/inference.py \ ++model=${model} \ ++model_revision=${model_revision} \ ++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \ ++output_dir="./outputs/debug" \ ++device="cpu" \ + diff --git a/funasr/models/transformer/model.py b/funasr/models/transformer/model.py index e2367a74b..7e40060dc 100644 --- a/funasr/models/transformer/model.py +++ b/funasr/models/transformer/model.py @@ -348,7 +348,7 @@ class Transformer(nn.Module): scorers["ngram"] = ngram weights = dict( - decoder=1.0 - kwargs.get("decoding_ctc_weight"), + decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0), ctc=kwargs.get("decoding_ctc_weight", 0.0), lm=kwargs.get("lm_weight", 0.0), ngram=kwargs.get("ngram_weight", 0.0), diff --git a/funasr/models/uniasr/beam_search.py b/funasr/models/uniasr/beam_search.py new file mode 100644 index 000000000..839a1f4e0 --- /dev/null +++ b/funasr/models/uniasr/beam_search.py @@ -0,0 +1,496 @@ +"""Beam search module.""" + +from itertools import chain +import logging +from typing import Any +from typing import Dict +from typing import List +from typing import NamedTuple +from typing import Tuple +from typing import Union + +import torch + +from funasr.metrics.common import end_detect +from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface +from funasr.models.transformer.scorers.scorer_interface import ScorerInterface + + +class Hypothesis(NamedTuple): + """Hypothesis data type.""" + + yseq: torch.Tensor + score: Union[float, torch.Tensor] = 0 + scores: Dict[str, Union[float, torch.Tensor]] = dict() + states: Dict[str, Any] = dict() + + def asdict(self) -> dict: + """Convert data to JSON-friendly dict.""" + return self._replace( + yseq=self.yseq.tolist(), + score=float(self.score), + scores={k: float(v) for k, v in self.scores.items()}, + )._asdict() + + + +class BeamSearchScama(torch.nn.Module): + """Beam search implementation.""" + + def __init__( + self, + scorers: Dict[str, ScorerInterface], + weights: Dict[str, float], + beam_size: int, + vocab_size: int, + sos: int, + eos: int, + token_list: List[str] = None, + pre_beam_ratio: float = 1.5, + pre_beam_score_key: str = None, + ): + """Initialize beam search. + + Args: + scorers (dict[str, ScorerInterface]): Dict of decoder modules + e.g., Decoder, CTCPrefixScorer, LM + The scorer will be ignored if it is `None` + weights (dict[str, float]): Dict of weights for each scorers + The scorer will be ignored if its weight is 0 + beam_size (int): The number of hypotheses kept during search + vocab_size (int): The number of vocabulary + sos (int): Start of sequence id + eos (int): End of sequence id + token_list (list[str]): List of tokens for debug log + pre_beam_score_key (str): key of scores to perform pre-beam search + pre_beam_ratio (float): beam size in the pre-beam search + will be `int(pre_beam_ratio * beam_size)` + + """ + super().__init__() + # set scorers + self.weights = weights + self.scorers = dict() + self.full_scorers = dict() + self.part_scorers = dict() + # this module dict is required for recursive cast + # `self.to(device, dtype)` in `recog.py` + self.nn_dict = torch.nn.ModuleDict() + for k, v in scorers.items(): + w = weights.get(k, 0) + if w == 0 or v is None: + continue + assert isinstance( + v, ScorerInterface + ), f"{k} ({type(v)}) does not implement ScorerInterface" + self.scorers[k] = v + if isinstance(v, PartialScorerInterface): + self.part_scorers[k] = v + else: + self.full_scorers[k] = v + if isinstance(v, torch.nn.Module): + self.nn_dict[k] = v + + # set configurations + self.sos = sos + self.eos = eos + self.token_list = token_list + self.pre_beam_size = int(pre_beam_ratio * beam_size) + self.beam_size = beam_size + self.n_vocab = vocab_size + if ( + pre_beam_score_key is not None + and pre_beam_score_key != "full" + and pre_beam_score_key not in self.full_scorers + ): + raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}") + self.pre_beam_score_key = pre_beam_score_key + self.do_pre_beam = ( + self.pre_beam_score_key is not None + and self.pre_beam_size < self.n_vocab + and len(self.part_scorers) > 0 + ) + + def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]: + """Get an initial hypothesis data. + + Args: + x (torch.Tensor): The encoder output feature + + Returns: + Hypothesis: The initial hypothesis. + + """ + init_states = dict() + init_scores = dict() + for k, d in self.scorers.items(): + init_states[k] = d.init_state(x) + init_scores[k] = 0.0 + return [ + Hypothesis( + score=0.0, + scores=init_scores, + states=init_states, + yseq=torch.tensor([self.sos], device=x.device), + ) + ] + + @staticmethod + def append_token(xs: torch.Tensor, x: int) -> torch.Tensor: + """Append new token to prefix tokens. + + Args: + xs (torch.Tensor): The prefix token + x (int): The new token to append + + Returns: + torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device + + """ + x = torch.tensor([x], dtype=xs.dtype, device=xs.device) + return torch.cat((xs, x)) + + def score_full( + self, hyp: Hypothesis, + x: torch.Tensor, + x_mask: torch.Tensor = None, + pre_acoustic_embeds: torch.Tensor = None, + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: + """Score new hypothesis by `self.full_scorers`. + + Args: + hyp (Hypothesis): Hypothesis with prefix tokens to score + x (torch.Tensor): Corresponding input feature + + Returns: + Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of + score dict of `hyp` that has string keys of `self.full_scorers` + and tensor score values of shape: `(self.n_vocab,)`, + and state dict that has string keys + and state values of `self.full_scorers` + + """ + scores = dict() + states = dict() + for k, d in self.full_scorers.items(): + scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds) + return scores, states + + def score_partial( + self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: + """Score new hypothesis by `self.part_scorers`. + + Args: + hyp (Hypothesis): Hypothesis with prefix tokens to score + ids (torch.Tensor): 1D tensor of new partial tokens to score + x (torch.Tensor): Corresponding input feature + + Returns: + Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of + score dict of `hyp` that has string keys of `self.part_scorers` + and tensor score values of shape: `(len(ids),)`, + and state dict that has string keys + and state values of `self.part_scorers` + + """ + scores = dict() + states = dict() + for k, d in self.part_scorers.items(): + scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x) + return scores, states + + def beam( + self, weighted_scores: torch.Tensor, ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute topk full token ids and partial token ids. + + Args: + weighted_scores (torch.Tensor): The weighted sum scores for each tokens. + Its shape is `(self.n_vocab,)`. + ids (torch.Tensor): The partial token ids to compute topk + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + The topk full token ids and partial token ids. + Their shapes are `(self.beam_size,)` + + """ + # no pre beam performed + if weighted_scores.size(0) == ids.size(0): + top_ids = weighted_scores.topk(self.beam_size)[1] + return top_ids, top_ids + + # mask pruned in pre-beam not to select in topk + tmp = weighted_scores[ids] + weighted_scores[:] = -float("inf") + weighted_scores[ids] = tmp + top_ids = weighted_scores.topk(self.beam_size)[1] + local_ids = weighted_scores[ids].topk(self.beam_size)[1] + return top_ids, local_ids + + @staticmethod + def merge_scores( + prev_scores: Dict[str, float], + next_full_scores: Dict[str, torch.Tensor], + full_idx: int, + next_part_scores: Dict[str, torch.Tensor], + part_idx: int, + ) -> Dict[str, torch.Tensor]: + """Merge scores for new hypothesis. + + Args: + prev_scores (Dict[str, float]): + The previous hypothesis scores by `self.scorers` + next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers` + full_idx (int): The next token id for `next_full_scores` + next_part_scores (Dict[str, torch.Tensor]): + scores of partial tokens by `self.part_scorers` + part_idx (int): The new token id for `next_part_scores` + + Returns: + Dict[str, torch.Tensor]: The new score dict. + Its keys are names of `self.full_scorers` and `self.part_scorers`. + Its values are scalar tensors by the scorers. + + """ + new_scores = dict() + for k, v in next_full_scores.items(): + new_scores[k] = prev_scores[k] + v[full_idx] + for k, v in next_part_scores.items(): + new_scores[k] = prev_scores[k] + v[part_idx] + return new_scores + + def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any: + """Merge states for new hypothesis. + + Args: + states: states of `self.full_scorers` + part_states: states of `self.part_scorers` + part_idx (int): The new token id for `part_scores` + + Returns: + Dict[str, torch.Tensor]: The new score dict. + Its keys are names of `self.full_scorers` and `self.part_scorers`. + Its values are states of the scorers. + + """ + new_states = dict() + for k, v in states.items(): + new_states[k] = v + for k, d in self.part_scorers.items(): + new_states[k] = d.select_state(part_states[k], part_idx) + return new_states + + def search( + self, running_hyps: List[Hypothesis], + x: torch.Tensor, + x_mask: torch.Tensor = None, + pre_acoustic_embeds: torch.Tensor = None, + ) -> List[Hypothesis]: + """Search new tokens for running hypotheses and encoded speech x. + + Args: + running_hyps (List[Hypothesis]): Running hypotheses on beam + x (torch.Tensor): Encoded speech feature (T, D) + + Returns: + List[Hypotheses]: Best sorted hypotheses + + """ + best_hyps = [] + part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam + for hyp in running_hyps: + # scoring + weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device) + scores, states = self.score_full(hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds) + for k in self.full_scorers: + weighted_scores += self.weights[k] * scores[k] + # partial scoring + if self.do_pre_beam: + pre_beam_scores = ( + weighted_scores + if self.pre_beam_score_key == "full" + else scores[self.pre_beam_score_key] + ) + part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1] + part_scores, part_states = self.score_partial(hyp, part_ids, x) + for k in self.part_scorers: + weighted_scores[part_ids] += self.weights[k] * part_scores[k] + # add previous hyp score + weighted_scores += hyp.score + + # update hyps + for j, part_j in zip(*self.beam(weighted_scores, part_ids)): + # will be (2 x beam at most) + best_hyps.append( + Hypothesis( + score=weighted_scores[j], + yseq=self.append_token(hyp.yseq, j), + scores=self.merge_scores( + hyp.scores, scores, j, part_scores, part_j + ), + states=self.merge_states(states, part_states, part_j), + ) + ) + + # sort and prune 2 x beam -> beam + best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[ + : min(len(best_hyps), self.beam_size) + ] + return best_hyps + + def forward( + self, x: torch.Tensor, + scama_mask: torch.Tensor = None, + pre_acoustic_embeds: torch.Tensor = None, + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + maxlen: int = None, + minlen: int = 0, + ) -> List[Hypothesis]: + """Perform beam search. + + Args: + x (torch.Tensor): Encoded speech feature (T, D) + maxlenratio (float): Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths + If maxlenratio<0.0, its absolute value is interpreted + as a constant max output length. + minlenratio (float): Input length ratio to obtain min output length. + + Returns: + list[Hypothesis]: N-best decoding results + + """ + if maxlen is None: + # set length bounds + if maxlenratio == 0: + maxlen = x.shape[0] + elif maxlenratio < 0: + maxlen = -1 * int(maxlenratio) + else: + maxlen = max(1, int(maxlenratio * x.size(0))) + minlen = int(minlenratio * x.size(0)) + + logging.info("decoder input length: " + str(x.shape[0])) + logging.info("max output length: " + str(maxlen)) + logging.info("min output length: " + str(minlen)) + + # main loop of prefix search + running_hyps = self.init_hyp(x) + ended_hyps = [] + for i in range(maxlen): + logging.debug("position " + str(i)) + mask_enc = None + if scama_mask is not None: + token_num_predictor = scama_mask.size(1) + token_id_slice = min(i, token_num_predictor-1) + mask_enc = scama_mask[:, token_id_slice:token_id_slice+1, :] + # if mask_enc.size(1) == 0: + # mask_enc = scama_mask[:, -2:-1, :] + # # mask_enc = torch.zeros_like(mask_enc) + pre_acoustic_embeds_cur = None + if pre_acoustic_embeds is not None: + b, t, d = pre_acoustic_embeds.size() + pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(device=pre_acoustic_embeds.device) + pre_acoustic_embeds = torch.cat((pre_acoustic_embeds, pad), dim=1) + token_id_slice = min(i, t) + pre_acoustic_embeds_cur = pre_acoustic_embeds[:, token_id_slice:token_id_slice+1, :] + + best = self.search(running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur) + # post process of one iteration + running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps) + # end detection + if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i): + logging.info(f"end detected at {i}") + break + if len(running_hyps) == 0: + logging.info("no hypothesis. Finish decoding.") + break + else: + logging.debug(f"remained hypotheses: {len(running_hyps)}") + + nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) + # check the number of hypotheses reaching to eos + if len(nbest_hyps) == 0: + logging.warning( + "there is no N-best results, perform recognition " + "again with smaller minlenratio." + ) + return ( + [] + if minlenratio < 0.1 + else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1)) + ) + + # report the best result + for x in nbest_hyps: + yseq = "".join([self.token_list[x] for x in x.yseq]) + logging.debug("nbest: y: {}, yseq: {}, score: {}".format(x.yseq, yseq, x.score)) + best = nbest_hyps[0] + for k, v in best.scores.items(): + logging.info( + f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}" + ) + logging.info(f"total log probability: {best.score:.2f}") + logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}") + logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}") + if self.token_list is not None: + logging.info( + "best hypo: " + + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + + "\n" + ) + return nbest_hyps + + def post_process( + self, + i: int, + maxlen: int, + maxlenratio: float, + running_hyps: List[Hypothesis], + ended_hyps: List[Hypothesis], + ) -> List[Hypothesis]: + """Perform post-processing of beam search iterations. + + Args: + i (int): The length of hypothesis tokens. + maxlen (int): The maximum length of tokens in beam search. + maxlenratio (int): The maximum length ratio in beam search. + running_hyps (List[Hypothesis]): The running hypotheses in beam search. + ended_hyps (List[Hypothesis]): The ended hypotheses in beam search. + + Returns: + List[Hypothesis]: The new running hypotheses. + + """ + logging.debug(f"the number of running hypotheses: {len(running_hyps)}") + if self.token_list is not None: + logging.debug( + "best hypo: " + + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]]) + ) + # add eos in the final loop to avoid that there are no ended hyps + if i == maxlen - 1: + logging.info("adding in the last position in the loop") + running_hyps = [ + h._replace(yseq=self.append_token(h.yseq, self.eos)) + for h in running_hyps + ] + + # add ended hypotheses to a final list, and removed them from current hypotheses + # (this will be a problem, number of hyps < beam) + remained_hyps = [] + for hyp in running_hyps: + if hyp.yseq[-1] == self.eos: + # e.g., Word LM needs to add final score + for k, d in chain(self.full_scorers.items(), self.part_scorers.items()): + s = d.final_score(hyp.states[k]) + hyp.scores[k] += s + hyp = hyp._replace(score=hyp.score + self.weights[k] * s) + ended_hyps.append(hyp) + else: + remained_hyps.append(hyp) + return remained_hyps diff --git a/funasr/models/uniasr/model.py b/funasr/models/uniasr/model.py index de80d4ac7..6e564dca7 100644 --- a/funasr/models/uniasr/model.py +++ b/funasr/models/uniasr/model.py @@ -14,14 +14,13 @@ from funasr.models.ctc.ctc import CTC from funasr.utils import postprocess_utils from funasr.metrics.compute_acc import th_accuracy from funasr.utils.datadir_writer import DatadirWriter -from funasr.models.paraformer.search import Hypothesis from funasr.models.paraformer.cif_predictor import mae_loss from funasr.train_utils.device_funcs import force_gatherable from funasr.losses.label_smoothing_loss import LabelSmoothingLoss from funasr.models.transformer.utils.add_sos_eos import add_sos_eos from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank - +from funasr.models.scama.utils import sequence_mask @tables.register("model_classes", "UniASR") class UniASR(torch.nn.Module): @@ -31,19 +30,37 @@ class UniASR(torch.nn.Module): def __init__( self, - specaug: Optional[str] = None, - specaug_conf: Optional[Dict] = None, + specaug: str = None, + specaug_conf: dict = None, normalize: str = None, - normalize_conf: Optional[Dict] = None, + normalize_conf: dict = None, encoder: str = None, - encoder_conf: Optional[Dict] = None, + encoder_conf: dict = None, + encoder2: str = None, + encoder2_conf: dict = None, decoder: str = None, - decoder_conf: Optional[Dict] = None, - ctc: str = None, - ctc_conf: Optional[Dict] = None, + decoder_conf: dict = None, + decoder2: str = None, + decoder2_conf: dict = None, predictor: str = None, - predictor_conf: Optional[Dict] = None, + predictor_conf: dict = None, + predictor_bias: int = 0, + predictor_weight: float = 0.0, + predictor2: str = None, + predictor2_conf: dict = None, + predictor2_bias: int = 0, + predictor2_weight: float = 0.0, + ctc: str = None, + ctc_conf: dict = None, ctc_weight: float = 0.5, + ctc2: str = None, + ctc2_conf: dict = None, + ctc2_weight: float = 0.5, + decoder_attention_chunk_type: str = 'chunk', + decoder_attention_chunk_type2: str = 'chunk', + stride_conv=None, + stride_conv_conf: dict = None, + loss_weight_model1: float = 0.5, input_size: int = 80, vocab_size: int = -1, ignore_id: int = -1, @@ -52,60 +69,72 @@ class UniASR(torch.nn.Module): eos: int = 2, lsm_weight: float = 0.0, length_normalized_loss: bool = False, - # report_cer: bool = True, - # report_wer: bool = True, - # sym_space: str = "", - # sym_blank: str = "", - # extract_feats_in_collect_stats: bool = True, - # predictor=None, - predictor_weight: float = 0.0, - predictor_bias: int = 0, - sampling_ratio: float = 0.2, share_embedding: bool = False, - # preencoder: Optional[AbsPreEncoder] = None, - # postencoder: Optional[AbsPostEncoder] = None, - use_1st_decoder_loss: bool = False, - encoder1_encoder2_joint_training: bool = True, **kwargs, ): - assert 0.0 <= ctc_weight <= 1.0, ctc_weight - assert 0.0 <= interctc_weight < 1.0, interctc_weight - super().__init__() - self.blank_id = 0 - self.sos = 1 - self.eos = 2 + + if specaug is not None: + specaug_class = tables.specaug_classes.get(specaug) + specaug = specaug_class(**specaug_conf) + if normalize is not None: + normalize_class = tables.normalize_classes.get(normalize) + normalize = normalize_class(**normalize_conf) + + encoder_class = tables.encoder_classes.get(encoder) + encoder = encoder_class(input_size=input_size, **encoder_conf) + encoder_output_size = encoder.output_size() + + decoder_class = tables.decoder_classes.get(decoder) + decoder = decoder_class( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + **decoder_conf, + ) + predictor_class = tables.predictor_classes.get(predictor) + predictor = predictor_class(**predictor_conf) + + + + from funasr.models.transformer.utils.subsampling import Conv1dSubsampling + stride_conv = Conv1dSubsampling(**stride_conv_conf, idim=input_size + encoder_output_size, + odim=input_size + encoder_output_size) + stride_conv_output_size = stride_conv.output_size() + + encoder_class = tables.encoder_classes.get(encoder2) + encoder2 = encoder_class(input_size=stride_conv_output_size, **encoder2_conf) + encoder2_output_size = encoder2.output_size() + + decoder_class = tables.decoder_classes.get(decoder2) + decoder2 = decoder_class( + vocab_size=vocab_size, + encoder_output_size=encoder2_output_size, + **decoder2_conf, + ) + predictor_class = tables.predictor_classes.get(predictor2) + predictor2 = predictor_class(**predictor2_conf) + + + + self.blank_id = blank_id + self.sos = sos + self.eos = eos self.vocab_size = vocab_size self.ignore_id = ignore_id self.ctc_weight = ctc_weight - self.interctc_weight = interctc_weight - self.token_list = token_list.copy() + self.ctc2_weight = ctc2_weight - self.frontend = frontend self.specaug = specaug self.normalize = normalize - self.preencoder = preencoder - self.postencoder = postencoder + self.encoder = encoder - if not hasattr(self.encoder, "interctc_use_conditioning"): - self.encoder.interctc_use_conditioning = False - if self.encoder.interctc_use_conditioning: - self.encoder.conditioning_layer = torch.nn.Linear( - vocab_size, self.encoder.output_size() - ) - self.error_calculator = None - # we set self.decoder = None in the CTC mode since - # self.decoder parameters were never used and PyTorch complained - # and threw an Exception in the multi-GPU experiment. - # thanks Jeff Farris for pointing out the issue. - if ctc_weight == 1.0: - self.decoder = None - else: - self.decoder = decoder + self.decoder = decoder + self.ctc = None + self.ctc2 = None self.criterion_att = LabelSmoothingLoss( size=vocab_size, @@ -113,22 +142,13 @@ class UniASR(torch.nn.Module): smoothing=lsm_weight, normalize_length=length_normalized_loss, ) - - if report_cer or report_wer: - self.error_calculator = ErrorCalculator( - token_list, sym_space, sym_blank, report_cer, report_wer - ) - - if ctc_weight == 0.0: - self.ctc = None - else: - self.ctc = ctc - - self.extract_feats_in_collect_stats = extract_feats_in_collect_stats + self.predictor = predictor self.predictor_weight = predictor_weight self.criterion_pre = mae_loss(normalize_length=length_normalized_loss) - self.step_cur = 0 + self.encoder1_encoder2_joint_training = kwargs.get("encoder1_encoder2_joint_training", True) + + if self.encoder.overlap_chunk_cls is not None: from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder @@ -136,14 +156,10 @@ class UniASR(torch.nn.Module): self.encoder2 = encoder2 self.decoder2 = decoder2 - self.ctc_weight2 = ctc_weight2 - if ctc_weight2 == 0.0: - self.ctc2 = None - else: - self.ctc2 = ctc2 - self.interctc_weight2 = interctc_weight2 + self.ctc2_weight = ctc2_weight + self.predictor2 = predictor2 - self.predictor_weight2 = predictor_weight2 + self.predictor2_weight = predictor2_weight self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2 self.stride_conv = stride_conv self.loss_weight_model1 = loss_weight_model1 @@ -152,10 +168,10 @@ class UniASR(torch.nn.Module): self.build_scama_mask_for_cross_attention_decoder_fn2 = build_scama_mask_for_cross_attention_decoder self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2 - self.enable_maas_finetune = enable_maas_finetune - self.freeze_encoder2 = freeze_encoder2 - self.encoder1_encoder2_joint_training = encoder1_encoder2_joint_training self.length_normalized_loss = length_normalized_loss + self.enable_maas_finetune = kwargs.get("enable_maas_finetune", False) + self.freeze_encoder2 = kwargs.get("freeze_encoder2", False) + self.beam_search = None def forward( self, @@ -163,7 +179,7 @@ class UniASR(torch.nn.Module): speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, - decoding_ind: int = None, + **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: @@ -172,19 +188,14 @@ class UniASR(torch.nn.Module): text: (Batch, Length) text_lengths: (Batch,) """ - assert text_lengths.dim() == 1, text_lengths.shape - # Check that batch_size is unified - assert ( - speech.shape[0] - == speech_lengths.shape[0] - == text.shape[0] - == text_lengths.shape[0] - ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + decoding_ind = kwargs.get("decoding_ind", None) + if len(text_lengths.size()) > 1: + text_lengths = text_lengths[:, 0] + if len(speech_lengths.size()) > 1: + speech_lengths = speech_lengths[:, 0] + batch_size = speech.shape[0] - # for data-parallel - text = text[:, : text_lengths.max()] - speech = speech[:, :speech_lengths.max()] ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind) # 1. Encoder @@ -194,10 +205,6 @@ class UniASR(torch.nn.Module): else: speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind) - intermediate_outs = None - if isinstance(encoder_out, tuple): - intermediate_outs = encoder_out[1] - encoder_out = encoder_out[0] loss_att, acc_att, cer_att, wer_att = None, None, None, None loss_ctc, cer_ctc = None, None @@ -210,62 +217,12 @@ class UniASR(torch.nn.Module): # 1. CTC branch if self.enable_maas_finetune: with torch.no_grad(): - if self.ctc_weight != 0.0: - if self.encoder.overlap_chunk_cls is not None: - encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, - encoder_out_lens, - chunk_outs=None) - loss_ctc, cer_ctc = self._calc_ctc_loss( - encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths - ) - # Collect CTC branch stats - stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None - stats["cer_ctc"] = cer_ctc + loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) - # Intermediate CTC (optional) - loss_interctc = 0.0 - if self.interctc_weight != 0.0 and intermediate_outs is not None: - for layer_idx, intermediate_out in intermediate_outs: - # we assume intermediate_out has the same length & padding - # as those of encoder_out - if self.encoder.overlap_chunk_cls is not None: - encoder_out_ctc, encoder_out_lens_ctc = \ - self.encoder.overlap_chunk_cls.remove_chunk( - intermediate_out, - encoder_out_lens, - chunk_outs=None) - loss_ic, cer_ic = self._calc_ctc_loss( - encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths - ) - loss_interctc = loss_interctc + loss_ic - - # Collect Intermedaite CTC stats - stats["loss_interctc_layer{}".format(layer_idx)] = ( - loss_ic.detach() if loss_ic is not None else None - ) - stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic - - loss_interctc = loss_interctc / len(intermediate_outs) - - # calculate whole encoder loss - loss_ctc = ( - 1 - self.interctc_weight - ) * loss_ctc + self.interctc_weight * loss_interctc - - # 2b. Attention decoder branch - if self.ctc_weight != 1.0: - loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss( - encoder_out, encoder_out_lens, text, text_lengths - ) - - # 3. CTC-Att loss definition - if self.ctc_weight == 0.0: - loss = loss_att + loss_pre * self.predictor_weight - elif self.ctc_weight == 1.0: - loss = loss_ctc - else: - loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss = loss_att + loss_pre * self.predictor_weight # Collect Attn branch stats stats["loss_att"] = loss_att.detach() if loss_att is not None else None @@ -274,62 +231,13 @@ class UniASR(torch.nn.Module): stats["wer"] = wer_att stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None else: - if self.ctc_weight != 0.0: - if self.encoder.overlap_chunk_cls is not None: - encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, - encoder_out_lens, - chunk_outs=None) - loss_ctc, cer_ctc = self._calc_ctc_loss( - encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths - ) + + loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) - # Collect CTC branch stats - stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None - stats["cer_ctc"] = cer_ctc - # Intermediate CTC (optional) - loss_interctc = 0.0 - if self.interctc_weight != 0.0 and intermediate_outs is not None: - for layer_idx, intermediate_out in intermediate_outs: - # we assume intermediate_out has the same length & padding - # as those of encoder_out - if self.encoder.overlap_chunk_cls is not None: - encoder_out_ctc, encoder_out_lens_ctc = \ - self.encoder.overlap_chunk_cls.remove_chunk( - intermediate_out, - encoder_out_lens, - chunk_outs=None) - loss_ic, cer_ic = self._calc_ctc_loss( - encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths - ) - loss_interctc = loss_interctc + loss_ic - - # Collect Intermedaite CTC stats - stats["loss_interctc_layer{}".format(layer_idx)] = ( - loss_ic.detach() if loss_ic is not None else None - ) - stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic - - loss_interctc = loss_interctc / len(intermediate_outs) - - # calculate whole encoder loss - loss_ctc = ( - 1 - self.interctc_weight - ) * loss_ctc + self.interctc_weight * loss_interctc - - # 2b. Attention decoder branch - if self.ctc_weight != 1.0: - loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss( - encoder_out, encoder_out_lens, text, text_lengths - ) - - # 3. CTC-Att loss definition - if self.ctc_weight == 0.0: - loss = loss_att + loss_pre * self.predictor_weight - elif self.ctc_weight == 1.0: - loss = loss_ctc - else: - loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss = loss_att + loss_pre * self.predictor_weight # Collect Attn branch stats stats["loss_att"] = loss_att.detach() if loss_att is not None else None @@ -354,67 +262,14 @@ class UniASR(torch.nn.Module): if isinstance(encoder_out, tuple): intermediate_outs = encoder_out[1] encoder_out = encoder_out[0] - # CTC2 - if self.ctc_weight2 != 0.0: - if self.encoder2.overlap_chunk_cls is not None: - encoder_out_ctc, encoder_out_lens_ctc = \ - self.encoder2.overlap_chunk_cls.remove_chunk( - encoder_out, - encoder_out_lens, - chunk_outs=None, - ) - loss_ctc, cer_ctc = self._calc_ctc_loss2( - encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths - ) - # Collect CTC branch stats - stats["loss_ctc2"] = loss_ctc.detach() if loss_ctc is not None else None - stats["cer_ctc2"] = cer_ctc - # Intermediate CTC (optional) - loss_interctc = 0.0 - if self.interctc_weight2 != 0.0 and intermediate_outs is not None: - for layer_idx, intermediate_out in intermediate_outs: - # we assume intermediate_out has the same length & padding - # as those of encoder_out - if self.encoder2.overlap_chunk_cls is not None: - encoder_out_ctc, encoder_out_lens_ctc = \ - self.encoder2.overlap_chunk_cls.remove_chunk( - intermediate_out, - encoder_out_lens, - chunk_outs=None) - loss_ic, cer_ic = self._calc_ctc_loss2( - encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths - ) - loss_interctc = loss_interctc + loss_ic + loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss2( + encoder_out, encoder_out_lens, text, text_lengths + ) - # Collect Intermedaite CTC stats - stats["loss_interctc_layer{}2".format(layer_idx)] = ( - loss_ic.detach() if loss_ic is not None else None - ) - stats["cer_interctc_layer{}2".format(layer_idx)] = cer_ic - loss_interctc = loss_interctc / len(intermediate_outs) - - # calculate whole encoder loss - loss_ctc = ( - 1 - self.interctc_weight2 - ) * loss_ctc + self.interctc_weight2 * loss_interctc - - # 2b. Attention decoder branch - if self.ctc_weight2 != 1.0: - loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss2( - encoder_out, encoder_out_lens, text, text_lengths - ) - - # 3. CTC-Att loss definition - if self.ctc_weight2 == 0.0: - loss = loss_att + loss_pre * self.predictor_weight2 - elif self.ctc_weight2 == 1.0: - loss = loss_ctc - else: - loss = self.ctc_weight2 * loss_ctc + ( - 1 - self.ctc_weight2) * loss_att + loss_pre * self.predictor_weight2 + loss = loss_att + loss_pre * self.predictor2_weight # Collect Attn branch stats stats["loss_att2"] = loss_att.detach() if loss_att is not None else None @@ -422,6 +277,7 @@ class UniASR(torch.nn.Module): stats["cer2"] = cer_att stats["wer2"] = wer_att stats["loss_pre2"] = loss_pre.detach().cpu() if loss_pre is not None else None + loss2 = loss loss = loss1 * self.loss_weight_model1 + loss2 * (1 - self.loss_weight_model1) @@ -456,61 +312,31 @@ class UniASR(torch.nn.Module): return {"feats": feats, "feats_lengths": feats_lengths} def encode( - self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0, - ) -> Tuple[torch.Tensor, torch.Tensor]: + self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, + ): """Frontend + Encoder. Note that this method is used by asr_inference.py Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) """ + ind = kwargs.get("ind", 0) with autocast(False): - # 1. Extract feats - feats, feats_lengths = self._extract_feats(speech, speech_lengths) - - # 2. Data augmentation + # Data augmentation if self.specaug is not None and self.training: - feats, feats_lengths = self.specaug(feats, feats_lengths) - - # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + speech, speech_lengths = self.specaug(speech, speech_lengths) + + # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN if self.normalize is not None: - feats, feats_lengths = self.normalize(feats, feats_lengths) - speech_raw = feats.clone().to(feats.device) - # Pre-encoder, e.g. used for raw input data - if self.preencoder is not None: - feats, feats_lengths = self.preencoder(feats, feats_lengths) + speech, speech_lengths = self.normalize(speech, speech_lengths) + + speech_raw = speech.clone().to(speech.device) + # 4. Forward encoder - # feats: (Batch, Length, Dim) - # -> encoder_out: (Batch, Length2, Dim2) - if self.encoder.interctc_use_conditioning: - encoder_out, encoder_out_lens, _ = self.encoder( - feats, feats_lengths, ctc=self.ctc, ind=ind - ) - else: - encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind) - intermediate_outs = None + encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ind=ind) if isinstance(encoder_out, tuple): - intermediate_outs = encoder_out[1] encoder_out = encoder_out[0] - # Post-encoder, e.g. NLU - if self.postencoder is not None: - encoder_out, encoder_out_lens = self.postencoder( - encoder_out, encoder_out_lens - ) - - assert encoder_out.size(0) == speech.size(0), ( - encoder_out.size(), - speech.size(0), - ) - assert encoder_out.size(1) <= encoder_out_lens.max(), ( - encoder_out.size(), - encoder_out_lens.max(), - ) - - if intermediate_outs is not None: - return (encoder_out, intermediate_outs), encoder_out_lens - return speech_raw, encoder_out, encoder_out_lens def encode2( @@ -519,28 +345,15 @@ class UniASR(torch.nn.Module): encoder_out_lens: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, - ind: int = 0, - ) -> Tuple[torch.Tensor, torch.Tensor]: + **kwargs, + ): """Frontend + Encoder. Note that this method is used by asr_inference.py Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) """ - # with autocast(False): - # # 1. Extract feats - # feats, feats_lengths = self._extract_feats(speech, speech_lengths) - # - # # 2. Data augmentation - # if self.specaug is not None and self.training: - # feats, feats_lengths = self.specaug(feats, feats_lengths) - # - # # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN - # if self.normalize is not None: - # feats, feats_lengths = self.normalize(feats, feats_lengths) - # Pre-encoder, e.g. used for raw input data - # if self.preencoder is not None: - # feats, feats_lengths = self.preencoder(feats, feats_lengths) + ind = kwargs.get("ind", 0) encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk( encoder_out, encoder_out_lens, @@ -557,55 +370,14 @@ class UniASR(torch.nn.Module): # 4. Forward encoder # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) - if self.encoder2.interctc_use_conditioning: - encoder_out, encoder_out_lens, _ = self.encoder2( - speech, speech_lengths, ctc=self.ctc2, ind=ind - ) - else: - encoder_out, encoder_out_lens, _ = self.encoder2(speech, speech_lengths, ind=ind) - intermediate_outs = None + + encoder_out, encoder_out_lens, _ = self.encoder2(speech, speech_lengths, ind=ind) if isinstance(encoder_out, tuple): - intermediate_outs = encoder_out[1] encoder_out = encoder_out[0] - # # Post-encoder, e.g. NLU - # if self.postencoder is not None: - # encoder_out, encoder_out_lens = self.postencoder( - # encoder_out, encoder_out_lens - # ) - - assert encoder_out.size(0) == speech.size(0), ( - encoder_out.size(), - speech.size(0), - ) - assert encoder_out.size(1) <= encoder_out_lens.max(), ( - encoder_out.size(), - encoder_out_lens.max(), - ) - - if intermediate_outs is not None: - return (encoder_out, intermediate_outs), encoder_out_lens return encoder_out, encoder_out_lens - def _extract_feats( - self, speech: torch.Tensor, speech_lengths: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - assert speech_lengths.dim() == 1, speech_lengths.shape - - # for data-parallel - speech = speech[:, : speech_lengths.max()] - - if self.frontend is not None: - # Frontend - # e.g. STFT and Feature extract - # data_loader may send time-domain signal in this case - # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) - feats, feats_lengths = self.frontend(speech, speech_lengths) - else: - # No frontend and no feature extract - feats, feats_lengths = speech, speech_lengths - return feats, feats_lengths def nll( self, @@ -1024,36 +796,152 @@ class UniASR(torch.nn.Module): return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask - def _calc_ctc_loss( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ys_pad: torch.Tensor, - ys_pad_lens: torch.Tensor, - ): - # Calc CTC loss - loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + def init_beam_search(self, + **kwargs, + ): + from funasr.models.uniasr.beam_search import BeamSearchScama + from funasr.models.transformer.scorers.ctc import CTCPrefixScorer + from funasr.models.transformer.scorers.length_bonus import LengthBonus - # Calc CER using CTC - cer_ctc = None - if not self.training and self.error_calculator is not None: - ys_hat = self.ctc.argmax(encoder_out).data - cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) - return loss_ctc, cer_ctc + decoding_mode = kwargs.get("decoding_mode", "model1") + if decoding_mode == "model1": + decoder = self.decoder + else: + decoder = self.decoder2 + # 1. Build ASR model + scorers = {} + + if self.ctc != None: + ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos) + scorers.update( + ctc=ctc + ) + token_list = kwargs.get("token_list") + scorers.update( + decoder=decoder, + length_bonus=LengthBonus(len(token_list)), + ) + + # 3. Build ngram model + # ngram is not supported now + ngram = None + scorers["ngram"] = ngram + + weights = dict( + decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0), + ctc=kwargs.get("decoding_ctc_weight", 0.0), + lm=kwargs.get("lm_weight", 0.0), + ngram=kwargs.get("ngram_weight", 0.0), + length_bonus=kwargs.get("penalty", 0.0), + ) + beam_search = BeamSearchScama( + beam_size=kwargs.get("beam_size", 5), + weights=weights, + scorers=scorers, + sos=self.sos, + eos=self.eos, + vocab_size=len(token_list), + token_list=token_list, + pre_beam_score_key=None if self.ctc_weight == 1.0 else "full", + ) + + self.beam_search = beam_search - def _calc_ctc_loss2( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ys_pad: torch.Tensor, - ys_pad_lens: torch.Tensor, - ): - # Calc CTC loss - loss_ctc = self.ctc2(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + def inference(self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): - # Calc CER using CTC - cer_ctc = None - if not self.training and self.error_calculator is not None: - ys_hat = self.ctc2.argmax(encoder_out).data - cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) - return loss_ctc, cer_ctc + decoding_model = kwargs.get("decoding_model", "normal") + token_num_relax = kwargs.get("token_num_relax", 5) + if decoding_model == "fast": + decoding_ind = 0 + decoding_mode = "model1" + elif decoding_model == "offline": + decoding_ind = 1 + decoding_mode = "model2" + else: + decoding_ind = 0 + decoding_mode = "model2" + # init beamsearch + + if self.beam_search is None: + logging.info("enable beam_search") + self.init_beam_search(decoding_mode=decoding_mode, **kwargs) + self.nbest = kwargs.get("nbest", 1) + + meta_data = {} + if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank + speech, speech_lengths = data_in, data_lengths + if len(speech.shape) < 3: + speech = speech[None, :, :] + if speech_lengths is None: + speech_lengths = speech.shape[1] + else: + # extract fbank feats + time1 = time.perf_counter() + audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), + data_type=kwargs.get("data_type", "sound"), + tokenizer=tokenizer) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), + frontend=frontend) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + + speech = speech.to(device=kwargs["device"]) + speech_lengths = speech_lengths.to(device=kwargs["device"]) + speech_raw = speech.clone().to(device=kwargs["device"]) + # Encoder + _, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=decoding_ind) + if decoding_mode == "model1": + predictor_outs = self.calc_predictor_mask(encoder_out, encoder_out_lens) + else: + encoder_out, encoder_out_lens = self.encode2(encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=decoding_ind) + predictor_outs = self.calc_predictor_mask2(encoder_out, encoder_out_lens) + + + scama_mask = predictor_outs[4] + pre_token_length = predictor_outs[1] + pre_acoustic_embeds = predictor_outs[0] + maxlen = pre_token_length.sum().item() + token_num_relax + minlen = max(0, pre_token_length.sum().item() - token_num_relax) + # c. Passed the encoder result and the beam search + nbest_hyps = self.beam_search( + x=encoder_out[0], scama_mask=scama_mask, pre_acoustic_embeds=pre_acoustic_embeds, maxlenratio=0.0, + minlenratio=0.0, maxlen=int(maxlen), minlen=int(minlen), + ) + + nbest_hyps = nbest_hyps[: self.nbest] + + results = [] + for hyp in nbest_hyps: + + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != 0, token_int)) + + + # Change integer-ids to tokens + token = tokenizer.ids2tokens(token_int) + text_postprocessed = tokenizer.tokens2text(token) + if not hasattr(tokenizer, "bpemodel"): + text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) + + + result_i = {"key": key[0], "text": text_postprocessed} + results.append(result_i) + + return results, meta_data \ No newline at end of file