diff --git a/examples/industrial_data_pretraining/seaco_paraformer/infer.sh b/examples/industrial_data_pretraining/seaco_paraformer/infer.sh index eb0da1fa0..a39083b36 100644 --- a/examples/industrial_data_pretraining/seaco_paraformer/infer.sh +++ b/examples/industrial_data_pretraining/seaco_paraformer/infer.sh @@ -2,7 +2,7 @@ # download model local_path_root=../modelscope_models mkdir -p ${local_path_root} -local_path=${local_path_root}/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404 +local_path=${local_path_root}/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch git clone https://www.modelscope.cn/damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git ${local_path} diff --git a/funasr/models/bici_paraformer/model.py b/funasr/models/bici_paraformer/model.py deleted file mode 100644 index c37ba12f2..000000000 --- a/funasr/models/bici_paraformer/model.py +++ /dev/null @@ -1,338 +0,0 @@ - -import logging -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union -import tempfile -import codecs -import requests -import re -import copy -import torch -import torch.nn as nn -import random -import numpy as np -import time - -from funasr.models.transformer.utils.add_sos_eos import add_sos_eos -from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list -from funasr.metrics.compute_acc import th_accuracy -from funasr.train_utils.device_funcs import force_gatherable - -from funasr.models.paraformer.search import Hypothesis - -from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank -from funasr.utils import postprocess_utils -from funasr.utils.datadir_writer import DatadirWriter -from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard -from funasr.register import tables -from funasr.models.ctc.ctc import CTC - - -from funasr.models.paraformer.model import Paraformer - -@tables.register("model_classes", "BiCifParaformer") -class BiCifParaformer(Paraformer): - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition - https://arxiv.org/abs/2206.08317 - """ - - def __init__( - self, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - - - def _calc_pre2_loss( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ys_pad: torch.Tensor, - ys_pad_lens: torch.Tensor, - ): - encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( - encoder_out.device) - if self.predictor_bias == 1: - _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) - ys_pad_lens = ys_pad_lens + self.predictor_bias - _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id) - - # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length) - loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2) - - return loss_pre2 - - - def _calc_att_loss( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ys_pad: torch.Tensor, - ys_pad_lens: torch.Tensor, - ): - encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( - encoder_out.device) - if self.predictor_bias == 1: - _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) - ys_pad_lens = ys_pad_lens + self.predictor_bias - pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad, - encoder_out_mask, - ignore_id=self.ignore_id) - - # 0. sampler - decoder_out_1st = None - if self.sampling_ratio > 0.0: - sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, - pre_acoustic_embeds) - else: - sematic_embeds = pre_acoustic_embeds - - # 1. Forward decoder - decoder_outs = self.decoder( - encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens - ) - decoder_out, _ = decoder_outs[0], decoder_outs[1] - - if decoder_out_1st is None: - decoder_out_1st = decoder_out - # 2. Compute attention loss - loss_att = self.criterion_att(decoder_out, ys_pad) - acc_att = th_accuracy( - decoder_out_1st.view(-1, self.vocab_size), - ys_pad, - ignore_label=self.ignore_id, - ) - loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length) - - # Compute cer/wer using attention-decoder - if self.training or self.error_calculator is None: - cer_att, wer_att = None, None - else: - ys_hat = decoder_out_1st.argmax(dim=-1) - cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) - - return loss_att, acc_att, cer_att, wer_att, loss_pre - - - def calc_predictor(self, encoder_out, encoder_out_lens): - encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( - encoder_out.device) - pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, - None, - encoder_out_mask, - ignore_id=self.ignore_id) - return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index - - - def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num): - encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( - encoder_out.device) - ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out, - encoder_out_mask, - token_num) - return ds_alphas, ds_cif_peak, us_alphas, us_peaks - - - def forward( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - text: torch.Tensor, - text_lengths: torch.Tensor, - **kwargs, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: - """Frontend + Encoder + Decoder + Calc loss - Args: - speech: (Batch, Length, ...) - speech_lengths: (Batch, ) - text: (Batch, Length) - text_lengths: (Batch,) - """ - if len(text_lengths.size()) > 1: - text_lengths = text_lengths[:, 0] - if len(speech_lengths.size()) > 1: - speech_lengths = speech_lengths[:, 0] - - batch_size = speech.shape[0] - - # Encoder - encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) - - - loss_ctc, cer_ctc = None, None - loss_pre = None - stats = dict() - - # decoder: CTC branch - if self.ctc_weight != 0.0: - loss_ctc, cer_ctc = self._calc_ctc_loss( - encoder_out, encoder_out_lens, text, text_lengths - ) - - # Collect CTC branch stats - stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None - stats["cer_ctc"] = cer_ctc - - - # decoder: Attention decoder branch - loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss( - encoder_out, encoder_out_lens, text, text_lengths - ) - - loss_pre2 = self._calc_pre2_loss( - encoder_out, encoder_out_lens, text, text_lengths - ) - - # 3. CTC-Att loss definition - if self.ctc_weight == 0.0: - loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5 - else: - loss = self.ctc_weight * loss_ctc + ( - 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5 - - # Collect Attn branch stats - stats["loss_att"] = loss_att.detach() if loss_att is not None else None - stats["acc"] = acc_att - stats["cer"] = cer_att - stats["wer"] = wer_att - stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None - stats["loss_pre2"] = loss_pre2.detach().cpu() - - stats["loss"] = torch.clone(loss.detach()) - - # force_gatherable: to-device and to-tensor if scalar for DataParallel - if self.length_normalized_loss: - batch_size = int((text_lengths + self.predictor_bias).sum()) - - loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) - return loss, stats, weight - - - def generate(self, - data_in, - data_lengths=None, - key: list = None, - tokenizer=None, - frontend=None, - **kwargs, - ): - - # init beamsearch - is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None - is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None - if self.beam_search is None and (is_use_lm or is_use_ctc): - logging.info("enable beam_search") - self.init_beam_search(**kwargs) - self.nbest = kwargs.get("nbest", 1) - - meta_data = {} - if isinstance(data_in, torch.Tensor): # fbank - speech, speech_lengths = data_in, data_lengths - if len(speech.shape) < 3: - speech = speech[None, :, :] - if speech_lengths is None: - speech_lengths = speech.shape[1] - else: - # extract fbank feats - time1 = time.perf_counter() - audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) - time2 = time.perf_counter() - meta_data["load_data"] = f"{time2 - time1:0.3f}" - speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), - frontend=frontend) - time3 = time.perf_counter() - meta_data["extract_feat"] = f"{time3 - time2:0.3f}" - meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 - - speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"]) - - # Encoder - encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) - if isinstance(encoder_out, tuple): - encoder_out = encoder_out[0] - - # predictor - predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens) - pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \ - predictor_outs[2], predictor_outs[3] - pre_token_length = pre_token_length.round().long() - if torch.max(pre_token_length) < 1: - return [] - decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds, - pre_token_length) - decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] - - # BiCifParaformer, test no bias cif2 - _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens, - pre_token_length) - - results = [] - b, n, d = decoder_out.size() - for i in range(b): - x = encoder_out[i, :encoder_out_lens[i], :] - am_scores = decoder_out[i, :pre_token_length[i], :] - if self.beam_search is not None: - nbest_hyps = self.beam_search( - x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0), - minlenratio=kwargs.get("minlenratio", 0.0) - ) - - nbest_hyps = nbest_hyps[: self.nbest] - else: - - yseq = am_scores.argmax(dim=-1) - score = am_scores.max(dim=-1)[0] - score = torch.sum(score, dim=-1) - # pad with mask tokens to ensure compatibility with sos/eos tokens - yseq = torch.tensor( - [self.sos] + yseq.tolist() + [self.eos], device=yseq.device - ) - nbest_hyps = [Hypothesis(yseq=yseq, score=score)] - for nbest_idx, hyp in enumerate(nbest_hyps): - ibest_writer = None - if ibest_writer is None and kwargs.get("output_dir") is not None: - writer = DatadirWriter(kwargs.get("output_dir")) - ibest_writer = writer[f"{nbest_idx + 1}best_recog"] - # remove sos/eos and get results - last_pos = -1 - if isinstance(hyp.yseq, list): - token_int = hyp.yseq[1:last_pos] - else: - token_int = hyp.yseq[1:last_pos].tolist() - - # remove blank symbol id, which is assumed to be 0 - token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) - - if tokenizer is not None: - # Change integer-ids to tokens - token = tokenizer.ids2tokens(token_int) - text = tokenizer.tokens2text(token) - - _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3], - us_peaks[i][:encoder_out_lens[i] * 3], - copy.copy(token), - vad_offset=kwargs.get("begin_time", 0)) - - text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess( - token, timestamp) - - result_i = {"key": key[i], "text": text_postprocessed, - "timestamp": time_stamp_postprocessed, - } - - if ibest_writer is not None: - ibest_writer["token"][key[i]] = " ".join(token) - # ibest_writer["text"][key[i]] = text - ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed - ibest_writer["text"][key[i]] = text_postprocessed - else: - result_i = {"key": key[i], "token_int": token_int} - results.append(result_i) - - return results, meta_data \ No newline at end of file diff --git a/funasr/models/bici_paraformer/__init__.py b/funasr/models/bicif_paraformer/__init__.py similarity index 100% rename from funasr/models/bici_paraformer/__init__.py rename to funasr/models/bicif_paraformer/__init__.py diff --git a/funasr/models/bici_paraformer/cif_predictor.py b/funasr/models/bicif_paraformer/cif_predictor.py similarity index 100% rename from funasr/models/bici_paraformer/cif_predictor.py rename to funasr/models/bicif_paraformer/cif_predictor.py diff --git a/funasr/models/bicif_paraformer/model.py b/funasr/models/bicif_paraformer/model.py new file mode 100644 index 000000000..25b0462fb --- /dev/null +++ b/funasr/models/bicif_paraformer/model.py @@ -0,0 +1,340 @@ + +import logging +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union +import tempfile +import codecs +import requests +import re +import copy +import torch +import torch.nn as nn +import random +import numpy as np +import time + +from funasr.models.transformer.utils.add_sos_eos import add_sos_eos +from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list +from funasr.metrics.compute_acc import th_accuracy +from funasr.train_utils.device_funcs import force_gatherable + +from funasr.models.paraformer.search import Hypothesis + +from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank +from funasr.utils import postprocess_utils +from funasr.utils.datadir_writer import DatadirWriter +from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard +from funasr.register import tables +from funasr.models.ctc.ctc import CTC + + +from funasr.models.paraformer.model import Paraformer + +@tables.register("model_classes", "BiCifParaformer") +class BiCifParaformer(Paraformer): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Paper1: FunASR: A Fundamental End-to-End Speech Recognition Toolkit + https://arxiv.org/abs/2305.11013 + Paper2: Achieving timestamp prediction while recognizing with non-autoregressive end-to-end ASR model + https://arxiv.org/abs/2301.12343 + """ + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + + def _calc_pre2_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( + encoder_out.device) + if self.predictor_bias == 1: + _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_pad_lens = ys_pad_lens + self.predictor_bias + _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id) + + # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length) + loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2) + + return loss_pre2 + + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( + encoder_out.device) + if self.predictor_bias == 1: + _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_pad_lens = ys_pad_lens + self.predictor_bias + pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad, + encoder_out_mask, + ignore_id=self.ignore_id) + + # 0. sampler + decoder_out_1st = None + if self.sampling_ratio > 0.0: + sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, + pre_acoustic_embeds) + else: + sematic_embeds = pre_acoustic_embeds + + # 1. Forward decoder + decoder_outs = self.decoder( + encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens + ) + decoder_out, _ = decoder_outs[0], decoder_outs[1] + + if decoder_out_1st is None: + decoder_out_1st = decoder_out + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_pad) + acc_att = th_accuracy( + decoder_out_1st.view(-1, self.vocab_size), + ys_pad, + ignore_label=self.ignore_id, + ) + loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out_1st.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att, loss_pre + + + def calc_predictor(self, encoder_out, encoder_out_lens): + encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( + encoder_out.device) + pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, + None, + encoder_out_mask, + ignore_id=self.ignore_id) + return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index + + + def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num): + encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( + encoder_out.device) + ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out, + encoder_out_mask, + token_num) + return ds_alphas, ds_cif_peak, us_alphas, us_peaks + + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Decoder + Calc loss + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + if len(text_lengths.size()) > 1: + text_lengths = text_lengths[:, 0] + if len(speech_lengths.size()) > 1: + speech_lengths = speech_lengths[:, 0] + + batch_size = speech.shape[0] + + # Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + + + loss_ctc, cer_ctc = None, None + loss_pre = None + stats = dict() + + # decoder: CTC branch + if self.ctc_weight != 0.0: + loss_ctc, cer_ctc = self._calc_ctc_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + # Collect CTC branch stats + stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None + stats["cer_ctc"] = cer_ctc + + + # decoder: Attention decoder branch + loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + loss_pre2 = self._calc_pre2_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + # 3. CTC-Att loss definition + if self.ctc_weight == 0.0: + loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5 + else: + loss = self.ctc_weight * loss_ctc + ( + 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5 + + # Collect Attn branch stats + stats["loss_att"] = loss_att.detach() if loss_att is not None else None + stats["acc"] = acc_att + stats["cer"] = cer_att + stats["wer"] = wer_att + stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None + stats["loss_pre2"] = loss_pre2.detach().cpu() + + stats["loss"] = torch.clone(loss.detach()) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = int((text_lengths + self.predictor_bias).sum()) + + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + + def generate(self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): + + # init beamsearch + is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None + is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None + if self.beam_search is None and (is_use_lm or is_use_ctc): + logging.info("enable beam_search") + self.init_beam_search(**kwargs) + self.nbest = kwargs.get("nbest", 1) + + meta_data = {} + if isinstance(data_in, torch.Tensor): # fbank + speech, speech_lengths = data_in, data_lengths + if len(speech.shape) < 3: + speech = speech[None, :, :] + if speech_lengths is None: + speech_lengths = speech.shape[1] + else: + # extract fbank feats + time1 = time.perf_counter() + audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), + frontend=frontend) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + + speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"]) + + # Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + + # predictor + predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens) + pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \ + predictor_outs[2], predictor_outs[3] + pre_token_length = pre_token_length.round().long() + if torch.max(pre_token_length) < 1: + return [] + decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds, + pre_token_length) + decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] + + # BiCifParaformer, test no bias cif2 + _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens, + pre_token_length) + + results = [] + b, n, d = decoder_out.size() + for i in range(b): + x = encoder_out[i, :encoder_out_lens[i], :] + am_scores = decoder_out[i, :pre_token_length[i], :] + if self.beam_search is not None: + nbest_hyps = self.beam_search( + x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0), + minlenratio=kwargs.get("minlenratio", 0.0) + ) + + nbest_hyps = nbest_hyps[: self.nbest] + else: + + yseq = am_scores.argmax(dim=-1) + score = am_scores.max(dim=-1)[0] + score = torch.sum(score, dim=-1) + # pad with mask tokens to ensure compatibility with sos/eos tokens + yseq = torch.tensor( + [self.sos] + yseq.tolist() + [self.eos], device=yseq.device + ) + nbest_hyps = [Hypothesis(yseq=yseq, score=score)] + for nbest_idx, hyp in enumerate(nbest_hyps): + ibest_writer = None + if ibest_writer is None and kwargs.get("output_dir") is not None: + writer = DatadirWriter(kwargs.get("output_dir")) + ibest_writer = writer[f"{nbest_idx + 1}best_recog"] + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) + + if tokenizer is not None: + # Change integer-ids to tokens + token = tokenizer.ids2tokens(token_int) + text = tokenizer.tokens2text(token) + + _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3], + us_peaks[i][:encoder_out_lens[i] * 3], + copy.copy(token), + vad_offset=kwargs.get("begin_time", 0)) + + text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess( + token, timestamp) + + result_i = {"key": key[i], "text": text_postprocessed, + "timestamp": time_stamp_postprocessed, + } + + if ibest_writer is not None: + ibest_writer["token"][key[i]] = " ".join(token) + # ibest_writer["text"][key[i]] = text + ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed + ibest_writer["text"][key[i]] = text_postprocessed + else: + result_i = {"key": key[i], "token_int": token_int} + results.append(result_i) + + return results, meta_data \ No newline at end of file diff --git a/funasr/models/bici_paraformer/template.yaml b/funasr/models/bicif_paraformer/template.yaml similarity index 100% rename from funasr/models/bici_paraformer/template.yaml rename to funasr/models/bicif_paraformer/template.yaml diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py index d25babe6f..d107a57e4 100644 --- a/funasr/models/seaco_paraformer/model.py +++ b/funasr/models/seaco_paraformer/model.py @@ -1,512 +1,534 @@ import os -import logging -from contextlib import contextmanager -from distutils.version import LooseVersion -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union -import tempfile -import codecs -import requests import re +import time import copy import torch -import torch.nn as nn -import random +import codecs +import logging +import tempfile +import requests import numpy as np -import time -# from funasr.layers.abs_normalize import AbsNormalize +from typing import Dict +from typing import List +from typing import Tuple +from typing import Union +from typing import Optional +from contextlib import contextmanager +from distutils.version import LooseVersion + from funasr.losses.label_smoothing_loss import ( - LabelSmoothingLoss, # noqa: H301 + LabelSmoothingLoss, # noqa: H301 ) -# from funasr.models.ctc import CTC -# from funasr.models.decoder.abs_decoder import AbsDecoder -# from funasr.models.e2e_asr_common import ErrorCalculator -# from funasr.models.encoder.abs_encoder import AbsEncoder -# from funasr.frontends.abs_frontend import AbsFrontend -# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder from funasr.models.paraformer.cif_predictor import mae_loss -# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder -# from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.transformer.utils.add_sos_eos import add_sos_eos from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list +from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard from funasr.metrics.compute_acc import th_accuracy from funasr.train_utils.device_funcs import force_gatherable -# from funasr.models.base_model import FunASRModel -# from funasr.models.paraformer.cif_predictor import CifPredictorV3 from funasr.models.paraformer.search import Hypothesis if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): - from torch.cuda.amp import autocast + from torch.cuda.amp import autocast else: - # Nothing to do if torch<1.6.0 - @contextmanager - def autocast(enabled=True): - yield + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank from funasr.utils import postprocess_utils from funasr.utils.datadir_writer import DatadirWriter from funasr.models.paraformer.model import Paraformer +from funasr.models.bicif_paraformer.model import BiCifParaformer from funasr.register import tables @tables.register("model_classes", "SeacoParaformer") -class SeacoParaformer(Paraformer): - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - SeACo-Paraformer: A Non-Autoregressive ASR System with Flexible and Effective Hotword Customization Ability - https://arxiv.org/abs/2308.03266 - """ - - def __init__( - self, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - - self.inner_dim = kwargs.get("inner_dim", 256) - self.bias_encoder_type = kwargs.get("bias_encoder_type", "lstm") - bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0) - bias_encoder_bid = kwargs.get("bias_encoder_bid", False) - seaco_lsm_weight = kwargs.get("seaco_lsm_weight", 0.0) - seaco_length_normalized_loss = kwargs.get("seaco_length_normalized_loss", True) +class SeacoParaformer(BiCifParaformer, Paraformer): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + SeACo-Paraformer: A Non-Autoregressive ASR System with Flexible and Effective Hotword Customization Ability + https://arxiv.org/abs/2308.03266 + """ + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.inner_dim = kwargs.get("inner_dim", 256) + self.bias_encoder_type = kwargs.get("bias_encoder_type", "lstm") + bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0) + bias_encoder_bid = kwargs.get("bias_encoder_bid", False) + seaco_lsm_weight = kwargs.get("seaco_lsm_weight", 0.0) + seaco_length_normalized_loss = kwargs.get("seaco_length_normalized_loss", True) - # bias encoder - if self.bias_encoder_type == 'lstm': - logging.warning("enable bias encoder sampling and contextual training") - self.bias_encoder = torch.nn.LSTM(self.inner_dim, - self.inner_dim, - 2, - batch_first=True, - dropout=bias_encoder_dropout_rate, - bidirectional=bias_encoder_bid) - if bias_encoder_bid: - self.lstm_proj = torch.nn.Linear(self.inner_dim*2, self.inner_dim) - else: - self.lstm_proj = None - self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim) - elif self.bias_encoder_type == 'mean': - logging.warning("enable bias encoder sampling and contextual training") - self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim) - else: - logging.error("Unsupport bias encoder type: {}".format(self.bias_encoder_type)) + # bias encoder + if self.bias_encoder_type == 'lstm': + logging.warning("enable bias encoder sampling and contextual training") + self.bias_encoder = torch.nn.LSTM(self.inner_dim, + self.inner_dim, + 2, + batch_first=True, + dropout=bias_encoder_dropout_rate, + bidirectional=bias_encoder_bid) + if bias_encoder_bid: + self.lstm_proj = torch.nn.Linear(self.inner_dim*2, self.inner_dim) + else: + self.lstm_proj = None + self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim) + elif self.bias_encoder_type == 'mean': + logging.warning("enable bias encoder sampling and contextual training") + self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim) + else: + logging.error("Unsupport bias encoder type: {}".format(self.bias_encoder_type)) - # seaco decoder - seaco_decoder = kwargs.get("seaco_decoder", None) - if seaco_decoder is not None: - seaco_decoder_conf = kwargs.get("seaco_decoder_conf") - seaco_decoder_class = tables.decoder_classes.get(seaco_decoder.lower()) - self.seaco_decoder = seaco_decoder_class( - vocab_size=self.vocab_size, - encoder_output_size=self.inner_dim, - **seaco_decoder_conf, - ) - self.hotword_output_layer = torch.nn.Linear(self.inner_dim, self.vocab_size) - self.criterion_seaco = LabelSmoothingLoss( - size=self.vocab_size, - padding_idx=self.ignore_id, - smoothing=seaco_lsm_weight, - normalize_length=seaco_length_normalized_loss, - ) - self.train_decoder = kwargs.get("train_decoder", False) - self.NO_BIAS = kwargs.get("NO_BIAS", 8377) - - def forward( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - text: torch.Tensor, - text_lengths: torch.Tensor, - **kwargs, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: - """Frontend + Encoder + Decoder + Calc loss + # seaco decoder + seaco_decoder = kwargs.get("seaco_decoder", None) + if seaco_decoder is not None: + seaco_decoder_conf = kwargs.get("seaco_decoder_conf") + seaco_decoder_class = tables.decoder_classes.get(seaco_decoder.lower()) + self.seaco_decoder = seaco_decoder_class( + vocab_size=self.vocab_size, + encoder_output_size=self.inner_dim, + **seaco_decoder_conf, + ) + self.hotword_output_layer = torch.nn.Linear(self.inner_dim, self.vocab_size) + self.criterion_seaco = LabelSmoothingLoss( + size=self.vocab_size, + padding_idx=self.ignore_id, + smoothing=seaco_lsm_weight, + normalize_length=seaco_length_normalized_loss, + ) + self.train_decoder = kwargs.get("train_decoder", False) + self.NO_BIAS = kwargs.get("NO_BIAS", 8377) + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Decoder + Calc loss - Args: - speech: (Batch, Length, ...) - speech_lengths: (Batch, ) - text: (Batch, Length) - text_lengths: (Batch,) - """ - assert text_lengths.dim() == 1, text_lengths.shape - # Check that batch_size is unified - assert ( - speech.shape[0] - == speech_lengths.shape[0] - == text.shape[0] - == text_lengths.shape[0] - ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) - - hotword_pad = kwargs.get("hotword_pad") - hotword_lengths = kwargs.get("hotword_lengths") - dha_pad = kwargs.get("dha_pad") + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert ( + speech.shape[0] + == speech_lengths.shape[0] + == text.shape[0] + == text_lengths.shape[0] + ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + + hotword_pad = kwargs.get("hotword_pad") + hotword_lengths = kwargs.get("hotword_lengths") + dha_pad = kwargs.get("dha_pad") - batch_size = speech.shape[0] - self.step_cur += 1 - # for data-parallel - text = text[:, : text_lengths.max()] - speech = speech[:, :speech_lengths.max()] + batch_size = speech.shape[0] + self.step_cur += 1 + # for data-parallel + text = text[:, : text_lengths.max()] + speech = speech[:, :speech_lengths.max()] - # 1. Encoder - encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) - if self.predictor_bias == 1: - _, ys_pad = add_sos_eos(text, self.sos, self.eos, self.ignore_id) - ys_lengths = text_lengths + self.predictor_bias + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + if self.predictor_bias == 1: + _, ys_pad = add_sos_eos(text, self.sos, self.eos, self.ignore_id) + ys_lengths = text_lengths + self.predictor_bias - stats = dict() - loss_seaco = self._calc_seaco_loss(encoder_out, - encoder_out_lens, - ys_pad, - ys_lengths, - hotword_pad, - hotword_lengths, - dha_pad, - ) - if self.train_decoder: - loss_att, acc_att = self._calc_att_loss( - encoder_out, encoder_out_lens, text, text_lengths - ) - loss = loss_seaco + loss_att - stats["loss_att"] = torch.clone(loss_att.detach()) - stats["acc_att"] = acc_att - else: - loss = loss_seaco - stats["loss_seaco"] = torch.clone(loss_seaco.detach()) - stats["loss"] = torch.clone(loss.detach()) + stats = dict() + loss_seaco = self._calc_seaco_loss(encoder_out, + encoder_out_lens, + ys_pad, + ys_lengths, + hotword_pad, + hotword_lengths, + dha_pad, + ) + if self.train_decoder: + loss_att, acc_att = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + loss = loss_seaco + loss_att + stats["loss_att"] = torch.clone(loss_att.detach()) + stats["acc_att"] = acc_att + else: + loss = loss_seaco + stats["loss_seaco"] = torch.clone(loss_seaco.detach()) + stats["loss"] = torch.clone(loss.detach()) - # force_gatherable: to-device and to-tensor if scalar for DataParallel - if self.length_normalized_loss: - batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size) - loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) - return loss, stats, weight + # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size) + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight - def _merge(self, cif_attended, dec_attended): - return cif_attended + dec_attended - - def _calc_seaco_loss( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ys_pad: torch.Tensor, - ys_lengths: torch.Tensor, - hotword_pad: torch.Tensor, - hotword_lengths: torch.Tensor, - dha_pad: torch.Tensor, - ): - # predictor forward - encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( - encoder_out.device) - pre_acoustic_embeds, _, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask, - ignore_id=self.ignore_id) - # decoder forward - decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True) - selected = self._hotword_representation(hotword_pad, - hotword_lengths) - contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device) - num_hot_word = contextual_info.shape[1] - _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device) - # dha core - cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, pre_acoustic_embeds, ys_lengths) - dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_out, ys_lengths) - merged = self._merge(cif_attended, dec_attended) - dha_output = self.hotword_output_layer(merged[:, :-1]) # remove the last token in loss calculation - loss_att = self.criterion_seaco(dha_output, dha_pad) - return loss_att + def _merge(self, cif_attended, dec_attended): + return cif_attended + dec_attended + + def _calc_seaco_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_lengths: torch.Tensor, + hotword_pad: torch.Tensor, + hotword_lengths: torch.Tensor, + dha_pad: torch.Tensor, + ): + # predictor forward + encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( + encoder_out.device) + pre_acoustic_embeds, _, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask, + ignore_id=self.ignore_id) + # decoder forward + decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True) + selected = self._hotword_representation(hotword_pad, + hotword_lengths) + contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device) + num_hot_word = contextual_info.shape[1] + _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device) + # dha core + cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, pre_acoustic_embeds, ys_lengths) + dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_out, ys_lengths) + merged = self._merge(cif_attended, dec_attended) + dha_output = self.hotword_output_layer(merged[:, :-1]) # remove the last token in loss calculation + loss_att = self.criterion_seaco(dha_output, dha_pad) + return loss_att - def _seaco_decode_with_ASF(self, - encoder_out, - encoder_out_lens, - sematic_embeds, - ys_pad_lens, - hw_list, - nfilter=50, - seaco_weight=1.0): - # decoder forward - decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True) - decoder_pred = torch.log_softmax(decoder_out, dim=-1) - if hw_list is not None: - hw_lengths = [len(i) for i in hw_list] - hw_list_ = [torch.Tensor(i).long() for i in hw_list] - hw_list_pad = pad_list(hw_list_, 0).to(encoder_out.device) - selected = self._hotword_representation(hw_list_pad, torch.Tensor(hw_lengths).int().to(encoder_out.device)) - contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device) - num_hot_word = contextual_info.shape[1] - _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device) - - # ASF Core - if nfilter > 0 and nfilter < num_hot_word: - for dec in self.seaco_decoder.decoders: - dec.reserve_attn = True - # cif_attended, _ = self.decoder2(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens) - dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens) - # cif_filter = torch.topk(self.decoder2.decoders[-1].attn_mat[0][0].sum(0).sum(0)[:-1], min(nfilter, num_hot_word-1))[1].tolist() - hotword_scores = self.seaco_decoder.decoders[-1].attn_mat[0][0].sum(0).sum(0)[:-1] - # hotword_scores /= torch.sqrt(torch.tensor(hw_lengths)[:-1].float()).to(hotword_scores.device) - dec_filter = torch.topk(hotword_scores, min(nfilter, num_hot_word-1))[1].tolist() - add_filter = dec_filter - add_filter.append(len(hw_list_pad)-1) - # filter hotword embedding - selected = selected[add_filter] - # again - contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device) - num_hot_word = contextual_info.shape[1] - _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device) - for dec in self.seaco_decoder.decoders: - dec.attn_mat = [] - dec.reserve_attn = False - - # SeACo Core - cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens) - dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens) - merged = self._merge(cif_attended, dec_attended) - - dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation - dha_pred = torch.log_softmax(dha_output, dim=-1) - # import pdb; pdb.set_trace() - def _merge_res(dec_output, dha_output): - lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0]) - dha_ids = dha_output.max(-1)[-1][0] - dha_mask = (dha_ids == 8377).int().unsqueeze(-1) - a = (1 - lmbd) / lmbd - b = 1 / lmbd - a, b = a.to(dec_output.device), b.to(dec_output.device) - dha_mask = (dha_mask + a.reshape(-1, 1, 1)) / b.reshape(-1, 1, 1) - # logits = dec_output * dha_mask + dha_output[:,:,:-1] * (1-dha_mask) - logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask) - return logits - merged_pred = _merge_res(decoder_pred, dha_pred) - return merged_pred - else: - return decoder_pred + def _seaco_decode_with_ASF(self, + encoder_out, + encoder_out_lens, + sematic_embeds, + ys_pad_lens, + hw_list, + nfilter=50, + seaco_weight=1.0): + # decoder forward + decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True) + decoder_pred = torch.log_softmax(decoder_out, dim=-1) + if hw_list is not None: + hw_lengths = [len(i) for i in hw_list] + hw_list_ = [torch.Tensor(i).long() for i in hw_list] + hw_list_pad = pad_list(hw_list_, 0).to(encoder_out.device) + selected = self._hotword_representation(hw_list_pad, torch.Tensor(hw_lengths).int().to(encoder_out.device)) + contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device) + num_hot_word = contextual_info.shape[1] + _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device) + + # ASF Core + if nfilter > 0 and nfilter < num_hot_word: + for dec in self.seaco_decoder.decoders: + dec.reserve_attn = True + # cif_attended, _ = self.decoder2(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens) + dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens) + # cif_filter = torch.topk(self.decoder2.decoders[-1].attn_mat[0][0].sum(0).sum(0)[:-1], min(nfilter, num_hot_word-1))[1].tolist() + hotword_scores = self.seaco_decoder.decoders[-1].attn_mat[0][0].sum(0).sum(0)[:-1] + # hotword_scores /= torch.sqrt(torch.tensor(hw_lengths)[:-1].float()).to(hotword_scores.device) + dec_filter = torch.topk(hotword_scores, min(nfilter, num_hot_word-1))[1].tolist() + add_filter = dec_filter + add_filter.append(len(hw_list_pad)-1) + # filter hotword embedding + selected = selected[add_filter] + # again + contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device) + num_hot_word = contextual_info.shape[1] + _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device) + for dec in self.seaco_decoder.decoders: + dec.attn_mat = [] + dec.reserve_attn = False + + # SeACo Core + cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens) + dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens) + merged = self._merge(cif_attended, dec_attended) + + dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation + dha_pred = torch.log_softmax(dha_output, dim=-1) + # import pdb; pdb.set_trace() + def _merge_res(dec_output, dha_output): + lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0]) + dha_ids = dha_output.max(-1)[-1][0] + dha_mask = (dha_ids == 8377).int().unsqueeze(-1) + a = (1 - lmbd) / lmbd + b = 1 / lmbd + a, b = a.to(dec_output.device), b.to(dec_output.device) + dha_mask = (dha_mask + a.reshape(-1, 1, 1)) / b.reshape(-1, 1, 1) + # logits = dec_output * dha_mask + dha_output[:,:,:-1] * (1-dha_mask) + logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask) + return logits + merged_pred = _merge_res(decoder_pred, dha_pred) + return merged_pred + else: + return decoder_pred - def _hotword_representation(self, - hotword_pad, - hotword_lengths): - if self.bias_encoder_type != 'lstm': - logging.error("Unsupported bias encoder type") - hw_embed = self.decoder.embed(hotword_pad) - hw_embed, (_, _) = self.bias_encoder(hw_embed) - if self.lstm_proj is not None: - hw_embed = self.lstm_proj(hw_embed) - _ind = np.arange(0, hw_embed.shape[0]).tolist() - selected = hw_embed[_ind, [i-1 for i in hotword_lengths.detach().cpu().tolist()]] - return selected - - def generate(self, - data_in, - data_lengths=None, - key: list = None, - tokenizer=None, - frontend=None, - **kwargs, - ): - - # init beamsearch - is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None - is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None - if self.beam_search is None and (is_use_lm or is_use_ctc): - logging.info("enable beam_search") - self.init_beam_search(**kwargs) - self.nbest = kwargs.get("nbest", 1) - - meta_data = {} - - # extract fbank feats - time1 = time.perf_counter() - audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) - time2 = time.perf_counter() - meta_data["load_data"] = f"{time2 - time1:0.3f}" - speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), - frontend=frontend) - time3 = time.perf_counter() - meta_data["extract_feat"] = f"{time3 - time2:0.3f}" - meta_data[ - "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 - - speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"]) + def _hotword_representation(self, + hotword_pad, + hotword_lengths): + if self.bias_encoder_type != 'lstm': + logging.error("Unsupported bias encoder type") + hw_embed = self.decoder.embed(hotword_pad) + hw_embed, (_, _) = self.bias_encoder(hw_embed) + if self.lstm_proj is not None: + hw_embed = self.lstm_proj(hw_embed) + _ind = np.arange(0, hw_embed.shape[0]).tolist() + selected = hw_embed[_ind, [i-1 for i in hotword_lengths.detach().cpu().tolist()]] + return selected - # hotword - self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend) - - # Encoder - encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) - if isinstance(encoder_out, tuple): - encoder_out = encoder_out[0] - - # predictor - predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens) - pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \ - predictor_outs[2], predictor_outs[3] - pre_token_length = pre_token_length.round().long() - if torch.max(pre_token_length) < 1: - return [] + ''' + def calc_predictor(self, encoder_out, encoder_out_lens): + encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( + encoder_out.device) + pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, + None, + encoder_out_mask, + ignore_id=self.ignore_id) + return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index - decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens, - pre_acoustic_embeds, - pre_token_length, - hw_list=self.hotword_list) - # decoder_out, _ = decoder_outs[0], decoder_outs[1] - - results = [] - b, n, d = decoder_out.size() - for i in range(b): - x = encoder_out[i, :encoder_out_lens[i], :] - am_scores = decoder_out[i, :pre_token_length[i], :] - if self.beam_search is not None: - nbest_hyps = self.beam_search( - x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0), - minlenratio=kwargs.get("minlenratio", 0.0) - ) - - nbest_hyps = nbest_hyps[: self.nbest] - else: - - yseq = am_scores.argmax(dim=-1) - score = am_scores.max(dim=-1)[0] - score = torch.sum(score, dim=-1) - # pad with mask tokens to ensure compatibility with sos/eos tokens - yseq = torch.tensor( - [self.sos] + yseq.tolist() + [self.eos], device=yseq.device - ) - nbest_hyps = [Hypothesis(yseq=yseq, score=score)] - for nbest_idx, hyp in enumerate(nbest_hyps): - ibest_writer = None - if ibest_writer is None and kwargs.get("output_dir") is not None: - writer = DatadirWriter(kwargs.get("output_dir")) - ibest_writer = writer[f"{nbest_idx + 1}best_recog"] - # remove sos/eos and get results - last_pos = -1 - if isinstance(hyp.yseq, list): - token_int = hyp.yseq[1:last_pos] - else: - token_int = hyp.yseq[1:last_pos].tolist() - - # remove blank symbol id, which is assumed to be 0 - token_int = list( - filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) - - if tokenizer is not None: - # Change integer-ids to tokens - token = tokenizer.ids2tokens(token_int) - text = tokenizer.tokens2text(token) - - text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) - result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed} - - if ibest_writer is not None: - ibest_writer["token"][key[i]] = " ".join(token) - ibest_writer["text"][key[i]] = text - ibest_writer["text_postprocessed"][key[i]] = text_postprocessed - else: - result_i = {"key": key[i], "token_int": token_int} - results.append(result_i) - - return results, meta_data + def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num): + encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( + encoder_out.device) + ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out, + encoder_out_mask, + token_num) + return ds_alphas, ds_cif_peak, us_alphas, us_peaks + ''' + + def generate(self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): + + # init beamsearch + is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None + is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None + if self.beam_search is None and (is_use_lm or is_use_ctc): + logging.info("enable beam_search") + self.init_beam_search(**kwargs) + self.nbest = kwargs.get("nbest", 1) + + meta_data = {} + + # extract fbank feats + time1 = time.perf_counter() + audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), + frontend=frontend) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data[ + "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + + speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"]) + + # hotword + self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend) + + # Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + + # predictor + predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens) + pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \ + predictor_outs[2], predictor_outs[3] + pre_token_length = pre_token_length.round().long() + if torch.max(pre_token_length) < 1: + return [] - def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None): - def load_seg_dict(seg_dict_file): - seg_dict = {} - assert isinstance(seg_dict_file, str) - with open(seg_dict_file, "r", encoding="utf8") as f: - lines = f.readlines() - for line in lines: - s = line.strip().split() - key = s[0] - value = s[1:] - seg_dict[key] = " ".join(value) - return seg_dict - - def seg_tokenize(txt, seg_dict): - pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$') - out_txt = "" - for word in txt: - word = word.lower() - if word in seg_dict: - out_txt += seg_dict[word] + " " - else: - if pattern.match(word): - for char in word: - if char in seg_dict: - out_txt += seg_dict[char] + " " - else: - out_txt += "" + " " - else: - out_txt += "" + " " - return out_txt.strip().split() - - seg_dict = None - if frontend.cmvn_file is not None: - model_dir = os.path.dirname(frontend.cmvn_file) - seg_dict_file = os.path.join(model_dir, 'seg_dict') - if os.path.exists(seg_dict_file): - seg_dict = load_seg_dict(seg_dict_file) - else: - seg_dict = None - # for None - if hotword_list_or_file is None: - hotword_list = None - # for local txt inputs - elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'): - logging.info("Attempting to parse hotwords from local txt...") - hotword_list = [] - hotword_str_list = [] - with codecs.open(hotword_list_or_file, 'r') as fin: - for line in fin.readlines(): - hw = line.strip() - hw_list = hw.split() - if seg_dict is not None: - hw_list = seg_tokenize(hw_list, seg_dict) - hotword_str_list.append(hw) - hotword_list.append(tokenizer.tokens2ids(hw_list)) - hotword_list.append([self.sos]) - hotword_str_list.append('') - logging.info("Initialized hotword list from file: {}, hotword list: {}." - .format(hotword_list_or_file, hotword_str_list)) - # for url, download and generate txt - elif hotword_list_or_file.startswith('http'): - logging.info("Attempting to parse hotwords from url...") - work_dir = tempfile.TemporaryDirectory().name - if not os.path.exists(work_dir): - os.makedirs(work_dir) - text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file)) - local_file = requests.get(hotword_list_or_file) - open(text_file_path, "wb").write(local_file.content) - hotword_list_or_file = text_file_path - hotword_list = [] - hotword_str_list = [] - with codecs.open(hotword_list_or_file, 'r') as fin: - for line in fin.readlines(): - hw = line.strip() - hw_list = hw.split() - if seg_dict is not None: - hw_list = seg_tokenize(hw_list, seg_dict) - hotword_str_list.append(hw) - hotword_list.append(tokenizer.tokens2ids(hw_list)) - hotword_list.append([self.sos]) - hotword_str_list.append('') - logging.info("Initialized hotword list from file: {}, hotword list: {}." - .format(hotword_list_or_file, hotword_str_list)) - # for text str input - elif not hotword_list_or_file.endswith('.txt'): - logging.info("Attempting to parse hotwords as str...") - hotword_list = [] - hotword_str_list = [] - for hw in hotword_list_or_file.strip().split(): - hotword_str_list.append(hw) - hw_list = hw.strip().split() - if seg_dict is not None: - hw_list = seg_tokenize(hw_list, seg_dict) - hotword_list.append(tokenizer.tokens2ids(hw_list)) - hotword_list.append([self.sos]) - hotword_str_list.append('') - logging.info("Hotword list: {}.".format(hotword_str_list)) - else: - hotword_list = None - return hotword_list + decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens, + pre_acoustic_embeds, + pre_token_length, + hw_list=self.hotword_list) + # decoder_out, _ = decoder_outs[0], decoder_outs[1] + _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens, + pre_token_length) + + results = [] + b, n, d = decoder_out.size() + for i in range(b): + x = encoder_out[i, :encoder_out_lens[i], :] + am_scores = decoder_out[i, :pre_token_length[i], :] + if self.beam_search is not None: + nbest_hyps = self.beam_search( + x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0), + minlenratio=kwargs.get("minlenratio", 0.0) + ) + + nbest_hyps = nbest_hyps[: self.nbest] + else: + + yseq = am_scores.argmax(dim=-1) + score = am_scores.max(dim=-1)[0] + score = torch.sum(score, dim=-1) + # pad with mask tokens to ensure compatibility with sos/eos tokens + yseq = torch.tensor( + [self.sos] + yseq.tolist() + [self.eos], device=yseq.device + ) + nbest_hyps = [Hypothesis(yseq=yseq, score=score)] + for nbest_idx, hyp in enumerate(nbest_hyps): + ibest_writer = None + if ibest_writer is None and kwargs.get("output_dir") is not None: + writer = DatadirWriter(kwargs.get("output_dir")) + ibest_writer = writer[f"{nbest_idx + 1}best_recog"] + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list( + filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) + + if tokenizer is not None: + # Change integer-ids to tokens + token = tokenizer.ids2tokens(token_int) + text = tokenizer.tokens2text(token) + + _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3], + us_peaks[i][:encoder_out_lens[i] * 3], + copy.copy(token), + vad_offset=kwargs.get("begin_time", 0)) + + text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess( + token, timestamp) + + result_i = {"key": key[i], "text": text_postprocessed, + "timestamp": time_stamp_postprocessed, + } + + if ibest_writer is not None: + ibest_writer["token"][key[i]] = " ".join(token) + # ibest_writer["text"][key[i]] = text + ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed + ibest_writer["text"][key[i]] = text_postprocessed + else: + result_i = {"key": key[i], "token_int": token_int} + results.append(result_i) + + return results, meta_data + + + def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None): + def load_seg_dict(seg_dict_file): + seg_dict = {} + assert isinstance(seg_dict_file, str) + with open(seg_dict_file, "r", encoding="utf8") as f: + lines = f.readlines() + for line in lines: + s = line.strip().split() + key = s[0] + value = s[1:] + seg_dict[key] = " ".join(value) + return seg_dict + + def seg_tokenize(txt, seg_dict): + pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$') + out_txt = "" + for word in txt: + word = word.lower() + if word in seg_dict: + out_txt += seg_dict[word] + " " + else: + if pattern.match(word): + for char in word: + if char in seg_dict: + out_txt += seg_dict[char] + " " + else: + out_txt += "" + " " + else: + out_txt += "" + " " + return out_txt.strip().split() + + seg_dict = None + if frontend.cmvn_file is not None: + model_dir = os.path.dirname(frontend.cmvn_file) + seg_dict_file = os.path.join(model_dir, 'seg_dict') + if os.path.exists(seg_dict_file): + seg_dict = load_seg_dict(seg_dict_file) + else: + seg_dict = None + # for None + if hotword_list_or_file is None: + hotword_list = None + # for local txt inputs + elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'): + logging.info("Attempting to parse hotwords from local txt...") + hotword_list = [] + hotword_str_list = [] + with codecs.open(hotword_list_or_file, 'r') as fin: + for line in fin.readlines(): + hw = line.strip() + hw_list = hw.split() + if seg_dict is not None: + hw_list = seg_tokenize(hw_list, seg_dict) + hotword_str_list.append(hw) + hotword_list.append(tokenizer.tokens2ids(hw_list)) + hotword_list.append([self.sos]) + hotword_str_list.append('') + logging.info("Initialized hotword list from file: {}, hotword list: {}." + .format(hotword_list_or_file, hotword_str_list)) + # for url, download and generate txt + elif hotword_list_or_file.startswith('http'): + logging.info("Attempting to parse hotwords from url...") + work_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(work_dir): + os.makedirs(work_dir) + text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file)) + local_file = requests.get(hotword_list_or_file) + open(text_file_path, "wb").write(local_file.content) + hotword_list_or_file = text_file_path + hotword_list = [] + hotword_str_list = [] + with codecs.open(hotword_list_or_file, 'r') as fin: + for line in fin.readlines(): + hw = line.strip() + hw_list = hw.split() + if seg_dict is not None: + hw_list = seg_tokenize(hw_list, seg_dict) + hotword_str_list.append(hw) + hotword_list.append(tokenizer.tokens2ids(hw_list)) + hotword_list.append([self.sos]) + hotword_str_list.append('') + logging.info("Initialized hotword list from file: {}, hotword list: {}." + .format(hotword_list_or_file, hotword_str_list)) + # for text str input + elif not hotword_list_or_file.endswith('.txt'): + logging.info("Attempting to parse hotwords as str...") + hotword_list = [] + hotword_str_list = [] + for hw in hotword_list_or_file.strip().split(): + hotword_str_list.append(hw) + hw_list = hw.strip().split() + if seg_dict is not None: + hw_list = seg_tokenize(hw_list, seg_dict) + hotword_list.append(tokenizer.tokens2ids(hw_list)) + hotword_list.append([self.sos]) + hotword_str_list.append('') + logging.info("Hotword list: {}.".format(hotword_str_list)) + else: + hotword_list = None + return hotword_list