import logging from contextlib import contextmanager from distutils.version import LooseVersion from typing import Dict from typing import List from typing import Optional from typing import Tuple from typing import Union import tempfile import codecs import requests import re import copy import torch import torch.nn as nn import random import numpy as np import time from funasr.losses.label_smoothing_loss import ( LabelSmoothingLoss, # noqa: H301 ) # from funasr.models.ctc import CTC # from funasr.models.decoder.abs_decoder import AbsDecoder # from funasr.models.e2e_asr_common import ErrorCalculator # from funasr.models.encoder.abs_encoder import AbsEncoder # from funasr.frontends.abs_frontend import AbsFrontend # from funasr.models.postencoder.abs_postencoder import AbsPostEncoder from funasr.models.paraformer.cif_predictor import mae_loss # from funasr.models.preencoder.abs_preencoder import AbsPreEncoder # from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.transformer.utils.add_sos_eos import add_sos_eos from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list from funasr.metrics.compute_acc import th_accuracy from funasr.train_utils.device_funcs import force_gatherable # from funasr.models.base_model import FunASRModel # from funasr.models.paraformer.cif_predictor import CifPredictorV3 from funasr.models.paraformer.search import Hypothesis from funasr.models.model_class_factory import * if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast else: # Nothing to do if torch<1.6.0 @contextmanager def autocast(enabled=True): yield from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank from funasr.utils import postprocess_utils from funasr.utils.datadir_writer import DatadirWriter from funasr.models.transformer.utils.nets_utils import get_transducer_task_io class Transducer(nn.Module): """ESPnet2ASRTransducerModel module definition.""" def __init__( self, frontend: Optional[str] = None, frontend_conf: Optional[Dict] = None, specaug: Optional[str] = None, specaug_conf: Optional[Dict] = None, normalize: str = None, normalize_conf: Optional[Dict] = None, encoder: str = None, encoder_conf: Optional[Dict] = None, decoder: str = None, decoder_conf: Optional[Dict] = None, joint_network: str = None, joint_network_conf: Optional[Dict] = None, transducer_weight: float = 1.0, fastemit_lambda: float = 0.0, auxiliary_ctc_weight: float = 0.0, auxiliary_ctc_dropout_rate: float = 0.0, auxiliary_lm_loss_weight: float = 0.0, auxiliary_lm_loss_smoothing: float = 0.0, input_size: int = 80, vocab_size: int = -1, ignore_id: int = -1, blank_id: int = 0, sos: int = 1, eos: int = 2, lsm_weight: float = 0.0, length_normalized_loss: bool = False, # report_cer: bool = True, # report_wer: bool = True, # sym_space: str = "", # sym_blank: str = "", # extract_feats_in_collect_stats: bool = True, share_embedding: bool = False, # preencoder: Optional[AbsPreEncoder] = None, # postencoder: Optional[AbsPostEncoder] = None, **kwargs, ): super().__init__() if frontend is not None: frontend_class = frontend_classes.get_class(frontend) frontend = frontend_class(**frontend_conf) if specaug is not None: specaug_class = specaug_classes.get_class(specaug) specaug = specaug_class(**specaug_conf) if normalize is not None: normalize_class = normalize_classes.get_class(normalize) normalize = normalize_class(**normalize_conf) encoder_class = encoder_classes.get_class(encoder) encoder = encoder_class(input_size=input_size, **encoder_conf) encoder_output_size = encoder.output_size() decoder_class = decoder_classes.get_class(decoder) decoder = decoder_class( vocab_size=vocab_size, encoder_output_size=encoder_output_size, **decoder_conf, ) decoder_output_size = decoder.output_size joint_network_class = joint_network_classes.get_class(decoder) joint_network = joint_network_class( vocab_size, encoder_output_size, decoder_output_size, **joint_network_conf, ) self.criterion_transducer = None self.error_calculator = None self.use_auxiliary_ctc = auxiliary_ctc_weight > 0 self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0 if self.use_auxiliary_ctc: self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size) self.ctc_dropout_rate = auxiliary_ctc_dropout_rate if self.use_auxiliary_lm_loss: self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size) self.lm_loss_smoothing = auxiliary_lm_loss_smoothing self.transducer_weight = transducer_weight self.fastemit_lambda = fastemit_lambda self.auxiliary_ctc_weight = auxiliary_ctc_weight self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight self.blank_id = blank_id self.sos = sos if sos is not None else vocab_size - 1 self.eos = eos if eos is not None else vocab_size - 1 self.vocab_size = vocab_size self.ignore_id = ignore_id self.frontend = frontend self.specaug = specaug self.normalize = normalize self.encoder = encoder self.decoder = decoder self.joint_network = joint_network self.criterion_att = LabelSmoothingLoss( size=vocab_size, padding_idx=ignore_id, smoothing=lsm_weight, normalize_length=length_normalized_loss, ) # # if report_cer or report_wer: # self.error_calculator = ErrorCalculator( # token_list, sym_space, sym_blank, report_cer, report_wer # ) # self.length_normalized_loss = length_normalized_loss self.beam_search = None def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) """ # import pdb; # pdb.set_trace() if len(text_lengths.size()) > 1: text_lengths = text_lengths[:, 0] if len(speech_lengths.size()) > 1: speech_lengths = speech_lengths[:, 0] batch_size = speech.shape[0] # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) if hasattr(self.encoder, 'overlap_chunk_cls') and self.encoder.overlap_chunk_cls is not None: encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens, chunk_outs=None) # 2. Transducer-related I/O preparation decoder_in, target, t_len, u_len = get_transducer_task_io( text, encoder_out_lens, ignore_id=self.ignore_id, ) # 3. Decoder self.decoder.set_device(encoder_out.device) decoder_out = self.decoder(decoder_in, u_len) # 4. Joint Network joint_out = self.joint_network( encoder_out.unsqueeze(2), decoder_out.unsqueeze(1) ) # 5. Losses loss_trans, cer_trans, wer_trans = self._calc_transducer_loss( encoder_out, joint_out, target, t_len, u_len, ) loss_ctc, loss_lm = 0.0, 0.0 if self.use_auxiliary_ctc: loss_ctc = self._calc_ctc_loss( encoder_out, target, t_len, u_len, ) if self.use_auxiliary_lm_loss: loss_lm = self._calc_lm_loss(decoder_out, target) loss = ( self.transducer_weight * loss_trans + self.auxiliary_ctc_weight * loss_ctc + self.auxiliary_lm_loss_weight * loss_lm ) stats = dict( loss=loss.detach(), loss_transducer=loss_trans.detach(), aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None, aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None, cer_transducer=cer_trans, wer_transducer=wer_trans, ) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight def encode( self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder. Note that this method is used by asr_inference.py Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) ind: int """ with autocast(False): # Data augmentation if self.specaug is not None and self.training: speech, speech_lengths = self.specaug(speech, speech_lengths) # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN if self.normalize is not None: speech, speech_lengths = self.normalize(speech, speech_lengths) # Forward encoder # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) if self.encoder.interctc_use_conditioning: encoder_out, encoder_out_lens, _ = self.encoder( speech, speech_lengths, ctc=self.ctc ) else: encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths) intermediate_outs = None if isinstance(encoder_out, tuple): intermediate_outs = encoder_out[1] encoder_out = encoder_out[0] if intermediate_outs is not None: return (encoder_out, intermediate_outs), encoder_out_lens return encoder_out, encoder_out_lens def _calc_transducer_loss( self, encoder_out: torch.Tensor, joint_out: torch.Tensor, target: torch.Tensor, t_len: torch.Tensor, u_len: torch.Tensor, ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]: """Compute Transducer loss. Args: encoder_out: Encoder output sequences. (B, T, D_enc) joint_out: Joint Network output sequences (B, T, U, D_joint) target: Target label ID sequences. (B, L) t_len: Encoder output sequences lengths. (B,) u_len: Target label ID sequences lengths. (B,) Return: loss_transducer: Transducer loss value. cer_transducer: Character error rate for Transducer. wer_transducer: Word Error Rate for Transducer. """ if self.criterion_transducer is None: try: from warp_rnnt import rnnt_loss as RNNTLoss self.criterion_transducer = RNNTLoss except ImportError: logging.error( "warp-rnnt was not installed." "Please consult the installation documentation." ) exit(1) log_probs = torch.log_softmax(joint_out, dim=-1) loss_transducer = self.criterion_transducer( log_probs, target, t_len, u_len, reduction="mean", blank=self.blank_id, fastemit_lambda=self.fastemit_lambda, gather=True, ) if not self.training and (self.report_cer or self.report_wer): if self.error_calculator is None: from funasr.metrics import ErrorCalculatorTransducer as ErrorCalculator self.error_calculator = ErrorCalculator( self.decoder, self.joint_network, self.token_list, self.sym_space, self.sym_blank, report_cer=self.report_cer, report_wer=self.report_wer, ) cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len) return loss_transducer, cer_transducer, wer_transducer return loss_transducer, None, None def _calc_ctc_loss( self, encoder_out: torch.Tensor, target: torch.Tensor, t_len: torch.Tensor, u_len: torch.Tensor, ) -> torch.Tensor: """Compute CTC loss. Args: encoder_out: Encoder output sequences. (B, T, D_enc) target: Target label ID sequences. (B, L) t_len: Encoder output sequences lengths. (B,) u_len: Target label ID sequences lengths. (B,) Return: loss_ctc: CTC loss value. """ ctc_in = self.ctc_lin( torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate) ) ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1) target_mask = target != 0 ctc_target = target[target_mask].cpu() with torch.backends.cudnn.flags(deterministic=True): loss_ctc = torch.nn.functional.ctc_loss( ctc_in, ctc_target, t_len, u_len, zero_infinity=True, reduction="sum", ) loss_ctc /= target.size(0) return loss_ctc def _calc_lm_loss( self, decoder_out: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """Compute LM loss. Args: decoder_out: Decoder output sequences. (B, U, D_dec) target: Target label ID sequences. (B, L) Return: loss_lm: LM loss value. """ lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size) lm_target = target.view(-1).type(torch.int64) with torch.no_grad(): true_dist = lm_loss_in.clone() true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1)) # Ignore blank ID (0) ignore = lm_target == 0 lm_target = lm_target.masked_fill(ignore, 0) true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing)) loss_lm = torch.nn.functional.kl_div( torch.log_softmax(lm_loss_in, dim=1), true_dist, reduction="none", ) loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size( 0 ) return loss_lm def init_beam_search(self, **kwargs, ): from funasr.models.transformer.search import BeamSearch from funasr.models.transformer.scorers.ctc import CTCPrefixScorer from funasr.models.transformer.scorers.length_bonus import LengthBonus # 1. Build ASR model scorers = {} if self.ctc != None: ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos) scorers.update( ctc=ctc ) token_list = kwargs.get("token_list") scorers.update( length_bonus=LengthBonus(len(token_list)), ) # 3. Build ngram model # ngram is not supported now ngram = None scorers["ngram"] = ngram weights = dict( decoder=1.0 - kwargs.get("decoding_ctc_weight"), ctc=kwargs.get("decoding_ctc_weight", 0.0), lm=kwargs.get("lm_weight", 0.0), ngram=kwargs.get("ngram_weight", 0.0), length_bonus=kwargs.get("penalty", 0.0), ) beam_search = BeamSearch( beam_size=kwargs.get("beam_size", 2), weights=weights, scorers=scorers, sos=self.sos, eos=self.eos, vocab_size=len(token_list), token_list=token_list, pre_beam_score_key=None if self.ctc_weight == 1.0 else "full", ) # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval() # for scorer in scorers.values(): # if isinstance(scorer, torch.nn.Module): # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval() self.beam_search = beam_search def generate(self, data_in: list, data_lengths: list=None, key: list=None, tokenizer=None, **kwargs, ): if kwargs.get("batch_size", 1) > 1: raise NotImplementedError("batch decoding is not implemented") # init beamsearch is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None if self.beam_search is None and (is_use_lm or is_use_ctc): logging.info("enable beam_search") self.init_beam_search(**kwargs) self.nbest = kwargs.get("nbest", 1) meta_data = {} # extract fbank feats time1 = time.perf_counter() audio_sample_list = load_audio_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000)) time2 = time.perf_counter() meta_data["load_data"] = f"{time2 - time1:0.3f}" speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend) time3 = time.perf_counter() meta_data["extract_feat"] = f"{time3 - time2:0.3f}" meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000 speech = speech.to(device=kwargs["device"]) speech_lengths = speech_lengths.to(device=kwargs["device"]) # Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) if isinstance(encoder_out, tuple): encoder_out = encoder_out[0] # c. Passed the encoder result and the beam search nbest_hyps = self.beam_search( x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0) ) nbest_hyps = nbest_hyps[: self.nbest] results = [] b, n, d = encoder_out.size() for i in range(b): for nbest_idx, hyp in enumerate(nbest_hyps): ibest_writer = None if ibest_writer is None and kwargs.get("output_dir") is not None: writer = DatadirWriter(kwargs.get("output_dir")) ibest_writer = writer[f"{nbest_idx+1}best_recog"] # remove sos/eos and get results last_pos = -1 if isinstance(hyp.yseq, list): token_int = hyp.yseq[1:last_pos] else: token_int = hyp.yseq[1:last_pos].tolist() # remove blank symbol id, which is assumed to be 0 token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) # Change integer-ids to tokens token = tokenizer.ids2tokens(token_int) text = tokenizer.tokens2text(token) text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed} results.append(result_i) if ibest_writer is not None: ibest_writer["token"][key[i]] = " ".join(token) ibest_writer["text"][key[i]] = text ibest_writer["text_postprocessed"][key[i]] = text_postprocessed return results, meta_data