mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
boundary aware transducer (#691)
* boundary aware transducer * resolve conflict * delete type check --------- Co-authored-by: aky15 <ankeyu.aky@11.17.44.249>
This commit is contained in:
parent
cf36ce977c
commit
05ada32da8
@ -1604,6 +1604,8 @@ def inference_launch(**kwargs):
|
||||
return inference_mfcca(**kwargs)
|
||||
elif mode == "rnnt":
|
||||
return inference_transducer(**kwargs)
|
||||
elif mode == "bat":
|
||||
return inference_transducer(**kwargs)
|
||||
elif mode == "sa_asr":
|
||||
return inference_sa_asr(**kwargs)
|
||||
else:
|
||||
|
||||
@ -26,6 +26,7 @@ from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
|
||||
from funasr.models.e2e_asr_mfcca import MFCCA
|
||||
|
||||
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
|
||||
from funasr.models.e2e_asr_bat import BATModel
|
||||
|
||||
from funasr.models.e2e_sa_asr import SAASRModel
|
||||
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
|
||||
@ -46,7 +47,7 @@ from funasr.models.frontend.s3prl import S3prlFrontend
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
from funasr.models.frontend.windowing import SlidingWindow
|
||||
from funasr.models.joint_net.joint_network import JointNetwork
|
||||
from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
|
||||
from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3, BATPredictor
|
||||
from funasr.models.specaug.specaug import SpecAug
|
||||
from funasr.models.specaug.specaug import SpecAugLFR
|
||||
from funasr.modules.subsampling import Conv1dSubsampling
|
||||
@ -99,7 +100,7 @@ model_choices = ClassChoices(
|
||||
rnnt=TransducerModel,
|
||||
rnnt_unified=UnifiedTransducerModel,
|
||||
sa_asr=SAASRModel,
|
||||
|
||||
bat=BATModel,
|
||||
),
|
||||
default="asr",
|
||||
)
|
||||
@ -188,6 +189,7 @@ predictor_choices = ClassChoices(
|
||||
ctc_predictor=None,
|
||||
cif_predictor_v2=CifPredictorV2,
|
||||
cif_predictor_v3=CifPredictorV3,
|
||||
bat_predictor=BATPredictor,
|
||||
),
|
||||
default="cif_predictor",
|
||||
optional=True,
|
||||
@ -313,12 +315,15 @@ def build_asr_model(args):
|
||||
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
|
||||
|
||||
# decoder
|
||||
decoder_class = decoder_choices.get_class(args.decoder)
|
||||
decoder = decoder_class(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder.output_size(),
|
||||
**args.decoder_conf,
|
||||
)
|
||||
if hasattr(args, "decoder") and args.decoder is not None:
|
||||
decoder_class = decoder_choices.get_class(args.decoder)
|
||||
decoder = decoder_class(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder.output_size(),
|
||||
**args.decoder_conf,
|
||||
)
|
||||
else:
|
||||
decoder = None
|
||||
|
||||
# ctc
|
||||
ctc = CTC(
|
||||
@ -463,6 +468,53 @@ def build_asr_model(args):
|
||||
joint_network=joint_network,
|
||||
**args.model_conf,
|
||||
)
|
||||
elif args.model == "bat":
|
||||
# 5. Decoder
|
||||
encoder_output_size = encoder.output_size()
|
||||
|
||||
rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
|
||||
decoder = rnnt_decoder_class(
|
||||
vocab_size,
|
||||
**args.rnnt_decoder_conf,
|
||||
)
|
||||
decoder_output_size = decoder.output_size
|
||||
|
||||
if getattr(args, "decoder", None) is not None:
|
||||
att_decoder_class = decoder_choices.get_class(args.decoder)
|
||||
|
||||
att_decoder = att_decoder_class(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
**args.decoder_conf,
|
||||
)
|
||||
else:
|
||||
att_decoder = None
|
||||
# 6. Joint Network
|
||||
joint_network = JointNetwork(
|
||||
vocab_size,
|
||||
encoder_output_size,
|
||||
decoder_output_size,
|
||||
**args.joint_network_conf,
|
||||
)
|
||||
|
||||
predictor_class = predictor_choices.get_class(args.predictor)
|
||||
predictor = predictor_class(**args.predictor_conf)
|
||||
|
||||
model_class = model_choices.get_class(args.model)
|
||||
# 7. Build model
|
||||
model = model_class(
|
||||
vocab_size=vocab_size,
|
||||
token_list=token_list,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
att_decoder=att_decoder,
|
||||
joint_network=joint_network,
|
||||
predictor=predictor,
|
||||
**args.model_conf,
|
||||
)
|
||||
elif args.model == "sa_asr":
|
||||
asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
|
||||
asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
|
||||
|
||||
496
funasr/models/e2e_asr_bat.py
Normal file
496
funasr/models/e2e_asr_bat.py
Normal file
@ -0,0 +1,496 @@
|
||||
"""Boundary Aware Transducer (BAT) model."""
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from packaging.version import parse as V
|
||||
from funasr.losses.label_smoothing_loss import (
|
||||
LabelSmoothingLoss, # noqa: H301
|
||||
)
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.joint_net.joint_network import JointNetwork
|
||||
from funasr.modules.nets_utils import get_transducer_task_io
|
||||
from funasr.modules.nets_utils import th_accuracy
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.add_sos_eos import add_sos_eos
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.models.base_model import FunASRModel
|
||||
|
||||
if V(torch.__version__) >= V("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
|
||||
class BATModel(FunASRModel):
|
||||
"""BATModel module definition.
|
||||
|
||||
Args:
|
||||
vocab_size: Size of complete vocabulary (w/ EOS and blank included).
|
||||
token_list: List of token
|
||||
frontend: Frontend module.
|
||||
specaug: SpecAugment module.
|
||||
normalize: Normalization module.
|
||||
encoder: Encoder module.
|
||||
decoder: Decoder module.
|
||||
joint_network: Joint Network module.
|
||||
transducer_weight: Weight of the Transducer loss.
|
||||
fastemit_lambda: FastEmit lambda value.
|
||||
auxiliary_ctc_weight: Weight of auxiliary CTC loss.
|
||||
auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
|
||||
auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
|
||||
auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
|
||||
ignore_id: Initial padding ID.
|
||||
sym_space: Space symbol.
|
||||
sym_blank: Blank Symbol
|
||||
report_cer: Whether to report Character Error Rate during validation.
|
||||
report_wer: Whether to report Word Error Rate during validation.
|
||||
extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
encoder: AbsEncoder,
|
||||
decoder: RNNTDecoder,
|
||||
joint_network: JointNetwork,
|
||||
att_decoder: Optional[AbsAttDecoder] = None,
|
||||
predictor = None,
|
||||
transducer_weight: float = 1.0,
|
||||
predictor_weight: float = 1.0,
|
||||
cif_weight: float = 1.0,
|
||||
fastemit_lambda: float = 0.0,
|
||||
auxiliary_ctc_weight: float = 0.0,
|
||||
auxiliary_ctc_dropout_rate: float = 0.0,
|
||||
auxiliary_lm_loss_weight: float = 0.0,
|
||||
auxiliary_lm_loss_smoothing: float = 0.0,
|
||||
ignore_id: int = -1,
|
||||
sym_space: str = "<space>",
|
||||
sym_blank: str = "<blank>",
|
||||
report_cer: bool = True,
|
||||
report_wer: bool = True,
|
||||
extract_feats_in_collect_stats: bool = True,
|
||||
lsm_weight: float = 0.0,
|
||||
length_normalized_loss: bool = False,
|
||||
r_d: int = 5,
|
||||
r_u: int = 5,
|
||||
) -> None:
|
||||
"""Construct an BATModel object."""
|
||||
super().__init__()
|
||||
|
||||
# The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
|
||||
self.blank_id = 0
|
||||
self.vocab_size = vocab_size
|
||||
self.ignore_id = ignore_id
|
||||
self.token_list = token_list.copy()
|
||||
|
||||
self.sym_space = sym_space
|
||||
self.sym_blank = sym_blank
|
||||
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.joint_network = joint_network
|
||||
|
||||
self.criterion_transducer = None
|
||||
self.error_calculator = None
|
||||
|
||||
self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
|
||||
self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
|
||||
|
||||
if self.use_auxiliary_ctc:
|
||||
self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size)
|
||||
self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
|
||||
|
||||
if self.use_auxiliary_lm_loss:
|
||||
self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
|
||||
self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
|
||||
|
||||
self.transducer_weight = transducer_weight
|
||||
self.fastemit_lambda = fastemit_lambda
|
||||
|
||||
self.auxiliary_ctc_weight = auxiliary_ctc_weight
|
||||
self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
|
||||
|
||||
self.report_cer = report_cer
|
||||
self.report_wer = report_wer
|
||||
|
||||
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
|
||||
|
||||
self.criterion_pre = torch.nn.L1Loss()
|
||||
self.predictor_weight = predictor_weight
|
||||
self.predictor = predictor
|
||||
|
||||
self.cif_weight = cif_weight
|
||||
if self.cif_weight > 0:
|
||||
self.cif_output_layer = torch.nn.Linear(encoder.output_size(), vocab_size)
|
||||
self.criterion_cif = LabelSmoothingLoss(
|
||||
size=vocab_size,
|
||||
padding_idx=ignore_id,
|
||||
smoothing=lsm_weight,
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
self.r_d = r_d
|
||||
self.r_u = r_u
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Forward architecture and compute loss(es).
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
text: Label ID sequences. (B, L)
|
||||
text_lengths: Label ID sequences lengths. (B,)
|
||||
kwargs: Contains "utts_id".
|
||||
|
||||
Return:
|
||||
loss: Main loss value.
|
||||
stats: Task statistics.
|
||||
weight: Task weights.
|
||||
|
||||
"""
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
== text.shape[0]
|
||||
== text_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||
|
||||
batch_size = speech.shape[0]
|
||||
text = text[:, : text_lengths.max()]
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
if hasattr(self.encoder, 'overlap_chunk_cls') and self.encoder.overlap_chunk_cls is not None:
|
||||
encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
|
||||
chunk_outs=None)
|
||||
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(encoder_out.device)
|
||||
# 2. Transducer-related I/O preparation
|
||||
decoder_in, target, t_len, u_len = get_transducer_task_io(
|
||||
text,
|
||||
encoder_out_lens,
|
||||
ignore_id=self.ignore_id,
|
||||
)
|
||||
|
||||
# 3. Decoder
|
||||
self.decoder.set_device(encoder_out.device)
|
||||
decoder_out = self.decoder(decoder_in, u_len)
|
||||
|
||||
pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=self.ignore_id)
|
||||
loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length), pre_token_length)
|
||||
|
||||
if self.cif_weight > 0.0:
|
||||
cif_predict = self.cif_output_layer(pre_acoustic_embeds)
|
||||
loss_cif = self.criterion_cif(cif_predict, text)
|
||||
else:
|
||||
loss_cif = 0.0
|
||||
|
||||
# 5. Losses
|
||||
boundary = torch.zeros((encoder_out.size(0), 4), dtype=torch.int64, device=encoder_out.device)
|
||||
boundary[:, 2] = u_len.long().detach()
|
||||
boundary[:, 3] = t_len.long().detach()
|
||||
|
||||
pre_peak_index = torch.floor(pre_peak_index).long()
|
||||
s_begin = pre_peak_index - self.r_d
|
||||
|
||||
T = encoder_out.size(1)
|
||||
B = encoder_out.size(0)
|
||||
U = decoder_out.size(1)
|
||||
|
||||
mask = torch.arange(0, T, device=encoder_out.device).reshape(1, T).expand(B, T)
|
||||
mask = mask <= boundary[:, 3].reshape(B, 1) - 1
|
||||
|
||||
s_begin_padding = boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1
|
||||
# handle the cases where `len(symbols) < s_range`
|
||||
s_begin_padding = torch.clamp(s_begin_padding, min=0)
|
||||
|
||||
s_begin = torch.where(mask, s_begin, s_begin_padding)
|
||||
|
||||
mask2 = s_begin < boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1
|
||||
|
||||
s_begin = torch.where(mask2, s_begin, boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1)
|
||||
|
||||
s_begin = torch.clamp(s_begin, min=0)
|
||||
|
||||
ranges = s_begin.reshape((B, T, 1)).expand((B, T, min(self.r_u+self.r_d, min(u_len)))) + torch.arange(min(self.r_d+self.r_u, min(u_len)), device=encoder_out.device)
|
||||
|
||||
import fast_rnnt
|
||||
am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning(
|
||||
am=self.joint_network.lin_enc(encoder_out),
|
||||
lm=self.joint_network.lin_dec(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
logits = self.joint_network(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
loss_trans = fast_rnnt.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=target.long(),
|
||||
ranges=ranges,
|
||||
termination_symbol=self.blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
cer_trans, wer_trans = None, None
|
||||
if not self.training and (self.report_cer or self.report_wer):
|
||||
if self.error_calculator is None:
|
||||
from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
|
||||
self.error_calculator = ErrorCalculator(
|
||||
self.decoder,
|
||||
self.joint_network,
|
||||
self.token_list,
|
||||
self.sym_space,
|
||||
self.sym_blank,
|
||||
report_cer=self.report_cer,
|
||||
report_wer=self.report_wer,
|
||||
)
|
||||
cer_trans, wer_trans = self.error_calculator(encoder_out, target, t_len)
|
||||
|
||||
loss_ctc, loss_lm = 0.0, 0.0
|
||||
|
||||
if self.use_auxiliary_ctc:
|
||||
loss_ctc = self._calc_ctc_loss(
|
||||
encoder_out,
|
||||
target,
|
||||
t_len,
|
||||
u_len,
|
||||
)
|
||||
|
||||
if self.use_auxiliary_lm_loss:
|
||||
loss_lm = self._calc_lm_loss(decoder_out, target)
|
||||
|
||||
loss = (
|
||||
self.transducer_weight * loss_trans
|
||||
+ self.auxiliary_ctc_weight * loss_ctc
|
||||
+ self.auxiliary_lm_loss_weight * loss_lm
|
||||
+ self.predictor_weight * loss_pre
|
||||
+ self.cif_weight * loss_cif
|
||||
)
|
||||
|
||||
stats = dict(
|
||||
loss=loss.detach(),
|
||||
loss_transducer=loss_trans.detach(),
|
||||
loss_pre=loss_pre.detach(),
|
||||
loss_cif=loss_cif.detach() if loss_cif > 0.0 else None,
|
||||
aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
|
||||
aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
|
||||
cer_transducer=cer_trans,
|
||||
wer_transducer=wer_trans,
|
||||
)
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Collect features sequences and features lengths sequences.
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
text: Label ID sequences. (B, L)
|
||||
text_lengths: Label ID sequences lengths. (B,)
|
||||
kwargs: Contains "utts_id".
|
||||
|
||||
Return:
|
||||
{}: "feats": Features sequences. (B, T, D_feats),
|
||||
"feats_lengths": Features sequences lengths. (B,)
|
||||
|
||||
"""
|
||||
if self.extract_feats_in_collect_stats:
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
else:
|
||||
# Generate dummy stats if extract_feats_in_collect_stats is False
|
||||
logging.warning(
|
||||
"Generating dummy stats for feats and feats_lengths, "
|
||||
"because encoder_conf.extract_feats_in_collect_stats is "
|
||||
f"{self.extract_feats_in_collect_stats}"
|
||||
)
|
||||
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
|
||||
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||
|
||||
def encode(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encoder speech sequences.
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
|
||||
Return:
|
||||
encoder_out: Encoder outputs. (B, T, D_enc)
|
||||
encoder_out_lens: Encoder outputs lengths. (B,)
|
||||
|
||||
"""
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
|
||||
# 2. Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
|
||||
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
||||
|
||||
assert encoder_out.size(0) == speech.size(0), (
|
||||
encoder_out.size(),
|
||||
speech.size(0),
|
||||
)
|
||||
assert encoder_out.size(1) <= encoder_out_lens.max(), (
|
||||
encoder_out.size(),
|
||||
encoder_out_lens.max(),
|
||||
)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Extract features sequences and features sequences lengths.
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
|
||||
Return:
|
||||
feats: Features sequences. (B, T, D_feats)
|
||||
feats_lengths: Features sequences lengths. (B,)
|
||||
|
||||
"""
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
|
||||
if self.frontend is not None:
|
||||
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
|
||||
return feats, feats_lengths
|
||||
|
||||
def _calc_ctc_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
t_len: torch.Tensor,
|
||||
u_len: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute CTC loss.
|
||||
|
||||
Args:
|
||||
encoder_out: Encoder output sequences. (B, T, D_enc)
|
||||
target: Target label ID sequences. (B, L)
|
||||
t_len: Encoder output sequences lengths. (B,)
|
||||
u_len: Target label ID sequences lengths. (B,)
|
||||
|
||||
Return:
|
||||
loss_ctc: CTC loss value.
|
||||
|
||||
"""
|
||||
ctc_in = self.ctc_lin(
|
||||
torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
|
||||
)
|
||||
ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
|
||||
|
||||
target_mask = target != 0
|
||||
ctc_target = target[target_mask].cpu()
|
||||
|
||||
with torch.backends.cudnn.flags(deterministic=True):
|
||||
loss_ctc = torch.nn.functional.ctc_loss(
|
||||
ctc_in,
|
||||
ctc_target,
|
||||
t_len,
|
||||
u_len,
|
||||
zero_infinity=True,
|
||||
reduction="sum",
|
||||
)
|
||||
loss_ctc /= target.size(0)
|
||||
|
||||
return loss_ctc
|
||||
|
||||
def _calc_lm_loss(
|
||||
self,
|
||||
decoder_out: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute LM loss.
|
||||
|
||||
Args:
|
||||
decoder_out: Decoder output sequences. (B, U, D_dec)
|
||||
target: Target label ID sequences. (B, L)
|
||||
|
||||
Return:
|
||||
loss_lm: LM loss value.
|
||||
|
||||
"""
|
||||
lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
|
||||
lm_target = target.view(-1).type(torch.int64)
|
||||
|
||||
with torch.no_grad():
|
||||
true_dist = lm_loss_in.clone()
|
||||
true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
|
||||
|
||||
# Ignore blank ID (0)
|
||||
ignore = lm_target == 0
|
||||
lm_target = lm_target.masked_fill(ignore, 0)
|
||||
|
||||
true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
|
||||
|
||||
loss_lm = torch.nn.functional.kl_div(
|
||||
torch.log_softmax(lm_loss_in, dim=1),
|
||||
true_dist,
|
||||
reduction="none",
|
||||
)
|
||||
loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
|
||||
0
|
||||
)
|
||||
|
||||
return loss_lm
|
||||
@ -353,11 +353,6 @@ class TransducerModel(FunASRModel):
|
||||
"""
|
||||
if self.criterion_transducer is None:
|
||||
try:
|
||||
# from warprnnt_pytorch import RNNTLoss
|
||||
# self.criterion_transducer = RNNTLoss(
|
||||
# reduction="mean",
|
||||
# fastemit_lambda=self.fastemit_lambda,
|
||||
# )
|
||||
from warp_rnnt import rnnt_loss as RNNTLoss
|
||||
self.criterion_transducer = RNNTLoss
|
||||
|
||||
@ -368,12 +363,6 @@ class TransducerModel(FunASRModel):
|
||||
)
|
||||
exit(1)
|
||||
|
||||
# loss_transducer = self.criterion_transducer(
|
||||
# joint_out,
|
||||
# target,
|
||||
# t_len,
|
||||
# u_len,
|
||||
# )
|
||||
log_probs = torch.log_softmax(joint_out, dim=-1)
|
||||
|
||||
loss_transducer = self.criterion_transducer(
|
||||
@ -637,7 +626,6 @@ class UnifiedTransducerModel(FunASRModel):
|
||||
|
||||
batch_size = speech.shape[0]
|
||||
text = text[:, : text_lengths.max()]
|
||||
#print(speech.shape)
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
@ -854,11 +842,6 @@ class UnifiedTransducerModel(FunASRModel):
|
||||
"""
|
||||
if self.criterion_transducer is None:
|
||||
try:
|
||||
# from warprnnt_pytorch import RNNTLoss
|
||||
# self.criterion_transducer = RNNTLoss(
|
||||
# reduction="mean",
|
||||
# fastemit_lambda=self.fastemit_lambda,
|
||||
# )
|
||||
from warp_rnnt import rnnt_loss as RNNTLoss
|
||||
self.criterion_transducer = RNNTLoss
|
||||
|
||||
@ -869,12 +852,6 @@ class UnifiedTransducerModel(FunASRModel):
|
||||
)
|
||||
exit(1)
|
||||
|
||||
# loss_transducer = self.criterion_transducer(
|
||||
# joint_out,
|
||||
# target,
|
||||
# t_len,
|
||||
# u_len,
|
||||
# )
|
||||
log_probs = torch.log_softmax(joint_out, dim=-1)
|
||||
|
||||
loss_transducer = self.criterion_transducer(
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import Tensor
|
||||
import logging
|
||||
import numpy as np
|
||||
from funasr.torch_utils.device_funcs import to_device
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.streaming_utils.utils import sequence_mask
|
||||
from typing import Optional, Tuple
|
||||
|
||||
class CifPredictor(nn.Module):
|
||||
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
|
||||
@ -747,3 +749,128 @@ class CifPredictorV3(nn.Module):
|
||||
predictor_alignments = index_div_bool_zeros_count_tile_out
|
||||
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
|
||||
return predictor_alignments.detach(), predictor_alignments_length.detach()
|
||||
|
||||
class BATPredictor(nn.Module):
|
||||
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, return_accum=False):
|
||||
super(BATPredictor, self).__init__()
|
||||
|
||||
self.pad = nn.ConstantPad1d((l_order, r_order), 0)
|
||||
self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
|
||||
self.cif_output = nn.Linear(idim, 1)
|
||||
self.dropout = torch.nn.Dropout(p=dropout)
|
||||
self.threshold = threshold
|
||||
self.smooth_factor = smooth_factor
|
||||
self.noise_threshold = noise_threshold
|
||||
self.return_accum = return_accum
|
||||
|
||||
def cif(
|
||||
self,
|
||||
input: Tensor,
|
||||
alpha: Tensor,
|
||||
beta: float = 1.0,
|
||||
return_accum: bool = False,
|
||||
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
B, S, C = input.size()
|
||||
assert tuple(alpha.size()) == (B, S), f"{alpha.size()} != {(B, S)}"
|
||||
|
||||
dtype = alpha.dtype
|
||||
alpha = alpha.float()
|
||||
|
||||
alpha_sum = alpha.sum(1)
|
||||
feat_lengths = (alpha_sum / beta).floor().long()
|
||||
T = feat_lengths.max()
|
||||
|
||||
# aggregate and integrate
|
||||
csum = alpha.cumsum(-1)
|
||||
with torch.no_grad():
|
||||
# indices used for scattering
|
||||
right_idx = (csum / beta).floor().long().clip(max=T)
|
||||
left_idx = right_idx.roll(1, dims=1)
|
||||
left_idx[:, 0] = 0
|
||||
|
||||
# count # of fires from each source
|
||||
fire_num = right_idx - left_idx
|
||||
extra_weights = (fire_num - 1).clip(min=0)
|
||||
# The extra entry in last dim is for
|
||||
output = input.new_zeros((B, T + 1, C))
|
||||
source_range = torch.arange(1, 1 + S).unsqueeze(0).type_as(input)
|
||||
zero = alpha.new_zeros((1,))
|
||||
|
||||
# right scatter
|
||||
fire_mask = fire_num > 0
|
||||
right_weight = torch.where(
|
||||
fire_mask,
|
||||
csum - right_idx.type_as(alpha) * beta,
|
||||
zero
|
||||
).type_as(input)
|
||||
# assert right_weight.ge(0).all(), f"{right_weight} should be non-negative."
|
||||
output.scatter_add_(
|
||||
1,
|
||||
right_idx.unsqueeze(-1).expand(-1, -1, C),
|
||||
right_weight.unsqueeze(-1) * input
|
||||
)
|
||||
|
||||
# left scatter
|
||||
left_weight = (
|
||||
alpha - right_weight - extra_weights.type_as(alpha) * beta
|
||||
).type_as(input)
|
||||
output.scatter_add_(
|
||||
1,
|
||||
left_idx.unsqueeze(-1).expand(-1, -1, C),
|
||||
left_weight.unsqueeze(-1) * input
|
||||
)
|
||||
|
||||
# extra scatters
|
||||
if extra_weights.ge(0).any():
|
||||
extra_steps = extra_weights.max().item()
|
||||
tgt_idx = left_idx
|
||||
src_feats = input * beta
|
||||
for _ in range(extra_steps):
|
||||
tgt_idx = (tgt_idx + 1).clip(max=T)
|
||||
# (B, S, 1)
|
||||
src_mask = (extra_weights > 0)
|
||||
output.scatter_add_(
|
||||
1,
|
||||
tgt_idx.unsqueeze(-1).expand(-1, -1, C),
|
||||
src_feats * src_mask.unsqueeze(2)
|
||||
)
|
||||
extra_weights -= 1
|
||||
|
||||
output = output[:, :T, :]
|
||||
|
||||
if return_accum:
|
||||
return output, csum
|
||||
else:
|
||||
return output, alpha
|
||||
|
||||
def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None, target_label_length=None):
|
||||
h = hidden
|
||||
context = h.transpose(1, 2)
|
||||
queries = self.pad(context)
|
||||
memory = self.cif_conv1d(queries)
|
||||
output = memory + context
|
||||
output = self.dropout(output)
|
||||
output = output.transpose(1, 2)
|
||||
output = torch.relu(output)
|
||||
output = self.cif_output(output)
|
||||
alphas = torch.sigmoid(output)
|
||||
alphas = torch.nn.functional.relu(alphas*self.smooth_factor - self.noise_threshold)
|
||||
if mask is not None:
|
||||
alphas = alphas * mask.transpose(-1, -2).float()
|
||||
if mask_chunk_predictor is not None:
|
||||
alphas = alphas * mask_chunk_predictor
|
||||
alphas = alphas.squeeze(-1)
|
||||
if target_label_length is not None:
|
||||
target_length = target_label_length
|
||||
elif target_label is not None:
|
||||
target_length = (target_label != ignore_id).float().sum(-1)
|
||||
# logging.info("target_length: {}".format(target_length))
|
||||
else:
|
||||
target_length = None
|
||||
token_num = alphas.sum(-1)
|
||||
if target_length is not None:
|
||||
# length_noise = torch.rand(alphas.size(0), device=alphas.device) - 0.5
|
||||
# target_length = length_noise + target_length
|
||||
alphas *= ((target_length + 1e-4) / token_num)[:, None].repeat(1, alphas.size(1))
|
||||
acoustic_embeds, cif_peak = self.cif(hidden, alphas, self.threshold, self.return_accum)
|
||||
return acoustic_embeds, token_num, alphas, cif_peak
|
||||
|
||||
@ -47,6 +47,7 @@ from funasr.models.e2e_asr_mfcca import MFCCA
|
||||
from funasr.models.e2e_sa_asr import SAASRModel
|
||||
from funasr.models.e2e_uni_asr import UniASR
|
||||
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
|
||||
from funasr.models.e2e_asr_bat import BATModel
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
|
||||
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
|
||||
@ -66,7 +67,7 @@ from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
|
||||
from funasr.models.postencoder.hugging_face_transformers_postencoder import (
|
||||
HuggingFaceTransformersPostEncoder, # noqa: H301
|
||||
)
|
||||
from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
|
||||
from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3, BATPredictor
|
||||
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr.models.preencoder.linear import LinearProjection
|
||||
from funasr.models.preencoder.sinc import LightweightSincConvs
|
||||
@ -135,6 +136,7 @@ model_choices = ClassChoices(
|
||||
timestamp_prediction=TimestampPredictor,
|
||||
rnnt=TransducerModel,
|
||||
rnnt_unified=UnifiedTransducerModel,
|
||||
bat=BATModel,
|
||||
sa_asr=SAASRModel,
|
||||
),
|
||||
type_check=FunASRModel,
|
||||
@ -266,6 +268,7 @@ predictor_choices = ClassChoices(
|
||||
ctc_predictor=None,
|
||||
cif_predictor_v2=CifPredictorV2,
|
||||
cif_predictor_v3=CifPredictorV3,
|
||||
bat_predictor=BATPredictor,
|
||||
),
|
||||
type_check=None,
|
||||
default="cif_predictor",
|
||||
@ -1508,6 +1511,139 @@ class ASRTransducerTask(ASRTask):
|
||||
|
||||
return model
|
||||
|
||||
class ASRBATTask(ASRTask):
|
||||
"""ASR Boundary Aware Transducer Task definition."""
|
||||
|
||||
num_optimizers: int = 1
|
||||
|
||||
class_choices_list = [
|
||||
model_choices,
|
||||
frontend_choices,
|
||||
specaug_choices,
|
||||
normalize_choices,
|
||||
encoder_choices,
|
||||
rnnt_decoder_choices,
|
||||
joint_network_choices,
|
||||
predictor_choices,
|
||||
]
|
||||
|
||||
trainer = Trainer
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args: argparse.Namespace) -> BATModel:
|
||||
"""Required data depending on task mode.
|
||||
Args:
|
||||
cls: ASRBATTask object.
|
||||
args: Task arguments.
|
||||
Return:
|
||||
model: ASR BAT model.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
if isinstance(args.token_list, str):
|
||||
with open(args.token_list, encoding="utf-8") as f:
|
||||
token_list = [line.rstrip() for line in f]
|
||||
|
||||
# Overwriting token_list to keep it as "portable".
|
||||
args.token_list = list(token_list)
|
||||
elif isinstance(args.token_list, (tuple, list)):
|
||||
token_list = list(args.token_list)
|
||||
else:
|
||||
raise RuntimeError("token_list must be str or list")
|
||||
vocab_size = len(token_list)
|
||||
logging.info(f"Vocabulary size: {vocab_size }")
|
||||
|
||||
# 1. frontend
|
||||
if args.input_size is None:
|
||||
# Extract features in the model
|
||||
frontend_class = frontend_choices.get_class(args.frontend)
|
||||
frontend = frontend_class(**args.frontend_conf)
|
||||
input_size = frontend.output_size()
|
||||
else:
|
||||
# Give features from data-loader
|
||||
frontend = None
|
||||
input_size = args.input_size
|
||||
|
||||
# 2. Data augmentation for spectrogram
|
||||
if args.specaug is not None:
|
||||
specaug_class = specaug_choices.get_class(args.specaug)
|
||||
specaug = specaug_class(**args.specaug_conf)
|
||||
else:
|
||||
specaug = None
|
||||
|
||||
# 3. Normalization layer
|
||||
if args.normalize is not None:
|
||||
normalize_class = normalize_choices.get_class(args.normalize)
|
||||
normalize = normalize_class(**args.normalize_conf)
|
||||
else:
|
||||
normalize = None
|
||||
|
||||
# 4. Encoder
|
||||
if getattr(args, "encoder", None) is not None:
|
||||
encoder_class = encoder_choices.get_class(args.encoder)
|
||||
encoder = encoder_class(input_size, **args.encoder_conf)
|
||||
else:
|
||||
encoder = Encoder(input_size, **args.encoder_conf)
|
||||
encoder_output_size = encoder.output_size()
|
||||
|
||||
# 5. Decoder
|
||||
rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
|
||||
decoder = rnnt_decoder_class(
|
||||
vocab_size,
|
||||
**args.rnnt_decoder_conf,
|
||||
)
|
||||
decoder_output_size = decoder.output_size
|
||||
|
||||
if getattr(args, "decoder", None) is not None:
|
||||
att_decoder_class = decoder_choices.get_class(args.decoder)
|
||||
|
||||
att_decoder = att_decoder_class(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
**args.decoder_conf,
|
||||
)
|
||||
else:
|
||||
att_decoder = None
|
||||
# 6. Joint Network
|
||||
joint_network = JointNetwork(
|
||||
vocab_size,
|
||||
encoder_output_size,
|
||||
decoder_output_size,
|
||||
**args.joint_network_conf,
|
||||
)
|
||||
|
||||
predictor_class = predictor_choices.get_class(args.predictor)
|
||||
predictor = predictor_class(**args.predictor_conf)
|
||||
|
||||
# 7. Build model
|
||||
try:
|
||||
model_class = model_choices.get_class(args.model)
|
||||
except AttributeError:
|
||||
model_class = model_choices.get_class("rnnt_unified")
|
||||
|
||||
model = model_class(
|
||||
vocab_size=vocab_size,
|
||||
token_list=token_list,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
att_decoder=att_decoder,
|
||||
joint_network=joint_network,
|
||||
predictor=predictor,
|
||||
**args.model_conf,
|
||||
)
|
||||
# 8. Initialize model
|
||||
if args.init is not None:
|
||||
raise NotImplementedError(
|
||||
"Currently not supported.",
|
||||
"Initialization part will be reworked in a short future.",
|
||||
)
|
||||
|
||||
#assert check_return_type(model)
|
||||
|
||||
return model
|
||||
|
||||
class ASRTaskSAASR(ASRTask):
|
||||
# If you need more than one optimizers, change this value
|
||||
|
||||
Loading…
Reference in New Issue
Block a user