rnnt bug fix

This commit is contained in:
aky15 2023-05-09 11:16:07 +08:00
parent bdb8a99da4
commit 77045e7bb7
7 changed files with 59 additions and 25 deletions

View File

@ -188,18 +188,15 @@ class Speech2Text:
self.frontend = frontend
self.window_size = self.chunk_size + self.right_context
self._ctx = self.asr_model.encoder.get_encoder_input_size(
self.window_size
)
if self.streaming:
self._ctx = self.asr_model.encoder.get_encoder_input_size(
self.window_size
)
#self.last_chunk_length = (
# self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
#) * self.hop_length
self.last_chunk_length = (
self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
)
self.reset_inference_cache()
self.last_chunk_length = (
self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
)
self.reset_inference_cache()
def reset_inference_cache(self) -> None:
"""Reset Speech2Text parameters."""

View File

@ -33,6 +33,7 @@ class RNNTDecoder(torch.nn.Module):
dropout_rate: float = 0.0,
embed_dropout_rate: float = 0.0,
embed_pad: int = 0,
use_embed_mask: bool = False,
) -> None:
"""Construct a RNNDecoder object."""
super().__init__()
@ -66,6 +67,15 @@ class RNNTDecoder(torch.nn.Module):
self.device = next(self.parameters()).device
self.score_cache = {}
self.use_embed_mask = use_embed_mask
if self.use_embed_mask:
self._embed_mask = SpecAug(
time_mask_width_range=3,
num_time_mask=4,
apply_freq_mask=False,
apply_time_warp=False
)
def forward(
self,
@ -88,6 +98,8 @@ class RNNTDecoder(torch.nn.Module):
states = self.init_state(labels.size(0))
dec_embed = self.dropout_embed(self.embed(labels))
if self.use_embed_mask and self.training:
dec_embed = self._embed_mask(dec_embed, label_lens)[0]
dec_out, states = self.rnn_forward(dec_embed, states)
return dec_out

View File

@ -12,7 +12,7 @@ from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.modules.nets_utils import get_transducer_task_io
from funasr.layers.abs_normalize import AbsNormalize
@ -62,7 +62,7 @@ class TransducerModel(AbsESPnetModel):
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
encoder: Encoder,
encoder: AbsEncoder,
decoder: RNNTDecoder,
joint_network: JointNetwork,
att_decoder: Optional[AbsAttDecoder] = None,
@ -286,7 +286,7 @@ class TransducerModel(AbsESPnetModel):
feats, feats_lengths = self.normalize(feats, feats_lengths)
# 4. Forward encoder
encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths)
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
@ -515,7 +515,7 @@ class UnifiedTransducerModel(AbsESPnetModel):
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
encoder: Encoder,
encoder: AbsEncoder,
decoder: RNNTDecoder,
joint_network: JointNetwork,
att_decoder: Optional[AbsAttDecoder] = None,

View File

@ -307,7 +307,7 @@ class ChunkEncoderLayer(torch.nn.Module):
feed_forward: torch.nn.Module,
feed_forward_macaron: torch.nn.Module,
conv_mod: torch.nn.Module,
norm_class: torch.nn.Module = torch.nn.LayerNorm,
norm_class: torch.nn.Module = LayerNorm,
norm_args: Dict = {},
dropout_rate: float = 0.0,
) -> None:
@ -1145,7 +1145,7 @@ class ConformerChunkEncoder(AbsEncoder):
x = x[:,::self.time_reduction_factor,:]
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
return x, olens
return x, olens, None
def simu_chunk_forward(
self,

View File

@ -485,14 +485,39 @@ def rename_state_dict(
new_k = k.replace(old_prefix, new_prefix)
state_dict[new_k] = v
class Swish(torch.nn.Module):
"""Construct an Swish object."""
"""Swish activation definition.
def forward(self, x):
"""Return Swich activation function."""
return x * torch.sigmoid(x)
Swish(x) = (beta * x) * sigmoid(x)
where beta = 1 defines standard Swish activation.
References:
https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1.
E-swish variant: https://arxiv.org/abs/1801.07145.
Args:
beta: Beta parameter for E-Swish.
(beta >= 1. If beta < 1, use standard Swish).
use_builtin: Whether to use PyTorch function if available.
"""
def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None:
super().__init__()
self.beta = beta
if beta > 1:
self.swish = lambda x: (self.beta * x) * torch.sigmoid(x)
else:
if use_builtin:
self.swish = torch.nn.SiLU()
else:
self.swish = lambda x: x * torch.sigmoid(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward computation."""
return self.swish(x)
def get_activation(act):
"""Return activation function."""

View File

@ -7,7 +7,7 @@
"""Repeat the same layer definition."""
from typing import Dict, List, Optional
from funasr.modules.layer_norm import LayerNorm
import torch
@ -48,7 +48,7 @@ class MultiBlocks(torch.nn.Module):
self,
block_list: List[torch.nn.Module],
output_size: int,
norm_class: torch.nn.Module = torch.nn.LayerNorm,
norm_class: torch.nn.Module = LayerNorm,
) -> None:
"""Construct a MultiBlocks object."""
super().__init__()

View File

@ -1682,7 +1682,7 @@ class ASRTransducerTask(AbsTask):
# 7. Build model
if encoder.unified_model_training:
if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training:
model = UnifiedTransducerModel(
vocab_size=vocab_size,
token_list=token_list,