import os import logging from contextlib import contextmanager from distutils.version import LooseVersion from typing import Dict from typing import List from typing import Optional from typing import Tuple from typing import Union import tempfile import codecs import requests import re import copy import torch import torch.nn as nn import random import numpy as np import time # from funasr.layers.abs_normalize import AbsNormalize from funasr.losses.label_smoothing_loss import ( LabelSmoothingLoss, # noqa: H301 ) # from funasr.models.ctc import CTC # from funasr.models.decoder.abs_decoder import AbsDecoder # from funasr.models.e2e_asr_common import ErrorCalculator # from funasr.models.encoder.abs_encoder import AbsEncoder # from funasr.frontends.abs_frontend import AbsFrontend # from funasr.models.postencoder.abs_postencoder import AbsPostEncoder from funasr.models.paraformer.cif_predictor import mae_loss # from funasr.models.preencoder.abs_preencoder import AbsPreEncoder # from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.transformer.utils.add_sos_eos import add_sos_eos from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list from funasr.metrics.compute_acc import th_accuracy from funasr.train_utils.device_funcs import force_gatherable # from funasr.models.base_model import FunASRModel # from funasr.models.paraformer.cif_predictor import CifPredictorV3 from funasr.models.paraformer.search import Hypothesis if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast else: # Nothing to do if torch<1.6.0 @contextmanager def autocast(enabled=True): yield from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank from funasr.utils import postprocess_utils from funasr.utils.datadir_writer import DatadirWriter from funasr.models.paraformer.model import Paraformer from funasr.register import tables @tables.register("model_classes", "SeacoParaformer") class SeacoParaformer(Paraformer): """ Author: Speech Lab of DAMO Academy, Alibaba Group SeACo-Paraformer: A Non-Autoregressive ASR System with Flexible and Effective Hotword Customization Ability https://arxiv.org/abs/2308.03266 """ def __init__( self, *args, **kwargs, ): super().__init__(*args, **kwargs) self.inner_dim = kwargs.get("inner_dim", 256) self.bias_encoder_type = kwargs.get("bias_encoder_type", "lstm") bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0) bias_encoder_bid = kwargs.get("bias_encoder_bid", False) seaco_lsm_weight = kwargs.get("seaco_lsm_weight", 0.0) seaco_length_normalized_loss = kwargs.get("seaco_length_normalized_loss", True) # bias encoder if self.bias_encoder_type == 'lstm': logging.warning("enable bias encoder sampling and contextual training") self.bias_encoder = torch.nn.LSTM(self.inner_dim, self.inner_dim, 2, batch_first=True, dropout=bias_encoder_dropout_rate, bidirectional=bias_encoder_bid) if bias_encoder_bid: self.lstm_proj = torch.nn.Linear(self.inner_dim*2, self.inner_dim) else: self.lstm_proj = None self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim) elif self.bias_encoder_type == 'mean': logging.warning("enable bias encoder sampling and contextual training") self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim) else: logging.error("Unsupport bias encoder type: {}".format(self.bias_encoder_type)) # seaco decoder seaco_decoder = kwargs.get("seaco_decoder", None) if seaco_decoder is not None: seaco_decoder_conf = kwargs.get("seaco_decoder_conf") seaco_decoder_class = tables.decoder_classes.get(seaco_decoder.lower()) self.seaco_decoder = seaco_decoder_class( vocab_size=self.vocab_size, encoder_output_size=self.inner_dim, **seaco_decoder_conf, ) self.hotword_output_layer = torch.nn.Linear(self.inner_dim, self.vocab_size) self.criterion_seaco = LabelSmoothingLoss( size=self.vocab_size, padding_idx=self.ignore_id, smoothing=seaco_lsm_weight, normalize_length=seaco_length_normalized_loss, ) self.train_decoder = kwargs.get("train_decoder", False) self.NO_BIAS = kwargs.get("NO_BIAS", 8377) def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert ( speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0] ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) hotword_pad = kwargs.get("hotword_pad") hotword_lengths = kwargs.get("hotword_lengths") dha_pad = kwargs.get("dha_pad") batch_size = speech.shape[0] self.step_cur += 1 # for data-parallel text = text[:, : text_lengths.max()] speech = speech[:, :speech_lengths.max()] # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) if self.predictor_bias == 1: _, ys_pad = add_sos_eos(text, self.sos, self.eos, self.ignore_id) ys_lengths = text_lengths + self.predictor_bias stats = dict() loss_seaco = self._calc_seaco_loss(encoder_out, encoder_out_lens, ys_pad, ys_lengths, hotword_pad, hotword_lengths, dha_pad, ) if self.train_decoder: loss_att, acc_att = self._calc_att_loss( encoder_out, encoder_out_lens, text, text_lengths ) loss = loss_seaco + loss_att stats["loss_att"] = torch.clone(loss_att.detach()) stats["acc_att"] = acc_att else: loss = loss_seaco stats["loss_seaco"] = torch.clone(loss_seaco.detach()) stats["loss"] = torch.clone(loss.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel if self.length_normalized_loss: batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight def _merge(self, cif_attended, dec_attended): return cif_attended + dec_attended def _calc_seaco_loss( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_lengths: torch.Tensor, hotword_pad: torch.Tensor, hotword_lengths: torch.Tensor, dha_pad: torch.Tensor, ): # predictor forward encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( encoder_out.device) pre_acoustic_embeds, _, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id) # decoder forward decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True) selected = self._hotword_representation(hotword_pad, hotword_lengths) contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device) num_hot_word = contextual_info.shape[1] _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device) # dha core cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, pre_acoustic_embeds, ys_lengths) dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_out, ys_lengths) merged = self._merge(cif_attended, dec_attended) dha_output = self.hotword_output_layer(merged[:, :-1]) # remove the last token in loss calculation loss_att = self.criterion_seaco(dha_output, dha_pad) return loss_att def _seaco_decode_with_ASF(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list, nfilter=50, seaco_weight=1.0): # decoder forward decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True) decoder_pred = torch.log_softmax(decoder_out, dim=-1) if hw_list is not None: hw_lengths = [len(i) for i in hw_list] hw_list_ = [torch.Tensor(i).long() for i in hw_list] hw_list_pad = pad_list(hw_list_, 0).to(encoder_out.device) selected = self._hotword_representation(hw_list_pad, torch.Tensor(hw_lengths).int().to(encoder_out.device)) contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device) num_hot_word = contextual_info.shape[1] _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device) # ASF Core if nfilter > 0 and nfilter < num_hot_word: for dec in self.seaco_decoder.decoders: dec.reserve_attn = True # cif_attended, _ = self.decoder2(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens) dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens) # cif_filter = torch.topk(self.decoder2.decoders[-1].attn_mat[0][0].sum(0).sum(0)[:-1], min(nfilter, num_hot_word-1))[1].tolist() hotword_scores = self.seaco_decoder.decoders[-1].attn_mat[0][0].sum(0).sum(0)[:-1] # hotword_scores /= torch.sqrt(torch.tensor(hw_lengths)[:-1].float()).to(hotword_scores.device) dec_filter = torch.topk(hotword_scores, min(nfilter, num_hot_word-1))[1].tolist() add_filter = dec_filter add_filter.append(len(hw_list_pad)-1) # filter hotword embedding selected = selected[add_filter] # again contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device) num_hot_word = contextual_info.shape[1] _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device) for dec in self.seaco_decoder.decoders: dec.attn_mat = [] dec.reserve_attn = False # SeACo Core cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens) dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens) merged = self._merge(cif_attended, dec_attended) dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation dha_pred = torch.log_softmax(dha_output, dim=-1) # import pdb; pdb.set_trace() def _merge_res(dec_output, dha_output): lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0]) dha_ids = dha_output.max(-1)[-1][0] dha_mask = (dha_ids == 8377).int().unsqueeze(-1) a = (1 - lmbd) / lmbd b = 1 / lmbd a, b = a.to(dec_output.device), b.to(dec_output.device) dha_mask = (dha_mask + a.reshape(-1, 1, 1)) / b.reshape(-1, 1, 1) # logits = dec_output * dha_mask + dha_output[:,:,:-1] * (1-dha_mask) logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask) return logits merged_pred = _merge_res(decoder_pred, dha_pred) return merged_pred else: return decoder_pred def _hotword_representation(self, hotword_pad, hotword_lengths): if self.bias_encoder_type != 'lstm': logging.error("Unsupported bias encoder type") hw_embed = self.decoder.embed(hotword_pad) hw_embed, (_, _) = self.bias_encoder(hw_embed) if self.lstm_proj is not None: hw_embed = self.lstm_proj(hw_embed) _ind = np.arange(0, hw_embed.shape[0]).tolist() selected = hw_embed[_ind, [i-1 for i in hotword_lengths.detach().cpu().tolist()]] return selected def generate(self, data_in, data_lengths=None, key: list = None, tokenizer=None, frontend=None, **kwargs, ): # init beamsearch is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None if self.beam_search is None and (is_use_lm or is_use_ctc): logging.info("enable beam_search") self.init_beam_search(**kwargs) self.nbest = kwargs.get("nbest", 1) meta_data = {} # extract fbank feats time1 = time.perf_counter() audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) time2 = time.perf_counter() meta_data["load_data"] = f"{time2 - time1:0.3f}" speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend) time3 = time.perf_counter() meta_data["extract_feat"] = f"{time3 - time2:0.3f}" meta_data[ "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"]) # hotword self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend) # Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) if isinstance(encoder_out, tuple): encoder_out = encoder_out[0] # predictor predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens) pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \ predictor_outs[2], predictor_outs[3] pre_token_length = pre_token_length.round().long() if torch.max(pre_token_length) < 1: return [] decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list) # decoder_out, _ = decoder_outs[0], decoder_outs[1] results = [] b, n, d = decoder_out.size() for i in range(b): x = encoder_out[i, :encoder_out_lens[i], :] am_scores = decoder_out[i, :pre_token_length[i], :] if self.beam_search is not None: nbest_hyps = self.beam_search( x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0) ) nbest_hyps = nbest_hyps[: self.nbest] else: yseq = am_scores.argmax(dim=-1) score = am_scores.max(dim=-1)[0] score = torch.sum(score, dim=-1) # pad with mask tokens to ensure compatibility with sos/eos tokens yseq = torch.tensor( [self.sos] + yseq.tolist() + [self.eos], device=yseq.device ) nbest_hyps = [Hypothesis(yseq=yseq, score=score)] for nbest_idx, hyp in enumerate(nbest_hyps): ibest_writer = None if ibest_writer is None and kwargs.get("output_dir") is not None: writer = DatadirWriter(kwargs.get("output_dir")) ibest_writer = writer[f"{nbest_idx + 1}best_recog"] # remove sos/eos and get results last_pos = -1 if isinstance(hyp.yseq, list): token_int = hyp.yseq[1:last_pos] else: token_int = hyp.yseq[1:last_pos].tolist() # remove blank symbol id, which is assumed to be 0 token_int = list( filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) if tokenizer is not None: # Change integer-ids to tokens token = tokenizer.ids2tokens(token_int) text = tokenizer.tokens2text(token) text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed} if ibest_writer is not None: ibest_writer["token"][key[i]] = " ".join(token) ibest_writer["text"][key[i]] = text ibest_writer["text_postprocessed"][key[i]] = text_postprocessed else: result_i = {"key": key[i], "token_int": token_int} results.append(result_i) return results, meta_data def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None): def load_seg_dict(seg_dict_file): seg_dict = {} assert isinstance(seg_dict_file, str) with open(seg_dict_file, "r", encoding="utf8") as f: lines = f.readlines() for line in lines: s = line.strip().split() key = s[0] value = s[1:] seg_dict[key] = " ".join(value) return seg_dict def seg_tokenize(txt, seg_dict): pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$') out_txt = "" for word in txt: word = word.lower() if word in seg_dict: out_txt += seg_dict[word] + " " else: if pattern.match(word): for char in word: if char in seg_dict: out_txt += seg_dict[char] + " " else: out_txt += "" + " " else: out_txt += "" + " " return out_txt.strip().split() seg_dict = None if frontend.cmvn_file is not None: model_dir = os.path.dirname(frontend.cmvn_file) seg_dict_file = os.path.join(model_dir, 'seg_dict') if os.path.exists(seg_dict_file): seg_dict = load_seg_dict(seg_dict_file) else: seg_dict = None # for None if hotword_list_or_file is None: hotword_list = None # for local txt inputs elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'): logging.info("Attempting to parse hotwords from local txt...") hotword_list = [] hotword_str_list = [] with codecs.open(hotword_list_or_file, 'r') as fin: for line in fin.readlines(): hw = line.strip() hw_list = hw.split() if seg_dict is not None: hw_list = seg_tokenize(hw_list, seg_dict) hotword_str_list.append(hw) hotword_list.append(tokenizer.tokens2ids(hw_list)) hotword_list.append([self.sos]) hotword_str_list.append('') logging.info("Initialized hotword list from file: {}, hotword list: {}." .format(hotword_list_or_file, hotword_str_list)) # for url, download and generate txt elif hotword_list_or_file.startswith('http'): logging.info("Attempting to parse hotwords from url...") work_dir = tempfile.TemporaryDirectory().name if not os.path.exists(work_dir): os.makedirs(work_dir) text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file)) local_file = requests.get(hotword_list_or_file) open(text_file_path, "wb").write(local_file.content) hotword_list_or_file = text_file_path hotword_list = [] hotword_str_list = [] with codecs.open(hotword_list_or_file, 'r') as fin: for line in fin.readlines(): hw = line.strip() hw_list = hw.split() if seg_dict is not None: hw_list = seg_tokenize(hw_list, seg_dict) hotword_str_list.append(hw) hotword_list.append(tokenizer.tokens2ids(hw_list)) hotword_list.append([self.sos]) hotword_str_list.append('') logging.info("Initialized hotword list from file: {}, hotword list: {}." .format(hotword_list_or_file, hotword_str_list)) # for text str input elif not hotword_list_or_file.endswith('.txt'): logging.info("Attempting to parse hotwords as str...") hotword_list = [] hotword_str_list = [] for hw in hotword_list_or_file.strip().split(): hotword_str_list.append(hw) hw_list = hw.strip().split() if seg_dict is not None: hw_list = seg_tokenize(hw_list, seg_dict) hotword_list.append(tokenizer.tokens2ids(hw_list)) hotword_list.append([self.sos]) hotword_str_list.append('') logging.info("Hotword list: {}.".format(hotword_str_list)) else: hotword_list = None return hotword_list