mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update seaco finetune
This commit is contained in:
parent
a65016e23e
commit
149063ced4
@ -181,8 +181,6 @@ class Paraformer(torch.nn.Module):
|
|||||||
text: (Batch, Length)
|
text: (Batch, Length)
|
||||||
text_lengths: (Batch,)
|
text_lengths: (Batch,)
|
||||||
"""
|
"""
|
||||||
# import pdb;
|
|
||||||
# pdb.set_trace()
|
|
||||||
if len(text_lengths.size()) > 1:
|
if len(text_lengths.size()) > 1:
|
||||||
text_lengths = text_lengths[:, 0]
|
text_lengths = text_lengths[:, 0]
|
||||||
if len(speech_lengths.size()) > 1:
|
if len(speech_lengths.size()) > 1:
|
||||||
@ -190,7 +188,6 @@ class Paraformer(torch.nn.Module):
|
|||||||
|
|
||||||
batch_size = speech.shape[0]
|
batch_size = speech.shape[0]
|
||||||
|
|
||||||
|
|
||||||
# Encoder
|
# Encoder
|
||||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||||
|
|
||||||
|
|||||||
@ -97,7 +97,8 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
|||||||
smoothing=seaco_lsm_weight,
|
smoothing=seaco_lsm_weight,
|
||||||
normalize_length=seaco_length_normalized_loss,
|
normalize_length=seaco_length_normalized_loss,
|
||||||
)
|
)
|
||||||
self.train_decoder = kwargs.get("train_decoder", False)
|
self.train_decoder = kwargs.get("train_decoder", True)
|
||||||
|
self.seaco_weight = kwargs.get("seaco_weight", 0.01)
|
||||||
self.NO_BIAS = kwargs.get("NO_BIAS", 8377)
|
self.NO_BIAS = kwargs.get("NO_BIAS", 8377)
|
||||||
self.predictor_name = kwargs.get("predictor")
|
self.predictor_name = kwargs.get("predictor")
|
||||||
|
|
||||||
@ -117,9 +118,10 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
|||||||
text: (Batch, Length)
|
text: (Batch, Length)
|
||||||
text_lengths: (Batch,)
|
text_lengths: (Batch,)
|
||||||
"""
|
"""
|
||||||
text_lengths = text_lengths.squeeze()
|
if len(text_lengths.size()) > 1:
|
||||||
speech_lengths = speech_lengths.squeeze()
|
text_lengths = text_lengths[:, 0]
|
||||||
assert text_lengths.dim() == 1, text_lengths.shape
|
if len(speech_lengths.size()) > 1:
|
||||||
|
speech_lengths = speech_lengths[:, 0]
|
||||||
# Check that batch_size is unified
|
# Check that batch_size is unified
|
||||||
assert (
|
assert (
|
||||||
speech.shape[0]
|
speech.shape[0]
|
||||||
@ -131,6 +133,8 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
|||||||
hotword_pad = kwargs.get("hotword_pad")
|
hotword_pad = kwargs.get("hotword_pad")
|
||||||
hotword_lengths = kwargs.get("hotword_lengths")
|
hotword_lengths = kwargs.get("hotword_lengths")
|
||||||
seaco_label_pad = kwargs.get("seaco_label_pad")
|
seaco_label_pad = kwargs.get("seaco_label_pad")
|
||||||
|
if len(hotword_lengths.size()) > 1:
|
||||||
|
hotword_lengths = hotword_lengths[:, 0]
|
||||||
|
|
||||||
batch_size = speech.shape[0]
|
batch_size = speech.shape[0]
|
||||||
# for data-parallel
|
# for data-parallel
|
||||||
@ -156,11 +160,12 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
|||||||
loss_att, acc_att = self._calc_att_loss(
|
loss_att, acc_att = self._calc_att_loss(
|
||||||
encoder_out, encoder_out_lens, text, text_lengths
|
encoder_out, encoder_out_lens, text, text_lengths
|
||||||
)
|
)
|
||||||
loss = loss_seaco + loss_att
|
loss = loss_seaco + loss_att * self.seaco_weight
|
||||||
stats["loss_att"] = torch.clone(loss_att.detach())
|
stats["loss_att"] = torch.clone(loss_att.detach())
|
||||||
stats["acc_att"] = acc_att
|
stats["acc_att"] = acc_att
|
||||||
else:
|
else:
|
||||||
loss = loss_seaco
|
loss = loss_seaco
|
||||||
|
|
||||||
stats["loss_seaco"] = torch.clone(loss_seaco.detach())
|
stats["loss_seaco"] = torch.clone(loss_seaco.detach())
|
||||||
stats["loss"] = torch.clone(loss.detach())
|
stats["loss"] = torch.clone(loss.detach())
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user