diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py index 605a95c8a..f414e4fd0 100644 --- a/funasr/models/e2e_asr_paraformer.py +++ b/funasr/models/e2e_asr_paraformer.py @@ -92,17 +92,8 @@ class Paraformer(FunASRModel): self.frontend = frontend self.specaug = specaug self.normalize = normalize - self.preencoder = preencoder - self.postencoder = postencoder self.encoder = encoder - if not hasattr(self.encoder, "interctc_use_conditioning"): - self.encoder.interctc_use_conditioning = False - if self.encoder.interctc_use_conditioning: - self.encoder.conditioning_layer = torch.nn.Linear( - vocab_size, self.encoder.output_size() - ) - self.error_calculator = None if ctc_weight == 1.0: @@ -170,9 +161,7 @@ class Paraformer(FunASRModel): # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) - intermediate_outs = None if isinstance(encoder_out, tuple): - intermediate_outs = encoder_out[1] encoder_out = encoder_out[0] loss_att, acc_att, cer_att, wer_att = None, None, None, None @@ -190,30 +179,6 @@ class Paraformer(FunASRModel): stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None stats["cer_ctc"] = cer_ctc - # Intermediate CTC (optional) - loss_interctc = 0.0 - if self.interctc_weight != 0.0 and intermediate_outs is not None: - for layer_idx, intermediate_out in intermediate_outs: - # we assume intermediate_out has the same length & padding - # as those of encoder_out - loss_ic, cer_ic = self._calc_ctc_loss( - intermediate_out, encoder_out_lens, text, text_lengths - ) - loss_interctc = loss_interctc + loss_ic - - # Collect Intermedaite CTC stats - stats["loss_interctc_layer{}".format(layer_idx)] = ( - loss_ic.detach() if loss_ic is not None else None - ) - stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic - - loss_interctc = loss_interctc / len(intermediate_outs) - - # calculate whole encoder loss - loss_ctc = ( - 1 - self.interctc_weight - ) * loss_ctc + self.interctc_weight * loss_interctc - # 2b. Attention decoder branch if self.ctc_weight != 1.0: loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss( @@ -281,29 +246,8 @@ class Paraformer(FunASRModel): if self.normalize is not None: feats, feats_lengths = self.normalize(feats, feats_lengths) - # Pre-encoder, e.g. used for raw input data - if self.preencoder is not None: - feats, feats_lengths = self.preencoder(feats, feats_lengths) - # 4. Forward encoder - # feats: (Batch, Length, Dim) - # -> encoder_out: (Batch, Length2, Dim2) - if self.encoder.interctc_use_conditioning: - encoder_out, encoder_out_lens, _ = self.encoder( - feats, feats_lengths, ctc=self.ctc - ) - else: - encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) - intermediate_outs = None - if isinstance(encoder_out, tuple): - intermediate_outs = encoder_out[1] - encoder_out = encoder_out[0] - - # Post-encoder, e.g. NLU - if self.postencoder is not None: - encoder_out, encoder_out_lens = self.postencoder( - encoder_out, encoder_out_lens - ) + encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) assert encoder_out.size(0) == speech.size(0), ( encoder_out.size(), @@ -314,9 +258,6 @@ class Paraformer(FunASRModel): encoder_out_lens.max(), ) - if intermediate_outs is not None: - return (encoder_out, intermediate_outs), encoder_out_lens - return encoder_out, encoder_out_lens def encode_chunk( @@ -340,32 +281,8 @@ class Paraformer(FunASRModel): if self.normalize is not None: feats, feats_lengths = self.normalize(feats, feats_lengths) - # Pre-encoder, e.g. used for raw input data - if self.preencoder is not None: - feats, feats_lengths = self.preencoder(feats, feats_lengths) - # 4. Forward encoder - # feats: (Batch, Length, Dim) - # -> encoder_out: (Batch, Length2, Dim2) - if self.encoder.interctc_use_conditioning: - encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk( - feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc - ) - else: - encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"]) - intermediate_outs = None - if isinstance(encoder_out, tuple): - intermediate_outs = encoder_out[1] - encoder_out = encoder_out[0] - - # Post-encoder, e.g. NLU - if self.postencoder is not None: - encoder_out, encoder_out_lens = self.postencoder( - encoder_out, encoder_out_lens - ) - - if intermediate_outs is not None: - return (encoder_out, intermediate_outs), encoder_out_lens + encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"]) return encoder_out, torch.tensor([encoder_out.size(1)])