From 74f4f7b4c7228763fd923d2a27eb7d2515f89907 Mon Sep 17 00:00:00 2001 From: Shi Xian <40013335+R1ckShi@users.noreply.github.com> Date: Fri, 8 Mar 2024 11:33:04 +0800 Subject: [PATCH] seaco with cifv2 (#1450) * seaco with cifv2 --- funasr/models/seaco_paraformer/model.py | 74 +++++++++++++++---------- 1 file changed, 46 insertions(+), 28 deletions(-) diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py index a8b1f1fb1..f671db654 100644 --- a/funasr/models/seaco_paraformer/model.py +++ b/funasr/models/seaco_paraformer/model.py @@ -30,7 +30,7 @@ from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank -import pdb + if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast else: @@ -99,6 +99,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer): ) self.train_decoder = kwargs.get("train_decoder", False) self.NO_BIAS = kwargs.get("NO_BIAS", 8377) + self.predictor_name = kwargs.get("predictor") def forward( self, @@ -170,6 +171,16 @@ class SeacoParaformer(BiCifParaformer, Paraformer): def _merge(self, cif_attended, dec_attended): return cif_attended + dec_attended + def calc_predictor(self, encoder_out, encoder_out_lens): + encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( + encoder_out.device) + predictor_outs = self.predictor(encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id) + if len(predictor_outs) == 4: + pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs + else: + pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = predictor_outs + return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index + def _calc_seaco_loss( self, encoder_out: torch.Tensor, @@ -248,7 +259,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer): def _merge_res(dec_output, dha_output): lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0]) dha_ids = dha_output.max(-1)[-1]# [0] - dha_mask = (dha_ids == 8377).int().unsqueeze(-1) + dha_mask = (dha_ids == self.NO_BIAS).int().unsqueeze(-1) a = (1 - lmbd) / lmbd b = 1 / lmbd a, b = a.to(dec_output.device), b.to(dec_output.device) @@ -332,23 +343,28 @@ class SeacoParaformer(BiCifParaformer, Paraformer): if isinstance(encoder_out, tuple): encoder_out = encoder_out[0] - # predictor predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens) - pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \ - predictor_outs[2], predictor_outs[3] + pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1] pre_token_length = pre_token_length.round().long() if torch.max(pre_token_length) < 1: return [] - decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens, - pre_acoustic_embeds, - pre_token_length, - hw_list=self.hotword_list) + decoder_out = self._seaco_decode_with_ASF(encoder_out, + encoder_out_lens, + pre_acoustic_embeds, + pre_token_length, + hw_list=self.hotword_list + ) # decoder_out, _ = decoder_outs[0], decoder_outs[1] - _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens, - pre_token_length) + if self.predictor_name == "CifPredictorV3": + _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, + encoder_out_lens, + pre_token_length) + else: + us_alphas = None + results = [] b, n, d = decoder_out.size() for i in range(b): @@ -393,23 +409,25 @@ class SeacoParaformer(BiCifParaformer, Paraformer): # Change integer-ids to tokens token = tokenizer.ids2tokens(token_int) text = tokenizer.tokens2text(token) - - _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3], - us_peaks[i][:encoder_out_lens[i] * 3], - copy.copy(token), - vad_offset=kwargs.get("begin_time", 0)) - - text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess( - token, timestamp) - - result_i = {"key": key[i], "text": text_postprocessed, - "timestamp": time_stamp_postprocessed - } - - if ibest_writer is not None: - ibest_writer["token"][key[i]] = " ".join(token) - ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed - ibest_writer["text"][key[i]] = text_postprocessed + if us_alphas is not None: + _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3], + us_peaks[i][:encoder_out_lens[i] * 3], + copy.copy(token), + vad_offset=kwargs.get("begin_time", 0)) + text_postprocessed, time_stamp_postprocessed, _ = \ + postprocess_utils.sentence_postprocess(token, timestamp) + result_i = {"key": key[i], "text": text_postprocessed, + "timestamp": time_stamp_postprocessed} + if ibest_writer is not None: + ibest_writer["token"][key[i]] = " ".join(token) + ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed + ibest_writer["text"][key[i]] = text_postprocessed + else: + text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) + result_i = {"key": key[i], "text": text_postprocessed} + if ibest_writer is not None: + ibest_writer["token"][key[i]] = " ".join(token) + ibest_writer["text"][key[i]] = text_postprocessed else: result_i = {"key": key[i], "token_int": token_int} results.append(result_i)