mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
ec383fba56
commit
e358063f03
@ -92,17 +92,8 @@ class Paraformer(FunASRModel):
|
|||||||
self.frontend = frontend
|
self.frontend = frontend
|
||||||
self.specaug = specaug
|
self.specaug = specaug
|
||||||
self.normalize = normalize
|
self.normalize = normalize
|
||||||
self.preencoder = preencoder
|
|
||||||
self.postencoder = postencoder
|
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
|
|
||||||
if not hasattr(self.encoder, "interctc_use_conditioning"):
|
|
||||||
self.encoder.interctc_use_conditioning = False
|
|
||||||
if self.encoder.interctc_use_conditioning:
|
|
||||||
self.encoder.conditioning_layer = torch.nn.Linear(
|
|
||||||
vocab_size, self.encoder.output_size()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.error_calculator = None
|
self.error_calculator = None
|
||||||
|
|
||||||
if ctc_weight == 1.0:
|
if ctc_weight == 1.0:
|
||||||
@ -170,9 +161,7 @@ class Paraformer(FunASRModel):
|
|||||||
|
|
||||||
# 1. Encoder
|
# 1. Encoder
|
||||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||||
intermediate_outs = None
|
|
||||||
if isinstance(encoder_out, tuple):
|
if isinstance(encoder_out, tuple):
|
||||||
intermediate_outs = encoder_out[1]
|
|
||||||
encoder_out = encoder_out[0]
|
encoder_out = encoder_out[0]
|
||||||
|
|
||||||
loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
||||||
@ -190,30 +179,6 @@ class Paraformer(FunASRModel):
|
|||||||
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
||||||
stats["cer_ctc"] = cer_ctc
|
stats["cer_ctc"] = cer_ctc
|
||||||
|
|
||||||
# Intermediate CTC (optional)
|
|
||||||
loss_interctc = 0.0
|
|
||||||
if self.interctc_weight != 0.0 and intermediate_outs is not None:
|
|
||||||
for layer_idx, intermediate_out in intermediate_outs:
|
|
||||||
# we assume intermediate_out has the same length & padding
|
|
||||||
# as those of encoder_out
|
|
||||||
loss_ic, cer_ic = self._calc_ctc_loss(
|
|
||||||
intermediate_out, encoder_out_lens, text, text_lengths
|
|
||||||
)
|
|
||||||
loss_interctc = loss_interctc + loss_ic
|
|
||||||
|
|
||||||
# Collect Intermedaite CTC stats
|
|
||||||
stats["loss_interctc_layer{}".format(layer_idx)] = (
|
|
||||||
loss_ic.detach() if loss_ic is not None else None
|
|
||||||
)
|
|
||||||
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
|
|
||||||
|
|
||||||
loss_interctc = loss_interctc / len(intermediate_outs)
|
|
||||||
|
|
||||||
# calculate whole encoder loss
|
|
||||||
loss_ctc = (
|
|
||||||
1 - self.interctc_weight
|
|
||||||
) * loss_ctc + self.interctc_weight * loss_interctc
|
|
||||||
|
|
||||||
# 2b. Attention decoder branch
|
# 2b. Attention decoder branch
|
||||||
if self.ctc_weight != 1.0:
|
if self.ctc_weight != 1.0:
|
||||||
loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
|
loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
|
||||||
@ -281,29 +246,8 @@ class Paraformer(FunASRModel):
|
|||||||
if self.normalize is not None:
|
if self.normalize is not None:
|
||||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||||
|
|
||||||
# Pre-encoder, e.g. used for raw input data
|
|
||||||
if self.preencoder is not None:
|
|
||||||
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
|
||||||
|
|
||||||
# 4. Forward encoder
|
# 4. Forward encoder
|
||||||
# feats: (Batch, Length, Dim)
|
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
||||||
# -> encoder_out: (Batch, Length2, Dim2)
|
|
||||||
if self.encoder.interctc_use_conditioning:
|
|
||||||
encoder_out, encoder_out_lens, _ = self.encoder(
|
|
||||||
feats, feats_lengths, ctc=self.ctc
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
|
||||||
intermediate_outs = None
|
|
||||||
if isinstance(encoder_out, tuple):
|
|
||||||
intermediate_outs = encoder_out[1]
|
|
||||||
encoder_out = encoder_out[0]
|
|
||||||
|
|
||||||
# Post-encoder, e.g. NLU
|
|
||||||
if self.postencoder is not None:
|
|
||||||
encoder_out, encoder_out_lens = self.postencoder(
|
|
||||||
encoder_out, encoder_out_lens
|
|
||||||
)
|
|
||||||
|
|
||||||
assert encoder_out.size(0) == speech.size(0), (
|
assert encoder_out.size(0) == speech.size(0), (
|
||||||
encoder_out.size(),
|
encoder_out.size(),
|
||||||
@ -314,9 +258,6 @@ class Paraformer(FunASRModel):
|
|||||||
encoder_out_lens.max(),
|
encoder_out_lens.max(),
|
||||||
)
|
)
|
||||||
|
|
||||||
if intermediate_outs is not None:
|
|
||||||
return (encoder_out, intermediate_outs), encoder_out_lens
|
|
||||||
|
|
||||||
return encoder_out, encoder_out_lens
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
def encode_chunk(
|
def encode_chunk(
|
||||||
@ -340,32 +281,8 @@ class Paraformer(FunASRModel):
|
|||||||
if self.normalize is not None:
|
if self.normalize is not None:
|
||||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||||
|
|
||||||
# Pre-encoder, e.g. used for raw input data
|
|
||||||
if self.preencoder is not None:
|
|
||||||
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
|
||||||
|
|
||||||
# 4. Forward encoder
|
# 4. Forward encoder
|
||||||
# feats: (Batch, Length, Dim)
|
encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
|
||||||
# -> encoder_out: (Batch, Length2, Dim2)
|
|
||||||
if self.encoder.interctc_use_conditioning:
|
|
||||||
encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
|
|
||||||
feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
|
|
||||||
intermediate_outs = None
|
|
||||||
if isinstance(encoder_out, tuple):
|
|
||||||
intermediate_outs = encoder_out[1]
|
|
||||||
encoder_out = encoder_out[0]
|
|
||||||
|
|
||||||
# Post-encoder, e.g. NLU
|
|
||||||
if self.postencoder is not None:
|
|
||||||
encoder_out, encoder_out_lens = self.postencoder(
|
|
||||||
encoder_out, encoder_out_lens
|
|
||||||
)
|
|
||||||
|
|
||||||
if intermediate_outs is not None:
|
|
||||||
return (encoder_out, intermediate_outs), encoder_out_lens
|
|
||||||
|
|
||||||
return encoder_out, torch.tensor([encoder_out.size(1)])
|
return encoder_out, torch.tensor([encoder_out.size(1)])
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user