From 607073619cedf2c114e1589aa6d5953d171f33bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Thu, 27 Apr 2023 19:27:49 +0800 Subject: [PATCH] update --- funasr/models/data2vec.py | 14 +- funasr/models/e2e_asr.py | 21 +- funasr/models/e2e_asr_mfcca.py | 148 ++++---- funasr/models/e2e_asr_paraformer.py | 511 +++++++++++++++++++++------- funasr/models/e2e_diar_eend_ola.py | 3 +- funasr/models/e2e_diar_sond.py | 28 +- funasr/models/e2e_sv.py | 30 +- funasr/models/e2e_tp.py | 26 +- funasr/models/e2e_uni_asr.py | 22 +- funasr/models/e2e_vad.py | 45 ++- 10 files changed, 579 insertions(+), 269 deletions(-) diff --git a/funasr/models/data2vec.py b/funasr/models/data2vec.py index 380c137b6..e5bd64034 100644 --- a/funasr/models/data2vec.py +++ b/funasr/models/data2vec.py @@ -12,7 +12,11 @@ from typing import Tuple import torch from typeguard import check_argument_types +from funasr.layers.abs_normalize import AbsNormalize +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.preencoder.abs_preencoder import AbsPreEncoder +from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.torch_utils.device_funcs import force_gatherable from funasr.models.base_model import FunASRModel @@ -30,11 +34,11 @@ class Data2VecPretrainModel(FunASRModel): def __init__( self, - frontend: Optional[torch.nn.Module], - specaug: Optional[torch.nn.Module], - normalize: Optional[torch.nn.Module], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], preencoder: Optional[AbsPreEncoder], - encoder: torch.nn.Module, + encoder: AbsEncoder, ): assert check_argument_types() @@ -53,7 +57,6 @@ class Data2VecPretrainModel(FunASRModel): speech_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Calc loss - Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) @@ -102,7 +105,6 @@ class Data2VecPretrainModel(FunASRModel): speech_lengths: torch.Tensor, ): """Frontend + Encoder. - Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) diff --git a/funasr/models/e2e_asr.py b/funasr/models/e2e_asr.py index 779d70365..8410ede18 100644 --- a/funasr/models/e2e_asr.py +++ b/funasr/models/e2e_asr.py @@ -13,18 +13,22 @@ from typing import Union import torch from typeguard import check_argument_types +from funasr.layers.abs_normalize import AbsNormalize from funasr.losses.label_smoothing_loss import ( LabelSmoothingLoss, # noqa: H301 ) from funasr.models.ctc import CTC -from funasr.models.frontend.abs_frontend import AbsFrontend -from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.models.decoder.abs_decoder import AbsDecoder -from funasr.models.base_model import FunASRModel +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.models.frontend.abs_frontend import AbsFrontend +from funasr.models.postencoder.abs_postencoder import AbsPostEncoder +from funasr.models.preencoder.abs_preencoder import AbsPreEncoder +from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.modules.add_sos_eos import add_sos_eos from funasr.modules.e2e_asr_common import ErrorCalculator from funasr.modules.nets_utils import th_accuracy from funasr.torch_utils.device_funcs import force_gatherable +from funasr.models.base_model import FunASRModel if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast @@ -43,9 +47,11 @@ class ESPnetASRModel(FunASRModel): vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], frontend: Optional[AbsFrontend], - specaug: Optional[torch.nn.Module], - normalize: Optional[torch.nn.Module], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + preencoder: Optional[AbsPreEncoder], encoder: AbsEncoder, + postencoder: Optional[AbsPostEncoder], decoder: AbsDecoder, ctc: CTC, ctc_weight: float = 0.5, @@ -127,7 +133,6 @@ class ESPnetASRModel(FunASRModel): text_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss - Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) @@ -243,7 +248,6 @@ class ESPnetASRModel(FunASRModel): self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder. Note that this method is used by asr_inference.py - Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) @@ -325,9 +329,7 @@ class ESPnetASRModel(FunASRModel): ys_pad_lens: torch.Tensor, ) -> torch.Tensor: """Compute negative log likelihood(nll) from transformer-decoder - Normally, this function is called in batchify_nll. - Args: encoder_out: (Batch, Length, Dim) encoder_out_lens: (Batch,) @@ -364,7 +366,6 @@ class ESPnetASRModel(FunASRModel): batch_size: int = 100, ): """Compute negative log likelihood(nll) from transformer-decoder - To avoid OOM, this fuction seperate the input into batches. Then call nll for each batch and combine and return results. Args: diff --git a/funasr/models/e2e_asr_mfcca.py b/funasr/models/e2e_asr_mfcca.py index efdd90dc7..44679ef2c 100644 --- a/funasr/models/e2e_asr_mfcca.py +++ b/funasr/models/e2e_asr_mfcca.py @@ -17,10 +17,13 @@ from funasr.losses.label_smoothing_loss import ( ) from funasr.models.ctc import CTC from funasr.models.decoder.abs_decoder import AbsDecoder +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.preencoder.abs_preencoder import AbsPreEncoder -from funasr.models.base_model import FunASRModel +from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.layers.abs_normalize import AbsNormalize from funasr.torch_utils.device_funcs import force_gatherable - +from funasr.models.base_model import FunASRModel if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast @@ -32,30 +35,36 @@ else: import pdb import random import math + + class MFCCA(FunASRModel): - """CTC-attention hybrid Encoder-Decoder model""" + """ + Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University + MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario + https://arxiv.org/abs/2210.05265 + """ def __init__( - self, - vocab_size: int, - token_list: Union[Tuple[str, ...], List[str]], - frontend: Optional[torch.nn.Module], - specaug: Optional[torch.nn.Module], - normalize: Optional[torch.nn.Module], - preencoder: Optional[AbsPreEncoder], - encoder: torch.nn.Module, - decoder: AbsDecoder, - ctc: CTC, - rnnt_decoder: None, - ctc_weight: float = 0.5, - ignore_id: int = -1, - lsm_weight: float = 0.0, - mask_ratio: float = 0.0, - length_normalized_loss: bool = False, - report_cer: bool = True, - report_wer: bool = True, - sym_space: str = "", - sym_blank: str = "", + self, + vocab_size: int, + token_list: Union[Tuple[str, ...], List[str]], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + preencoder: Optional[AbsPreEncoder], + encoder: AbsEncoder, + decoder: AbsDecoder, + ctc: CTC, + rnnt_decoder: None, + ctc_weight: float = 0.5, + ignore_id: int = -1, + lsm_weight: float = 0.0, + mask_ratio: float = 0.0, + length_normalized_loss: bool = False, + report_cer: bool = True, + report_wer: bool = True, + sym_space: str = "", + sym_blank: str = "", ): assert check_argument_types() assert 0.0 <= ctc_weight <= 1.0, ctc_weight @@ -69,10 +78,9 @@ class MFCCA(FunASRModel): self.ignore_id = ignore_id self.ctc_weight = ctc_weight self.token_list = token_list.copy() - + self.mask_ratio = mask_ratio - self.frontend = frontend self.specaug = specaug self.normalize = normalize @@ -106,14 +114,13 @@ class MFCCA(FunASRModel): self.error_calculator = None def forward( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - text: torch.Tensor, - text_lengths: torch.Tensor, + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss - Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) @@ -123,22 +130,22 @@ class MFCCA(FunASRModel): assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert ( - speech.shape[0] - == speech_lengths.shape[0] - == text.shape[0] - == text_lengths.shape[0] + speech.shape[0] + == speech_lengths.shape[0] + == text.shape[0] + == text_lengths.shape[0] ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) - #pdb.set_trace() - if(speech.dim()==3 and speech.size(2)==8 and self.mask_ratio !=0): + # pdb.set_trace() + if (speech.dim() == 3 and speech.size(2) == 8 and self.mask_ratio != 0): rate_num = random.random() - #rate_num = 0.1 - if(rate_num<=self.mask_ratio): - retain_channel = math.ceil(random.random() *8) - if(retain_channel>1): - speech = speech[:,:,torch.randperm(8)[0:retain_channel].sort().values] + # rate_num = 0.1 + if (rate_num <= self.mask_ratio): + retain_channel = math.ceil(random.random() * 8) + if (retain_channel > 1): + speech = speech[:, :, torch.randperm(8)[0:retain_channel].sort().values] else: - speech = speech[:,:,torch.randperm(8)[0]] - #pdb.set_trace() + speech = speech[:, :, torch.randperm(8)[0]] + # pdb.set_trace() batch_size = speech.shape[0] # for data-parallel text = text[:, : text_lengths.max()] @@ -188,20 +195,19 @@ class MFCCA(FunASRModel): return loss, stats, weight def collect_feats( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - text: torch.Tensor, - text_lengths: torch.Tensor, + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, ) -> Dict[str, torch.Tensor]: feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths) return {"feats": feats, "feats_lengths": feats_lengths} def encode( - self, speech: torch.Tensor, speech_lengths: torch.Tensor + self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder. Note that this method is used by asr_inference.py - Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) @@ -220,14 +226,14 @@ class MFCCA(FunASRModel): # Pre-encoder, e.g. used for raw input data if self.preencoder is not None: feats, feats_lengths = self.preencoder(feats, feats_lengths) - #pdb.set_trace() + # pdb.set_trace() encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size) assert encoder_out.size(0) == speech.size(0), ( encoder_out.size(), speech.size(0), ) - if(encoder_out.dim()==4): + if (encoder_out.dim() == 4): assert encoder_out.size(2) <= encoder_out_lens.max(), ( encoder_out.size(), encoder_out_lens.max(), @@ -241,7 +247,7 @@ class MFCCA(FunASRModel): return encoder_out, encoder_out_lens def _extract_feats( - self, speech: torch.Tensor, speech_lengths: torch.Tensor + self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: assert speech_lengths.dim() == 1, speech_lengths.shape # for data-parallel @@ -259,11 +265,11 @@ class MFCCA(FunASRModel): return feats, feats_lengths, channel_size def _calc_att_loss( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ys_pad: torch.Tensor, - ys_pad_lens: torch.Tensor, + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, ): ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_in_lens = ys_pad_lens + 1 @@ -291,14 +297,14 @@ class MFCCA(FunASRModel): return loss_att, acc_att, cer_att, wer_att def _calc_ctc_loss( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ys_pad: torch.Tensor, - ys_pad_lens: torch.Tensor, + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, ): # Calc CTC loss - if(encoder_out.dim()==4): + if (encoder_out.dim() == 4): encoder_out = encoder_out.mean(1) loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) @@ -310,10 +316,10 @@ class MFCCA(FunASRModel): return loss_ctc, cer_ctc def _calc_rnnt_loss( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ys_pad: torch.Tensor, - ys_pad_lens: torch.Tensor, + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, ): - raise NotImplementedError + raise NotImplementedError \ No newline at end of file diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py index f414e4fd0..9d4f10663 100644 --- a/funasr/models/e2e_asr_paraformer.py +++ b/funasr/models/e2e_asr_paraformer.py @@ -12,23 +12,26 @@ import random import numpy as np from typeguard import check_argument_types +from funasr.layers.abs_normalize import AbsNormalize from funasr.losses.label_smoothing_loss import ( LabelSmoothingLoss, # noqa: H301 ) from funasr.models.ctc import CTC from funasr.models.decoder.abs_decoder import AbsDecoder from funasr.models.e2e_asr_common import ErrorCalculator +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.postencoder.abs_postencoder import AbsPostEncoder from funasr.models.predictor.cif import mae_loss from funasr.models.preencoder.abs_preencoder import AbsPreEncoder -from funasr.models.base_model import FunASRModel +from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.modules.add_sos_eos import add_sos_eos from funasr.modules.nets_utils import make_pad_mask, pad_list from funasr.modules.nets_utils import th_accuracy from funasr.torch_utils.device_funcs import force_gatherable +from funasr.models.base_model import FunASRModel from funasr.models.predictor.cif import CifPredictorV3 - if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast else: @@ -40,7 +43,7 @@ else: class Paraformer(FunASRModel): """ - Author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition https://arxiv.org/abs/2206.08317 """ @@ -49,10 +52,12 @@ class Paraformer(FunASRModel): self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], - frontend: Optional[torch.nn.Module], - specaug: Optional[torch.nn.Module], - normalize: Optional[torch.nn.Module], - encoder: torch.nn.Module, + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + preencoder: Optional[AbsPreEncoder], + encoder: AbsEncoder, + postencoder: Optional[AbsPostEncoder], decoder: AbsDecoder, ctc: CTC, ctc_weight: float = 0.5, @@ -92,8 +97,17 @@ class Paraformer(FunASRModel): self.frontend = frontend self.specaug = specaug self.normalize = normalize + self.preencoder = preencoder + self.postencoder = postencoder self.encoder = encoder + if not hasattr(self.encoder, "interctc_use_conditioning"): + self.encoder.interctc_use_conditioning = False + if self.encoder.interctc_use_conditioning: + self.encoder.conditioning_layer = torch.nn.Linear( + vocab_size, self.encoder.output_size() + ) + self.error_calculator = None if ctc_weight == 1.0: @@ -138,7 +152,6 @@ class Paraformer(FunASRModel): text_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss - Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) @@ -161,7 +174,9 @@ class Paraformer(FunASRModel): # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + intermediate_outs = None if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] encoder_out = encoder_out[0] loss_att, acc_att, cer_att, wer_att = None, None, None, None @@ -179,6 +194,30 @@ class Paraformer(FunASRModel): stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None stats["cer_ctc"] = cer_ctc + # Intermediate CTC (optional) + loss_interctc = 0.0 + if self.interctc_weight != 0.0 and intermediate_outs is not None: + for layer_idx, intermediate_out in intermediate_outs: + # we assume intermediate_out has the same length & padding + # as those of encoder_out + loss_ic, cer_ic = self._calc_ctc_loss( + intermediate_out, encoder_out_lens, text, text_lengths + ) + loss_interctc = loss_interctc + loss_ic + + # Collect Intermedaite CTC stats + stats["loss_interctc_layer{}".format(layer_idx)] = ( + loss_ic.detach() if loss_ic is not None else None + ) + stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic + + loss_interctc = loss_interctc / len(intermediate_outs) + + # calculate whole encoder loss + loss_ctc = ( + 1 - self.interctc_weight + ) * loss_ctc + self.interctc_weight * loss_interctc + # 2b. Attention decoder branch if self.ctc_weight != 1.0: loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss( @@ -229,7 +268,6 @@ class Paraformer(FunASRModel): self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder. Note that this method is used by asr_inference.py - Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) @@ -246,8 +284,29 @@ class Paraformer(FunASRModel): if self.normalize is not None: feats, feats_lengths = self.normalize(feats, feats_lengths) + # Pre-encoder, e.g. used for raw input data + if self.preencoder is not None: + feats, feats_lengths = self.preencoder(feats, feats_lengths) + # 4. Forward encoder - encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + if self.encoder.interctc_use_conditioning: + encoder_out, encoder_out_lens, _ = self.encoder( + feats, feats_lengths, ctc=self.ctc + ) + else: + encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + # Post-encoder, e.g. NLU + if self.postencoder is not None: + encoder_out, encoder_out_lens = self.postencoder( + encoder_out, encoder_out_lens + ) assert encoder_out.size(0) == speech.size(0), ( encoder_out.size(), @@ -258,45 +317,18 @@ class Paraformer(FunASRModel): encoder_out_lens.max(), ) + if intermediate_outs is not None: + return (encoder_out, intermediate_outs), encoder_out_lens + return encoder_out, encoder_out_lens - def encode_chunk( - self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Frontend + Encoder. Note that this method is used by asr_inference.py - - Args: - speech: (Batch, Length, ...) - speech_lengths: (Batch, ) - """ - with autocast(False): - # 1. Extract feats - feats, feats_lengths = self._extract_feats(speech, speech_lengths) - - # 2. Data augmentation - if self.specaug is not None and self.training: - feats, feats_lengths = self.specaug(feats, feats_lengths) - - # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN - if self.normalize is not None: - feats, feats_lengths = self.normalize(feats, feats_lengths) - - # 4. Forward encoder - encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"]) - - return encoder_out, torch.tensor([encoder_out.size(1)]) - def calc_predictor(self, encoder_out, encoder_out_lens): encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( encoder_out.device) - pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask, - ignore_id=self.ignore_id) - return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index - - def calc_predictor_chunk(self, encoder_out, cache=None): - - pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor.forward_chunk(encoder_out, cache["encoder"]) + pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, + encoder_out_mask, + ignore_id=self.ignore_id) return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens): @@ -308,14 +340,6 @@ class Paraformer(FunASRModel): decoder_out = torch.log_softmax(decoder_out, dim=-1) return decoder_out, ys_pad_lens - def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None): - decoder_outs = self.decoder.forward_chunk( - encoder_out, sematic_embeds, cache["decoder"] - ) - decoder_out = decoder_outs - decoder_out = torch.log_softmax(decoder_out, dim=-1) - return decoder_out - def _extract_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -342,9 +366,7 @@ class Paraformer(FunASRModel): ys_pad_lens: torch.Tensor, ) -> torch.Tensor: """Compute negative log likelihood(nll) from transformer-decoder - Normally, this function is called in batchify_nll. - Args: encoder_out: (Batch, Length, Dim) encoder_out_lens: (Batch,) @@ -381,7 +403,6 @@ class Paraformer(FunASRModel): batch_size: int = 100, ): """Compute negative log likelihood(nll) from transformer-decoder - To avoid OOM, this fuction seperate the input into batches. Then call nll for each batch and combine and return results. Args: @@ -521,9 +542,186 @@ class Paraformer(FunASRModel): return loss_ctc, cer_ctc -class ParaformerBert(Paraformer): +class ParaformerOnline(Paraformer): """ Author: Speech Lab, Alibaba Group, China + Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition + https://arxiv.org/abs/2206.08317 + """ + + def __init__( + self, *args, **kwargs, + ): + super().__init__(*args, **kwargs) + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Decoder + Calc loss + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert ( + speech.shape[0] + == speech_lengths.shape[0] + == text.shape[0] + == text_lengths.shape[0] + ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + batch_size = speech.shape[0] + self.step_cur += 1 + # for data-parallel + text = text[:, : text_lengths.max()] + speech = speech[:, :speech_lengths.max()] + + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + loss_att, acc_att, cer_att, wer_att = None, None, None, None + loss_ctc, cer_ctc = None, None + loss_pre = None + stats = dict() + + # 1. CTC branch + if self.ctc_weight != 0.0: + loss_ctc, cer_ctc = self._calc_ctc_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + # Collect CTC branch stats + stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None + stats["cer_ctc"] = cer_ctc + + # Intermediate CTC (optional) + loss_interctc = 0.0 + if self.interctc_weight != 0.0 and intermediate_outs is not None: + for layer_idx, intermediate_out in intermediate_outs: + # we assume intermediate_out has the same length & padding + # as those of encoder_out + loss_ic, cer_ic = self._calc_ctc_loss( + intermediate_out, encoder_out_lens, text, text_lengths + ) + loss_interctc = loss_interctc + loss_ic + + # Collect Intermedaite CTC stats + stats["loss_interctc_layer{}".format(layer_idx)] = ( + loss_ic.detach() if loss_ic is not None else None + ) + stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic + + loss_interctc = loss_interctc / len(intermediate_outs) + + # calculate whole encoder loss + loss_ctc = ( + 1 - self.interctc_weight + ) * loss_ctc + self.interctc_weight * loss_interctc + + # 2b. Attention decoder branch + if self.ctc_weight != 1.0: + loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + # 3. CTC-Att loss definition + if self.ctc_weight == 0.0: + loss = loss_att + loss_pre * self.predictor_weight + elif self.ctc_weight == 1.0: + loss = loss_ctc + else: + loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + + # Collect Attn branch stats + stats["loss_att"] = loss_att.detach() if loss_att is not None else None + stats["acc"] = acc_att + stats["cer"] = cer_att + stats["wer"] = wer_att + stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None + + stats["loss"] = torch.clone(loss.detach()) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def encode_chunk( + self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Frontend + Encoder. Note that this method is used by asr_inference.py + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # Pre-encoder, e.g. used for raw input data + if self.preencoder is not None: + feats, feats_lengths = self.preencoder(feats, feats_lengths) + + # 4. Forward encoder + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + if self.encoder.interctc_use_conditioning: + encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk( + feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc + ) + else: + encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"]) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + # Post-encoder, e.g. NLU + if self.postencoder is not None: + encoder_out, encoder_out_lens = self.postencoder( + encoder_out, encoder_out_lens + ) + + if intermediate_outs is not None: + return (encoder_out, intermediate_outs), encoder_out_lens + + return encoder_out, torch.tensor([encoder_out.size(1)]) + + def calc_predictor_chunk(self, encoder_out, cache=None): + + pre_acoustic_embeds, pre_token_length = \ + self.predictor.forward_chunk(encoder_out, cache["encoder"]) + return pre_acoustic_embeds, pre_token_length + + def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None): + decoder_outs = self.decoder.forward_chunk( + encoder_out, sematic_embeds, cache["decoder"] + ) + decoder_out = decoder_outs + decoder_out = torch.log_softmax(decoder_out, dim=-1) + return decoder_out + + +class ParaformerBert(Paraformer): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group Paraformer2: advanced paraformer with LFMMI and bert for non-autoregressive end-to-end speech recognition """ @@ -531,11 +729,11 @@ class ParaformerBert(Paraformer): self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], - frontend: Optional[torch.nn.Module], - specaug: Optional[torch.nn.Module], - normalize: Optional[torch.nn.Module], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], preencoder: Optional[AbsPreEncoder], - encoder: torch.nn.Module, + encoder: AbsEncoder, postencoder: Optional[AbsPostEncoder], decoder: AbsDecoder, ctc: CTC, @@ -690,7 +888,6 @@ class ParaformerBert(Paraformer): embed_lengths: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss - Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) @@ -799,74 +996,73 @@ class ParaformerBert(Paraformer): class BiCifParaformer(Paraformer): - """ Paraformer model with an extra cif predictor to conduct accurate timestamp prediction """ def __init__( - self, - vocab_size: int, - token_list: Union[Tuple[str, ...], List[str]], - frontend: Optional[torch.nn.Module], - specaug: Optional[torch.nn.Module], - normalize: Optional[torch.nn.Module], - preencoder: Optional[AbsPreEncoder], - encoder: torch.nn.Module, - postencoder: Optional[AbsPostEncoder], - decoder: AbsDecoder, - ctc: CTC, - ctc_weight: float = 0.5, - interctc_weight: float = 0.0, - ignore_id: int = -1, - blank_id: int = 0, - sos: int = 1, - eos: int = 2, - lsm_weight: float = 0.0, - length_normalized_loss: bool = False, - report_cer: bool = True, - report_wer: bool = True, - sym_space: str = "", - sym_blank: str = "", - extract_feats_in_collect_stats: bool = True, - predictor = None, - predictor_weight: float = 0.0, - predictor_bias: int = 0, - sampling_ratio: float = 0.2, + self, + vocab_size: int, + token_list: Union[Tuple[str, ...], List[str]], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + preencoder: Optional[AbsPreEncoder], + encoder: AbsEncoder, + postencoder: Optional[AbsPostEncoder], + decoder: AbsDecoder, + ctc: CTC, + ctc_weight: float = 0.5, + interctc_weight: float = 0.0, + ignore_id: int = -1, + blank_id: int = 0, + sos: int = 1, + eos: int = 2, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + report_cer: bool = True, + report_wer: bool = True, + sym_space: str = "", + sym_blank: str = "", + extract_feats_in_collect_stats: bool = True, + predictor=None, + predictor_weight: float = 0.0, + predictor_bias: int = 0, + sampling_ratio: float = 0.2, ): assert check_argument_types() assert 0.0 <= ctc_weight <= 1.0, ctc_weight assert 0.0 <= interctc_weight < 1.0, interctc_weight super().__init__( - vocab_size=vocab_size, - token_list=token_list, - frontend=frontend, - specaug=specaug, - normalize=normalize, - preencoder=preencoder, - encoder=encoder, - postencoder=postencoder, - decoder=decoder, - ctc=ctc, - ctc_weight=ctc_weight, - interctc_weight=interctc_weight, - ignore_id=ignore_id, - blank_id=blank_id, - sos=sos, - eos=eos, - lsm_weight=lsm_weight, - length_normalized_loss=length_normalized_loss, - report_cer=report_cer, - report_wer=report_wer, - sym_space=sym_space, - sym_blank=sym_blank, - extract_feats_in_collect_stats=extract_feats_in_collect_stats, - predictor=predictor, - predictor_weight=predictor_weight, - predictor_bias=predictor_bias, - sampling_ratio=sampling_ratio, + vocab_size=vocab_size, + token_list=token_list, + frontend=frontend, + specaug=specaug, + normalize=normalize, + preencoder=preencoder, + encoder=encoder, + postencoder=postencoder, + decoder=decoder, + ctc=ctc, + ctc_weight=ctc_weight, + interctc_weight=interctc_weight, + ignore_id=ignore_id, + blank_id=blank_id, + sos=sos, + eos=eos, + lsm_weight=lsm_weight, + length_normalized_loss=length_normalized_loss, + report_cer=report_cer, + report_wer=report_wer, + sym_space=sym_space, + sym_blank=sym_blank, + extract_feats_in_collect_stats=extract_feats_in_collect_stats, + predictor=predictor, + predictor_weight=predictor_weight, + predictor_bias=predictor_bias, + sampling_ratio=sampling_ratio, ) assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3" @@ -888,21 +1084,77 @@ class BiCifParaformer(Paraformer): loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2) return loss_pre2 - + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( + encoder_out.device) + if self.predictor_bias == 1: + _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_pad_lens = ys_pad_lens + self.predictor_bias + pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad, + encoder_out_mask, + ignore_id=self.ignore_id) + + # 0. sampler + decoder_out_1st = None + if self.sampling_ratio > 0.0: + if self.step_cur < 2: + logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) + sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, + pre_acoustic_embeds) + else: + if self.step_cur < 2: + logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) + sematic_embeds = pre_acoustic_embeds + + # 1. Forward decoder + decoder_outs = self.decoder( + encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens + ) + decoder_out, _ = decoder_outs[0], decoder_outs[1] + + if decoder_out_1st is None: + decoder_out_1st = decoder_out + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_pad) + acc_att = th_accuracy( + decoder_out_1st.view(-1, self.vocab_size), + ys_pad, + ignore_label=self.ignore_id, + ) + loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out_1st.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att, loss_pre + def calc_predictor(self, encoder_out, encoder_out_lens): encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( encoder_out.device) - pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, None, encoder_out_mask, - ignore_id=self.ignore_id) + pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, + None, + encoder_out_mask, + ignore_id=self.ignore_id) return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index - + def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num): encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( encoder_out.device) ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out, - encoder_out_mask, - token_num) + encoder_out_mask, + token_num) return ds_alphas, ds_cif_peak, us_alphas, us_peaks def forward( @@ -913,7 +1165,6 @@ class BiCifParaformer(Paraformer): text_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss - Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) @@ -996,7 +1247,8 @@ class BiCifParaformer(Paraformer): elif self.ctc_weight == 1.0: loss = loss_ctc else: - loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5 + loss = self.ctc_weight * loss_ctc + ( + 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5 # Collect Attn branch stats stats["loss_att"] = loss_att.detach() if loss_att is not None else None @@ -1022,11 +1274,11 @@ class ContextualParaformer(Paraformer): self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], - frontend: Optional[torch.nn.Module], - specaug: Optional[torch.nn.Module], - normalize: Optional[torch.nn.Module], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], preencoder: Optional[AbsPreEncoder], - encoder: torch.nn.Module, + encoder: AbsEncoder, postencoder: Optional[AbsPostEncoder], decoder: AbsDecoder, ctc: CTC, @@ -1120,7 +1372,6 @@ class ContextualParaformer(Paraformer): text_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss - Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) @@ -1504,4 +1755,4 @@ class ContextualParaformer(Paraformer): "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape)) - return var_dict_torch_update + return var_dict_torch_update \ No newline at end of file diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py index b4a3fa290..da7c6745a 100644 --- a/funasr/models/e2e_diar_eend_ola.py +++ b/funasr/models/e2e_diar_eend_ola.py @@ -15,8 +15,8 @@ from funasr.models.frontend.wav_frontend import WavFrontendMel23 from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor from funasr.modules.eend_ola.utils.power import generate_mapping_dict -from funasr.models.base_model import FunASRModel from funasr.torch_utils.device_funcs import force_gatherable +from funasr.models.base_model import FunASRModel if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): pass @@ -91,7 +91,6 @@ class DiarEENDOLAModel(FunASRModel): text_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss - Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py index dc7135f4e..9c3fb92bc 100644 --- a/funasr/models/e2e_diar_sond.py +++ b/funasr/models/e2e_diar_sond.py @@ -14,9 +14,15 @@ import torch from torch.nn import functional as F from typeguard import check_argument_types +from funasr.modules.nets_utils import to_device from funasr.modules.nets_utils import make_pad_mask -from funasr.models.base_model import FunASRModel +from funasr.models.decoder.abs_decoder import AbsDecoder +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.models.frontend.abs_frontend import AbsFrontend +from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.layers.abs_normalize import AbsNormalize from funasr.torch_utils.device_funcs import force_gatherable +from funasr.models.base_model import FunASRModel from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy from funasr.utils.misc import int2vec @@ -30,16 +36,20 @@ else: class DiarSondModel(FunASRModel): - """Speaker overlap-aware neural diarization model - reference: https://arxiv.org/abs/2211.10243 + """ + Author: Speech Lab, Alibaba Group, China + SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis + https://arxiv.org/abs/2211.10243 + TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization + https://arxiv.org/abs/2303.05397 """ def __init__( self, vocab_size: int, - frontend: Optional[torch.nn.Module], - specaug: Optional[torch.nn.Module], - normalize: Optional[torch.nn.Module], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], encoder: torch.nn.Module, speaker_encoder: Optional[torch.nn.Module], ci_scorer: torch.nn.Module, @@ -105,7 +115,6 @@ class DiarSondModel(FunASRModel): binary_labels_lengths: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss - Args: speech: (Batch, samples) or (Batch, frames, input_size) speech_lengths: (Batch,) default None for chunk interator, @@ -342,7 +351,7 @@ class DiarSondModel(FunASRModel): cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1]) cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1]) - if isinstance(self.ci_scorer, torch.nn.Module): + if isinstance(self.ci_scorer, AbsEncoder): ci_simi = self.ci_scorer(ge_in, ge_len)[0] ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute([0, 2, 1]) else: @@ -381,7 +390,6 @@ class DiarSondModel(FunASRModel): self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder - Args: speech: (Batch, Length, ...) speech_lengths: (Batch,) @@ -481,4 +489,4 @@ class DiarSondModel(FunASRModel): speaker_miss, speaker_falarm, speaker_error, - ) + ) \ No newline at end of file diff --git a/funasr/models/e2e_sv.py b/funasr/models/e2e_sv.py index 582c25dfb..bd82c7c35 100644 --- a/funasr/models/e2e_sv.py +++ b/funasr/models/e2e_sv.py @@ -1,3 +1,8 @@ + +""" +Author: Speech Lab, Alibaba Group, China +""" + import logging from contextlib import contextmanager from distutils.version import LooseVersion @@ -10,11 +15,22 @@ from typing import Union import torch from typeguard import check_argument_types +from funasr.layers.abs_normalize import AbsNormalize +from funasr.losses.label_smoothing_loss import ( + LabelSmoothingLoss, # noqa: H301 +) +from funasr.models.ctc import CTC from funasr.models.decoder.abs_decoder import AbsDecoder +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.postencoder.abs_postencoder import AbsPostEncoder from funasr.models.preencoder.abs_preencoder import AbsPreEncoder -from funasr.models.base_model import FunASRModel +from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.modules.add_sos_eos import add_sos_eos +from funasr.modules.e2e_asr_common import ErrorCalculator +from funasr.modules.nets_utils import th_accuracy from funasr.torch_utils.device_funcs import force_gatherable +from funasr.models.base_model import FunASRModel if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast @@ -32,11 +48,11 @@ class ESPnetSVModel(FunASRModel): self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], - frontend: Optional[torch.nn.Module], - specaug: Optional[torch.nn.Module], - normalize: Optional[torch.nn.Module], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], preencoder: Optional[AbsPreEncoder], - encoder: torch.nn.Module, + encoder: AbsEncoder, postencoder: Optional[AbsPostEncoder], pooling_layer: torch.nn.Module, decoder: AbsDecoder, @@ -65,7 +81,6 @@ class ESPnetSVModel(FunASRModel): text_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss - Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) @@ -206,7 +221,6 @@ class ESPnetSVModel(FunASRModel): self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder. Note that this method is used by asr_inference.py - Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) @@ -256,4 +270,4 @@ class ESPnetSVModel(FunASRModel): else: # No frontend and no feature extract feats, feats_lengths = speech, speech_lengths - return feats, feats_lengths + return feats, feats_lengths \ No newline at end of file diff --git a/funasr/models/e2e_tp.py b/funasr/models/e2e_tp.py index c5dc63c59..39419c8d4 100644 --- a/funasr/models/e2e_tp.py +++ b/funasr/models/e2e_tp.py @@ -2,20 +2,24 @@ import logging from contextlib import contextmanager from distutils.version import LooseVersion from typing import Dict +from typing import List from typing import Optional from typing import Tuple +from typing import Union import torch +import numpy as np from typeguard import check_argument_types +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.predictor.cif import mae_loss -from funasr.models.base_model import FunASRModel from funasr.modules.add_sos_eos import add_sos_eos from funasr.modules.nets_utils import make_pad_mask, pad_list from funasr.torch_utils.device_funcs import force_gatherable +from funasr.models.base_model import FunASRModel from funasr.models.predictor.cif import CifPredictorV3 - if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast else: @@ -25,15 +29,15 @@ else: yield -class TimestampPredictor(FunASRModel): +class TimestampPredictor(AbsESPnetModel): """ - Author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group """ def __init__( self, - frontend: Optional[torch.nn.Module], - encoder: torch.nn.Module, + frontend: Optional[AbsFrontend], + encoder: AbsEncoder, predictor: CifPredictorV3, predictor_bias: int = 0, token_list=None, @@ -51,7 +55,7 @@ class TimestampPredictor(FunASRModel): self.predictor_bias = predictor_bias self.criterion_pre = mae_loss() self.token_list = token_list - + def forward( self, speech: torch.Tensor, @@ -60,7 +64,6 @@ class TimestampPredictor(FunASRModel): text_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss - Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) @@ -108,7 +111,6 @@ class TimestampPredictor(FunASRModel): self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder. Note that this method is used by asr_inference.py - Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) @@ -123,7 +125,7 @@ class TimestampPredictor(FunASRModel): encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) return encoder_out, encoder_out_lens - + def _extract_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -146,8 +148,8 @@ class TimestampPredictor(FunASRModel): encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( encoder_out.device) ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out, - encoder_out_mask, - token_num) + encoder_out_mask, + token_num) return ds_alphas, ds_cif_peak, us_alphas, us_peaks def collect_feats( diff --git a/funasr/models/e2e_uni_asr.py b/funasr/models/e2e_uni_asr.py index 0c533899f..d08ea37fd 100644 --- a/funasr/models/e2e_uni_asr.py +++ b/funasr/models/e2e_uni_asr.py @@ -17,10 +17,13 @@ from funasr.losses.label_smoothing_loss import ( LabelSmoothingLoss, # noqa: H301 ) from funasr.models.ctc import CTC -from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.models.decoder.abs_decoder import AbsDecoder +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.postencoder.abs_postencoder import AbsPostEncoder from funasr.models.preencoder.abs_preencoder import AbsPreEncoder +from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.layers.abs_normalize import AbsNormalize from funasr.torch_utils.device_funcs import force_gatherable from funasr.models.base_model import FunASRModel from funasr.modules.streaming_utils.chunk_utilis import sequence_mask @@ -37,18 +40,18 @@ else: class UniASR(FunASRModel): """ - Author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group """ def __init__( self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], - frontend: Optional[torch.nn.Module], - specaug: Optional[torch.nn.Module], - normalize: Optional[torch.nn.Module], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], preencoder: Optional[AbsPreEncoder], - encoder: torch.nn.Module, + encoder: AbsEncoder, postencoder: Optional[AbsPostEncoder], decoder: AbsDecoder, ctc: CTC, @@ -176,7 +179,6 @@ class UniASR(FunASRModel): decoding_ind: int = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss - Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) @@ -466,7 +468,6 @@ class UniASR(FunASRModel): self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder. Note that this method is used by asr_inference.py - Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) @@ -530,7 +531,6 @@ class UniASR(FunASRModel): ind: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder. Note that this method is used by asr_inference.py - Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) @@ -624,9 +624,7 @@ class UniASR(FunASRModel): ys_pad_lens: torch.Tensor, ) -> torch.Tensor: """Compute negative log likelihood(nll) from transformer-decoder - Normally, this function is called in batchify_nll. - Args: encoder_out: (Batch, Length, Dim) encoder_out_lens: (Batch,) @@ -663,7 +661,6 @@ class UniASR(FunASRModel): batch_size: int = 100, ): """Compute negative log likelihood(nll) from transformer-decoder - To avoid OOM, this fuction seperate the input into batches. Then call nll for each batch and combine and return results. Args: @@ -1069,4 +1066,3 @@ class UniASR(FunASRModel): ys_hat = self.ctc2.argmax(encoder_out).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) return loss_ctc, cer_ctc - diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py index ff3742949..e477750d0 100644 --- a/funasr/models/e2e_vad.py +++ b/funasr/models/e2e_vad.py @@ -35,6 +35,12 @@ class VadDetectMode(Enum): class VADXOptions: + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__( self, sample_rate: int = 16000, @@ -99,6 +105,12 @@ class VADXOptions: class E2EVadSpeechBufWithDoa(object): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__(self): self.start_ms = 0 self.end_ms = 0 @@ -117,6 +129,12 @@ class E2EVadSpeechBufWithDoa(object): class E2EVadFrameProb(object): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__(self): self.noise_prob = 0.0 self.speech_prob = 0.0 @@ -126,6 +144,12 @@ class E2EVadFrameProb(object): class WindowDetector(object): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__(self, window_size_ms: int, sil_to_speech_time: int, speech_to_sil_time: int, frame_size_ms: int): self.window_size_ms = window_size_ms @@ -192,6 +216,12 @@ class WindowDetector(object): class E2EVadModel(nn.Module): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None): super(E2EVadModel, self).__init__() self.vad_opts = VADXOptions(**vad_post_args) @@ -286,7 +316,7 @@ class E2EVadModel(nn.Module): 0.000001)) def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None: - scores = self.encoder(feats, in_cache) # return B * T * D + scores = self.encoder(feats, in_cache).to('cpu') # return B * T * D assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match" self.vad_opts.nn_eval_block_size = scores.shape[1] self.frm_cnt += scores.shape[1] # count total frames @@ -444,7 +474,7 @@ class E2EVadModel(nn.Module): - 1)) / self.vad_opts.noise_frame_num_used_for_snr return frame_state - + def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]: @@ -460,8 +490,9 @@ class E2EVadModel(nn.Module): segment_batch = [] if len(self.output_data_buf) > 0: for i in range(self.output_data_buf_offset, len(self.output_data_buf)): - if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[ - i].contain_seg_end_point: + if not is_final and ( + not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[ + i].contain_seg_end_point): continue segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms] segment_batch.append(segment) @@ -474,11 +505,11 @@ class E2EVadModel(nn.Module): return segments, in_cache def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(), - is_final: bool = False, max_end_sil: int = 800 - ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]: + is_final: bool = False, max_end_sil: int = 800 + ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]: self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres self.waveform = waveform # compute decibel for each frame - + self.ComputeScores(feats, in_cache) self.ComputeDecibel() if not is_final: