diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py index 18cd343e7..508eb7378 100644 --- a/funasr/models/decoder/sanm_decoder.py +++ b/funasr/models/decoder/sanm_decoder.py @@ -935,6 +935,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): hlens: torch.Tensor, ys_in_pad: torch.Tensor, ys_in_lens: torch.Tensor, + chunk_mask: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward decoder. @@ -955,6 +956,10 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): """ tgt = ys_in_pad tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] + if chunk_mask is not None: + memory_mask = memory_mask * chunk_mask + if tgt_mask.size(1) != memory_mask.size(1): + memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1) memory = hs_pad memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py index 83a45245f..54db971d0 100644 --- a/funasr/models/e2e_asr_paraformer.py +++ b/funasr/models/e2e_asr_paraformer.py @@ -161,6 +161,7 @@ class Paraformer(FunASRModel): speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) + decoding_ind: int """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified @@ -278,11 +279,12 @@ class Paraformer(FunASRModel): def encode( self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[Tuple[Any, Optional[Any]], Any]: """Frontend + Encoder. Note that this method is used by asr_inference.py Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) + ind: int """ with autocast(False): # 1. Extract feats @@ -611,9 +613,108 @@ class ParaformerOnline(Paraformer): """ def __init__( - self, *args, **kwargs, + self, + vocab_size: int, + token_list: Union[Tuple[str, ...], List[str]], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + encoder: AbsEncoder, + decoder: AbsDecoder, + ctc: CTC, + ctc_weight: float = 0.5, + interctc_weight: float = 0.0, + ignore_id: int = -1, + blank_id: int = 0, + sos: int = 1, + eos: int = 2, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + report_cer: bool = True, + report_wer: bool = True, + sym_space: str = "", + sym_blank: str = "", + extract_feats_in_collect_stats: bool = True, + predictor=None, + predictor_weight: float = 0.0, + predictor_bias: int = 0, + sampling_ratio: float = 0.2, + decoder_attention_chunk_type: str = 'chunk', + share_embedding: bool = False, + preencoder: Optional[AbsPreEncoder] = None, + postencoder: Optional[AbsPostEncoder] = None, + use_1st_decoder_loss: bool = False, ): - super().__init__(*args, **kwargs) + assert check_argument_types() + assert 0.0 <= ctc_weight <= 1.0, ctc_weight + assert 0.0 <= interctc_weight < 1.0, interctc_weight + + super().__init__() + # note that eos is the same as sos (equivalent ID) + self.blank_id = blank_id + self.sos = vocab_size - 1 if sos is None else sos + self.eos = vocab_size - 1 if eos is None else eos + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.ctc_weight = ctc_weight + self.interctc_weight = interctc_weight + self.token_list = token_list.copy() + + self.frontend = frontend + self.specaug = specaug + self.normalize = normalize + self.preencoder = preencoder + self.postencoder = postencoder + self.encoder = encoder + + if not hasattr(self.encoder, "interctc_use_conditioning"): + self.encoder.interctc_use_conditioning = False + if self.encoder.interctc_use_conditioning: + self.encoder.conditioning_layer = torch.nn.Linear( + vocab_size, self.encoder.output_size() + ) + + self.error_calculator = None + + if ctc_weight == 1.0: + self.decoder = None + else: + self.decoder = decoder + + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + if report_cer or report_wer: + self.error_calculator = ErrorCalculator( + token_list, sym_space, sym_blank, report_cer, report_wer + ) + + if ctc_weight == 0.0: + self.ctc = None + else: + self.ctc = ctc + + self.extract_feats_in_collect_stats = extract_feats_in_collect_stats + self.predictor = predictor + self.predictor_weight = predictor_weight + self.predictor_bias = predictor_bias + self.sampling_ratio = sampling_ratio + self.criterion_pre = mae_loss(normalize_length=length_normalized_loss) + self.step_cur = 0 + if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None: + from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder + self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder + self.decoder_attention_chunk_type = decoder_attention_chunk_type + + self.share_embedding = share_embedding + if self.share_embedding: + self.decoder.embed = None + + self.use_1st_decoder_loss = use_1st_decoder_loss def forward( self, @@ -621,6 +722,7 @@ class ParaformerOnline(Paraformer): speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, + decoding_ind: int = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: @@ -628,6 +730,7 @@ class ParaformerOnline(Paraformer): speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) + decoding_ind: int """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified @@ -644,7 +747,11 @@ class ParaformerOnline(Paraformer): speech = speech[:, :speech_lengths.max()] # 1. Encoder - encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + if hasattr(self.encoder, "overlap_chunk_cls"): + ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind) + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind) + else: + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) intermediate_outs = None if isinstance(encoder_out, tuple): intermediate_outs = encoder_out[1] @@ -657,8 +764,12 @@ class ParaformerOnline(Paraformer): # 1. CTC branch if self.ctc_weight != 0.0: + if hasattr(self.encoder, "overlap_chunk_cls"): + encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, + encoder_out_lens, + chunk_outs=None) loss_ctc, cer_ctc = self._calc_ctc_loss( - encoder_out, encoder_out_lens, text, text_lengths + encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths ) # Collect CTC branch stats @@ -671,8 +782,14 @@ class ParaformerOnline(Paraformer): for layer_idx, intermediate_out in intermediate_outs: # we assume intermediate_out has the same length & padding # as those of encoder_out + if hasattr(self.encoder, "overlap_chunk_cls"): + encoder_out_ctc, encoder_out_lens_ctc = \ + self.encoder.overlap_chunk_cls.remove_chunk( + intermediate_out, + encoder_out_lens, + chunk_outs=None) loss_ic, cer_ic = self._calc_ctc_loss( - intermediate_out, encoder_out_lens, text, text_lengths + encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths ) loss_interctc = loss_interctc + loss_ic @@ -691,7 +808,7 @@ class ParaformerOnline(Paraformer): # 2b. Attention decoder branch if self.ctc_weight != 1.0: - loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss( + loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_predictor_loss( encoder_out, encoder_out_lens, text, text_lengths ) @@ -703,8 +820,12 @@ class ParaformerOnline(Paraformer): else: loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + if self.use_1st_decoder_loss and pre_loss_att is not None: + loss = loss + pre_loss_att + # Collect Attn branch stats stats["loss_att"] = loss_att.detach() if loss_att is not None else None + stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None stats["acc"] = acc_att stats["cer"] = cer_att stats["wer"] = wer_att @@ -716,6 +837,63 @@ class ParaformerOnline(Paraformer): loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight + def encode( + self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Frontend + Encoder. Note that this method is used by asr_inference.py + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + # Pre-encoder, e.g. used for raw input data + if self.preencoder is not None: + feats, feats_lengths = self.preencoder(feats, feats_lengths) + + # 4. Forward encoder + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + if self.encoder.interctc_use_conditioning: + encoder_out, encoder_out_lens, _ = self.encoder( + feats, feats_lengths, ctc=self.ctc, ind=ind + ) + else: + encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + # Post-encoder, e.g. NLU + if self.postencoder is not None: + encoder_out, encoder_out_lens = self.postencoder( + encoder_out, encoder_out_lens + ) + + assert encoder_out.size(0) == speech.size(0), ( + encoder_out.size(), + speech.size(0), + ) + assert encoder_out.size(1) <= encoder_out_lens.max(), ( + encoder_out.size(), + encoder_out_lens.max(), + ) + + if intermediate_outs is not None: + return (encoder_out, intermediate_outs), encoder_out_lens + + return encoder_out, encoder_out_lens + def encode_chunk( self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -765,6 +943,174 @@ class ParaformerOnline(Paraformer): return encoder_out, torch.tensor([encoder_out.size(1)]) + def _calc_att_predictor_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( + encoder_out.device) + if self.predictor_bias == 1: + _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_pad_lens = ys_pad_lens + self.predictor_bias + mask_chunk_predictor = None + if self.encoder.overlap_chunk_cls is not None: + mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None, + device=encoder_out.device, + batch_size=encoder_out.size( + 0)) + mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device, + batch_size=encoder_out.size(0)) + encoder_out = encoder_out * mask_shfit_chunk + pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out, + ys_pad, + encoder_out_mask, + ignore_id=self.ignore_id, + mask_chunk_predictor=mask_chunk_predictor, + target_label_length=ys_pad_lens, + ) + predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas, + encoder_out_lens) + + scama_mask = None + if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk': + encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur + attention_chunk_center_bias = 0 + attention_chunk_size = encoder_chunk_size + decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur + mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\ + get_mask_shift_att_chunk_decoder(None, + device=encoder_out.device, + batch_size=encoder_out.size(0) + ) + scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn( + predictor_alignments=predictor_alignments, + encoder_sequence_length=encoder_out_lens, + chunk_size=1, + encoder_chunk_size=encoder_chunk_size, + attention_chunk_center_bias=attention_chunk_center_bias, + attention_chunk_size=attention_chunk_size, + attention_chunk_type=self.decoder_attention_chunk_type, + step=None, + predictor_mask_chunk_hopping=mask_chunk_predictor, + decoder_att_look_back_factor=decoder_att_look_back_factor, + mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder, + target_length=ys_pad_lens, + is_training=self.training, + ) + elif self.encoder.overlap_chunk_cls is not None: + encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, + encoder_out_lens, + chunk_outs=None) + # 0. sampler + decoder_out_1st = None + pre_loss_att = None + if self.sampling_ratio > 0.0: + if self.step_cur < 2: + logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) + if self.use_1st_decoder_loss: + sematic_embeds, decoder_out_1st, pre_loss_att = \ + self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, + ys_pad_lens, pre_acoustic_embeds, scama_mask) + else: + sematic_embeds, decoder_out_1st = \ + self.sampler(encoder_out, encoder_out_lens, ys_pad, + ys_pad_lens, pre_acoustic_embeds, scama_mask) + else: + if self.step_cur < 2: + logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) + sematic_embeds = pre_acoustic_embeds + + # 1. Forward decoder + decoder_outs = self.decoder( + encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, scama_mask + ) + decoder_out, _ = decoder_outs[0], decoder_outs[1] + + if decoder_out_1st is None: + decoder_out_1st = decoder_out + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_pad) + acc_att = th_accuracy( + decoder_out_1st.view(-1, self.vocab_size), + ys_pad, + ignore_label=self.ignore_id, + ) + loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out_1st.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att + + def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None): + + tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device) + ys_pad_masked = ys_pad * tgt_mask[:, :, 0] + if self.share_embedding: + ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked] + else: + ys_pad_embed = self.decoder.embed(ys_pad_masked) + with torch.no_grad(): + decoder_outs = self.decoder( + encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask + ) + decoder_out, _ = decoder_outs[0], decoder_outs[1] + pred_tokens = decoder_out.argmax(-1) + nonpad_positions = ys_pad.ne(self.ignore_id) + seq_lens = (nonpad_positions).sum(1) + same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1) + input_mask = torch.ones_like(nonpad_positions) + bsz, seq_len = ys_pad.size() + for li in range(bsz): + target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long() + if target_num > 0: + input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0) + input_mask = input_mask.eq(1) + input_mask = input_mask.masked_fill(~nonpad_positions, False) + input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device) + + sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill( + input_mask_expand_dim, 0) + return sematic_embeds * tgt_mask, decoder_out * tgt_mask + + def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None): + tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device) + ys_pad_masked = ys_pad * tgt_mask[:, :, 0] + if self.share_embedding: + ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked] + else: + ys_pad_embed = self.decoder.embed(ys_pad_masked) + decoder_outs = self.decoder( + encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask + ) + pre_loss_att = self.criterion_att(decoder_outs[0], ys_pad) + decoder_out, _ = decoder_outs[0], decoder_outs[1] + pred_tokens = decoder_out.argmax(-1) + nonpad_positions = ys_pad.ne(self.ignore_id) + seq_lens = (nonpad_positions).sum(1) + same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1) + input_mask = torch.ones_like(nonpad_positions) + bsz, seq_len = ys_pad.size() + for li in range(bsz): + target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long() + if target_num > 0: + input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0) + input_mask = input_mask.eq(1) + input_mask = input_mask.masked_fill(~nonpad_positions, False) + input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device) + + sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill( + input_mask_expand_dim, 0) + + return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att + def calc_predictor_chunk(self, encoder_out, cache=None): pre_acoustic_embeds, pre_token_length = \