From 25590804bab251dc60a055b0093a369ef6fdd6ee Mon Sep 17 00:00:00 2001 From: "haoneng.lhn" Date: Mon, 17 Apr 2023 15:19:56 +0800 Subject: [PATCH] update --- funasr/models/e2e_asr_paraformer.py | 241 +++++++++++++++++++------- funasr/models/encoder/sanm_encoder.py | 6 +- funasr/modules/embedding.py | 23 ++- funasr/tasks/asr.py | 3 +- 4 files changed, 204 insertions(+), 69 deletions(-) diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py index 5c8560d00..39902510e 100644 --- a/funasr/models/e2e_asr_paraformer.py +++ b/funasr/models/e2e_asr_paraformer.py @@ -325,56 +325,6 @@ class Paraformer(AbsESPnetModel): return encoder_out, encoder_out_lens - def encode_chunk( - self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Frontend + Encoder. Note that this method is used by asr_inference.py - - Args: - speech: (Batch, Length, ...) - speech_lengths: (Batch, ) - """ - with autocast(False): - # 1. Extract feats - feats, feats_lengths = self._extract_feats(speech, speech_lengths) - - # 2. Data augmentation - if self.specaug is not None and self.training: - feats, feats_lengths = self.specaug(feats, feats_lengths) - - # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN - if self.normalize is not None: - feats, feats_lengths = self.normalize(feats, feats_lengths) - - # Pre-encoder, e.g. used for raw input data - if self.preencoder is not None: - feats, feats_lengths = self.preencoder(feats, feats_lengths) - - # 4. Forward encoder - # feats: (Batch, Length, Dim) - # -> encoder_out: (Batch, Length2, Dim2) - if self.encoder.interctc_use_conditioning: - encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk( - feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc - ) - else: - encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"]) - intermediate_outs = None - if isinstance(encoder_out, tuple): - intermediate_outs = encoder_out[1] - encoder_out = encoder_out[0] - - # Post-encoder, e.g. NLU - if self.postencoder is not None: - encoder_out, encoder_out_lens = self.postencoder( - encoder_out, encoder_out_lens - ) - - if intermediate_outs is not None: - return (encoder_out, intermediate_outs), encoder_out_lens - - return encoder_out, torch.tensor([encoder_out.size(1)]) - def calc_predictor(self, encoder_out, encoder_out_lens): encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( @@ -383,11 +333,6 @@ class Paraformer(AbsESPnetModel): ignore_id=self.ignore_id) return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index - def calc_predictor_chunk(self, encoder_out, cache=None): - - pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor.forward_chunk(encoder_out, cache["encoder"]) - return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index - def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens): decoder_outs = self.decoder( @@ -397,14 +342,6 @@ class Paraformer(AbsESPnetModel): decoder_out = torch.log_softmax(decoder_out, dim=-1) return decoder_out, ys_pad_lens - def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None): - decoder_outs = self.decoder.forward_chunk( - encoder_out, sematic_embeds, cache["decoder"] - ) - decoder_out = decoder_outs - decoder_out = torch.log_softmax(decoder_out, dim=-1) - return decoder_out - def _extract_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -610,6 +547,184 @@ class Paraformer(AbsESPnetModel): return loss_ctc, cer_ctc +class ParaformerOnline(Paraformer): + """ + Author: Speech Lab, Alibaba Group, China + Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition + https://arxiv.org/abs/2206.08317 + """ + + def __init__( + self, *args, **kwargs, + ): + super().__init__(*args, **kwargs) + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Decoder + Calc loss + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert ( + speech.shape[0] + == speech_lengths.shape[0] + == text.shape[0] + == text_lengths.shape[0] + ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + batch_size = speech.shape[0] + self.step_cur += 1 + # for data-parallel + text = text[:, : text_lengths.max()] + speech = speech[:, :speech_lengths.max()] + + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + loss_att, acc_att, cer_att, wer_att = None, None, None, None + loss_ctc, cer_ctc = None, None + loss_pre = None + stats = dict() + + # 1. CTC branch + if self.ctc_weight != 0.0: + loss_ctc, cer_ctc = self._calc_ctc_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + # Collect CTC branch stats + stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None + stats["cer_ctc"] = cer_ctc + + # Intermediate CTC (optional) + loss_interctc = 0.0 + if self.interctc_weight != 0.0 and intermediate_outs is not None: + for layer_idx, intermediate_out in intermediate_outs: + # we assume intermediate_out has the same length & padding + # as those of encoder_out + loss_ic, cer_ic = self._calc_ctc_loss( + intermediate_out, encoder_out_lens, text, text_lengths + ) + loss_interctc = loss_interctc + loss_ic + + # Collect Intermedaite CTC stats + stats["loss_interctc_layer{}".format(layer_idx)] = ( + loss_ic.detach() if loss_ic is not None else None + ) + stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic + + loss_interctc = loss_interctc / len(intermediate_outs) + + # calculate whole encoder loss + loss_ctc = ( + 1 - self.interctc_weight + ) * loss_ctc + self.interctc_weight * loss_interctc + + # 2b. Attention decoder branch + if self.ctc_weight != 1.0: + loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + # 3. CTC-Att loss definition + if self.ctc_weight == 0.0: + loss = loss_att + loss_pre * self.predictor_weight + elif self.ctc_weight == 1.0: + loss = loss_ctc + else: + loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + + # Collect Attn branch stats + stats["loss_att"] = loss_att.detach() if loss_att is not None else None + stats["acc"] = acc_att + stats["cer"] = cer_att + stats["wer"] = wer_att + stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None + + stats["loss"] = torch.clone(loss.detach()) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def encode_chunk( + self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Frontend + Encoder. Note that this method is used by asr_inference.py + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # Pre-encoder, e.g. used for raw input data + if self.preencoder is not None: + feats, feats_lengths = self.preencoder(feats, feats_lengths) + + # 4. Forward encoder + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + if self.encoder.interctc_use_conditioning: + encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk( + feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc + ) + else: + encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"]) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + # Post-encoder, e.g. NLU + if self.postencoder is not None: + encoder_out, encoder_out_lens = self.postencoder( + encoder_out, encoder_out_lens + ) + + if intermediate_outs is not None: + return (encoder_out, intermediate_outs), encoder_out_lens + + return encoder_out, torch.tensor([encoder_out.size(1)]) + + def calc_predictor_chunk(self, encoder_out, cache=None): + + pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = \ + self.predictor.forward_chunk(encoder_out, cache["encoder"]) + return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index + + def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None): + decoder_outs = self.decoder.forward_chunk( + encoder_out, sematic_embeds, cache["decoder"] + ) + decoder_out = decoder_outs + decoder_out = torch.log_softmax(decoder_out, dim=-1) + return decoder_out + + class ParaformerBert(Paraformer): """ Author: Speech Lab of DAMO Academy, Alibaba Group diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py index 7ac912137..f2502bbb6 100644 --- a/funasr/models/encoder/sanm_encoder.py +++ b/funasr/models/encoder/sanm_encoder.py @@ -11,7 +11,7 @@ from typeguard import check_argument_types import numpy as np from funasr.modules.nets_utils import make_pad_mask from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask -from funasr.modules.embedding import SinusoidalPositionEncoder +from funasr.modules.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder from funasr.modules.layer_norm import LayerNorm from funasr.modules.multi_layer_conv import Conv1dLinear from funasr.modules.multi_layer_conv import MultiLayeredConv1d @@ -180,6 +180,8 @@ class SANMEncoder(AbsEncoder): self.embed = torch.nn.Linear(input_size, output_size) elif input_layer == "pe": self.embed = SinusoidalPositionEncoder() + elif input_layer == "pe_online": + self.embed = StreamSinusoidalPositionEncoder() else: raise ValueError("unknown input_layer: " + input_layer) self.normalize_before = normalize_before @@ -357,7 +359,7 @@ class SANMEncoder(AbsEncoder): if self.embed is None: xs_pad = xs_pad else: - xs_pad = self.embed.forward_chunk(xs_pad, cache) + xs_pad = self.embed(xs_pad, cache) encoder_outs = self.encoders0(xs_pad, None, None, None, None) xs_pad, masks = encoder_outs[0], encoder_outs[1] diff --git a/funasr/modules/embedding.py b/funasr/modules/embedding.py index 79ca0b2f8..4b292a79b 100644 --- a/funasr/modules/embedding.py +++ b/funasr/modules/embedding.py @@ -407,7 +407,24 @@ class SinusoidalPositionEncoder(torch.nn.Module): return x + position_encoding - def forward_chunk(self, x, cache=None): +class StreamSinusoidalPositionEncoder(torch.nn.Module): + ''' + + ''' + def __int__(self, d_model=80, dropout_rate=0.1): + pass + + def encode(self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32): + batch_size = positions.size(0) + positions = positions.type(dtype) + log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype)) / (depth / 2 - 1) + inv_timescales = torch.exp(torch.arange(depth / 2).type(dtype) * (-log_timescale_increment)) + inv_timescales = torch.reshape(inv_timescales, [batch_size, -1]) + scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(inv_timescales, [1, 1, -1]) + encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) + return encoding.type(dtype) + + def forward(self, x, cache=None): start_idx = 0 pad_left = 0 pad_right = 0 @@ -419,8 +436,8 @@ class SinusoidalPositionEncoder(torch.nn.Module): positions = torch.arange(1, timesteps+start_idx+1)[None, :] position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) outputs = x + position_encoding[:, start_idx: start_idx + timesteps] - outputs = outputs.transpose(1,2) + outputs = outputs.transpose(1, 2) outputs = F.pad(outputs, (pad_left, pad_right)) - outputs = outputs.transpose(1,2) + outputs = outputs.transpose(1, 2) return outputs diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py index e15147332..52a0ce753 100644 --- a/funasr/tasks/asr.py +++ b/funasr/tasks/asr.py @@ -39,7 +39,7 @@ from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN from funasr.models.decoder.transformer_decoder import TransformerDecoder from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder from funasr.models.e2e_asr import ESPnetASRModel -from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer +from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer from funasr.models.e2e_tp import TimestampPredictor from funasr.models.e2e_asr_mfcca import MFCCA from funasr.models.e2e_uni_asr import UniASR @@ -121,6 +121,7 @@ model_choices = ClassChoices( asr=ESPnetASRModel, uniasr=UniASR, paraformer=Paraformer, + paraformer_online=ParaformerOnline, paraformer_bert=ParaformerBert, bicif_paraformer=BiCifParaformer, contextual_paraformer=ContextualParaformer,