diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index 37a5fe464..81513aea8 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -1604,6 +1604,8 @@ def inference_launch(**kwargs): return inference_mfcca(**kwargs) elif mode == "rnnt": return inference_transducer(**kwargs) + elif mode == "bat": + return inference_transducer(**kwargs) elif mode == "sa_asr": return inference_sa_asr(**kwargs) else: diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py index a76b20435..6606d3077 100644 --- a/funasr/build_utils/build_asr_model.py +++ b/funasr/build_utils/build_asr_model.py @@ -26,6 +26,7 @@ from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer from funasr.models.e2e_asr_mfcca import MFCCA from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel +from funasr.models.e2e_asr_bat import BATModel from funasr.models.e2e_sa_asr import SAASRModel from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer @@ -46,7 +47,7 @@ from funasr.models.frontend.s3prl import S3prlFrontend from funasr.models.frontend.wav_frontend import WavFrontend from funasr.models.frontend.windowing import SlidingWindow from funasr.models.joint_net.joint_network import JointNetwork -from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3 +from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3, BATPredictor from funasr.models.specaug.specaug import SpecAug from funasr.models.specaug.specaug import SpecAugLFR from funasr.modules.subsampling import Conv1dSubsampling @@ -99,7 +100,7 @@ model_choices = ClassChoices( rnnt=TransducerModel, rnnt_unified=UnifiedTransducerModel, sa_asr=SAASRModel, - + bat=BATModel, ), default="asr", ) @@ -188,6 +189,7 @@ predictor_choices = ClassChoices( ctc_predictor=None, cif_predictor_v2=CifPredictorV2, cif_predictor_v3=CifPredictorV3, + bat_predictor=BATPredictor, ), default="cif_predictor", optional=True, @@ -313,12 +315,15 @@ def build_asr_model(args): encoder = encoder_class(input_size=input_size, **args.encoder_conf) # decoder - decoder_class = decoder_choices.get_class(args.decoder) - decoder = decoder_class( - vocab_size=vocab_size, - encoder_output_size=encoder.output_size(), - **args.decoder_conf, - ) + if hasattr(args, "decoder") and args.decoder is not None: + decoder_class = decoder_choices.get_class(args.decoder) + decoder = decoder_class( + vocab_size=vocab_size, + encoder_output_size=encoder.output_size(), + **args.decoder_conf, + ) + else: + decoder = None # ctc ctc = CTC( @@ -463,6 +468,53 @@ def build_asr_model(args): joint_network=joint_network, **args.model_conf, ) + elif args.model == "bat": + # 5. Decoder + encoder_output_size = encoder.output_size() + + rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder) + decoder = rnnt_decoder_class( + vocab_size, + **args.rnnt_decoder_conf, + ) + decoder_output_size = decoder.output_size + + if getattr(args, "decoder", None) is not None: + att_decoder_class = decoder_choices.get_class(args.decoder) + + att_decoder = att_decoder_class( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + **args.decoder_conf, + ) + else: + att_decoder = None + # 6. Joint Network + joint_network = JointNetwork( + vocab_size, + encoder_output_size, + decoder_output_size, + **args.joint_network_conf, + ) + + predictor_class = predictor_choices.get_class(args.predictor) + predictor = predictor_class(**args.predictor_conf) + + model_class = model_choices.get_class(args.model) + # 7. Build model + model = model_class( + vocab_size=vocab_size, + token_list=token_list, + frontend=frontend, + specaug=specaug, + normalize=normalize, + encoder=encoder, + decoder=decoder, + att_decoder=att_decoder, + joint_network=joint_network, + predictor=predictor, + **args.model_conf, + ) elif args.model == "sa_asr": asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder) asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf) diff --git a/funasr/models/e2e_asr_bat.py b/funasr/models/e2e_asr_bat.py new file mode 100644 index 000000000..9627292ec --- /dev/null +++ b/funasr/models/e2e_asr_bat.py @@ -0,0 +1,496 @@ +"""Boundary Aware Transducer (BAT) model.""" + +import logging +from contextlib import contextmanager +from typing import Dict, List, Optional, Tuple, Union + +import torch +from packaging.version import parse as V +from funasr.losses.label_smoothing_loss import ( + LabelSmoothingLoss, # noqa: H301 +) +from funasr.models.frontend.abs_frontend import AbsFrontend +from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.models.decoder.rnnt_decoder import RNNTDecoder +from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.models.joint_net.joint_network import JointNetwork +from funasr.modules.nets_utils import get_transducer_task_io +from funasr.modules.nets_utils import th_accuracy +from funasr.modules.nets_utils import make_pad_mask +from funasr.modules.add_sos_eos import add_sos_eos +from funasr.layers.abs_normalize import AbsNormalize +from funasr.torch_utils.device_funcs import force_gatherable +from funasr.models.base_model import FunASRModel + +if V(torch.__version__) >= V("1.6.0"): + from torch.cuda.amp import autocast +else: + + @contextmanager + def autocast(enabled=True): + yield + + +class BATModel(FunASRModel): + """BATModel module definition. + + Args: + vocab_size: Size of complete vocabulary (w/ EOS and blank included). + token_list: List of token + frontend: Frontend module. + specaug: SpecAugment module. + normalize: Normalization module. + encoder: Encoder module. + decoder: Decoder module. + joint_network: Joint Network module. + transducer_weight: Weight of the Transducer loss. + fastemit_lambda: FastEmit lambda value. + auxiliary_ctc_weight: Weight of auxiliary CTC loss. + auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs. + auxiliary_lm_loss_weight: Weight of auxiliary LM loss. + auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing. + ignore_id: Initial padding ID. + sym_space: Space symbol. + sym_blank: Blank Symbol + report_cer: Whether to report Character Error Rate during validation. + report_wer: Whether to report Word Error Rate during validation. + extract_feats_in_collect_stats: Whether to use extract_feats stats collection. + + """ + + def __init__( + self, + vocab_size: int, + token_list: Union[Tuple[str, ...], List[str]], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + encoder: AbsEncoder, + decoder: RNNTDecoder, + joint_network: JointNetwork, + att_decoder: Optional[AbsAttDecoder] = None, + predictor = None, + transducer_weight: float = 1.0, + predictor_weight: float = 1.0, + cif_weight: float = 1.0, + fastemit_lambda: float = 0.0, + auxiliary_ctc_weight: float = 0.0, + auxiliary_ctc_dropout_rate: float = 0.0, + auxiliary_lm_loss_weight: float = 0.0, + auxiliary_lm_loss_smoothing: float = 0.0, + ignore_id: int = -1, + sym_space: str = "", + sym_blank: str = "", + report_cer: bool = True, + report_wer: bool = True, + extract_feats_in_collect_stats: bool = True, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + r_d: int = 5, + r_u: int = 5, + ) -> None: + """Construct an BATModel object.""" + super().__init__() + + # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos) + self.blank_id = 0 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.token_list = token_list.copy() + + self.sym_space = sym_space + self.sym_blank = sym_blank + + self.frontend = frontend + self.specaug = specaug + self.normalize = normalize + + self.encoder = encoder + self.decoder = decoder + self.joint_network = joint_network + + self.criterion_transducer = None + self.error_calculator = None + + self.use_auxiliary_ctc = auxiliary_ctc_weight > 0 + self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0 + + if self.use_auxiliary_ctc: + self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size) + self.ctc_dropout_rate = auxiliary_ctc_dropout_rate + + if self.use_auxiliary_lm_loss: + self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size) + self.lm_loss_smoothing = auxiliary_lm_loss_smoothing + + self.transducer_weight = transducer_weight + self.fastemit_lambda = fastemit_lambda + + self.auxiliary_ctc_weight = auxiliary_ctc_weight + self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight + + self.report_cer = report_cer + self.report_wer = report_wer + + self.extract_feats_in_collect_stats = extract_feats_in_collect_stats + + self.criterion_pre = torch.nn.L1Loss() + self.predictor_weight = predictor_weight + self.predictor = predictor + + self.cif_weight = cif_weight + if self.cif_weight > 0: + self.cif_output_layer = torch.nn.Linear(encoder.output_size(), vocab_size) + self.criterion_cif = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + self.r_d = r_d + self.r_u = r_u + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Forward architecture and compute loss(es). + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + text: Label ID sequences. (B, L) + text_lengths: Label ID sequences lengths. (B,) + kwargs: Contains "utts_id". + + Return: + loss: Main loss value. + stats: Task statistics. + weight: Task weights. + + """ + assert text_lengths.dim() == 1, text_lengths.shape + assert ( + speech.shape[0] + == speech_lengths.shape[0] + == text.shape[0] + == text_lengths.shape[0] + ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + + batch_size = speech.shape[0] + text = text[:, : text_lengths.max()] + + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + if hasattr(self.encoder, 'overlap_chunk_cls') and self.encoder.overlap_chunk_cls is not None: + encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens, + chunk_outs=None) + + encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(encoder_out.device) + # 2. Transducer-related I/O preparation + decoder_in, target, t_len, u_len = get_transducer_task_io( + text, + encoder_out_lens, + ignore_id=self.ignore_id, + ) + + # 3. Decoder + self.decoder.set_device(encoder_out.device) + decoder_out = self.decoder(decoder_in, u_len) + + pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=self.ignore_id) + loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length), pre_token_length) + + if self.cif_weight > 0.0: + cif_predict = self.cif_output_layer(pre_acoustic_embeds) + loss_cif = self.criterion_cif(cif_predict, text) + else: + loss_cif = 0.0 + + # 5. Losses + boundary = torch.zeros((encoder_out.size(0), 4), dtype=torch.int64, device=encoder_out.device) + boundary[:, 2] = u_len.long().detach() + boundary[:, 3] = t_len.long().detach() + + pre_peak_index = torch.floor(pre_peak_index).long() + s_begin = pre_peak_index - self.r_d + + T = encoder_out.size(1) + B = encoder_out.size(0) + U = decoder_out.size(1) + + mask = torch.arange(0, T, device=encoder_out.device).reshape(1, T).expand(B, T) + mask = mask <= boundary[:, 3].reshape(B, 1) - 1 + + s_begin_padding = boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1 + # handle the cases where `len(symbols) < s_range` + s_begin_padding = torch.clamp(s_begin_padding, min=0) + + s_begin = torch.where(mask, s_begin, s_begin_padding) + + mask2 = s_begin < boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1 + + s_begin = torch.where(mask2, s_begin, boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1) + + s_begin = torch.clamp(s_begin, min=0) + + ranges = s_begin.reshape((B, T, 1)).expand((B, T, min(self.r_u+self.r_d, min(u_len)))) + torch.arange(min(self.r_d+self.r_u, min(u_len)), device=encoder_out.device) + + import fast_rnnt + am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning( + am=self.joint_network.lin_enc(encoder_out), + lm=self.joint_network.lin_dec(decoder_out), + ranges=ranges, + ) + + logits = self.joint_network(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + loss_trans = fast_rnnt.rnnt_loss_pruned( + logits=logits.float(), + symbols=target.long(), + ranges=ranges, + termination_symbol=self.blank_id, + boundary=boundary, + reduction="sum", + ) + + cer_trans, wer_trans = None, None + if not self.training and (self.report_cer or self.report_wer): + if self.error_calculator is None: + from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator + self.error_calculator = ErrorCalculator( + self.decoder, + self.joint_network, + self.token_list, + self.sym_space, + self.sym_blank, + report_cer=self.report_cer, + report_wer=self.report_wer, + ) + cer_trans, wer_trans = self.error_calculator(encoder_out, target, t_len) + + loss_ctc, loss_lm = 0.0, 0.0 + + if self.use_auxiliary_ctc: + loss_ctc = self._calc_ctc_loss( + encoder_out, + target, + t_len, + u_len, + ) + + if self.use_auxiliary_lm_loss: + loss_lm = self._calc_lm_loss(decoder_out, target) + + loss = ( + self.transducer_weight * loss_trans + + self.auxiliary_ctc_weight * loss_ctc + + self.auxiliary_lm_loss_weight * loss_lm + + self.predictor_weight * loss_pre + + self.cif_weight * loss_cif + ) + + stats = dict( + loss=loss.detach(), + loss_transducer=loss_trans.detach(), + loss_pre=loss_pre.detach(), + loss_cif=loss_cif.detach() if loss_cif > 0.0 else None, + aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None, + aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None, + cer_transducer=cer_trans, + wer_transducer=wer_trans, + ) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + + return loss, stats, weight + + def collect_feats( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Dict[str, torch.Tensor]: + """Collect features sequences and features lengths sequences. + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + text: Label ID sequences. (B, L) + text_lengths: Label ID sequences lengths. (B,) + kwargs: Contains "utts_id". + + Return: + {}: "feats": Features sequences. (B, T, D_feats), + "feats_lengths": Features sequences lengths. (B,) + + """ + if self.extract_feats_in_collect_stats: + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + else: + # Generate dummy stats if extract_feats_in_collect_stats is False + logging.warning( + "Generating dummy stats for feats and feats_lengths, " + "because encoder_conf.extract_feats_in_collect_stats is " + f"{self.extract_feats_in_collect_stats}" + ) + + feats, feats_lengths = speech, speech_lengths + + return {"feats": feats, "feats_lengths": feats_lengths} + + def encode( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encoder speech sequences. + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + + Return: + encoder_out: Encoder outputs. (B, T, D_enc) + encoder_out_lens: Encoder outputs lengths. (B,) + + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # 4. Forward encoder + encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) + + assert encoder_out.size(0) == speech.size(0), ( + encoder_out.size(), + speech.size(0), + ) + assert encoder_out.size(1) <= encoder_out_lens.max(), ( + encoder_out.size(), + encoder_out_lens.max(), + ) + + return encoder_out, encoder_out_lens + + def _extract_feats( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Extract features sequences and features sequences lengths. + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + + Return: + feats: Features sequences. (B, T, D_feats) + feats_lengths: Features sequences lengths. (B,) + + """ + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, : speech_lengths.max()] + + if self.frontend is not None: + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + feats, feats_lengths = speech, speech_lengths + + return feats, feats_lengths + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + target: torch.Tensor, + t_len: torch.Tensor, + u_len: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + target: Target label ID sequences. (B, L) + t_len: Encoder output sequences lengths. (B,) + u_len: Target label ID sequences lengths. (B,) + + Return: + loss_ctc: CTC loss value. + + """ + ctc_in = self.ctc_lin( + torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate) + ) + ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1) + + target_mask = target != 0 + ctc_target = target[target_mask].cpu() + + with torch.backends.cudnn.flags(deterministic=True): + loss_ctc = torch.nn.functional.ctc_loss( + ctc_in, + ctc_target, + t_len, + u_len, + zero_infinity=True, + reduction="sum", + ) + loss_ctc /= target.size(0) + + return loss_ctc + + def _calc_lm_loss( + self, + decoder_out: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + """Compute LM loss. + + Args: + decoder_out: Decoder output sequences. (B, U, D_dec) + target: Target label ID sequences. (B, L) + + Return: + loss_lm: LM loss value. + + """ + lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size) + lm_target = target.view(-1).type(torch.int64) + + with torch.no_grad(): + true_dist = lm_loss_in.clone() + true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1)) + + # Ignore blank ID (0) + ignore = lm_target == 0 + lm_target = lm_target.masked_fill(ignore, 0) + + true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing)) + + loss_lm = torch.nn.functional.kl_div( + torch.log_softmax(lm_loss_in, dim=1), + true_dist, + reduction="none", + ) + loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size( + 0 + ) + + return loss_lm diff --git a/funasr/models/e2e_asr_transducer.py b/funasr/models/e2e_asr_transducer.py index 80914b119..729e918dd 100644 --- a/funasr/models/e2e_asr_transducer.py +++ b/funasr/models/e2e_asr_transducer.py @@ -353,11 +353,6 @@ class TransducerModel(FunASRModel): """ if self.criterion_transducer is None: try: - # from warprnnt_pytorch import RNNTLoss - # self.criterion_transducer = RNNTLoss( - # reduction="mean", - # fastemit_lambda=self.fastemit_lambda, - # ) from warp_rnnt import rnnt_loss as RNNTLoss self.criterion_transducer = RNNTLoss @@ -368,12 +363,6 @@ class TransducerModel(FunASRModel): ) exit(1) - # loss_transducer = self.criterion_transducer( - # joint_out, - # target, - # t_len, - # u_len, - # ) log_probs = torch.log_softmax(joint_out, dim=-1) loss_transducer = self.criterion_transducer( @@ -637,7 +626,6 @@ class UnifiedTransducerModel(FunASRModel): batch_size = speech.shape[0] text = text[:, : text_lengths.max()] - #print(speech.shape) # 1. Encoder encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths) @@ -854,11 +842,6 @@ class UnifiedTransducerModel(FunASRModel): """ if self.criterion_transducer is None: try: - # from warprnnt_pytorch import RNNTLoss - # self.criterion_transducer = RNNTLoss( - # reduction="mean", - # fastemit_lambda=self.fastemit_lambda, - # ) from warp_rnnt import rnnt_loss as RNNTLoss self.criterion_transducer = RNNTLoss @@ -869,12 +852,6 @@ class UnifiedTransducerModel(FunASRModel): ) exit(1) - # loss_transducer = self.criterion_transducer( - # joint_out, - # target, - # t_len, - # u_len, - # ) log_probs = torch.log_softmax(joint_out, dim=-1) loss_transducer = self.criterion_transducer( diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py index 3c363dbab..c66af94e5 100644 --- a/funasr/models/predictor/cif.py +++ b/funasr/models/predictor/cif.py @@ -1,10 +1,12 @@ import torch from torch import nn +from torch import Tensor import logging import numpy as np from funasr.torch_utils.device_funcs import to_device from funasr.modules.nets_utils import make_pad_mask from funasr.modules.streaming_utils.utils import sequence_mask +from typing import Optional, Tuple class CifPredictor(nn.Module): def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45): @@ -747,3 +749,128 @@ class CifPredictorV3(nn.Module): predictor_alignments = index_div_bool_zeros_count_tile_out predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype) return predictor_alignments.detach(), predictor_alignments_length.detach() + +class BATPredictor(nn.Module): + def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, return_accum=False): + super(BATPredictor, self).__init__() + + self.pad = nn.ConstantPad1d((l_order, r_order), 0) + self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim) + self.cif_output = nn.Linear(idim, 1) + self.dropout = torch.nn.Dropout(p=dropout) + self.threshold = threshold + self.smooth_factor = smooth_factor + self.noise_threshold = noise_threshold + self.return_accum = return_accum + + def cif( + self, + input: Tensor, + alpha: Tensor, + beta: float = 1.0, + return_accum: bool = False, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + B, S, C = input.size() + assert tuple(alpha.size()) == (B, S), f"{alpha.size()} != {(B, S)}" + + dtype = alpha.dtype + alpha = alpha.float() + + alpha_sum = alpha.sum(1) + feat_lengths = (alpha_sum / beta).floor().long() + T = feat_lengths.max() + + # aggregate and integrate + csum = alpha.cumsum(-1) + with torch.no_grad(): + # indices used for scattering + right_idx = (csum / beta).floor().long().clip(max=T) + left_idx = right_idx.roll(1, dims=1) + left_idx[:, 0] = 0 + + # count # of fires from each source + fire_num = right_idx - left_idx + extra_weights = (fire_num - 1).clip(min=0) + # The extra entry in last dim is for + output = input.new_zeros((B, T + 1, C)) + source_range = torch.arange(1, 1 + S).unsqueeze(0).type_as(input) + zero = alpha.new_zeros((1,)) + + # right scatter + fire_mask = fire_num > 0 + right_weight = torch.where( + fire_mask, + csum - right_idx.type_as(alpha) * beta, + zero + ).type_as(input) + # assert right_weight.ge(0).all(), f"{right_weight} should be non-negative." + output.scatter_add_( + 1, + right_idx.unsqueeze(-1).expand(-1, -1, C), + right_weight.unsqueeze(-1) * input + ) + + # left scatter + left_weight = ( + alpha - right_weight - extra_weights.type_as(alpha) * beta + ).type_as(input) + output.scatter_add_( + 1, + left_idx.unsqueeze(-1).expand(-1, -1, C), + left_weight.unsqueeze(-1) * input + ) + + # extra scatters + if extra_weights.ge(0).any(): + extra_steps = extra_weights.max().item() + tgt_idx = left_idx + src_feats = input * beta + for _ in range(extra_steps): + tgt_idx = (tgt_idx + 1).clip(max=T) + # (B, S, 1) + src_mask = (extra_weights > 0) + output.scatter_add_( + 1, + tgt_idx.unsqueeze(-1).expand(-1, -1, C), + src_feats * src_mask.unsqueeze(2) + ) + extra_weights -= 1 + + output = output[:, :T, :] + + if return_accum: + return output, csum + else: + return output, alpha + + def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None, target_label_length=None): + h = hidden + context = h.transpose(1, 2) + queries = self.pad(context) + memory = self.cif_conv1d(queries) + output = memory + context + output = self.dropout(output) + output = output.transpose(1, 2) + output = torch.relu(output) + output = self.cif_output(output) + alphas = torch.sigmoid(output) + alphas = torch.nn.functional.relu(alphas*self.smooth_factor - self.noise_threshold) + if mask is not None: + alphas = alphas * mask.transpose(-1, -2).float() + if mask_chunk_predictor is not None: + alphas = alphas * mask_chunk_predictor + alphas = alphas.squeeze(-1) + if target_label_length is not None: + target_length = target_label_length + elif target_label is not None: + target_length = (target_label != ignore_id).float().sum(-1) + # logging.info("target_length: {}".format(target_length)) + else: + target_length = None + token_num = alphas.sum(-1) + if target_length is not None: + # length_noise = torch.rand(alphas.size(0), device=alphas.device) - 0.5 + # target_length = length_noise + target_length + alphas *= ((target_length + 1e-4) / token_num)[:, None].repeat(1, alphas.size(1)) + acoustic_embeds, cif_peak = self.cif(hidden, alphas, self.threshold, self.return_accum) + return acoustic_embeds, token_num, alphas, cif_peak diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py index 4b94aeb94..39e0ea9f1 100644 --- a/funasr/tasks/asr.py +++ b/funasr/tasks/asr.py @@ -47,6 +47,7 @@ from funasr.models.e2e_asr_mfcca import MFCCA from funasr.models.e2e_sa_asr import SAASRModel from funasr.models.e2e_uni_asr import UniASR from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel +from funasr.models.e2e_asr_bat import BATModel from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder from funasr.models.encoder.data2vec_encoder import Data2VecEncoder @@ -66,7 +67,7 @@ from funasr.models.postencoder.abs_postencoder import AbsPostEncoder from funasr.models.postencoder.hugging_face_transformers_postencoder import ( HuggingFaceTransformersPostEncoder, # noqa: H301 ) -from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3 +from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3, BATPredictor from funasr.models.preencoder.abs_preencoder import AbsPreEncoder from funasr.models.preencoder.linear import LinearProjection from funasr.models.preencoder.sinc import LightweightSincConvs @@ -135,6 +136,7 @@ model_choices = ClassChoices( timestamp_prediction=TimestampPredictor, rnnt=TransducerModel, rnnt_unified=UnifiedTransducerModel, + bat=BATModel, sa_asr=SAASRModel, ), type_check=FunASRModel, @@ -266,6 +268,7 @@ predictor_choices = ClassChoices( ctc_predictor=None, cif_predictor_v2=CifPredictorV2, cif_predictor_v3=CifPredictorV3, + bat_predictor=BATPredictor, ), type_check=None, default="cif_predictor", @@ -1508,6 +1511,139 @@ class ASRTransducerTask(ASRTask): return model +class ASRBATTask(ASRTask): + """ASR Boundary Aware Transducer Task definition.""" + + num_optimizers: int = 1 + + class_choices_list = [ + model_choices, + frontend_choices, + specaug_choices, + normalize_choices, + encoder_choices, + rnnt_decoder_choices, + joint_network_choices, + predictor_choices, + ] + + trainer = Trainer + + @classmethod + def build_model(cls, args: argparse.Namespace) -> BATModel: + """Required data depending on task mode. + Args: + cls: ASRBATTask object. + args: Task arguments. + Return: + model: ASR BAT model. + """ + assert check_argument_types() + + if isinstance(args.token_list, str): + with open(args.token_list, encoding="utf-8") as f: + token_list = [line.rstrip() for line in f] + + # Overwriting token_list to keep it as "portable". + args.token_list = list(token_list) + elif isinstance(args.token_list, (tuple, list)): + token_list = list(args.token_list) + else: + raise RuntimeError("token_list must be str or list") + vocab_size = len(token_list) + logging.info(f"Vocabulary size: {vocab_size }") + + # 1. frontend + if args.input_size is None: + # Extract features in the model + frontend_class = frontend_choices.get_class(args.frontend) + frontend = frontend_class(**args.frontend_conf) + input_size = frontend.output_size() + else: + # Give features from data-loader + frontend = None + input_size = args.input_size + + # 2. Data augmentation for spectrogram + if args.specaug is not None: + specaug_class = specaug_choices.get_class(args.specaug) + specaug = specaug_class(**args.specaug_conf) + else: + specaug = None + + # 3. Normalization layer + if args.normalize is not None: + normalize_class = normalize_choices.get_class(args.normalize) + normalize = normalize_class(**args.normalize_conf) + else: + normalize = None + + # 4. Encoder + if getattr(args, "encoder", None) is not None: + encoder_class = encoder_choices.get_class(args.encoder) + encoder = encoder_class(input_size, **args.encoder_conf) + else: + encoder = Encoder(input_size, **args.encoder_conf) + encoder_output_size = encoder.output_size() + + # 5. Decoder + rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder) + decoder = rnnt_decoder_class( + vocab_size, + **args.rnnt_decoder_conf, + ) + decoder_output_size = decoder.output_size + + if getattr(args, "decoder", None) is not None: + att_decoder_class = decoder_choices.get_class(args.decoder) + + att_decoder = att_decoder_class( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + **args.decoder_conf, + ) + else: + att_decoder = None + # 6. Joint Network + joint_network = JointNetwork( + vocab_size, + encoder_output_size, + decoder_output_size, + **args.joint_network_conf, + ) + + predictor_class = predictor_choices.get_class(args.predictor) + predictor = predictor_class(**args.predictor_conf) + + # 7. Build model + try: + model_class = model_choices.get_class(args.model) + except AttributeError: + model_class = model_choices.get_class("rnnt_unified") + + model = model_class( + vocab_size=vocab_size, + token_list=token_list, + frontend=frontend, + specaug=specaug, + normalize=normalize, + encoder=encoder, + decoder=decoder, + att_decoder=att_decoder, + joint_network=joint_network, + predictor=predictor, + **args.model_conf, + ) + # 8. Initialize model + if args.init is not None: + raise NotImplementedError( + "Currently not supported.", + "Initialization part will be reworked in a short future.", + ) + + #assert check_return_type(model) + + return model class ASRTaskSAASR(ASRTask): # If you need more than one optimizers, change this value