mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update error calculator for rnnt
This commit is contained in:
parent
490531ed51
commit
bdb8a99da4
@ -386,7 +386,7 @@ class TransducerModel(AbsESPnetModel):
|
|||||||
|
|
||||||
if not self.training and (self.report_cer or self.report_wer):
|
if not self.training and (self.report_cer or self.report_wer):
|
||||||
if self.error_calculator is None:
|
if self.error_calculator is None:
|
||||||
from espnet2.asr_transducer.error_calculator import ErrorCalculator
|
from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
|
||||||
|
|
||||||
self.error_calculator = ErrorCalculator(
|
self.error_calculator = ErrorCalculator(
|
||||||
self.decoder,
|
self.decoder,
|
||||||
@ -398,7 +398,7 @@ class TransducerModel(AbsESPnetModel):
|
|||||||
report_wer=self.report_wer,
|
report_wer=self.report_wer,
|
||||||
)
|
)
|
||||||
|
|
||||||
cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
|
cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
|
||||||
|
|
||||||
return loss_transducer, cer_transducer, wer_transducer
|
return loss_transducer, cer_transducer, wer_transducer
|
||||||
|
|
||||||
@ -889,6 +889,8 @@ class UnifiedTransducerModel(AbsESPnetModel):
|
|||||||
|
|
||||||
if not self.training and (self.report_cer or self.report_wer):
|
if not self.training and (self.report_cer or self.report_wer):
|
||||||
if self.error_calculator is None:
|
if self.error_calculator is None:
|
||||||
|
from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
|
||||||
|
|
||||||
self.error_calculator = ErrorCalculator(
|
self.error_calculator = ErrorCalculator(
|
||||||
self.decoder,
|
self.decoder,
|
||||||
self.joint_network,
|
self.joint_network,
|
||||||
@ -899,7 +901,7 @@ class UnifiedTransducerModel(AbsESPnetModel):
|
|||||||
report_wer=self.report_wer,
|
report_wer=self.report_wer,
|
||||||
)
|
)
|
||||||
|
|
||||||
cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
|
cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
|
||||||
return loss_transducer, cer_transducer, wer_transducer
|
return loss_transducer, cer_transducer, wer_transducer
|
||||||
|
|
||||||
return loss_transducer, None, None
|
return loss_transducer, None, None
|
||||||
|
|||||||
@ -296,12 +296,13 @@ class ErrorCalculatorTransducer:
|
|||||||
self.report_wer = report_wer
|
self.report_wer = report_wer
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, encoder_out: torch.Tensor, target: torch.Tensor
|
self, encoder_out: torch.Tensor, target: torch.Tensor, encoder_out_lens: torch.Tensor,
|
||||||
) -> Tuple[Optional[float], Optional[float]]:
|
) -> Tuple[Optional[float], Optional[float]]:
|
||||||
"""Calculate sentence-level WER or/and CER score for Transducer model.
|
"""Calculate sentence-level WER or/and CER score for Transducer model.
|
||||||
Args:
|
Args:
|
||||||
encoder_out: Encoder output sequences. (B, T, D_enc)
|
encoder_out: Encoder output sequences. (B, T, D_enc)
|
||||||
target: Target label ID sequences. (B, L)
|
target: Target label ID sequences. (B, L)
|
||||||
|
encoder_out_lens: Encoder output sequences length. (B,)
|
||||||
Returns:
|
Returns:
|
||||||
: Sentence-level CER score.
|
: Sentence-level CER score.
|
||||||
: Sentence-level WER score.
|
: Sentence-level WER score.
|
||||||
@ -312,7 +313,10 @@ class ErrorCalculatorTransducer:
|
|||||||
|
|
||||||
encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
|
encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
|
||||||
|
|
||||||
batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
|
batch_nbest = [
|
||||||
|
self.beam_search(encoder_out[b][: encoder_out_lens[b]])
|
||||||
|
for b in range(batchsize)
|
||||||
|
]
|
||||||
pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
|
pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
|
||||||
|
|
||||||
char_pred, char_target = self.convert_to_char(pred, target)
|
char_pred, char_target = self.convert_to_char(pred, target)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user