rnnt reorg

This commit is contained in:
aky15 2023-04-12 16:49:56 +08:00
parent d46a542fae
commit 7d1efe158e
35 changed files with 418 additions and 1849 deletions

View File

@ -1,13 +1,13 @@
encoder_conf:
main_conf:
pos_wise_act_type: swish
pos_enc_dropout_rate: 0.3
pos_enc_dropout_rate: 0.5
conv_mod_act_type: swish
time_reduction_factor: 2
unified_model_training: true
default_chunk_size: 16
jitter_range: 4
left_chunk_size: 1
left_chunk_size: 0
input_conf:
block_type: conv2d
conv_size: 512
@ -18,9 +18,9 @@ encoder_conf:
linear_size: 2048
hidden_size: 512
heads: 8
dropout_rate: 0.3
pos_wise_dropout_rate: 0.3
att_dropout_rate: 0.3
dropout_rate: 0.5
pos_wise_dropout_rate: 0.5
att_dropout_rate: 0.5
conv_mod_kernel_size: 15
num_blocks: 12
@ -29,8 +29,8 @@ decoder: rnn
decoder_conf:
embed_size: 512
hidden_size: 512
embed_dropout_rate: 0.2
dropout_rate: 0.1
embed_dropout_rate: 0.5
dropout_rate: 0.5
joint_network_conf:
joint_space_size: 512
@ -41,14 +41,14 @@ model_conf:
# minibatch related
use_amp: true
batch_type: numel
batch_bins: 1600000
batch_type: unsorted
batch_size: 16
num_workers: 16
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 80
max_epoch: 200
val_scheduler_criterion:
- valid
- loss
@ -56,11 +56,11 @@ best_model_criterion:
- - valid
- cer_transducer_chunk
- min
keep_nbest_models: 5
keep_nbest_models: 10
optim: adam
optim_conf:
lr: 0.0003
lr: 0.001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
@ -75,10 +75,12 @@ specaug_conf:
apply_freq_mask: true
freq_mask_width_range:
- 0
- 30
- 40
num_freq_mask: 2
apply_time_mask: true
time_mask_width_range:
- 0
- 40
num_time_mask: 2
- 50
num_time_mask: 5
log_interval: 50

View File

@ -16,11 +16,11 @@ import torch
from packaging.version import parse as V
from typeguard import check_argument_types, check_return_type
from funasr.models_transducer.beam_search_transducer import (
from funasr.modules.beam_search.beam_search_transducer import (
BeamSearchTransducer,
Hypothesis,
)
from funasr.models_transducer.utils import TooShortUttError
from funasr.modules.nets_utils import TooShortUttError
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.tasks.asr_transducer import ASRTransducerTask
from funasr.tasks.lm import LMTask
@ -500,7 +500,6 @@ def inference(
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
<<<<<<< HEAD
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
assert len(batch.keys()) == 1
@ -541,59 +540,6 @@ def inference(
if text is not None:
ibest_writer["text"][key] = text
=======
# batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
logging.info("decoding, utt_id: {}".format(keys))
# N-best list of (text, token, token_int, hyp_object)
time_beg = time.time()
results = speech2text(cache=cache, **batch)
if len(results) < 1:
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
time_end = time.time()
forward_time = time_end - time_beg
lfr_factor = results[0][-1]
length = results[0][-2]
forward_time_total += forward_time
length_total += length
rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time, 100 * forward_time / (length * lfr_factor))
logging.info(rtf_cur)
for batch_id in range(_bs):
result = [results[batch_id][:-2]]
key = keys[batch_id]
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), result):
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
ibest_writer["rtf"][key] = rtf_cur
if text is not None:
text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
# asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = " ".join(word_lists)
logging.info("decoding, utt: {}, predictions: {}".format(key, text))
rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))
logging.info(rtf_avg)
if writer is not None:
ibest_writer["rtf"]["rtf_avf"] = rtf_avg
return asr_result_list
return _forward
>>>>>>> main
def get_parser():

View File

