diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py index d47db1130..6c7957c5f 100644 --- a/funasr/models/paraformer/model.py +++ b/funasr/models/paraformer/model.py @@ -181,8 +181,6 @@ class Paraformer(torch.nn.Module): text: (Batch, Length) text_lengths: (Batch,) """ - # import pdb; - # pdb.set_trace() if len(text_lengths.size()) > 1: text_lengths = text_lengths[:, 0] if len(speech_lengths.size()) > 1: @@ -190,7 +188,6 @@ class Paraformer(torch.nn.Module): batch_size = speech.shape[0] - # Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py index 21b6abaec..8f8734025 100644 --- a/funasr/models/seaco_paraformer/model.py +++ b/funasr/models/seaco_paraformer/model.py @@ -97,7 +97,8 @@ class SeacoParaformer(BiCifParaformer, Paraformer): smoothing=seaco_lsm_weight, normalize_length=seaco_length_normalized_loss, ) - self.train_decoder = kwargs.get("train_decoder", False) + self.train_decoder = kwargs.get("train_decoder", True) + self.seaco_weight = kwargs.get("seaco_weight", 0.01) self.NO_BIAS = kwargs.get("NO_BIAS", 8377) self.predictor_name = kwargs.get("predictor") @@ -117,9 +118,10 @@ class SeacoParaformer(BiCifParaformer, Paraformer): text: (Batch, Length) text_lengths: (Batch,) """ - text_lengths = text_lengths.squeeze() - speech_lengths = speech_lengths.squeeze() - assert text_lengths.dim() == 1, text_lengths.shape + if len(text_lengths.size()) > 1: + text_lengths = text_lengths[:, 0] + if len(speech_lengths.size()) > 1: + speech_lengths = speech_lengths[:, 0] # Check that batch_size is unified assert ( speech.shape[0] @@ -131,6 +133,8 @@ class SeacoParaformer(BiCifParaformer, Paraformer): hotword_pad = kwargs.get("hotword_pad") hotword_lengths = kwargs.get("hotword_lengths") seaco_label_pad = kwargs.get("seaco_label_pad") + if len(hotword_lengths.size()) > 1: + hotword_lengths = hotword_lengths[:, 0] batch_size = speech.shape[0] # for data-parallel @@ -156,11 +160,12 @@ class SeacoParaformer(BiCifParaformer, Paraformer): loss_att, acc_att = self._calc_att_loss( encoder_out, encoder_out_lens, text, text_lengths ) - loss = loss_seaco + loss_att + loss = loss_seaco + loss_att * self.seaco_weight stats["loss_att"] = torch.clone(loss_att.detach()) stats["acc_att"] = acc_att else: loss = loss_seaco + stats["loss_seaco"] = torch.clone(loss_seaco.detach()) stats["loss"] = torch.clone(loss.detach())