mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
459 lines
16 KiB
Python
459 lines
16 KiB
Python
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
|
|
|
import logging
|
|
from contextlib import contextmanager
|
|
from distutils.version import LooseVersion
|
|
from typing import Dict
|
|
from typing import List
|
|
from typing import Optional
|
|
from typing import Tuple
|
|
from typing import Union
|
|
|
|
import torch
|
|
from typeguard import check_argument_types
|
|
|
|
from funasr.layers.abs_normalize import AbsNormalize
|
|
from funasr.losses.label_smoothing_loss import (
|
|
LabelSmoothingLoss, # noqa: H301
|
|
)
|
|
from funasr.models.ctc import CTC
|
|
from funasr.models.decoder.abs_decoder import AbsDecoder
|
|
from funasr.models.encoder.abs_encoder import AbsEncoder
|
|
from funasr.models.frontend.abs_frontend import AbsFrontend
|
|
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
|
|
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
|
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
|
from funasr.modules.add_sos_eos import add_sos_eos
|
|
from funasr.modules.e2e_asr_common import ErrorCalculator
|
|
from funasr.modules.nets_utils import th_accuracy
|
|
from funasr.torch_utils.device_funcs import force_gatherable
|
|
from funasr.train.abs_espnet_model import AbsESPnetModel
|
|
|
|
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
|
from torch.cuda.amp import autocast
|
|
else:
|
|
# Nothing to do if torch<1.6.0
|
|
@contextmanager
|
|
def autocast(enabled=True):
|
|
yield
|
|
|
|
|
|
class ESPnetASRModel(AbsESPnetModel):
|
|
"""CTC-attention hybrid Encoder-Decoder model"""
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_size: int,
|
|
token_list: Union[Tuple[str, ...], List[str]],
|
|
frontend: Optional[AbsFrontend],
|
|
specaug: Optional[AbsSpecAug],
|
|
normalize: Optional[AbsNormalize],
|
|
preencoder: Optional[AbsPreEncoder],
|
|
encoder: AbsEncoder,
|
|
postencoder: Optional[AbsPostEncoder],
|
|
decoder: AbsDecoder,
|
|
ctc: CTC,
|
|
ctc_weight: float = 0.5,
|
|
interctc_weight: float = 0.0,
|
|
ignore_id: int = -1,
|
|
lsm_weight: float = 0.0,
|
|
length_normalized_loss: bool = False,
|
|
report_cer: bool = True,
|
|
report_wer: bool = True,
|
|
sym_space: str = "<space>",
|
|
sym_blank: str = "<blank>",
|
|
extract_feats_in_collect_stats: bool = True,
|
|
):
|
|
assert check_argument_types()
|
|
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
|
assert 0.0 <= interctc_weight < 1.0, interctc_weight
|
|
|
|
super().__init__()
|
|
# note that eos is the same as sos (equivalent ID)
|
|
self.blank_id = 0
|
|
self.sos = 1
|
|
self.eos = 2
|
|
self.vocab_size = vocab_size
|
|
self.ignore_id = ignore_id
|
|
self.ctc_weight = ctc_weight
|
|
self.interctc_weight = interctc_weight
|
|
self.token_list = token_list.copy()
|
|
|
|
self.frontend = frontend
|
|
self.specaug = specaug
|
|
self.normalize = normalize
|
|
self.preencoder = preencoder
|
|
self.postencoder = postencoder
|
|
self.encoder = encoder
|
|
|
|
if not hasattr(self.encoder, "interctc_use_conditioning"):
|
|
self.encoder.interctc_use_conditioning = False
|
|
if self.encoder.interctc_use_conditioning:
|
|
self.encoder.conditioning_layer = torch.nn.Linear(
|
|
vocab_size, self.encoder.output_size()
|
|
)
|
|
|
|
self.error_calculator = None
|
|
|
|
|
|
# we set self.decoder = None in the CTC mode since
|
|
# self.decoder parameters were never used and PyTorch complained
|
|
# and threw an Exception in the multi-GPU experiment.
|
|
# thanks Jeff Farris for pointing out the issue.
|
|
if ctc_weight == 1.0:
|
|
self.decoder = None
|
|
else:
|
|
self.decoder = decoder
|
|
|
|
self.criterion_att = LabelSmoothingLoss(
|
|
size=vocab_size,
|
|
padding_idx=ignore_id,
|
|
smoothing=lsm_weight,
|
|
normalize_length=length_normalized_loss,
|
|
)
|
|
|
|
if report_cer or report_wer:
|
|
self.error_calculator = ErrorCalculator(
|
|
token_list, sym_space, sym_blank, report_cer, report_wer
|
|
)
|
|
|
|
if ctc_weight == 0.0:
|
|
self.ctc = None
|
|
else:
|
|
self.ctc = ctc
|
|
|
|
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
|
|
|
|
def forward(
|
|
self,
|
|
speech: torch.Tensor,
|
|
speech_lengths: torch.Tensor,
|
|
text: torch.Tensor,
|
|
text_lengths: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
|
"""Frontend + Encoder + Decoder + Calc loss
|
|
|
|
Args:
|
|
speech: (Batch, Length, ...)
|
|
speech_lengths: (Batch, )
|
|
text: (Batch, Length)
|
|
text_lengths: (Batch,)
|
|
"""
|
|
assert text_lengths.dim() == 1, text_lengths.shape
|
|
# Check that batch_size is unified
|
|
assert (
|
|
speech.shape[0]
|
|
== speech_lengths.shape[0]
|
|
== text.shape[0]
|
|
== text_lengths.shape[0]
|
|
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
|
batch_size = speech.shape[0]
|
|
|
|
# for data-parallel
|
|
text = text[:, : text_lengths.max()]
|
|
|
|
# 1. Encoder
|
|
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
|
intermediate_outs = None
|
|
if isinstance(encoder_out, tuple):
|
|
intermediate_outs = encoder_out[1]
|
|
encoder_out = encoder_out[0]
|
|
|
|
loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
|
loss_ctc, cer_ctc = None, None
|
|
stats = dict()
|
|
|
|
# 1. CTC branch
|
|
if self.ctc_weight != 0.0:
|
|
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
|
encoder_out, encoder_out_lens, text, text_lengths
|
|
)
|
|
|
|
# Collect CTC branch stats
|
|
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
|
stats["cer_ctc"] = cer_ctc
|
|
|
|
# Intermediate CTC (optional)
|
|
loss_interctc = 0.0
|
|
if self.interctc_weight != 0.0 and intermediate_outs is not None:
|
|
for layer_idx, intermediate_out in intermediate_outs:
|
|
# we assume intermediate_out has the same length & padding
|
|
# as those of encoder_out
|
|
loss_ic, cer_ic = self._calc_ctc_loss(
|
|
intermediate_out, encoder_out_lens, text, text_lengths
|
|
)
|
|
loss_interctc = loss_interctc + loss_ic
|
|
|
|
# Collect Intermedaite CTC stats
|
|
stats["loss_interctc_layer{}".format(layer_idx)] = (
|
|
loss_ic.detach() if loss_ic is not None else None
|
|
)
|
|
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
|
|
|
|
loss_interctc = loss_interctc / len(intermediate_outs)
|
|
|
|
# calculate whole encoder loss
|
|
loss_ctc = (
|
|
1 - self.interctc_weight
|
|
) * loss_ctc + self.interctc_weight * loss_interctc
|
|
|
|
|
|
# 2b. Attention decoder branch
|
|
if self.ctc_weight != 1.0:
|
|
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
|
|
encoder_out, encoder_out_lens, text, text_lengths
|
|
)
|
|
|
|
# 3. CTC-Att loss definition
|
|
if self.ctc_weight == 0.0:
|
|
loss = loss_att
|
|
elif self.ctc_weight == 1.0:
|
|
loss = loss_ctc
|
|
else:
|
|
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
|
|
|
|
# Collect Attn branch stats
|
|
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
|
stats["acc"] = acc_att
|
|
stats["cer"] = cer_att
|
|
stats["wer"] = wer_att
|
|
|
|
# Collect total loss stats
|
|
stats["loss"] = torch.clone(loss.detach())
|
|
|
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
|
return loss, stats, weight
|
|
|
|
def collect_feats(
|
|
self,
|
|
speech: torch.Tensor,
|
|
speech_lengths: torch.Tensor,
|
|
text: torch.Tensor,
|
|
text_lengths: torch.Tensor,
|
|
) -> Dict[str, torch.Tensor]:
|
|
if self.extract_feats_in_collect_stats:
|
|
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
|
else:
|
|
# Generate dummy stats if extract_feats_in_collect_stats is False
|
|
logging.warning(
|
|
"Generating dummy stats for feats and feats_lengths, "
|
|
"because encoder_conf.extract_feats_in_collect_stats is "
|
|
f"{self.extract_feats_in_collect_stats}"
|
|
)
|
|
feats, feats_lengths = speech, speech_lengths
|
|
return {"feats": feats, "feats_lengths": feats_lengths}
|
|
|
|
def encode(
|
|
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
|
|
|
Args:
|
|
speech: (Batch, Length, ...)
|
|
speech_lengths: (Batch, )
|
|
"""
|
|
with autocast(False):
|
|
# 1. Extract feats
|
|
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
|
|
|
# 2. Data augmentation
|
|
if self.specaug is not None and self.training:
|
|
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
|
|
|
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
|
if self.normalize is not None:
|
|
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
|
|
|
# Pre-encoder, e.g. used for raw input data
|
|
if self.preencoder is not None:
|
|
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
|
|
|
# 4. Forward encoder
|
|
# feats: (Batch, Length, Dim)
|
|
# -> encoder_out: (Batch, Length2, Dim2)
|
|
if self.encoder.interctc_use_conditioning:
|
|
encoder_out, encoder_out_lens, _ = self.encoder(
|
|
feats, feats_lengths, ctc=self.ctc
|
|
)
|
|
else:
|
|
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
|
intermediate_outs = None
|
|
if isinstance(encoder_out, tuple):
|
|
intermediate_outs = encoder_out[1]
|
|
encoder_out = encoder_out[0]
|
|
|
|
# Post-encoder, e.g. NLU
|
|
if self.postencoder is not None:
|
|
encoder_out, encoder_out_lens = self.postencoder(
|
|
encoder_out, encoder_out_lens
|
|
)
|
|
|
|
assert encoder_out.size(0) == speech.size(0), (
|
|
encoder_out.size(),
|
|
speech.size(0),
|
|
)
|
|
assert encoder_out.size(1) <= encoder_out_lens.max(), (
|
|
encoder_out.size(),
|
|
encoder_out_lens.max(),
|
|
)
|
|
|
|
if intermediate_outs is not None:
|
|
return (encoder_out, intermediate_outs), encoder_out_lens
|
|
|
|
return encoder_out, encoder_out_lens
|
|
|
|
def _extract_feats(
|
|
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
assert speech_lengths.dim() == 1, speech_lengths.shape
|
|
|
|
# for data-parallel
|
|
speech = speech[:, : speech_lengths.max()]
|
|
|
|
if self.frontend is not None:
|
|
# Frontend
|
|
# e.g. STFT and Feature extract
|
|
# data_loader may send time-domain signal in this case
|
|
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
|
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
|
else:
|
|
# No frontend and no feature extract
|
|
feats, feats_lengths = speech, speech_lengths
|
|
return feats, feats_lengths
|
|
|
|
def nll(
|
|
self,
|
|
encoder_out: torch.Tensor,
|
|
encoder_out_lens: torch.Tensor,
|
|
ys_pad: torch.Tensor,
|
|
ys_pad_lens: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Compute negative log likelihood(nll) from transformer-decoder
|
|
|
|
Normally, this function is called in batchify_nll.
|
|
|
|
Args:
|
|
encoder_out: (Batch, Length, Dim)
|
|
encoder_out_lens: (Batch,)
|
|
ys_pad: (Batch, Length)
|
|
ys_pad_lens: (Batch,)
|
|
"""
|
|
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
|
ys_in_lens = ys_pad_lens + 1
|
|
|
|
# 1. Forward decoder
|
|
decoder_out, _ = self.decoder(
|
|
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
|
|
) # [batch, seqlen, dim]
|
|
batch_size = decoder_out.size(0)
|
|
decoder_num_class = decoder_out.size(2)
|
|
# nll: negative log-likelihood
|
|
nll = torch.nn.functional.cross_entropy(
|
|
decoder_out.view(-1, decoder_num_class),
|
|
ys_out_pad.view(-1),
|
|
ignore_index=self.ignore_id,
|
|
reduction="none",
|
|
)
|
|
nll = nll.view(batch_size, -1)
|
|
nll = nll.sum(dim=1)
|
|
assert nll.size(0) == batch_size
|
|
return nll
|
|
|
|
def batchify_nll(
|
|
self,
|
|
encoder_out: torch.Tensor,
|
|
encoder_out_lens: torch.Tensor,
|
|
ys_pad: torch.Tensor,
|
|
ys_pad_lens: torch.Tensor,
|
|
batch_size: int = 100,
|
|
):
|
|
"""Compute negative log likelihood(nll) from transformer-decoder
|
|
|
|
To avoid OOM, this fuction seperate the input into batches.
|
|
Then call nll for each batch and combine and return results.
|
|
Args:
|
|
encoder_out: (Batch, Length, Dim)
|
|
encoder_out_lens: (Batch,)
|
|
ys_pad: (Batch, Length)
|
|
ys_pad_lens: (Batch,)
|
|
batch_size: int, samples each batch contain when computing nll,
|
|
you may change this to avoid OOM or increase
|
|
GPU memory usage
|
|
"""
|
|
total_num = encoder_out.size(0)
|
|
if total_num <= batch_size:
|
|
nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
|
else:
|
|
nll = []
|
|
start_idx = 0
|
|
while True:
|
|
end_idx = min(start_idx + batch_size, total_num)
|
|
batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
|
|
batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
|
|
batch_ys_pad = ys_pad[start_idx:end_idx, :]
|
|
batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
|
|
batch_nll = self.nll(
|
|
batch_encoder_out,
|
|
batch_encoder_out_lens,
|
|
batch_ys_pad,
|
|
batch_ys_pad_lens,
|
|
)
|
|
nll.append(batch_nll)
|
|
start_idx = end_idx
|
|
if start_idx == total_num:
|
|
break
|
|
nll = torch.cat(nll)
|
|
assert nll.size(0) == total_num
|
|
return nll
|
|
|
|
def _calc_att_loss(
|
|
self,
|
|
encoder_out: torch.Tensor,
|
|
encoder_out_lens: torch.Tensor,
|
|
ys_pad: torch.Tensor,
|
|
ys_pad_lens: torch.Tensor,
|
|
):
|
|
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
|
ys_in_lens = ys_pad_lens + 1
|
|
|
|
# 1. Forward decoder
|
|
decoder_out, _ = self.decoder(
|
|
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
|
|
)
|
|
|
|
# 2. Compute attention loss
|
|
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
|
acc_att = th_accuracy(
|
|
decoder_out.view(-1, self.vocab_size),
|
|
ys_out_pad,
|
|
ignore_label=self.ignore_id,
|
|
)
|
|
|
|
# Compute cer/wer using attention-decoder
|
|
if self.training or self.error_calculator is None:
|
|
cer_att, wer_att = None, None
|
|
else:
|
|
ys_hat = decoder_out.argmax(dim=-1)
|
|
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
|
|
|
|
return loss_att, acc_att, cer_att, wer_att
|
|
|
|
def _calc_ctc_loss(
|
|
self,
|
|
encoder_out: torch.Tensor,
|
|
encoder_out_lens: torch.Tensor,
|
|
ys_pad: torch.Tensor,
|
|
ys_pad_lens: torch.Tensor,
|
|
):
|
|
# Calc CTC loss
|
|
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
|
|
|
# Calc CER using CTC
|
|
cer_ctc = None
|
|
if not self.training and self.error_calculator is not None:
|
|
ys_hat = self.ctc.argmax(encoder_out).data
|
|
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
|
|
return loss_ctc, cer_ctc
|