mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'main' of github.com:alibaba-damo-academy/FunASR
merge
This commit is contained in:
commit
d79287c37e
@ -14,6 +14,7 @@ from funasr.models.transformer.attention import (
|
|||||||
MultiHeadedAttention, # noqa: H301
|
MultiHeadedAttention, # noqa: H301
|
||||||
RelPositionMultiHeadedAttention, # noqa: H301
|
RelPositionMultiHeadedAttention, # noqa: H301
|
||||||
LegacyRelPositionMultiHeadedAttention, # noqa: H301
|
LegacyRelPositionMultiHeadedAttention, # noqa: H301
|
||||||
|
RelPositionMultiHeadedAttentionChunk,
|
||||||
)
|
)
|
||||||
from funasr.models.transformer.embedding import (
|
from funasr.models.transformer.embedding import (
|
||||||
PositionalEncoding, # noqa: H301
|
PositionalEncoding, # noqa: H301
|
||||||
@ -610,4 +611,669 @@ class ConformerEncoder(nn.Module):
|
|||||||
if len(intermediate_outs) > 0:
|
if len(intermediate_outs) > 0:
|
||||||
return (xs_pad, intermediate_outs), olens, None
|
return (xs_pad, intermediate_outs), olens, None
|
||||||
return xs_pad, olens, None
|
return xs_pad, olens, None
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConvolution(torch.nn.Module):
|
||||||
|
"""ConformerConvolution module definition.
|
||||||
|
Args:
|
||||||
|
channels: The number of channels.
|
||||||
|
kernel_size: Size of the convolving kernel.
|
||||||
|
activation: Type of activation function.
|
||||||
|
norm_args: Normalization module arguments.
|
||||||
|
causal: Whether to use causal convolution (set to True if streaming).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
activation: torch.nn.Module = torch.nn.ReLU(),
|
||||||
|
norm_args: Dict = {},
|
||||||
|
causal: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Construct an ConformerConvolution object."""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert (kernel_size - 1) % 2 == 0
|
||||||
|
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
|
||||||
|
self.pointwise_conv1 = torch.nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
2 * channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if causal:
|
||||||
|
self.lorder = kernel_size - 1
|
||||||
|
padding = 0
|
||||||
|
else:
|
||||||
|
self.lorder = 0
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
|
|
||||||
|
self.depthwise_conv = torch.nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=padding,
|
||||||
|
groups=channels,
|
||||||
|
)
|
||||||
|
self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
|
||||||
|
self.pointwise_conv2 = torch.nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cache: Optional[torch.Tensor] = None,
|
||||||
|
right_context: int = 0,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute convolution module.
|
||||||
|
Args:
|
||||||
|
x: ConformerConvolution input sequences. (B, T, D_hidden)
|
||||||
|
cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
|
||||||
|
right_context: Number of frames in right context.
|
||||||
|
Returns:
|
||||||
|
x: ConformerConvolution output sequences. (B, T, D_hidden)
|
||||||
|
cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
|
||||||
|
"""
|
||||||
|
x = self.pointwise_conv1(x.transpose(1, 2))
|
||||||
|
x = torch.nn.functional.glu(x, dim=1)
|
||||||
|
|
||||||
|
if self.lorder > 0:
|
||||||
|
if cache is None:
|
||||||
|
x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
|
||||||
|
else:
|
||||||
|
x = torch.cat([cache, x], dim=2)
|
||||||
|
|
||||||
|
if right_context > 0:
|
||||||
|
cache = x[:, :, -(self.lorder + right_context) : -right_context]
|
||||||
|
else:
|
||||||
|
cache = x[:, :, -self.lorder :]
|
||||||
|
|
||||||
|
x = self.depthwise_conv(x)
|
||||||
|
x = self.activation(self.norm(x))
|
||||||
|
|
||||||
|
x = self.pointwise_conv2(x).transpose(1, 2)
|
||||||
|
|
||||||
|
return x, cache
|
||||||
|
|
||||||
|
class ChunkEncoderLayer(torch.nn.Module):
|
||||||
|
"""Chunk Conformer module definition.
|
||||||
|
Args:
|
||||||
|
block_size: Input/output size.
|
||||||
|
self_att: Self-attention module instance.
|
||||||
|
feed_forward: Feed-forward module instance.
|
||||||
|
feed_forward_macaron: Feed-forward module instance for macaron network.
|
||||||
|
conv_mod: Convolution module instance.
|
||||||
|
norm_class: Normalization module class.
|
||||||
|
norm_args: Normalization module arguments.
|
||||||
|
dropout_rate: Dropout rate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
block_size: int,
|
||||||
|
self_att: torch.nn.Module,
|
||||||
|
feed_forward: torch.nn.Module,
|
||||||
|
feed_forward_macaron: torch.nn.Module,
|
||||||
|
conv_mod: torch.nn.Module,
|
||||||
|
norm_class: torch.nn.Module = LayerNorm,
|
||||||
|
norm_args: Dict = {},
|
||||||
|
dropout_rate: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
"""Construct a Conformer object."""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.self_att = self_att
|
||||||
|
|
||||||
|
self.feed_forward = feed_forward
|
||||||
|
self.feed_forward_macaron = feed_forward_macaron
|
||||||
|
self.feed_forward_scale = 0.5
|
||||||
|
|
||||||
|
self.conv_mod = conv_mod
|
||||||
|
|
||||||
|
self.norm_feed_forward = norm_class(block_size, **norm_args)
|
||||||
|
self.norm_self_att = norm_class(block_size, **norm_args)
|
||||||
|
|
||||||
|
self.norm_macaron = norm_class(block_size, **norm_args)
|
||||||
|
self.norm_conv = norm_class(block_size, **norm_args)
|
||||||
|
self.norm_final = norm_class(block_size, **norm_args)
|
||||||
|
|
||||||
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||||
|
|
||||||
|
self.block_size = block_size
|
||||||
|
self.cache = None
|
||||||
|
|
||||||
|
def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
|
||||||
|
"""Initialize/Reset self-attention and convolution modules cache for streaming.
|
||||||
|
Args:
|
||||||
|
left_context: Number of left frames during chunk-by-chunk inference.
|
||||||
|
device: Device to use for cache tensor.
|
||||||
|
"""
|
||||||
|
self.cache = [
|
||||||
|
torch.zeros(
|
||||||
|
(1, left_context, self.block_size),
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
torch.zeros(
|
||||||
|
(
|
||||||
|
1,
|
||||||
|
self.block_size,
|
||||||
|
self.conv_mod.kernel_size - 1,
|
||||||
|
),
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
pos_enc: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
chunk_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Encode input sequences.
|
||||||
|
Args:
|
||||||
|
x: Conformer input sequences. (B, T, D_block)
|
||||||
|
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
|
||||||
|
mask: Source mask. (B, T)
|
||||||
|
chunk_mask: Chunk mask. (T_2, T_2)
|
||||||
|
Returns:
|
||||||
|
x: Conformer output sequences. (B, T, D_block)
|
||||||
|
mask: Source mask. (B, T)
|
||||||
|
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
|
||||||
|
"""
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
x = self.norm_macaron(x)
|
||||||
|
x = residual + self.feed_forward_scale * self.dropout(
|
||||||
|
self.feed_forward_macaron(x)
|
||||||
|
)
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
x = self.norm_self_att(x)
|
||||||
|
x_q = x
|
||||||
|
x = residual + self.dropout(
|
||||||
|
self.self_att(
|
||||||
|
x_q,
|
||||||
|
x,
|
||||||
|
x,
|
||||||
|
pos_enc,
|
||||||
|
mask,
|
||||||
|
chunk_mask=chunk_mask,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
x = self.norm_conv(x)
|
||||||
|
x, _ = self.conv_mod(x)
|
||||||
|
x = residual + self.dropout(x)
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
x = self.norm_feed_forward(x)
|
||||||
|
x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
|
||||||
|
|
||||||
|
x = self.norm_final(x)
|
||||||
|
return x, mask, pos_enc
|
||||||
|
|
||||||
|
def chunk_forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
pos_enc: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
chunk_size: int = 16,
|
||||||
|
left_context: int = 0,
|
||||||
|
right_context: int = 0,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Encode chunk of input sequence.
|
||||||
|
Args:
|
||||||
|
x: Conformer input sequences. (B, T, D_block)
|
||||||
|
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
|
||||||
|
mask: Source mask. (B, T_2)
|
||||||
|
left_context: Number of frames in left context.
|
||||||
|
right_context: Number of frames in right context.
|
||||||
|
Returns:
|
||||||
|
x: Conformer output sequences. (B, T, D_block)
|
||||||
|
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
|
||||||
|
"""
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
x = self.norm_macaron(x)
|
||||||
|
x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
x = self.norm_self_att(x)
|
||||||
|
if left_context > 0:
|
||||||
|
key = torch.cat([self.cache[0], x], dim=1)
|
||||||
|
else:
|
||||||
|
key = x
|
||||||
|
val = key
|
||||||
|
|
||||||
|
if right_context > 0:
|
||||||
|
att_cache = key[:, -(left_context + right_context) : -right_context, :]
|
||||||
|
else:
|
||||||
|
att_cache = key[:, -left_context:, :]
|
||||||
|
x = residual + self.self_att(
|
||||||
|
x,
|
||||||
|
key,
|
||||||
|
val,
|
||||||
|
pos_enc,
|
||||||
|
mask,
|
||||||
|
left_context=left_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
x = self.norm_conv(x)
|
||||||
|
x, conv_cache = self.conv_mod(
|
||||||
|
x, cache=self.cache[1], right_context=right_context
|
||||||
|
)
|
||||||
|
x = residual + x
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
x = self.norm_feed_forward(x)
|
||||||
|
x = residual + self.feed_forward_scale * self.feed_forward(x)
|
||||||
|
|
||||||
|
x = self.norm_final(x)
|
||||||
|
self.cache = [att_cache, conv_cache]
|
||||||
|
|
||||||
|
return x, pos_enc
|
||||||
|
|
||||||
|
@tables.register("encoder_classes", "ChunkConformerEncoder")
|
||||||
|
class ConformerChunkEncoder(torch.nn.Module):
|
||||||
|
"""Encoder module definition.
|
||||||
|
Args:
|
||||||
|
input_size: Input size.
|
||||||
|
body_conf: Encoder body configuration.
|
||||||
|
input_conf: Encoder input configuration.
|
||||||
|
main_conf: Encoder main configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
output_size: int = 256,
|
||||||
|
attention_heads: int = 4,
|
||||||
|
linear_units: int = 2048,
|
||||||
|
num_blocks: int = 6,
|
||||||
|
dropout_rate: float = 0.1,
|
||||||
|
positional_dropout_rate: float = 0.1,
|
||||||
|
attention_dropout_rate: float = 0.0,
|
||||||
|
embed_vgg_like: bool = False,
|
||||||
|
normalize_before: bool = True,
|
||||||
|
concat_after: bool = False,
|
||||||
|
positionwise_layer_type: str = "linear",
|
||||||
|
positionwise_conv_kernel_size: int = 3,
|
||||||
|
macaron_style: bool = False,
|
||||||
|
rel_pos_type: str = "legacy",
|
||||||
|
pos_enc_layer_type: str = "rel_pos",
|
||||||
|
selfattention_layer_type: str = "rel_selfattn",
|
||||||
|
activation_type: str = "swish",
|
||||||
|
use_cnn_module: bool = True,
|
||||||
|
zero_triu: bool = False,
|
||||||
|
norm_type: str = "layer_norm",
|
||||||
|
cnn_module_kernel: int = 31,
|
||||||
|
conv_mod_norm_eps: float = 0.00001,
|
||||||
|
conv_mod_norm_momentum: float = 0.1,
|
||||||
|
simplified_att_score: bool = False,
|
||||||
|
dynamic_chunk_training: bool = False,
|
||||||
|
short_chunk_threshold: float = 0.75,
|
||||||
|
short_chunk_size: int = 25,
|
||||||
|
left_chunk_size: int = 0,
|
||||||
|
time_reduction_factor: int = 1,
|
||||||
|
unified_model_training: bool = False,
|
||||||
|
default_chunk_size: int = 16,
|
||||||
|
jitter_range: int = 4,
|
||||||
|
subsampling_factor: int = 1,
|
||||||
|
) -> None:
|
||||||
|
"""Construct an Encoder object."""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
self.embed = StreamingConvInput(
|
||||||
|
input_size=input_size,
|
||||||
|
conv_size=output_size,
|
||||||
|
subsampling_factor=subsampling_factor,
|
||||||
|
vgg_like=embed_vgg_like,
|
||||||
|
output_size=output_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pos_enc = StreamingRelPositionalEncoding(
|
||||||
|
output_size,
|
||||||
|
positional_dropout_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
activation = get_activation(
|
||||||
|
activation_type
|
||||||
|
)
|
||||||
|
|
||||||
|
pos_wise_args = (
|
||||||
|
output_size,
|
||||||
|
linear_units,
|
||||||
|
positional_dropout_rate,
|
||||||
|
activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_mod_norm_args = {
|
||||||
|
"eps": conv_mod_norm_eps,
|
||||||
|
"momentum": conv_mod_norm_momentum,
|
||||||
|
}
|
||||||
|
|
||||||
|
conv_mod_args = (
|
||||||
|
output_size,
|
||||||
|
cnn_module_kernel,
|
||||||
|
activation,
|
||||||
|
conv_mod_norm_args,
|
||||||
|
dynamic_chunk_training or unified_model_training,
|
||||||
|
)
|
||||||
|
|
||||||
|
mult_att_args = (
|
||||||
|
attention_heads,
|
||||||
|
output_size,
|
||||||
|
attention_dropout_rate,
|
||||||
|
simplified_att_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
fn_modules = []
|
||||||
|
for _ in range(num_blocks):
|
||||||
|
module = lambda: ChunkEncoderLayer(
|
||||||
|
output_size,
|
||||||
|
RelPositionMultiHeadedAttentionChunk(*mult_att_args),
|
||||||
|
PositionwiseFeedForward(*pos_wise_args),
|
||||||
|
PositionwiseFeedForward(*pos_wise_args),
|
||||||
|
CausalConvolution(*conv_mod_args),
|
||||||
|
dropout_rate=dropout_rate,
|
||||||
|
)
|
||||||
|
fn_modules.append(module)
|
||||||
|
|
||||||
|
self.encoders = MultiBlocks(
|
||||||
|
[fn() for fn in fn_modules],
|
||||||
|
output_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._output_size = output_size
|
||||||
|
|
||||||
|
self.dynamic_chunk_training = dynamic_chunk_training
|
||||||
|
self.short_chunk_threshold = short_chunk_threshold
|
||||||
|
self.short_chunk_size = short_chunk_size
|
||||||
|
self.left_chunk_size = left_chunk_size
|
||||||
|
|
||||||
|
self.unified_model_training = unified_model_training
|
||||||
|
self.default_chunk_size = default_chunk_size
|
||||||
|
self.jitter_range = jitter_range
|
||||||
|
|
||||||
|
self.time_reduction_factor = time_reduction_factor
|
||||||
|
|
||||||
|
def output_size(self) -> int:
|
||||||
|
return self._output_size
|
||||||
|
|
||||||
|
def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
|
||||||
|
"""Return the corresponding number of sample for a given chunk size, in frames.
|
||||||
|
Where size is the number of features frames after applying subsampling.
|
||||||
|
Args:
|
||||||
|
size: Number of frames after subsampling.
|
||||||
|
hop_length: Frontend's hop length
|
||||||
|
Returns:
|
||||||
|
: Number of raw samples
|
||||||
|
"""
|
||||||
|
return self.embed.get_size_before_subsampling(size) * hop_length
|
||||||
|
|
||||||
|
def get_encoder_input_size(self, size: int) -> int:
|
||||||
|
"""Return the corresponding number of sample for a given chunk size, in frames.
|
||||||
|
Where size is the number of features frames after applying subsampling.
|
||||||
|
Args:
|
||||||
|
size: Number of frames after subsampling.
|
||||||
|
Returns:
|
||||||
|
: Number of raw samples
|
||||||
|
"""
|
||||||
|
return self.embed.get_size_before_subsampling(size)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
|
||||||
|
"""Initialize/Reset encoder streaming cache.
|
||||||
|
Args:
|
||||||
|
left_context: Number of frames in left context.
|
||||||
|
device: Device ID.
|
||||||
|
"""
|
||||||
|
return self.encoders.reset_streaming_cache(left_context, device)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_len: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Encode input sequences.
|
||||||
|
Args:
|
||||||
|
x: Encoder input features. (B, T_in, F)
|
||||||
|
x_len: Encoder input features lengths. (B,)
|
||||||
|
Returns:
|
||||||
|
x: Encoder outputs. (B, T_out, D_enc)
|
||||||
|
x_len: Encoder outputs lenghts. (B,)
|
||||||
|
"""
|
||||||
|
short_status, limit_size = check_short_utt(
|
||||||
|
self.embed.subsampling_factor, x.size(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if short_status:
|
||||||
|
raise TooShortUttError(
|
||||||
|
f"has {x.size(1)} frames and is too short for subsampling "
|
||||||
|
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||||
|
x.size(1),
|
||||||
|
limit_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
mask = make_source_mask(x_len).to(x.device)
|
||||||
|
|
||||||
|
if self.unified_model_training:
|
||||||
|
if self.training:
|
||||||
|
chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
|
||||||
|
else:
|
||||||
|
chunk_size = self.default_chunk_size
|
||||||
|
x, mask = self.embed(x, mask, chunk_size)
|
||||||
|
pos_enc = self.pos_enc(x)
|
||||||
|
chunk_mask = make_chunk_mask(
|
||||||
|
x.size(1),
|
||||||
|
chunk_size,
|
||||||
|
left_chunk_size=self.left_chunk_size,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
x_utt = self.encoders(
|
||||||
|
x,
|
||||||
|
pos_enc,
|
||||||
|
mask,
|
||||||
|
chunk_mask=None,
|
||||||
|
)
|
||||||
|
x_chunk = self.encoders(
|
||||||
|
x,
|
||||||
|
pos_enc,
|
||||||
|
mask,
|
||||||
|
chunk_mask=chunk_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
olens = mask.eq(0).sum(1)
|
||||||
|
if self.time_reduction_factor > 1:
|
||||||
|
x_utt = x_utt[:,::self.time_reduction_factor,:]
|
||||||
|
x_chunk = x_chunk[:,::self.time_reduction_factor,:]
|
||||||
|
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
|
||||||
|
|
||||||
|
return x_utt, x_chunk, olens
|
||||||
|
|
||||||
|
elif self.dynamic_chunk_training:
|
||||||
|
max_len = x.size(1)
|
||||||
|
if self.training:
|
||||||
|
chunk_size = torch.randint(1, max_len, (1,)).item()
|
||||||
|
|
||||||
|
if chunk_size > (max_len * self.short_chunk_threshold):
|
||||||
|
chunk_size = max_len
|
||||||
|
else:
|
||||||
|
chunk_size = (chunk_size % self.short_chunk_size) + 1
|
||||||
|
else:
|
||||||
|
chunk_size = self.default_chunk_size
|
||||||
|
|
||||||
|
x, mask = self.embed(x, mask, chunk_size)
|
||||||
|
pos_enc = self.pos_enc(x)
|
||||||
|
|
||||||
|
chunk_mask = make_chunk_mask(
|
||||||
|
x.size(1),
|
||||||
|
chunk_size,
|
||||||
|
left_chunk_size=self.left_chunk_size,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x, mask = self.embed(x, mask, None)
|
||||||
|
pos_enc = self.pos_enc(x)
|
||||||
|
chunk_mask = None
|
||||||
|
x = self.encoders(
|
||||||
|
x,
|
||||||
|
pos_enc,
|
||||||
|
mask,
|
||||||
|
chunk_mask=chunk_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
olens = mask.eq(0).sum(1)
|
||||||
|
if self.time_reduction_factor > 1:
|
||||||
|
x = x[:,::self.time_reduction_factor,:]
|
||||||
|
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
|
||||||
|
|
||||||
|
return x, olens, None
|
||||||
|
|
||||||
|
def full_utt_forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_len: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Encode input sequences.
|
||||||
|
Args:
|
||||||
|
x: Encoder input features. (B, T_in, F)
|
||||||
|
x_len: Encoder input features lengths. (B,)
|
||||||
|
Returns:
|
||||||
|
x: Encoder outputs. (B, T_out, D_enc)
|
||||||
|
x_len: Encoder outputs lenghts. (B,)
|
||||||
|
"""
|
||||||
|
short_status, limit_size = check_short_utt(
|
||||||
|
self.embed.subsampling_factor, x.size(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if short_status:
|
||||||
|
raise TooShortUttError(
|
||||||
|
f"has {x.size(1)} frames and is too short for subsampling "
|
||||||
|
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||||
|
x.size(1),
|
||||||
|
limit_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
mask = make_source_mask(x_len).to(x.device)
|
||||||
|
x, mask = self.embed(x, mask, None)
|
||||||
|
pos_enc = self.pos_enc(x)
|
||||||
|
x_utt = self.encoders(
|
||||||
|
x,
|
||||||
|
pos_enc,
|
||||||
|
mask,
|
||||||
|
chunk_mask=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.time_reduction_factor > 1:
|
||||||
|
x_utt = x_utt[:,::self.time_reduction_factor,:]
|
||||||
|
return x_utt
|
||||||
|
|
||||||
|
def simu_chunk_forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_len: torch.Tensor,
|
||||||
|
chunk_size: int = 16,
|
||||||
|
left_context: int = 32,
|
||||||
|
right_context: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
short_status, limit_size = check_short_utt(
|
||||||
|
self.embed.subsampling_factor, x.size(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if short_status:
|
||||||
|
raise TooShortUttError(
|
||||||
|
f"has {x.size(1)} frames and is too short for subsampling "
|
||||||
|
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||||
|
x.size(1),
|
||||||
|
limit_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
mask = make_source_mask(x_len)
|
||||||
|
|
||||||
|
x, mask = self.embed(x, mask, chunk_size)
|
||||||
|
pos_enc = self.pos_enc(x)
|
||||||
|
chunk_mask = make_chunk_mask(
|
||||||
|
x.size(1),
|
||||||
|
chunk_size,
|
||||||
|
left_chunk_size=self.left_chunk_size,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = self.encoders(
|
||||||
|
x,
|
||||||
|
pos_enc,
|
||||||
|
mask,
|
||||||
|
chunk_mask=chunk_mask,
|
||||||
|
)
|
||||||
|
olens = mask.eq(0).sum(1)
|
||||||
|
if self.time_reduction_factor > 1:
|
||||||
|
x = x[:,::self.time_reduction_factor,:]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def chunk_forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_len: torch.Tensor,
|
||||||
|
processed_frames: torch.tensor,
|
||||||
|
chunk_size: int = 16,
|
||||||
|
left_context: int = 32,
|
||||||
|
right_context: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Encode input sequences as chunks.
|
||||||
|
Args:
|
||||||
|
x: Encoder input features. (1, T_in, F)
|
||||||
|
x_len: Encoder input features lengths. (1,)
|
||||||
|
processed_frames: Number of frames already seen.
|
||||||
|
left_context: Number of frames in left context.
|
||||||
|
right_context: Number of frames in right context.
|
||||||
|
Returns:
|
||||||
|
x: Encoder outputs. (B, T_out, D_enc)
|
||||||
|
"""
|
||||||
|
mask = make_source_mask(x_len)
|
||||||
|
x, mask = self.embed(x, mask, None)
|
||||||
|
|
||||||
|
if left_context > 0:
|
||||||
|
processed_mask = (
|
||||||
|
torch.arange(left_context, device=x.device)
|
||||||
|
.view(1, left_context)
|
||||||
|
.flip(1)
|
||||||
|
)
|
||||||
|
processed_mask = processed_mask >= processed_frames
|
||||||
|
mask = torch.cat([processed_mask, mask], dim=1)
|
||||||
|
pos_enc = self.pos_enc(x, left_context=left_context)
|
||||||
|
x = self.encoders.chunk_forward(
|
||||||
|
x,
|
||||||
|
pos_enc,
|
||||||
|
mask,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
left_context=left_context,
|
||||||
|
right_context=right_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
if right_context > 0:
|
||||||
|
x = x[:, 0:-right_context, :]
|
||||||
|
|
||||||
|
if self.time_reduction_factor > 1:
|
||||||
|
x = x[:,::self.time_reduction_factor,:]
|
||||||
|
return x
|
||||||
|
|||||||
@ -335,7 +335,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
|||||||
|
|
||||||
speech = speech.to(device=kwargs["device"])
|
speech = speech.to(device=kwargs["device"])
|
||||||
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
||||||
|
|
||||||
# hotword
|
# hotword
|
||||||
self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)
|
self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,13 @@
|
|||||||
"""Search algorithms for Transducer models."""
|
#!/usr/bin/env python3
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||||
|
# MIT License (https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from funasr.models.transducer.joint_network import JointNetwork
|
from funasr.models.transducer.joint_network import JointNetwork
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,15 @@
|
|||||||
"""Transducer joint network implementation."""
|
#!/usr/bin/env python3
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||||
|
# MIT License (https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from funasr.register import tables
|
||||||
from funasr.models.transformer.utils.nets_utils import get_activation
|
from funasr.models.transformer.utils.nets_utils import get_activation
|
||||||
|
|
||||||
|
|
||||||
|
@tables.register("joint_network_classes", "joint_network")
|
||||||
class JointNetwork(torch.nn.Module):
|
class JointNetwork(torch.nn.Module):
|
||||||
"""Transducer joint network module.
|
"""Transducer joint network module.
|
||||||
|
|
||||||
|
|||||||
@ -1,42 +1,26 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||||
|
# MIT License (https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
import logging
|
import logging
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
from distutils.version import LooseVersion
|
from distutils.version import LooseVersion
|
||||||
from typing import Dict
|
|
||||||
from typing import List
|
|
||||||
from typing import Optional
|
|
||||||
from typing import Tuple
|
|
||||||
from typing import Union
|
|
||||||
import tempfile
|
|
||||||
import codecs
|
|
||||||
import requests
|
|
||||||
import re
|
|
||||||
import copy
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
import time
|
|
||||||
from funasr.losses.label_smoothing_loss import (
|
|
||||||
LabelSmoothingLoss, # noqa: H301
|
|
||||||
)
|
|
||||||
# from funasr.models.ctc import CTC
|
|
||||||
# from funasr.models.decoder.abs_decoder import AbsDecoder
|
|
||||||
# from funasr.models.e2e_asr_common import ErrorCalculator
|
|
||||||
# from funasr.models.encoder.abs_encoder import AbsEncoder
|
|
||||||
# from funasr.frontends.abs_frontend import AbsFrontend
|
|
||||||
# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
|
|
||||||
from funasr.models.paraformer.cif_predictor import mae_loss
|
|
||||||
# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
|
||||||
# from funasr.models.specaug.abs_specaug import AbsSpecAug
|
|
||||||
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
|
|
||||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
|
|
||||||
from funasr.metrics.compute_acc import th_accuracy
|
|
||||||
from funasr.train_utils.device_funcs import force_gatherable
|
|
||||||
# from funasr.models.base_model import FunASRModel
|
|
||||||
# from funasr.models.paraformer.cif_predictor import CifPredictorV3
|
|
||||||
from funasr.models.paraformer.search import Hypothesis
|
|
||||||
|
|
||||||
from funasr.models.model_class_factory import *
|
from funasr.register import tables
|
||||||
|
from funasr.utils import postprocess_utils
|
||||||
|
from funasr.utils.datadir_writer import DatadirWriter
|
||||||
|
from funasr.train_utils.device_funcs import force_gatherable
|
||||||
|
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
|
||||||
|
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
||||||
|
from funasr.models.transformer.scorers.length_bonus import LengthBonus
|
||||||
|
from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
|
||||||
|
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||||
|
from funasr.models.transducer.beam_search_transducer import BeamSearchTransducer
|
||||||
|
|
||||||
|
|
||||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
@ -45,16 +29,10 @@ else:
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def autocast(enabled=True):
|
def autocast(enabled=True):
|
||||||
yield
|
yield
|
||||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
|
||||||
from funasr.utils import postprocess_utils
|
|
||||||
from funasr.utils.datadir_writer import DatadirWriter
|
|
||||||
from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
|
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
@tables.register("model_classes", "Transducer")
|
||||||
"""ESPnet2ASRTransducerModel module definition."""
|
class Transducer(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
frontend: Optional[str] = None,
|
frontend: Optional[str] = None,
|
||||||
@ -96,28 +74,24 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if frontend is not None:
|
|
||||||
frontend_class = frontend_classes.get_class(frontend)
|
|
||||||
frontend = frontend_class(**frontend_conf)
|
|
||||||
if specaug is not None:
|
if specaug is not None:
|
||||||
specaug_class = specaug_classes.get_class(specaug)
|
specaug_class = tables.specaug_classes.get(specaug)
|
||||||
specaug = specaug_class(**specaug_conf)
|
specaug = specaug_class(**specaug_conf)
|
||||||
if normalize is not None:
|
if normalize is not None:
|
||||||
normalize_class = normalize_classes.get_class(normalize)
|
normalize_class = tables.normalize_classes.get(normalize)
|
||||||
normalize = normalize_class(**normalize_conf)
|
normalize = normalize_class(**normalize_conf)
|
||||||
encoder_class = encoder_classes.get_class(encoder)
|
encoder_class = tables.encoder_classes.get(encoder)
|
||||||
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
||||||
encoder_output_size = encoder.output_size()
|
encoder_output_size = encoder.output_size()
|
||||||
|
|
||||||
decoder_class = decoder_classes.get_class(decoder)
|
decoder_class = tables.decoder_classes.get(decoder)
|
||||||
decoder = decoder_class(
|
decoder = decoder_class(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
encoder_output_size=encoder_output_size,
|
|
||||||
**decoder_conf,
|
**decoder_conf,
|
||||||
)
|
)
|
||||||
decoder_output_size = decoder.output_size
|
decoder_output_size = decoder.output_size
|
||||||
|
|
||||||
joint_network_class = joint_network_classes.get_class(decoder)
|
joint_network_class = tables.joint_network_classes.get(joint_network)
|
||||||
joint_network = joint_network_class(
|
joint_network = joint_network_class(
|
||||||
vocab_size,
|
vocab_size,
|
||||||
encoder_output_size,
|
encoder_output_size,
|
||||||
@ -125,7 +99,6 @@ class Transducer(nn.Module):
|
|||||||
**joint_network_conf,
|
**joint_network_conf,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
self.criterion_transducer = None
|
self.criterion_transducer = None
|
||||||
self.error_calculator = None
|
self.error_calculator = None
|
||||||
|
|
||||||
@ -157,23 +130,17 @@ class Transducer(nn.Module):
|
|||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
self.joint_network = joint_network
|
self.joint_network = joint_network
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
self.criterion_att = LabelSmoothingLoss(
|
self.criterion_att = LabelSmoothingLoss(
|
||||||
size=vocab_size,
|
size=vocab_size,
|
||||||
padding_idx=ignore_id,
|
padding_idx=ignore_id,
|
||||||
smoothing=lsm_weight,
|
smoothing=lsm_weight,
|
||||||
normalize_length=length_normalized_loss,
|
normalize_length=length_normalized_loss,
|
||||||
)
|
)
|
||||||
#
|
|
||||||
# if report_cer or report_wer:
|
|
||||||
# self.error_calculator = ErrorCalculator(
|
|
||||||
# token_list, sym_space, sym_blank, report_cer, report_wer
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
|
|
||||||
self.length_normalized_loss = length_normalized_loss
|
self.length_normalized_loss = length_normalized_loss
|
||||||
self.beam_search = None
|
self.beam_search = None
|
||||||
|
self.ctc = None
|
||||||
|
self.ctc_weight = 0.0
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -190,8 +157,6 @@ class Transducer(nn.Module):
|
|||||||
text: (Batch, Length)
|
text: (Batch, Length)
|
||||||
text_lengths: (Batch,)
|
text_lengths: (Batch,)
|
||||||
"""
|
"""
|
||||||
# import pdb;
|
|
||||||
# pdb.set_trace()
|
|
||||||
if len(text_lengths.size()) > 1:
|
if len(text_lengths.size()) > 1:
|
||||||
text_lengths = text_lengths[:, 0]
|
text_lengths = text_lengths[:, 0]
|
||||||
if len(speech_lengths.size()) > 1:
|
if len(speech_lengths.size()) > 1:
|
||||||
@ -283,12 +248,7 @@ class Transducer(nn.Module):
|
|||||||
# Forward encoder
|
# Forward encoder
|
||||||
# feats: (Batch, Length, Dim)
|
# feats: (Batch, Length, Dim)
|
||||||
# -> encoder_out: (Batch, Length2, Dim2)
|
# -> encoder_out: (Batch, Length2, Dim2)
|
||||||
if self.encoder.interctc_use_conditioning:
|
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
|
||||||
encoder_out, encoder_out_lens, _ = self.encoder(
|
|
||||||
speech, speech_lengths, ctc=self.ctc
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
|
|
||||||
intermediate_outs = None
|
intermediate_outs = None
|
||||||
if isinstance(encoder_out, tuple):
|
if isinstance(encoder_out, tuple):
|
||||||
intermediate_outs = encoder_out[1]
|
intermediate_outs = encoder_out[1]
|
||||||
@ -449,9 +409,6 @@ class Transducer(nn.Module):
|
|||||||
def init_beam_search(self,
|
def init_beam_search(self,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
from funasr.models.transformer.search import BeamSearch
|
|
||||||
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
|
|
||||||
from funasr.models.transformer.scorers.length_bonus import LengthBonus
|
|
||||||
|
|
||||||
# 1. Build ASR model
|
# 1. Build ASR model
|
||||||
scorers = {}
|
scorers = {}
|
||||||
@ -466,28 +423,16 @@ class Transducer(nn.Module):
|
|||||||
length_bonus=LengthBonus(len(token_list)),
|
length_bonus=LengthBonus(len(token_list)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 3. Build ngram model
|
# 3. Build ngram model
|
||||||
# ngram is not supported now
|
# ngram is not supported now
|
||||||
ngram = None
|
ngram = None
|
||||||
scorers["ngram"] = ngram
|
scorers["ngram"] = ngram
|
||||||
|
|
||||||
weights = dict(
|
beam_search = BeamSearchTransducer(
|
||||||
decoder=1.0 - kwargs.get("decoding_ctc_weight"),
|
self.decoder,
|
||||||
ctc=kwargs.get("decoding_ctc_weight", 0.0),
|
self.joint_network,
|
||||||
lm=kwargs.get("lm_weight", 0.0),
|
kwargs.get("beam_size", 2),
|
||||||
ngram=kwargs.get("ngram_weight", 0.0),
|
nbest=1,
|
||||||
length_bonus=kwargs.get("penalty", 0.0),
|
|
||||||
)
|
|
||||||
beam_search = BeamSearch(
|
|
||||||
beam_size=kwargs.get("beam_size", 2),
|
|
||||||
weights=weights,
|
|
||||||
scorers=scorers,
|
|
||||||
sos=self.sos,
|
|
||||||
eos=self.eos,
|
|
||||||
vocab_size=len(token_list),
|
|
||||||
token_list=token_list,
|
|
||||||
pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
|
|
||||||
)
|
)
|
||||||
# beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
|
# beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
|
||||||
# for scorer in scorers.values():
|
# for scorer in scorers.values():
|
||||||
@ -495,13 +440,13 @@ class Transducer(nn.Module):
|
|||||||
# scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
|
# scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
|
||||||
self.beam_search = beam_search
|
self.beam_search = beam_search
|
||||||
|
|
||||||
def generate(self,
|
def inference(self,
|
||||||
data_in: list,
|
data_in: list,
|
||||||
data_lengths: list=None,
|
data_lengths: list=None,
|
||||||
key: list=None,
|
key: list=None,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
if kwargs.get("batch_size", 1) > 1:
|
if kwargs.get("batch_size", 1) > 1:
|
||||||
raise NotImplementedError("batch decoding is not implemented")
|
raise NotImplementedError("batch decoding is not implemented")
|
||||||
@ -509,10 +454,10 @@ class Transducer(nn.Module):
|
|||||||
# init beamsearch
|
# init beamsearch
|
||||||
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
|
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
|
||||||
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
|
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
|
||||||
if self.beam_search is None and (is_use_lm or is_use_ctc):
|
# if self.beam_search is None and (is_use_lm or is_use_ctc):
|
||||||
logging.info("enable beam_search")
|
logging.info("enable beam_search")
|
||||||
self.init_beam_search(**kwargs)
|
self.init_beam_search(**kwargs)
|
||||||
self.nbest = kwargs.get("nbest", 1)
|
self.nbest = kwargs.get("nbest", 1)
|
||||||
|
|
||||||
meta_data = {}
|
meta_data = {}
|
||||||
# extract fbank feats
|
# extract fbank feats
|
||||||
@ -534,13 +479,9 @@ class Transducer(nn.Module):
|
|||||||
encoder_out = encoder_out[0]
|
encoder_out = encoder_out[0]
|
||||||
|
|
||||||
# c. Passed the encoder result and the beam search
|
# c. Passed the encoder result and the beam search
|
||||||
nbest_hyps = self.beam_search(
|
nbest_hyps = self.beam_search(encoder_out[0], is_final=True)
|
||||||
x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
|
|
||||||
)
|
|
||||||
|
|
||||||
nbest_hyps = nbest_hyps[: self.nbest]
|
nbest_hyps = nbest_hyps[: self.nbest]
|
||||||
|
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
b, n, d = encoder_out.size()
|
b, n, d = encoder_out.size()
|
||||||
for i in range(b):
|
for i in range(b):
|
||||||
@ -553,9 +494,9 @@ class Transducer(nn.Module):
|
|||||||
# remove sos/eos and get results
|
# remove sos/eos and get results
|
||||||
last_pos = -1
|
last_pos = -1
|
||||||
if isinstance(hyp.yseq, list):
|
if isinstance(hyp.yseq, list):
|
||||||
token_int = hyp.yseq[1:last_pos]
|
token_int = hyp.yseq#[1:last_pos]
|
||||||
else:
|
else:
|
||||||
token_int = hyp.yseq[1:last_pos].tolist()
|
token_int = hyp.yseq#[1:last_pos].tolist()
|
||||||
|
|
||||||
# remove blank symbol id, which is assumed to be 0
|
# remove blank symbol id, which is assumed to be 0
|
||||||
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
|
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
|
||||||
|
|||||||
@ -1,10 +1,15 @@
|
|||||||
import random
|
#!/usr/bin/env python3
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||||
|
# MIT License (https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from funasr.register import tables
|
||||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||||
from funasr.models.transformer.utils.nets_utils import to_device
|
from funasr.models.transformer.utils.nets_utils import to_device
|
||||||
from funasr.models.language_model.rnn.attentions import initial_att
|
from funasr.models.language_model.rnn.attentions import initial_att
|
||||||
@ -78,7 +83,7 @@ def build_attention_list(
|
|||||||
)
|
)
|
||||||
return att_list
|
return att_list
|
||||||
|
|
||||||
|
@tables.register("decoder_classes", "rnn_decoder")
|
||||||
class RNNDecoder(nn.Module):
|
class RNNDecoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,112 +0,0 @@
|
|||||||
|
|
||||||
from typing import Optional
|
|
||||||
from typing import Sequence
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
|
||||||
from funasr.models.language_model.rnn.encoders import RNN
|
|
||||||
from funasr.models.language_model.rnn.encoders import RNNP
|
|
||||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
|
||||||
|
|
||||||
|
|
||||||
class RNNEncoder(AbsEncoder):
|
|
||||||
"""RNNEncoder class.
|
|
||||||
Args:
|
|
||||||
input_size: The number of expected features in the input
|
|
||||||
output_size: The number of output features
|
|
||||||
hidden_size: The number of hidden features
|
|
||||||
bidirectional: If ``True`` becomes a bidirectional LSTM
|
|
||||||
use_projection: Use projection layer or not
|
|
||||||
num_layers: Number of recurrent layers
|
|
||||||
dropout: dropout probability
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_size: int,
|
|
||||||
rnn_type: str = "lstm",
|
|
||||||
bidirectional: bool = True,
|
|
||||||
use_projection: bool = True,
|
|
||||||
num_layers: int = 4,
|
|
||||||
hidden_size: int = 320,
|
|
||||||
output_size: int = 320,
|
|
||||||
dropout: float = 0.0,
|
|
||||||
subsample: Optional[Sequence[int]] = (2, 2, 1, 1),
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self._output_size = output_size
|
|
||||||
self.rnn_type = rnn_type
|
|
||||||
self.bidirectional = bidirectional
|
|
||||||
self.use_projection = use_projection
|
|
||||||
|
|
||||||
if rnn_type not in {"lstm", "gru"}:
|
|
||||||
raise ValueError(f"Not supported rnn_type={rnn_type}")
|
|
||||||
|
|
||||||
if subsample is None:
|
|
||||||
subsample = np.ones(num_layers + 1, dtype=np.int32)
|
|
||||||
else:
|
|
||||||
subsample = subsample[:num_layers]
|
|
||||||
# Append 1 at the beginning because the second or later is used
|
|
||||||
subsample = np.pad(
|
|
||||||
np.array(subsample, dtype=np.int32),
|
|
||||||
[1, num_layers - len(subsample)],
|
|
||||||
mode="constant",
|
|
||||||
constant_values=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
rnn_type = ("b" if bidirectional else "") + rnn_type
|
|
||||||
if use_projection:
|
|
||||||
self.enc = torch.nn.ModuleList(
|
|
||||||
[
|
|
||||||
RNNP(
|
|
||||||
input_size,
|
|
||||||
num_layers,
|
|
||||||
hidden_size,
|
|
||||||
output_size,
|
|
||||||
subsample,
|
|
||||||
dropout,
|
|
||||||
typ=rnn_type,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.enc = torch.nn.ModuleList(
|
|
||||||
[
|
|
||||||
RNN(
|
|
||||||
input_size,
|
|
||||||
num_layers,
|
|
||||||
hidden_size,
|
|
||||||
output_size,
|
|
||||||
dropout,
|
|
||||||
typ=rnn_type,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def output_size(self) -> int:
|
|
||||||
return self._output_size
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
xs_pad: torch.Tensor,
|
|
||||||
ilens: torch.Tensor,
|
|
||||||
prev_states: torch.Tensor = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
if prev_states is None:
|
|
||||||
prev_states = [None] * len(self.enc)
|
|
||||||
assert len(prev_states) == len(self.enc)
|
|
||||||
|
|
||||||
current_states = []
|
|
||||||
for module, prev_state in zip(self.enc, prev_states):
|
|
||||||
xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
|
|
||||||
current_states.append(states)
|
|
||||||
|
|
||||||
if self.use_projection:
|
|
||||||
xs_pad.masked_fill_(make_pad_mask(ilens, xs_pad, 1), 0.0)
|
|
||||||
else:
|
|
||||||
xs_pad = xs_pad.masked_fill(make_pad_mask(ilens, xs_pad, 1), 0.0)
|
|
||||||
return xs_pad, ilens, current_states
|
|
||||||
@ -1,12 +1,17 @@
|
|||||||
"""RNN decoder definition for Transducer models."""
|
#!/usr/bin/env python3
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
from typing import List, Optional, Tuple
|
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||||
|
# MIT License (https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from funasr.models.transducer.beam_search_transducer import Hypothesis
|
from funasr.register import tables
|
||||||
from funasr.models.specaug.specaug import SpecAug
|
from funasr.models.specaug.specaug import SpecAug
|
||||||
|
from funasr.models.transducer.beam_search_transducer import Hypothesis
|
||||||
|
|
||||||
|
|
||||||
|
@tables.register("decoder_classes", "rnnt_decoder")
|
||||||
class RNNTDecoder(torch.nn.Module):
|
class RNNTDecoder(torch.nn.Module):
|
||||||
"""RNN decoder module.
|
"""RNN decoder module.
|
||||||
|
|
||||||
|
|||||||
@ -312,8 +312,221 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|||||||
return self.forward_attention(v, scores, mask)
|
return self.forward_attention(v, scores, mask)
|
||||||
|
|
||||||
|
|
||||||
|
class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
|
||||||
|
"""RelPositionMultiHeadedAttention definition.
|
||||||
|
Args:
|
||||||
|
num_heads: Number of attention heads.
|
||||||
|
embed_size: Embedding size.
|
||||||
|
dropout_rate: Dropout rate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
embed_size: int,
|
||||||
|
dropout_rate: float = 0.0,
|
||||||
|
simplified_attention_score: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Construct an MultiHeadedAttention object."""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.d_k = embed_size // num_heads
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
assert self.d_k * num_heads == embed_size, (
|
||||||
|
"embed_size (%d) must be divisible by num_heads (%d)",
|
||||||
|
(embed_size, num_heads),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.linear_q = torch.nn.Linear(embed_size, embed_size)
|
||||||
|
self.linear_k = torch.nn.Linear(embed_size, embed_size)
|
||||||
|
self.linear_v = torch.nn.Linear(embed_size, embed_size)
|
||||||
|
|
||||||
|
self.linear_out = torch.nn.Linear(embed_size, embed_size)
|
||||||
|
|
||||||
|
if simplified_attention_score:
|
||||||
|
self.linear_pos = torch.nn.Linear(embed_size, num_heads)
|
||||||
|
|
||||||
|
self.compute_att_score = self.compute_simplified_attention_score
|
||||||
|
else:
|
||||||
|
self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
|
||||||
|
|
||||||
|
self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
|
||||||
|
self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
|
||||||
|
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||||
|
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||||
|
|
||||||
|
self.compute_att_score = self.compute_attention_score
|
||||||
|
|
||||||
|
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||||
|
self.attn = None
|
||||||
|
|
||||||
|
def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
|
||||||
|
"""Compute relative positional encoding.
|
||||||
|
Args:
|
||||||
|
x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
|
||||||
|
left_context: Number of frames in left context.
|
||||||
|
Returns:
|
||||||
|
x: Output sequence. (B, H, T_1, T_2)
|
||||||
|
"""
|
||||||
|
batch_size, n_heads, time1, n = x.shape
|
||||||
|
time2 = time1 + left_context
|
||||||
|
|
||||||
|
batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
|
||||||
|
|
||||||
|
return x.as_strided(
|
||||||
|
(batch_size, n_heads, time1, time2),
|
||||||
|
(batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
|
||||||
|
storage_offset=(n_stride * (time1 - 1)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_simplified_attention_score(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
pos_enc: torch.Tensor,
|
||||||
|
left_context: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Simplified attention score computation.
|
||||||
|
Reference: https://github.com/k2-fsa/icefall/pull/458
|
||||||
|
Args:
|
||||||
|
query: Transformed query tensor. (B, H, T_1, d_k)
|
||||||
|
key: Transformed key tensor. (B, H, T_2, d_k)
|
||||||
|
pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
|
||||||
|
left_context: Number of frames in left context.
|
||||||
|
Returns:
|
||||||
|
: Attention score. (B, H, T_1, T_2)
|
||||||
|
"""
|
||||||
|
pos_enc = self.linear_pos(pos_enc)
|
||||||
|
|
||||||
|
matrix_ac = torch.matmul(query, key.transpose(2, 3))
|
||||||
|
|
||||||
|
matrix_bd = self.rel_shift(
|
||||||
|
pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
|
||||||
|
left_context=left_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
|
||||||
|
|
||||||
|
def compute_attention_score(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
pos_enc: torch.Tensor,
|
||||||
|
left_context: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Attention score computation.
|
||||||
|
Args:
|
||||||
|
query: Transformed query tensor. (B, H, T_1, d_k)
|
||||||
|
key: Transformed key tensor. (B, H, T_2, d_k)
|
||||||
|
pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
|
||||||
|
left_context: Number of frames in left context.
|
||||||
|
Returns:
|
||||||
|
: Attention score. (B, H, T_1, T_2)
|
||||||
|
"""
|
||||||
|
p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
|
||||||
|
|
||||||
|
query = query.transpose(1, 2)
|
||||||
|
q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
|
||||||
|
q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
|
||||||
|
|
||||||
|
matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
|
||||||
|
|
||||||
|
matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
|
||||||
|
matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
|
||||||
|
|
||||||
|
return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
|
||||||
|
|
||||||
|
def forward_qkv(
|
||||||
|
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Transform query, key and value.
|
||||||
|
Args:
|
||||||
|
query: Query tensor. (B, T_1, size)
|
||||||
|
key: Key tensor. (B, T_2, size)
|
||||||
|
v: Value tensor. (B, T_2, size)
|
||||||
|
Returns:
|
||||||
|
q: Transformed query tensor. (B, H, T_1, d_k)
|
||||||
|
k: Transformed key tensor. (B, H, T_2, d_k)
|
||||||
|
v: Transformed value tensor. (B, H, T_2, d_k)
|
||||||
|
"""
|
||||||
|
n_batch = query.size(0)
|
||||||
|
|
||||||
|
q = (
|
||||||
|
self.linear_q(query)
|
||||||
|
.view(n_batch, -1, self.num_heads, self.d_k)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
k = (
|
||||||
|
self.linear_k(key)
|
||||||
|
.view(n_batch, -1, self.num_heads, self.d_k)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
v = (
|
||||||
|
self.linear_v(value)
|
||||||
|
.view(n_batch, -1, self.num_heads, self.d_k)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
def forward_attention(
|
||||||
|
self,
|
||||||
|
value: torch.Tensor,
|
||||||
|
scores: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
chunk_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute attention context vector.
|
||||||
|
Args:
|
||||||
|
value: Transformed value. (B, H, T_2, d_k)
|
||||||
|
scores: Attention score. (B, H, T_1, T_2)
|
||||||
|
mask: Source mask. (B, T_2)
|
||||||
|
chunk_mask: Chunk mask. (T_1, T_1)
|
||||||
|
Returns:
|
||||||
|
attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
|
||||||
|
"""
|
||||||
|
batch_size = scores.size(0)
|
||||||
|
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
if chunk_mask is not None:
|
||||||
|
mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
|
||||||
|
scores = scores.masked_fill(mask, float("-inf"))
|
||||||
|
self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
|
||||||
|
|
||||||
|
attn_output = self.dropout(self.attn)
|
||||||
|
attn_output = torch.matmul(attn_output, value)
|
||||||
|
|
||||||
|
attn_output = self.linear_out(
|
||||||
|
attn_output.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
.view(batch_size, -1, self.num_heads * self.d_k)
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
pos_enc: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
chunk_mask: Optional[torch.Tensor] = None,
|
||||||
|
left_context: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute scaled dot product attention with rel. positional encoding.
|
||||||
|
Args:
|
||||||
|
query: Query tensor. (B, T_1, size)
|
||||||
|
key: Key tensor. (B, T_2, size)
|
||||||
|
value: Value tensor. (B, T_2, size)
|
||||||
|
pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
|
||||||
|
mask: Source mask. (B, T_2)
|
||||||
|
chunk_mask: Chunk mask. (T_1, T_1)
|
||||||
|
left_context: Number of frames in left context.
|
||||||
|
Returns:
|
||||||
|
: Output tensor. (B, T_1, H * d_k)
|
||||||
|
"""
|
||||||
|
q, k, v = self.forward_qkv(query, key, value)
|
||||||
|
scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
|
||||||
|
return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user