diff --git a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml index ef37b97eb..60f796c75 100644 --- a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml +++ b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml @@ -1,13 +1,13 @@ encoder_conf: main_conf: pos_wise_act_type: swish - pos_enc_dropout_rate: 0.3 + pos_enc_dropout_rate: 0.5 conv_mod_act_type: swish time_reduction_factor: 2 unified_model_training: true default_chunk_size: 16 jitter_range: 4 - left_chunk_size: 1 + left_chunk_size: 0 input_conf: block_type: conv2d conv_size: 512 @@ -18,9 +18,9 @@ encoder_conf: linear_size: 2048 hidden_size: 512 heads: 8 - dropout_rate: 0.3 - pos_wise_dropout_rate: 0.3 - att_dropout_rate: 0.3 + dropout_rate: 0.5 + pos_wise_dropout_rate: 0.5 + att_dropout_rate: 0.5 conv_mod_kernel_size: 15 num_blocks: 12 @@ -29,8 +29,8 @@ decoder: rnn decoder_conf: embed_size: 512 hidden_size: 512 - embed_dropout_rate: 0.2 - dropout_rate: 0.1 + embed_dropout_rate: 0.5 + dropout_rate: 0.5 joint_network_conf: joint_space_size: 512 @@ -41,14 +41,14 @@ model_conf: # minibatch related use_amp: true -batch_type: numel -batch_bins: 1600000 +batch_type: unsorted +batch_size: 16 num_workers: 16 # optimization related accum_grad: 1 grad_clip: 5 -max_epoch: 80 +max_epoch: 200 val_scheduler_criterion: - valid - loss @@ -56,11 +56,11 @@ best_model_criterion: - - valid - cer_transducer_chunk - min -keep_nbest_models: 5 +keep_nbest_models: 10 optim: adam optim_conf: - lr: 0.0003 + lr: 0.001 scheduler: warmuplr scheduler_conf: warmup_steps: 25000 @@ -75,10 +75,12 @@ specaug_conf: apply_freq_mask: true freq_mask_width_range: - 0 - - 30 + - 40 num_freq_mask: 2 apply_time_mask: true time_mask_width_range: - 0 - - 40 - num_time_mask: 2 + - 50 + num_time_mask: 5 + +log_interval: 50 diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py index 768bf7215..465f88254 100644 --- a/funasr/bin/asr_inference_rnnt.py +++ b/funasr/bin/asr_inference_rnnt.py @@ -16,11 +16,11 @@ import torch from packaging.version import parse as V from typeguard import check_argument_types, check_return_type -from funasr.models_transducer.beam_search_transducer import ( +from funasr.modules.beam_search.beam_search_transducer import ( BeamSearchTransducer, Hypothesis, ) -from funasr.models_transducer.utils import TooShortUttError +from funasr.modules.nets_utils import TooShortUttError from funasr.fileio.datadir_writer import DatadirWriter from funasr.tasks.asr_transducer import ASRTransducerTask from funasr.tasks.lm import LMTask @@ -500,7 +500,6 @@ def inference( _bs = len(next(iter(batch.values()))) assert len(keys) == _bs, f"{len(keys)} != {_bs}" -<<<<<<< HEAD batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} assert len(batch.keys()) == 1 @@ -541,59 +540,6 @@ def inference( if text is not None: ibest_writer["text"][key] = text -======= - # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")} - - logging.info("decoding, utt_id: {}".format(keys)) - # N-best list of (text, token, token_int, hyp_object) - - time_beg = time.time() - results = speech2text(cache=cache, **batch) - if len(results) < 1: - hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) - results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest - time_end = time.time() - forward_time = time_end - time_beg - lfr_factor = results[0][-1] - length = results[0][-2] - forward_time_total += forward_time - length_total += length - rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time, 100 * forward_time / (length * lfr_factor)) - logging.info(rtf_cur) - - for batch_id in range(_bs): - result = [results[batch_id][:-2]] - - key = keys[batch_id] - for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), result): - # Create a directory: outdir/{n}best_recog - if writer is not None: - ibest_writer = writer[f"{n}best_recog"] - - # Write the result to each file - ibest_writer["token"][key] = " ".join(token) - # ibest_writer["token_int"][key] = " ".join(map(str, token_int)) - ibest_writer["score"][key] = str(hyp.score) - ibest_writer["rtf"][key] = rtf_cur - - if text is not None: - text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token) - item = {'key': key, 'value': text_postprocessed} - asr_result_list.append(item) - finish_count += 1 - # asr_utils.print_progress(finish_count / file_count) - if writer is not None: - ibest_writer["text"][key] = " ".join(word_lists) - - logging.info("decoding, utt: {}, predictions: {}".format(key, text)) - rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor)) - logging.info(rtf_avg) - if writer is not None: - ibest_writer["rtf"]["rtf_avf"] = rtf_avg - return asr_result_list - - return _forward ->>>>>>> main def get_parser(): diff --git a/funasr/models_transducer/espnet_transducer_model.py b/funasr/models/e2e_transducer.py similarity index 98% rename from funasr/models_transducer/espnet_transducer_model.py rename to funasr/models/e2e_transducer.py index e32f6e350..b669c9d3e 100644 --- a/funasr/models_transducer/espnet_transducer_model.py +++ b/funasr/models/e2e_transducer.py @@ -10,11 +10,11 @@ from typeguard import check_argument_types from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.specaug.abs_specaug import AbsSpecAug -from funasr.models_transducer.decoder.abs_decoder import AbsDecoder +from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder -from funasr.models_transducer.encoder.encoder import Encoder -from funasr.models_transducer.joint_network import JointNetwork -from funasr.models_transducer.utils import get_transducer_task_io +from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder +from funasr.models.joint_network import JointNetwork +from funasr.modules.nets_utils import get_transducer_task_io from funasr.layers.abs_normalize import AbsNormalize from funasr.torch_utils.device_funcs import force_gatherable from funasr.train.abs_espnet_model import AbsESPnetModel @@ -28,7 +28,7 @@ else: yield -class ESPnetASRTransducerModel(AbsESPnetModel): +class TransducerModel(AbsESPnetModel): """ESPnet2ASRTransducerModel module definition. Args: diff --git a/funasr/models_transducer/espnet_transducer_model_unified.py b/funasr/models/e2e_transducer_unified.py similarity index 98% rename from funasr/models_transducer/espnet_transducer_model_unified.py rename to funasr/models/e2e_transducer_unified.py index be61e8381..600354216 100644 --- a/funasr/models_transducer/espnet_transducer_model_unified.py +++ b/funasr/models/e2e_transducer_unified.py @@ -10,10 +10,10 @@ from typeguard import check_argument_types from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.specaug.abs_specaug import AbsSpecAug -from funasr.models_transducer.decoder.abs_decoder import AbsDecoder -from funasr.models_transducer.encoder.encoder import Encoder -from funasr.models_transducer.joint_network import JointNetwork -from funasr.models_transducer.utils import get_transducer_task_io +from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder +from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder +from funasr.models.joint_network import JointNetwork +from funasr.modules.nets_utils import get_transducer_task_io from funasr.layers.abs_normalize import AbsNormalize from funasr.torch_utils.device_funcs import force_gatherable from funasr.train.abs_espnet_model import AbsESPnetModel @@ -23,7 +23,7 @@ from funasr.modules.nets_utils import th_accuracy from funasr.losses.label_smoothing_loss import ( # noqa: H301 LabelSmoothingLoss, ) -from funasr.models_transducer.error_calculator import ErrorCalculator +from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator if V(torch.__version__) >= V("1.6.0"): from torch.cuda.amp import autocast else: @@ -33,7 +33,7 @@ else: yield -class ESPnetASRUnifiedTransducerModel(AbsESPnetModel): +class UnifiedTransducerModel(AbsESPnetModel): """ESPnet2ASRTransducerModel module definition. Args: @@ -289,7 +289,6 @@ class ESPnetASRUnifiedTransducerModel(AbsESPnetModel): # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) - return loss, stats, weight def collect_feats( diff --git a/funasr/models_transducer/encoder/encoder.py b/funasr/models/encoder/chunk_encoder.py similarity index 96% rename from funasr/models_transducer/encoder/encoder.py rename to funasr/models/encoder/chunk_encoder.py index b486a113f..c6fc292e0 100644 --- a/funasr/models_transducer/encoder/encoder.py +++ b/funasr/models/encoder/chunk_encoder.py @@ -1,26 +1,23 @@ -"""Encoder for Transducer model.""" - from typing import Any, Dict, List, Tuple import torch from typeguard import check_argument_types -from funasr.models_transducer.encoder.building import ( +from funasr.models.encoder.chunk_encoder_utils.building import ( build_body_blocks, build_input_block, build_main_parameters, build_positional_encoding, ) -from funasr.models_transducer.encoder.validation import validate_architecture -from funasr.models_transducer.utils import ( +from funasr.models.encoder.chunk_encoder_utils.validation import validate_architecture +from funasr.modules.nets_utils import ( TooShortUttError, check_short_utt, make_chunk_mask, make_source_mask, ) - -class Encoder(torch.nn.Module): +class ChunkEncoder(torch.nn.Module): """Encoder module definition. Args: @@ -61,10 +58,9 @@ class Encoder(torch.nn.Module): self.unified_model_training = main_params["unified_model_training"] self.default_chunk_size = main_params["default_chunk_size"] - self.jitter_range = main_params["jitter_range"] - - self.time_reduction_factor = main_params["time_reduction_factor"] + self.jitter_range = main_params["jitter_range"] + self.time_reduction_factor = main_params["time_reduction_factor"] def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int: """Return the corresponding number of sample for a given chunk size, in frames. @@ -79,7 +75,7 @@ class Encoder(torch.nn.Module): """ return self.embed.get_size_before_subsampling(size) * hop_length - + def get_encoder_input_size(self, size: int) -> int: """Return the corresponding number of sample for a given chunk size, in frames. @@ -157,7 +153,7 @@ class Encoder(torch.nn.Module): mask, chunk_mask=chunk_mask, ) - + olens = mask.eq(0).sum(1) if self.time_reduction_factor > 1: x_utt = x_utt[:,::self.time_reduction_factor,:] @@ -194,14 +190,14 @@ class Encoder(torch.nn.Module): mask, chunk_mask=chunk_mask, ) - + olens = mask.eq(0).sum(1) if self.time_reduction_factor > 1: x = x[:,::self.time_reduction_factor,:] olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 return x, olens - + def simu_chunk_forward( self, x: torch.Tensor, @@ -290,7 +286,7 @@ class Encoder(torch.nn.Module): if right_context > 0: x = x[:, 0:-right_context, :] - + if self.time_reduction_factor > 1: x = x[:,::self.time_reduction_factor,:] return x diff --git a/funasr/models_transducer/__init__.py b/funasr/models/encoder/chunk_encoder_blocks/__init__.py similarity index 100% rename from funasr/models_transducer/__init__.py rename to funasr/models/encoder/chunk_encoder_blocks/__init__.py diff --git a/funasr/models_transducer/encoder/blocks/branchformer.py b/funasr/models/encoder/chunk_encoder_blocks/branchformer.py similarity index 100% rename from funasr/models_transducer/encoder/blocks/branchformer.py rename to funasr/models/encoder/chunk_encoder_blocks/branchformer.py diff --git a/funasr/models_transducer/encoder/blocks/conformer.py b/funasr/models/encoder/chunk_encoder_blocks/conformer.py similarity index 100% rename from funasr/models_transducer/encoder/blocks/conformer.py rename to funasr/models/encoder/chunk_encoder_blocks/conformer.py diff --git a/funasr/models_transducer/encoder/blocks/conv1d.py b/funasr/models/encoder/chunk_encoder_blocks/conv1d.py similarity index 100% rename from funasr/models_transducer/encoder/blocks/conv1d.py rename to funasr/models/encoder/chunk_encoder_blocks/conv1d.py diff --git a/funasr/models_transducer/encoder/blocks/conv_input.py b/funasr/models/encoder/chunk_encoder_blocks/conv_input.py similarity index 98% rename from funasr/models_transducer/encoder/blocks/conv_input.py rename to funasr/models/encoder/chunk_encoder_blocks/conv_input.py index ffec93e5e..b9bd2fdc2 100644 --- a/funasr/models_transducer/encoder/blocks/conv_input.py +++ b/funasr/models/encoder/chunk_encoder_blocks/conv_input.py @@ -5,7 +5,7 @@ from typing import Optional, Tuple, Union import torch import math -from funasr.models_transducer.utils import sub_factor_to_params, pad_to_len +from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len class ConvInput(torch.nn.Module): diff --git a/funasr/models_transducer/encoder/blocks/linear_input.py b/funasr/models/encoder/chunk_encoder_blocks/linear_input.py similarity index 100% rename from funasr/models_transducer/encoder/blocks/linear_input.py rename to funasr/models/encoder/chunk_encoder_blocks/linear_input.py diff --git a/funasr/models_transducer/decoder/__init__.py b/funasr/models/encoder/chunk_encoder_modules/__init__.py similarity index 100% rename from funasr/models_transducer/decoder/__init__.py rename to funasr/models/encoder/chunk_encoder_modules/__init__.py diff --git a/funasr/models_transducer/encoder/modules/attention.py b/funasr/models/encoder/chunk_encoder_modules/attention.py similarity index 100% rename from funasr/models_transducer/encoder/modules/attention.py rename to funasr/models/encoder/chunk_encoder_modules/attention.py diff --git a/funasr/models_transducer/encoder/modules/convolution.py b/funasr/models/encoder/chunk_encoder_modules/convolution.py similarity index 100% rename from funasr/models_transducer/encoder/modules/convolution.py rename to funasr/models/encoder/chunk_encoder_modules/convolution.py diff --git a/funasr/models_transducer/encoder/modules/multi_blocks.py b/funasr/models/encoder/chunk_encoder_modules/multi_blocks.py similarity index 100% rename from funasr/models_transducer/encoder/modules/multi_blocks.py rename to funasr/models/encoder/chunk_encoder_modules/multi_blocks.py diff --git a/funasr/models_transducer/encoder/modules/normalization.py b/funasr/models/encoder/chunk_encoder_modules/normalization.py similarity index 100% rename from funasr/models_transducer/encoder/modules/normalization.py rename to funasr/models/encoder/chunk_encoder_modules/normalization.py diff --git a/funasr/models_transducer/encoder/modules/positional_encoding.py b/funasr/models/encoder/chunk_encoder_modules/positional_encoding.py similarity index 100% rename from funasr/models_transducer/encoder/modules/positional_encoding.py rename to funasr/models/encoder/chunk_encoder_modules/positional_encoding.py diff --git a/funasr/models_transducer/encoder/building.py b/funasr/models/encoder/chunk_encoder_utils/building.py similarity index 92% rename from funasr/models_transducer/encoder/building.py rename to funasr/models/encoder/chunk_encoder_utils/building.py index a19943be7..21611aa19 100644 --- a/funasr/models_transducer/encoder/building.py +++ b/funasr/models/encoder/chunk_encoder_utils/building.py @@ -2,22 +2,22 @@ from typing import Any, Dict, List, Optional, Union -from funasr.models_transducer.activation import get_activation -from funasr.models_transducer.encoder.blocks.branchformer import Branchformer -from funasr.models_transducer.encoder.blocks.conformer import Conformer -from funasr.models_transducer.encoder.blocks.conv1d import Conv1d -from funasr.models_transducer.encoder.blocks.conv_input import ConvInput -from funasr.models_transducer.encoder.blocks.linear_input import LinearInput -from funasr.models_transducer.encoder.modules.attention import ( # noqa: H301 +from funasr.modules.activation import get_activation +from funasr.models.encoder.chunk_encoder_blocks.branchformer import Branchformer +from funasr.models.encoder.chunk_encoder_blocks.conformer import Conformer +from funasr.models.encoder.chunk_encoder_blocks.conv1d import Conv1d +from funasr.models.encoder.chunk_encoder_blocks.conv_input import ConvInput +from funasr.models.encoder.chunk_encoder_blocks.linear_input import LinearInput +from funasr.models.encoder.chunk_encoder_modules.attention import ( # noqa: H301 RelPositionMultiHeadedAttention, ) -from funasr.models_transducer.encoder.modules.convolution import ( # noqa: H301 +from funasr.models.encoder.chunk_encoder_modules.convolution import ( # noqa: H301 ConformerConvolution, ConvolutionalSpatialGatingUnit, ) -from funasr.models_transducer.encoder.modules.multi_blocks import MultiBlocks -from funasr.models_transducer.encoder.modules.normalization import get_normalization -from funasr.models_transducer.encoder.modules.positional_encoding import ( # noqa: H301 +from funasr.models.encoder.chunk_encoder_modules.multi_blocks import MultiBlocks +from funasr.models.encoder.chunk_encoder_modules.normalization import get_normalization +from funasr.models.encoder.chunk_encoder_modules.positional_encoding import ( # noqa: H301 RelPositionalEncoding, ) from funasr.modules.positionwise_feed_forward import ( diff --git a/funasr/models_transducer/encoder/validation.py b/funasr/models/encoder/chunk_encoder_utils/validation.py similarity index 98% rename from funasr/models_transducer/encoder/validation.py rename to funasr/models/encoder/chunk_encoder_utils/validation.py index 00035363a..1103cb93f 100644 --- a/funasr/models_transducer/encoder/validation.py +++ b/funasr/models/encoder/chunk_encoder_utils/validation.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Tuple -from funasr.models_transducer.utils import sub_factor_to_params +from funasr.modules.nets_utils import sub_factor_to_params def validate_block_arguments( diff --git a/funasr/models_transducer/joint_network.py b/funasr/models/joint_network.py similarity index 96% rename from funasr/models_transducer/joint_network.py rename to funasr/models/joint_network.py index 119dd84a5..5cabdb4f7 100644 --- a/funasr/models_transducer/joint_network.py +++ b/funasr/models/joint_network.py @@ -2,7 +2,7 @@ import torch -from funasr.models_transducer.activation import get_activation +from funasr.modules.activation import get_activation class JointNetwork(torch.nn.Module): diff --git a/funasr/models_transducer/encoder/__init__.py b/funasr/models/rnnt_decoder/__init__.py similarity index 100% rename from funasr/models_transducer/encoder/__init__.py rename to funasr/models/rnnt_decoder/__init__.py diff --git a/funasr/models_transducer/decoder/abs_decoder.py b/funasr/models/rnnt_decoder/abs_decoder.py similarity index 100% rename from funasr/models_transducer/decoder/abs_decoder.py rename to funasr/models/rnnt_decoder/abs_decoder.py diff --git a/funasr/models_transducer/decoder/rnn_decoder.py b/funasr/models/rnnt_decoder/rnn_decoder.py similarity index 98% rename from funasr/models_transducer/decoder/rnn_decoder.py rename to funasr/models/rnnt_decoder/rnn_decoder.py index 04c32287a..c4e79511c 100644 --- a/funasr/models_transducer/decoder/rnn_decoder.py +++ b/funasr/models/rnnt_decoder/rnn_decoder.py @@ -5,8 +5,8 @@ from typing import List, Optional, Tuple import torch from typeguard import check_argument_types -from funasr.models_transducer.beam_search_transducer import Hypothesis -from funasr.models_transducer.decoder.abs_decoder import AbsDecoder +from funasr.modules.beam_search.beam_search_transducer import Hypothesis +from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder from funasr.models.specaug.specaug import SpecAug class RNNDecoder(AbsDecoder): diff --git a/funasr/models_transducer/decoder/stateless_decoder.py b/funasr/models/rnnt_decoder/stateless_decoder.py similarity index 86% rename from funasr/models_transducer/decoder/stateless_decoder.py rename to funasr/models/rnnt_decoder/stateless_decoder.py index 07c8f519b..a2e1fc14b 100644 --- a/funasr/models_transducer/decoder/stateless_decoder.py +++ b/funasr/models/rnnt_decoder/stateless_decoder.py @@ -5,8 +5,8 @@ from typing import List, Optional, Tuple import torch from typeguard import check_argument_types -from funasr.models_transducer.beam_search_transducer import Hypothesis -from funasr.models_transducer.decoder.abs_decoder import AbsDecoder +from funasr.modules.beam_search.beam_search_transducer import Hypothesis +from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder from funasr.models.specaug.specaug import SpecAug class StatelessDecoder(AbsDecoder): @@ -26,7 +26,6 @@ class StatelessDecoder(AbsDecoder): embed_size: int = 256, embed_dropout_rate: float = 0.0, embed_pad: int = 0, - use_embed_mask: bool = False, ) -> None: """Construct a StatelessDecoder object.""" super().__init__() @@ -42,14 +41,6 @@ class StatelessDecoder(AbsDecoder): self.device = next(self.parameters()).device self.score_cache = {} - self.use_embed_mask = use_embed_mask - if self.use_embed_mask: - self._embed_mask = SpecAug( - time_mask_width_range=3, - num_time_mask=1, - apply_freq_mask=False, - apply_time_warp=False - ) def forward( @@ -69,9 +60,6 @@ class StatelessDecoder(AbsDecoder): """ dec_embed = self.embed_dropout_rate(self.embed(labels)) - if self.use_embed_mask and self.training: - dec_embed = self._embed_mask(dec_embed, label_lens)[0] - return dec_embed def score( diff --git a/funasr/models_transducer/encoder/blocks/__init__.py b/funasr/models_transducer/encoder/blocks/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/funasr/models_transducer/encoder/modules/__init__.py b/funasr/models_transducer/encoder/modules/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/funasr/models_transducer/encoder/sanm_encoder.py b/funasr/models_transducer/encoder/sanm_encoder.py deleted file mode 100644 index 9e74bdfeb..000000000 --- a/funasr/models_transducer/encoder/sanm_encoder.py +++ /dev/null @@ -1,835 +0,0 @@ -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union -import logging -import torch -import torch.nn as nn -from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk -from typeguard import check_argument_types -import numpy as np -from funasr.modules.nets_utils import make_pad_mask -from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM -from funasr.modules.embedding import SinusoidalPositionEncoder -from funasr.modules.layer_norm import LayerNorm -from funasr.modules.multi_layer_conv import Conv1dLinear -from funasr.modules.multi_layer_conv import MultiLayeredConv1d -from funasr.modules.positionwise_feed_forward import ( - PositionwiseFeedForward, # noqa: H301 -) -from funasr.modules.repeat import repeat -from funasr.modules.subsampling import Conv2dSubsampling -from funasr.modules.subsampling import Conv2dSubsampling2 -from funasr.modules.subsampling import Conv2dSubsampling6 -from funasr.modules.subsampling import Conv2dSubsampling8 -from funasr.modules.subsampling import TooShortUttError -from funasr.modules.subsampling import check_short_utt -from funasr.models.ctc import CTC -from funasr.models.encoder.abs_encoder import AbsEncoder - - -class EncoderLayerSANM(nn.Module): - def __init__( - self, - in_size, - size, - self_attn, - feed_forward, - dropout_rate, - normalize_before=True, - concat_after=False, - stochastic_depth_rate=0.0, - ): - """Construct an EncoderLayer object.""" - super(EncoderLayerSANM, self).__init__() - self.self_attn = self_attn - self.feed_forward = feed_forward - self.norm1 = LayerNorm(in_size) - self.norm2 = LayerNorm(size) - self.dropout = nn.Dropout(dropout_rate) - self.in_size = in_size - self.size = size - self.normalize_before = normalize_before - self.concat_after = concat_after - if self.concat_after: - self.concat_linear = nn.Linear(size + size, size) - self.stochastic_depth_rate = stochastic_depth_rate - self.dropout_rate = dropout_rate - - def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None): - """Compute encoded features. - Args: - x_input (torch.Tensor): Input tensor (#batch, time, size). - mask (torch.Tensor): Mask tensor for the input (#batch, time). - cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). - Returns: - torch.Tensor: Output tensor (#batch, time, size). - torch.Tensor: Mask tensor (#batch, time). - """ - skip_layer = False - # with stochastic depth, residual connection `x + f(x)` becomes - # `x <- x + 1 / (1 - p) * f(x)` at training time. - stoch_layer_coeff = 1.0 - if self.training and self.stochastic_depth_rate > 0: - skip_layer = torch.rand(1).item() < self.stochastic_depth_rate - stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) - - if skip_layer: - if cache is not None: - x = torch.cat([cache, x], dim=1) - return x, mask - - residual = x - if self.normalize_before: - x = self.norm1(x) - - if self.concat_after: - x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1) - if self.in_size == self.size: - x = residual + stoch_layer_coeff * self.concat_linear(x_concat) - else: - x = stoch_layer_coeff * self.concat_linear(x_concat) - else: - if self.in_size == self.size: - x = residual + stoch_layer_coeff * self.dropout( - self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder) - ) - else: - x = stoch_layer_coeff * self.dropout( - self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder) - ) - if not self.normalize_before: - x = self.norm1(x) - - residual = x - if self.normalize_before: - x = self.norm2(x) - x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) - if not self.normalize_before: - x = self.norm2(x) - - - return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder - -class SANMEncoder(AbsEncoder): - """ - author: Speech Lab, Alibaba Group, China - San-m: Memory equipped self-attention for end-to-end speech recognition - https://arxiv.org/abs/2006.01713 - """ - - def __init__( - self, - input_size: int, - output_size: int = 256, - attention_heads: int = 4, - linear_units: int = 2048, - num_blocks: int = 6, - dropout_rate: float = 0.1, - positional_dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.0, - pos_enc_class=SinusoidalPositionEncoder, - normalize_before: bool = True, - concat_after: bool = False, - positionwise_layer_type: str = "linear", - positionwise_conv_kernel_size: int = 1, - padding_idx: int = -1, - interctc_layer_idx: List[int] = [], - interctc_use_conditioning: bool = False, - kernel_size : int = 11, - sanm_shfit : int = 0, - tf2torch_tensor_name_prefix_torch: str = "encoder", - tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder", - ): - assert check_argument_types() - super().__init__() - - self.embed = SinusoidalPositionEncoder() - self.normalize_before = normalize_before - if positionwise_layer_type == "linear": - positionwise_layer = PositionwiseFeedForward - positionwise_layer_args = ( - output_size, - linear_units, - dropout_rate, - ) - elif positionwise_layer_type == "conv1d": - positionwise_layer = MultiLayeredConv1d - positionwise_layer_args = ( - output_size, - linear_units, - positionwise_conv_kernel_size, - dropout_rate, - ) - elif positionwise_layer_type == "conv1d-linear": - positionwise_layer = Conv1dLinear - positionwise_layer_args = ( - output_size, - linear_units, - positionwise_conv_kernel_size, - dropout_rate, - ) - else: - raise NotImplementedError("Support only linear or conv1d.") - - encoder_selfattn_layer = MultiHeadedAttentionSANM - encoder_selfattn_layer_args0 = ( - attention_heads, - input_size, - output_size, - attention_dropout_rate, - kernel_size, - sanm_shfit, - ) - - encoder_selfattn_layer_args = ( - attention_heads, - output_size, - output_size, - attention_dropout_rate, - kernel_size, - sanm_shfit, - ) - self.encoders0 = repeat( - 1, - lambda lnum: EncoderLayerSANM( - input_size, - output_size, - encoder_selfattn_layer(*encoder_selfattn_layer_args0), - positionwise_layer(*positionwise_layer_args), - dropout_rate, - normalize_before, - concat_after, - ), - ) - - self.encoders = repeat( - num_blocks-1, - lambda lnum: EncoderLayerSANM( - output_size, - output_size, - encoder_selfattn_layer(*encoder_selfattn_layer_args), - positionwise_layer(*positionwise_layer_args), - dropout_rate, - normalize_before, - concat_after, - ), - ) - if self.normalize_before: - self.after_norm = LayerNorm(output_size) - - self.interctc_layer_idx = interctc_layer_idx - if len(interctc_layer_idx) > 0: - assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks - self.interctc_use_conditioning = interctc_use_conditioning - self.conditioning_layer = None - self.dropout = nn.Dropout(dropout_rate) - self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch - self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf - - def forward( - self, - xs_pad: torch.Tensor, - ilens: torch.Tensor, - prev_states: torch.Tensor = None, - ctc: CTC = None, - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """Embed positions in tensor. - Args: - xs_pad: input tensor (B, L, D) - ilens: input length (B) - prev_states: Not to be used now. - Returns: - position embedded tensor and mask - """ - masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) - xs_pad = xs_pad * self.output_size**0.5 - if self.embed is None: - xs_pad = xs_pad - elif ( - isinstance(self.embed, Conv2dSubsampling) - or isinstance(self.embed, Conv2dSubsampling2) - or isinstance(self.embed, Conv2dSubsampling6) - or isinstance(self.embed, Conv2dSubsampling8) - ): - short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) - if short_status: - raise TooShortUttError( - f"has {xs_pad.size(1)} frames and is too short for subsampling " - + f"(it needs more than {limit_size} frames), return empty results", - xs_pad.size(1), - limit_size, - ) - xs_pad, masks = self.embed(xs_pad, masks) - else: - xs_pad = self.embed(xs_pad) - - # xs_pad = self.dropout(xs_pad) - encoder_outs = self.encoders0(xs_pad, masks) - xs_pad, masks = encoder_outs[0], encoder_outs[1] - intermediate_outs = [] - if len(self.interctc_layer_idx) == 0: - encoder_outs = self.encoders(xs_pad, masks) - xs_pad, masks = encoder_outs[0], encoder_outs[1] - else: - for layer_idx, encoder_layer in enumerate(self.encoders): - encoder_outs = encoder_layer(xs_pad, masks) - xs_pad, masks = encoder_outs[0], encoder_outs[1] - - if layer_idx + 1 in self.interctc_layer_idx: - encoder_out = xs_pad - - # intermediate outputs are also normalized - if self.normalize_before: - encoder_out = self.after_norm(encoder_out) - - intermediate_outs.append((layer_idx + 1, encoder_out)) - - if self.interctc_use_conditioning: - ctc_out = ctc.softmax(encoder_out) - xs_pad = xs_pad + self.conditioning_layer(ctc_out) - - if self.normalize_before: - xs_pad = self.after_norm(xs_pad) - - olens = masks.squeeze(1).sum(1) - if len(intermediate_outs) > 0: - return (xs_pad, intermediate_outs), olens, None - return xs_pad, olens - - def gen_tf2torch_map_dict(self): - tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch - tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf - map_dict_local = { - ## encoder - # cicd - "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 0), - }, # (768,256),(1,256,768) - "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (768,),(768,) - "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 2, 0), - }, # (256,1,31),(1,31,256,1) - "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 0), - }, # (256,256),(1,256,256) - "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - # ffn - "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 0), - }, # (1024,256),(1,256,1024) - "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (1024,),(1024,) - "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 0), - }, # (256,1024),(1,1024,256) - "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - # out norm - "{}.after_norm.weight".format(tensor_name_prefix_torch): - {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.after_norm.bias".format(tensor_name_prefix_torch): - {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - - } - - return map_dict_local - - def convert_tf2torch(self, - var_dict_tf, - var_dict_torch, - ): - - map_dict = self.gen_tf2torch_map_dict() - - var_dict_torch_update = dict() - for name in sorted(var_dict_torch.keys(), reverse=False): - names = name.split('.') - if names[0] == self.tf2torch_tensor_name_prefix_torch: - if names[1] == "encoders0": - layeridx = int(names[2]) - name_q = name.replace(".{}.".format(layeridx), ".layeridx.") - - name_q = name_q.replace("encoders0", "encoders") - layeridx_bias = 0 - layeridx += layeridx_bias - if name_q in map_dict.keys(): - name_v = map_dict[name_q]["name"] - name_tf = name_v.replace("layeridx", "{}".format(layeridx)) - data_tf = var_dict_tf[name_tf] - if map_dict[name_q]["squeeze"] is not None: - data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) - if map_dict[name_q]["transpose"] is not None: - data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) - data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") - assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, - var_dict_torch[ - name].size(), - data_tf.size()) - var_dict_torch_update[name] = data_tf - logging.info( - "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, - var_dict_tf[name_tf].shape)) - elif names[1] == "encoders": - layeridx = int(names[2]) - name_q = name.replace(".{}.".format(layeridx), ".layeridx.") - layeridx_bias = 1 - layeridx += layeridx_bias - if name_q in map_dict.keys(): - name_v = map_dict[name_q]["name"] - name_tf = name_v.replace("layeridx", "{}".format(layeridx)) - data_tf = var_dict_tf[name_tf] - if map_dict[name_q]["squeeze"] is not None: - data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) - if map_dict[name_q]["transpose"] is not None: - data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) - data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") - assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, - var_dict_torch[ - name].size(), - data_tf.size()) - var_dict_torch_update[name] = data_tf - logging.info( - "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, - var_dict_tf[name_tf].shape)) - - elif names[1] == "after_norm": - name_tf = map_dict[name]["name"] - data_tf = var_dict_tf[name_tf] - data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") - var_dict_torch_update[name] = data_tf - logging.info( - "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, - var_dict_tf[name_tf].shape)) - - return var_dict_torch_update - - -class SANMEncoderChunkOpt(AbsEncoder): - """ - author: Speech Lab, Alibaba Group, China - SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition - https://arxiv.org/abs/2006.01713 - """ - - def __init__( - self, - input_size: int, - output_size: int = 256, - attention_heads: int = 4, - linear_units: int = 2048, - num_blocks: int = 6, - dropout_rate: float = 0.1, - positional_dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.0, - pos_enc_class=SinusoidalPositionEncoder, - normalize_before: bool = True, - concat_after: bool = False, - positionwise_layer_type: str = "linear", - positionwise_conv_kernel_size: int = 1, - padding_idx: int = -1, - interctc_layer_idx: List[int] = [], - interctc_use_conditioning: bool = False, - kernel_size: int = 11, - sanm_shfit: int = 0, - chunk_size: Union[int, Sequence[int]] = (16,), - stride: Union[int, Sequence[int]] = (10,), - pad_left: Union[int, Sequence[int]] = (0,), - time_reduction_factor: int = 1, - encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,), - decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,), - tf2torch_tensor_name_prefix_torch: str = "encoder", - tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder", - ): - assert check_argument_types() - super().__init__() - self.output_size = output_size - - self.embed = SinusoidalPositionEncoder() - - self.normalize_before = normalize_before - if positionwise_layer_type == "linear": - positionwise_layer = PositionwiseFeedForward - positionwise_layer_args = ( - output_size, - linear_units, - dropout_rate, - ) - elif positionwise_layer_type == "conv1d": - positionwise_layer = MultiLayeredConv1d - positionwise_layer_args = ( - output_size, - linear_units, - positionwise_conv_kernel_size, - dropout_rate, - ) - elif positionwise_layer_type == "conv1d-linear": - positionwise_layer = Conv1dLinear - positionwise_layer_args = ( - output_size, - linear_units, - positionwise_conv_kernel_size, - dropout_rate, - ) - else: - raise NotImplementedError("Support only linear or conv1d.") - - encoder_selfattn_layer = MultiHeadedAttentionSANM - encoder_selfattn_layer_args0 = ( - attention_heads, - input_size, - output_size, - attention_dropout_rate, - kernel_size, - sanm_shfit, - ) - - encoder_selfattn_layer_args = ( - attention_heads, - output_size, - output_size, - attention_dropout_rate, - kernel_size, - sanm_shfit, - ) - self.encoders0 = repeat( - 1, - lambda lnum: EncoderLayerSANM( - input_size, - output_size, - encoder_selfattn_layer(*encoder_selfattn_layer_args0), - positionwise_layer(*positionwise_layer_args), - dropout_rate, - normalize_before, - concat_after, - ), - ) - - self.encoders = repeat( - num_blocks - 1, - lambda lnum: EncoderLayerSANM( - output_size, - output_size, - encoder_selfattn_layer(*encoder_selfattn_layer_args), - positionwise_layer(*positionwise_layer_args), - dropout_rate, - normalize_before, - concat_after, - ), - ) - if self.normalize_before: - self.after_norm = LayerNorm(output_size) - - self.interctc_layer_idx = interctc_layer_idx - if len(interctc_layer_idx) > 0: - assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks - self.interctc_use_conditioning = interctc_use_conditioning - self.conditioning_layer = None - shfit_fsmn = (kernel_size - 1) // 2 - self.overlap_chunk_cls = overlap_chunk( - chunk_size=chunk_size, - stride=stride, - pad_left=pad_left, - shfit_fsmn=shfit_fsmn, - encoder_att_look_back_factor=encoder_att_look_back_factor, - decoder_att_look_back_factor=decoder_att_look_back_factor, - ) - self.time_reduction_factor = time_reduction_factor - self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch - self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf - - def forward( - self, - xs_pad: torch.Tensor, - ilens: torch.Tensor, - prev_states: torch.Tensor = None, - ctc: CTC = None, - ind: int = 0, - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """Embed positions in tensor. - Args: - xs_pad: input tensor (B, L, D) - ilens: input length (B) - prev_states: Not to be used now. - Returns: - position embedded tensor and mask - """ - masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) - xs_pad *= self.output_size ** 0.5 - if self.embed is None: - xs_pad = xs_pad - elif ( - isinstance(self.embed, Conv2dSubsampling) - or isinstance(self.embed, Conv2dSubsampling2) - or isinstance(self.embed, Conv2dSubsampling6) - or isinstance(self.embed, Conv2dSubsampling8) - ): - short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) - if short_status: - raise TooShortUttError( - f"has {xs_pad.size(1)} frames and is too short for subsampling " - + f"(it needs more than {limit_size} frames), return empty results", - xs_pad.size(1), - limit_size, - ) - xs_pad, masks = self.embed(xs_pad, masks) - else: - xs_pad = self.embed(xs_pad) - - mask_shfit_chunk, mask_att_chunk_encoder = None, None - if self.overlap_chunk_cls is not None: - ilens = masks.squeeze(1).sum(1) - chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind) - xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs) - masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) - mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0), - dtype=xs_pad.dtype) - mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device, - xs_pad.size(0), - dtype=xs_pad.dtype) - - encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder) - xs_pad, masks = encoder_outs[0], encoder_outs[1] - intermediate_outs = [] - if len(self.interctc_layer_idx) == 0: - encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder) - xs_pad, masks = encoder_outs[0], encoder_outs[1] - else: - for layer_idx, encoder_layer in enumerate(self.encoders): - encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder) - xs_pad, masks = encoder_outs[0], encoder_outs[1] - if layer_idx + 1 in self.interctc_layer_idx: - encoder_out = xs_pad - - # intermediate outputs are also normalized - if self.normalize_before: - encoder_out = self.after_norm(encoder_out) - - intermediate_outs.append((layer_idx + 1, encoder_out)) - - if self.interctc_use_conditioning: - ctc_out = ctc.softmax(encoder_out) - xs_pad = xs_pad + self.conditioning_layer(ctc_out) - - if self.normalize_before: - xs_pad = self.after_norm(xs_pad) - - olens = masks.squeeze(1).sum(1) - - xs_pad, olens = self.overlap_chunk_cls.remove_chunk(xs_pad, olens, chunk_outs=None) - - if self.time_reduction_factor > 1: - xs_pad = xs_pad[:,::self.time_reduction_factor,:] - olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 - - if len(intermediate_outs) > 0: - return (xs_pad, intermediate_outs), olens, None - return xs_pad, olens - - def gen_tf2torch_map_dict(self): - tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch - tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf - map_dict_local = { - ## encoder - # cicd - "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 0), - }, # (768,256),(1,256,768) - "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (768,),(768,) - "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 2, 0), - }, # (256,1,31),(1,31,256,1) - "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 0), - }, # (256,256),(1,256,256) - "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - # ffn - "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 0), - }, # (1024,256),(1,256,1024) - "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (1024,),(1024,) - "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 0), - }, # (256,1024),(1,1024,256) - "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - # out norm - "{}.after_norm.weight".format(tensor_name_prefix_torch): - {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.after_norm.bias".format(tensor_name_prefix_torch): - {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - - } - - return map_dict_local - - def convert_tf2torch(self, - var_dict_tf, - var_dict_torch, - ): - - map_dict = self.gen_tf2torch_map_dict() - - var_dict_torch_update = dict() - for name in sorted(var_dict_torch.keys(), reverse=False): - names = name.split('.') - if names[0] == self.tf2torch_tensor_name_prefix_torch: - if names[1] == "encoders0": - layeridx = int(names[2]) - name_q = name.replace(".{}.".format(layeridx), ".layeridx.") - - name_q = name_q.replace("encoders0", "encoders") - layeridx_bias = 0 - layeridx += layeridx_bias - if name_q in map_dict.keys(): - name_v = map_dict[name_q]["name"] - name_tf = name_v.replace("layeridx", "{}".format(layeridx)) - data_tf = var_dict_tf[name_tf] - if map_dict[name_q]["squeeze"] is not None: - data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) - if map_dict[name_q]["transpose"] is not None: - data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) - data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") - assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, - var_dict_torch[ - name].size(), - data_tf.size()) - var_dict_torch_update[name] = data_tf - logging.info( - "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, - var_dict_tf[name_tf].shape)) - elif names[1] == "encoders": - layeridx = int(names[2]) - name_q = name.replace(".{}.".format(layeridx), ".layeridx.") - layeridx_bias = 1 - layeridx += layeridx_bias - if name_q in map_dict.keys(): - name_v = map_dict[name_q]["name"] - name_tf = name_v.replace("layeridx", "{}".format(layeridx)) - data_tf = var_dict_tf[name_tf] - if map_dict[name_q]["squeeze"] is not None: - data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) - if map_dict[name_q]["transpose"] is not None: - data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) - data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") - assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, - var_dict_torch[ - name].size(), - data_tf.size()) - var_dict_torch_update[name] = data_tf - logging.info( - "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, - var_dict_tf[name_tf].shape)) - - elif names[1] == "after_norm": - name_tf = map_dict[name]["name"] - data_tf = var_dict_tf[name_tf] - data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") - var_dict_torch_update[name] = data_tf - logging.info( - "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, - var_dict_tf[name_tf].shape)) - - return var_dict_torch_update diff --git a/funasr/models_transducer/error_calculator.py b/funasr/models_transducer/error_calculator.py deleted file mode 100644 index 34b1dc74e..000000000 --- a/funasr/models_transducer/error_calculator.py +++ /dev/null @@ -1,169 +0,0 @@ -"""Error Calculator module for Transducer.""" - -from typing import List, Optional, Tuple - -import torch - -from funasr.models_transducer.beam_search_transducer import BeamSearchTransducer -from funasr.models_transducer.decoder.abs_decoder import AbsDecoder -from funasr.models_transducer.joint_network import JointNetwork - - -class ErrorCalculator: - """Calculate CER and WER for transducer models. - - Args: - decoder: Decoder module. - joint_network: Joint Network module. - token_list: List of token units. - sym_space: Space symbol. - sym_blank: Blank symbol. - report_cer: Whether to compute CER. - report_wer: Whether to compute WER. - - """ - - def __init__( - self, - decoder: AbsDecoder, - joint_network: JointNetwork, - token_list: List[int], - sym_space: str, - sym_blank: str, - report_cer: bool = False, - report_wer: bool = False, - ) -> None: - """Construct an ErrorCalculatorTransducer object.""" - super().__init__() - - self.beam_search = BeamSearchTransducer( - decoder=decoder, - joint_network=joint_network, - beam_size=1, - search_type="default", - score_norm=False, - ) - - self.decoder = decoder - - self.token_list = token_list - self.space = sym_space - self.blank = sym_blank - - self.report_cer = report_cer - self.report_wer = report_wer - - def __call__( - self, encoder_out: torch.Tensor, target: torch.Tensor - ) -> Tuple[Optional[float], Optional[float]]: - """Calculate sentence-level WER or/and CER score for Transducer model. - - Args: - encoder_out: Encoder output sequences. (B, T, D_enc) - target: Target label ID sequences. (B, L) - - Returns: - : Sentence-level CER score. - : Sentence-level WER score. - - """ - cer, wer = None, None - - batchsize = int(encoder_out.size(0)) - - encoder_out = encoder_out.to(next(self.decoder.parameters()).device) - - batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)] - pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest] - - char_pred, char_target = self.convert_to_char(pred, target) - - if self.report_cer: - cer = self.calculate_cer(char_pred, char_target) - - if self.report_wer: - wer = self.calculate_wer(char_pred, char_target) - - return cer, wer - - def convert_to_char( - self, pred: torch.Tensor, target: torch.Tensor - ) -> Tuple[List, List]: - """Convert label ID sequences to character sequences. - - Args: - pred: Prediction label ID sequences. (B, U) - target: Target label ID sequences. (B, L) - - Returns: - char_pred: Prediction character sequences. (B, ?) - char_target: Target character sequences. (B, ?) - - """ - char_pred, char_target = [], [] - - for i, pred_i in enumerate(pred): - char_pred_i = [self.token_list[int(h)] for h in pred_i] - char_target_i = [self.token_list[int(r)] for r in target[i]] - - char_pred_i = "".join(char_pred_i).replace(self.space, " ") - char_pred_i = char_pred_i.replace(self.blank, "") - - char_target_i = "".join(char_target_i).replace(self.space, " ") - char_target_i = char_target_i.replace(self.blank, "") - - char_pred.append(char_pred_i) - char_target.append(char_target_i) - - return char_pred, char_target - - def calculate_cer( - self, char_pred: torch.Tensor, char_target: torch.Tensor - ) -> float: - """Calculate sentence-level CER score. - - Args: - char_pred: Prediction character sequences. (B, ?) - char_target: Target character sequences. (B, ?) - - Returns: - : Average sentence-level CER score. - - """ - import editdistance - - distances, lens = [], [] - - for i, char_pred_i in enumerate(char_pred): - pred = char_pred_i.replace(" ", "") - target = char_target[i].replace(" ", "") - distances.append(editdistance.eval(pred, target)) - lens.append(len(target)) - - return float(sum(distances)) / sum(lens) - - def calculate_wer( - self, char_pred: torch.Tensor, char_target: torch.Tensor - ) -> float: - """Calculate sentence-level WER score. - - Args: - char_pred: Prediction character sequences. (B, ?) - char_target: Target character sequences. (B, ?) - - Returns: - : Average sentence-level WER score - - """ - import editdistance - - distances, lens = [], [] - - for i, char_pred_i in enumerate(char_pred): - pred = char_pred_i.replace("▁", " ").split() - target = char_target[i].replace("▁", " ").split() - - distances.append(editdistance.eval(pred, target)) - lens.append(len(target)) - - return float(sum(distances)) / sum(lens) diff --git a/funasr/models_transducer/espnet_transducer_model_uni_asr.py b/funasr/models_transducer/espnet_transducer_model_uni_asr.py deleted file mode 100644 index 2add3fa78..000000000 --- a/funasr/models_transducer/espnet_transducer_model_uni_asr.py +++ /dev/null @@ -1,485 +0,0 @@ -"""ESPnet2 ASR Transducer model.""" - -import logging -from contextlib import contextmanager -from typing import Dict, List, Optional, Tuple, Union - -import torch -from packaging.version import parse as V -from typeguard import check_argument_types - -from funasr.models.frontend.abs_frontend import AbsFrontend -from funasr.models.specaug.abs_specaug import AbsSpecAug -from funasr.models_transducer.decoder.abs_decoder import AbsDecoder -from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder -from funasr.models_transducer.encoder.encoder import Encoder -from funasr.models_transducer.joint_network import JointNetwork -from funasr.models_transducer.utils import get_transducer_task_io -from funasr.layers.abs_normalize import AbsNormalize -from funasr.torch_utils.device_funcs import force_gatherable -from funasr.train.abs_espnet_model import AbsESPnetModel - -if V(torch.__version__) >= V("1.6.0"): - from torch.cuda.amp import autocast -else: - - @contextmanager - def autocast(enabled=True): - yield - - -class UniASRTransducerModel(AbsESPnetModel): - """ESPnet2ASRTransducerModel module definition. - - Args: - vocab_size: Size of complete vocabulary (w/ EOS and blank included). - token_list: List of token - frontend: Frontend module. - specaug: SpecAugment module. - normalize: Normalization module. - encoder: Encoder module. - decoder: Decoder module. - joint_network: Joint Network module. - transducer_weight: Weight of the Transducer loss. - fastemit_lambda: FastEmit lambda value. - auxiliary_ctc_weight: Weight of auxiliary CTC loss. - auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs. - auxiliary_lm_loss_weight: Weight of auxiliary LM loss. - auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing. - ignore_id: Initial padding ID. - sym_space: Space symbol. - sym_blank: Blank Symbol - report_cer: Whether to report Character Error Rate during validation. - report_wer: Whether to report Word Error Rate during validation. - extract_feats_in_collect_stats: Whether to use extract_feats stats collection. - - """ - - def __init__( - self, - vocab_size: int, - token_list: Union[Tuple[str, ...], List[str]], - frontend: Optional[AbsFrontend], - specaug: Optional[AbsSpecAug], - normalize: Optional[AbsNormalize], - encoder, - decoder: AbsDecoder, - att_decoder: Optional[AbsAttDecoder], - joint_network: JointNetwork, - transducer_weight: float = 1.0, - fastemit_lambda: float = 0.0, - auxiliary_ctc_weight: float = 0.0, - auxiliary_ctc_dropout_rate: float = 0.0, - auxiliary_lm_loss_weight: float = 0.0, - auxiliary_lm_loss_smoothing: float = 0.0, - ignore_id: int = -1, - sym_space: str = "", - sym_blank: str = "", - report_cer: bool = True, - report_wer: bool = True, - extract_feats_in_collect_stats: bool = True, - ) -> None: - """Construct an ESPnetASRTransducerModel object.""" - super().__init__() - - assert check_argument_types() - - # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos) - self.blank_id = 0 - self.vocab_size = vocab_size - self.ignore_id = ignore_id - self.token_list = token_list.copy() - - self.sym_space = sym_space - self.sym_blank = sym_blank - - self.frontend = frontend - self.specaug = specaug - self.normalize = normalize - - self.encoder = encoder - self.decoder = decoder - self.joint_network = joint_network - - self.criterion_transducer = None - self.error_calculator = None - - self.use_auxiliary_ctc = auxiliary_ctc_weight > 0 - self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0 - - if self.use_auxiliary_ctc: - self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size) - self.ctc_dropout_rate = auxiliary_ctc_dropout_rate - - if self.use_auxiliary_lm_loss: - self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size) - self.lm_loss_smoothing = auxiliary_lm_loss_smoothing - - self.transducer_weight = transducer_weight - self.fastemit_lambda = fastemit_lambda - - self.auxiliary_ctc_weight = auxiliary_ctc_weight - self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight - - self.report_cer = report_cer - self.report_wer = report_wer - - self.extract_feats_in_collect_stats = extract_feats_in_collect_stats - - def forward( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - text: torch.Tensor, - text_lengths: torch.Tensor, - decoding_ind: int = None, - **kwargs, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: - """Forward architecture and compute loss(es). - - Args: - speech: Speech sequences. (B, S) - speech_lengths: Speech sequences lengths. (B,) - text: Label ID sequences. (B, L) - text_lengths: Label ID sequences lengths. (B,) - kwargs: Contains "utts_id". - - Return: - loss: Main loss value. - stats: Task statistics. - weight: Task weights. - - """ - assert text_lengths.dim() == 1, text_lengths.shape - assert ( - speech.shape[0] - == speech_lengths.shape[0] - == text.shape[0] - == text_lengths.shape[0] - ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) - - batch_size = speech.shape[0] - text = text[:, : text_lengths.max()] - - # 1. Encoder - ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind) - encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind) - # 2. Transducer-related I/O preparation - decoder_in, target, t_len, u_len = get_transducer_task_io( - text, - encoder_out_lens, - ignore_id=self.ignore_id, - ) - - # 3. Decoder - self.decoder.set_device(encoder_out.device) - decoder_out = self.decoder(decoder_in, u_len) - - # 4. Joint Network - joint_out = self.joint_network( - encoder_out.unsqueeze(2), decoder_out.unsqueeze(1) - ) - - # 5. Losses - loss_trans, cer_trans, wer_trans = self._calc_transducer_loss( - encoder_out, - joint_out, - target, - t_len, - u_len, - ) - - loss_ctc, loss_lm = 0.0, 0.0 - - if self.use_auxiliary_ctc: - loss_ctc = self._calc_ctc_loss( - encoder_out, - target, - t_len, - u_len, - ) - - if self.use_auxiliary_lm_loss: - loss_lm = self._calc_lm_loss(decoder_out, target) - - loss = ( - self.transducer_weight * loss_trans - + self.auxiliary_ctc_weight * loss_ctc - + self.auxiliary_lm_loss_weight * loss_lm - ) - - stats = dict( - loss=loss.detach(), - loss_transducer=loss_trans.detach(), - aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None, - aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None, - cer_transducer=cer_trans, - wer_transducer=wer_trans, - ) - - # force_gatherable: to-device and to-tensor if scalar for DataParallel - loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) - - return loss, stats, weight - - def collect_feats( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - text: torch.Tensor, - text_lengths: torch.Tensor, - **kwargs, - ) -> Dict[str, torch.Tensor]: - """Collect features sequences and features lengths sequences. - - Args: - speech: Speech sequences. (B, S) - speech_lengths: Speech sequences lengths. (B,) - text: Label ID sequences. (B, L) - text_lengths: Label ID sequences lengths. (B,) - kwargs: Contains "utts_id". - - Return: - {}: "feats": Features sequences. (B, T, D_feats), - "feats_lengths": Features sequences lengths. (B,) - - """ - if self.extract_feats_in_collect_stats: - feats, feats_lengths = self._extract_feats(speech, speech_lengths) - else: - # Generate dummy stats if extract_feats_in_collect_stats is False - logging.warning( - "Generating dummy stats for feats and feats_lengths, " - "because encoder_conf.extract_feats_in_collect_stats is " - f"{self.extract_feats_in_collect_stats}" - ) - - feats, feats_lengths = speech, speech_lengths - - return {"feats": feats, "feats_lengths": feats_lengths} - - def encode( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - ind: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Encoder speech sequences. - - Args: - speech: Speech sequences. (B, S) - speech_lengths: Speech sequences lengths. (B,) - - Return: - encoder_out: Encoder outputs. (B, T, D_enc) - encoder_out_lens: Encoder outputs lengths. (B,) - - """ - with autocast(False): - # 1. Extract feats - feats, feats_lengths = self._extract_feats(speech, speech_lengths) - - # 2. Data augmentation - if self.specaug is not None and self.training: - feats, feats_lengths = self.specaug(feats, feats_lengths) - - # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN - if self.normalize is not None: - feats, feats_lengths = self.normalize(feats, feats_lengths) - - # 4. Forward encoder - encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths, ind=ind) - - assert encoder_out.size(0) == speech.size(0), ( - encoder_out.size(), - speech.size(0), - ) - assert encoder_out.size(1) <= encoder_out_lens.max(), ( - encoder_out.size(), - encoder_out_lens.max(), - ) - - return encoder_out, encoder_out_lens - - def _extract_feats( - self, speech: torch.Tensor, speech_lengths: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Extract features sequences and features sequences lengths. - - Args: - speech: Speech sequences. (B, S) - speech_lengths: Speech sequences lengths. (B,) - - Return: - feats: Features sequences. (B, T, D_feats) - feats_lengths: Features sequences lengths. (B,) - - """ - assert speech_lengths.dim() == 1, speech_lengths.shape - - # for data-parallel - speech = speech[:, : speech_lengths.max()] - - if self.frontend is not None: - feats, feats_lengths = self.frontend(speech, speech_lengths) - else: - feats, feats_lengths = speech, speech_lengths - - return feats, feats_lengths - - def _calc_transducer_loss( - self, - encoder_out: torch.Tensor, - joint_out: torch.Tensor, - target: torch.Tensor, - t_len: torch.Tensor, - u_len: torch.Tensor, - ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]: - """Compute Transducer loss. - - Args: - encoder_out: Encoder output sequences. (B, T, D_enc) - joint_out: Joint Network output sequences (B, T, U, D_joint) - target: Target label ID sequences. (B, L) - t_len: Encoder output sequences lengths. (B,) - u_len: Target label ID sequences lengths. (B,) - - Return: - loss_transducer: Transducer loss value. - cer_transducer: Character error rate for Transducer. - wer_transducer: Word Error Rate for Transducer. - - """ - if self.criterion_transducer is None: - try: - # from warprnnt_pytorch import RNNTLoss - # self.criterion_transducer = RNNTLoss( - # reduction="mean", - # fastemit_lambda=self.fastemit_lambda, - # ) - from warp_rnnt import rnnt_loss as RNNTLoss - self.criterion_transducer = RNNTLoss - - except ImportError: - logging.error( - "warp-rnnt was not installed." - "Please consult the installation documentation." - ) - exit(1) - - # loss_transducer = self.criterion_transducer( - # joint_out, - # target, - # t_len, - # u_len, - # ) - log_probs = torch.log_softmax(joint_out, dim=-1) - - loss_transducer = self.criterion_transducer( - log_probs, - target, - t_len, - u_len, - reduction="mean", - blank=self.blank_id, - gather=True, - ) - - if not self.training and (self.report_cer or self.report_wer): - if self.error_calculator is None: - from espnet2.asr_transducer.error_calculator import ErrorCalculator - - self.error_calculator = ErrorCalculator( - self.decoder, - self.joint_network, - self.token_list, - self.sym_space, - self.sym_blank, - report_cer=self.report_cer, - report_wer=self.report_wer, - ) - - cer_transducer, wer_transducer = self.error_calculator(encoder_out, target) - - return loss_transducer, cer_transducer, wer_transducer - - return loss_transducer, None, None - - def _calc_ctc_loss( - self, - encoder_out: torch.Tensor, - target: torch.Tensor, - t_len: torch.Tensor, - u_len: torch.Tensor, - ) -> torch.Tensor: - """Compute CTC loss. - - Args: - encoder_out: Encoder output sequences. (B, T, D_enc) - target: Target label ID sequences. (B, L) - t_len: Encoder output sequences lengths. (B,) - u_len: Target label ID sequences lengths. (B,) - - Return: - loss_ctc: CTC loss value. - - """ - ctc_in = self.ctc_lin( - torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate) - ) - ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1) - - target_mask = target != 0 - ctc_target = target[target_mask].cpu() - - with torch.backends.cudnn.flags(deterministic=True): - loss_ctc = torch.nn.functional.ctc_loss( - ctc_in, - ctc_target, - t_len, - u_len, - zero_infinity=True, - reduction="sum", - ) - loss_ctc /= target.size(0) - - return loss_ctc - - def _calc_lm_loss( - self, - decoder_out: torch.Tensor, - target: torch.Tensor, - ) -> torch.Tensor: - """Compute LM loss. - - Args: - decoder_out: Decoder output sequences. (B, U, D_dec) - target: Target label ID sequences. (B, L) - - Return: - loss_lm: LM loss value. - - """ - lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size) - lm_target = target.view(-1).type(torch.int64) - - with torch.no_grad(): - true_dist = lm_loss_in.clone() - true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1)) - - # Ignore blank ID (0) - ignore = lm_target == 0 - lm_target = lm_target.masked_fill(ignore, 0) - - true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing)) - - loss_lm = torch.nn.functional.kl_div( - torch.log_softmax(lm_loss_in, dim=1), - true_dist, - reduction="none", - ) - loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size( - 0 - ) - - return loss_lm diff --git a/funasr/models_transducer/utils.py b/funasr/models_transducer/utils.py deleted file mode 100644 index fd3c531b4..000000000 --- a/funasr/models_transducer/utils.py +++ /dev/null @@ -1,200 +0,0 @@ -"""Utility functions for Transducer models.""" - -from typing import List, Tuple - -import torch - - -class TooShortUttError(Exception): - """Raised when the utt is too short for subsampling. - - Args: - message: Error message to display. - actual_size: The size that cannot pass the subsampling. - limit: The size limit for subsampling. - - """ - - def __init__(self, message: str, actual_size: int, limit: int) -> None: - """Construct a TooShortUttError module.""" - super().__init__(message) - - self.actual_size = actual_size - self.limit = limit - - -def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]: - """Check if the input is too short for subsampling. - - Args: - sub_factor: Subsampling factor for Conv2DSubsampling. - size: Input size. - - Returns: - : Whether an error should be sent. - : Size limit for specified subsampling factor. - - """ - if sub_factor == 2 and size < 3: - return True, 7 - elif sub_factor == 4 and size < 7: - return True, 7 - elif sub_factor == 6 and size < 11: - return True, 11 - - return False, -1 - - -def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]: - """Get conv2D second layer parameters for given subsampling factor. - - Args: - sub_factor: Subsampling factor (1/X). - input_size: Input size. - - Returns: - : Kernel size for second convolution. - : Stride for second convolution. - : Conv2DSubsampling output size. - - """ - if sub_factor == 2: - return 3, 1, (((input_size - 1) // 2 - 2)) - elif sub_factor == 4: - return 3, 2, (((input_size - 1) // 2 - 1) // 2) - elif sub_factor == 6: - return 5, 3, (((input_size - 1) // 2 - 2) // 3) - else: - raise ValueError( - "subsampling_factor parameter should be set to either 2, 4 or 6." - ) - - -def make_chunk_mask( - size: int, - chunk_size: int, - left_chunk_size: int = 0, - device: torch.device = None, -) -> torch.Tensor: - """Create chunk mask for the subsequent steps (size, size). - - Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py - - Args: - size: Size of the source mask. - chunk_size: Number of frames in chunk. - left_chunk_size: Size of the left context in chunks (0 means full context). - device: Device for the mask tensor. - - Returns: - mask: Chunk mask. (size, size) - - """ - mask = torch.zeros(size, size, device=device, dtype=torch.bool) - - for i in range(size): - if left_chunk_size <= 0: - start = 0 - else: - start = max((i // chunk_size - left_chunk_size) * chunk_size, 0) - - end = min((i // chunk_size + 1) * chunk_size, size) - mask[i, start:end] = True - - return ~mask - - -def make_source_mask(lengths: torch.Tensor) -> torch.Tensor: - """Create source mask for given lengths. - - Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py - - Args: - lengths: Sequence lengths. (B,) - - Returns: - : Mask for the sequence lengths. (B, max_len) - - """ - max_len = lengths.max() - batch_size = lengths.size(0) - - expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths) - - return expanded_lengths >= lengths.unsqueeze(1) - - -def get_transducer_task_io( - labels: torch.Tensor, - encoder_out_lens: torch.Tensor, - ignore_id: int = -1, - blank_id: int = 0, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Get Transducer loss I/O. - - Args: - labels: Label ID sequences. (B, L) - encoder_out_lens: Encoder output lengths. (B,) - ignore_id: Padding symbol ID. - blank_id: Blank symbol ID. - - Returns: - decoder_in: Decoder inputs. (B, U) - target: Target label ID sequences. (B, U) - t_len: Time lengths. (B,) - u_len: Label lengths. (B,) - - """ - - def pad_list(labels: List[torch.Tensor], padding_value: int = 0): - """Create padded batch of labels from a list of labels sequences. - - Args: - labels: Labels sequences. [B x (?)] - padding_value: Padding value. - - Returns: - labels: Batch of padded labels sequences. (B,) - - """ - batch_size = len(labels) - - padded = ( - labels[0] - .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:]) - .fill_(padding_value) - ) - - for i in range(batch_size): - padded[i, : labels[i].size(0)] = labels[i] - - return padded - - device = labels.device - - labels_unpad = [y[y != ignore_id] for y in labels] - blank = labels[0].new([blank_id]) - - decoder_in = pad_list( - [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id - ).to(device) - - target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device) - - encoder_out_lens = list(map(int, encoder_out_lens)) - t_len = torch.IntTensor(encoder_out_lens).to(device) - - u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device) - - return decoder_in, target, t_len, u_len - -def pad_to_len(t: torch.Tensor, pad_len: int, dim: int): - """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros.""" - if t.size(dim) == pad_len: - return t - else: - pad_size = list(t.shape) - pad_size[dim] = pad_len - t.size(dim) - return torch.cat( - [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim - ) diff --git a/funasr/models_transducer/activation.py b/funasr/modules/activation.py similarity index 100% rename from funasr/models_transducer/activation.py rename to funasr/modules/activation.py diff --git a/funasr/models_transducer/beam_search_transducer.py b/funasr/modules/beam_search/beam_search_transducer.py similarity index 99% rename from funasr/models_transducer/beam_search_transducer.py rename to funasr/modules/beam_search/beam_search_transducer.py index 8e234e45a..eaf5627f9 100644 --- a/funasr/models_transducer/beam_search_transducer.py +++ b/funasr/modules/beam_search/beam_search_transducer.py @@ -6,8 +6,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch -from funasr.models_transducer.decoder.abs_decoder import AbsDecoder -from funasr.models_transducer.joint_network import JointNetwork +from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder +from funasr.models.joint_network import JointNetwork @dataclass diff --git a/funasr/modules/e2e_asr_common.py b/funasr/modules/e2e_asr_common.py index 92f90796a..9b5039c91 100644 --- a/funasr/modules/e2e_asr_common.py +++ b/funasr/modules/e2e_asr_common.py @@ -6,6 +6,8 @@ """Common functions for ASR.""" +from typing import List, Optional, Tuple + import json import logging import sys @@ -13,7 +15,11 @@ import sys from itertools import groupby import numpy as np import six +import torch +from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer +from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder +from funasr.models.joint_network import JointNetwork def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): """End detection. @@ -247,3 +253,148 @@ class ErrorCalculator(object): word_eds.append(editdistance.eval(hyp_words, ref_words)) word_ref_lens.append(len(ref_words)) return float(sum(word_eds)) / sum(word_ref_lens) + +class ErrorCalculatorTransducer: + """Calculate CER and WER for transducer models. + Args: + decoder: Decoder module. + joint_network: Joint Network module. + token_list: List of token units. + sym_space: Space symbol. + sym_blank: Blank symbol. + report_cer: Whether to compute CER. + report_wer: Whether to compute WER. + """ + + def __init__( + self, + decoder: AbsDecoder, + joint_network: JointNetwork, + token_list: List[int], + sym_space: str, + sym_blank: str, + report_cer: bool = False, + report_wer: bool = False, + ) -> None: + """Construct an ErrorCalculatorTransducer object.""" + super().__init__() + + self.beam_search = BeamSearchTransducer( + decoder=decoder, + joint_network=joint_network, + beam_size=1, + search_type="default", + score_norm=False, + ) + + self.decoder = decoder + + self.token_list = token_list + self.space = sym_space + self.blank = sym_blank + + self.report_cer = report_cer + self.report_wer = report_wer + + def __call__( + self, encoder_out: torch.Tensor, target: torch.Tensor + ) -> Tuple[Optional[float], Optional[float]]: + """Calculate sentence-level WER or/and CER score for Transducer model. + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + target: Target label ID sequences. (B, L) + Returns: + : Sentence-level CER score. + : Sentence-level WER score. + """ + cer, wer = None, None + + batchsize = int(encoder_out.size(0)) + + encoder_out = encoder_out.to(next(self.decoder.parameters()).device) + + batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)] + pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest] + + char_pred, char_target = self.convert_to_char(pred, target) + + if self.report_cer: + cer = self.calculate_cer(char_pred, char_target) + + if self.report_wer: + wer = self.calculate_wer(char_pred, char_target) + + return cer, wer + + def convert_to_char( + self, pred: torch.Tensor, target: torch.Tensor + ) -> Tuple[List, List]: + """Convert label ID sequences to character sequences. + Args: + pred: Prediction label ID sequences. (B, U) + target: Target label ID sequences. (B, L) + Returns: + char_pred: Prediction character sequences. (B, ?) + char_target: Target character sequences. (B, ?) + """ + char_pred, char_target = [], [] + + for i, pred_i in enumerate(pred): + char_pred_i = [self.token_list[int(h)] for h in pred_i] + char_target_i = [self.token_list[int(r)] for r in target[i]] + + char_pred_i = "".join(char_pred_i).replace(self.space, " ") + char_pred_i = char_pred_i.replace(self.blank, "") + + char_target_i = "".join(char_target_i).replace(self.space, " ") + char_target_i = char_target_i.replace(self.blank, "") + + char_pred.append(char_pred_i) + char_target.append(char_target_i) + + return char_pred, char_target + + def calculate_cer( + self, char_pred: torch.Tensor, char_target: torch.Tensor + ) -> float: + """Calculate sentence-level CER score. + Args: + char_pred: Prediction character sequences. (B, ?) + char_target: Target character sequences. (B, ?) + Returns: + : Average sentence-level CER score. + """ + import editdistance + + distances, lens = [], [] + + for i, char_pred_i in enumerate(char_pred): + pred = char_pred_i.replace(" ", "") + target = char_target[i].replace(" ", "") + distances.append(editdistance.eval(pred, target)) + lens.append(len(target)) + + return float(sum(distances)) / sum(lens) + + def calculate_wer( + self, char_pred: torch.Tensor, char_target: torch.Tensor + ) -> float: + """Calculate sentence-level WER score. + Args: + char_pred: Prediction character sequences. (B, ?) + char_target: Target character sequences. (B, ?) + Returns: + : Average sentence-level WER score + """ + import editdistance + + distances, lens = [], [] + + for i, char_pred_i in enumerate(char_pred): + pred = char_pred_i.replace("▁", " ").split() + target = char_target[i].replace("▁", " ").split() + + distances.append(editdistance.eval(pred, target)) + lens.append(len(target)) + + return float(sum(distances)) / sum(lens) diff --git a/funasr/modules/nets_utils.py b/funasr/modules/nets_utils.py index 6d77d69a6..5d4fe1c85 100644 --- a/funasr/modules/nets_utils.py +++ b/funasr/modules/nets_utils.py @@ -3,7 +3,7 @@ """Network related utility tools.""" import logging -from typing import Dict +from typing import Dict, List, Tuple import numpy as np import torch @@ -506,3 +506,196 @@ def get_activation(act): } return activation_funcs[act]() + +class TooShortUttError(Exception): + """Raised when the utt is too short for subsampling. + + Args: + message: Error message to display. + actual_size: The size that cannot pass the subsampling. + limit: The size limit for subsampling. + + """ + + def __init__(self, message: str, actual_size: int, limit: int) -> None: + """Construct a TooShortUttError module.""" + super().__init__(message) + + self.actual_size = actual_size + self.limit = limit + + +def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]: + """Check if the input is too short for subsampling. + + Args: + sub_factor: Subsampling factor for Conv2DSubsampling. + size: Input size. + + Returns: + : Whether an error should be sent. + : Size limit for specified subsampling factor. + + """ + if sub_factor == 2 and size < 3: + return True, 7 + elif sub_factor == 4 and size < 7: + return True, 7 + elif sub_factor == 6 and size < 11: + return True, 11 + + return False, -1 + + +def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]: + """Get conv2D second layer parameters for given subsampling factor. + + Args: + sub_factor: Subsampling factor (1/X). + input_size: Input size. + + Returns: + : Kernel size for second convolution. + : Stride for second convolution. + : Conv2DSubsampling output size. + + """ + if sub_factor == 2: + return 3, 1, (((input_size - 1) // 2 - 2)) + elif sub_factor == 4: + return 3, 2, (((input_size - 1) // 2 - 1) // 2) + elif sub_factor == 6: + return 5, 3, (((input_size - 1) // 2 - 2) // 3) + else: + raise ValueError( + "subsampling_factor parameter should be set to either 2, 4 or 6." + ) + + +def make_chunk_mask( + size: int, + chunk_size: int, + left_chunk_size: int = 0, + device: torch.device = None, +) -> torch.Tensor: + """Create chunk mask for the subsequent steps (size, size). + + Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py + + Args: + size: Size of the source mask. + chunk_size: Number of frames in chunk. + left_chunk_size: Size of the left context in chunks (0 means full context). + device: Device for the mask tensor. + + Returns: + mask: Chunk mask. (size, size) + + """ + mask = torch.zeros(size, size, device=device, dtype=torch.bool) + + for i in range(size): + if left_chunk_size <= 0: + start = 0 + else: + start = max((i // chunk_size - left_chunk_size) * chunk_size, 0) + + end = min((i // chunk_size + 1) * chunk_size, size) + mask[i, start:end] = True + + return ~mask + +def make_source_mask(lengths: torch.Tensor) -> torch.Tensor: + """Create source mask for given lengths. + + Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py + + Args: + lengths: Sequence lengths. (B,) + + Returns: + : Mask for the sequence lengths. (B, max_len) + + """ + max_len = lengths.max() + batch_size = lengths.size(0) + + expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths) + + return expanded_lengths >= lengths.unsqueeze(1) + + +def get_transducer_task_io( + labels: torch.Tensor, + encoder_out_lens: torch.Tensor, + ignore_id: int = -1, + blank_id: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Get Transducer loss I/O. + + Args: + labels: Label ID sequences. (B, L) + encoder_out_lens: Encoder output lengths. (B,) + ignore_id: Padding symbol ID. + blank_id: Blank symbol ID. + + Returns: + decoder_in: Decoder inputs. (B, U) + target: Target label ID sequences. (B, U) + t_len: Time lengths. (B,) + u_len: Label lengths. (B,) + + """ + + def pad_list(labels: List[torch.Tensor], padding_value: int = 0): + """Create padded batch of labels from a list of labels sequences. + + Args: + labels: Labels sequences. [B x (?)] + padding_value: Padding value. + + Returns: + labels: Batch of padded labels sequences. (B,) + + """ + batch_size = len(labels) + + padded = ( + labels[0] + .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:]) + .fill_(padding_value) + ) + + for i in range(batch_size): + padded[i, : labels[i].size(0)] = labels[i] + + return padded + + device = labels.device + + labels_unpad = [y[y != ignore_id] for y in labels] + blank = labels[0].new([blank_id]) + + decoder_in = pad_list( + [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id + ).to(device) + + target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device) + + encoder_out_lens = list(map(int, encoder_out_lens)) + t_len = torch.IntTensor(encoder_out_lens).to(device) + + u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device) + + return decoder_in, target, t_len, u_len + +def pad_to_len(t: torch.Tensor, pad_len: int, dim: int): + """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros.""" + if t.size(dim) == pad_len: + return t + else: + pad_size = list(t.shape) + pad_size[dim] = pad_len - t.size(dim) + return torch.cat( + [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim + ) diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py index be1445590..cae18c169 100644 --- a/funasr/tasks/asr_transducer.py +++ b/funasr/tasks/asr_transducer.py @@ -21,15 +21,13 @@ from funasr.models.decoder.transformer_decoder import ( LightweightConvolutionTransformerDecoder, TransformerDecoder, ) -from funasr.models_transducer.decoder.abs_decoder import AbsDecoder -from funasr.models_transducer.decoder.rnn_decoder import RNNDecoder -from funasr.models_transducer.decoder.stateless_decoder import StatelessDecoder -from funasr.models_transducer.encoder.encoder import Encoder -from funasr.models_transducer.encoder.sanm_encoder import SANMEncoderChunkOpt -from funasr.models_transducer.espnet_transducer_model import ESPnetASRTransducerModel -from funasr.models_transducer.espnet_transducer_model_unified import ESPnetASRUnifiedTransducerModel -from funasr.models_transducer.espnet_transducer_model_uni_asr import UniASRTransducerModel -from funasr.models_transducer.joint_network import JointNetwork +from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder +from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder +from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder +from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder +from funasr.models.e2e_transducer import TransducerModel +from funasr.models.e2e_transducer_unified import UnifiedTransducerModel +from funasr.models.joint_network import JointNetwork from funasr.layers.abs_normalize import AbsNormalize from funasr.layers.global_mvn import GlobalMVN from funasr.layers.utterance_mvn import UtteranceMVN @@ -75,7 +73,6 @@ encoder_choices = ClassChoices( "encoder", classes=dict( encoder=Encoder, - sanm_chunk_opt=SANMEncoderChunkOpt, ), default="encoder", ) @@ -158,7 +155,7 @@ class ASRTransducerTask(AbsTask): group.add_argument( "--model_conf", action=NestedDictAction, - default=get_default_kwargs(ESPnetASRTransducerModel), + default=get_default_kwargs(TransducerModel), help="The keyword arguments for the model class.", ) # group.add_argument( @@ -354,7 +351,7 @@ class ASRTransducerTask(AbsTask): return retval @classmethod - def build_model(cls, args: argparse.Namespace) -> ESPnetASRTransducerModel: + def build_model(cls, args: argparse.Namespace) -> TransducerModel: """Required data depending on task mode. Args: cls: ASRTransducerTask object. @@ -440,22 +437,8 @@ class ASRTransducerTask(AbsTask): # 7. Build model - if getattr(args, "encoder", None) is not None and args.encoder == 'sanm_chunk_opt': - model = UniASRTransducerModel( - vocab_size=vocab_size, - token_list=token_list, - frontend=frontend, - specaug=specaug, - normalize=normalize, - encoder=encoder, - decoder=decoder, - att_decoder=att_decoder, - joint_network=joint_network, - **args.model_conf, - ) - - elif encoder.unified_model_training: - model = ESPnetASRUnifiedTransducerModel( + if encoder.unified_model_training: + model = UnifiedTransducerModel( vocab_size=vocab_size, token_list=token_list, frontend=frontend, @@ -469,7 +452,7 @@ class ASRTransducerTask(AbsTask): ) else: - model = ESPnetASRTransducerModel( + model = TransducerModel( vocab_size=vocab_size, token_list=token_list, frontend=frontend,