@ -10,11 +10,11 @@ from typeguard import check_argument_types
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
from funasr.models_transducer.encoder.encoder import Encoder
from funasr.models_transducer.joint_network import JointNetwork
from funasr.models_transducer.utils import get_transducer_task_io
from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
from funasr.models.joint_network import JointNetwork
from funasr.modules.nets_utils import get_transducer_task_io
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
@ -28,7 +28,7 @@ else:
yield
class ESPnetASRTransducerModel(AbsESPnetModel):
class TransducerModel(AbsESPnetModel):
"""ESPnet2ASRTransducerModel module definition.
Args:

View File

@ -10,10 +10,10 @@ from typeguard import check_argument_types
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
from funasr.models_transducer.encoder.encoder import Encoder
from funasr.models_transducer.joint_network import JointNetwork
from funasr.models_transducer.utils import get_transducer_task_io
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
from funasr.models.joint_network import JointNetwork
from funasr.modules.nets_utils import get_transducer_task_io
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
@ -23,7 +23,7 @@ from funasr.modules.nets_utils import th_accuracy
from funasr.losses.label_smoothing_loss import ( # noqa: H301
LabelSmoothingLoss,
)
from funasr.models_transducer.error_calculator import ErrorCalculator
from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
if V(torch.__version__) >= V("1.6.0"):
from torch.cuda.amp import autocast
else:
@ -33,7 +33,7 @@ else:
yield
class ESPnetASRUnifiedTransducerModel(AbsESPnetModel):
class UnifiedTransducerModel(AbsESPnetModel):
"""ESPnet2ASRTransducerModel module definition.
Args:
@ -289,7 +289,6 @@ class ESPnetASRUnifiedTransducerModel(AbsESPnetModel):
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def collect_feats(

View File

@ -1,26 +1,23 @@
"""Encoder for Transducer model."""
from typing import Any, Dict, List, Tuple
import torch
from typeguard import check_argument_types
from funasr.models_transducer.encoder.building import (
from funasr.models.encoder.chunk_encoder_utils.building import (
build_body_blocks,
build_input_block,
build_main_parameters,
build_positional_encoding,
)
from funasr.models_transducer.encoder.validation import validate_architecture
from funasr.models_transducer.utils import (
from funasr.models.encoder.chunk_encoder_utils.validation import validate_architecture
from funasr.modules.nets_utils import (
TooShortUttError,
check_short_utt,
make_chunk_mask,
make_source_mask,
)
class Encoder(torch.nn.Module):
class ChunkEncoder(torch.nn.Module):
"""Encoder module definition.
Args:
@ -61,10 +58,9 @@ class Encoder(torch.nn.Module):
self.unified_model_training = main_params["unified_model_training"]
self.default_chunk_size = main_params["default_chunk_size"]
self.jitter_range = main_params["jitter_range"]
self.time_reduction_factor = main_params["time_reduction_factor"]
self.jitter_range = main_params["jitter_range"]
self.time_reduction_factor = main_params["time_reduction_factor"]
def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
"""Return the corresponding number of sample for a given chunk size, in frames.
@ -79,7 +75,7 @@ class Encoder(torch.nn.Module):
"""
return self.embed.get_size_before_subsampling(size) * hop_length
def get_encoder_input_size(self, size: int) -> int:
"""Return the corresponding number of sample for a given chunk size, in frames.
@ -157,7 +153,7 @@ class Encoder(torch.nn.Module):
mask,
chunk_mask=chunk_mask,
)
olens = mask.eq(0).sum(1)
if self.time_reduction_factor > 1:
x_utt = x_utt[:,::self.time_reduction_factor,:]
@ -194,14 +190,14 @@ class Encoder(torch.nn.Module):
mask,
chunk_mask=chunk_mask,
)
olens = mask.eq(0).sum(1)
if self.time_reduction_factor > 1:
x = x[:,::self.time_reduction_factor,:]
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
return x, olens
def simu_chunk_forward(
self,
x: torch.Tensor,
@ -290,7 +286,7 @@ class Encoder(torch.nn.Module):
if right_context > 0:
x = x[:, 0:-right_context, :]
if self.time_reduction_factor > 1:
x = x[:,::self.time_reduction_factor,:]
return x

View File

@ -5,7 +5,7 @@ from typing import Optional, Tuple, Union
import torch
import math
from funasr.models_transducer.utils import sub_factor_to_params, pad_to_len
from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len
class ConvInput(torch.nn.Module):

View File

@ -2,22 +2,22 @@
from typing import Any, Dict, List, Optional, Union
from funasr.models_transducer.activation import get_activation
from funasr.models_transducer.encoder.blocks.branchformer import Branchformer
from funasr.models_transducer.encoder.blocks.conformer import Conformer
from funasr.models_transducer.encoder.blocks.conv1d import Conv1d
from funasr.models_transducer.encoder.blocks.conv_input import ConvInput
from funasr.models_transducer.encoder.blocks.linear_input import LinearInput
from funasr.models_transducer.encoder.modules.attention import ( # noqa: H301
from funasr.modules.activation import get_activation
from funasr.models.encoder.chunk_encoder_blocks.branchformer import Branchformer
from funasr.models.encoder.chunk_encoder_blocks.conformer import Conformer
from funasr.models.encoder.chunk_encoder_blocks.conv1d import Conv1d
from funasr.models.encoder.chunk_encoder_blocks.conv_input import ConvInput
from funasr.models.encoder.chunk_encoder_blocks.linear_input import LinearInput
from funasr.models.encoder.chunk_encoder_modules.attention import ( # noqa: H301
RelPositionMultiHeadedAttention,
)
from funasr.models_transducer.encoder.modules.convolution import ( # noqa: H301
from funasr.models.encoder.chunk_encoder_modules.convolution import ( # noqa: H301
ConformerConvolution,
ConvolutionalSpatialGatingUnit,
)
from funasr.models_transducer.encoder.modules.multi_blocks import MultiBlocks
from funasr.models_transducer.encoder.modules.normalization import get_normalization
from funasr.models_transducer.encoder.modules.positional_encoding import ( # noqa: H301
from funasr.models.encoder.chunk_encoder_modules.multi_blocks import MultiBlocks
from funasr.models.encoder.chunk_encoder_modules.normalization import get_normalization
from funasr.models.encoder.chunk_encoder_modules.positional_encoding import ( # noqa: H301
RelPositionalEncoding,
)
from funasr.modules.positionwise_feed_forward import (

View File

@ -2,7 +2,7 @@
from typing import Any, Dict, List, Tuple
from funasr.models_transducer.utils import sub_factor_to_params
from funasr.modules.nets_utils import sub_factor_to_params
def validate_block_arguments(

View File

@ -2,7 +2,7 @@
import torch
from funasr.models_transducer.activation import get_activation
from funasr.modules.activation import get_activation
class JointNetwork(torch.nn.Module):

View File

@ -5,8 +5,8 @@ from typing import List, Optional, Tuple
import torch
from typeguard import check_argument_types
from funasr.models_transducer.beam_search_transducer import Hypothesis
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
from funasr.modules.beam_search.beam_search_transducer import Hypothesis
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.specaug.specaug import SpecAug
class RNNDecoder(AbsDecoder):

View File

@ -5,8 +5,8 @@ from typing import List, Optional, Tuple
import torch
from typeguard import check_argument_types
from funasr.models_transducer.beam_search_transducer import Hypothesis
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
from funasr.modules.beam_search.beam_search_transducer import Hypothesis
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.specaug.specaug import SpecAug
class StatelessDecoder(AbsDecoder):
@ -26,7 +26,6 @@ class StatelessDecoder(AbsDecoder):
embed_size: int = 256,
embed_dropout_rate: float = 0.0,
embed_pad: int = 0,
use_embed_mask: bool = False,
) -> None:
"""Construct a StatelessDecoder object."""
super().__init__()
@ -42,14 +41,6 @@ class StatelessDecoder(AbsDecoder):
self.device = next(self.parameters()).device
self.score_cache = {}
self.use_embed_mask = use_embed_mask
if self.use_embed_mask:
self._embed_mask = SpecAug(
time_mask_width_range=3,
num_time_mask=1,
apply_freq_mask=False,
apply_time_warp=False
)
def forward(
@ -69,9 +60,6 @@ class StatelessDecoder(AbsDecoder):
"""
dec_embed = self.embed_dropout_rate(self.embed(labels))
if self.use_embed_mask and self.training:
dec_embed = self._embed_mask(dec_embed, label_lens)[0]
return dec_embed
def score(

View File

@ -1,835 +0,0 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import logging
import torch
import torch.nn as nn
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
from typeguard import check_argument_types
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
from funasr.modules.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr.modules.repeat import repeat
from funasr.modules.subsampling import Conv2dSubsampling
from funasr.modules.subsampling import Conv2dSubsampling2
from funasr.modules.subsampling import Conv2dSubsampling6
from funasr.modules.subsampling import Conv2dSubsampling8
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
class EncoderLayerSANM(nn.Module):
def __init__(
self,
in_size,
size,
self_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayerSANM, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(in_size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.in_size = in_size
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
self.dropout_rate = dropout_rate
def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
residual = x
if self.normalize_before:
x = self.norm1(x)
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = stoch_layer_coeff * self.concat_linear(x_concat)
else:
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.dropout(
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
)
else:
x = stoch_layer_coeff * self.dropout(
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
class SANMEncoder(AbsEncoder):
"""
author: Speech Lab, Alibaba Group, China
San-m: Memory equipped self-attention for end-to-end speech recognition
https://arxiv.org/abs/2006.01713
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
pos_enc_class=SinusoidalPositionEncoder,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
kernel_size : int = 11,
sanm_shfit : int = 0,
tf2torch_tensor_name_prefix_torch: str = "encoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
):
assert check_argument_types()
super().__init__()
self.embed = SinusoidalPositionEncoder()
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
encoder_selfattn_layer = MultiHeadedAttentionSANM
encoder_selfattn_layer_args0 = (
attention_heads,
input_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
encoder_selfattn_layer_args = (
attention_heads,
output_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
self.encoders0 = repeat(
1,
lambda lnum: EncoderLayerSANM(
input_size,
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args0),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
self.encoders = repeat(
num_blocks-1,
lambda lnum: EncoderLayerSANM(
output_size,
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
self.dropout = nn.Dropout(dropout_rate)
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
xs_pad = xs_pad * self.output_size**0.5
if self.embed is None:
xs_pad = xs_pad
elif (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
# xs_pad = self.dropout(xs_pad)
encoder_outs = self.encoders0(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
encoder_outs = self.encoders(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
map_dict_local = {
## encoder
# cicd
"{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (768,256),(1,256,768)
"{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (768,),(768,)
"{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 2, 0),
}, # (256,1,31),(1,31,256,1)
"{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,256),(1,256,256)
"{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
# ffn
"{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (1024,256),(1,256,1024)
"{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,1024),(1,1024,256)
"{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
# out norm
"{}.after_norm.weight".format(tensor_name_prefix_torch):
{"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.after_norm.bias".format(tensor_name_prefix_torch):
{"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
}
return map_dict_local
def convert_tf2torch(self,
var_dict_tf,
var_dict_torch,
):
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
for name in sorted(var_dict_torch.keys(), reverse=False):
names = name.split('.')
if names[0] == self.tf2torch_tensor_name_prefix_torch:
if names[1] == "encoders0":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
name_q = name_q.replace("encoders0", "encoders")
layeridx_bias = 0
layeridx += layeridx_bias
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
elif names[1] == "encoders":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
layeridx_bias = 1
layeridx += layeridx_bias
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
elif names[1] == "after_norm":
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
return var_dict_torch_update
class SANMEncoderChunkOpt(AbsEncoder):
"""
author: Speech Lab, Alibaba Group, China
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
https://arxiv.org/abs/2006.01713
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
pos_enc_class=SinusoidalPositionEncoder,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
kernel_size: int = 11,
sanm_shfit: int = 0,
chunk_size: Union[int, Sequence[int]] = (16,),
stride: Union[int, Sequence[int]] = (10,),
pad_left: Union[int, Sequence[int]] = (0,),
time_reduction_factor: int = 1,
encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
tf2torch_tensor_name_prefix_torch: str = "encoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
):
assert check_argument_types()
super().__init__()
self.output_size = output_size
self.embed = SinusoidalPositionEncoder()
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
encoder_selfattn_layer = MultiHeadedAttentionSANM
encoder_selfattn_layer_args0 = (
attention_heads,
input_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
encoder_selfattn_layer_args = (
attention_heads,
output_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
self.encoders0 = repeat(
1,
lambda lnum: EncoderLayerSANM(
input_size,
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args0),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
self.encoders = repeat(
num_blocks - 1,
lambda lnum: EncoderLayerSANM(
output_size,
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
shfit_fsmn = (kernel_size - 1) // 2
self.overlap_chunk_cls = overlap_chunk(
chunk_size=chunk_size,
stride=stride,
pad_left=pad_left,
shfit_fsmn=shfit_fsmn,
encoder_att_look_back_factor=encoder_att_look_back_factor,
decoder_att_look_back_factor=decoder_att_look_back_factor,
)
self.time_reduction_factor = time_reduction_factor
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
ind: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
xs_pad *= self.output_size ** 0.5
if self.embed is None:
xs_pad = xs_pad
elif (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
mask_shfit_chunk, mask_att_chunk_encoder = None, None
if self.overlap_chunk_cls is not None:
ilens = masks.squeeze(1).sum(1)
chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
dtype=xs_pad.dtype)
mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
xs_pad.size(0),
dtype=xs_pad.dtype)
encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
xs_pad, olens = self.overlap_chunk_cls.remove_chunk(xs_pad, olens, chunk_outs=None)
if self.time_reduction_factor > 1:
xs_pad = xs_pad[:,::self.time_reduction_factor,:]
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
map_dict_local = {
## encoder
# cicd
"{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (768,256),(1,256,768)
"{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (768,),(768,)
"{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 2, 0),
}, # (256,1,31),(1,31,256,1)
"{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,256),(1,256,256)
"{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
# ffn
"{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (1024,256),(1,256,1024)
"{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,1024),(1,1024,256)
"{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
# out norm
"{}.after_norm.weight".format(tensor_name_prefix_torch):
{"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.after_norm.bias".format(tensor_name_prefix_torch):
{"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
}
return map_dict_local
def convert_tf2torch(self,
var_dict_tf,
var_dict_torch,
):
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
for name in sorted(var_dict_torch.keys(), reverse=False):
names = name.split('.')
if names[0] == self.tf2torch_tensor_name_prefix_torch:
if names[1] == "encoders0":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
name_q = name_q.replace("encoders0", "encoders")
layeridx_bias = 0
layeridx += layeridx_bias
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
elif names[1] == "encoders":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
layeridx_bias = 1
layeridx += layeridx_bias
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
elif names[1] == "after_norm":
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
return var_dict_torch_update

View File

@ -1,169 +0,0 @@
"""Error Calculator module for Transducer."""
from typing import List, Optional, Tuple
import torch
from funasr.models_transducer.beam_search_transducer import BeamSearchTransducer
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
from funasr.models_transducer.joint_network import JointNetwork
class ErrorCalculator:
"""Calculate CER and WER for transducer models.
Args:
decoder: Decoder module.
joint_network: Joint Network module.
token_list: List of token units.
sym_space: Space symbol.
sym_blank: Blank symbol.
report_cer: Whether to compute CER.
report_wer: Whether to compute WER.
"""
def __init__(
self,
decoder: AbsDecoder,
joint_network: JointNetwork,
token_list: List[int],
sym_space: str,
sym_blank: str,
report_cer: bool = False,
report_wer: bool = False,
) -> None:
"""Construct an ErrorCalculatorTransducer object."""
super().__init__()
self.beam_search = BeamSearchTransducer(
decoder=decoder,
joint_network=joint_network,
beam_size=1,
search_type="default",
score_norm=False,
)
self.decoder = decoder
self.token_list = token_list
self.space = sym_space
self.blank = sym_blank
self.report_cer = report_cer
self.report_wer = report_wer
def __call__(
self, encoder_out: torch.Tensor, target: torch.Tensor
) -> Tuple[Optional[float], Optional[float]]:
"""Calculate sentence-level WER or/and CER score for Transducer model.
Args:
encoder_out: Encoder output sequences. (B, T, D_enc)
target: Target label ID sequences. (B, L)
Returns:
: Sentence-level CER score.
: Sentence-level WER score.
"""
cer, wer = None, None
batchsize = int(encoder_out.size(0))
encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
char_pred, char_target = self.convert_to_char(pred, target)
if self.report_cer:
cer = self.calculate_cer(char_pred, char_target)
if self.report_wer:
wer = self.calculate_wer(char_pred, char_target)
return cer, wer
def convert_to_char(
self, pred: torch.Tensor, target: torch.Tensor
) -> Tuple[List, List]:
"""Convert label ID sequences to character sequences.
Args:
pred: Prediction label ID sequences. (B, U)
target: Target label ID sequences. (B, L)
Returns:
char_pred: Prediction character sequences. (B, ?)
char_target: Target character sequences. (B, ?)
"""
char_pred, char_target = [], []
for i, pred_i in enumerate(pred):
char_pred_i = [self.token_list[int(h)] for h in pred_i]
char_target_i = [self.token_list[int(r)] for r in target[i]]
char_pred_i = "".join(char_pred_i).replace(self.space, " ")
char_pred_i = char_pred_i.replace(self.blank, "")
char_target_i = "".join(char_target_i).replace(self.space, " ")
char_target_i = char_target_i.replace(self.blank, "")
char_pred.append(char_pred_i)
char_target.append(char_target_i)
return char_pred, char_target
def calculate_cer(
self, char_pred: torch.Tensor, char_target: torch.Tensor
) -> float:
"""Calculate sentence-level CER score.
Args:
char_pred: Prediction character sequences. (B, ?)
char_target: Target character sequences. (B, ?)
Returns:
: Average sentence-level CER score.
"""
import editdistance
distances, lens = [], []
for i, char_pred_i in enumerate(char_pred):
pred = char_pred_i.replace(" ", "")
target = char_target[i].replace(" ", "")
distances.append(editdistance.eval(pred, target))
lens.append(len(target))
return float(sum(distances)) / sum(lens)
def calculate_wer(
self, char_pred: torch.Tensor, char_target: torch.Tensor
) -> float:
"""Calculate sentence-level WER score.
Args:
char_pred: Prediction character sequences. (B, ?)
char_target: Target character sequences. (B, ?)
Returns:
: Average sentence-level WER score
"""
import editdistance
distances, lens = [], []
for i, char_pred_i in enumerate(char_pred):
pred = char_pred_i.replace("", " ").split()
target = char_target[i].replace("", " ").split()
distances.append(editdistance.eval(pred, target))
lens.append(len(target))
return float(sum(distances)) / sum(lens)

View File

@ -1,485 +0,0 @@
"""ESPnet2 ASR Transducer model."""
import logging
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Union
import torch
from packaging.version import parse as V
from typeguard import check_argument_types
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
from funasr.models_transducer.encoder.encoder import Encoder
from funasr.models_transducer.joint_network import JointNetwork
from funasr.models_transducer.utils import get_transducer_task_io
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
if V(torch.__version__) >= V("1.6.0"):
from torch.cuda.amp import autocast
else:
@contextmanager
def autocast(enabled=True):
yield
class UniASRTransducerModel(AbsESPnetModel):
"""ESPnet2ASRTransducerModel module definition.
Args:
vocab_size: Size of complete vocabulary (w/ EOS and blank included).
token_list: List of token
frontend: Frontend module.
specaug: SpecAugment module.
normalize: Normalization module.
encoder: Encoder module.
decoder: Decoder module.
joint_network: Joint Network module.
transducer_weight: Weight of the Transducer loss.
fastemit_lambda: FastEmit lambda value.
auxiliary_ctc_weight: Weight of auxiliary CTC loss.
auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
ignore_id: Initial padding ID.
sym_space: Space symbol.
sym_blank: Blank Symbol
report_cer: Whether to report Character Error Rate during validation.
report_wer: Whether to report Word Error Rate during validation.
extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
"""
def __init__(
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
encoder,
decoder: AbsDecoder,
att_decoder: Optional[AbsAttDecoder],
joint_network: JointNetwork,
transducer_weight: float = 1.0,
fastemit_lambda: float = 0.0,
auxiliary_ctc_weight: float = 0.0,
auxiliary_ctc_dropout_rate: float = 0.0,
auxiliary_lm_loss_weight: float = 0.0,
auxiliary_lm_loss_smoothing: float = 0.0,
ignore_id: int = -1,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
report_cer: bool = True,
report_wer: bool = True,
extract_feats_in_collect_stats: bool = True,
) -> None:
"""Construct an ESPnetASRTransducerModel object."""
super().__init__()
assert check_argument_types()
# The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
self.blank_id = 0
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.token_list = token_list.copy()
self.sym_space = sym_space
self.sym_blank = sym_blank
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
self.encoder = encoder
self.decoder = decoder
self.joint_network = joint_network
self.criterion_transducer = None
self.error_calculator = None
self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
if self.use_auxiliary_ctc:
self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
if self.use_auxiliary_lm_loss:
self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
self.transducer_weight = transducer_weight
self.fastemit_lambda = fastemit_lambda
self.auxiliary_ctc_weight = auxiliary_ctc_weight
self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
self.report_cer = report_cer
self.report_wer = report_wer
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
decoding_ind: int = None,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Forward architecture and compute loss(es).
Args:
speech: Speech sequences. (B, S)
speech_lengths: Speech sequences lengths. (B,)
text: Label ID sequences. (B, L)
text_lengths: Label ID sequences lengths. (B,)
kwargs: Contains "utts_id".
Return:
loss: Main loss value.
stats: Task statistics.
weight: Task weights.
"""
assert text_lengths.dim() == 1, text_lengths.shape
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
text = text[:, : text_lengths.max()]
# 1. Encoder
ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
# 2. Transducer-related I/O preparation
decoder_in, target, t_len, u_len = get_transducer_task_io(
text,
encoder_out_lens,
ignore_id=self.ignore_id,
)
# 3. Decoder
self.decoder.set_device(encoder_out.device)
decoder_out = self.decoder(decoder_in, u_len)
# 4. Joint Network
joint_out = self.joint_network(
encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
)
# 5. Losses
loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
encoder_out,
joint_out,
target,
t_len,
u_len,
)
loss_ctc, loss_lm = 0.0, 0.0
if self.use_auxiliary_ctc:
loss_ctc = self._calc_ctc_loss(
encoder_out,
target,
t_len,
u_len,
)
if self.use_auxiliary_lm_loss:
loss_lm = self._calc_lm_loss(decoder_out, target)
loss = (
self.transducer_weight * loss_trans
+ self.auxiliary_ctc_weight * loss_ctc
+ self.auxiliary_lm_loss_weight * loss_lm
)
stats = dict(
loss=loss.detach(),
loss_transducer=loss_trans.detach(),
aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
cer_transducer=cer_trans,
wer_transducer=wer_trans,
)
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Collect features sequences and features lengths sequences.
Args:
speech: Speech sequences. (B, S)
speech_lengths: Speech sequences lengths. (B,)
text: Label ID sequences. (B, L)
text_lengths: Label ID sequences lengths. (B,)
kwargs: Contains "utts_id".
Return:
{}: "feats": Features sequences. (B, T, D_feats),
"feats_lengths": Features sequences lengths. (B,)
"""
if self.extract_feats_in_collect_stats:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
else:
# Generate dummy stats if extract_feats_in_collect_stats is False
logging.warning(
"Generating dummy stats for feats and feats_lengths, "
"because encoder_conf.extract_feats_in_collect_stats is "
f"{self.extract_feats_in_collect_stats}"
)
feats, feats_lengths = speech, speech_lengths
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
ind: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder speech sequences.
Args:
speech: Speech sequences. (B, S)
speech_lengths: Speech sequences lengths. (B,)
Return:
encoder_out: Encoder outputs. (B, T, D_enc)
encoder_out_lens: Encoder outputs lengths. (B,)
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# 4. Forward encoder
encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths, ind=ind)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
assert encoder_out.size(1) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
)
return encoder_out, encoder_out_lens
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Extract features sequences and features sequences lengths.
Args:
speech: Speech sequences. (B, S)
speech_lengths: Speech sequences lengths. (B,)
Return:
feats: Features sequences. (B, T, D_feats)
feats_lengths: Features sequences lengths. (B,)
"""
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
def _calc_transducer_loss(
self,
encoder_out: torch.Tensor,
joint_out: torch.Tensor,
target: torch.Tensor,
t_len: torch.Tensor,
u_len: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
"""Compute Transducer loss.
Args:
encoder_out: Encoder output sequences. (B, T, D_enc)
joint_out: Joint Network output sequences (B, T, U, D_joint)
target: Target label ID sequences. (B, L)
t_len: Encoder output sequences lengths. (B,)
u_len: Target label ID sequences lengths. (B,)
Return:
loss_transducer: Transducer loss value.
cer_transducer: Character error rate for Transducer.
wer_transducer: Word Error Rate for Transducer.
"""
if self.criterion_transducer is None:
try:
# from warprnnt_pytorch import RNNTLoss
# self.criterion_transducer = RNNTLoss(
# reduction="mean",
# fastemit_lambda=self.fastemit_lambda,
# )
from warp_rnnt import rnnt_loss as RNNTLoss
self.criterion_transducer = RNNTLoss
except ImportError:
logging.error(
"warp-rnnt was not installed."
"Please consult the installation documentation."
)
exit(1)
# loss_transducer = self.criterion_transducer(
# joint_out,
# target,
# t_len,
# u_len,
# )
log_probs = torch.log_softmax(joint_out, dim=-1)
loss_transducer = self.criterion_transducer(
log_probs,
target,
t_len,
u_len,
reduction="mean",
blank=self.blank_id,
gather=True,
)
if not self.training and (self.report_cer or self.report_wer):
if self.error_calculator is None:
from espnet2.asr_transducer.error_calculator import ErrorCalculator
self.error_calculator = ErrorCalculator(
self.decoder,
self.joint_network,
self.token_list,
self.sym_space,
self.sym_blank,
report_cer=self.report_cer,
report_wer=self.report_wer,
)
cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
return loss_transducer, cer_transducer, wer_transducer
return loss_transducer, None, None
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
target: torch.Tensor,
t_len: torch.Tensor,
u_len: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
encoder_out: Encoder output sequences. (B, T, D_enc)
target: Target label ID sequences. (B, L)
t_len: Encoder output sequences lengths. (B,)
u_len: Target label ID sequences lengths. (B,)
Return:
loss_ctc: CTC loss value.
"""
ctc_in = self.ctc_lin(
torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
)
ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
target_mask = target != 0
ctc_target = target[target_mask].cpu()
with torch.backends.cudnn.flags(deterministic=True):
loss_ctc = torch.nn.functional.ctc_loss(
ctc_in,
ctc_target,
t_len,
u_len,
zero_infinity=True,
reduction="sum",
)
loss_ctc /= target.size(0)
return loss_ctc
def _calc_lm_loss(
self,
decoder_out: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
"""Compute LM loss.
Args:
decoder_out: Decoder output sequences. (B, U, D_dec)
target: Target label ID sequences. (B, L)
Return:
loss_lm: LM loss value.
"""
lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
lm_target = target.view(-1).type(torch.int64)
with torch.no_grad():
true_dist = lm_loss_in.clone()
true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
# Ignore blank ID (0)
ignore = lm_target == 0
lm_target = lm_target.masked_fill(ignore, 0)
true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
loss_lm = torch.nn.functional.kl_div(
torch.log_softmax(lm_loss_in, dim=1),
true_dist,
reduction="none",
)
loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
0
)
return loss_lm

View File

@ -1,200 +0,0 @@
"""Utility functions for Transducer models."""
from typing import List, Tuple
import torch
class TooShortUttError(Exception):
"""Raised when the utt is too short for subsampling.
Args:
message: Error message to display.
actual_size: The size that cannot pass the subsampling.
limit: The size limit for subsampling.
"""
def __init__(self, message: str, actual_size: int, limit: int) -> None:
"""Construct a TooShortUttError module."""
super().__init__(message)
self.actual_size = actual_size
self.limit = limit
def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
"""Check if the input is too short for subsampling.
Args:
sub_factor: Subsampling factor for Conv2DSubsampling.
size: Input size.
Returns:
: Whether an error should be sent.
: Size limit for specified subsampling factor.
"""
if sub_factor == 2 and size < 3:
return True, 7
elif sub_factor == 4 and size < 7:
return True, 7
elif sub_factor == 6 and size < 11:
return True, 11
return False, -1
def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
"""Get conv2D second layer parameters for given subsampling factor.
Args:
sub_factor: Subsampling factor (1/X).
input_size: Input size.
Returns:
: Kernel size for second convolution.
: Stride for second convolution.
: Conv2DSubsampling output size.
"""
if sub_factor == 2:
return 3, 1, (((input_size - 1) // 2 - 2))
elif sub_factor == 4:
return 3, 2, (((input_size - 1) // 2 - 1) // 2)
elif sub_factor == 6:
return 5, 3, (((input_size - 1) // 2 - 2) // 3)
else:
raise ValueError(
"subsampling_factor parameter should be set to either 2, 4 or 6."
)
def make_chunk_mask(
size: int,
chunk_size: int,
left_chunk_size: int = 0,
device: torch.device = None,
) -> torch.Tensor:
"""Create chunk mask for the subsequent steps (size, size).
Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
Args:
size: Size of the source mask.
chunk_size: Number of frames in chunk.
left_chunk_size: Size of the left context in chunks (0 means full context).
device: Device for the mask tensor.
Returns:
mask: Chunk mask. (size, size)
"""
mask = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if left_chunk_size <= 0:
start = 0
else:
start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
end = min((i // chunk_size + 1) * chunk_size, size)
mask[i, start:end] = True
return ~mask
def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Create source mask for given lengths.
Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
Args:
lengths: Sequence lengths. (B,)
Returns:
: Mask for the sequence lengths. (B, max_len)
"""
max_len = lengths.max()
batch_size = lengths.size(0)
expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
return expanded_lengths >= lengths.unsqueeze(1)
def get_transducer_task_io(
labels: torch.Tensor,
encoder_out_lens: torch.Tensor,
ignore_id: int = -1,
blank_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Get Transducer loss I/O.
Args:
labels: Label ID sequences. (B, L)
encoder_out_lens: Encoder output lengths. (B,)
ignore_id: Padding symbol ID.
blank_id: Blank symbol ID.
Returns:
decoder_in: Decoder inputs. (B, U)
target: Target label ID sequences. (B, U)
t_len: Time lengths. (B,)
u_len: Label lengths. (B,)
"""
def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
"""Create padded batch of labels from a list of labels sequences.
Args:
labels: Labels sequences. [B x (?)]
padding_value: Padding value.
Returns:
labels: Batch of padded labels sequences. (B,)
"""
batch_size = len(labels)
padded = (
labels[0]
.new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
.fill_(padding_value)
)
for i in range(batch_size):
padded[i, : labels[i].size(0)] = labels[i]
return padded
device = labels.device
labels_unpad = [y[y != ignore_id] for y in labels]
blank = labels[0].new([blank_id])
decoder_in = pad_list(
[torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
).to(device)
target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
encoder_out_lens = list(map(int, encoder_out_lens))
t_len = torch.IntTensor(encoder_out_lens).to(device)
u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
return decoder_in, target, t_len, u_len
def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
"""Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
if t.size(dim) == pad_len:
return t
else:
pad_size = list(t.shape)
pad_size[dim] = pad_len - t.size(dim)
return torch.cat(
[t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
)

View File

@ -6,8 +6,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
from funasr.models_transducer.joint_network import JointNetwork
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.joint_network import JointNetwork
@dataclass

View File

@ -6,6 +6,8 @@
"""Common functions for ASR."""
from typing import List, Optional, Tuple
import json
import logging
import sys
@ -13,7 +15,11 @@ import sys
from itertools import groupby
import numpy as np
import six
import torch
from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.joint_network import JointNetwork
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
"""End detection.
@ -247,3 +253,148 @@ class ErrorCalculator(object):
word_eds.append(editdistance.eval(hyp_words, ref_words))
word_ref_lens.append(len(ref_words))
return float(sum(word_eds)) / sum(word_ref_lens)
class ErrorCalculatorTransducer:
"""Calculate CER and WER for transducer models.
Args:
decoder: Decoder module.
joint_network: Joint Network module.
token_list: List of token units.
sym_space: Space symbol.
sym_blank: Blank symbol.
report_cer: Whether to compute CER.
report_wer: Whether to compute WER.
"""
def __init__(
self,
decoder: AbsDecoder,
joint_network: JointNetwork,
token_list: List[int],
sym_space: str,
sym_blank: str,
report_cer: bool = False,
report_wer: bool = False,
) -> None:
"""Construct an ErrorCalculatorTransducer object."""
super().__init__()
self.beam_search = BeamSearchTransducer(
decoder=decoder,
joint_network=joint_network,
beam_size=1,
search_type="default",
score_norm=False,
)
self.decoder = decoder
self.token_list = token_list
self.space = sym_space
self.blank = sym_blank
self.report_cer = report_cer
self.report_wer = report_wer
def __call__(
self, encoder_out: torch.Tensor, target: torch.Tensor
) -> Tuple[Optional[float], Optional[float]]:
"""Calculate sentence-level WER or/and CER score for Transducer model.
Args:
encoder_out: Encoder output sequences. (B, T, D_enc)
target: Target label ID sequences. (B, L)
Returns:
: Sentence-level CER score.
: Sentence-level WER score.
"""
cer, wer = None, None
batchsize = int(encoder_out.size(0))
encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
char_pred, char_target = self.convert_to_char(pred, target)
if self.report_cer:
cer = self.calculate_cer(char_pred, char_target)
if self.report_wer:
wer = self.calculate_wer(char_pred, char_target)
return cer, wer
def convert_to_char(
self, pred: torch.Tensor, target: torch.Tensor
) -> Tuple[List, List]:
"""Convert label ID sequences to character sequences.
Args:
pred: Prediction label ID sequences. (B, U)
target: Target label ID sequences. (B, L)
Returns:
char_pred: Prediction character sequences. (B, ?)
char_target: Target character sequences. (B, ?)
"""
char_pred, char_target = [], []
for i, pred_i in enumerate(pred):
char_pred_i = [self.token_list[int(h)] for h in pred_i]
char_target_i = [self.token_list[int(r)] for r in target[i]]
char_pred_i = "".join(char_pred_i).replace(self.space, " ")
char_pred_i = char_pred_i.replace(self.blank, "")
char_target_i = "".join(char_target_i).replace(self.space, " ")
char_target_i = char_target_i.replace(self.blank, "")
char_pred.append(char_pred_i)
char_target.append(char_target_i)
return char_pred, char_target
def calculate_cer(
self, char_pred: torch.Tensor, char_target: torch.Tensor
) -> float:
"""Calculate sentence-level CER score.
Args:
char_pred: Prediction character sequences. (B, ?)
char_target: Target character sequences. (B, ?)
Returns:
: Average sentence-level CER score.
"""
import editdistance
distances, lens = [], []
for i, char_pred_i in enumerate(char_pred):
pred = char_pred_i.replace(" ", "")
target = char_target[i].replace(" ", "")
distances.append(editdistance.eval(pred, target))
lens.append(len(target))
return float(sum(distances)) / sum(lens)
def calculate_wer(
self, char_pred: torch.Tensor, char_target: torch.Tensor
) -> float:
"""Calculate sentence-level WER score.
Args:
char_pred: Prediction character sequences. (B, ?)
char_target: Target character sequences. (B, ?)
Returns:
: Average sentence-level WER score
"""
import editdistance
distances, lens = [], []
for i, char_pred_i in enumerate(char_pred):
pred = char_pred_i.replace("", " ").split()
target = char_target[i].replace("", " ").split()
distances.append(editdistance.eval(pred, target))
lens.append(len(target))
return float(sum(distances)) / sum(lens)

View File

@ -3,7 +3,7 @@
"""Network related utility tools."""
import logging
from typing import Dict
from typing import Dict, List, Tuple
import numpy as np
import torch
@ -506,3 +506,196 @@ def get_activation(act):
}
return activation_funcs[act]()
class TooShortUttError(Exception):
"""Raised when the utt is too short for subsampling.
Args:
message: Error message to display.
actual_size: The size that cannot pass the subsampling.
limit: The size limit for subsampling.
"""
def __init__(self, message: str, actual_size: int, limit: int) -> None:
"""Construct a TooShortUttError module."""
super().__init__(message)
self.actual_size = actual_size
self.limit = limit
def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
"""Check if the input is too short for subsampling.
Args:
sub_factor: Subsampling factor for Conv2DSubsampling.
size: Input size.
Returns:
: Whether an error should be sent.
: Size limit for specified subsampling factor.
"""
if sub_factor == 2 and size < 3:
return True, 7
elif sub_factor == 4 and size < 7:
return True, 7
elif sub_factor == 6 and size < 11:
return True, 11
return False, -1
def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
"""Get conv2D second layer parameters for given subsampling factor.
Args:
sub_factor: Subsampling factor (1/X).
input_size: Input size.
Returns:
: Kernel size for second convolution.
: Stride for second convolution.
: Conv2DSubsampling output size.
"""
if sub_factor == 2:
return 3, 1, (((input_size - 1) // 2 - 2))
elif sub_factor == 4:
return 3, 2, (((input_size - 1) // 2 - 1) // 2)
elif sub_factor == 6:
return 5, 3, (((input_size - 1) // 2 - 2) // 3)
else:
raise ValueError(
"subsampling_factor parameter should be set to either 2, 4 or 6."
)
def make_chunk_mask(
size: int,
chunk_size: int,
left_chunk_size: int = 0,
device: torch.device = None,
) -> torch.Tensor:
"""Create chunk mask for the subsequent steps (size, size).
Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
Args:
size: Size of the source mask.
chunk_size: Number of frames in chunk.
left_chunk_size: Size of the left context in chunks (0 means full context).
device: Device for the mask tensor.
Returns:
mask: Chunk mask. (size, size)
"""
mask = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if left_chunk_size <= 0:
start = 0
else:
start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
end = min((i // chunk_size + 1) * chunk_size, size)
mask[i, start:end] = True
return ~mask
def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Create source mask for given lengths.
Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
Args:
lengths: Sequence lengths. (B,)
Returns:
: Mask for the sequence lengths. (B, max_len)
"""
max_len = lengths.max()
batch_size = lengths.size(0)
expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
return expanded_lengths >= lengths.unsqueeze(1)
def get_transducer_task_io(
labels: torch.Tensor,
encoder_out_lens: torch.Tensor,
ignore_id: int = -1,
blank_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Get Transducer loss I/O.
Args:
labels: Label ID sequences. (B, L)
encoder_out_lens: Encoder output lengths. (B,)
ignore_id: Padding symbol ID.
blank_id: Blank symbol ID.
Returns:
decoder_in: Decoder inputs. (B, U)
target: Target label ID sequences. (B, U)
t_len: Time lengths. (B,)
u_len: Label lengths. (B,)
"""
def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
"""Create padded batch of labels from a list of labels sequences.
Args:
labels: Labels sequences. [B x (?)]
padding_value: Padding value.
Returns:
labels: Batch of padded labels sequences. (B,)
"""
batch_size = len(labels)
padded = (
labels[0]
.new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
.fill_(padding_value)
)
for i in range(batch_size):
padded[i, : labels[i].size(0)] = labels[i]
return padded
device = labels.device
labels_unpad = [y[y != ignore_id] for y in labels]
blank = labels[0].new([blank_id])
decoder_in = pad_list(
[torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
).to(device)
target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
encoder_out_lens = list(map(int, encoder_out_lens))
t_len = torch.IntTensor(encoder_out_lens).to(device)
u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
return decoder_in, target, t_len, u_len
def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
"""Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
if t.size(dim) == pad_len:
return t
else:
pad_size = list(t.shape)
pad_size[dim] = pad_len - t.size(dim)
return torch.cat(
[t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
)

View File

@ -21,15 +21,13 @@ from funasr.models.decoder.transformer_decoder import (
LightweightConvolutionTransformerDecoder,
TransformerDecoder,
)
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
from funasr.models_transducer.decoder.rnn_decoder import RNNDecoder
from funasr.models_transducer.decoder.stateless_decoder import StatelessDecoder
from funasr.models_transducer.encoder.encoder import Encoder
from funasr.models_transducer.encoder.sanm_encoder import SANMEncoderChunkOpt
from funasr.models_transducer.espnet_transducer_model import ESPnetASRTransducerModel
from funasr.models_transducer.espnet_transducer_model_unified import ESPnetASRUnifiedTransducerModel
from funasr.models_transducer.espnet_transducer_model_uni_asr import UniASRTransducerModel
from funasr.models_transducer.joint_network import JointNetwork
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder
from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder
from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
from funasr.models.e2e_transducer import TransducerModel
from funasr.models.e2e_transducer_unified import UnifiedTransducerModel
from funasr.models.joint_network import JointNetwork
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
@ -75,7 +73,6 @@ encoder_choices = ClassChoices(
"encoder",
classes=dict(
encoder=Encoder,
sanm_chunk_opt=SANMEncoderChunkOpt,
),
default="encoder",
)
@ -158,7 +155,7 @@ class ASRTransducerTask(AbsTask):
group.add_argument(
"--model_conf",
action=NestedDictAction,
default=get_default_kwargs(ESPnetASRTransducerModel),
default=get_default_kwargs(TransducerModel),
help="The keyword arguments for the model class.",
)
# group.add_argument(
@ -354,7 +351,7 @@ class ASRTransducerTask(AbsTask):
return retval
@classmethod
def build_model(cls, args: argparse.Namespace) -> ESPnetASRTransducerModel:
def build_model(cls, args: argparse.Namespace) -> TransducerModel:
"""Required data depending on task mode.
Args:
cls: ASRTransducerTask object.
@ -440,22 +437,8 @@ class ASRTransducerTask(AbsTask):
# 7. Build model
if getattr(args, "encoder", None) is not None and args.encoder == 'sanm_chunk_opt':
model = UniASRTransducerModel(
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,
specaug=specaug,
normalize=normalize,
encoder=encoder,
decoder=decoder,
att_decoder=att_decoder,
joint_network=joint_network,
**args.model_conf,
)
elif encoder.unified_model_training:
model = ESPnetASRUnifiedTransducerModel(
if encoder.unified_model_training:
model = UnifiedTransducerModel(
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,
@ -469,7 +452,7 @@ class ASRTransducerTask(AbsTask):
)
else:
model = ESPnetASRTransducerModel(
model = TransducerModel(
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,