From 84b4a01979ecc483096cccf5185dbe5e56946217 Mon Sep 17 00:00:00 2001 From: "haoneng.lhn" Date: Fri, 26 May 2023 11:43:27 +0800 Subject: [PATCH] add paraformer online infer and finetune --- funasr/bin/asr_inference_launch.py | 4 +- funasr/models/decoder/sanm_decoder.py | 6 +- funasr/models/e2e_asr_paraformer.py | 96 ++++++++++++++++++++++++++- 3 files changed, 99 insertions(+), 7 deletions(-) diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index dbbb3ed1e..f5296f678 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -1618,6 +1618,8 @@ def inference_launch(**kwargs): return inference_uniasr(**kwargs) elif mode == "paraformer": return inference_paraformer(**kwargs) + elif mode == "paraformer_online": + return inference_paraformer(**kwargs) elif mode == "paraformer_streaming": return inference_paraformer_online(**kwargs) elif mode.startswith("paraformer_vad"): @@ -1900,4 +1902,4 @@ def main(cmd=None): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py index 508eb7378..ed920bfbb 100644 --- a/funasr/models/decoder/sanm_decoder.py +++ b/funasr/models/decoder/sanm_decoder.py @@ -956,14 +956,14 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): """ tgt = ys_in_pad tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] + + memory = hs_pad + memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] if chunk_mask is not None: memory_mask = memory_mask * chunk_mask if tgt_mask.size(1) != memory_mask.size(1): memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1) - memory = hs_pad - memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] - x = tgt x, tgt_mask, memory, memory_mask, _ = self.decoders( x, tgt_mask, memory, memory_mask diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py index 54db971d0..09af2cdef 100644 --- a/funasr/models/e2e_asr_paraformer.py +++ b/funasr/models/e2e_asr_paraformer.py @@ -279,7 +279,7 @@ class Paraformer(FunASRModel): def encode( self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0, - ) -> Tuple[Tuple[Any, Optional[Any]], Any]: + ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder. Note that this method is used by asr_inference.py Args: speech: (Batch, Length, ...) @@ -649,7 +649,35 @@ class ParaformerOnline(Paraformer): assert 0.0 <= ctc_weight <= 1.0, ctc_weight assert 0.0 <= interctc_weight < 1.0, interctc_weight - super().__init__() + super().__init__( + vocab_size=vocab_size, + token_list=token_list, + frontend=frontend, + specaug=specaug, + normalize=normalize, + preencoder=preencoder, + encoder=encoder, + postencoder=postencoder, + decoder=decoder, + ctc=ctc, + ctc_weight=ctc_weight, + interctc_weight=interctc_weight, + ignore_id=ignore_id, + blank_id=blank_id, + sos=sos, + eos=eos, + lsm_weight=lsm_weight, + length_normalized_loss=length_normalized_loss, + report_cer=report_cer, + report_wer=report_wer, + sym_space=sym_space, + sym_blank=sym_blank, + extract_feats_in_collect_stats=extract_feats_in_collect_stats, + predictor=predictor, + predictor_weight=predictor_weight, + predictor_bias=predictor_bias, + sampling_ratio=sampling_ratio, + ) # note that eos is the same as sos (equivalent ID) self.blank_id = blank_id self.sos = vocab_size - 1 if sos is None else sos @@ -705,6 +733,7 @@ class ParaformerOnline(Paraformer): self.sampling_ratio = sampling_ratio self.criterion_pre = mae_loss(normalize_length=length_normalized_loss) self.step_cur = 0 + self.scama_mask = None if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None: from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder @@ -859,7 +888,7 @@ class ParaformerOnline(Paraformer): # Pre-encoder, e.g. used for raw input data if self.preencoder is not None: feats, feats_lengths = self.preencoder(feats, feats_lengths) - + # 4. Forward encoder # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) @@ -1111,12 +1140,73 @@ class ParaformerOnline(Paraformer): return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att + def calc_predictor(self, encoder_out, encoder_out_lens): + + encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( + encoder_out.device) + mask_chunk_predictor = None + if self.encoder.overlap_chunk_cls is not None: + mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None, + device=encoder_out.device, + batch_size=encoder_out.size( + 0)) + mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device, + batch_size=encoder_out.size(0)) + encoder_out = encoder_out * mask_shfit_chunk + pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out, + None, + encoder_out_mask, + ignore_id=self.ignore_id, + mask_chunk_predictor=mask_chunk_predictor, + target_label_length=None, + ) + predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas[:, :-1], + encoder_out_lens) + + scama_mask = None + if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk': + encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur + attention_chunk_center_bias = 0 + attention_chunk_size = encoder_chunk_size + decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur + mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\ + get_mask_shift_att_chunk_decoder(None, + device=encoder_out.device, + batch_size=encoder_out.size(0) + ) + scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn( + predictor_alignments=predictor_alignments, + encoder_sequence_length=encoder_out_lens, + chunk_size=1, + encoder_chunk_size=encoder_chunk_size, + attention_chunk_center_bias=attention_chunk_center_bias, + attention_chunk_size=attention_chunk_size, + attention_chunk_type=self.decoder_attention_chunk_type, + step=None, + predictor_mask_chunk_hopping=mask_chunk_predictor, + decoder_att_look_back_factor=decoder_att_look_back_factor, + mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder, + target_length=None, + is_training=self.training, + ) + self.scama_mask = scama_mask + + return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index + def calc_predictor_chunk(self, encoder_out, cache=None): pre_acoustic_embeds, pre_token_length = \ self.predictor.forward_chunk(encoder_out, cache["encoder"]) return pre_acoustic_embeds, pre_token_length + def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens): + decoder_outs = self.decoder( + encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask + ) + decoder_out = decoder_outs[0] + decoder_out = torch.log_softmax(decoder_out, dim=-1) + return decoder_out, ys_pad_lens + def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None): decoder_outs = self.decoder.forward_chunk( encoder_out, sematic_embeds, cache["decoder"]