update seaco finetune (#1526)

This commit is contained in:
Shi Xian 2024-03-21 14:59:46 +08:00 committed by GitHub
parent 944a053a66
commit 980007a486
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 7 additions and 8 deletions

View File

@ -10,7 +10,7 @@ gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
## option 1, download model automatically ## option 1, download model automatically
model_name_or_model_dir="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" model_name_or_model_dir="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.4" model_revision="v2.0.7"
## option 2, download model by git ## option 2, download model by git
#local_path_root=${workspace}/modelscope_models #local_path_root=${workspace}/modelscope_models

View File

@ -94,10 +94,8 @@ class ContextualParaformer(Paraformer):
text: (Batch, Length) text: (Batch, Length)
text_lengths: (Batch,) text_lengths: (Batch,)
""" """
if len(text_lengths.size()) > 1: text_lengths = text_lengths.squeeze()
text_lengths = text_lengths[:, 0] speech_lengths = speech_lengths.squeeze()
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0] batch_size = speech.shape[0]

View File

@ -117,6 +117,8 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
text: (Batch, Length) text: (Batch, Length)
text_lengths: (Batch,) text_lengths: (Batch,)
""" """
text_lengths = text_lengths.squeeze()
speech_lengths = speech_lengths.squeeze()
assert text_lengths.dim() == 1, text_lengths.shape assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified # Check that batch_size is unified
assert ( assert (
@ -164,7 +166,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
# force_gatherable: to-device and to-tensor if scalar for DataParallel # force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss: if self.length_normalized_loss:
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size) batch_size = (text_lengths + self.predictor_bias).sum()
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight return loss, stats, weight
@ -190,8 +192,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
# predictor forward # predictor forward
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device) encoder_out.device)
pre_acoustic_embeds, _, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask, pre_acoustic_embeds = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)[0]
ignore_id=self.ignore_id)
# decoder forward # decoder forward
decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True) decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True)
selected = self._hotword_representation(hotword_pad, selected = self._hotword_representation(hotword_pad,