Merge branch 'main' of github.com:alibaba-damo-academy/FunASR

merge
This commit is contained in:
游雁 2024-02-20 14:05:58 +08:00
commit d79287c37e
9 changed files with 957 additions and 232 deletions

View File

@ -14,6 +14,7 @@ from funasr.models.transformer.attention import (
MultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttention, # noqa: H301
LegacyRelPositionMultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttentionChunk,
)
from funasr.models.transformer.embedding import (
PositionalEncoding, # noqa: H301
@ -610,4 +611,669 @@ class ConformerEncoder(nn.Module):
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
class CausalConvolution(torch.nn.Module):
"""ConformerConvolution module definition.
Args:
channels: The number of channels.
kernel_size: Size of the convolving kernel.
activation: Type of activation function.
norm_args: Normalization module arguments.
causal: Whether to use causal convolution (set to True if streaming).
"""
def __init__(
self,
channels: int,
kernel_size: int,
activation: torch.nn.Module = torch.nn.ReLU(),
norm_args: Dict = {},
causal: bool = False,
) -> None:
"""Construct an ConformerConvolution object."""
super().__init__()
assert (kernel_size - 1) % 2 == 0
self.kernel_size = kernel_size
self.pointwise_conv1 = torch.nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
)
if causal:
self.lorder = kernel_size - 1
padding = 0
else:
self.lorder = 0
padding = (kernel_size - 1) // 2
self.depthwise_conv = torch.nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
)
self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
self.pointwise_conv2 = torch.nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
)
self.activation = activation
def forward(
self,
x: torch.Tensor,
cache: Optional[torch.Tensor] = None,
right_context: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args:
x: ConformerConvolution input sequences. (B, T, D_hidden)
cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
right_context: Number of frames in right context.
Returns:
x: ConformerConvolution output sequences. (B, T, D_hidden)
cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
"""
x = self.pointwise_conv1(x.transpose(1, 2))
x = torch.nn.functional.glu(x, dim=1)
if self.lorder > 0:
if cache is None:
x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
else:
x = torch.cat([cache, x], dim=2)
if right_context > 0:
cache = x[:, :, -(self.lorder + right_context) : -right_context]
else:
cache = x[:, :, -self.lorder :]
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
x = self.pointwise_conv2(x).transpose(1, 2)
return x, cache
class ChunkEncoderLayer(torch.nn.Module):
"""Chunk Conformer module definition.
Args:
block_size: Input/output size.
self_att: Self-attention module instance.
feed_forward: Feed-forward module instance.
feed_forward_macaron: Feed-forward module instance for macaron network.
conv_mod: Convolution module instance.
norm_class: Normalization module class.
norm_args: Normalization module arguments.
dropout_rate: Dropout rate.
"""
def __init__(
self,
block_size: int,
self_att: torch.nn.Module,
feed_forward: torch.nn.Module,
feed_forward_macaron: torch.nn.Module,
conv_mod: torch.nn.Module,
norm_class: torch.nn.Module = LayerNorm,
norm_args: Dict = {},
dropout_rate: float = 0.0,
) -> None:
"""Construct a Conformer object."""
super().__init__()
self.self_att = self_att
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.feed_forward_scale = 0.5
self.conv_mod = conv_mod
self.norm_feed_forward = norm_class(block_size, **norm_args)
self.norm_self_att = norm_class(block_size, **norm_args)
self.norm_macaron = norm_class(block_size, **norm_args)
self.norm_conv = norm_class(block_size, **norm_args)
self.norm_final = norm_class(block_size, **norm_args)
self.dropout = torch.nn.Dropout(dropout_rate)
self.block_size = block_size
self.cache = None
def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
"""Initialize/Reset self-attention and convolution modules cache for streaming.
Args:
left_context: Number of left frames during chunk-by-chunk inference.
device: Device to use for cache tensor.
"""
self.cache = [
torch.zeros(
(1, left_context, self.block_size),
device=device,
),
torch.zeros(
(
1,
self.block_size,
self.conv_mod.kernel_size - 1,
),
device=device,
),
]
def forward(
self,
x: torch.Tensor,
pos_enc: torch.Tensor,
mask: torch.Tensor,
chunk_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Encode input sequences.
Args:
x: Conformer input sequences. (B, T, D_block)
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
mask: Source mask. (B, T)
chunk_mask: Chunk mask. (T_2, T_2)
Returns:
x: Conformer output sequences. (B, T, D_block)
mask: Source mask. (B, T)
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
"""
residual = x
x = self.norm_macaron(x)
x = residual + self.feed_forward_scale * self.dropout(
self.feed_forward_macaron(x)
)
residual = x
x = self.norm_self_att(x)
x_q = x
x = residual + self.dropout(
self.self_att(
x_q,
x,
x,
pos_enc,
mask,
chunk_mask=chunk_mask,
)
)
residual = x
x = self.norm_conv(x)
x, _ = self.conv_mod(x)
x = residual + self.dropout(x)
residual = x
x = self.norm_feed_forward(x)
x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
x = self.norm_final(x)
return x, mask, pos_enc
def chunk_forward(
self,
x: torch.Tensor,
pos_enc: torch.Tensor,
mask: torch.Tensor,
chunk_size: int = 16,
left_context: int = 0,
right_context: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode chunk of input sequence.
Args:
x: Conformer input sequences. (B, T, D_block)
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
mask: Source mask. (B, T_2)
left_context: Number of frames in left context.
right_context: Number of frames in right context.
Returns:
x: Conformer output sequences. (B, T, D_block)
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
"""
residual = x
x = self.norm_macaron(x)
x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
residual = x
x = self.norm_self_att(x)
if left_context > 0:
key = torch.cat([self.cache[0], x], dim=1)
else:
key = x
val = key
if right_context > 0:
att_cache = key[:, -(left_context + right_context) : -right_context, :]
else:
att_cache = key[:, -left_context:, :]
x = residual + self.self_att(
x,
key,
val,
pos_enc,
mask,
left_context=left_context,
)
residual = x
x = self.norm_conv(x)
x, conv_cache = self.conv_mod(
x, cache=self.cache[1], right_context=right_context
)
x = residual + x
residual = x
x = self.norm_feed_forward(x)
x = residual + self.feed_forward_scale * self.feed_forward(x)
x = self.norm_final(x)
self.cache = [att_cache, conv_cache]
return x, pos_enc
@tables.register("encoder_classes", "ChunkConformerEncoder")
class ConformerChunkEncoder(torch.nn.Module):
"""Encoder module definition.
Args:
input_size: Input size.
body_conf: Encoder body configuration.
input_conf: Encoder input configuration.
main_conf: Encoder main configuration.
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
embed_vgg_like: bool = False,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 3,
macaron_style: bool = False,
rel_pos_type: str = "legacy",
pos_enc_layer_type: str = "rel_pos",
selfattention_layer_type: str = "rel_selfattn",
activation_type: str = "swish",
use_cnn_module: bool = True,
zero_triu: bool = False,
norm_type: str = "layer_norm",
cnn_module_kernel: int = 31,
conv_mod_norm_eps: float = 0.00001,
conv_mod_norm_momentum: float = 0.1,
simplified_att_score: bool = False,
dynamic_chunk_training: bool = False,
short_chunk_threshold: float = 0.75,
short_chunk_size: int = 25,
left_chunk_size: int = 0,
time_reduction_factor: int = 1,
unified_model_training: bool = False,
default_chunk_size: int = 16,
jitter_range: int = 4,
subsampling_factor: int = 1,
) -> None:
"""Construct an Encoder object."""
super().__init__()
self.embed = StreamingConvInput(
input_size=input_size,
conv_size=output_size,
subsampling_factor=subsampling_factor,
vgg_like=embed_vgg_like,
output_size=output_size,
)
self.pos_enc = StreamingRelPositionalEncoding(
output_size,
positional_dropout_rate,
)
activation = get_activation(
activation_type
)
pos_wise_args = (
output_size,
linear_units,
positional_dropout_rate,
activation,
)
conv_mod_norm_args = {
"eps": conv_mod_norm_eps,
"momentum": conv_mod_norm_momentum,
}
conv_mod_args = (
output_size,
cnn_module_kernel,
activation,
conv_mod_norm_args,
dynamic_chunk_training or unified_model_training,
)
mult_att_args = (
attention_heads,
output_size,
attention_dropout_rate,
simplified_att_score,
)
fn_modules = []
for _ in range(num_blocks):
module = lambda: ChunkEncoderLayer(
output_size,
RelPositionMultiHeadedAttentionChunk(*mult_att_args),
PositionwiseFeedForward(*pos_wise_args),
PositionwiseFeedForward(*pos_wise_args),
CausalConvolution(*conv_mod_args),
dropout_rate=dropout_rate,
)
fn_modules.append(module)
self.encoders = MultiBlocks(
[fn() for fn in fn_modules],
output_size,
)
self._output_size = output_size
self.dynamic_chunk_training = dynamic_chunk_training
self.short_chunk_threshold = short_chunk_threshold
self.short_chunk_size = short_chunk_size
self.left_chunk_size = left_chunk_size
self.unified_model_training = unified_model_training
self.default_chunk_size = default_chunk_size
self.jitter_range = jitter_range
self.time_reduction_factor = time_reduction_factor
def output_size(self) -> int:
return self._output_size
def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
"""Return the corresponding number of sample for a given chunk size, in frames.
Where size is the number of features frames after applying subsampling.
Args:
size: Number of frames after subsampling.
hop_length: Frontend's hop length
Returns:
: Number of raw samples
"""
return self.embed.get_size_before_subsampling(size) * hop_length
def get_encoder_input_size(self, size: int) -> int:
"""Return the corresponding number of sample for a given chunk size, in frames.
Where size is the number of features frames after applying subsampling.
Args:
size: Number of frames after subsampling.
Returns:
: Number of raw samples
"""
return self.embed.get_size_before_subsampling(size)
def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
"""Initialize/Reset encoder streaming cache.
Args:
left_context: Number of frames in left context.
device: Device ID.
"""
return self.encoders.reset_streaming_cache(left_context, device)
def forward(
self,
x: torch.Tensor,
x_len: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode input sequences.
Args:
x: Encoder input features. (B, T_in, F)
x_len: Encoder input features lengths. (B,)
Returns:
x: Encoder outputs. (B, T_out, D_enc)
x_len: Encoder outputs lenghts. (B,)
"""
short_status, limit_size = check_short_utt(
self.embed.subsampling_factor, x.size(1)
)
if short_status:
raise TooShortUttError(
f"has {x.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
x.size(1),
limit_size,
)
mask = make_source_mask(x_len).to(x.device)
if self.unified_model_training:
if self.training:
chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
else:
chunk_size = self.default_chunk_size
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
chunk_mask = make_chunk_mask(
x.size(1),
chunk_size,
left_chunk_size=self.left_chunk_size,
device=x.device,
)
x_utt = self.encoders(
x,
pos_enc,
mask,
chunk_mask=None,
)
x_chunk = self.encoders(
x,
pos_enc,
mask,
chunk_mask=chunk_mask,
)
olens = mask.eq(0).sum(1)
if self.time_reduction_factor > 1:
x_utt = x_utt[:,::self.time_reduction_factor,:]
x_chunk = x_chunk[:,::self.time_reduction_factor,:]
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
return x_utt, x_chunk, olens
elif self.dynamic_chunk_training:
max_len = x.size(1)
if self.training:
chunk_size = torch.randint(1, max_len, (1,)).item()
if chunk_size > (max_len * self.short_chunk_threshold):
chunk_size = max_len
else:
chunk_size = (chunk_size % self.short_chunk_size) + 1
else:
chunk_size = self.default_chunk_size
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
chunk_mask = make_chunk_mask(
x.size(1),
chunk_size,
left_chunk_size=self.left_chunk_size,
device=x.device,
)
else:
x, mask = self.embed(x, mask, None)
pos_enc = self.pos_enc(x)
chunk_mask = None
x = self.encoders(
x,
pos_enc,
mask,
chunk_mask=chunk_mask,
)
olens = mask.eq(0).sum(1)
if self.time_reduction_factor > 1:
x = x[:,::self.time_reduction_factor,:]
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
return x, olens, None
def full_utt_forward(
self,
x: torch.Tensor,
x_len: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode input sequences.
Args:
x: Encoder input features. (B, T_in, F)
x_len: Encoder input features lengths. (B,)
Returns:
x: Encoder outputs. (B, T_out, D_enc)
x_len: Encoder outputs lenghts. (B,)
"""
short_status, limit_size = check_short_utt(
self.embed.subsampling_factor, x.size(1)
)
if short_status:
raise TooShortUttError(
f"has {x.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
x.size(1),
limit_size,
)
mask = make_source_mask(x_len).to(x.device)
x, mask = self.embed(x, mask, None)
pos_enc = self.pos_enc(x)
x_utt = self.encoders(
x,
pos_enc,
mask,
chunk_mask=None,
)
if self.time_reduction_factor > 1:
x_utt = x_utt[:,::self.time_reduction_factor,:]
return x_utt
def simu_chunk_forward(
self,
x: torch.Tensor,
x_len: torch.Tensor,
chunk_size: int = 16,
left_context: int = 32,
right_context: int = 0,
) -> torch.Tensor:
short_status, limit_size = check_short_utt(
self.embed.subsampling_factor, x.size(1)
)
if short_status:
raise TooShortUttError(
f"has {x.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
x.size(1),
limit_size,
)
mask = make_source_mask(x_len)
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
chunk_mask = make_chunk_mask(
x.size(1),
chunk_size,
left_chunk_size=self.left_chunk_size,
device=x.device,
)
x = self.encoders(
x,
pos_enc,
mask,
chunk_mask=chunk_mask,
)
olens = mask.eq(0).sum(1)
if self.time_reduction_factor > 1:
x = x[:,::self.time_reduction_factor,:]
return x
def chunk_forward(
self,
x: torch.Tensor,
x_len: torch.Tensor,
processed_frames: torch.tensor,
chunk_size: int = 16,
left_context: int = 32,
right_context: int = 0,
) -> torch.Tensor:
"""Encode input sequences as chunks.
Args:
x: Encoder input features. (1, T_in, F)
x_len: Encoder input features lengths. (1,)
processed_frames: Number of frames already seen.
left_context: Number of frames in left context.
right_context: Number of frames in right context.
Returns:
x: Encoder outputs. (B, T_out, D_enc)
"""
mask = make_source_mask(x_len)
x, mask = self.embed(x, mask, None)
if left_context > 0:
processed_mask = (
torch.arange(left_context, device=x.device)
.view(1, left_context)
.flip(1)
)
processed_mask = processed_mask >= processed_frames
mask = torch.cat([processed_mask, mask], dim=1)
pos_enc = self.pos_enc(x, left_context=left_context)
x = self.encoders.chunk_forward(
x,
pos_enc,
mask,
chunk_size=chunk_size,
left_context=left_context,
right_context=right_context,
)
if right_context > 0:
x = x[:, 0:-right_context, :]
if self.time_reduction_factor > 1:
x = x[:,::self.time_reduction_factor,:]
return x

View File

@ -335,7 +335,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# hotword
self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)

View File

@ -1,11 +1,13 @@
"""Search algorithms for Transducer models."""
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import torch
import numpy as np
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from funasr.models.transducer.joint_network import JointNetwork

View File

@ -1,10 +1,15 @@
"""Transducer joint network implementation."""
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import torch
from funasr.register import tables
from funasr.models.transformer.utils.nets_utils import get_activation
@tables.register("joint_network_classes", "joint_network")
class JointNetwork(torch.nn.Module):
"""Transducer joint network module.

View File

@ -1,42 +1,26 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import time
import torch
import logging
from contextlib import contextmanager
from typing import Dict, Optional, Tuple
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import tempfile
import codecs
import requests
import re
import copy
import torch
import torch.nn as nn
import random
import numpy as np
import time
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
# from funasr.models.ctc import CTC
# from funasr.models.decoder.abs_decoder import AbsDecoder
# from funasr.models.e2e_asr_common import ErrorCalculator
# from funasr.models.encoder.abs_encoder import AbsEncoder
# from funasr.frontends.abs_frontend import AbsFrontend
# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.paraformer.cif_predictor import mae_loss
# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
# from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import force_gatherable
# from funasr.models.base_model import FunASRModel
# from funasr.models.paraformer.cif_predictor import CifPredictorV3
from funasr.models.paraformer.search import Hypothesis
from funasr.models.model_class_factory import *
from funasr.register import tables
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.scorers.length_bonus import LengthBonus
from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.models.transducer.beam_search_transducer import BeamSearchTransducer
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@ -45,16 +29,10 @@ else:
@contextmanager
def autocast(enabled=True):
yield
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
class Transducer(nn.Module):
"""ESPnet2ASRTransducerModel module definition."""
@tables.register("model_classes", "Transducer")
class Transducer(torch.nn.Module):
def __init__(
self,
frontend: Optional[str] = None,
@ -96,28 +74,24 @@ class Transducer(nn.Module):
super().__init__()
if frontend is not None:
frontend_class = frontend_classes.get_class(frontend)
frontend = frontend_class(**frontend_conf)
if specaug is not None:
specaug_class = specaug_classes.get_class(specaug)
specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**specaug_conf)
if normalize is not None:
normalize_class = normalize_classes.get_class(normalize)
normalize_class = tables.normalize_classes.get(normalize)
normalize = normalize_class(**normalize_conf)
encoder_class = encoder_classes.get_class(encoder)
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(input_size=input_size, **encoder_conf)
encoder_output_size = encoder.output_size()
decoder_class = decoder_classes.get_class(decoder)
decoder_class = tables.decoder_classes.get(decoder)
decoder = decoder_class(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
**decoder_conf,
)
decoder_output_size = decoder.output_size
joint_network_class = joint_network_classes.get_class(decoder)
joint_network_class = tables.joint_network_classes.get(joint_network)
joint_network = joint_network_class(
vocab_size,
encoder_output_size,
@ -125,7 +99,6 @@ class Transducer(nn.Module):
**joint_network_conf,
)
self.criterion_transducer = None
self.error_calculator = None
@ -157,23 +130,17 @@ class Transducer(nn.Module):
self.decoder = decoder
self.joint_network = joint_network
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
#
# if report_cer or report_wer:
# self.error_calculator = ErrorCalculator(
# token_list, sym_space, sym_blank, report_cer, report_wer
# )
#
self.length_normalized_loss = length_normalized_loss
self.beam_search = None
self.ctc = None
self.ctc_weight = 0.0
def forward(
self,
@ -190,8 +157,6 @@ class Transducer(nn.Module):
text: (Batch, Length)
text_lengths: (Batch,)
"""
# import pdb;
# pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
@ -283,12 +248,7 @@ class Transducer(nn.Module):
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder(
speech, speech_lengths, ctc=self.ctc
)
else:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
@ -449,9 +409,6 @@ class Transducer(nn.Module):
def init_beam_search(self,
**kwargs,
):
from funasr.models.transformer.search import BeamSearch
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
@ -466,28 +423,16 @@ class Transducer(nn.Module):
length_bonus=LengthBonus(len(token_list)),
)
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
weights = dict(
decoder=1.0 - kwargs.get("decoding_ctc_weight"),
ctc=kwargs.get("decoding_ctc_weight", 0.0),
lm=kwargs.get("lm_weight", 0.0),
ngram=kwargs.get("ngram_weight", 0.0),
length_bonus=kwargs.get("penalty", 0.0),
)
beam_search = BeamSearch(
beam_size=kwargs.get("beam_size", 2),
weights=weights,
scorers=scorers,
sos=self.sos,
eos=self.eos,
vocab_size=len(token_list),
token_list=token_list,
pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
beam_search = BeamSearchTransducer(
self.decoder,
self.joint_network,
kwargs.get("beam_size", 2),
nbest=1,
)
# beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
# for scorer in scorers.values():
@ -495,13 +440,13 @@ class Transducer(nn.Module):
# scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
self.beam_search = beam_search
def generate(self,
data_in: list,
data_lengths: list=None,
key: list=None,
tokenizer=None,
**kwargs,
):
def inference(self,
data_in: list,
data_lengths: list=None,
key: list=None,
tokenizer=None,
**kwargs,
):
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
@ -509,10 +454,10 @@ class Transducer(nn.Module):
# init beamsearch
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
# if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
# extract fbank feats
@ -534,13 +479,9 @@ class Transducer(nn.Module):
encoder_out = encoder_out[0]
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
)
nbest_hyps = self.beam_search(encoder_out[0], is_final=True)
nbest_hyps = nbest_hyps[: self.nbest]
results = []
b, n, d = encoder_out.size()
for i in range(b):
@ -553,9 +494,9 @@ class Transducer(nn.Module):
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
token_int = hyp.yseq#[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
token_int = hyp.yseq#[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))

