mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update seaco finetune (#1526)
This commit is contained in:
parent
944a053a66
commit
980007a486
@ -10,7 +10,7 @@ gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
|
||||
## option 1, download model automatically
|
||||
model_name_or_model_dir="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
model_revision="v2.0.4"
|
||||
model_revision="v2.0.7"
|
||||
|
||||
## option 2, download model by git
|
||||
#local_path_root=${workspace}/modelscope_models
|
||||
|
||||
@ -94,10 +94,8 @@ class ContextualParaformer(Paraformer):
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
"""
|
||||
if len(text_lengths.size()) > 1:
|
||||
text_lengths = text_lengths[:, 0]
|
||||
if len(speech_lengths.size()) > 1:
|
||||
speech_lengths = speech_lengths[:, 0]
|
||||
text_lengths = text_lengths.squeeze()
|
||||
speech_lengths = speech_lengths.squeeze()
|
||||
|
||||
batch_size = speech.shape[0]
|
||||
|
||||
|
||||
@ -117,6 +117,8 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
"""
|
||||
text_lengths = text_lengths.squeeze()
|
||||
speech_lengths = speech_lengths.squeeze()
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
# Check that batch_size is unified
|
||||
assert (
|
||||
@ -164,7 +166,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
||||
batch_size = (text_lengths + self.predictor_bias).sum()
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
@ -190,8 +192,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
# predictor forward
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||
encoder_out.device)
|
||||
pre_acoustic_embeds, _, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
|
||||
ignore_id=self.ignore_id)
|
||||
pre_acoustic_embeds = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)[0]
|
||||
# decoder forward
|
||||
decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True)
|
||||
selected = self._hotword_representation(hotword_pad,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user