mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update seaco finetune (#1526)
This commit is contained in:
parent
944a053a66
commit
980007a486
@ -10,7 +10,7 @@ gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
|||||||
|
|
||||||
## option 1, download model automatically
|
## option 1, download model automatically
|
||||||
model_name_or_model_dir="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
model_name_or_model_dir="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||||
model_revision="v2.0.4"
|
model_revision="v2.0.7"
|
||||||
|
|
||||||
## option 2, download model by git
|
## option 2, download model by git
|
||||||
#local_path_root=${workspace}/modelscope_models
|
#local_path_root=${workspace}/modelscope_models
|
||||||
|
|||||||
@ -94,10 +94,8 @@ class ContextualParaformer(Paraformer):
|
|||||||
text: (Batch, Length)
|
text: (Batch, Length)
|
||||||
text_lengths: (Batch,)
|
text_lengths: (Batch,)
|
||||||
"""
|
"""
|
||||||
if len(text_lengths.size()) > 1:
|
text_lengths = text_lengths.squeeze()
|
||||||
text_lengths = text_lengths[:, 0]
|
speech_lengths = speech_lengths.squeeze()
|
||||||
if len(speech_lengths.size()) > 1:
|
|
||||||
speech_lengths = speech_lengths[:, 0]
|
|
||||||
|
|
||||||
batch_size = speech.shape[0]
|
batch_size = speech.shape[0]
|
||||||
|
|
||||||
|
|||||||
@ -117,6 +117,8 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
|||||||
text: (Batch, Length)
|
text: (Batch, Length)
|
||||||
text_lengths: (Batch,)
|
text_lengths: (Batch,)
|
||||||
"""
|
"""
|
||||||
|
text_lengths = text_lengths.squeeze()
|
||||||
|
speech_lengths = speech_lengths.squeeze()
|
||||||
assert text_lengths.dim() == 1, text_lengths.shape
|
assert text_lengths.dim() == 1, text_lengths.shape
|
||||||
# Check that batch_size is unified
|
# Check that batch_size is unified
|
||||||
assert (
|
assert (
|
||||||
@ -164,7 +166,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
|||||||
|
|
||||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||||
if self.length_normalized_loss:
|
if self.length_normalized_loss:
|
||||||
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
batch_size = (text_lengths + self.predictor_bias).sum()
|
||||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||||
return loss, stats, weight
|
return loss, stats, weight
|
||||||
|
|
||||||
@ -190,8 +192,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
|||||||
# predictor forward
|
# predictor forward
|
||||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||||
encoder_out.device)
|
encoder_out.device)
|
||||||
pre_acoustic_embeds, _, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
|
pre_acoustic_embeds = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)[0]
|
||||||
ignore_id=self.ignore_id)
|
|
||||||
# decoder forward
|
# decoder forward
|
||||||
decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True)
|
decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True)
|
||||||
selected = self._hotword_representation(hotword_pad,
|
selected = self._hotword_representation(hotword_pad,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user