View File

@ -1,10 +1,15 @@
import random
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import numpy as np
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from funasr.register import tables
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.utils.nets_utils import to_device
from funasr.models.language_model.rnn.attentions import initial_att
@ -78,7 +83,7 @@ def build_attention_list(
)
return att_list
@tables.register("decoder_classes", "rnn_decoder")
class RNNDecoder(nn.Module):
def __init__(
self,

View File

@ -1,112 +0,0 @@
from typing import Optional
from typing import Sequence
from typing import Tuple
import numpy as np
import torch
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.language_model.rnn.encoders import RNN
from funasr.models.language_model.rnn.encoders import RNNP
from funasr.models.encoder.abs_encoder import AbsEncoder
class RNNEncoder(AbsEncoder):
"""RNNEncoder class.
Args:
input_size: The number of expected features in the input
output_size: The number of output features
hidden_size: The number of hidden features
bidirectional: If ``True`` becomes a bidirectional LSTM
use_projection: Use projection layer or not
num_layers: Number of recurrent layers
dropout: dropout probability
"""
def __init__(
self,
input_size: int,
rnn_type: str = "lstm",
bidirectional: bool = True,
use_projection: bool = True,
num_layers: int = 4,
hidden_size: int = 320,
output_size: int = 320,
dropout: float = 0.0,
subsample: Optional[Sequence[int]] = (2, 2, 1, 1),
):
super().__init__()
self._output_size = output_size
self.rnn_type = rnn_type
self.bidirectional = bidirectional
self.use_projection = use_projection
if rnn_type not in {"lstm", "gru"}:
raise ValueError(f"Not supported rnn_type={rnn_type}")
if subsample is None:
subsample = np.ones(num_layers + 1, dtype=np.int32)
else:
subsample = subsample[:num_layers]
# Append 1 at the beginning because the second or later is used
subsample = np.pad(
np.array(subsample, dtype=np.int32),
[1, num_layers - len(subsample)],
mode="constant",
constant_values=1,
)
rnn_type = ("b" if bidirectional else "") + rnn_type
if use_projection:
self.enc = torch.nn.ModuleList(
[
RNNP(
input_size,
num_layers,
hidden_size,
output_size,
subsample,
dropout,
typ=rnn_type,
)
]
)
else:
self.enc = torch.nn.ModuleList(
[
RNN(
input_size,
num_layers,
hidden_size,
output_size,
dropout,
typ=rnn_type,
)
]
)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if prev_states is None:
prev_states = [None] * len(self.enc)
assert len(prev_states) == len(self.enc)
current_states = []
for module, prev_state in zip(self.enc, prev_states):
xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
current_states.append(states)
if self.use_projection:
xs_pad.masked_fill_(make_pad_mask(ilens, xs_pad, 1), 0.0)
else:
xs_pad = xs_pad.masked_fill(make_pad_mask(ilens, xs_pad, 1), 0.0)
return xs_pad, ilens, current_states

View File

@ -1,12 +1,17 @@
"""RNN decoder definition for Transducer models."""
from typing import List, Optional, Tuple
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import torch
from typing import List, Optional, Tuple
from funasr.models.transducer.beam_search_transducer import Hypothesis
from funasr.register import tables
from funasr.models.specaug.specaug import SpecAug
from funasr.models.transducer.beam_search_transducer import Hypothesis
@tables.register("decoder_classes", "rnnt_decoder")
class RNNTDecoder(torch.nn.Module):
"""RNN decoder module.

View File

@ -312,8 +312,221 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
return self.forward_attention(v, scores, mask)
class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
"""RelPositionMultiHeadedAttention definition.
Args:
num_heads: Number of attention heads.
embed_size: Embedding size.
dropout_rate: Dropout rate.
"""
def __init__(
self,
num_heads: int,
embed_size: int,
dropout_rate: float = 0.0,
simplified_attention_score: bool = False,
) -> None:
"""Construct an MultiHeadedAttention object."""
super().__init__()
self.d_k = embed_size // num_heads
self.num_heads = num_heads
assert self.d_k * num_heads == embed_size, (
"embed_size (%d) must be divisible by num_heads (%d)",
(embed_size, num_heads),
)
self.linear_q = torch.nn.Linear(embed_size, embed_size)
self.linear_k = torch.nn.Linear(embed_size, embed_size)
self.linear_v = torch.nn.Linear(embed_size, embed_size)
self.linear_out = torch.nn.Linear(embed_size, embed_size)
if simplified_attention_score:
self.linear_pos = torch.nn.Linear(embed_size, num_heads)
self.compute_att_score = self.compute_simplified_attention_score
else:
self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
self.compute_att_score = self.compute_attention_score
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.attn = None
def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
"""Compute relative positional encoding.
Args:
x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
left_context: Number of frames in left context.
Returns:
x: Output sequence. (B, H, T_1, T_2)
"""
batch_size, n_heads, time1, n = x.shape
time2 = time1 + left_context
batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
return x.as_strided(
(batch_size, n_heads, time1, time2),
(batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
storage_offset=(n_stride * (time1 - 1)),
)
def compute_simplified_attention_score(
self,
query: torch.Tensor,
key: torch.Tensor,
pos_enc: torch.Tensor,
left_context: int = 0,
) -> torch.Tensor:
"""Simplified attention score computation.
Reference: https://github.com/k2-fsa/icefall/pull/458
Args:
query: Transformed query tensor. (B, H, T_1, d_k)
key: Transformed key tensor. (B, H, T_2, d_k)
pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
left_context: Number of frames in left context.
Returns:
: Attention score. (B, H, T_1, T_2)
"""
pos_enc = self.linear_pos(pos_enc)
matrix_ac = torch.matmul(query, key.transpose(2, 3))
matrix_bd = self.rel_shift(
pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
left_context=left_context,
)
return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
def compute_attention_score(
self,
query: torch.Tensor,
key: torch.Tensor,
pos_enc: torch.Tensor,
left_context: int = 0,
) -> torch.Tensor:
"""Attention score computation.
Args:
query: Transformed query tensor. (B, H, T_1, d_k)
key: Transformed key tensor. (B, H, T_2, d_k)
pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
left_context: Number of frames in left context.
Returns:
: Attention score. (B, H, T_1, T_2)
"""
p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
query = query.transpose(1, 2)
q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transform query, key and value.
Args:
query: Query tensor. (B, T_1, size)
key: Key tensor. (B, T_2, size)
v: Value tensor. (B, T_2, size)
Returns:
q: Transformed query tensor. (B, H, T_1, d_k)
k: Transformed key tensor. (B, H, T_2, d_k)
v: Transformed value tensor. (B, H, T_2, d_k)
"""
n_batch = query.size(0)
q = (
self.linear_q(query)
.view(n_batch, -1, self.num_heads, self.d_k)
.transpose(1, 2)
)
k = (
self.linear_k(key)
.view(n_batch, -1, self.num_heads, self.d_k)
.transpose(1, 2)
)
v = (
self.linear_v(value)
.view(n_batch, -1, self.num_heads, self.d_k)
.transpose(1, 2)
)
return q, k, v
def forward_attention(
self,
value: torch.Tensor,
scores: torch.Tensor,
mask: torch.Tensor,
chunk_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute attention context vector.
Args:
value: Transformed value. (B, H, T_2, d_k)
scores: Attention score. (B, H, T_1, T_2)
mask: Source mask. (B, T_2)
chunk_mask: Chunk mask. (T_1, T_1)
Returns:
attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
"""
batch_size = scores.size(0)
mask = mask.unsqueeze(1).unsqueeze(2)
if chunk_mask is not None:
mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
scores = scores.masked_fill(mask, float("-inf"))
self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
attn_output = self.dropout(self.attn)
attn_output = torch.matmul(attn_output, value)
attn_output = self.linear_out(
attn_output.transpose(1, 2)
.contiguous()
.view(batch_size, -1, self.num_heads * self.d_k)
)
return attn_output
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
pos_enc: torch.Tensor,
mask: torch.Tensor,
chunk_mask: Optional[torch.Tensor] = None,
left_context: int = 0,
) -> torch.Tensor:
"""Compute scaled dot product attention with rel. positional encoding.
Args:
query: Query tensor. (B, T_1, size)
key: Key tensor. (B, T_2, size)
value: Value tensor. (B, T_2, size)
pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
mask: Source mask. (B, T_2)
chunk_mask: Chunk mask. (T_1, T_1)
left_context: Number of frames in left context.
Returns:
: Output tensor. (B, T_1, H * d_k)
"""
q, k, v = self.forward_qkv(query, key, value)
scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)