diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py index d96464360..bd36907f7 100644 --- a/funasr/bin/asr_inference_rnnt.py +++ b/funasr/bin/asr_inference_rnnt.py @@ -188,18 +188,15 @@ class Speech2Text: self.frontend = frontend self.window_size = self.chunk_size + self.right_context - self._ctx = self.asr_model.encoder.get_encoder_input_size( - self.window_size - ) + if self.streaming: + self._ctx = self.asr_model.encoder.get_encoder_input_size( + self.window_size + ) - #self.last_chunk_length = ( - # self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 - #) * self.hop_length - - self.last_chunk_length = ( - self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 - ) - self.reset_inference_cache() + self.last_chunk_length = ( + self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 + ) + self.reset_inference_cache() def reset_inference_cache(self) -> None: """Reset Speech2Text parameters.""" diff --git a/funasr/models/decoder/rnnt_decoder.py b/funasr/models/decoder/rnnt_decoder.py index 5401ab20c..a0fe9eadc 100644 --- a/funasr/models/decoder/rnnt_decoder.py +++ b/funasr/models/decoder/rnnt_decoder.py @@ -33,6 +33,7 @@ class RNNTDecoder(torch.nn.Module): dropout_rate: float = 0.0, embed_dropout_rate: float = 0.0, embed_pad: int = 0, + use_embed_mask: bool = False, ) -> None: """Construct a RNNDecoder object.""" super().__init__() @@ -66,6 +67,15 @@ class RNNTDecoder(torch.nn.Module): self.device = next(self.parameters()).device self.score_cache = {} + + self.use_embed_mask = use_embed_mask + if self.use_embed_mask: + self._embed_mask = SpecAug( + time_mask_width_range=3, + num_time_mask=4, + apply_freq_mask=False, + apply_time_warp=False + ) def forward( self, @@ -88,6 +98,8 @@ class RNNTDecoder(torch.nn.Module): states = self.init_state(labels.size(0)) dec_embed = self.dropout_embed(self.embed(labels)) + if self.use_embed_mask and self.training: + dec_embed = self._embed_mask(dec_embed, label_lens)[0] dec_out, states = self.rnn_forward(dec_embed, states) return dec_out diff --git a/funasr/models/e2e_asr_transducer.py b/funasr/models/e2e_asr_transducer.py index f8ba0f0c6..a5aaa6c52 100644 --- a/funasr/models/e2e_asr_transducer.py +++ b/funasr/models/e2e_asr_transducer.py @@ -12,7 +12,7 @@ from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.decoder.rnnt_decoder import RNNTDecoder from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder -from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder +from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.models.joint_net.joint_network import JointNetwork from funasr.modules.nets_utils import get_transducer_task_io from funasr.layers.abs_normalize import AbsNormalize @@ -62,7 +62,7 @@ class TransducerModel(AbsESPnetModel): frontend: Optional[AbsFrontend], specaug: Optional[AbsSpecAug], normalize: Optional[AbsNormalize], - encoder: Encoder, + encoder: AbsEncoder, decoder: RNNTDecoder, joint_network: JointNetwork, att_decoder: Optional[AbsAttDecoder] = None, @@ -286,7 +286,7 @@ class TransducerModel(AbsESPnetModel): feats, feats_lengths = self.normalize(feats, feats_lengths) # 4. Forward encoder - encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths) + encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) assert encoder_out.size(0) == speech.size(0), ( encoder_out.size(), @@ -515,7 +515,7 @@ class UnifiedTransducerModel(AbsESPnetModel): frontend: Optional[AbsFrontend], specaug: Optional[AbsSpecAug], normalize: Optional[AbsNormalize], - encoder: Encoder, + encoder: AbsEncoder, decoder: RNNTDecoder, joint_network: JointNetwork, att_decoder: Optional[AbsAttDecoder] = None, diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py index 9777ceed6..434f2a480 100644 --- a/funasr/models/encoder/conformer_encoder.py +++ b/funasr/models/encoder/conformer_encoder.py @@ -307,7 +307,7 @@ class ChunkEncoderLayer(torch.nn.Module): feed_forward: torch.nn.Module, feed_forward_macaron: torch.nn.Module, conv_mod: torch.nn.Module, - norm_class: torch.nn.Module = torch.nn.LayerNorm, + norm_class: torch.nn.Module = LayerNorm, norm_args: Dict = {}, dropout_rate: float = 0.0, ) -> None: @@ -1145,7 +1145,7 @@ class ConformerChunkEncoder(AbsEncoder): x = x[:,::self.time_reduction_factor,:] olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 - return x, olens + return x, olens, None def simu_chunk_forward( self, diff --git a/funasr/modules/nets_utils.py b/funasr/modules/nets_utils.py index 10df124f8..397a5c428 100644 --- a/funasr/modules/nets_utils.py +++ b/funasr/modules/nets_utils.py @@ -485,14 +485,39 @@ def rename_state_dict( new_k = k.replace(old_prefix, new_prefix) state_dict[new_k] = v - class Swish(torch.nn.Module): - """Construct an Swish object.""" + """Swish activation definition. - def forward(self, x): - """Return Swich activation function.""" - return x * torch.sigmoid(x) + Swish(x) = (beta * x) * sigmoid(x) + where beta = 1 defines standard Swish activation. + References: + https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1. + E-swish variant: https://arxiv.org/abs/1801.07145. + + Args: + beta: Beta parameter for E-Swish. + (beta >= 1. If beta < 1, use standard Swish). + use_builtin: Whether to use PyTorch function if available. + + """ + + def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None: + super().__init__() + + self.beta = beta + + if beta > 1: + self.swish = lambda x: (self.beta * x) * torch.sigmoid(x) + else: + if use_builtin: + self.swish = torch.nn.SiLU() + else: + self.swish = lambda x: x * torch.sigmoid(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward computation.""" + return self.swish(x) def get_activation(act): """Return activation function.""" diff --git a/funasr/modules/repeat.py b/funasr/modules/repeat.py index 2b2dac8f3..ff1e182af 100644 --- a/funasr/modules/repeat.py +++ b/funasr/modules/repeat.py @@ -7,7 +7,7 @@ """Repeat the same layer definition.""" from typing import Dict, List, Optional - +from funasr.modules.layer_norm import LayerNorm import torch @@ -48,7 +48,7 @@ class MultiBlocks(torch.nn.Module): self, block_list: List[torch.nn.Module], output_size: int, - norm_class: torch.nn.Module = torch.nn.LayerNorm, + norm_class: torch.nn.Module = LayerNorm, ) -> None: """Construct a MultiBlocks object.""" super().__init__() diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py index 87db05c67..a64b9e7e4 100644 --- a/funasr/tasks/asr.py +++ b/funasr/tasks/asr.py @@ -1682,7 +1682,7 @@ class ASRTransducerTask(AbsTask): # 7. Build model - if encoder.unified_model_training: + if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training: model = UnifiedTransducerModel( vocab_size=vocab_size, token_list=token_list,