mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update bicifparaformer forward
This commit is contained in:
parent
cf8646cd92
commit
2d65e5e754
@ -977,6 +977,59 @@ class BiCifParaformer(Paraformer):
|
||||
loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
|
||||
|
||||
return loss_pre2
|
||||
|
||||
def _calc_att_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||
encoder_out.device)
|
||||
if self.predictor_bias == 1:
|
||||
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
||||
ys_pad_lens = ys_pad_lens + self.predictor_bias
|
||||
pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
|
||||
ignore_id=self.ignore_id)
|
||||
|
||||
# 0. sampler
|
||||
decoder_out_1st = None
|
||||
if self.sampling_ratio > 0.0:
|
||||
if self.step_cur < 2:
|
||||
logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
|
||||
sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
|
||||
pre_acoustic_embeds)
|
||||
else:
|
||||
if self.step_cur < 2:
|
||||
logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
|
||||
sematic_embeds = pre_acoustic_embeds
|
||||
|
||||
# 1. Forward decoder
|
||||
decoder_outs = self.decoder(
|
||||
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
|
||||
)
|
||||
decoder_out, _ = decoder_outs[0], decoder_outs[1]
|
||||
|
||||
if decoder_out_1st is None:
|
||||
decoder_out_1st = decoder_out
|
||||
# 2. Compute attention loss
|
||||
loss_att = self.criterion_att(decoder_out, ys_pad)
|
||||
acc_att = th_accuracy(
|
||||
decoder_out_1st.view(-1, self.vocab_size),
|
||||
ys_pad,
|
||||
ignore_label=self.ignore_id,
|
||||
)
|
||||
loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
|
||||
|
||||
# Compute cer/wer using attention-decoder
|
||||
if self.training or self.error_calculator is None:
|
||||
cer_att, wer_att = None, None
|
||||
else:
|
||||
ys_hat = decoder_out_1st.argmax(dim=-1)
|
||||
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
|
||||
|
||||
return loss_att, acc_att, cer_att, wer_att, loss_pre
|
||||
|
||||
def calc_predictor(self, encoder_out, encoder_out_lens):
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user