mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
rnnt bug fix
This commit is contained in:
parent
bdb8a99da4
commit
77045e7bb7
@ -188,18 +188,15 @@ class Speech2Text:
|
||||
self.frontend = frontend
|
||||
self.window_size = self.chunk_size + self.right_context
|
||||
|
||||
self._ctx = self.asr_model.encoder.get_encoder_input_size(
|
||||
self.window_size
|
||||
)
|
||||
if self.streaming:
|
||||
self._ctx = self.asr_model.encoder.get_encoder_input_size(
|
||||
self.window_size
|
||||
)
|
||||
|
||||
#self.last_chunk_length = (
|
||||
# self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
|
||||
#) * self.hop_length
|
||||
|
||||
self.last_chunk_length = (
|
||||
self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
|
||||
)
|
||||
self.reset_inference_cache()
|
||||
self.last_chunk_length = (
|
||||
self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
|
||||
)
|
||||
self.reset_inference_cache()
|
||||
|
||||
def reset_inference_cache(self) -> None:
|
||||
"""Reset Speech2Text parameters."""
|
||||
|
||||
@ -33,6 +33,7 @@ class RNNTDecoder(torch.nn.Module):
|
||||
dropout_rate: float = 0.0,
|
||||
embed_dropout_rate: float = 0.0,
|
||||
embed_pad: int = 0,
|
||||
use_embed_mask: bool = False,
|
||||
) -> None:
|
||||
"""Construct a RNNDecoder object."""
|
||||
super().__init__()
|
||||
@ -66,6 +67,15 @@ class RNNTDecoder(torch.nn.Module):
|
||||
|
||||
self.device = next(self.parameters()).device
|
||||
self.score_cache = {}
|
||||
|
||||
self.use_embed_mask = use_embed_mask
|
||||
if self.use_embed_mask:
|
||||
self._embed_mask = SpecAug(
|
||||
time_mask_width_range=3,
|
||||
num_time_mask=4,
|
||||
apply_freq_mask=False,
|
||||
apply_time_warp=False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -88,6 +98,8 @@ class RNNTDecoder(torch.nn.Module):
|
||||
states = self.init_state(labels.size(0))
|
||||
|
||||
dec_embed = self.dropout_embed(self.embed(labels))
|
||||
if self.use_embed_mask and self.training:
|
||||
dec_embed = self._embed_mask(dec_embed, label_lens)[0]
|
||||
dec_out, states = self.rnn_forward(dec_embed, states)
|
||||
return dec_out
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
|
||||
from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.joint_net.joint_network import JointNetwork
|
||||
from funasr.modules.nets_utils import get_transducer_task_io
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
@ -62,7 +62,7 @@ class TransducerModel(AbsESPnetModel):
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
encoder: Encoder,
|
||||
encoder: AbsEncoder,
|
||||
decoder: RNNTDecoder,
|
||||
joint_network: JointNetwork,
|
||||
att_decoder: Optional[AbsAttDecoder] = None,
|
||||
@ -286,7 +286,7 @@ class TransducerModel(AbsESPnetModel):
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths)
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
||||
|
||||
assert encoder_out.size(0) == speech.size(0), (
|
||||
encoder_out.size(),
|
||||
@ -515,7 +515,7 @@ class UnifiedTransducerModel(AbsESPnetModel):
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
encoder: Encoder,
|
||||
encoder: AbsEncoder,
|
||||
decoder: RNNTDecoder,
|
||||
joint_network: JointNetwork,
|
||||
att_decoder: Optional[AbsAttDecoder] = None,
|
||||
|
||||
@ -307,7 +307,7 @@ class ChunkEncoderLayer(torch.nn.Module):
|
||||
feed_forward: torch.nn.Module,
|
||||
feed_forward_macaron: torch.nn.Module,
|
||||
conv_mod: torch.nn.Module,
|
||||
norm_class: torch.nn.Module = torch.nn.LayerNorm,
|
||||
norm_class: torch.nn.Module = LayerNorm,
|
||||
norm_args: Dict = {},
|
||||
dropout_rate: float = 0.0,
|
||||
) -> None:
|
||||
@ -1145,7 +1145,7 @@ class ConformerChunkEncoder(AbsEncoder):
|
||||
x = x[:,::self.time_reduction_factor,:]
|
||||
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
|
||||
|
||||
return x, olens
|
||||
return x, olens, None
|
||||
|
||||
def simu_chunk_forward(
|
||||
self,
|
||||
|
||||
@ -485,14 +485,39 @@ def rename_state_dict(
|
||||
new_k = k.replace(old_prefix, new_prefix)
|
||||
state_dict[new_k] = v
|
||||
|
||||
|
||||
class Swish(torch.nn.Module):
|
||||
"""Construct an Swish object."""
|
||||
"""Swish activation definition.
|
||||
|
||||
def forward(self, x):
|
||||
"""Return Swich activation function."""
|
||||
return x * torch.sigmoid(x)
|
||||
Swish(x) = (beta * x) * sigmoid(x)
|
||||
where beta = 1 defines standard Swish activation.
|
||||
|
||||
References:
|
||||
https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1.
|
||||
E-swish variant: https://arxiv.org/abs/1801.07145.
|
||||
|
||||
Args:
|
||||
beta: Beta parameter for E-Swish.
|
||||
(beta >= 1. If beta < 1, use standard Swish).
|
||||
use_builtin: Whether to use PyTorch function if available.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.beta = beta
|
||||
|
||||
if beta > 1:
|
||||
self.swish = lambda x: (self.beta * x) * torch.sigmoid(x)
|
||||
else:
|
||||
if use_builtin:
|
||||
self.swish = torch.nn.SiLU()
|
||||
else:
|
||||
self.swish = lambda x: x * torch.sigmoid(x)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward computation."""
|
||||
return self.swish(x)
|
||||
|
||||
def get_activation(act):
|
||||
"""Return activation function."""
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
"""Repeat the same layer definition."""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
import torch
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ class MultiBlocks(torch.nn.Module):
|
||||
self,
|
||||
block_list: List[torch.nn.Module],
|
||||
output_size: int,
|
||||
norm_class: torch.nn.Module = torch.nn.LayerNorm,
|
||||
norm_class: torch.nn.Module = LayerNorm,
|
||||
) -> None:
|
||||
"""Construct a MultiBlocks object."""
|
||||
super().__init__()
|
||||
|
||||
@ -1682,7 +1682,7 @@ class ASRTransducerTask(AbsTask):
|
||||
|
||||
# 7. Build model
|
||||
|
||||
if encoder.unified_model_training:
|
||||
if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training:
|
||||
model = UnifiedTransducerModel(
|
||||
vocab_size=vocab_size,
|
||||
token_list=token_list,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user