From 980007a486a4a2cfbcf4f98c7a88e28e17fd0f48 Mon Sep 17 00:00:00 2001 From: Shi Xian <40013335+R1ckShi@users.noreply.github.com> Date: Thu, 21 Mar 2024 14:59:46 +0800 Subject: [PATCH] update seaco finetune (#1526) --- .../seaco_paraformer/finetune.sh | 2 +- funasr/models/contextual_paraformer/model.py | 6 ++---- funasr/models/seaco_paraformer/model.py | 7 ++++--- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/industrial_data_pretraining/seaco_paraformer/finetune.sh b/examples/industrial_data_pretraining/seaco_paraformer/finetune.sh index cfdec7779..5614f44a5 100644 --- a/examples/industrial_data_pretraining/seaco_paraformer/finetune.sh +++ b/examples/industrial_data_pretraining/seaco_paraformer/finetune.sh @@ -10,7 +10,7 @@ gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ## option 1, download model automatically model_name_or_model_dir="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" -model_revision="v2.0.4" +model_revision="v2.0.7" ## option 2, download model by git #local_path_root=${workspace}/modelscope_models diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py index 9968bf2e4..b9fd3c463 100644 --- a/funasr/models/contextual_paraformer/model.py +++ b/funasr/models/contextual_paraformer/model.py @@ -94,10 +94,8 @@ class ContextualParaformer(Paraformer): text: (Batch, Length) text_lengths: (Batch,) """ - if len(text_lengths.size()) > 1: - text_lengths = text_lengths[:, 0] - if len(speech_lengths.size()) > 1: - speech_lengths = speech_lengths[:, 0] + text_lengths = text_lengths.squeeze() + speech_lengths = speech_lengths.squeeze() batch_size = speech.shape[0] diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py index 27ff5d1f8..21b6abaec 100644 --- a/funasr/models/seaco_paraformer/model.py +++ b/funasr/models/seaco_paraformer/model.py @@ -117,6 +117,8 @@ class SeacoParaformer(BiCifParaformer, Paraformer): text: (Batch, Length) text_lengths: (Batch,) """ + text_lengths = text_lengths.squeeze() + speech_lengths = speech_lengths.squeeze() assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert ( @@ -164,7 +166,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer): # force_gatherable: to-device and to-tensor if scalar for DataParallel if self.length_normalized_loss: - batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size) + batch_size = (text_lengths + self.predictor_bias).sum() loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight @@ -190,8 +192,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer): # predictor forward encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( encoder_out.device) - pre_acoustic_embeds, _, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask, - ignore_id=self.ignore_id) + pre_acoustic_embeds = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)[0] # decoder forward decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True) selected = self._hotword_representation(hotword_pad,