From fc606ceef3aa5a1dbca795a43147c0aa9ddf0b34 Mon Sep 17 00:00:00 2001 From: aky15 Date: Tue, 14 Mar 2023 20:42:08 +0800 Subject: [PATCH 01/14] rnnt --- funasr/bin/asr_inference_launch.py | 43 + funasr/bin/asr_inference_rnnt.py | 1297 ++++++++--------- funasr/bin/asr_train_transducer.py | 46 + funasr/models_transducer/__init__.py | 0 funasr/models_transducer/activation.py | 213 +++ .../beam_search_transducer.py | 705 +++++++++ funasr/models_transducer/decoder/__init__.py | 0 .../models_transducer/decoder/abs_decoder.py | 110 ++ .../models_transducer/decoder/rnn_decoder.py | 259 ++++ .../decoder/stateless_decoder.py | 157 ++ funasr/models_transducer/encoder/__init__.py | 0 .../encoder/blocks/__init__.py | 0 .../encoder/blocks/branchformer.py | 178 +++ .../encoder/blocks/conformer.py | 198 +++ .../encoder/blocks/conv1d.py | 221 +++ .../encoder/blocks/conv_input.py | 226 +++ .../encoder/blocks/linear_input.py | 52 + funasr/models_transducer/encoder/building.py | 352 +++++ funasr/models_transducer/encoder/encoder.py | 294 ++++ .../encoder/modules/__init__.py | 0 .../encoder/modules/attention.py | 246 ++++ .../encoder/modules/convolution.py | 196 +++ .../encoder/modules/multi_blocks.py | 105 ++ .../encoder/modules/normalization.py | 170 +++ .../encoder/modules/positional_encoding.py | 91 ++ .../models_transducer/encoder/sanm_encoder.py | 835 +++++++++++ .../models_transducer/encoder/validation.py | 171 +++ funasr/models_transducer/error_calculator.py | 170 +++ .../espnet_transducer_model.py | 484 ++++++ .../espnet_transducer_model_uni_asr.py | 485 ++++++ .../espnet_transducer_model_unified.py | 588 ++++++++ funasr/models_transducer/joint_network.py | 62 + funasr/models_transducer/utils.py | 200 +++ funasr/tasks/asr_transducer.py | 487 +++++++ 34 files changed, 7945 insertions(+), 696 deletions(-) create mode 100755 funasr/bin/asr_train_transducer.py create mode 100644 funasr/models_transducer/__init__.py create mode 100644 funasr/models_transducer/activation.py create mode 100644 funasr/models_transducer/beam_search_transducer.py create mode 100644 funasr/models_transducer/decoder/__init__.py create mode 100644 funasr/models_transducer/decoder/abs_decoder.py create mode 100644 funasr/models_transducer/decoder/rnn_decoder.py create mode 100644 funasr/models_transducer/decoder/stateless_decoder.py create mode 100644 funasr/models_transducer/encoder/__init__.py create mode 100644 funasr/models_transducer/encoder/blocks/__init__.py create mode 100644 funasr/models_transducer/encoder/blocks/branchformer.py create mode 100644 funasr/models_transducer/encoder/blocks/conformer.py create mode 100644 funasr/models_transducer/encoder/blocks/conv1d.py create mode 100644 funasr/models_transducer/encoder/blocks/conv_input.py create mode 100644 funasr/models_transducer/encoder/blocks/linear_input.py create mode 100644 funasr/models_transducer/encoder/building.py create mode 100644 funasr/models_transducer/encoder/encoder.py create mode 100644 funasr/models_transducer/encoder/modules/__init__.py create mode 100644 funasr/models_transducer/encoder/modules/attention.py create mode 100644 funasr/models_transducer/encoder/modules/convolution.py create mode 100644 funasr/models_transducer/encoder/modules/multi_blocks.py create mode 100644 funasr/models_transducer/encoder/modules/normalization.py create mode 100644 funasr/models_transducer/encoder/modules/positional_encoding.py create mode 100644 funasr/models_transducer/encoder/sanm_encoder.py create mode 100644 funasr/models_transducer/encoder/validation.py create mode 100644 funasr/models_transducer/error_calculator.py create mode 100644 funasr/models_transducer/espnet_transducer_model.py create mode 100644 funasr/models_transducer/espnet_transducer_model_uni_asr.py create mode 100644 funasr/models_transducer/espnet_transducer_model_unified.py create mode 100644 funasr/models_transducer/joint_network.py create mode 100644 funasr/models_transducer/utils.py create mode 100644 funasr/tasks/asr_transducer.py diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index 1fae766ea..b9be3e221 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -131,6 +131,11 @@ def get_parser(): help="Pretrained model tag. If specify this option, *_train_config and " "*_file will be overwritten", ) + group.add_argument( + "--beam_search_config", + default={}, + help="The keyword arguments for transducer beam search.", + ) group = parser.add_argument_group("Beam-search related") group.add_argument( @@ -168,6 +173,41 @@ def get_parser(): group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") group.add_argument("--streaming", type=str2bool, default=False) + group.add_argument("--simu_streaming", type=str2bool, default=False) + group.add_argument("--chunk_size", type=int, default=16) + group.add_argument("--left_context", type=int, default=16) + group.add_argument("--right_context", type=int, default=0) + group.add_argument( + "--display_partial_hypotheses", + type=bool, + default=False, + help="Whether to display partial hypotheses during chunk-by-chunk inference.", + ) + + group = parser.add_argument_group("Dynamic quantization related") + group.add_argument( + "--quantize_asr_model", + type=bool, + default=False, + help="Apply dynamic quantization to ASR model.", + ) + group.add_argument( + "--quantize_modules", + nargs="*", + default=None, + help="""Module names to apply dynamic quantization on. + The module names are provided as a list, where each name is separated + by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]). + Each specified name should be an attribute of 'torch.nn', e.g.: + torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""", + ) + group.add_argument( + "--quantize_dtype", + type=str, + default="qint8", + choices=["float16", "qint8"], + help="Dtype for dynamic quantization.", + ) group = parser.add_argument_group("Text converter related") group.add_argument( @@ -262,6 +302,9 @@ def inference_launch_funasr(**kwargs): elif mode == "mfcca": from funasr.bin.asr_inference_mfcca import inference_modelscope return inference_modelscope(**kwargs) + elif mode == "rnnt": + from funasr.bin.asr_inference_rnnt import inference + return inference(**kwargs) else: logging.info("Unknown decoding mode: {}".format(mode)) return None diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py index 6cd70613b..f651f118d 100644 --- a/funasr/bin/asr_inference_rnnt.py +++ b/funasr/bin/asr_inference_rnnt.py @@ -1,151 +1,145 @@ #!/usr/bin/env python3 + +""" Inference class definition for Transducer models.""" + +from __future__ import annotations + import argparse import logging +import math import sys -import time -import copy -import os -import codecs -import tempfile -import requests from pathlib import Path -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union -from typing import Dict -from typing import Any -from typing import List +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch -from typeguard import check_argument_types +from packaging.version import parse as V +from typeguard import check_argument_types, check_return_type +from funasr.models_transducer.beam_search_transducer import ( + BeamSearchTransducer, + Hypothesis, +) +from funasr.models_transducer.utils import TooShortUttError from funasr.fileio.datadir_writer import DatadirWriter -from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch -from funasr.modules.beam_search.beam_search import Hypothesis -from funasr.modules.scorers.ctc import CTCPrefixScorer -from funasr.modules.scorers.length_bonus import LengthBonus -from funasr.modules.subsampling import TooShortUttError -from funasr.tasks.asr import ASRTaskParaformer as ASRTask +from funasr.tasks.asr_transducer import ASRTransducerTask from funasr.tasks.lm import LMTask from funasr.text.build_tokenizer import build_tokenizer from funasr.text.token_id_converter import TokenIDConverter from funasr.torch_utils.device_funcs import to_device from funasr.torch_utils.set_all_random_seed import set_all_random_seed from funasr.utils import config_argparse +from funasr.utils.types import str2bool, str2triple_str, str_or_none from funasr.utils.cli_utils import get_commandline_args -from funasr.utils.types import str2bool -from funasr.utils.types import str2triple_str -from funasr.utils.types import str_or_none -from funasr.utils import asr_utils, wav_utils, postprocess_utils -from funasr.models.frontend.wav_frontend import WavFrontend -from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer -from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export class Speech2Text: - """Speech2Text class - - Examples: - >>> import soundfile - >>> speech2text = Speech2Text("asr_config.yml", "asr.pth") - >>> audio, rate = soundfile.read("speech.wav") - >>> speech2text(audio) - [(text, token, token_int, hypothesis object), ...] - + """Speech2Text class for Transducer models. + Args: + asr_train_config: ASR model training config path. + asr_model_file: ASR model path. + beam_search_config: Beam search config path. + lm_train_config: Language Model training config path. + lm_file: Language Model config path. + token_type: Type of token units. + bpemodel: BPE model path. + device: Device to use for inference. + beam_size: Size of beam during search. + dtype: Data type. + lm_weight: Language model weight. + quantize_asr_model: Whether to apply dynamic quantization to ASR model. + quantize_modules: List of module names to apply dynamic quantization on. + quantize_dtype: Dynamic quantization data type. + nbest: Number of final hypothesis. + streaming: Whether to perform chunk-by-chunk inference. + chunk_size: Number of frames in chunk AFTER subsampling. + left_context: Number of frames in left context AFTER subsampling. + right_context: Number of frames in right context AFTER subsampling. + display_partial_hypotheses: Whether to display partial hypotheses. """ def __init__( - self, - asr_train_config: Union[Path, str] = None, - asr_model_file: Union[Path, str] = None, - cmvn_file: Union[Path, str] = None, - lm_train_config: Union[Path, str] = None, - lm_file: Union[Path, str] = None, - token_type: str = None, - bpemodel: str = None, - device: str = "cpu", - maxlenratio: float = 0.0, - minlenratio: float = 0.0, - dtype: str = "float32", - beam_size: int = 20, - ctc_weight: float = 0.5, - lm_weight: float = 1.0, - ngram_weight: float = 0.9, - penalty: float = 0.0, - nbest: int = 1, - frontend_conf: dict = None, - hotword_list_or_file: str = None, - **kwargs, - ): + self, + asr_train_config: Union[Path, str] = None, + asr_model_file: Union[Path, str] = None, + beam_search_config: Dict[str, Any] = None, + lm_train_config: Union[Path, str] = None, + lm_file: Union[Path, str] = None, + token_type: str = None, + bpemodel: str = None, + device: str = "cpu", + beam_size: int = 5, + dtype: str = "float32", + lm_weight: float = 1.0, + quantize_asr_model: bool = False, + quantize_modules: List[str] = None, + quantize_dtype: str = "qint8", + nbest: int = 1, + streaming: bool = False, + simu_streaming: bool = False, + chunk_size: int = 16, + left_context: int = 32, + right_context: int = 0, + display_partial_hypotheses: bool = False, + ) -> None: + """Construct a Speech2Text object.""" + super().__init__() + assert check_argument_types() - # 1. Build ASR model - scorers = {} - asr_model, asr_train_args = ASRTask.build_model_from_file( - asr_train_config, asr_model_file, cmvn_file, device - ) - frontend = None - if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: - frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) - - logging.info("asr_model: {}".format(asr_model)) - logging.info("asr_train_args: {}".format(asr_train_args)) - asr_model.to(dtype=getattr(torch, dtype)).eval() - - if asr_model.ctc != None: - ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) - scorers.update( - ctc=ctc - ) - token_list = asr_model.token_list - scorers.update( - length_bonus=LengthBonus(len(token_list)), + asr_model, asr_train_args = ASRTransducerTask.build_model_from_file( + asr_train_config, asr_model_file, device ) - # 2. Build Language model + if quantize_asr_model: + if quantize_modules is not None: + if not all([q in ["LSTM", "Linear"] for q in quantize_modules]): + raise ValueError( + "Only 'Linear' and 'LSTM' modules are currently supported" + " by PyTorch and in --quantize_modules" + ) + + q_config = set([getattr(torch.nn, q) for q in quantize_modules]) + else: + q_config = {torch.nn.Linear} + + if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")): + raise ValueError( + "float16 dtype for dynamic quantization is not supported with torch" + " version < 1.5.0. Switching to qint8 dtype instead." + ) + q_dtype = getattr(torch, quantize_dtype) + + asr_model = torch.quantization.quantize_dynamic( + asr_model, q_config, dtype=q_dtype + ).eval() + else: + asr_model.to(dtype=getattr(torch, dtype)).eval() + if lm_train_config is not None: lm, lm_train_args = LMTask.build_model_from_file( lm_train_config, lm_file, device ) - scorers["lm"] = lm.lm - - # 3. Build ngram model - # ngram is not supported now - ngram = None - scorers["ngram"] = ngram + lm_scorer = lm.lm + else: + lm_scorer = None # 4. Build BeamSearch object - # transducer is not supported now - beam_search_transducer = None + if beam_search_config is None: + beam_search_config = {} - weights = dict( - decoder=1.0 - ctc_weight, - ctc=ctc_weight, - lm=lm_weight, - ngram=ngram_weight, - length_bonus=penalty, - ) - beam_search = BeamSearch( - beam_size=beam_size, - weights=weights, - scorers=scorers, - sos=asr_model.sos, - eos=asr_model.eos, - vocab_size=len(token_list), - token_list=token_list, - pre_beam_score_key=None if ctc_weight == 1.0 else "full", + beam_search = BeamSearchTransducer( + asr_model.decoder, + asr_model.joint_network, + beam_size, + lm=lm_scorer, + lm_weight=lm_weight, + nbest=nbest, + **beam_search_config, ) - beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() - for scorer in scorers.values(): - if isinstance(scorer, torch.nn.Module): - scorer.to(device=device, dtype=getattr(torch, dtype)).eval() + token_list = asr_model.token_list - logging.info(f"Decoding device={device}, dtype={dtype}") - - # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text if token_type is None: token_type = asr_train_args.token_type if bpemodel is None: @@ -165,439 +159,397 @@ class Speech2Text: self.asr_model = asr_model self.asr_train_args = asr_train_args + self.device = device + self.dtype = dtype + self.nbest = nbest + self.converter = converter self.tokenizer = tokenizer - # 6. [Optional] Build hotword list from str, local file or url - self.hotword_list = None - self.hotword_list = self.generate_hotwords_list(hotword_list_or_file) - - is_use_lm = lm_weight != 0.0 and lm_file is not None - if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm: - beam_search = None self.beam_search = beam_search - logging.info(f"Beam_search: {self.beam_search}") - self.beam_search_transducer = beam_search_transducer - self.maxlenratio = maxlenratio - self.minlenratio = minlenratio - self.device = device - self.dtype = dtype - self.nbest = nbest - self.frontend = frontend - self.encoder_downsampling_factor = 1 - if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d": - self.encoder_downsampling_factor = 4 + self.streaming = streaming + self.simu_streaming = simu_streaming + self.chunk_size = max(chunk_size, 0) + self.left_context = max(left_context, 0) + self.right_context = max(right_context, 0) - @torch.no_grad() - def __call__( - self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None - ): - """Inference + if not streaming or chunk_size == 0: + self.streaming = False + self.asr_model.encoder.dynamic_chunk_training = False + + if not simu_streaming or chunk_size == 0: + self.simu_streaming = False + self.asr_model.encoder.dynamic_chunk_training = False - Args: - speech: Input speech data - Returns: - text, token, token_int, hyp + self.n_fft = asr_train_args.frontend_conf.get("n_fft", 512) + self.hop_length = asr_train_args.frontend_conf.get("hop_length", 128) - """ - assert check_argument_types() - - # Input as audio signal - if isinstance(speech, np.ndarray): - speech = torch.tensor(speech) - - if self.frontend is not None: - feats, feats_len = self.frontend.forward(speech, speech_lengths) - feats = to_device(feats, device=self.device) - feats_len = feats_len.int() - self.asr_model.frontend = None + if asr_train_args.frontend_conf.get("win_length", None) is not None: + self.frontend_window_size = asr_train_args.frontend_conf["win_length"] else: - feats = speech - feats_len = speech_lengths - lfr_factor = max(1, (feats.size()[-1] // 80) - 1) - batch = {"speech": feats, "speech_lengths": feats_len} + self.frontend_window_size = self.n_fft - # a. To device - batch = to_device(batch, device=self.device) - - # b. Forward Encoder - enc, enc_len = self.asr_model.encode(**batch) - if isinstance(enc, tuple): - enc = enc[0] - # assert len(enc) == 1, len(enc) - enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor - - predictor_outs = self.asr_model.calc_predictor(enc, enc_len) - pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \ - predictor_outs[2], predictor_outs[3] - pre_token_length = pre_token_length.round().long() - if torch.max(pre_token_length) < 1: - return [] - if not isinstance(self.asr_model, ContextualParaformer): - if self.hotword_list: - logging.warning("Hotword is given but asr model is not a ContextualParaformer.") - decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) - decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] - else: - decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list) - decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] - - results = [] - b, n, d = decoder_out.size() - for i in range(b): - x = enc[i, :enc_len[i], :] - am_scores = decoder_out[i, :pre_token_length[i], :] - if self.beam_search is not None: - nbest_hyps = self.beam_search( - x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio - ) - - nbest_hyps = nbest_hyps[: self.nbest] - else: - yseq = am_scores.argmax(dim=-1) - score = am_scores.max(dim=-1)[0] - score = torch.sum(score, dim=-1) - # pad with mask tokens to ensure compatibility with sos/eos tokens - yseq = torch.tensor( - [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device - ) - nbest_hyps = [Hypothesis(yseq=yseq, score=score)] - - for hyp in nbest_hyps: - assert isinstance(hyp, (Hypothesis)), type(hyp) - - # remove sos/eos and get results - last_pos = -1 - if isinstance(hyp.yseq, list): - token_int = hyp.yseq[1:last_pos] - else: - token_int = hyp.yseq[1:last_pos].tolist() - - # remove blank symbol id, which is assumed to be 0 - token_int = list(filter(lambda x: x != 0 and x != 2, token_int)) - - # Change integer-ids to tokens - token = self.converter.ids2tokens(token_int) - - if self.tokenizer is not None: - text = self.tokenizer.tokens2text(token) - else: - text = None - - results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor)) - - # assert check_return_type(results) - return results - - def generate_hotwords_list(self, hotword_list_or_file): - # for None - if hotword_list_or_file is None: - hotword_list = None - # for local txt inputs - elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'): - logging.info("Attempting to parse hotwords from local txt...") - hotword_list = [] - hotword_str_list = [] - with codecs.open(hotword_list_or_file, 'r') as fin: - for line in fin.readlines(): - hw = line.strip() - hotword_str_list.append(hw) - hotword_list.append(self.converter.tokens2ids([i for i in hw])) - hotword_list.append([self.asr_model.sos]) - hotword_str_list.append('') - logging.info("Initialized hotword list from file: {}, hotword list: {}." - .format(hotword_list_or_file, hotword_str_list)) - # for url, download and generate txt - elif hotword_list_or_file.startswith('http'): - logging.info("Attempting to parse hotwords from url...") - work_dir = tempfile.TemporaryDirectory().name - if not os.path.exists(work_dir): - os.makedirs(work_dir) - text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file)) - local_file = requests.get(hotword_list_or_file) - open(text_file_path, "wb").write(local_file.content) - hotword_list_or_file = text_file_path - hotword_list = [] - hotword_str_list = [] - with codecs.open(hotword_list_or_file, 'r') as fin: - for line in fin.readlines(): - hw = line.strip() - hotword_str_list.append(hw) - hotword_list.append(self.converter.tokens2ids([i for i in hw])) - hotword_list.append([self.asr_model.sos]) - hotword_str_list.append('') - logging.info("Initialized hotword list from file: {}, hotword list: {}." - .format(hotword_list_or_file, hotword_str_list)) - # for text str input - elif not hotword_list_or_file.endswith('.txt'): - logging.info("Attempting to parse hotwords as str...") - hotword_list = [] - hotword_str_list = [] - for hw in hotword_list_or_file.strip().split(): - hotword_str_list.append(hw) - hotword_list.append(self.converter.tokens2ids([i for i in hw])) - hotword_list.append([self.asr_model.sos]) - hotword_str_list.append('') - logging.info("Hotword list: {}.".format(hotword_str_list)) - else: - hotword_list = None - return hotword_list - -class Speech2TextExport: - """Speech2TextExport class - - """ - - def __init__( - self, - asr_train_config: Union[Path, str] = None, - asr_model_file: Union[Path, str] = None, - cmvn_file: Union[Path, str] = None, - lm_train_config: Union[Path, str] = None, - lm_file: Union[Path, str] = None, - token_type: str = None, - bpemodel: str = None, - device: str = "cpu", - maxlenratio: float = 0.0, - minlenratio: float = 0.0, - dtype: str = "float32", - beam_size: int = 20, - ctc_weight: float = 0.5, - lm_weight: float = 1.0, - ngram_weight: float = 0.9, - penalty: float = 0.0, - nbest: int = 1, - frontend_conf: dict = None, - hotword_list_or_file: str = None, - **kwargs, - ): - - # 1. Build ASR model - asr_model, asr_train_args = ASRTask.build_model_from_file( - asr_train_config, asr_model_file, cmvn_file, device + self.window_size = self.chunk_size + self.right_context + self._raw_ctx = self.asr_model.encoder.get_encoder_input_raw_size( + self.window_size, self.hop_length ) - frontend = None - if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: - frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) + self._ctx = self.asr_model.encoder.get_encoder_input_size( + self.window_size + ) + - logging.info("asr_model: {}".format(asr_model)) - logging.info("asr_train_args: {}".format(asr_train_args)) - asr_model.to(dtype=getattr(torch, dtype)).eval() + #self.last_chunk_length = ( + # self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 + #) * self.hop_length - token_list = asr_model.token_list + self.last_chunk_length = ( + self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 + ) + self.reset_inference_cache() + def reset_inference_cache(self) -> None: + """Reset Speech2Text parameters.""" + self.frontend_cache = None + self.asr_model.encoder.reset_streaming_cache( + self.left_context, device=self.device + ) + self.beam_search.reset_inference_cache() - logging.info(f"Decoding device={device}, dtype={dtype}") - - # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text - if token_type is None: - token_type = asr_train_args.token_type - if bpemodel is None: - bpemodel = asr_train_args.bpemodel - - if token_type is None: - tokenizer = None - elif token_type == "bpe": - if bpemodel is not None: - tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) - else: - tokenizer = None - else: - tokenizer = build_tokenizer(token_type=token_type) - converter = TokenIDConverter(token_list=token_list) - logging.info(f"Text tokenizer: {tokenizer}") - - # self.asr_model = asr_model - self.asr_train_args = asr_train_args - self.converter = converter - self.tokenizer = tokenizer - - self.device = device - self.dtype = dtype - self.nbest = nbest - self.frontend = frontend - - model = Paraformer_export(asr_model, onnx=False) - self.asr_model = model - - @torch.no_grad() - def __call__( - self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None - ): - """Inference + self.num_processed_frames = torch.tensor([[0]], device=self.device) + def apply_frontend( + self, speech: torch.Tensor, is_final: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward frontend. Args: - speech: Input speech data + speech: Speech data. (S) + is_final: Whether speech corresponds to the final (or only) chunk of data. Returns: - text, token, token_int, hyp + feats: Features sequence. (1, T_in, F) + feats_lengths: Features sequence length. (1, T_in, F) + """ + if self.frontend_cache is not None: + speech = torch.cat([self.frontend_cache["waveform_buffer"], speech], dim=0) + if is_final: + if self.streaming and speech.size(0) < self.last_chunk_length: + pad = torch.zeros( + self.last_chunk_length - speech.size(0), dtype=speech.dtype + ) + speech = torch.cat([speech, pad], dim=0) + + speech_to_process = speech + waveform_buffer = None + else: + n_frames = ( + speech.size(0) - (self.frontend_window_size - self.hop_length) + ) // self.hop_length + + n_residual = ( + speech.size(0) - (self.frontend_window_size - self.hop_length) + ) % self.hop_length + + speech_to_process = speech.narrow( + 0, + 0, + (self.frontend_window_size - self.hop_length) + + n_frames * self.hop_length, + ) + + waveform_buffer = speech.narrow( + 0, + speech.size(0) + - (self.frontend_window_size - self.hop_length) + - n_residual, + (self.frontend_window_size - self.hop_length) + n_residual, + ).clone() + + speech_to_process = speech_to_process.unsqueeze(0).to( + getattr(torch, self.dtype) + ) + lengths = speech_to_process.new_full( + [1], dtype=torch.long, fill_value=speech_to_process.size(1) + ) + batch = {"speech": speech_to_process, "speech_lengths": lengths} + batch = to_device(batch, device=self.device) + + feats, feats_lengths = self.asr_model._extract_feats(**batch) + if self.asr_model.normalize is not None: + feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) + + if is_final: + if self.frontend_cache is None: + pass + else: + feats = feats.narrow( + 1, + math.ceil( + math.ceil(self.frontend_window_size / self.hop_length) / 2 + ), + feats.size(1) + - math.ceil( + math.ceil(self.frontend_window_size / self.hop_length) / 2 + ), + ) + else: + if self.frontend_cache is None: + feats = feats.narrow( + 1, + 0, + feats.size(1) + - math.ceil( + math.ceil(self.frontend_window_size / self.hop_length) / 2 + ), + ) + else: + feats = feats.narrow( + 1, + math.ceil( + math.ceil(self.frontend_window_size / self.hop_length) / 2 + ), + feats.size(1) + - 2 + * math.ceil( + math.ceil(self.frontend_window_size / self.hop_length) / 2 + ), + ) + + feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) + + if is_final: + self.frontend_cache = None + else: + self.frontend_cache = {"waveform_buffer": waveform_buffer} + + return feats, feats_lengths + + @torch.no_grad() + def streaming_decode( + self, + speech: Union[torch.Tensor, np.ndarray], + is_final: bool = True, + ) -> List[Hypothesis]: + """Speech2Text streaming call. + Args: + speech: Chunk of speech data. (S) + is_final: Whether speech corresponds to the final chunk of data. + Returns: + nbest_hypothesis: N-best hypothesis. + """ + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + if is_final: + if self.streaming and speech.size(0) < self.last_chunk_length: + pad = torch.zeros( + self.last_chunk_length - speech.size(0), speech.size(1), dtype=speech.dtype + ) + speech = torch.cat([speech, pad], dim=0) #feats, feats_length = self.apply_frontend(speech, is_final=is_final) + + feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) + feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) + + if self.asr_model.normalize is not None: + feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) + + feats = to_device(feats, device=self.device) + feats_lengths = to_device(feats_lengths, device=self.device) + enc_out = self.asr_model.encoder.chunk_forward( + feats, + feats_lengths, + self.num_processed_frames, + chunk_size=self.chunk_size, + left_context=self.left_context, + right_context=self.right_context, + ) + nbest_hyps = self.beam_search(enc_out[0], is_final=is_final) + + self.num_processed_frames += self.chunk_size + + if is_final: + self.reset_inference_cache() + + return nbest_hyps + + @torch.no_grad() + def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[Hypothesis]: + """Speech2Text call. + Args: + speech: Speech data. (S) + Returns: + nbest_hypothesis: N-best hypothesis. """ assert check_argument_types() - # Input as audio signal if isinstance(speech, np.ndarray): speech = torch.tensor(speech) - - if self.frontend is not None: - feats, feats_len = self.frontend.forward(speech, speech_lengths) - feats = to_device(feats, device=self.device) - feats_len = feats_len.int() - self.asr_model.frontend = None - else: - feats = speech - feats_len = speech_lengths - - enc_len_batch_total = feats_len.sum() - lfr_factor = max(1, (feats.size()[-1] // 80) - 1) - batch = {"speech": feats, "speech_lengths": feats_len} - - # a. To device - batch = to_device(batch, device=self.device) - - decoder_outs = self.asr_model(**batch) - decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] + feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) + feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) + + if self.asr_model.normalize is not None: + feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) + + feats = to_device(feats, device=self.device) + feats_lengths = to_device(feats_lengths, device=self.device) + enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context, self.right_context) + nbest_hyps = self.beam_search(enc_out[0]) + + return nbest_hyps + + @torch.no_grad() + def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[Hypothesis]: + """Speech2Text call. + Args: + speech: Speech data. (S) + Returns: + nbest_hypothesis: N-best hypothesis. + """ + assert check_argument_types() + + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + + # lengths: (1,) + # feats, feats_length = self.apply_frontend(speech) + feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) + # lengths: (1,) + feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) + + # print(feats.shape) + # print(feats_lengths) + if self.asr_model.normalize is not None: + feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) + + feats = to_device(feats, device=self.device) + feats_lengths = to_device(feats_lengths, device=self.device) + + enc_out, _ = self.asr_model.encoder(feats, feats_lengths) + + nbest_hyps = self.beam_search(enc_out[0]) + + return nbest_hyps + + def hypotheses_to_results(self, nbest_hyps: List[Hypothesis]) -> List[Any]: + """Build partial or final results from the hypotheses. + Args: + nbest_hyps: N-best hypothesis. + Returns: + results: Results containing different representation for the hypothesis. + """ results = [] - b, n, d = decoder_out.size() - for i in range(b): - am_scores = decoder_out[i, :ys_pad_lens[i], :] - yseq = am_scores.argmax(dim=-1) - score = am_scores.max(dim=-1)[0] - score = torch.sum(score, dim=-1) - # pad with mask tokens to ensure compatibility with sos/eos tokens - yseq = torch.tensor( - yseq.tolist(), device=yseq.device - ) - nbest_hyps = [Hypothesis(yseq=yseq, score=score)] + for hyp in nbest_hyps: + token_int = list(filter(lambda x: x != 0, hyp.yseq)) - for hyp in nbest_hyps: - assert isinstance(hyp, (Hypothesis)), type(hyp) + token = self.converter.ids2tokens(token_int) - # remove sos/eos and get results - last_pos = -1 - if isinstance(hyp.yseq, list): - token_int = hyp.yseq[1:last_pos] - else: - token_int = hyp.yseq[1:last_pos].tolist() + if self.tokenizer is not None: + text = self.tokenizer.tokens2text(token) + else: + text = None + results.append((text, token, token_int, hyp)) - # remove blank symbol id, which is assumed to be 0 - token_int = list(filter(lambda x: x != 0 and x != 2, token_int)) - - # Change integer-ids to tokens - token = self.converter.ids2tokens(token_int) - - if self.tokenizer is not None: - text = self.tokenizer.tokens2text(token) - else: - text = None - - results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor)) + assert check_return_type(results) return results + @staticmethod + def from_pretrained( + model_tag: Optional[str] = None, + **kwargs: Optional[Any], + ) -> Speech2Text: + """Build Speech2Text instance from the pretrained model. + Args: + model_tag: Model tag of the pretrained models. + Return: + : Speech2Text instance. + """ + if model_tag is not None: + try: + from espnet_model_zoo.downloader import ModelDownloader + + except ImportError: + logging.error( + "`espnet_model_zoo` is not installed. " + "Please install via `pip install -U espnet_model_zoo`." + ) + raise + d = ModelDownloader() + kwargs.update(**d.download_and_unpack(model_tag)) + + return Speech2Text(**kwargs) + def inference( - maxlenratio: float, - minlenratio: float, - batch_size: int, - beam_size: int, - ngpu: int, - ctc_weight: float, - lm_weight: float, - penalty: float, - log_level: Union[int, str], - data_path_and_name_and_type, - asr_train_config: Optional[str], - asr_model_file: Optional[str], - cmvn_file: Optional[str] = None, - raw_inputs: Union[np.ndarray, torch.Tensor] = None, - lm_train_config: Optional[str] = None, - lm_file: Optional[str] = None, - token_type: Optional[str] = None, - key_file: Optional[str] = None, - word_lm_train_config: Optional[str] = None, - bpemodel: Optional[str] = None, - allow_variable_data_keys: bool = False, - streaming: bool = False, - output_dir: Optional[str] = None, - dtype: str = "float32", - seed: int = 0, - ngram_weight: float = 0.9, - nbest: int = 1, - num_workers: int = 1, - - **kwargs, -): - inference_pipeline = inference_modelscope( - maxlenratio=maxlenratio, - minlenratio=minlenratio, - batch_size=batch_size, - beam_size=beam_size, - ngpu=ngpu, - ctc_weight=ctc_weight, - lm_weight=lm_weight, - penalty=penalty, - log_level=log_level, - asr_train_config=asr_train_config, - asr_model_file=asr_model_file, - cmvn_file=cmvn_file, - raw_inputs=raw_inputs, - lm_train_config=lm_train_config, - lm_file=lm_file, - token_type=token_type, - key_file=key_file, - word_lm_train_config=word_lm_train_config, - bpemodel=bpemodel, - allow_variable_data_keys=allow_variable_data_keys, - streaming=streaming, - output_dir=output_dir, - dtype=dtype, - seed=seed, - ngram_weight=ngram_weight, - nbest=nbest, - num_workers=num_workers, - - **kwargs, - ) - return inference_pipeline(data_path_and_name_and_type, raw_inputs) - - -def inference_modelscope( - maxlenratio: float, - minlenratio: float, - batch_size: int, - beam_size: int, - ngpu: int, - ctc_weight: float, - lm_weight: float, - penalty: float, - log_level: Union[int, str], - # data_path_and_name_and_type, - asr_train_config: Optional[str], - asr_model_file: Optional[str], - cmvn_file: Optional[str] = None, - lm_train_config: Optional[str] = None, - lm_file: Optional[str] = None, - token_type: Optional[str] = None, - key_file: Optional[str] = None, - word_lm_train_config: Optional[str] = None, - bpemodel: Optional[str] = None, - allow_variable_data_keys: bool = False, - dtype: str = "float32", - seed: int = 0, - ngram_weight: float = 0.9, - nbest: int = 1, - num_workers: int = 1, - output_dir: Optional[str] = None, - param_dict: dict = None, - **kwargs, -): + output_dir: str, + batch_size: int, + dtype: str, + beam_size: int, + ngpu: int, + seed: int, + lm_weight: float, + nbest: int, + num_workers: int, + log_level: Union[int, str], + data_path_and_name_and_type: Sequence[Tuple[str, str, str]], + asr_train_config: Optional[str], + asr_model_file: Optional[str], + beam_search_config: Optional[dict], + lm_train_config: Optional[str], + lm_file: Optional[str], + model_tag: Optional[str], + token_type: Optional[str], + bpemodel: Optional[str], + key_file: Optional[str], + allow_variable_data_keys: bool, + quantize_asr_model: Optional[bool], + quantize_modules: Optional[List[str]], + quantize_dtype: Optional[str], + streaming: Optional[bool], + simu_streaming: Optional[bool], + chunk_size: Optional[int], + left_context: Optional[int], + right_context: Optional[int], + display_partial_hypotheses: bool, + **kwargs, +) -> None: + """Transducer model inference. + Args: + output_dir: Output directory path. + batch_size: Batch decoding size. + dtype: Data type. + beam_size: Beam size. + ngpu: Number of GPUs. + seed: Random number generator seed. + lm_weight: Weight of language model. + nbest: Number of final hypothesis. + num_workers: Number of workers. + log_level: Level of verbose for logs. + data_path_and_name_and_type: + asr_train_config: ASR model training config path. + asr_model_file: ASR model path. + beam_search_config: Beam search config path. + lm_train_config: Language Model training config path. + lm_file: Language Model path. + model_tag: Model tag. + token_type: Type of token units. + bpemodel: BPE model path. + key_file: File key. + allow_variable_data_keys: Whether to allow variable data keys. + quantize_asr_model: Whether to apply dynamic quantization to ASR model. + quantize_modules: List of module names to apply dynamic quantization on. + quantize_dtype: Dynamic quantization data type. + streaming: Whether to perform chunk-by-chunk inference. + chunk_size: Number of frames in chunk AFTER subsampling. + left_context: Number of frames in left context AFTER subsampling. + right_context: Number of frames in right context AFTER subsampling. + display_partial_hypotheses: Whether to display partial hypotheses. + """ assert check_argument_types() - if word_lm_train_config is not None: - raise NotImplementedError("Word LM is not implemented") + if batch_size > 1: + raise NotImplementedError("batch decoding is not implemented") if ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") @@ -605,19 +557,11 @@ def inference_modelscope( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) - - export_mode = False - if param_dict is not None: - hotword_list_or_file = param_dict.get('hotword') - export_mode = param_dict.get("export_mode", False) - else: - hotword_list_or_file = None - if ngpu >= 1 and torch.cuda.is_available(): + if ngpu >= 1: device = "cuda" else: device = "cpu" - batch_size = 1 # 1. Set random-seed set_all_random_seed(seed) @@ -626,144 +570,105 @@ def inference_modelscope( speech2text_kwargs = dict( asr_train_config=asr_train_config, asr_model_file=asr_model_file, - cmvn_file=cmvn_file, + beam_search_config=beam_search_config, lm_train_config=lm_train_config, lm_file=lm_file, token_type=token_type, bpemodel=bpemodel, device=device, - maxlenratio=maxlenratio, - minlenratio=minlenratio, dtype=dtype, beam_size=beam_size, - ctc_weight=ctc_weight, lm_weight=lm_weight, - ngram_weight=ngram_weight, - penalty=penalty, nbest=nbest, - hotword_list_or_file=hotword_list_or_file, + quantize_asr_model=quantize_asr_model, + quantize_modules=quantize_modules, + quantize_dtype=quantize_dtype, + streaming=streaming, + simu_streaming=simu_streaming, + chunk_size=chunk_size, + left_context=left_context, + right_context=right_context, + ) + speech2text = Speech2Text.from_pretrained( + model_tag=model_tag, + **speech2text_kwargs, ) - if export_mode: - speech2text = Speech2TextExport(**speech2text_kwargs) - else: - speech2text = Speech2Text(**speech2text_kwargs) - def _forward( - data_path_and_name_and_type, - raw_inputs: Union[np.ndarray, torch.Tensor] = None, - output_dir_v2: Optional[str] = None, - fs: dict = None, - param_dict: dict = None, - **kwargs, - ): - - hotword_list_or_file = None - if param_dict is not None: - hotword_list_or_file = param_dict.get('hotword') - if 'hotword' in kwargs: - hotword_list_or_file = kwargs['hotword'] - if hotword_list_or_file is not None or 'hotword' in kwargs: - speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file) - cache = None - if 'cache' in param_dict: - cache = param_dict['cache'] - # 3. Build data-iterator - if data_path_and_name_and_type is None and raw_inputs is not None: - if isinstance(raw_inputs, torch.Tensor): - raw_inputs = raw_inputs.numpy() - data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] - loader = ASRTask.build_streaming_iterator( - data_path_and_name_and_type, - dtype=dtype, - fs=fs, - batch_size=batch_size, - key_file=key_file, - num_workers=num_workers, - preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), - collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), - allow_variable_data_keys=allow_variable_data_keys, - inference=True, - ) - - forward_time_total = 0.0 - length_total = 0.0 - finish_count = 0 - file_count = 1 - # 7 .Start for-loop - # FIXME(kamo): The output format should be discussed about - asr_result_list = [] - output_path = output_dir_v2 if output_dir_v2 is not None else output_dir - if output_path is not None: - writer = DatadirWriter(output_path) - else: - writer = None + # 3. Build data-iterator + loader = ASRTransducerTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=ASRTransducerTask.build_preprocess_fn( + speech2text.asr_train_args, False + ), + collate_fn=ASRTransducerTask.build_collate_fn( + speech2text.asr_train_args, False + ), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + # 4 .Start for-loop + with DatadirWriter(output_dir) as writer: for keys, batch in loader: assert isinstance(batch, dict), type(batch) assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) assert len(keys) == _bs, f"{len(keys)} != {_bs}" - # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")} + batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} + assert len(batch.keys()) == 1 - logging.info("decoding, utt_id: {}".format(keys)) - # N-best list of (text, token, token_int, hyp_object) + try: + if speech2text.streaming: + speech = batch["speech"] - time_beg = time.time() - results = speech2text(cache=cache, **batch) - if len(results) < 1: - hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) - results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest - time_end = time.time() - forward_time = time_end - time_beg - lfr_factor = results[0][-1] - length = results[0][-2] - forward_time_total += forward_time - length_total += length - rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time, 100 * forward_time / (length * lfr_factor)) - logging.info(rtf_cur) + _steps = len(speech) // speech2text._ctx + _end = 0 + for i in range(_steps): + _end = (i + 1) * speech2text._ctx - for batch_id in range(_bs): - result = [results[batch_id][:-2]] + speech2text.streaming_decode( + speech[i * speech2text._ctx : _end], is_final=False + ) - key = keys[batch_id] - for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), result): - # Create a directory: outdir/{n}best_recog - if writer is not None: - ibest_writer = writer[f"{n}best_recog"] + final_hyps = speech2text.streaming_decode( + speech[_end : len(speech)], is_final=True + ) + elif speech2text.simu_streaming: + final_hyps = speech2text.simu_streaming_decode(**batch) + else: + final_hyps = speech2text(**batch) - # Write the result to each file - ibest_writer["token"][key] = " ".join(token) - # ibest_writer["token_int"][key] = " ".join(map(str, token_int)) - ibest_writer["score"][key] = str(hyp.score) - ibest_writer["rtf"][key] = rtf_cur + results = speech2text.hypotheses_to_results(final_hyps) + except TooShortUttError as e: + logging.warning(f"Utterance {keys} {e}") + hyp = Hypothesis(score=0.0, yseq=[], dec_state=None) + results = [[" ", [""], [2], hyp]] * nbest - if text is not None: - text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) - item = {'key': key, 'value': text_postprocessed} - asr_result_list.append(item) - finish_count += 1 - # asr_utils.print_progress(finish_count / file_count) - if writer is not None: - ibest_writer["text"][key] = text_postprocessed + key = keys[0] + for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results): + ibest_writer = writer[f"{n}best_recog"] - logging.info("decoding, utt: {}, predictions: {}".format(key, text)) - rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor)) - logging.info(rtf_avg) - if writer is not None: - ibest_writer["rtf"]["rtf_avf"] = rtf_avg - return asr_result_list + ibest_writer["token"][key] = " ".join(token) + ibest_writer["token_int"][key] = " ".join(map(str, token_int)) + ibest_writer["score"][key] = str(hyp.score) - return _forward + if text is not None: + ibest_writer["text"][key] = text def get_parser(): + """Get Transducer model inference parser.""" + parser = config_argparse.ArgumentParser( - description="ASR Decoding", + description="ASR Transducer Decoding", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - # Note(kamo): Use '_' instead of '-' as separator. - # '-' is confusing if written in yaml. parser.add_argument( "--log_level", type=lambda x: x.upper(), @@ -792,17 +697,12 @@ def get_parser(): default=1, help="The number of workers used for DataLoader", ) - parser.add_argument( - "--hotword", - type=str_or_none, - default=None, - help="hotword file path or hotwords seperated by space" - ) + group = parser.add_argument_group("Input data related") group.add_argument( "--data_path_and_name_and_type", type=str2triple_str, - required=False, + required=True, action="append", ) group.add_argument("--key_file", type=str_or_none) @@ -819,11 +719,6 @@ def get_parser(): type=str, help="ASR model parameter file", ) - group.add_argument( - "--cmvn_file", - type=str, - help="Global cmvn file", - ) group.add_argument( "--lm_train_config", type=str, @@ -834,26 +729,11 @@ def get_parser(): type=str, help="LM parameter file", ) - group.add_argument( - "--word_lm_train_config", - type=str, - help="Word LM training configuration", - ) - group.add_argument( - "--word_lm_file", - type=str, - help="Word LM parameter file", - ) - group.add_argument( - "--ngram_file", - type=str, - help="N-gram parameter file", - ) group.add_argument( "--model_tag", type=str, help="Pretrained model tag. If specify this option, *_train_config and " - "*_file will be overwritten", + "*_file will be overwritten", ) group = parser.add_argument_group("Beam-search related") @@ -864,42 +744,13 @@ def get_parser(): help="The batch size for inference", ) group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") - group.add_argument("--beam_size", type=int, default=20, help="Beam size") - group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty") - group.add_argument( - "--maxlenratio", - type=float, - default=0.0, - help="Input length ratio to obtain max output length. " - "If maxlenratio=0.0 (default), it uses a end-detect " - "function " - "to automatically find maximum hypothesis lengths." - "If maxlenratio<0.0, its absolute value is interpreted" - "as a constant max output length", - ) - group.add_argument( - "--minlenratio", - type=float, - default=0.0, - help="Input length ratio to obtain min output length", - ) - group.add_argument( - "--ctc_weight", - type=float, - default=0.5, - help="CTC weight in joint decoding", - ) + group.add_argument("--beam_size", type=int, default=5, help="Beam size") group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") - group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") - group.add_argument("--streaming", type=str2bool, default=False) - group.add_argument( - "--frontend_conf", - default=None, - help="", + "--beam_search_config", + default={}, + help="The keyword arguments for transducer beam search.", ) - group.add_argument("--raw_inputs", type=list, default=None) - # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}]) group = parser.add_argument_group("Text converter related") group.add_argument( @@ -908,14 +759,77 @@ def get_parser(): default=None, choices=["char", "bpe", None], help="The token type for ASR model. " - "If not given, refers from the training args", + "If not given, refers from the training args", ) group.add_argument( "--bpemodel", type=str_or_none, default=None, help="The model path of sentencepiece. " - "If not given, refers from the training args", + "If not given, refers from the training args", + ) + + group = parser.add_argument_group("Dynamic quantization related") + parser.add_argument( + "--quantize_asr_model", + type=bool, + default=False, + help="Apply dynamic quantization to ASR model.", + ) + parser.add_argument( + "--quantize_modules", + nargs="*", + default=None, + help="""Module names to apply dynamic quantization on. + The module names are provided as a list, where each name is separated + by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]). + Each specified name should be an attribute of 'torch.nn', e.g.: + torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""", + ) + parser.add_argument( + "--quantize_dtype", + type=str, + default="qint8", + choices=["float16", "qint8"], + help="Dtype for dynamic quantization.", + ) + + group = parser.add_argument_group("Streaming related") + parser.add_argument( + "--streaming", + type=bool, + default=False, + help="Whether to perform chunk-by-chunk inference.", + ) + parser.add_argument( + "--simu_streaming", + type=bool, + default=False, + help="Whether to simulate chunk-by-chunk inference.", + ) + parser.add_argument( + "--chunk_size", + type=int, + default=16, + help="Number of frames in chunk AFTER subsampling.", + ) + parser.add_argument( + "--left_context", + type=int, + default=32, + help="Number of frames in left context of the chunk AFTER subsampling.", + ) + parser.add_argument( + "--right_context", + type=int, + default=0, + help="Number of frames in right context of the chunk AFTER subsampling.", + ) + parser.add_argument( + "--display_partial_hypotheses", + type=bool, + default=False, + help="Whether to display partial hypotheses during chunk-by-chunk inference.", ) return parser @@ -923,24 +837,15 @@ def get_parser(): def main(cmd=None): print(get_commandline_args(), file=sys.stderr) + parser = get_parser() args = parser.parse_args(cmd) - param_dict = {'hotword': args.hotword} kwargs = vars(args) + kwargs.pop("config", None) - kwargs['param_dict'] = param_dict inference(**kwargs) if __name__ == "__main__": main() - # from modelscope.pipelines import pipeline - # from modelscope.utils.constant import Tasks - # - # inference_16k_pipline = pipeline( - # task=Tasks.auto_speech_recognition, - # model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') - # - # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav') - # print(rec_result) diff --git a/funasr/bin/asr_train_transducer.py b/funasr/bin/asr_train_transducer.py new file mode 100755 index 000000000..9b6d287dd --- /dev/null +++ b/funasr/bin/asr_train_transducer.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +import os + +from funasr.tasks.asr_transducer import ASRTransducerTask + + +# for ASR Training +def parse_args(): + parser = ASRTransducerTask.get_parser() + parser.add_argument( + "--gpu_id", + type=int, + default=0, + help="local gpu id.", + ) + args = parser.parse_args() + return args + + +def main(args=None, cmd=None): + # for ASR Training + ASRTransducerTask.main(args=args, cmd=cmd) + + +if __name__ == '__main__': + args = parse_args() + + # setup local gpu_id + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) + + # DDP settings + if args.ngpu > 1: + args.distributed = True + else: + args.distributed = False + assert args.num_worker_count == 1 + + # re-compute batch size: when dataset type is small + if args.dataset_type == "small": + if args.batch_size is not None: + args.batch_size = args.batch_size * args.ngpu + if args.batch_bins is not None: + args.batch_bins = args.batch_bins * args.ngpu + + main(args=args) diff --git a/funasr/models_transducer/__init__.py b/funasr/models_transducer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/models_transducer/activation.py b/funasr/models_transducer/activation.py new file mode 100644 index 000000000..82cda1251 --- /dev/null +++ b/funasr/models_transducer/activation.py @@ -0,0 +1,213 @@ +"""Activation functions for Transducer.""" + +import torch +from packaging.version import parse as V + + +def get_activation( + activation_type: str, + ftswish_threshold: float = -0.2, + ftswish_mean_shift: float = 0.0, + hardtanh_min_val: int = -1.0, + hardtanh_max_val: int = 1.0, + leakyrelu_neg_slope: float = 0.01, + smish_alpha: float = 1.0, + smish_beta: float = 1.0, + softplus_beta: float = 1.0, + softplus_threshold: int = 20, + swish_beta: float = 1.0, +) -> torch.nn.Module: + """Return activation function. + + Args: + activation_type: Activation function type. + ftswish_threshold: Threshold value for FTSwish activation formulation. + ftswish_mean_shift: Mean shifting value for FTSwish activation formulation. + hardtanh_min_val: Minimum value of the linear region range for HardTanh. + hardtanh_max_val: Maximum value of the linear region range for HardTanh. + leakyrelu_neg_slope: Negative slope value for LeakyReLU activation formulation. + smish_alpha: Alpha value for Smish activation fomulation. + smish_beta: Beta value for Smish activation formulation. + softplus_beta: Beta value for softplus activation formulation in Mish. + softplus_threshold: Values above this revert to a linear function in Mish. + swish_beta: Beta value for Swish variant formulation. + + Returns: + : Activation function. + + """ + torch_version = V(torch.__version__) + + activations = { + "ftswish": ( + FTSwish, + {"threshold": ftswish_threshold, "mean_shift": ftswish_mean_shift}, + ), + "hardtanh": ( + torch.nn.Hardtanh, + {"min_val": hardtanh_min_val, "max_val": hardtanh_max_val}, + ), + "leaky_relu": (torch.nn.LeakyReLU, {"negative_slope": leakyrelu_neg_slope}), + "mish": ( + Mish, + { + "softplus_beta": softplus_beta, + "softplus_threshold": softplus_threshold, + "use_builtin": torch_version >= V("1.9"), + }, + ), + "relu": (torch.nn.ReLU, {}), + "selu": (torch.nn.SELU, {}), + "smish": (Smish, {"alpha": smish_alpha, "beta": smish_beta}), + "swish": ( + Swish, + {"beta": swish_beta, "use_builtin": torch_version >= V("1.8")}, + ), + "tanh": (torch.nn.Tanh, {}), + "identity": (torch.nn.Identity, {}), + } + + act_func, act_args = activations[activation_type] + + return act_func(**act_args) + + +class FTSwish(torch.nn.Module): + """Flatten-T Swish activation definition. + + FTSwish(x) = x * sigmoid(x) + threshold + where FTSwish(x) < 0 = threshold + + Reference: https://arxiv.org/abs/1812.06247 + + Args: + threshold: Threshold value for FTSwish activation formulation. (threshold < 0) + mean_shift: Mean shifting value for FTSwish activation formulation. + (applied only if != 0, disabled by default) + + """ + + def __init__(self, threshold: float = -0.2, mean_shift: float = 0) -> None: + super().__init__() + + assert threshold < 0, "FTSwish threshold parameter should be < 0." + + self.threshold = threshold + self.mean_shift = mean_shift + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward computation.""" + x = (x * torch.sigmoid(x)) + self.threshold + x = torch.where(x >= 0, x, torch.tensor([self.threshold], device=x.device)) + + if self.mean_shift != 0: + x.sub_(self.mean_shift) + + return x + + +class Mish(torch.nn.Module): + """Mish activation definition. + + Mish(x) = x * tanh(softplus(x)) + + Reference: https://arxiv.org/abs/1908.08681. + + Args: + softplus_beta: Beta value for softplus activation formulation. + (Usually 0 > softplus_beta >= 2) + softplus_threshold: Values above this revert to a linear function. + (Usually 10 > softplus_threshold >= 20) + use_builtin: Whether to use PyTorch activation function if available. + + """ + + def __init__( + self, + softplus_beta: float = 1.0, + softplus_threshold: int = 20, + use_builtin: bool = False, + ) -> None: + super().__init__() + + if use_builtin: + self.mish = torch.nn.Mish() + else: + self.tanh = torch.nn.Tanh() + self.softplus = torch.nn.Softplus( + beta=softplus_beta, threshold=softplus_threshold + ) + + self.mish = lambda x: x * self.tanh(self.softplus(x)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward computation.""" + return self.mish(x) + + +class Smish(torch.nn.Module): + """Smish activation definition. + + Smish(x) = (alpha * x) * tanh(log(1 + sigmoid(beta * x))) + where alpha > 0 and beta > 0 + + Reference: https://www.mdpi.com/2079-9292/11/4/540/htm. + + Args: + alpha: Alpha value for Smish activation fomulation. + (Usually, alpha = 1. If alpha <= 0, set value to 1). + beta: Beta value for Smish activation formulation. + (Usually, beta = 1. If beta <= 0, set value to 1). + + """ + + def __init__(self, alpha: float = 1.0, beta: float = 1.0) -> None: + super().__init__() + + self.tanh = torch.nn.Tanh() + + self.alpha = alpha if alpha > 0 else 1 + self.beta = beta if beta > 0 else 1 + + self.smish = lambda x: (self.alpha * x) * self.tanh( + torch.log(1 + torch.sigmoid((self.beta * x))) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward computation.""" + return self.smish(x) + + +class Swish(torch.nn.Module): + """Swish activation definition. + + Swish(x) = (beta * x) * sigmoid(x) + where beta = 1 defines standard Swish activation. + + References: + https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1. + E-swish variant: https://arxiv.org/abs/1801.07145. + + Args: + beta: Beta parameter for E-Swish. + (beta >= 1. If beta < 1, use standard Swish). + use_builtin: Whether to use PyTorch function if available. + + """ + + def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None: + super().__init__() + + self.beta = beta + + if beta > 1: + self.swish = lambda x: (self.beta * x) * torch.sigmoid(x) + else: + if use_builtin: + self.swish = torch.nn.SiLU() + else: + self.swish = lambda x: x * torch.sigmoid(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward computation.""" + return self.swish(x) diff --git a/funasr/models_transducer/beam_search_transducer.py b/funasr/models_transducer/beam_search_transducer.py new file mode 100644 index 000000000..8e234e45a --- /dev/null +++ b/funasr/models_transducer/beam_search_transducer.py @@ -0,0 +1,705 @@ +"""Search algorithms for Transducer models.""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + +from funasr.models_transducer.decoder.abs_decoder import AbsDecoder +from funasr.models_transducer.joint_network import JointNetwork + + +@dataclass +class Hypothesis: + """Default hypothesis definition for Transducer search algorithms. + + Args: + score: Total log-probability. + yseq: Label sequence as integer ID sequence. + dec_state: RNNDecoder or StatelessDecoder state. + ((N, 1, D_dec), (N, 1, D_dec) or None) or None + lm_state: RNNLM state. ((N, D_lm), (N, D_lm)) or None + + """ + + score: float + yseq: List[int] + dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None + lm_state: Optional[Union[Dict[str, Any], List[Any]]] = None + + +@dataclass +class ExtendedHypothesis(Hypothesis): + """Extended hypothesis definition for NSC beam search and mAES. + + Args: + : Hypothesis dataclass arguments. + dec_out: Decoder output sequence. (B, D_dec) + lm_score: Log-probabilities of the LM for given label. (vocab_size) + + """ + + dec_out: torch.Tensor = None + lm_score: torch.Tensor = None + + +class BeamSearchTransducer: + """Beam search implementation for Transducer. + + Args: + decoder: Decoder module. + joint_network: Joint network module. + beam_size: Size of the beam. + lm: LM class. + lm_weight: LM weight for soft fusion. + search_type: Search algorithm to use during inference. + max_sym_exp: Number of maximum symbol expansions at each time step. (TSD) + u_max: Maximum expected target sequence length. (ALSD) + nstep: Number of maximum expansion steps at each time step. (mAES) + expansion_gamma: Allowed logp difference for prune-by-value method. (mAES) + expansion_beta: + Number of additional candidates for expanded hypotheses selection. (mAES) + score_norm: Normalize final scores by length. + nbest: Number of final hypothesis. + streaming: Whether to perform chunk-by-chunk beam search. + + """ + + def __init__( + self, + decoder: AbsDecoder, + joint_network: JointNetwork, + beam_size: int, + lm: Optional[torch.nn.Module] = None, + lm_weight: float = 0.1, + search_type: str = "default", + max_sym_exp: int = 3, + u_max: int = 50, + nstep: int = 2, + expansion_gamma: float = 2.3, + expansion_beta: int = 2, + score_norm: bool = False, + nbest: int = 1, + streaming: bool = False, + ) -> None: + """Construct a BeamSearchTransducer object.""" + super().__init__() + + self.decoder = decoder + self.joint_network = joint_network + + self.vocab_size = decoder.vocab_size + + assert beam_size <= self.vocab_size, ( + "beam_size (%d) should be smaller than or equal to vocabulary size (%d)." + % ( + beam_size, + self.vocab_size, + ) + ) + self.beam_size = beam_size + + if search_type == "default": + self.search_algorithm = self.default_beam_search + elif search_type == "tsd": + assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % ( + max_sym_exp + ) + self.max_sym_exp = max_sym_exp + + self.search_algorithm = self.time_sync_decoding + elif search_type == "alsd": + assert not streaming, "ALSD is not available in streaming mode." + + assert u_max >= 0, "u_max should be a positive integer, a portion of max_T." + self.u_max = u_max + + self.search_algorithm = self.align_length_sync_decoding + elif search_type == "maes": + assert self.vocab_size >= beam_size + expansion_beta, ( + "beam_size (%d) + expansion_beta (%d) " + " should be smaller than or equal to vocab size (%d)." + % (beam_size, expansion_beta, self.vocab_size) + ) + self.max_candidates = beam_size + expansion_beta + + self.nstep = nstep + self.expansion_gamma = expansion_gamma + + self.search_algorithm = self.modified_adaptive_expansion_search + else: + raise NotImplementedError( + "Specified search type (%s) is not supported." % search_type + ) + + self.use_lm = lm is not None + + if self.use_lm: + assert hasattr(lm, "rnn_type"), "Transformer LM is currently not supported." + + self.sos = self.vocab_size - 1 + + self.lm = lm + self.lm_weight = lm_weight + + self.score_norm = score_norm + self.nbest = nbest + + self.reset_inference_cache() + + def __call__( + self, + enc_out: torch.Tensor, + is_final: bool = True, + ) -> List[Hypothesis]: + """Perform beam search. + + Args: + enc_out: Encoder output sequence. (T, D_enc) + is_final: Whether enc_out is the final chunk of data. + + Returns: + nbest_hyps: N-best decoding results + + """ + self.decoder.set_device(enc_out.device) + + hyps = self.search_algorithm(enc_out) + + if is_final: + self.reset_inference_cache() + + return self.sort_nbest(hyps) + + self.search_cache = hyps + + return hyps + + def reset_inference_cache(self) -> None: + """Reset cache for decoder scoring and streaming.""" + self.decoder.score_cache = {} + self.search_cache = None + + def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]: + """Sort in-place hypotheses by score or score given sequence length. + + Args: + hyps: Hypothesis. + + Return: + hyps: Sorted hypothesis. + + """ + if self.score_norm: + hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True) + else: + hyps.sort(key=lambda x: x.score, reverse=True) + + return hyps[: self.nbest] + + def recombine_hyps(self, hyps: List[Hypothesis]) -> List[Hypothesis]: + """Recombine hypotheses with same label ID sequence. + + Args: + hyps: Hypotheses. + + Returns: + final: Recombined hypotheses. + + """ + final = {} + + for hyp in hyps: + str_yseq = "_".join(map(str, hyp.yseq)) + + if str_yseq in final: + final[str_yseq].score = np.logaddexp(final[str_yseq].score, hyp.score) + else: + final[str_yseq] = hyp + + return [*final.values()] + + def select_k_expansions( + self, + hyps: List[ExtendedHypothesis], + topk_idx: torch.Tensor, + topk_logp: torch.Tensor, + ) -> List[ExtendedHypothesis]: + """Return K hypotheses candidates for expansion from a list of hypothesis. + + K candidates are selected according to the extended hypotheses probabilities + and a prune-by-value method. Where K is equal to beam_size + beta. + + Args: + hyps: Hypotheses. + topk_idx: Indices of candidates hypothesis. + topk_logp: Log-probabilities of candidates hypothesis. + + Returns: + k_expansions: Best K expansion hypotheses candidates. + + """ + k_expansions = [] + + for i, hyp in enumerate(hyps): + hyp_i = [ + (int(k), hyp.score + float(v)) + for k, v in zip(topk_idx[i], topk_logp[i]) + ] + k_best_exp = max(hyp_i, key=lambda x: x[1])[1] + + k_expansions.append( + sorted( + filter( + lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i + ), + key=lambda x: x[1], + reverse=True, + ) + ) + + return k_expansions + + def create_lm_batch_inputs(self, hyps_seq: List[List[int]]) -> torch.Tensor: + """Make batch of inputs with left padding for LM scoring. + + Args: + hyps_seq: Hypothesis sequences. + + Returns: + : Padded batch of sequences. + + """ + max_len = max([len(h) for h in hyps_seq]) + + return torch.LongTensor( + [[self.sos] + ([0] * (max_len - len(h))) + h[1:] for h in hyps_seq], + device=self.decoder.device, + ) + + def default_beam_search(self, enc_out: torch.Tensor) -> List[Hypothesis]: + """Beam search implementation without prefix search. + + Modified from https://arxiv.org/pdf/1211.3711.pdf + + Args: + enc_out: Encoder output sequence. (T, D) + + Returns: + nbest_hyps: N-best hypothesis. + + """ + beam_k = min(self.beam_size, (self.vocab_size - 1)) + max_t = len(enc_out) + + if self.search_cache is not None: + kept_hyps = self.search_cache + else: + kept_hyps = [ + Hypothesis( + score=0.0, + yseq=[0], + dec_state=self.decoder.init_state(1), + ) + ] + + for t in range(max_t): + hyps = kept_hyps + kept_hyps = [] + + while True: + max_hyp = max(hyps, key=lambda x: x.score) + hyps.remove(max_hyp) + + label = torch.full( + (1, 1), + max_hyp.yseq[-1], + dtype=torch.long, + device=self.decoder.device, + ) + dec_out, state = self.decoder.score( + label, + max_hyp.yseq, + max_hyp.dec_state, + ) + + logp = torch.log_softmax( + self.joint_network(enc_out[t : t + 1, :], dec_out), + dim=-1, + ).squeeze(0) + top_k = logp[1:].topk(beam_k, dim=-1) + + kept_hyps.append( + Hypothesis( + score=(max_hyp.score + float(logp[0:1])), + yseq=max_hyp.yseq, + dec_state=max_hyp.dec_state, + lm_state=max_hyp.lm_state, + ) + ) + + if self.use_lm: + lm_scores, lm_state = self.lm.score( + torch.LongTensor( + [self.sos] + max_hyp.yseq[1:], device=self.decoder.device + ), + max_hyp.lm_state, + None, + ) + else: + lm_state = max_hyp.lm_state + + for logp, k in zip(*top_k): + score = max_hyp.score + float(logp) + + if self.use_lm: + score += self.lm_weight * lm_scores[k + 1] + + hyps.append( + Hypothesis( + score=score, + yseq=max_hyp.yseq + [int(k + 1)], + dec_state=state, + lm_state=lm_state, + ) + ) + + hyps_max = float(max(hyps, key=lambda x: x.score).score) + kept_most_prob = sorted( + [hyp for hyp in kept_hyps if hyp.score > hyps_max], + key=lambda x: x.score, + ) + if len(kept_most_prob) >= self.beam_size: + kept_hyps = kept_most_prob + break + + return kept_hyps + + def align_length_sync_decoding( + self, + enc_out: torch.Tensor, + ) -> List[Hypothesis]: + """Alignment-length synchronous beam search implementation. + + Based on https://ieeexplore.ieee.org/document/9053040 + + Args: + h: Encoder output sequences. (T, D) + + Returns: + nbest_hyps: N-best hypothesis. + + """ + t_max = int(enc_out.size(0)) + u_max = min(self.u_max, (t_max - 1)) + + B = [Hypothesis(yseq=[0], score=0.0, dec_state=self.decoder.init_state(1))] + final = [] + + if self.use_lm: + B[0].lm_state = self.lm.zero_state() + + for i in range(t_max + u_max): + A = [] + + B_ = [] + B_enc_out = [] + for hyp in B: + u = len(hyp.yseq) - 1 + t = i - u + + if t > (t_max - 1): + continue + + B_.append(hyp) + B_enc_out.append((t, enc_out[t])) + + if B_: + beam_enc_out = torch.stack([b[1] for b in B_enc_out]) + beam_dec_out, beam_state = self.decoder.batch_score(B_) + + beam_logp = torch.log_softmax( + self.joint_network(beam_enc_out, beam_dec_out), + dim=-1, + ) + beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1) + + if self.use_lm: + beam_lm_scores, beam_lm_states = self.lm.batch_score( + self.create_lm_batch_inputs([b.yseq for b in B_]), + [b.lm_state for b in B_], + None, + ) + + for i, hyp in enumerate(B_): + new_hyp = Hypothesis( + score=(hyp.score + float(beam_logp[i, 0])), + yseq=hyp.yseq[:], + dec_state=hyp.dec_state, + lm_state=hyp.lm_state, + ) + + A.append(new_hyp) + + if B_enc_out[i][0] == (t_max - 1): + final.append(new_hyp) + + for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): + new_hyp = Hypothesis( + score=(hyp.score + float(logp)), + yseq=(hyp.yseq[:] + [int(k)]), + dec_state=self.decoder.select_state(beam_state, i), + lm_state=hyp.lm_state, + ) + + if self.use_lm: + new_hyp.score += self.lm_weight * beam_lm_scores[i, k] + new_hyp.lm_state = beam_lm_states[i] + + A.append(new_hyp) + + B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size] + B = self.recombine_hyps(B) + + if final: + return final + + return B + + def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]: + """Time synchronous beam search implementation. + + Based on https://ieeexplore.ieee.org/document/9053040 + + Args: + enc_out: Encoder output sequence. (T, D) + + Returns: + nbest_hyps: N-best hypothesis. + + """ + if self.search_cache is not None: + B = self.search_cache + else: + B = [ + Hypothesis( + yseq=[0], + score=0.0, + dec_state=self.decoder.init_state(1), + ) + ] + + if self.use_lm: + B[0].lm_state = self.lm.zero_state() + + for enc_out_t in enc_out: + A = [] + C = B + + enc_out_t = enc_out_t.unsqueeze(0) + + for v in range(self.max_sym_exp): + D = [] + + beam_dec_out, beam_state = self.decoder.batch_score(C) + + beam_logp = torch.log_softmax( + self.joint_network(enc_out_t, beam_dec_out), + dim=-1, + ) + beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1) + + seq_A = [h.yseq for h in A] + + for i, hyp in enumerate(C): + if hyp.yseq not in seq_A: + A.append( + Hypothesis( + score=(hyp.score + float(beam_logp[i, 0])), + yseq=hyp.yseq[:], + dec_state=hyp.dec_state, + lm_state=hyp.lm_state, + ) + ) + else: + dict_pos = seq_A.index(hyp.yseq) + + A[dict_pos].score = np.logaddexp( + A[dict_pos].score, (hyp.score + float(beam_logp[i, 0])) + ) + + if v < (self.max_sym_exp - 1): + if self.use_lm: + beam_lm_scores, beam_lm_states = self.lm.batch_score( + self.create_lm_batch_inputs([c.yseq for c in C]), + [c.lm_state for c in C], + None, + ) + + for i, hyp in enumerate(C): + for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): + new_hyp = Hypothesis( + score=(hyp.score + float(logp)), + yseq=(hyp.yseq + [int(k)]), + dec_state=self.decoder.select_state(beam_state, i), + lm_state=hyp.lm_state, + ) + + if self.use_lm: + new_hyp.score += self.lm_weight * beam_lm_scores[i, k] + new_hyp.lm_state = beam_lm_states[i] + + D.append(new_hyp) + + C = sorted(D, key=lambda x: x.score, reverse=True)[: self.beam_size] + + B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size] + + return B + + def modified_adaptive_expansion_search( + self, + enc_out: torch.Tensor, + ) -> List[ExtendedHypothesis]: + """Modified version of Adaptive Expansion Search (mAES). + + Based on AES (https://ieeexplore.ieee.org/document/9250505) and + NSC (https://arxiv.org/abs/2201.05420). + + Args: + enc_out: Encoder output sequence. (T, D_enc) + + Returns: + nbest_hyps: N-best hypothesis. + + """ + if self.search_cache is not None: + kept_hyps = self.search_cache + else: + init_tokens = [ + ExtendedHypothesis( + yseq=[0], + score=0.0, + dec_state=self.decoder.init_state(1), + ) + ] + + beam_dec_out, beam_state = self.decoder.batch_score( + init_tokens, + ) + + if self.use_lm: + beam_lm_scores, beam_lm_states = self.lm.batch_score( + self.create_lm_batch_inputs([h.yseq for h in init_tokens]), + [h.lm_state for h in init_tokens], + None, + ) + + lm_state = beam_lm_states[0] + lm_score = beam_lm_scores[0] + else: + lm_state = None + lm_score = None + + kept_hyps = [ + ExtendedHypothesis( + yseq=[0], + score=0.0, + dec_state=self.decoder.select_state(beam_state, 0), + dec_out=beam_dec_out[0], + lm_state=lm_state, + lm_score=lm_score, + ) + ] + + for enc_out_t in enc_out: + hyps = kept_hyps + kept_hyps = [] + + beam_enc_out = enc_out_t.unsqueeze(0) + + list_b = [] + for n in range(self.nstep): + beam_dec_out = torch.stack([h.dec_out for h in hyps]) + + beam_logp, beam_idx = torch.log_softmax( + self.joint_network(beam_enc_out, beam_dec_out), + dim=-1, + ).topk(self.max_candidates, dim=-1) + + k_expansions = self.select_k_expansions(hyps, beam_idx, beam_logp) + + list_exp = [] + for i, hyp in enumerate(hyps): + for k, new_score in k_expansions[i]: + new_hyp = ExtendedHypothesis( + yseq=hyp.yseq[:], + score=new_score, + dec_out=hyp.dec_out, + dec_state=hyp.dec_state, + lm_state=hyp.lm_state, + lm_score=hyp.lm_score, + ) + + if k == 0: + list_b.append(new_hyp) + else: + new_hyp.yseq.append(int(k)) + + if self.use_lm: + new_hyp.score += self.lm_weight * float(hyp.lm_score[k]) + + list_exp.append(new_hyp) + + if not list_exp: + kept_hyps = sorted( + self.recombine_hyps(list_b), key=lambda x: x.score, reverse=True + )[: self.beam_size] + + break + else: + beam_dec_out, beam_state = self.decoder.batch_score( + list_exp, + ) + + if self.use_lm: + beam_lm_scores, beam_lm_states = self.lm.batch_score( + self.create_lm_batch_inputs([h.yseq for h in list_exp]), + [h.lm_state for h in list_exp], + None, + ) + + if n < (self.nstep - 1): + for i, hyp in enumerate(list_exp): + hyp.dec_out = beam_dec_out[i] + hyp.dec_state = self.decoder.select_state(beam_state, i) + + if self.use_lm: + hyp.lm_state = beam_lm_states[i] + hyp.lm_score = beam_lm_scores[i] + + hyps = list_exp[:] + else: + beam_logp = torch.log_softmax( + self.joint_network(beam_enc_out, beam_dec_out), + dim=-1, + ) + + for i, hyp in enumerate(list_exp): + hyp.score += float(beam_logp[i, 0]) + + hyp.dec_out = beam_dec_out[i] + hyp.dec_state = self.decoder.select_state(beam_state, i) + + if self.use_lm: + hyp.lm_state = beam_lm_states[i] + hyp.lm_score = beam_lm_scores[i] + + kept_hyps = sorted( + self.recombine_hyps(list_b + list_exp), + key=lambda x: x.score, + reverse=True, + )[: self.beam_size] + + return kept_hyps diff --git a/funasr/models_transducer/decoder/__init__.py b/funasr/models_transducer/decoder/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/models_transducer/decoder/abs_decoder.py b/funasr/models_transducer/decoder/abs_decoder.py new file mode 100644 index 000000000..5b4a335be --- /dev/null +++ b/funasr/models_transducer/decoder/abs_decoder.py @@ -0,0 +1,110 @@ +"""Abstract decoder definition for Transducer models.""" + +from abc import ABC, abstractmethod +from typing import Any, List, Optional, Tuple + +import torch + + +class AbsDecoder(torch.nn.Module, ABC): + """Abstract decoder module.""" + + @abstractmethod + def forward(self, labels: torch.Tensor) -> torch.Tensor: + """Encode source label sequences. + + Args: + labels: Label ID sequences. (B, L) + + Returns: + dec_out: Decoder output sequences. (B, T, D_dec) + + """ + raise NotImplementedError + + @abstractmethod + def score( + self, + label: torch.Tensor, + label_sequence: List[int], + dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]], + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]]: + """One-step forward hypothesis. + + Args: + label: Previous label. (1, 1) + label_sequence: Current label sequence. + dec_state: Previous decoder hidden states. + ((N, 1, D_dec), (N, 1, D_dec) or None) or None + + Returns: + dec_out: Decoder output sequence. (1, D_dec) or (1, D_emb) + dec_state: Decoder hidden states. + ((N, 1, D_dec), (N, 1, D_dec) or None) or None + + """ + raise NotImplementedError + + @abstractmethod + def batch_score( + self, + hyps: List[Any], + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]]: + """One-step forward hypotheses. + + Args: + hyps: Hypotheses. + + Returns: + dec_out: Decoder output sequences. (B, D_dec) or (B, D_emb) + states: Decoder hidden states. + ((N, B, D_dec), (N, B, D_dec) or None) or None + + """ + raise NotImplementedError + + @abstractmethod + def set_device(self, device: torch.Tensor) -> None: + """Set GPU device to use. + + Args: + device: Device ID. + + """ + raise NotImplementedError + + @abstractmethod + def init_state( + self, batch_size: int + ) -> Optional[Tuple[torch.Tensor, Optional[torch.tensor]]]: + """Initialize decoder states. + + Args: + batch_size: Batch size. + + Returns: + : Initial decoder hidden states. + ((N, B, D_dec), (N, B, D_dec) or None) or None + + """ + raise NotImplementedError + + @abstractmethod + def select_state( + self, + states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None, + idx: int = 0, + ) -> Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Get specified ID state from batch of states, if provided. + + Args: + states: Decoder hidden states. + ((N, B, D_dec), (N, B, D_dec) or None) or None + idx: State ID to extract. + + Returns: + : Decoder hidden state for given ID. + ((N, 1, D_dec), (N, 1, D_dec) or None) or None + + """ + raise NotImplementedError diff --git a/funasr/models_transducer/decoder/rnn_decoder.py b/funasr/models_transducer/decoder/rnn_decoder.py new file mode 100644 index 000000000..04c32287a --- /dev/null +++ b/funasr/models_transducer/decoder/rnn_decoder.py @@ -0,0 +1,259 @@ +"""RNN decoder definition for Transducer models.""" + +from typing import List, Optional, Tuple + +import torch +from typeguard import check_argument_types + +from funasr.models_transducer.beam_search_transducer import Hypothesis +from funasr.models_transducer.decoder.abs_decoder import AbsDecoder +from funasr.models.specaug.specaug import SpecAug + +class RNNDecoder(AbsDecoder): + """RNN decoder module. + + Args: + vocab_size: Vocabulary size. + embed_size: Embedding size. + hidden_size: Hidden size.. + rnn_type: Decoder layers type. + num_layers: Number of decoder layers. + dropout_rate: Dropout rate for decoder layers. + embed_dropout_rate: Dropout rate for embedding layer. + embed_pad: Embedding padding symbol ID. + + """ + + def __init__( + self, + vocab_size: int, + embed_size: int = 256, + hidden_size: int = 256, + rnn_type: str = "lstm", + num_layers: int = 1, + dropout_rate: float = 0.0, + embed_dropout_rate: float = 0.0, + embed_pad: int = 0, + ) -> None: + """Construct a RNNDecoder object.""" + super().__init__() + + assert check_argument_types() + + if rnn_type not in ("lstm", "gru"): + raise ValueError(f"Not supported: rnn_type={rnn_type}") + + self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad) + self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate) + + rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU + + self.rnn = torch.nn.ModuleList( + [rnn_class(embed_size, hidden_size, 1, batch_first=True)] + ) + + for _ in range(1, num_layers): + self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)] + + self.dropout_rnn = torch.nn.ModuleList( + [torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)] + ) + + self.dlayers = num_layers + self.dtype = rnn_type + + self.output_size = hidden_size + self.vocab_size = vocab_size + + self.device = next(self.parameters()).device + self.score_cache = {} + + def forward( + self, + labels: torch.Tensor, + label_lens: torch.Tensor, + states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None, + ) -> torch.Tensor: + """Encode source label sequences. + + Args: + labels: Label ID sequences. (B, L) + states: Decoder hidden states. + ((N, B, D_dec), (N, B, D_dec) or None) or None + + Returns: + dec_out: Decoder output sequences. (B, U, D_dec) + + """ + if states is None: + states = self.init_state(labels.size(0)) + + dec_embed = self.dropout_embed(self.embed(labels)) + dec_out, states = self.rnn_forward(dec_embed, states) + return dec_out + + def rnn_forward( + self, + x: torch.Tensor, + state: Tuple[torch.Tensor, Optional[torch.Tensor]], + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Encode source label sequences. + + Args: + x: RNN input sequences. (B, D_emb) + state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) + + Returns: + x: RNN output sequences. (B, D_dec) + (h_next, c_next): Decoder hidden states. + (N, B, D_dec), (N, B, D_dec) or None) + + """ + h_prev, c_prev = state + h_next, c_next = self.init_state(x.size(0)) + + for layer in range(self.dlayers): + if self.dtype == "lstm": + x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[ + layer + ](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1])) + else: + x, h_next[layer : layer + 1] = self.rnn[layer]( + x, hx=h_prev[layer : layer + 1] + ) + + x = self.dropout_rnn[layer](x) + + return x, (h_next, c_next) + + def score( + self, + label: torch.Tensor, + label_sequence: List[int], + dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]], + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: + """One-step forward hypothesis. + + Args: + label: Previous label. (1, 1) + label_sequence: Current label sequence. + dec_state: Previous decoder hidden states. + ((N, 1, D_dec), (N, 1, D_dec) or None) + + Returns: + dec_out: Decoder output sequence. (1, D_dec) + dec_state: Decoder hidden states. + ((N, 1, D_dec), (N, 1, D_dec) or None) + + """ + str_labels = "_".join(map(str, label_sequence)) + + if str_labels in self.score_cache: + dec_out, dec_state = self.score_cache[str_labels] + else: + dec_embed = self.embed(label) + dec_out, dec_state = self.rnn_forward(dec_embed, dec_state) + + self.score_cache[str_labels] = (dec_out, dec_state) + + return dec_out[0], dec_state + + def batch_score( + self, + hyps: List[Hypothesis], + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: + """One-step forward hypotheses. + + Args: + hyps: Hypotheses. + + Returns: + dec_out: Decoder output sequences. (B, D_dec) + states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) + + """ + labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device) + dec_embed = self.embed(labels) + + states = self.create_batch_states([h.dec_state for h in hyps]) + dec_out, states = self.rnn_forward(dec_embed, states) + + return dec_out.squeeze(1), states + + def set_device(self, device: torch.device) -> None: + """Set GPU device to use. + + Args: + device: Device ID. + + """ + self.device = device + + def init_state( + self, batch_size: int + ) -> Tuple[torch.Tensor, Optional[torch.tensor]]: + """Initialize decoder states. + + Args: + batch_size: Batch size. + + Returns: + : Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) + + """ + h_n = torch.zeros( + self.dlayers, + batch_size, + self.output_size, + device=self.device, + ) + + if self.dtype == "lstm": + c_n = torch.zeros( + self.dlayers, + batch_size, + self.output_size, + device=self.device, + ) + + return (h_n, c_n) + + return (h_n, None) + + def select_state( + self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Get specified ID state from decoder hidden states. + + Args: + states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) + idx: State ID to extract. + + Returns: + : Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None) + + """ + return ( + states[0][:, idx : idx + 1, :], + states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None, + ) + + def create_batch_states( + self, + new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Create decoder hidden states. + + Args: + new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)] + + Returns: + states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) + + """ + return ( + torch.cat([s[0] for s in new_states], dim=1), + torch.cat([s[1] for s in new_states], dim=1) + if self.dtype == "lstm" + else None, + ) diff --git a/funasr/models_transducer/decoder/stateless_decoder.py b/funasr/models_transducer/decoder/stateless_decoder.py new file mode 100644 index 000000000..07c8f519b --- /dev/null +++ b/funasr/models_transducer/decoder/stateless_decoder.py @@ -0,0 +1,157 @@ +"""Stateless decoder definition for Transducer models.""" + +from typing import List, Optional, Tuple + +import torch +from typeguard import check_argument_types + +from funasr.models_transducer.beam_search_transducer import Hypothesis +from funasr.models_transducer.decoder.abs_decoder import AbsDecoder +from funasr.models.specaug.specaug import SpecAug + +class StatelessDecoder(AbsDecoder): + """Stateless Transducer decoder module. + + Args: + vocab_size: Output size. + embed_size: Embedding size. + embed_dropout_rate: Dropout rate for embedding layer. + embed_pad: Embed/Blank symbol ID. + + """ + + def __init__( + self, + vocab_size: int, + embed_size: int = 256, + embed_dropout_rate: float = 0.0, + embed_pad: int = 0, + use_embed_mask: bool = False, + ) -> None: + """Construct a StatelessDecoder object.""" + super().__init__() + + assert check_argument_types() + + self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad) + self.embed_dropout_rate = torch.nn.Dropout(p=embed_dropout_rate) + + self.output_size = embed_size + self.vocab_size = vocab_size + + self.device = next(self.parameters()).device + self.score_cache = {} + + self.use_embed_mask = use_embed_mask + if self.use_embed_mask: + self._embed_mask = SpecAug( + time_mask_width_range=3, + num_time_mask=1, + apply_freq_mask=False, + apply_time_warp=False + ) + + + def forward( + self, + labels: torch.Tensor, + label_lens: torch.Tensor, + states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None, + ) -> torch.Tensor: + """Encode source label sequences. + + Args: + labels: Label ID sequences. (B, L) + states: Decoder hidden states. None + + Returns: + dec_embed: Decoder output sequences. (B, U, D_emb) + + """ + dec_embed = self.embed_dropout_rate(self.embed(labels)) + if self.use_embed_mask and self.training: + dec_embed = self._embed_mask(dec_embed, label_lens)[0] + + return dec_embed + + def score( + self, + label: torch.Tensor, + label_sequence: List[int], + state: None, + ) -> Tuple[torch.Tensor, None]: + """One-step forward hypothesis. + + Args: + label: Previous label. (1, 1) + label_sequence: Current label sequence. + state: Previous decoder hidden states. None + + Returns: + dec_out: Decoder output sequence. (1, D_emb) + state: Decoder hidden states. None + + """ + str_labels = "_".join(map(str, label_sequence)) + + if str_labels in self.score_cache: + dec_embed = self.score_cache[str_labels] + else: + dec_embed = self.embed(label) + + self.score_cache[str_labels] = dec_embed + + return dec_embed[0], None + + def batch_score( + self, + hyps: List[Hypothesis], + ) -> Tuple[torch.Tensor, None]: + """One-step forward hypotheses. + + Args: + hyps: Hypotheses. + + Returns: + dec_out: Decoder output sequences. (B, D_dec) + states: Decoder hidden states. None + + """ + labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device) + dec_embed = self.embed(labels) + + return dec_embed.squeeze(1), None + + def set_device(self, device: torch.device) -> None: + """Set GPU device to use. + + Args: + device: Device ID. + + """ + self.device = device + + def init_state(self, batch_size: int) -> None: + """Initialize decoder states. + + Args: + batch_size: Batch size. + + Returns: + : Initial decoder hidden states. None + + """ + return None + + def select_state(self, states: Optional[torch.Tensor], idx: int) -> None: + """Get specified ID state from decoder hidden states. + + Args: + states: Decoder hidden states. None + idx: State ID to extract. + + Returns: + : Decoder hidden state for given ID. None + + """ + return None diff --git a/funasr/models_transducer/encoder/__init__.py b/funasr/models_transducer/encoder/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/models_transducer/encoder/blocks/__init__.py b/funasr/models_transducer/encoder/blocks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/models_transducer/encoder/blocks/branchformer.py b/funasr/models_transducer/encoder/blocks/branchformer.py new file mode 100644 index 000000000..ba0b25d83 --- /dev/null +++ b/funasr/models_transducer/encoder/blocks/branchformer.py @@ -0,0 +1,178 @@ +"""Branchformer block for Transducer encoder.""" + +from typing import Dict, Optional, Tuple + +import torch + + +class Branchformer(torch.nn.Module): + """Branchformer module definition. + + Reference: https://arxiv.org/pdf/2207.02971.pdf + + Args: + block_size: Input/output size. + linear_size: Linear layers' hidden size. + self_att: Self-attention module instance. + conv_mod: Convolution module instance. + norm_class: Normalization class. + norm_args: Normalization module arguments. + dropout_rate: Dropout rate. + + """ + + def __init__( + self, + block_size: int, + linear_size: int, + self_att: torch.nn.Module, + conv_mod: torch.nn.Module, + norm_class: torch.nn.Module = torch.nn.LayerNorm, + norm_args: Dict = {}, + dropout_rate: float = 0.0, + ) -> None: + """Construct a Branchformer object.""" + super().__init__() + + self.self_att = self_att + self.conv_mod = conv_mod + + self.channel_proj1 = torch.nn.Sequential( + torch.nn.Linear(block_size, linear_size), torch.nn.GELU() + ) + self.channel_proj2 = torch.nn.Linear(linear_size // 2, block_size) + + self.merge_proj = torch.nn.Linear(block_size + block_size, block_size) + + self.norm_self_att = norm_class(block_size, **norm_args) + self.norm_mlp = norm_class(block_size, **norm_args) + self.norm_final = norm_class(block_size, **norm_args) + + self.dropout = torch.nn.Dropout(dropout_rate) + + self.block_size = block_size + self.linear_size = linear_size + self.cache = None + + def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: + """Initialize/Reset self-attention and convolution modules cache for streaming. + + Args: + left_context: Number of left frames during chunk-by-chunk inference. + device: Device to use for cache tensor. + + """ + self.cache = [ + torch.zeros( + (1, left_context, self.block_size), + device=device, + ), + torch.zeros( + ( + 1, + self.linear_size // 2, + self.conv_mod.kernel_size - 1, + ), + device=device, + ), + ] + + def forward( + self, + x: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + chunk_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Encode input sequences. + + Args: + x: Branchformer input sequences. (B, T, D_block) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + mask: Source mask. (B, T) + chunk_mask: Chunk mask. (T_2, T_2) + + Returns: + x: Branchformer output sequences. (B, T, D_block) + mask: Source mask. (B, T) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + + """ + x1 = x + x2 = x + + x1 = self.norm_self_att(x1) + + x1 = self.dropout( + self.self_att(x1, x1, x1, pos_enc, mask=mask, chunk_mask=chunk_mask) + ) + + x2 = self.norm_mlp(x2) + + x2 = self.channel_proj1(x2) + x2, _ = self.conv_mod(x2) + x2 = self.channel_proj2(x2) + + x2 = self.dropout(x2) + + x = x + self.dropout(self.merge_proj(torch.cat([x1, x2], dim=-1))) + + x = self.norm_final(x) + + return x, mask, pos_enc + + def chunk_forward( + self, + x: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + left_context: int = 0, + right_context: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode chunk of input sequence. + + Args: + x: Branchformer input sequences. (B, T, D_block) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + mask: Source mask. (B, T_2) + left_context: Number of frames in left context. + right_context: Number of frames in right context. + + Returns: + x: Branchformer output sequences. (B, T, D_block) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + + """ + x1 = x + x2 = x + + x1 = self.norm_self_att(x1) + + if left_context > 0: + key = torch.cat([self.cache[0], x1], dim=1) + else: + key = x1 + val = key + + if right_context > 0: + att_cache = key[:, -(left_context + right_context) : -right_context, :] + else: + att_cache = key[:, -left_context:, :] + + x1 = self.self_att(x1, key, val, pos_enc, mask=mask, left_context=left_context) + + x2 = self.norm_mlp(x2) + x2 = self.channel_proj1(x2) + + x2, conv_cache = self.conv_mod( + x2, cache=self.cache[1], right_context=right_context + ) + + x2 = self.channel_proj2(x2) + + x = x + self.merge_proj(torch.cat([x1, x2], dim=-1)) + + x = self.norm_final(x) + self.cache = [att_cache, conv_cache] + + return x, pos_enc diff --git a/funasr/models_transducer/encoder/blocks/conformer.py b/funasr/models_transducer/encoder/blocks/conformer.py new file mode 100644 index 000000000..0b9bbbf12 --- /dev/null +++ b/funasr/models_transducer/encoder/blocks/conformer.py @@ -0,0 +1,198 @@ +"""Conformer block for Transducer encoder.""" + +from typing import Dict, Optional, Tuple + +import torch + + +class Conformer(torch.nn.Module): + """Conformer module definition. + + Args: + block_size: Input/output size. + self_att: Self-attention module instance. + feed_forward: Feed-forward module instance. + feed_forward_macaron: Feed-forward module instance for macaron network. + conv_mod: Convolution module instance. + norm_class: Normalization module class. + norm_args: Normalization module arguments. + dropout_rate: Dropout rate. + + """ + + def __init__( + self, + block_size: int, + self_att: torch.nn.Module, + feed_forward: torch.nn.Module, + feed_forward_macaron: torch.nn.Module, + conv_mod: torch.nn.Module, + norm_class: torch.nn.Module = torch.nn.LayerNorm, + norm_args: Dict = {}, + dropout_rate: float = 0.0, + ) -> None: + """Construct a Conformer object.""" + super().__init__() + + self.self_att = self_att + + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.feed_forward_scale = 0.5 + + self.conv_mod = conv_mod + + self.norm_feed_forward = norm_class(block_size, **norm_args) + self.norm_self_att = norm_class(block_size, **norm_args) + + self.norm_macaron = norm_class(block_size, **norm_args) + self.norm_conv = norm_class(block_size, **norm_args) + self.norm_final = norm_class(block_size, **norm_args) + + self.dropout = torch.nn.Dropout(dropout_rate) + + self.block_size = block_size + self.cache = None + + def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: + """Initialize/Reset self-attention and convolution modules cache for streaming. + + Args: + left_context: Number of left frames during chunk-by-chunk inference. + device: Device to use for cache tensor. + + """ + self.cache = [ + torch.zeros( + (1, left_context, self.block_size), + device=device, + ), + torch.zeros( + ( + 1, + self.block_size, + self.conv_mod.kernel_size - 1, + ), + device=device, + ), + ] + + def forward( + self, + x: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + chunk_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Encode input sequences. + + Args: + x: Conformer input sequences. (B, T, D_block) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + mask: Source mask. (B, T) + chunk_mask: Chunk mask. (T_2, T_2) + + Returns: + x: Conformer output sequences. (B, T, D_block) + mask: Source mask. (B, T) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + + """ + residual = x + + x = self.norm_macaron(x) + x = residual + self.feed_forward_scale * self.dropout( + self.feed_forward_macaron(x) + ) + + residual = x + x = self.norm_self_att(x) + x_q = x + x = residual + self.dropout( + self.self_att( + x_q, + x, + x, + pos_enc, + mask, + chunk_mask=chunk_mask, + ) + ) + + residual = x + + x = self.norm_conv(x) + x, _ = self.conv_mod(x) + x = residual + self.dropout(x) + residual = x + + x = self.norm_feed_forward(x) + x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x)) + + x = self.norm_final(x) + return x, mask, pos_enc + + def chunk_forward( + self, + x: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + chunk_size: int = 16, + left_context: int = 0, + right_context: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode chunk of input sequence. + + Args: + x: Conformer input sequences. (B, T, D_block) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + mask: Source mask. (B, T_2) + left_context: Number of frames in left context. + right_context: Number of frames in right context. + + Returns: + x: Conformer output sequences. (B, T, D_block) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + + """ + residual = x + + x = self.norm_macaron(x) + x = residual + self.feed_forward_scale * self.feed_forward_macaron(x) + + residual = x + x = self.norm_self_att(x) + if left_context > 0: + key = torch.cat([self.cache[0], x], dim=1) + else: + key = x + val = key + + if right_context > 0: + att_cache = key[:, -(left_context + right_context) : -right_context, :] + else: + att_cache = key[:, -left_context:, :] + x = residual + self.self_att( + x, + key, + val, + pos_enc, + mask, + left_context=left_context, + ) + + residual = x + x = self.norm_conv(x) + x, conv_cache = self.conv_mod( + x, cache=self.cache[1], right_context=right_context + ) + x = residual + x + residual = x + + x = self.norm_feed_forward(x) + x = residual + self.feed_forward_scale * self.feed_forward(x) + + x = self.norm_final(x) + self.cache = [att_cache, conv_cache] + + return x, pos_enc diff --git a/funasr/models_transducer/encoder/blocks/conv1d.py b/funasr/models_transducer/encoder/blocks/conv1d.py new file mode 100644 index 000000000..f79cc37b4 --- /dev/null +++ b/funasr/models_transducer/encoder/blocks/conv1d.py @@ -0,0 +1,221 @@ +"""Conv1d block for Transducer encoder.""" + +from typing import Optional, Tuple, Union + +import torch + + +class Conv1d(torch.nn.Module): + """Conv1d module definition. + + Args: + input_size: Input dimension. + output_size: Output dimension. + kernel_size: Size of the convolving kernel. + stride: Stride of the convolution. + dilation: Spacing between the kernel points. + groups: Number of blocked connections from input channels to output channels. + bias: Whether to add a learnable bias to the output. + batch_norm: Whether to use batch normalization after convolution. + relu: Whether to use a ReLU activation after convolution. + causal: Whether to use causal convolution (set to True if streaming). + dropout_rate: Dropout rate. + + """ + + def __init__( + self, + input_size: int, + output_size: int, + kernel_size: Union[int, Tuple], + stride: Union[int, Tuple] = 1, + dilation: Union[int, Tuple] = 1, + groups: Union[int, Tuple] = 1, + bias: bool = True, + batch_norm: bool = False, + relu: bool = True, + causal: bool = False, + dropout_rate: float = 0.0, + ) -> None: + """Construct a Conv1d object.""" + super().__init__() + + if causal: + self.lorder = kernel_size - 1 + stride = 1 + else: + self.lorder = 0 + stride = stride + + self.conv = torch.nn.Conv1d( + input_size, + output_size, + kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + bias=bias, + ) + + self.dropout = torch.nn.Dropout(p=dropout_rate) + + if relu: + self.relu_func = torch.nn.ReLU() + + if batch_norm: + self.bn = torch.nn.BatchNorm1d(output_size) + + self.out_pos = torch.nn.Linear(input_size, output_size) + + self.input_size = input_size + self.output_size = output_size + + self.relu = relu + self.batch_norm = batch_norm + self.causal = causal + + self.kernel_size = kernel_size + self.padding = dilation * (kernel_size - 1) + self.stride = stride + + self.cache = None + + def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: + """Initialize/Reset Conv1d cache for streaming. + + Args: + left_context: Number of left frames during chunk-by-chunk inference. + device: Device to use for cache tensor. + + """ + self.cache = torch.zeros( + (1, self.input_size, self.kernel_size - 1), device=device + ) + + def forward( + self, + x: torch.Tensor, + pos_enc: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Encode input sequences. + + Args: + x: Conv1d input sequences. (B, T, D_in) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_in) + mask: Source mask. (B, T) + chunk_mask: Chunk mask. (T_2, T_2) + + Returns: + x: Conv1d output sequences. (B, sub(T), D_out) + mask: Source mask. (B, T) or (B, sub(T)) + pos_enc: Positional embedding sequences. + (B, 2 * (T - 1), D_att) or (B, 2 * (sub(T) - 1), D_out) + + """ + x = x.transpose(1, 2) + + if self.lorder > 0: + x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + else: + mask = self.create_new_mask(mask) + pos_enc = self.create_new_pos_enc(pos_enc) + + x = self.conv(x) + + if self.batch_norm: + x = self.bn(x) + + x = self.dropout(x) + + if self.relu: + x = self.relu_func(x) + + x = x.transpose(1, 2) + + return x, mask, self.out_pos(pos_enc) + + def chunk_forward( + self, + x: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + left_context: int = 0, + right_context: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode chunk of input sequence. + + Args: + x: Conv1d input sequences. (B, T, D_in) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_in) + mask: Source mask. (B, T) + left_context: Number of frames in left context. + right_context: Number of frames in right context. + + Returns: + x: Conv1d output sequences. (B, T, D_out) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_out) + + """ + x = torch.cat([self.cache, x.transpose(1, 2)], dim=2) + + if right_context > 0: + self.cache = x[:, :, -(self.lorder + right_context) : -right_context] + else: + self.cache = x[:, :, -self.lorder :] + + x = self.conv(x) + + if self.batch_norm: + x = self.bn(x) + + x = self.dropout(x) + + if self.relu: + x = self.relu_func(x) + + x = x.transpose(1, 2) + + return x, self.out_pos(pos_enc) + + def create_new_mask(self, mask: torch.Tensor) -> torch.Tensor: + """Create new mask for output sequences. + + Args: + mask: Mask of input sequences. (B, T) + + Returns: + mask: Mask of output sequences. (B, sub(T)) + + """ + if self.padding != 0: + mask = mask[:, : -self.padding] + + return mask[:, :: self.stride] + + def create_new_pos_enc(self, pos_enc: torch.Tensor) -> torch.Tensor: + """Create new positional embedding vector. + + Args: + pos_enc: Input sequences positional embedding. + (B, 2 * (T - 1), D_in) + + Returns: + pos_enc: Output sequences positional embedding. + (B, 2 * (sub(T) - 1), D_in) + + """ + pos_enc_positive = pos_enc[:, : pos_enc.size(1) // 2 + 1, :] + pos_enc_negative = pos_enc[:, pos_enc.size(1) // 2 :, :] + + if self.padding != 0: + pos_enc_positive = pos_enc_positive[:, : -self.padding, :] + pos_enc_negative = pos_enc_negative[:, : -self.padding, :] + + pos_enc_positive = pos_enc_positive[:, :: self.stride, :] + pos_enc_negative = pos_enc_negative[:, :: self.stride, :] + + pos_enc = torch.cat([pos_enc_positive, pos_enc_negative[:, 1:, :]], dim=1) + + return pos_enc diff --git a/funasr/models_transducer/encoder/blocks/conv_input.py b/funasr/models_transducer/encoder/blocks/conv_input.py new file mode 100644 index 000000000..931d0f0eb --- /dev/null +++ b/funasr/models_transducer/encoder/blocks/conv_input.py @@ -0,0 +1,226 @@ +"""ConvInput block for Transducer encoder.""" + +from typing import Optional, Tuple, Union + +import torch +import math + +from funasr.models_transducer.utils import sub_factor_to_params, pad_to_len + + +class ConvInput(torch.nn.Module): + """ConvInput module definition. + + Args: + input_size: Input size. + conv_size: Convolution size. + subsampling_factor: Subsampling factor. + vgg_like: Whether to use a VGG-like network. + output_size: Block output dimension. + + """ + + def __init__( + self, + input_size: int, + conv_size: Union[int, Tuple], + subsampling_factor: int = 4, + vgg_like: bool = True, + output_size: Optional[int] = None, + ) -> None: + """Construct a ConvInput object.""" + super().__init__() + if vgg_like: + if subsampling_factor == 1: + conv_size1, conv_size2 = conv_size + + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((1, 2)), + torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((1, 2)), + ) + + output_proj = conv_size2 * ((input_size // 2) // 2) + + self.subsampling_factor = 1 + + self.stride_1 = 1 + + self.create_new_mask = self.create_new_vgg_mask + + else: + conv_size1, conv_size2 = conv_size + + kernel_1 = int(subsampling_factor / 2) + + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((kernel_1, 2)), + torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((2, 2)), + ) + + output_proj = conv_size2 * ((input_size // 2) // 2) + + self.subsampling_factor = subsampling_factor + + self.create_new_mask = self.create_new_vgg_mask + + self.stride_1 = kernel_1 + + else: + if subsampling_factor == 1: + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]), + torch.nn.ReLU(), + ) + + output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2) + + self.subsampling_factor = subsampling_factor + self.kernel_2 = 3 + self.stride_2 = 1 + + self.create_new_mask = self.create_new_conv2d_mask + + else: + kernel_2, stride_2, conv_2_output_size = sub_factor_to_params( + subsampling_factor, + input_size, + ) + + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, conv_size, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2), + torch.nn.ReLU(), + ) + + output_proj = conv_size * conv_2_output_size + + self.subsampling_factor = subsampling_factor + self.kernel_2 = kernel_2 + self.stride_2 = stride_2 + + self.create_new_mask = self.create_new_conv2d_mask + + self.vgg_like = vgg_like + self.min_frame_length = 2 + + if output_size is not None: + self.output = torch.nn.Linear(output_proj, output_size) + self.output_size = output_size + else: + self.output = None + self.output_size = output_proj + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode input sequences. + + Args: + x: ConvInput input sequences. (B, T, D_feats) + mask: Mask of input sequences. (B, 1, T) + + Returns: + x: ConvInput output sequences. (B, sub(T), D_out) + mask: Mask of output sequences. (B, 1, sub(T)) + + """ + if mask is not None: + mask = self.create_new_mask(mask) + olens = max(mask.eq(0).sum(1)) + + b, t_input, f = x.size() + x = x.unsqueeze(1) # (b. 1. t. f) + if chunk_size is not None: + max_input_length = int( + chunk_size * self.subsampling_factor * (math.ceil(float(t_input) / (chunk_size * self.subsampling_factor) )) + ) + x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x) + x = list(x) + x = torch.stack(x, dim=0) + N_chunks = max_input_length // ( chunk_size * self.subsampling_factor) + x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f) + x = self.conv(x) + + _, c, t, f = x.size() + + if chunk_size is not None: + x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:] + else: + x = x.transpose(1, 2).contiguous().view(b, t, c * f) + + if self.output is not None: + x = self.output(x) + + return x, mask[:,:olens][:,:x.size(1)] + + def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor: + """Create a new mask for VGG output sequences. + + Args: + mask: Mask of input sequences. (B, T) + + Returns: + mask: Mask of output sequences. (B, sub(T)) + + """ + if self.subsampling_factor > 1: + vgg1_t_len = mask.size(1) - (mask.size(1) % (self.subsampling_factor // 2 )) + mask = mask[:, :vgg1_t_len][:, ::self.subsampling_factor // 2] + + vgg2_t_len = mask.size(1) - (mask.size(1) % 2) + mask = mask[:, :vgg2_t_len][:, ::2] + else: + mask = mask + + return mask + + def create_new_conv2d_mask(self, mask: torch.Tensor) -> torch.Tensor: + """Create new conformer mask for Conv2d output sequences. + + Args: + mask: Mask of input sequences. (B, T) + + Returns: + mask: Mask of output sequences. (B, sub(T)) + + """ + if self.subsampling_factor > 1: + return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2] + else: + return mask + + def get_size_before_subsampling(self, size: int) -> int: + """Return the original size before subsampling for a given size. + + Args: + size: Number of frames after subsampling. + + Returns: + : Number of frames before subsampling. + + """ + if self.subsampling_factor > 1: + if self.vgg_like: + return ((size * 2) * self.stride_1) + 1 + + return ((size + 2) * 2) + (self.kernel_2 - 1) * self.stride_2 + return size diff --git a/funasr/models_transducer/encoder/blocks/linear_input.py b/funasr/models_transducer/encoder/blocks/linear_input.py new file mode 100644 index 000000000..9bb9698a7 --- /dev/null +++ b/funasr/models_transducer/encoder/blocks/linear_input.py @@ -0,0 +1,52 @@ +"""LinearInput block for Transducer encoder.""" + +from typing import Optional, Tuple, Union + +import torch + +class LinearInput(torch.nn.Module): + """ConvInput module definition. + + Args: + input_size: Input size. + conv_size: Convolution size. + subsampling_factor: Subsampling factor. + vgg_like: Whether to use a VGG-like network. + output_size: Block output dimension. + + """ + + def __init__( + self, + input_size: int, + output_size: Optional[int] = None, + subsampling_factor: int = 1, + ) -> None: + """Construct a ConvInput object.""" + super().__init__() + self.embed = torch.nn.Sequential( + torch.nn.Linear(input_size, output_size), + torch.nn.LayerNorm(output_size), + torch.nn.Dropout(0.1), + ) + self.subsampling_factor = subsampling_factor + self.min_frame_length = 1 + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + + x = self.embed(x) + return x, mask + + def get_size_before_subsampling(self, size: int) -> int: + """Return the original size before subsampling for a given size. + + Args: + size: Number of frames after subsampling. + + Returns: + : Number of frames before subsampling. + + """ + return size diff --git a/funasr/models_transducer/encoder/building.py b/funasr/models_transducer/encoder/building.py new file mode 100644 index 000000000..a19943be7 --- /dev/null +++ b/funasr/models_transducer/encoder/building.py @@ -0,0 +1,352 @@ +"""Set of methods to build Transducer encoder architecture.""" + +from typing import Any, Dict, List, Optional, Union + +from funasr.models_transducer.activation import get_activation +from funasr.models_transducer.encoder.blocks.branchformer import Branchformer +from funasr.models_transducer.encoder.blocks.conformer import Conformer +from funasr.models_transducer.encoder.blocks.conv1d import Conv1d +from funasr.models_transducer.encoder.blocks.conv_input import ConvInput +from funasr.models_transducer.encoder.blocks.linear_input import LinearInput +from funasr.models_transducer.encoder.modules.attention import ( # noqa: H301 + RelPositionMultiHeadedAttention, +) +from funasr.models_transducer.encoder.modules.convolution import ( # noqa: H301 + ConformerConvolution, + ConvolutionalSpatialGatingUnit, +) +from funasr.models_transducer.encoder.modules.multi_blocks import MultiBlocks +from funasr.models_transducer.encoder.modules.normalization import get_normalization +from funasr.models_transducer.encoder.modules.positional_encoding import ( # noqa: H301 + RelPositionalEncoding, +) +from funasr.modules.positionwise_feed_forward import ( + PositionwiseFeedForward, +) + + +def build_main_parameters( + pos_wise_act_type: str = "swish", + conv_mod_act_type: str = "swish", + pos_enc_dropout_rate: float = 0.0, + pos_enc_max_len: int = 5000, + simplified_att_score: bool = False, + norm_type: str = "layer_norm", + conv_mod_norm_type: str = "layer_norm", + after_norm_eps: Optional[float] = None, + after_norm_partial: Optional[float] = None, + dynamic_chunk_training: bool = False, + short_chunk_threshold: float = 0.75, + short_chunk_size: int = 25, + left_chunk_size: int = 0, + time_reduction_factor: int = 1, + unified_model_training: bool = False, + default_chunk_size: int = 16, + jitter_range: int =4, + **activation_parameters, +) -> Dict[str, Any]: + """Build encoder main parameters. + + Args: + pos_wise_act_type: Conformer position-wise feed-forward activation type. + conv_mod_act_type: Conformer convolution module activation type. + pos_enc_dropout_rate: Positional encoding dropout rate. + pos_enc_max_len: Positional encoding maximum length. + simplified_att_score: Whether to use simplified attention score computation. + norm_type: X-former normalization module type. + conv_mod_norm_type: Conformer convolution module normalization type. + after_norm_eps: Epsilon value for the final normalization. + after_norm_partial: Value for the final normalization with RMSNorm. + dynamic_chunk_training: Whether to use dynamic chunk training. + short_chunk_threshold: Threshold for dynamic chunk selection. + short_chunk_size: Minimum number of frames during dynamic chunk training. + left_chunk_size: Number of frames in left context. + **activations_parameters: Parameters of the activation functions. + (See espnet2/asr_transducer/activation.py) + + Returns: + : Main encoder parameters + + """ + main_params = {} + + main_params["pos_wise_act"] = get_activation( + pos_wise_act_type, **activation_parameters + ) + + main_params["conv_mod_act"] = get_activation( + conv_mod_act_type, **activation_parameters + ) + + main_params["pos_enc_dropout_rate"] = pos_enc_dropout_rate + main_params["pos_enc_max_len"] = pos_enc_max_len + + main_params["simplified_att_score"] = simplified_att_score + + main_params["norm_type"] = norm_type + main_params["conv_mod_norm_type"] = conv_mod_norm_type + + ( + main_params["after_norm_class"], + main_params["after_norm_args"], + ) = get_normalization(norm_type, eps=after_norm_eps, partial=after_norm_partial) + + main_params["dynamic_chunk_training"] = dynamic_chunk_training + main_params["short_chunk_threshold"] = max(0, short_chunk_threshold) + main_params["short_chunk_size"] = max(0, short_chunk_size) + main_params["left_chunk_size"] = max(0, left_chunk_size) + + main_params["unified_model_training"] = unified_model_training + main_params["default_chunk_size"] = max(0, default_chunk_size) + main_params["jitter_range"] = max(0, jitter_range) + + main_params["time_reduction_factor"] = time_reduction_factor + + return main_params + + +def build_positional_encoding( + block_size: int, configuration: Dict[str, Any] +) -> RelPositionalEncoding: + """Build positional encoding block. + + Args: + block_size: Input/output size. + configuration: Positional encoding configuration. + + Returns: + : Positional encoding module. + + """ + return RelPositionalEncoding( + block_size, + configuration.get("pos_enc_dropout_rate", 0.0), + max_len=configuration.get("pos_enc_max_len", 5000), + ) + + +def build_input_block( + input_size: int, + configuration: Dict[str, Union[str, int]], +) -> ConvInput: + """Build encoder input block. + + Args: + input_size: Input size. + configuration: Input block configuration. + + Returns: + : ConvInput block function. + + """ + if configuration["linear"]: + return LinearInput( + input_size, + configuration["output_size"], + configuration["subsampling_factor"], + ) + else: + return ConvInput( + input_size, + configuration["conv_size"], + configuration["subsampling_factor"], + vgg_like=configuration["vgg_like"], + output_size=configuration["output_size"], + ) + + +def build_branchformer_block( + configuration: List[Dict[str, Any]], + main_params: Dict[str, Any], +) -> Conformer: + """Build Branchformer block. + + Args: + configuration: Branchformer block configuration. + main_params: Encoder main parameters. + + Returns: + : Branchformer block function. + + """ + hidden_size = configuration["hidden_size"] + linear_size = configuration["linear_size"] + + dropout_rate = configuration.get("dropout_rate", 0.0) + + conv_mod_norm_class, conv_mod_norm_args = get_normalization( + main_params["conv_mod_norm_type"], + eps=configuration.get("conv_mod_norm_eps"), + partial=configuration.get("conv_mod_norm_partial"), + ) + + conv_mod_args = ( + linear_size, + configuration["conv_mod_kernel_size"], + conv_mod_norm_class, + conv_mod_norm_args, + dropout_rate, + main_params["dynamic_chunk_training"], + ) + + mult_att_args = ( + configuration.get("heads", 4), + hidden_size, + configuration.get("att_dropout_rate", 0.0), + main_params["simplified_att_score"], + ) + + norm_class, norm_args = get_normalization( + main_params["norm_type"], + eps=configuration.get("norm_eps"), + partial=configuration.get("norm_partial"), + ) + + return lambda: Branchformer( + hidden_size, + linear_size, + RelPositionMultiHeadedAttention(*mult_att_args), + ConvolutionalSpatialGatingUnit(*conv_mod_args), + norm_class=norm_class, + norm_args=norm_args, + dropout_rate=dropout_rate, + ) + + +def build_conformer_block( + configuration: List[Dict[str, Any]], + main_params: Dict[str, Any], +) -> Conformer: + """Build Conformer block. + + Args: + configuration: Conformer block configuration. + main_params: Encoder main parameters. + + Returns: + : Conformer block function. + + """ + hidden_size = configuration["hidden_size"] + linear_size = configuration["linear_size"] + + pos_wise_args = ( + hidden_size, + linear_size, + configuration.get("pos_wise_dropout_rate", 0.0), + main_params["pos_wise_act"], + ) + + conv_mod_norm_args = { + "eps": configuration.get("conv_mod_norm_eps", 1e-05), + "momentum": configuration.get("conv_mod_norm_momentum", 0.1), + } + + conv_mod_args = ( + hidden_size, + configuration["conv_mod_kernel_size"], + main_params["conv_mod_act"], + conv_mod_norm_args, + main_params["dynamic_chunk_training"] or main_params["unified_model_training"], + ) + + mult_att_args = ( + configuration.get("heads", 4), + hidden_size, + configuration.get("att_dropout_rate", 0.0), + main_params["simplified_att_score"], + ) + + norm_class, norm_args = get_normalization( + main_params["norm_type"], + eps=configuration.get("norm_eps"), + partial=configuration.get("norm_partial"), + ) + + return lambda: Conformer( + hidden_size, + RelPositionMultiHeadedAttention(*mult_att_args), + PositionwiseFeedForward(*pos_wise_args), + PositionwiseFeedForward(*pos_wise_args), + ConformerConvolution(*conv_mod_args), + norm_class=norm_class, + norm_args=norm_args, + dropout_rate=configuration.get("dropout_rate", 0.0), + ) + + +def build_conv1d_block( + configuration: List[Dict[str, Any]], + causal: bool, +) -> Conv1d: + """Build Conv1d block. + + Args: + configuration: Conv1d block configuration. + + Returns: + : Conv1d block function. + + """ + return lambda: Conv1d( + configuration["input_size"], + configuration["output_size"], + configuration["kernel_size"], + stride=configuration.get("stride", 1), + dilation=configuration.get("dilation", 1), + groups=configuration.get("groups", 1), + bias=configuration.get("bias", True), + relu=configuration.get("relu", True), + batch_norm=configuration.get("batch_norm", False), + causal=causal, + dropout_rate=configuration.get("dropout_rate", 0.0), + ) + + +def build_body_blocks( + configuration: List[Dict[str, Any]], + main_params: Dict[str, Any], + output_size: int, +) -> MultiBlocks: + """Build encoder body blocks. + + Args: + configuration: Body blocks configuration. + main_params: Encoder main parameters. + output_size: Architecture output size. + + Returns: + MultiBlocks function encapsulation all encoder blocks. + + """ + fn_modules = [] + extended_conf = [] + + for c in configuration: + if c.get("num_blocks") is not None: + extended_conf += c["num_blocks"] * [ + {c_i: c[c_i] for c_i in c if c_i != "num_blocks"} + ] + else: + extended_conf += [c] + + for i, c in enumerate(extended_conf): + block_type = c["block_type"] + + if block_type == "branchformer": + module = build_branchformer_block(c, main_params) + elif block_type == "conformer": + module = build_conformer_block(c, main_params) + elif block_type == "conv1d": + module = build_conv1d_block(c, main_params["dynamic_chunk_training"]) + else: + raise NotImplementedError + + fn_modules.append(module) + + return MultiBlocks( + [fn() for fn in fn_modules], + output_size, + norm_class=main_params["after_norm_class"], + norm_args=main_params["after_norm_args"], + ) diff --git a/funasr/models_transducer/encoder/encoder.py b/funasr/models_transducer/encoder/encoder.py new file mode 100644 index 000000000..45c99c1de --- /dev/null +++ b/funasr/models_transducer/encoder/encoder.py @@ -0,0 +1,294 @@ +"""Encoder for Transducer model.""" + +from typing import Any, Dict, List, Tuple + +import torch +from typeguard import check_argument_types + +from funasr.models_transducer.encoder.building import ( + build_body_blocks, + build_input_block, + build_main_parameters, + build_positional_encoding, +) +from funasr.models_transducer.encoder.validation import validate_architecture +from funasr.models_transducer.utils import ( + TooShortUttError, + check_short_utt, + make_chunk_mask, + make_source_mask, +) + + +class Encoder(torch.nn.Module): + """Encoder module definition. + + Args: + input_size: Input size. + body_conf: Encoder body configuration. + input_conf: Encoder input configuration. + main_conf: Encoder main configuration. + + """ + + def __init__( + self, + input_size: int, + body_conf: List[Dict[str, Any]], + input_conf: Dict[str, Any] = {}, + main_conf: Dict[str, Any] = {}, + ) -> None: + """Construct an Encoder object.""" + super().__init__() + + assert check_argument_types() + + embed_size, output_size = validate_architecture( + input_conf, body_conf, input_size + ) + main_params = build_main_parameters(**main_conf) + + self.embed = build_input_block(input_size, input_conf) + self.pos_enc = build_positional_encoding(embed_size, main_params) + self.encoders = build_body_blocks(body_conf, main_params, output_size) + + self.output_size = output_size + + self.dynamic_chunk_training = main_params["dynamic_chunk_training"] + self.short_chunk_threshold = main_params["short_chunk_threshold"] + self.short_chunk_size = main_params["short_chunk_size"] + self.left_chunk_size = main_params["left_chunk_size"] + + self.unified_model_training = main_params["unified_model_training"] + self.default_chunk_size = main_params["default_chunk_size"] + self.jitter_range = main_params["jitter_range"] + + self.time_reduction_factor = main_params["time_reduction_factor"] + + def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int: + """Return the corresponding number of sample for a given chunk size, in frames. + + Where size is the number of features frames after applying subsampling. + + Args: + size: Number of frames after subsampling. + hop_length: Frontend's hop length + + Returns: + : Number of raw samples + + """ + return self.embed.get_size_before_subsampling(size) * hop_length + + def get_encoder_input_size(self, size: int) -> int: + """Return the corresponding number of sample for a given chunk size, in frames. + + Where size is the number of features frames after applying subsampling. + + Args: + size: Number of frames after subsampling. + + Returns: + : Number of raw samples + + """ + return self.embed.get_size_before_subsampling(size) + + + def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: + """Initialize/Reset encoder streaming cache. + + Args: + left_context: Number of frames in left context. + device: Device ID. + + """ + return self.encoders.reset_streaming_cache(left_context, device) + + def forward( + self, + x: torch.Tensor, + x_len: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode input sequences. + + Args: + x: Encoder input features. (B, T_in, F) + x_len: Encoder input features lengths. (B,) + + Returns: + x: Encoder outputs. (B, T_out, D_enc) + x_len: Encoder outputs lenghts. (B,) + + """ + short_status, limit_size = check_short_utt( + self.embed.subsampling_factor, x.size(1) + ) + + if short_status: + raise TooShortUttError( + f"has {x.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + x.size(1), + limit_size, + ) + + mask = make_source_mask(x_len) + if self.unified_model_training: + x, mask = self.embed(x, mask, self.default_chunk_size) + else: + x, mask = self.embed(x, mask) + pos_enc = self.pos_enc(x) + + if self.unified_model_training: + chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item() + chunk_mask = make_chunk_mask( + x.size(1), + chunk_size, + left_chunk_size=self.left_chunk_size, + device=x.device, + ) + x_utt = self.encoders( + x, + pos_enc, + mask, + chunk_mask=None, + ) + x_chunk = self.encoders( + x, + pos_enc, + mask, + chunk_mask=chunk_mask, + ) + + olens = mask.eq(0).sum(1) + if self.time_reduction_factor > 1: + x_utt = x_utt[:,::self.time_reduction_factor,:] + x_chunk = x_chunk[:,::self.time_reduction_factor,:] + olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 + + return x_utt, x_chunk, olens + + elif self.dynamic_chunk_training: + max_len = x.size(1) + chunk_size = torch.randint(1, max_len, (1,)).item() + + if chunk_size > (max_len * self.short_chunk_threshold): + chunk_size = max_len + else: + chunk_size = (chunk_size % self.short_chunk_size) + 1 + + chunk_mask = make_chunk_mask( + x.size(1), + chunk_size, + left_chunk_size=self.left_chunk_size, + device=x.device, + ) + else: + chunk_mask = None + x = self.encoders( + x, + pos_enc, + mask, + chunk_mask=chunk_mask, + ) + + olens = mask.eq(0).sum(1) + if self.time_reduction_factor > 1: + x = x[:,::self.time_reduction_factor,:] + olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 + + return x, olens + + def simu_chunk_forward( + self, + x: torch.Tensor, + x_len: torch.Tensor, + chunk_size: int = 16, + left_context: int = 32, + right_context: int = 0, + ) -> torch.Tensor: + short_status, limit_size = check_short_utt( + self.embed.subsampling_factor, x.size(1) + ) + + if short_status: + raise TooShortUttError( + f"has {x.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + x.size(1), + limit_size, + ) + + mask = make_source_mask(x_len) + + x, mask = self.embed(x, mask, chunk_size) + pos_enc = self.pos_enc(x) + chunk_mask = make_chunk_mask( + x.size(1), + chunk_size, + left_chunk_size=self.left_chunk_size, + device=x.device, + ) + + x = self.encoders( + x, + pos_enc, + mask, + chunk_mask=chunk_mask, + ) + olens = mask.eq(0).sum(1) + if self.time_reduction_factor > 1: + x = x[:,::self.time_reduction_factor,:] + + return x + + def chunk_forward( + self, + x: torch.Tensor, + x_len: torch.Tensor, + processed_frames: torch.tensor, + chunk_size: int = 16, + left_context: int = 32, + right_context: int = 0, + ) -> torch.Tensor: + """Encode input sequences as chunks. + + Args: + x: Encoder input features. (1, T_in, F) + x_len: Encoder input features lengths. (1,) + processed_frames: Number of frames already seen. + left_context: Number of frames in left context. + right_context: Number of frames in right context. + + Returns: + x: Encoder outputs. (B, T_out, D_enc) + + """ + mask = make_source_mask(x_len) + x, mask = self.embed(x, mask, None) + + if left_context > 0: + processed_mask = ( + torch.arange(left_context, device=x.device) + .view(1, left_context) + .flip(1) + ) + processed_mask = processed_mask >= processed_frames + mask = torch.cat([processed_mask, mask], dim=1) + pos_enc = self.pos_enc(x, left_context=left_context) + x = self.encoders.chunk_forward( + x, + pos_enc, + mask, + chunk_size=chunk_size, + left_context=left_context, + right_context=right_context, + ) + + if right_context > 0: + x = x[:, 0:-right_context, :] + + if self.time_reduction_factor > 1: + x = x[:,::self.time_reduction_factor,:] + return x diff --git a/funasr/models_transducer/encoder/modules/__init__.py b/funasr/models_transducer/encoder/modules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/models_transducer/encoder/modules/attention.py b/funasr/models_transducer/encoder/modules/attention.py new file mode 100644 index 000000000..53e708750 --- /dev/null +++ b/funasr/models_transducer/encoder/modules/attention.py @@ -0,0 +1,246 @@ +"""Multi-Head attention layers with relative positional encoding.""" + +import math +from typing import Optional, Tuple + +import torch + + +class RelPositionMultiHeadedAttention(torch.nn.Module): + """RelPositionMultiHeadedAttention definition. + + Args: + num_heads: Number of attention heads. + embed_size: Embedding size. + dropout_rate: Dropout rate. + + """ + + def __init__( + self, + num_heads: int, + embed_size: int, + dropout_rate: float = 0.0, + simplified_attention_score: bool = False, + ) -> None: + """Construct an MultiHeadedAttention object.""" + super().__init__() + + self.d_k = embed_size // num_heads + self.num_heads = num_heads + + assert self.d_k * num_heads == embed_size, ( + "embed_size (%d) must be divisible by num_heads (%d)", + (embed_size, num_heads), + ) + + self.linear_q = torch.nn.Linear(embed_size, embed_size) + self.linear_k = torch.nn.Linear(embed_size, embed_size) + self.linear_v = torch.nn.Linear(embed_size, embed_size) + + self.linear_out = torch.nn.Linear(embed_size, embed_size) + + if simplified_attention_score: + self.linear_pos = torch.nn.Linear(embed_size, num_heads) + + self.compute_att_score = self.compute_simplified_attention_score + else: + self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False) + + self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) + self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + self.compute_att_score = self.compute_attention_score + + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.attn = None + + def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor: + """Compute relative positional encoding. + + Args: + x: Input sequence. (B, H, T_1, 2 * T_1 - 1) + left_context: Number of frames in left context. + + Returns: + x: Output sequence. (B, H, T_1, T_2) + + """ + batch_size, n_heads, time1, n = x.shape + time2 = time1 + left_context + + batch_stride, n_heads_stride, time1_stride, n_stride = x.stride() + + return x.as_strided( + (batch_size, n_heads, time1, time2), + (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride), + storage_offset=(n_stride * (time1 - 1)), + ) + + def compute_simplified_attention_score( + self, + query: torch.Tensor, + key: torch.Tensor, + pos_enc: torch.Tensor, + left_context: int = 0, + ) -> torch.Tensor: + """Simplified attention score computation. + + Reference: https://github.com/k2-fsa/icefall/pull/458 + + Args: + query: Transformed query tensor. (B, H, T_1, d_k) + key: Transformed key tensor. (B, H, T_2, d_k) + pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) + left_context: Number of frames in left context. + + Returns: + : Attention score. (B, H, T_1, T_2) + + """ + pos_enc = self.linear_pos(pos_enc) + + matrix_ac = torch.matmul(query, key.transpose(2, 3)) + + matrix_bd = self.rel_shift( + pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1), + left_context=left_context, + ) + + return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) + + def compute_attention_score( + self, + query: torch.Tensor, + key: torch.Tensor, + pos_enc: torch.Tensor, + left_context: int = 0, + ) -> torch.Tensor: + """Attention score computation. + + Args: + query: Transformed query tensor. (B, H, T_1, d_k) + key: Transformed key tensor. (B, H, T_2, d_k) + pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) + left_context: Number of frames in left context. + + Returns: + : Attention score. (B, H, T_1, T_2) + + """ + p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k) + + query = query.transpose(1, 2) + q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) + + matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1)) + + matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1)) + matrix_bd = self.rel_shift(matrix_bd, left_context=left_context) + + return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) + + def forward_qkv( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Transform query, key and value. + + Args: + query: Query tensor. (B, T_1, size) + key: Key tensor. (B, T_2, size) + v: Value tensor. (B, T_2, size) + + Returns: + q: Transformed query tensor. (B, H, T_1, d_k) + k: Transformed key tensor. (B, H, T_2, d_k) + v: Transformed value tensor. (B, H, T_2, d_k) + + """ + n_batch = query.size(0) + + q = ( + self.linear_q(query) + .view(n_batch, -1, self.num_heads, self.d_k) + .transpose(1, 2) + ) + k = ( + self.linear_k(key) + .view(n_batch, -1, self.num_heads, self.d_k) + .transpose(1, 2) + ) + v = ( + self.linear_v(value) + .view(n_batch, -1, self.num_heads, self.d_k) + .transpose(1, 2) + ) + + return q, k, v + + def forward_attention( + self, + value: torch.Tensor, + scores: torch.Tensor, + mask: torch.Tensor, + chunk_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Compute attention context vector. + + Args: + value: Transformed value. (B, H, T_2, d_k) + scores: Attention score. (B, H, T_1, T_2) + mask: Source mask. (B, T_2) + chunk_mask: Chunk mask. (T_1, T_1) + + Returns: + attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k) + + """ + batch_size = scores.size(0) + mask = mask.unsqueeze(1).unsqueeze(2) + if chunk_mask is not None: + mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask + scores = scores.masked_fill(mask, float("-inf")) + self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) + + attn_output = self.dropout(self.attn) + attn_output = torch.matmul(attn_output, value) + + attn_output = self.linear_out( + attn_output.transpose(1, 2) + .contiguous() + .view(batch_size, -1, self.num_heads * self.d_k) + ) + + return attn_output + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + chunk_mask: Optional[torch.Tensor] = None, + left_context: int = 0, + ) -> torch.Tensor: + """Compute scaled dot product attention with rel. positional encoding. + + Args: + query: Query tensor. (B, T_1, size) + key: Key tensor. (B, T_2, size) + value: Value tensor. (B, T_2, size) + pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) + mask: Source mask. (B, T_2) + chunk_mask: Chunk mask. (T_1, T_1) + left_context: Number of frames in left context. + + Returns: + : Output tensor. (B, T_1, H * d_k) + + """ + q, k, v = self.forward_qkv(query, key, value) + scores = self.compute_att_score(q, k, pos_enc, left_context=left_context) + return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask) diff --git a/funasr/models_transducer/encoder/modules/convolution.py b/funasr/models_transducer/encoder/modules/convolution.py new file mode 100644 index 000000000..012538a7d --- /dev/null +++ b/funasr/models_transducer/encoder/modules/convolution.py @@ -0,0 +1,196 @@ +"""Convolution modules for X-former blocks.""" + +from typing import Dict, Optional, Tuple + +import torch + + +class ConformerConvolution(torch.nn.Module): + """ConformerConvolution module definition. + + Args: + channels: The number of channels. + kernel_size: Size of the convolving kernel. + activation: Type of activation function. + norm_args: Normalization module arguments. + causal: Whether to use causal convolution (set to True if streaming). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + activation: torch.nn.Module = torch.nn.ReLU(), + norm_args: Dict = {}, + causal: bool = False, + ) -> None: + """Construct an ConformerConvolution object.""" + super().__init__() + + assert (kernel_size - 1) % 2 == 0 + + self.kernel_size = kernel_size + + self.pointwise_conv1 = torch.nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + ) + + if causal: + self.lorder = kernel_size - 1 + padding = 0 + else: + self.lorder = 0 + padding = (kernel_size - 1) // 2 + + self.depthwise_conv = torch.nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + ) + self.norm = torch.nn.BatchNorm1d(channels, **norm_args) + self.pointwise_conv2 = torch.nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + ) + + self.activation = activation + + def forward( + self, + x: torch.Tensor, + cache: Optional[torch.Tensor] = None, + right_context: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute convolution module. + + Args: + x: ConformerConvolution input sequences. (B, T, D_hidden) + cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden) + right_context: Number of frames in right context. + + Returns: + x: ConformerConvolution output sequences. (B, T, D_hidden) + cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden) + + """ + x = self.pointwise_conv1(x.transpose(1, 2)) + x = torch.nn.functional.glu(x, dim=1) + + if self.lorder > 0: + if cache is None: + x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + else: + x = torch.cat([cache, x], dim=2) + + if right_context > 0: + cache = x[:, :, -(self.lorder + right_context) : -right_context] + else: + cache = x[:, :, -self.lorder :] + + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x).transpose(1, 2) + + return x, cache + + +class ConvolutionalSpatialGatingUnit(torch.nn.Module): + """Convolutional Spatial Gating Unit module definition. + + Args: + size: Initial size to determine the number of channels. + kernel_size: Size of the convolving kernel. + norm_class: Normalization module class. + norm_args: Normalization module arguments. + dropout_rate: Dropout rate. + causal: Whether to use causal convolution (set to True if streaming). + + """ + + def __init__( + self, + size: int, + kernel_size: int, + norm_class: torch.nn.Module = torch.nn.LayerNorm, + norm_args: Dict = {}, + dropout_rate: float = 0.0, + causal: bool = False, + ) -> None: + """Construct a ConvolutionalSpatialGatingUnit object.""" + super().__init__() + + channels = size // 2 + + self.kernel_size = kernel_size + + if causal: + self.lorder = kernel_size - 1 + padding = 0 + else: + self.lorder = 0 + padding = (kernel_size - 1) // 2 + + self.conv = torch.nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + ) + + self.norm = norm_class(channels, **norm_args) + self.activation = torch.nn.Identity() + + self.dropout = torch.nn.Dropout(dropout_rate) + + def forward( + self, + x: torch.Tensor, + cache: Optional[torch.Tensor] = None, + right_context: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute convolution module. + + Args: + x: ConvolutionalSpatialGatingUnit input sequences. (B, T, D_hidden) + cache: ConvolutionalSpationGatingUnit input cache. + (1, conv_kernel, D_hidden) + right_context: Number of frames in right context. + + Returns: + x: ConvolutionalSpatialGatingUnit output sequences. (B, T, D_hidden // 2) + + """ + x_r, x_g = x.chunk(2, dim=-1) + + x_g = self.norm(x_g).transpose(1, 2) + + if self.lorder > 0: + if cache is None: + x_g = torch.nn.functional.pad(x_g, (self.lorder, 0), "constant", 0.0) + else: + x_g = torch.cat([cache, x_g], dim=2) + + if right_context > 0: + cache = x_g[:, :, -(self.lorder + right_context) : -right_context] + else: + cache = x_g[:, :, -self.lorder :] + + x_g = self.conv(x_g).transpose(1, 2) + + x = self.dropout(x_r * self.activation(x_g)) + + return x, cache diff --git a/funasr/models_transducer/encoder/modules/multi_blocks.py b/funasr/models_transducer/encoder/modules/multi_blocks.py new file mode 100644 index 000000000..14aca8b6d --- /dev/null +++ b/funasr/models_transducer/encoder/modules/multi_blocks.py @@ -0,0 +1,105 @@ +"""MultiBlocks for encoder architecture.""" + +from typing import Dict, List, Optional + +import torch + + +class MultiBlocks(torch.nn.Module): + """MultiBlocks definition. + + Args: + block_list: Individual blocks of the encoder architecture. + output_size: Architecture output size. + norm_class: Normalization module class. + norm_args: Normalization module arguments. + + """ + + def __init__( + self, + block_list: List[torch.nn.Module], + output_size: int, + norm_class: torch.nn.Module = torch.nn.LayerNorm, + norm_args: Optional[Dict] = None, + ) -> None: + """Construct a MultiBlocks object.""" + super().__init__() + + self.blocks = torch.nn.ModuleList(block_list) + self.norm_blocks = norm_class(output_size, **norm_args) + + self.num_blocks = len(block_list) + + def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: + """Initialize/Reset encoder streaming cache. + + Args: + left_context: Number of left frames during chunk-by-chunk inference. + device: Device to use for cache tensor. + + """ + for idx in range(self.num_blocks): + self.blocks[idx].reset_streaming_cache(left_context, device) + + def forward( + self, + x: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + chunk_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward each block of the encoder architecture. + + Args: + x: MultiBlocks input sequences. (B, T, D_block_1) + pos_enc: Positional embedding sequences. + mask: Source mask. (B, T) + chunk_mask: Chunk mask. (T_2, T_2) + + Returns: + x: Output sequences. (B, T, D_block_N) + + """ + for block_index, block in enumerate(self.blocks): + x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask) + + x = self.norm_blocks(x) + + return x + + def chunk_forward( + self, + x: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + chunk_size: int = 0, + left_context: int = 0, + right_context: int = 0, + ) -> torch.Tensor: + """Forward each block of the encoder architecture. + + Args: + x: MultiBlocks input sequences. (B, T, D_block_1) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att) + mask: Source mask. (B, T_2) + left_context: Number of frames in left context. + right_context: Number of frames in right context. + + Returns: + x: MultiBlocks output sequences. (B, T, D_block_N) + + """ + for block_idx, block in enumerate(self.blocks): + x, pos_enc = block.chunk_forward( + x, + pos_enc, + mask, + chunk_size=chunk_size, + left_context=left_context, + right_context=right_context, + ) + + x = self.norm_blocks(x) + + return x diff --git a/funasr/models_transducer/encoder/modules/normalization.py b/funasr/models_transducer/encoder/modules/normalization.py new file mode 100644 index 000000000..ae35fd43f --- /dev/null +++ b/funasr/models_transducer/encoder/modules/normalization.py @@ -0,0 +1,170 @@ +"""Normalization modules for X-former blocks.""" + +from typing import Dict, Optional, Tuple + +import torch + + +def get_normalization( + normalization_type: str, + eps: Optional[float] = None, + partial: Optional[float] = None, +) -> Tuple[torch.nn.Module, Dict]: + """Get normalization module and arguments given parameters. + + Args: + normalization_type: Normalization module type. + eps: Value added to the denominator. + partial: Value defining the part of the input used for RMS stats (RMSNorm). + + Return: + : Normalization module class + : Normalization module arguments + + """ + norm = { + "basic_norm": ( + BasicNorm, + {"eps": eps if eps is not None else 0.25}, + ), + "layer_norm": (torch.nn.LayerNorm, {"eps": eps if eps is not None else 1e-12}), + "rms_norm": ( + RMSNorm, + { + "eps": eps if eps is not None else 1e-05, + "partial": partial if partial is not None else -1.0, + }, + ), + "scale_norm": ( + ScaleNorm, + {"eps": eps if eps is not None else 1e-05}, + ), + } + + return norm[normalization_type] + + +class BasicNorm(torch.nn.Module): + """BasicNorm module definition. + + Reference: https://github.com/k2-fsa/icefall/pull/288 + + Args: + normalized_shape: Expected size. + eps: Value added to the denominator for numerical stability. + + """ + + def __init__( + self, + normalized_shape: int, + eps: float = 0.25, + ) -> None: + """Construct a BasicNorm object.""" + super().__init__() + + self.eps = torch.nn.Parameter(torch.tensor(eps).log().detach()) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Compute basic normalization. + + Args: + x: Input sequences. (B, T, D_hidden) + + Returns: + : Output sequences. (B, T, D_hidden) + + """ + scales = (torch.mean(x.pow(2), dim=-1, keepdim=True) + self.eps.exp()) ** -0.5 + + return x * scales + + +class RMSNorm(torch.nn.Module): + """RMSNorm module definition. + + Reference: https://arxiv.org/pdf/1910.07467.pdf + + Args: + normalized_shape: Expected size. + eps: Value added to the denominator for numerical stability. + partial: Value defining the part of the input used for RMS stats. + + """ + + def __init__( + self, + normalized_shape: int, + eps: float = 1e-5, + partial: float = 0.0, + ) -> None: + """Construct a RMSNorm object.""" + super().__init__() + + self.normalized_shape = normalized_shape + + self.partial = True if 0 < partial < 1 else False + self.p = partial + self.eps = eps + + self.scale = torch.nn.Parameter(torch.ones(normalized_shape)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Compute RMS normalization. + + Args: + x: Input sequences. (B, T, D_hidden) + + Returns: + x: Output sequences. (B, T, D_hidden) + + """ + if self.partial: + partial_size = int(self.normalized_shape * self.p) + partial_x, _ = torch.split( + x, [partial_size, self.normalized_shape - partial_size], dim=-1 + ) + + norm_x = partial_x.norm(2, dim=-1, keepdim=True) + d_x = partial_size + else: + norm_x = x.norm(2, dim=-1, keepdim=True) + d_x = self.normalized_shape + + rms_x = norm_x * d_x ** (-1.0 / 2) + x = self.scale * (x / (rms_x + self.eps)) + + return x + + +class ScaleNorm(torch.nn.Module): + """ScaleNorm module definition. + + Reference: https://arxiv.org/pdf/1910.05895.pdf + + Args: + normalized_shape: Expected size. + eps: Value added to the denominator for numerical stability. + + """ + + def __init__(self, normalized_shape: int, eps: float = 1e-5) -> None: + """Construct a ScaleNorm object.""" + super().__init__() + + self.eps = eps + self.scale = torch.nn.Parameter(torch.tensor(normalized_shape**0.5)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Compute scale normalization. + + Args: + x: Input sequences. (B, T, D_hidden) + + Returns: + : Output sequences. (B, T, D_hidden) + + """ + norm = self.scale / torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps) + + return x * norm diff --git a/funasr/models_transducer/encoder/modules/positional_encoding.py b/funasr/models_transducer/encoder/modules/positional_encoding.py new file mode 100644 index 000000000..5b56e2671 --- /dev/null +++ b/funasr/models_transducer/encoder/modules/positional_encoding.py @@ -0,0 +1,91 @@ +"""Positional encoding modules.""" + +import math + +import torch + +from funasr.modules.embedding import _pre_hook + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding. + + Args: + size: Module size. + max_len: Maximum input length. + dropout_rate: Dropout rate. + + """ + + def __init__( + self, size: int, dropout_rate: float = 0.0, max_len: int = 5000 + ) -> None: + """Construct a RelativePositionalEncoding object.""" + super().__init__() + + self.size = size + + self.pe = None + self.dropout = torch.nn.Dropout(p=dropout_rate) + + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + self._register_load_state_dict_pre_hook(_pre_hook) + + def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None: + """Reset positional encoding. + + Args: + x: Input sequences. (B, T, ?) + left_context: Number of frames in left context. + + """ + time1 = x.size(1) + left_context + + if self.pe is not None: + if self.pe.size(1) >= time1 * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(device=x.device, dtype=x.dtype) + return + + pe_positive = torch.zeros(time1, self.size) + pe_negative = torch.zeros(time1, self.size) + + position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.size, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.size) + ) + + pe_positive[:, 0::2] = torch.sin(position * div_term) + pe_positive[:, 1::2] = torch.cos(position * div_term) + pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) + + pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) + pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) + pe_negative = pe_negative[1:].unsqueeze(0) + + self.pe = torch.cat([pe_positive, pe_negative], dim=1).to( + dtype=x.dtype, device=x.device + ) + + def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor: + """Compute positional encoding. + + Args: + x: Input sequences. (B, T, ?) + left_context: Number of frames in left context. + + Returns: + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), ?) + + """ + self.extend_pe(x, left_context=left_context) + + time1 = x.size(1) + left_context + + pos_enc = self.pe[ + :, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1) + ] + pos_enc = self.dropout(pos_enc) + + return pos_enc diff --git a/funasr/models_transducer/encoder/sanm_encoder.py b/funasr/models_transducer/encoder/sanm_encoder.py new file mode 100644 index 000000000..9e74bdfeb --- /dev/null +++ b/funasr/models_transducer/encoder/sanm_encoder.py @@ -0,0 +1,835 @@ +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union +import logging +import torch +import torch.nn as nn +from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk +from typeguard import check_argument_types +import numpy as np +from funasr.modules.nets_utils import make_pad_mask +from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM +from funasr.modules.embedding import SinusoidalPositionEncoder +from funasr.modules.layer_norm import LayerNorm +from funasr.modules.multi_layer_conv import Conv1dLinear +from funasr.modules.multi_layer_conv import MultiLayeredConv1d +from funasr.modules.positionwise_feed_forward import ( + PositionwiseFeedForward, # noqa: H301 +) +from funasr.modules.repeat import repeat +from funasr.modules.subsampling import Conv2dSubsampling +from funasr.modules.subsampling import Conv2dSubsampling2 +from funasr.modules.subsampling import Conv2dSubsampling6 +from funasr.modules.subsampling import Conv2dSubsampling8 +from funasr.modules.subsampling import TooShortUttError +from funasr.modules.subsampling import check_short_utt +from funasr.models.ctc import CTC +from funasr.models.encoder.abs_encoder import AbsEncoder + + +class EncoderLayerSANM(nn.Module): + def __init__( + self, + in_size, + size, + self_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + stochastic_depth_rate=0.0, + ): + """Construct an EncoderLayer object.""" + super(EncoderLayerSANM, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(in_size) + self.norm2 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.in_size = in_size + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + self.stochastic_depth_rate = stochastic_depth_rate + self.dropout_rate = dropout_rate + + def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None): + """Compute encoded features. + Args: + x_input (torch.Tensor): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + """ + skip_layer = False + # with stochastic depth, residual connection `x + f(x)` becomes + # `x <- x + 1 / (1 - p) * f(x)` at training time. + stoch_layer_coeff = 1.0 + if self.training and self.stochastic_depth_rate > 0: + skip_layer = torch.rand(1).item() < self.stochastic_depth_rate + stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) + + if skip_layer: + if cache is not None: + x = torch.cat([cache, x], dim=1) + return x, mask + + residual = x + if self.normalize_before: + x = self.norm1(x) + + if self.concat_after: + x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1) + if self.in_size == self.size: + x = residual + stoch_layer_coeff * self.concat_linear(x_concat) + else: + x = stoch_layer_coeff * self.concat_linear(x_concat) + else: + if self.in_size == self.size: + x = residual + stoch_layer_coeff * self.dropout( + self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder) + ) + else: + x = stoch_layer_coeff * self.dropout( + self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder) + ) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + + return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder + +class SANMEncoder(AbsEncoder): + """ + author: Speech Lab, Alibaba Group, China + San-m: Memory equipped self-attention for end-to-end speech recognition + https://arxiv.org/abs/2006.01713 + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + pos_enc_class=SinusoidalPositionEncoder, + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = "linear", + positionwise_conv_kernel_size: int = 1, + padding_idx: int = -1, + interctc_layer_idx: List[int] = [], + interctc_use_conditioning: bool = False, + kernel_size : int = 11, + sanm_shfit : int = 0, + tf2torch_tensor_name_prefix_torch: str = "encoder", + tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder", + ): + assert check_argument_types() + super().__init__() + + self.embed = SinusoidalPositionEncoder() + self.normalize_before = normalize_before + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + + encoder_selfattn_layer = MultiHeadedAttentionSANM + encoder_selfattn_layer_args0 = ( + attention_heads, + input_size, + output_size, + attention_dropout_rate, + kernel_size, + sanm_shfit, + ) + + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + output_size, + attention_dropout_rate, + kernel_size, + sanm_shfit, + ) + self.encoders0 = repeat( + 1, + lambda lnum: EncoderLayerSANM( + input_size, + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args0), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + self.encoders = repeat( + num_blocks-1, + lambda lnum: EncoderLayerSANM( + output_size, + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(output_size) + + self.interctc_layer_idx = interctc_layer_idx + if len(interctc_layer_idx) > 0: + assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks + self.interctc_use_conditioning = interctc_use_conditioning + self.conditioning_layer = None + self.dropout = nn.Dropout(dropout_rate) + self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch + self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ctc: CTC = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Embed positions in tensor. + Args: + xs_pad: input tensor (B, L, D) + ilens: input length (B) + prev_states: Not to be used now. + Returns: + position embedded tensor and mask + """ + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + xs_pad = xs_pad * self.output_size**0.5 + if self.embed is None: + xs_pad = xs_pad + elif ( + isinstance(self.embed, Conv2dSubsampling) + or isinstance(self.embed, Conv2dSubsampling2) + or isinstance(self.embed, Conv2dSubsampling6) + or isinstance(self.embed, Conv2dSubsampling8) + ): + short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) + if short_status: + raise TooShortUttError( + f"has {xs_pad.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + xs_pad.size(1), + limit_size, + ) + xs_pad, masks = self.embed(xs_pad, masks) + else: + xs_pad = self.embed(xs_pad) + + # xs_pad = self.dropout(xs_pad) + encoder_outs = self.encoders0(xs_pad, masks) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + intermediate_outs = [] + if len(self.interctc_layer_idx) == 0: + encoder_outs = self.encoders(xs_pad, masks) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + else: + for layer_idx, encoder_layer in enumerate(self.encoders): + encoder_outs = encoder_layer(xs_pad, masks) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + + if layer_idx + 1 in self.interctc_layer_idx: + encoder_out = xs_pad + + # intermediate outputs are also normalized + if self.normalize_before: + encoder_out = self.after_norm(encoder_out) + + intermediate_outs.append((layer_idx + 1, encoder_out)) + + if self.interctc_use_conditioning: + ctc_out = ctc.softmax(encoder_out) + xs_pad = xs_pad + self.conditioning_layer(ctc_out) + + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + + olens = masks.squeeze(1).sum(1) + if len(intermediate_outs) > 0: + return (xs_pad, intermediate_outs), olens, None + return xs_pad, olens + + def gen_tf2torch_map_dict(self): + tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch + tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf + map_dict_local = { + ## encoder + # cicd + "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (768,256),(1,256,768) + "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (768,),(768,) + "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 2, 0), + }, # (256,1,31),(1,31,256,1) + "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (256,256),(1,256,256) + "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + # ffn + "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (1024,256),(1,256,1024) + "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (1024,),(1024,) + "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (256,1024),(1,1024,256) + "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + # out norm + "{}.after_norm.weight".format(tensor_name_prefix_torch): + {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.after_norm.bias".format(tensor_name_prefix_torch): + {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + + } + + return map_dict_local + + def convert_tf2torch(self, + var_dict_tf, + var_dict_torch, + ): + + map_dict = self.gen_tf2torch_map_dict() + + var_dict_torch_update = dict() + for name in sorted(var_dict_torch.keys(), reverse=False): + names = name.split('.') + if names[0] == self.tf2torch_tensor_name_prefix_torch: + if names[1] == "encoders0": + layeridx = int(names[2]) + name_q = name.replace(".{}.".format(layeridx), ".layeridx.") + + name_q = name_q.replace("encoders0", "encoders") + layeridx_bias = 0 + layeridx += layeridx_bias + if name_q in map_dict.keys(): + name_v = map_dict[name_q]["name"] + name_tf = name_v.replace("layeridx", "{}".format(layeridx)) + data_tf = var_dict_tf[name_tf] + if map_dict[name_q]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) + if map_dict[name_q]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[ + name].size(), + data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info( + "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, + var_dict_tf[name_tf].shape)) + elif names[1] == "encoders": + layeridx = int(names[2]) + name_q = name.replace(".{}.".format(layeridx), ".layeridx.") + layeridx_bias = 1 + layeridx += layeridx_bias + if name_q in map_dict.keys(): + name_v = map_dict[name_q]["name"] + name_tf = name_v.replace("layeridx", "{}".format(layeridx)) + data_tf = var_dict_tf[name_tf] + if map_dict[name_q]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) + if map_dict[name_q]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[ + name].size(), + data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info( + "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, + var_dict_tf[name_tf].shape)) + + elif names[1] == "after_norm": + name_tf = map_dict[name]["name"] + data_tf = var_dict_tf[name_tf] + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + var_dict_torch_update[name] = data_tf + logging.info( + "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, + var_dict_tf[name_tf].shape)) + + return var_dict_torch_update + + +class SANMEncoderChunkOpt(AbsEncoder): + """ + author: Speech Lab, Alibaba Group, China + SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition + https://arxiv.org/abs/2006.01713 + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + pos_enc_class=SinusoidalPositionEncoder, + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = "linear", + positionwise_conv_kernel_size: int = 1, + padding_idx: int = -1, + interctc_layer_idx: List[int] = [], + interctc_use_conditioning: bool = False, + kernel_size: int = 11, + sanm_shfit: int = 0, + chunk_size: Union[int, Sequence[int]] = (16,), + stride: Union[int, Sequence[int]] = (10,), + pad_left: Union[int, Sequence[int]] = (0,), + time_reduction_factor: int = 1, + encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,), + decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,), + tf2torch_tensor_name_prefix_torch: str = "encoder", + tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder", + ): + assert check_argument_types() + super().__init__() + self.output_size = output_size + + self.embed = SinusoidalPositionEncoder() + + self.normalize_before = normalize_before + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + + encoder_selfattn_layer = MultiHeadedAttentionSANM + encoder_selfattn_layer_args0 = ( + attention_heads, + input_size, + output_size, + attention_dropout_rate, + kernel_size, + sanm_shfit, + ) + + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + output_size, + attention_dropout_rate, + kernel_size, + sanm_shfit, + ) + self.encoders0 = repeat( + 1, + lambda lnum: EncoderLayerSANM( + input_size, + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args0), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + self.encoders = repeat( + num_blocks - 1, + lambda lnum: EncoderLayerSANM( + output_size, + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(output_size) + + self.interctc_layer_idx = interctc_layer_idx + if len(interctc_layer_idx) > 0: + assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks + self.interctc_use_conditioning = interctc_use_conditioning + self.conditioning_layer = None + shfit_fsmn = (kernel_size - 1) // 2 + self.overlap_chunk_cls = overlap_chunk( + chunk_size=chunk_size, + stride=stride, + pad_left=pad_left, + shfit_fsmn=shfit_fsmn, + encoder_att_look_back_factor=encoder_att_look_back_factor, + decoder_att_look_back_factor=decoder_att_look_back_factor, + ) + self.time_reduction_factor = time_reduction_factor + self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch + self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ctc: CTC = None, + ind: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Embed positions in tensor. + Args: + xs_pad: input tensor (B, L, D) + ilens: input length (B) + prev_states: Not to be used now. + Returns: + position embedded tensor and mask + """ + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + xs_pad *= self.output_size ** 0.5 + if self.embed is None: + xs_pad = xs_pad + elif ( + isinstance(self.embed, Conv2dSubsampling) + or isinstance(self.embed, Conv2dSubsampling2) + or isinstance(self.embed, Conv2dSubsampling6) + or isinstance(self.embed, Conv2dSubsampling8) + ): + short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) + if short_status: + raise TooShortUttError( + f"has {xs_pad.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + xs_pad.size(1), + limit_size, + ) + xs_pad, masks = self.embed(xs_pad, masks) + else: + xs_pad = self.embed(xs_pad) + + mask_shfit_chunk, mask_att_chunk_encoder = None, None + if self.overlap_chunk_cls is not None: + ilens = masks.squeeze(1).sum(1) + chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind) + xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs) + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0), + dtype=xs_pad.dtype) + mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device, + xs_pad.size(0), + dtype=xs_pad.dtype) + + encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + intermediate_outs = [] + if len(self.interctc_layer_idx) == 0: + encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + else: + for layer_idx, encoder_layer in enumerate(self.encoders): + encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + if layer_idx + 1 in self.interctc_layer_idx: + encoder_out = xs_pad + + # intermediate outputs are also normalized + if self.normalize_before: + encoder_out = self.after_norm(encoder_out) + + intermediate_outs.append((layer_idx + 1, encoder_out)) + + if self.interctc_use_conditioning: + ctc_out = ctc.softmax(encoder_out) + xs_pad = xs_pad + self.conditioning_layer(ctc_out) + + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + + olens = masks.squeeze(1).sum(1) + + xs_pad, olens = self.overlap_chunk_cls.remove_chunk(xs_pad, olens, chunk_outs=None) + + if self.time_reduction_factor > 1: + xs_pad = xs_pad[:,::self.time_reduction_factor,:] + olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 + + if len(intermediate_outs) > 0: + return (xs_pad, intermediate_outs), olens, None + return xs_pad, olens + + def gen_tf2torch_map_dict(self): + tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch + tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf + map_dict_local = { + ## encoder + # cicd + "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (768,256),(1,256,768) + "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (768,),(768,) + "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 2, 0), + }, # (256,1,31),(1,31,256,1) + "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (256,256),(1,256,256) + "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + # ffn + "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (1024,256),(1,256,1024) + "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (1024,),(1024,) + "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (256,1024),(1,1024,256) + "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + # out norm + "{}.after_norm.weight".format(tensor_name_prefix_torch): + {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.after_norm.bias".format(tensor_name_prefix_torch): + {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + + } + + return map_dict_local + + def convert_tf2torch(self, + var_dict_tf, + var_dict_torch, + ): + + map_dict = self.gen_tf2torch_map_dict() + + var_dict_torch_update = dict() + for name in sorted(var_dict_torch.keys(), reverse=False): + names = name.split('.') + if names[0] == self.tf2torch_tensor_name_prefix_torch: + if names[1] == "encoders0": + layeridx = int(names[2]) + name_q = name.replace(".{}.".format(layeridx), ".layeridx.") + + name_q = name_q.replace("encoders0", "encoders") + layeridx_bias = 0 + layeridx += layeridx_bias + if name_q in map_dict.keys(): + name_v = map_dict[name_q]["name"] + name_tf = name_v.replace("layeridx", "{}".format(layeridx)) + data_tf = var_dict_tf[name_tf] + if map_dict[name_q]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) + if map_dict[name_q]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[ + name].size(), + data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info( + "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, + var_dict_tf[name_tf].shape)) + elif names[1] == "encoders": + layeridx = int(names[2]) + name_q = name.replace(".{}.".format(layeridx), ".layeridx.") + layeridx_bias = 1 + layeridx += layeridx_bias + if name_q in map_dict.keys(): + name_v = map_dict[name_q]["name"] + name_tf = name_v.replace("layeridx", "{}".format(layeridx)) + data_tf = var_dict_tf[name_tf] + if map_dict[name_q]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) + if map_dict[name_q]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[ + name].size(), + data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info( + "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, + var_dict_tf[name_tf].shape)) + + elif names[1] == "after_norm": + name_tf = map_dict[name]["name"] + data_tf = var_dict_tf[name_tf] + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + var_dict_torch_update[name] = data_tf + logging.info( + "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, + var_dict_tf[name_tf].shape)) + + return var_dict_torch_update diff --git a/funasr/models_transducer/encoder/validation.py b/funasr/models_transducer/encoder/validation.py new file mode 100644 index 000000000..00035363a --- /dev/null +++ b/funasr/models_transducer/encoder/validation.py @@ -0,0 +1,171 @@ +"""Set of methods to validate encoder architecture.""" + +from typing import Any, Dict, List, Tuple + +from funasr.models_transducer.utils import sub_factor_to_params + + +def validate_block_arguments( + configuration: Dict[str, Any], + block_id: int, + previous_block_output: int, +) -> Tuple[int, int]: + """Validate block arguments. + + Args: + configuration: Architecture configuration. + block_id: Block ID. + previous_block_output: Previous block output size. + + Returns: + input_size: Block input size. + output_size: Block output size. + + """ + block_type = configuration.get("block_type") + + if block_type is None: + raise ValueError( + "Block %d in encoder doesn't have a type assigned. " % block_id + ) + + if block_type in ["branchformer", "conformer"]: + if configuration.get("linear_size") is None: + raise ValueError( + "Missing 'linear_size' argument for X-former block (ID: %d)" % block_id + ) + + if configuration.get("conv_mod_kernel_size") is None: + raise ValueError( + "Missing 'conv_mod_kernel_size' argument for X-former block (ID: %d)" + % block_id + ) + + input_size = configuration.get("hidden_size") + output_size = configuration.get("hidden_size") + + elif block_type == "conv1d": + output_size = configuration.get("output_size") + + if output_size is None: + raise ValueError( + "Missing 'output_size' argument for Conv1d block (ID: %d)" % block_id + ) + + if configuration.get("kernel_size") is None: + raise ValueError( + "Missing 'kernel_size' argument for Conv1d block (ID: %d)" % block_id + ) + + input_size = configuration["input_size"] = previous_block_output + else: + raise ValueError("Block type: %s is not supported." % block_type) + + return input_size, output_size + + +def validate_input_block( + configuration: Dict[str, Any], body_first_conf: Dict[str, Any], input_size: int +) -> int: + """Validate input block. + + Args: + configuration: Encoder input block configuration. + body_first_conf: Encoder first body block configuration. + input_size: Encoder input block input size. + + Return: + output_size: Encoder input block output size. + + """ + vgg_like = configuration.get("vgg_like", False) + linear = configuration.get("linear", False) + next_block_type = body_first_conf.get("block_type") + allowed_next_block_type = ["branchformer", "conformer", "conv1d"] + + if next_block_type is None or (next_block_type not in allowed_next_block_type): + return -1 + + if configuration.get("subsampling_factor") is None: + configuration["subsampling_factor"] = 4 + + if vgg_like: + conv_size = configuration.get("conv_size", (64, 128)) + + if isinstance(conv_size, int): + conv_size = (conv_size, conv_size) + else: + conv_size = configuration.get("conv_size", None) + + if isinstance(conv_size, tuple): + conv_size = conv_size[0] + + if next_block_type == "conv1d": + if vgg_like: + output_size = conv_size[1] * ((input_size // 2) // 2) + else: + if conv_size is None: + conv_size = body_first_conf.get("output_size", 64) + + sub_factor = configuration["subsampling_factor"] + + _, _, conv_osize = sub_factor_to_params(sub_factor, input_size) + assert ( + conv_osize > 0 + ), "Conv2D output size is <1 with input size %d and subsampling %d" % ( + input_size, + sub_factor, + ) + + output_size = conv_osize * conv_size + + configuration["output_size"] = None + else: + output_size = body_first_conf.get("hidden_size") + + if conv_size is None: + conv_size = output_size + + configuration["output_size"] = output_size + + configuration["conv_size"] = conv_size + configuration["vgg_like"] = vgg_like + configuration["linear"] = linear + + return output_size + + +def validate_architecture( + input_conf: Dict[str, Any], body_conf: List[Dict[str, Any]], input_size: int +) -> Tuple[int, int]: + """Validate specified architecture is valid. + + Args: + input_conf: Encoder input block configuration. + body_conf: Encoder body blocks configuration. + input_size: Encoder input size. + + Returns: + input_block_osize: Encoder input block output size. + : Encoder body block output size. + + """ + input_block_osize = validate_input_block(input_conf, body_conf[0], input_size) + + cmp_io = [] + + for i, b in enumerate(body_conf): + _io = validate_block_arguments( + b, (i + 1), input_block_osize if i == 0 else cmp_io[i - 1][1] + ) + + cmp_io.append(_io) + + for i in range(1, len(cmp_io)): + if cmp_io[(i - 1)][1] != cmp_io[i][0]: + raise ValueError( + "Output/Input mismatch between blocks %d and %d" + " in the encoder body." % ((i - 1), i) + ) + + return input_block_osize, cmp_io[-1][1] diff --git a/funasr/models_transducer/error_calculator.py b/funasr/models_transducer/error_calculator.py new file mode 100644 index 000000000..17dbf362f --- /dev/null +++ b/funasr/models_transducer/error_calculator.py @@ -0,0 +1,170 @@ +"""Error Calculator module for Transducer.""" + +from typing import List, Optional, Tuple + +import torch + +from funasr.models_transducer.beam_search_transducer import BeamSearchTransducer +from funasr.models_transducer.decoder.abs_decoder import AbsDecoder +from funasr.models_transducer.joint_network import JointNetwork + + +class ErrorCalculator: + """Calculate CER and WER for transducer models. + + Args: + decoder: Decoder module. + joint_network: Joint Network module. + token_list: List of token units. + sym_space: Space symbol. + sym_blank: Blank symbol. + report_cer: Whether to compute CER. + report_wer: Whether to compute WER. + + """ + + def __init__( + self, + decoder: AbsDecoder, + joint_network: JointNetwork, + token_list: List[int], + sym_space: str, + sym_blank: str, + report_cer: bool = False, + report_wer: bool = False, + ) -> None: + """Construct an ErrorCalculatorTransducer object.""" + super().__init__() + + self.beam_search = BeamSearchTransducer( + decoder=decoder, + joint_network=joint_network, + beam_size=1, + search_type="default", + score_norm=False, + ) + + self.decoder = decoder + + self.token_list = token_list + self.space = sym_space + self.blank = sym_blank + + self.report_cer = report_cer + self.report_wer = report_wer + + def __call__( + self, encoder_out: torch.Tensor, target: torch.Tensor + ) -> Tuple[Optional[float], Optional[float]]: + """Calculate sentence-level WER or/and CER score for Transducer model. + + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + target: Target label ID sequences. (B, L) + + Returns: + : Sentence-level CER score. + : Sentence-level WER score. + + """ + cer, wer = None, None + + batchsize = int(encoder_out.size(0)) + + encoder_out = encoder_out.to(next(self.decoder.parameters()).device) + + batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)] + pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest] + + char_pred, char_target = self.convert_to_char(pred, target) + + if self.report_cer: + cer = self.calculate_cer(char_pred, char_target) + + if self.report_wer: + wer = self.calculate_wer(char_pred, char_target) + + return cer, wer + + def convert_to_char( + self, pred: torch.Tensor, target: torch.Tensor + ) -> Tuple[List, List]: + """Convert label ID sequences to character sequences. + + Args: + pred: Prediction label ID sequences. (B, U) + target: Target label ID sequences. (B, L) + + Returns: + char_pred: Prediction character sequences. (B, ?) + char_target: Target character sequences. (B, ?) + + """ + char_pred, char_target = [], [] + + for i, pred_i in enumerate(pred): + char_pred_i = [self.token_list[int(h)] for h in pred_i] + char_target_i = [self.token_list[int(r)] for r in target[i]] + + char_pred_i = "".join(char_pred_i).replace(self.space, " ") + char_pred_i = char_pred_i.replace(self.blank, "") + + char_target_i = "".join(char_target_i).replace(self.space, " ") + char_target_i = char_target_i.replace(self.blank, "") + + char_pred.append(char_pred_i) + char_target.append(char_target_i) + + return char_pred, char_target + + def calculate_cer( + self, char_pred: torch.Tensor, char_target: torch.Tensor + ) -> float: + """Calculate sentence-level CER score. + + Args: + char_pred: Prediction character sequences. (B, ?) + char_target: Target character sequences. (B, ?) + + Returns: + : Average sentence-level CER score. + + """ + import editdistance + + distances, lens = [], [] + + for i, char_pred_i in enumerate(char_pred): + pred = char_pred_i.replace(" ", "") + target = char_target[i].replace(" ", "") + + distances.append(editdistance.eval(pred, target)) + lens.append(len(target)) + + return float(sum(distances)) / sum(lens) + + def calculate_wer( + self, char_pred: torch.Tensor, char_target: torch.Tensor + ) -> float: + """Calculate sentence-level WER score. + + Args: + char_pred: Prediction character sequences. (B, ?) + char_target: Target character sequences. (B, ?) + + Returns: + : Average sentence-level WER score + + """ + import editdistance + + distances, lens = [], [] + + for i, char_pred_i in enumerate(char_pred): + pred = char_pred_i.replace("▁", " ").split() + target = char_target[i].replace("▁", " ").split() + + distances.append(editdistance.eval(pred, target)) + lens.append(len(target)) + + return float(sum(distances)) / sum(lens) diff --git a/funasr/models_transducer/espnet_transducer_model.py b/funasr/models_transducer/espnet_transducer_model.py new file mode 100644 index 000000000..e32f6e350 --- /dev/null +++ b/funasr/models_transducer/espnet_transducer_model.py @@ -0,0 +1,484 @@ +"""ESPnet2 ASR Transducer model.""" + +import logging +from contextlib import contextmanager +from typing import Dict, List, Optional, Tuple, Union + +import torch +from packaging.version import parse as V +from typeguard import check_argument_types + +from funasr.models.frontend.abs_frontend import AbsFrontend +from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.models_transducer.decoder.abs_decoder import AbsDecoder +from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder +from funasr.models_transducer.encoder.encoder import Encoder +from funasr.models_transducer.joint_network import JointNetwork +from funasr.models_transducer.utils import get_transducer_task_io +from funasr.layers.abs_normalize import AbsNormalize +from funasr.torch_utils.device_funcs import force_gatherable +from funasr.train.abs_espnet_model import AbsESPnetModel + +if V(torch.__version__) >= V("1.6.0"): + from torch.cuda.amp import autocast +else: + + @contextmanager + def autocast(enabled=True): + yield + + +class ESPnetASRTransducerModel(AbsESPnetModel): + """ESPnet2ASRTransducerModel module definition. + + Args: + vocab_size: Size of complete vocabulary (w/ EOS and blank included). + token_list: List of token + frontend: Frontend module. + specaug: SpecAugment module. + normalize: Normalization module. + encoder: Encoder module. + decoder: Decoder module. + joint_network: Joint Network module. + transducer_weight: Weight of the Transducer loss. + fastemit_lambda: FastEmit lambda value. + auxiliary_ctc_weight: Weight of auxiliary CTC loss. + auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs. + auxiliary_lm_loss_weight: Weight of auxiliary LM loss. + auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing. + ignore_id: Initial padding ID. + sym_space: Space symbol. + sym_blank: Blank Symbol + report_cer: Whether to report Character Error Rate during validation. + report_wer: Whether to report Word Error Rate during validation. + extract_feats_in_collect_stats: Whether to use extract_feats stats collection. + + """ + + def __init__( + self, + vocab_size: int, + token_list: Union[Tuple[str, ...], List[str]], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + encoder: Encoder, + decoder: AbsDecoder, + att_decoder: Optional[AbsAttDecoder], + joint_network: JointNetwork, + transducer_weight: float = 1.0, + fastemit_lambda: float = 0.0, + auxiliary_ctc_weight: float = 0.0, + auxiliary_ctc_dropout_rate: float = 0.0, + auxiliary_lm_loss_weight: float = 0.0, + auxiliary_lm_loss_smoothing: float = 0.0, + ignore_id: int = -1, + sym_space: str = "", + sym_blank: str = "", + report_cer: bool = True, + report_wer: bool = True, + extract_feats_in_collect_stats: bool = True, + ) -> None: + """Construct an ESPnetASRTransducerModel object.""" + super().__init__() + + assert check_argument_types() + + # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos) + self.blank_id = 0 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.token_list = token_list.copy() + + self.sym_space = sym_space + self.sym_blank = sym_blank + + self.frontend = frontend + self.specaug = specaug + self.normalize = normalize + + self.encoder = encoder + self.decoder = decoder + self.joint_network = joint_network + + self.criterion_transducer = None + self.error_calculator = None + + self.use_auxiliary_ctc = auxiliary_ctc_weight > 0 + self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0 + + if self.use_auxiliary_ctc: + self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size) + self.ctc_dropout_rate = auxiliary_ctc_dropout_rate + + if self.use_auxiliary_lm_loss: + self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size) + self.lm_loss_smoothing = auxiliary_lm_loss_smoothing + + self.transducer_weight = transducer_weight + self.fastemit_lambda = fastemit_lambda + + self.auxiliary_ctc_weight = auxiliary_ctc_weight + self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight + + self.report_cer = report_cer + self.report_wer = report_wer + + self.extract_feats_in_collect_stats = extract_feats_in_collect_stats + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Forward architecture and compute loss(es). + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + text: Label ID sequences. (B, L) + text_lengths: Label ID sequences lengths. (B,) + kwargs: Contains "utts_id". + + Return: + loss: Main loss value. + stats: Task statistics. + weight: Task weights. + + """ + assert text_lengths.dim() == 1, text_lengths.shape + assert ( + speech.shape[0] + == speech_lengths.shape[0] + == text.shape[0] + == text_lengths.shape[0] + ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + + batch_size = speech.shape[0] + text = text[:, : text_lengths.max()] + + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + + # 2. Transducer-related I/O preparation + decoder_in, target, t_len, u_len = get_transducer_task_io( + text, + encoder_out_lens, + ignore_id=self.ignore_id, + ) + + # 3. Decoder + self.decoder.set_device(encoder_out.device) + decoder_out = self.decoder(decoder_in, u_len) + + # 4. Joint Network + joint_out = self.joint_network( + encoder_out.unsqueeze(2), decoder_out.unsqueeze(1) + ) + + # 5. Losses + loss_trans, cer_trans, wer_trans = self._calc_transducer_loss( + encoder_out, + joint_out, + target, + t_len, + u_len, + ) + + loss_ctc, loss_lm = 0.0, 0.0 + + if self.use_auxiliary_ctc: + loss_ctc = self._calc_ctc_loss( + encoder_out, + target, + t_len, + u_len, + ) + + if self.use_auxiliary_lm_loss: + loss_lm = self._calc_lm_loss(decoder_out, target) + + loss = ( + self.transducer_weight * loss_trans + + self.auxiliary_ctc_weight * loss_ctc + + self.auxiliary_lm_loss_weight * loss_lm + ) + + stats = dict( + loss=loss.detach(), + loss_transducer=loss_trans.detach(), + aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None, + aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None, + cer_transducer=cer_trans, + wer_transducer=wer_trans, + ) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + + return loss, stats, weight + + def collect_feats( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Dict[str, torch.Tensor]: + """Collect features sequences and features lengths sequences. + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + text: Label ID sequences. (B, L) + text_lengths: Label ID sequences lengths. (B,) + kwargs: Contains "utts_id". + + Return: + {}: "feats": Features sequences. (B, T, D_feats), + "feats_lengths": Features sequences lengths. (B,) + + """ + if self.extract_feats_in_collect_stats: + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + else: + # Generate dummy stats if extract_feats_in_collect_stats is False + logging.warning( + "Generating dummy stats for feats and feats_lengths, " + "because encoder_conf.extract_feats_in_collect_stats is " + f"{self.extract_feats_in_collect_stats}" + ) + + feats, feats_lengths = speech, speech_lengths + + return {"feats": feats, "feats_lengths": feats_lengths} + + def encode( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encoder speech sequences. + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + + Return: + encoder_out: Encoder outputs. (B, T, D_enc) + encoder_out_lens: Encoder outputs lengths. (B,) + + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # 4. Forward encoder + encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths) + + assert encoder_out.size(0) == speech.size(0), ( + encoder_out.size(), + speech.size(0), + ) + assert encoder_out.size(1) <= encoder_out_lens.max(), ( + encoder_out.size(), + encoder_out_lens.max(), + ) + + return encoder_out, encoder_out_lens + + def _extract_feats( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Extract features sequences and features sequences lengths. + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + + Return: + feats: Features sequences. (B, T, D_feats) + feats_lengths: Features sequences lengths. (B,) + + """ + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, : speech_lengths.max()] + + if self.frontend is not None: + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + feats, feats_lengths = speech, speech_lengths + + return feats, feats_lengths + + def _calc_transducer_loss( + self, + encoder_out: torch.Tensor, + joint_out: torch.Tensor, + target: torch.Tensor, + t_len: torch.Tensor, + u_len: torch.Tensor, + ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]: + """Compute Transducer loss. + + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + joint_out: Joint Network output sequences (B, T, U, D_joint) + target: Target label ID sequences. (B, L) + t_len: Encoder output sequences lengths. (B,) + u_len: Target label ID sequences lengths. (B,) + + Return: + loss_transducer: Transducer loss value. + cer_transducer: Character error rate for Transducer. + wer_transducer: Word Error Rate for Transducer. + + """ + if self.criterion_transducer is None: + try: + # from warprnnt_pytorch import RNNTLoss + # self.criterion_transducer = RNNTLoss( + # reduction="mean", + # fastemit_lambda=self.fastemit_lambda, + # ) + from warp_rnnt import rnnt_loss as RNNTLoss + self.criterion_transducer = RNNTLoss + + except ImportError: + logging.error( + "warp-rnnt was not installed." + "Please consult the installation documentation." + ) + exit(1) + + # loss_transducer = self.criterion_transducer( + # joint_out, + # target, + # t_len, + # u_len, + # ) + log_probs = torch.log_softmax(joint_out, dim=-1) + + loss_transducer = self.criterion_transducer( + log_probs, + target, + t_len, + u_len, + reduction="mean", + blank=self.blank_id, + fastemit_lambda=self.fastemit_lambda, + gather=True, + ) + + if not self.training and (self.report_cer or self.report_wer): + if self.error_calculator is None: + from espnet2.asr_transducer.error_calculator import ErrorCalculator + + self.error_calculator = ErrorCalculator( + self.decoder, + self.joint_network, + self.token_list, + self.sym_space, + self.sym_blank, + report_cer=self.report_cer, + report_wer=self.report_wer, + ) + + cer_transducer, wer_transducer = self.error_calculator(encoder_out, target) + + return loss_transducer, cer_transducer, wer_transducer + + return loss_transducer, None, None + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + target: torch.Tensor, + t_len: torch.Tensor, + u_len: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + target: Target label ID sequences. (B, L) + t_len: Encoder output sequences lengths. (B,) + u_len: Target label ID sequences lengths. (B,) + + Return: + loss_ctc: CTC loss value. + + """ + ctc_in = self.ctc_lin( + torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate) + ) + ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1) + + target_mask = target != 0 + ctc_target = target[target_mask].cpu() + + with torch.backends.cudnn.flags(deterministic=True): + loss_ctc = torch.nn.functional.ctc_loss( + ctc_in, + ctc_target, + t_len, + u_len, + zero_infinity=True, + reduction="sum", + ) + loss_ctc /= target.size(0) + + return loss_ctc + + def _calc_lm_loss( + self, + decoder_out: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + """Compute LM loss. + + Args: + decoder_out: Decoder output sequences. (B, U, D_dec) + target: Target label ID sequences. (B, L) + + Return: + loss_lm: LM loss value. + + """ + lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size) + lm_target = target.view(-1).type(torch.int64) + + with torch.no_grad(): + true_dist = lm_loss_in.clone() + true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1)) + + # Ignore blank ID (0) + ignore = lm_target == 0 + lm_target = lm_target.masked_fill(ignore, 0) + + true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing)) + + loss_lm = torch.nn.functional.kl_div( + torch.log_softmax(lm_loss_in, dim=1), + true_dist, + reduction="none", + ) + loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size( + 0 + ) + + return loss_lm diff --git a/funasr/models_transducer/espnet_transducer_model_uni_asr.py b/funasr/models_transducer/espnet_transducer_model_uni_asr.py new file mode 100644 index 000000000..2add3fa78 --- /dev/null +++ b/funasr/models_transducer/espnet_transducer_model_uni_asr.py @@ -0,0 +1,485 @@ +"""ESPnet2 ASR Transducer model.""" + +import logging +from contextlib import contextmanager +from typing import Dict, List, Optional, Tuple, Union + +import torch +from packaging.version import parse as V +from typeguard import check_argument_types + +from funasr.models.frontend.abs_frontend import AbsFrontend +from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.models_transducer.decoder.abs_decoder import AbsDecoder +from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder +from funasr.models_transducer.encoder.encoder import Encoder +from funasr.models_transducer.joint_network import JointNetwork +from funasr.models_transducer.utils import get_transducer_task_io +from funasr.layers.abs_normalize import AbsNormalize +from funasr.torch_utils.device_funcs import force_gatherable +from funasr.train.abs_espnet_model import AbsESPnetModel + +if V(torch.__version__) >= V("1.6.0"): + from torch.cuda.amp import autocast +else: + + @contextmanager + def autocast(enabled=True): + yield + + +class UniASRTransducerModel(AbsESPnetModel): + """ESPnet2ASRTransducerModel module definition. + + Args: + vocab_size: Size of complete vocabulary (w/ EOS and blank included). + token_list: List of token + frontend: Frontend module. + specaug: SpecAugment module. + normalize: Normalization module. + encoder: Encoder module. + decoder: Decoder module. + joint_network: Joint Network module. + transducer_weight: Weight of the Transducer loss. + fastemit_lambda: FastEmit lambda value. + auxiliary_ctc_weight: Weight of auxiliary CTC loss. + auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs. + auxiliary_lm_loss_weight: Weight of auxiliary LM loss. + auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing. + ignore_id: Initial padding ID. + sym_space: Space symbol. + sym_blank: Blank Symbol + report_cer: Whether to report Character Error Rate during validation. + report_wer: Whether to report Word Error Rate during validation. + extract_feats_in_collect_stats: Whether to use extract_feats stats collection. + + """ + + def __init__( + self, + vocab_size: int, + token_list: Union[Tuple[str, ...], List[str]], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + encoder, + decoder: AbsDecoder, + att_decoder: Optional[AbsAttDecoder], + joint_network: JointNetwork, + transducer_weight: float = 1.0, + fastemit_lambda: float = 0.0, + auxiliary_ctc_weight: float = 0.0, + auxiliary_ctc_dropout_rate: float = 0.0, + auxiliary_lm_loss_weight: float = 0.0, + auxiliary_lm_loss_smoothing: float = 0.0, + ignore_id: int = -1, + sym_space: str = "", + sym_blank: str = "", + report_cer: bool = True, + report_wer: bool = True, + extract_feats_in_collect_stats: bool = True, + ) -> None: + """Construct an ESPnetASRTransducerModel object.""" + super().__init__() + + assert check_argument_types() + + # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos) + self.blank_id = 0 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.token_list = token_list.copy() + + self.sym_space = sym_space + self.sym_blank = sym_blank + + self.frontend = frontend + self.specaug = specaug + self.normalize = normalize + + self.encoder = encoder + self.decoder = decoder + self.joint_network = joint_network + + self.criterion_transducer = None + self.error_calculator = None + + self.use_auxiliary_ctc = auxiliary_ctc_weight > 0 + self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0 + + if self.use_auxiliary_ctc: + self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size) + self.ctc_dropout_rate = auxiliary_ctc_dropout_rate + + if self.use_auxiliary_lm_loss: + self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size) + self.lm_loss_smoothing = auxiliary_lm_loss_smoothing + + self.transducer_weight = transducer_weight + self.fastemit_lambda = fastemit_lambda + + self.auxiliary_ctc_weight = auxiliary_ctc_weight + self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight + + self.report_cer = report_cer + self.report_wer = report_wer + + self.extract_feats_in_collect_stats = extract_feats_in_collect_stats + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + decoding_ind: int = None, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Forward architecture and compute loss(es). + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + text: Label ID sequences. (B, L) + text_lengths: Label ID sequences lengths. (B,) + kwargs: Contains "utts_id". + + Return: + loss: Main loss value. + stats: Task statistics. + weight: Task weights. + + """ + assert text_lengths.dim() == 1, text_lengths.shape + assert ( + speech.shape[0] + == speech_lengths.shape[0] + == text.shape[0] + == text_lengths.shape[0] + ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + + batch_size = speech.shape[0] + text = text[:, : text_lengths.max()] + + # 1. Encoder + ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind) + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind) + # 2. Transducer-related I/O preparation + decoder_in, target, t_len, u_len = get_transducer_task_io( + text, + encoder_out_lens, + ignore_id=self.ignore_id, + ) + + # 3. Decoder + self.decoder.set_device(encoder_out.device) + decoder_out = self.decoder(decoder_in, u_len) + + # 4. Joint Network + joint_out = self.joint_network( + encoder_out.unsqueeze(2), decoder_out.unsqueeze(1) + ) + + # 5. Losses + loss_trans, cer_trans, wer_trans = self._calc_transducer_loss( + encoder_out, + joint_out, + target, + t_len, + u_len, + ) + + loss_ctc, loss_lm = 0.0, 0.0 + + if self.use_auxiliary_ctc: + loss_ctc = self._calc_ctc_loss( + encoder_out, + target, + t_len, + u_len, + ) + + if self.use_auxiliary_lm_loss: + loss_lm = self._calc_lm_loss(decoder_out, target) + + loss = ( + self.transducer_weight * loss_trans + + self.auxiliary_ctc_weight * loss_ctc + + self.auxiliary_lm_loss_weight * loss_lm + ) + + stats = dict( + loss=loss.detach(), + loss_transducer=loss_trans.detach(), + aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None, + aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None, + cer_transducer=cer_trans, + wer_transducer=wer_trans, + ) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + + return loss, stats, weight + + def collect_feats( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Dict[str, torch.Tensor]: + """Collect features sequences and features lengths sequences. + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + text: Label ID sequences. (B, L) + text_lengths: Label ID sequences lengths. (B,) + kwargs: Contains "utts_id". + + Return: + {}: "feats": Features sequences. (B, T, D_feats), + "feats_lengths": Features sequences lengths. (B,) + + """ + if self.extract_feats_in_collect_stats: + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + else: + # Generate dummy stats if extract_feats_in_collect_stats is False + logging.warning( + "Generating dummy stats for feats and feats_lengths, " + "because encoder_conf.extract_feats_in_collect_stats is " + f"{self.extract_feats_in_collect_stats}" + ) + + feats, feats_lengths = speech, speech_lengths + + return {"feats": feats, "feats_lengths": feats_lengths} + + def encode( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + ind: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encoder speech sequences. + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + + Return: + encoder_out: Encoder outputs. (B, T, D_enc) + encoder_out_lens: Encoder outputs lengths. (B,) + + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # 4. Forward encoder + encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths, ind=ind) + + assert encoder_out.size(0) == speech.size(0), ( + encoder_out.size(), + speech.size(0), + ) + assert encoder_out.size(1) <= encoder_out_lens.max(), ( + encoder_out.size(), + encoder_out_lens.max(), + ) + + return encoder_out, encoder_out_lens + + def _extract_feats( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Extract features sequences and features sequences lengths. + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + + Return: + feats: Features sequences. (B, T, D_feats) + feats_lengths: Features sequences lengths. (B,) + + """ + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, : speech_lengths.max()] + + if self.frontend is not None: + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + feats, feats_lengths = speech, speech_lengths + + return feats, feats_lengths + + def _calc_transducer_loss( + self, + encoder_out: torch.Tensor, + joint_out: torch.Tensor, + target: torch.Tensor, + t_len: torch.Tensor, + u_len: torch.Tensor, + ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]: + """Compute Transducer loss. + + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + joint_out: Joint Network output sequences (B, T, U, D_joint) + target: Target label ID sequences. (B, L) + t_len: Encoder output sequences lengths. (B,) + u_len: Target label ID sequences lengths. (B,) + + Return: + loss_transducer: Transducer loss value. + cer_transducer: Character error rate for Transducer. + wer_transducer: Word Error Rate for Transducer. + + """ + if self.criterion_transducer is None: + try: + # from warprnnt_pytorch import RNNTLoss + # self.criterion_transducer = RNNTLoss( + # reduction="mean", + # fastemit_lambda=self.fastemit_lambda, + # ) + from warp_rnnt import rnnt_loss as RNNTLoss + self.criterion_transducer = RNNTLoss + + except ImportError: + logging.error( + "warp-rnnt was not installed." + "Please consult the installation documentation." + ) + exit(1) + + # loss_transducer = self.criterion_transducer( + # joint_out, + # target, + # t_len, + # u_len, + # ) + log_probs = torch.log_softmax(joint_out, dim=-1) + + loss_transducer = self.criterion_transducer( + log_probs, + target, + t_len, + u_len, + reduction="mean", + blank=self.blank_id, + gather=True, + ) + + if not self.training and (self.report_cer or self.report_wer): + if self.error_calculator is None: + from espnet2.asr_transducer.error_calculator import ErrorCalculator + + self.error_calculator = ErrorCalculator( + self.decoder, + self.joint_network, + self.token_list, + self.sym_space, + self.sym_blank, + report_cer=self.report_cer, + report_wer=self.report_wer, + ) + + cer_transducer, wer_transducer = self.error_calculator(encoder_out, target) + + return loss_transducer, cer_transducer, wer_transducer + + return loss_transducer, None, None + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + target: torch.Tensor, + t_len: torch.Tensor, + u_len: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + target: Target label ID sequences. (B, L) + t_len: Encoder output sequences lengths. (B,) + u_len: Target label ID sequences lengths. (B,) + + Return: + loss_ctc: CTC loss value. + + """ + ctc_in = self.ctc_lin( + torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate) + ) + ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1) + + target_mask = target != 0 + ctc_target = target[target_mask].cpu() + + with torch.backends.cudnn.flags(deterministic=True): + loss_ctc = torch.nn.functional.ctc_loss( + ctc_in, + ctc_target, + t_len, + u_len, + zero_infinity=True, + reduction="sum", + ) + loss_ctc /= target.size(0) + + return loss_ctc + + def _calc_lm_loss( + self, + decoder_out: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + """Compute LM loss. + + Args: + decoder_out: Decoder output sequences. (B, U, D_dec) + target: Target label ID sequences. (B, L) + + Return: + loss_lm: LM loss value. + + """ + lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size) + lm_target = target.view(-1).type(torch.int64) + + with torch.no_grad(): + true_dist = lm_loss_in.clone() + true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1)) + + # Ignore blank ID (0) + ignore = lm_target == 0 + lm_target = lm_target.masked_fill(ignore, 0) + + true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing)) + + loss_lm = torch.nn.functional.kl_div( + torch.log_softmax(lm_loss_in, dim=1), + true_dist, + reduction="none", + ) + loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size( + 0 + ) + + return loss_lm diff --git a/funasr/models_transducer/espnet_transducer_model_unified.py b/funasr/models_transducer/espnet_transducer_model_unified.py new file mode 100644 index 000000000..efe3f4eb1 --- /dev/null +++ b/funasr/models_transducer/espnet_transducer_model_unified.py @@ -0,0 +1,588 @@ +"""ESPnet2 ASR Transducer model.""" + +import logging +from contextlib import contextmanager +from typing import Dict, List, Optional, Tuple, Union + +import torch +from packaging.version import parse as V +from typeguard import check_argument_types + +from funasr.models.frontend.abs_frontend import AbsFrontend +from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.models_transducer.decoder.abs_decoder import AbsDecoder +from funasr.models_transducer.encoder.encoder import Encoder +from funasr.models_transducer.joint_network import JointNetwork +from funasr.models_transducer.utils import get_transducer_task_io +from funasr.layers.abs_normalize import AbsNormalize +from funasr.torch_utils.device_funcs import force_gatherable +from funasr.train.abs_espnet_model import AbsESPnetModel +from funasr.modules.add_sos_eos import add_sos_eos +from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder +from funasr.modules.nets_utils import th_accuracy +from funasr.losses.label_smoothing_loss import ( # noqa: H301 + LabelSmoothingLoss, +) +from funasr.models_transducer.error_calculator import ErrorCalculator +if V(torch.__version__) >= V("1.6.0"): + from torch.cuda.amp import autocast +else: + + @contextmanager + def autocast(enabled=True): + yield + + +class ESPnetASRUnifiedTransducerModel(AbsESPnetModel): + """ESPnet2ASRTransducerModel module definition. + + Args: + vocab_size: Size of complete vocabulary (w/ EOS and blank included). + token_list: List of token + frontend: Frontend module. + specaug: SpecAugment module. + normalize: Normalization module. + encoder: Encoder module. + decoder: Decoder module. + joint_network: Joint Network module. + transducer_weight: Weight of the Transducer loss. + fastemit_lambda: FastEmit lambda value. + auxiliary_ctc_weight: Weight of auxiliary CTC loss. + auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs. + auxiliary_lm_loss_weight: Weight of auxiliary LM loss. + auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing. + ignore_id: Initial padding ID. + sym_space: Space symbol. + sym_blank: Blank Symbol + report_cer: Whether to report Character Error Rate during validation. + report_wer: Whether to report Word Error Rate during validation. + extract_feats_in_collect_stats: Whether to use extract_feats stats collection. + + """ + + def __init__( + self, + vocab_size: int, + token_list: Union[Tuple[str, ...], List[str]], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + encoder: Encoder, + decoder: AbsDecoder, + att_decoder: Optional[AbsAttDecoder], + joint_network: JointNetwork, + transducer_weight: float = 1.0, + fastemit_lambda: float = 0.0, + auxiliary_ctc_weight: float = 0.0, + auxiliary_att_weight: float = 0.0, + auxiliary_ctc_dropout_rate: float = 0.0, + auxiliary_lm_loss_weight: float = 0.0, + auxiliary_lm_loss_smoothing: float = 0.0, + ignore_id: int = -1, + sym_space: str = "", + sym_blank: str = "", + report_cer: bool = True, + report_wer: bool = True, + sym_sos: str = "", + sym_eos: str = "", + extract_feats_in_collect_stats: bool = True, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + ) -> None: + """Construct an ESPnetASRTransducerModel object.""" + super().__init__() + + assert check_argument_types() + + # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos) + self.blank_id = 0 + + if sym_sos in token_list: + self.sos = token_list.index(sym_sos) + else: + self.sos = vocab_size - 1 + if sym_eos in token_list: + self.eos = token_list.index(sym_eos) + else: + self.eos = vocab_size - 1 + + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.token_list = token_list.copy() + + self.sym_space = sym_space + self.sym_blank = sym_blank + + self.frontend = frontend + self.specaug = specaug + self.normalize = normalize + + self.encoder = encoder + self.decoder = decoder + self.joint_network = joint_network + + self.criterion_transducer = None + self.error_calculator = None + + self.use_auxiliary_ctc = auxiliary_ctc_weight > 0 + self.use_auxiliary_att = auxiliary_att_weight > 0 + self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0 + + if self.use_auxiliary_ctc: + self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size) + self.ctc_dropout_rate = auxiliary_ctc_dropout_rate + + if self.use_auxiliary_att: + self.att_decoder = att_decoder + + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + if self.use_auxiliary_lm_loss: + self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size) + self.lm_loss_smoothing = auxiliary_lm_loss_smoothing + + self.transducer_weight = transducer_weight + self.fastemit_lambda = fastemit_lambda + + self.auxiliary_ctc_weight = auxiliary_ctc_weight + self.auxiliary_att_weight = auxiliary_att_weight + self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight + + self.report_cer = report_cer + self.report_wer = report_wer + + self.extract_feats_in_collect_stats = extract_feats_in_collect_stats + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Forward architecture and compute loss(es). + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + text: Label ID sequences. (B, L) + text_lengths: Label ID sequences lengths. (B,) + kwargs: Contains "utts_id". + + Return: + loss: Main loss value. + stats: Task statistics. + weight: Task weights. + + """ + assert text_lengths.dim() == 1, text_lengths.shape + assert ( + speech.shape[0] + == speech_lengths.shape[0] + == text.shape[0] + == text_lengths.shape[0] + ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + + batch_size = speech.shape[0] + text = text[:, : text_lengths.max()] + #print(speech.shape) + # 1. Encoder + encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths) + + loss_att, loss_att_chunk = 0.0, 0.0 + + if self.use_auxiliary_att: + loss_att, _ = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + loss_att_chunk, _ = self._calc_att_loss( + encoder_out_chunk, encoder_out_lens, text, text_lengths + ) + + # 2. Transducer-related I/O preparation + decoder_in, target, t_len, u_len = get_transducer_task_io( + text, + encoder_out_lens, + ignore_id=self.ignore_id, + ) + + # 3. Decoder + self.decoder.set_device(encoder_out.device) + decoder_out = self.decoder(decoder_in, u_len) + + # 4. Joint Network + joint_out = self.joint_network( + encoder_out.unsqueeze(2), decoder_out.unsqueeze(1) + ) + + joint_out_chunk = self.joint_network( + encoder_out_chunk.unsqueeze(2), decoder_out.unsqueeze(1) + ) + + # 5. Losses + loss_trans_utt, cer_trans, wer_trans = self._calc_transducer_loss( + encoder_out, + joint_out, + target, + t_len, + u_len, + ) + + loss_trans_chunk, cer_trans_chunk, wer_trans_chunk = self._calc_transducer_loss( + encoder_out_chunk, + joint_out_chunk, + target, + t_len, + u_len, + ) + + loss_ctc, loss_ctc_chunk, loss_lm = 0.0, 0.0, 0.0 + + if self.use_auxiliary_ctc: + loss_ctc = self._calc_ctc_loss( + encoder_out, + target, + t_len, + u_len, + ) + loss_ctc_chunk = self._calc_ctc_loss( + encoder_out_chunk, + target, + t_len, + u_len, + ) + + if self.use_auxiliary_lm_loss: + loss_lm = self._calc_lm_loss(decoder_out, target) + + loss_trans = loss_trans_utt + loss_trans_chunk + loss_ctc = loss_ctc + loss_ctc_chunk + loss_ctc = loss_att + loss_att_chunk + + loss = ( + self.transducer_weight * loss_trans + + self.auxiliary_ctc_weight * loss_ctc + + self.auxiliary_att_weight * loss_att + + self.auxiliary_lm_loss_weight * loss_lm + ) + + stats = dict( + loss=loss.detach(), + loss_transducer=loss_trans_utt.detach(), + loss_transducer_chunk=loss_trans_chunk.detach(), + aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None, + aux_ctc_loss_chunk=loss_ctc_chunk.detach() if loss_ctc_chunk > 0.0 else None, + aux_att_loss=loss_att.detach() if loss_att > 0.0 else None, + aux_att_loss_chunk=loss_att_chunk.detach() if loss_att_chunk > 0.0 else None, + aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None, + cer_transducer=cer_trans, + wer_transducer=wer_trans, + cer_transducer_chunk=cer_trans_chunk, + wer_transducer_chunk=wer_trans_chunk, + ) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + + return loss, stats, weight + + def collect_feats( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Dict[str, torch.Tensor]: + """Collect features sequences and features lengths sequences. + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + text: Label ID sequences. (B, L) + text_lengths: Label ID sequences lengths. (B,) + kwargs: Contains "utts_id". + + Return: + {}: "feats": Features sequences. (B, T, D_feats), + "feats_lengths": Features sequences lengths. (B,) + + """ + if self.extract_feats_in_collect_stats: + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + else: + # Generate dummy stats if extract_feats_in_collect_stats is False + logging.warning( + "Generating dummy stats for feats and feats_lengths, " + "because encoder_conf.extract_feats_in_collect_stats is " + f"{self.extract_feats_in_collect_stats}" + ) + + feats, feats_lengths = speech, speech_lengths + + return {"feats": feats, "feats_lengths": feats_lengths} + + def encode( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encoder speech sequences. + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + + Return: + encoder_out: Encoder outputs. (B, T, D_enc) + encoder_out_lens: Encoder outputs lengths. (B,) + + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # 4. Forward encoder + encoder_out, encoder_out_chunk, encoder_out_lens = self.encoder(feats, feats_lengths) + + assert encoder_out.size(0) == speech.size(0), ( + encoder_out.size(), + speech.size(0), + ) + assert encoder_out.size(1) <= encoder_out_lens.max(), ( + encoder_out.size(), + encoder_out_lens.max(), + ) + + return encoder_out, encoder_out_chunk, encoder_out_lens + + def _extract_feats( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Extract features sequences and features sequences lengths. + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + + Return: + feats: Features sequences. (B, T, D_feats) + feats_lengths: Features sequences lengths. (B,) + + """ + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, : speech_lengths.max()] + + if self.frontend is not None: + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + feats, feats_lengths = speech, speech_lengths + + return feats, feats_lengths + + def _calc_transducer_loss( + self, + encoder_out: torch.Tensor, + joint_out: torch.Tensor, + target: torch.Tensor, + t_len: torch.Tensor, + u_len: torch.Tensor, + ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]: + """Compute Transducer loss. + + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + joint_out: Joint Network output sequences (B, T, U, D_joint) + target: Target label ID sequences. (B, L) + t_len: Encoder output sequences lengths. (B,) + u_len: Target label ID sequences lengths. (B,) + + Return: + loss_transducer: Transducer loss value. + cer_transducer: Character error rate for Transducer. + wer_transducer: Word Error Rate for Transducer. + + """ + if self.criterion_transducer is None: + try: + # from warprnnt_pytorch import RNNTLoss + # self.criterion_transducer = RNNTLoss( + # reduction="mean", + # fastemit_lambda=self.fastemit_lambda, + # ) + from warp_rnnt import rnnt_loss as RNNTLoss + self.criterion_transducer = RNNTLoss + + except ImportError: + logging.error( + "warp-rnnt was not installed." + "Please consult the installation documentation." + ) + exit(1) + + # loss_transducer = self.criterion_transducer( + # joint_out, + # target, + # t_len, + # u_len, + # ) + log_probs = torch.log_softmax(joint_out, dim=-1) + + loss_transducer = self.criterion_transducer( + log_probs, + target, + t_len, + u_len, + reduction="mean", + blank=self.blank_id, + fastemit_lambda=self.fastemit_lambda, + gather=True, + ) + + if not self.training and (self.report_cer or self.report_wer): + if self.error_calculator is None: + self.error_calculator = ErrorCalculator( + self.decoder, + self.joint_network, + self.token_list, + self.sym_space, + self.sym_blank, + report_cer=self.report_cer, + report_wer=self.report_wer, + ) + + cer_transducer, wer_transducer = self.error_calculator(encoder_out, target) + + return loss_transducer, cer_transducer, wer_transducer + + return loss_transducer, None, None + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + target: torch.Tensor, + t_len: torch.Tensor, + u_len: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + target: Target label ID sequences. (B, L) + t_len: Encoder output sequences lengths. (B,) + u_len: Target label ID sequences lengths. (B,) + + Return: + loss_ctc: CTC loss value. + + """ + ctc_in = self.ctc_lin( + torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate) + ) + ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1) + + target_mask = target != 0 + ctc_target = target[target_mask].cpu() + + with torch.backends.cudnn.flags(deterministic=True): + loss_ctc = torch.nn.functional.ctc_loss( + ctc_in, + ctc_target, + t_len, + u_len, + zero_infinity=True, + reduction="sum", + ) + loss_ctc /= target.size(0) + + return loss_ctc + + def _calc_lm_loss( + self, + decoder_out: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + """Compute LM loss. + + Args: + decoder_out: Decoder output sequences. (B, U, D_dec) + target: Target label ID sequences. (B, L) + + Return: + loss_lm: LM loss value. + + """ + lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size) + lm_target = target.view(-1).type(torch.int64) + + with torch.no_grad(): + true_dist = lm_loss_in.clone() + true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1)) + + # Ignore blank ID (0) + ignore = lm_target == 0 + lm_target = lm_target.masked_fill(ignore, 0) + + true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing)) + + loss_lm = torch.nn.functional.kl_div( + torch.log_softmax(lm_loss_in, dim=1), + true_dist, + reduction="none", + ) + loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size( + 0 + ) + + return loss_lm + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + if hasattr(self, "lang_token_id") and self.lang_token_id is not None: + ys_pad = torch.cat( + [ + self.lang_token_id.repeat(ys_pad.size(0), 1).to(ys_pad.device), + ys_pad, + ], + dim=1, + ) + ys_pad_lens += 1 + + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.att_decoder( + encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens + ) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, + ) + + return loss_att, acc_att diff --git a/funasr/models_transducer/joint_network.py b/funasr/models_transducer/joint_network.py new file mode 100644 index 000000000..119dd84a5 --- /dev/null +++ b/funasr/models_transducer/joint_network.py @@ -0,0 +1,62 @@ +"""Transducer joint network implementation.""" + +import torch + +from funasr.models_transducer.activation import get_activation + + +class JointNetwork(torch.nn.Module): + """Transducer joint network module. + + Args: + output_size: Output size. + encoder_size: Encoder output size. + decoder_size: Decoder output size.. + joint_space_size: Joint space size. + joint_act_type: Type of activation for joint network. + **activation_parameters: Parameters for the activation function. + + """ + + def __init__( + self, + output_size: int, + encoder_size: int, + decoder_size: int, + joint_space_size: int = 256, + joint_activation_type: str = "tanh", + **activation_parameters, + ) -> None: + """Construct a JointNetwork object.""" + super().__init__() + + self.lin_enc = torch.nn.Linear(encoder_size, joint_space_size) + self.lin_dec = torch.nn.Linear(decoder_size, joint_space_size, bias=False) + + self.lin_out = torch.nn.Linear(joint_space_size, output_size) + + self.joint_activation = get_activation( + joint_activation_type, **activation_parameters + ) + + def forward( + self, + enc_out: torch.Tensor, + dec_out: torch.Tensor, + project_input: bool = True, + ) -> torch.Tensor: + """Joint computation of encoder and decoder hidden state sequences. + + Args: + enc_out: Expanded encoder output state sequences (B, T, 1, D_enc) + dec_out: Expanded decoder output state sequences (B, 1, U, D_dec) + + Returns: + joint_out: Joint output state sequences. (B, T, U, D_out) + + """ + if project_input: + joint_out = self.joint_activation(self.lin_enc(enc_out) + self.lin_dec(dec_out)) + else: + joint_out = self.joint_activation(enc_out + dec_out) + return self.lin_out(joint_out) diff --git a/funasr/models_transducer/utils.py b/funasr/models_transducer/utils.py new file mode 100644 index 000000000..fd3c531b4 --- /dev/null +++ b/funasr/models_transducer/utils.py @@ -0,0 +1,200 @@ +"""Utility functions for Transducer models.""" + +from typing import List, Tuple + +import torch + + +class TooShortUttError(Exception): + """Raised when the utt is too short for subsampling. + + Args: + message: Error message to display. + actual_size: The size that cannot pass the subsampling. + limit: The size limit for subsampling. + + """ + + def __init__(self, message: str, actual_size: int, limit: int) -> None: + """Construct a TooShortUttError module.""" + super().__init__(message) + + self.actual_size = actual_size + self.limit = limit + + +def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]: + """Check if the input is too short for subsampling. + + Args: + sub_factor: Subsampling factor for Conv2DSubsampling. + size: Input size. + + Returns: + : Whether an error should be sent. + : Size limit for specified subsampling factor. + + """ + if sub_factor == 2 and size < 3: + return True, 7 + elif sub_factor == 4 and size < 7: + return True, 7 + elif sub_factor == 6 and size < 11: + return True, 11 + + return False, -1 + + +def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]: + """Get conv2D second layer parameters for given subsampling factor. + + Args: + sub_factor: Subsampling factor (1/X). + input_size: Input size. + + Returns: + : Kernel size for second convolution. + : Stride for second convolution. + : Conv2DSubsampling output size. + + """ + if sub_factor == 2: + return 3, 1, (((input_size - 1) // 2 - 2)) + elif sub_factor == 4: + return 3, 2, (((input_size - 1) // 2 - 1) // 2) + elif sub_factor == 6: + return 5, 3, (((input_size - 1) // 2 - 2) // 3) + else: + raise ValueError( + "subsampling_factor parameter should be set to either 2, 4 or 6." + ) + + +def make_chunk_mask( + size: int, + chunk_size: int, + left_chunk_size: int = 0, + device: torch.device = None, +) -> torch.Tensor: + """Create chunk mask for the subsequent steps (size, size). + + Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py + + Args: + size: Size of the source mask. + chunk_size: Number of frames in chunk. + left_chunk_size: Size of the left context in chunks (0 means full context). + device: Device for the mask tensor. + + Returns: + mask: Chunk mask. (size, size) + + """ + mask = torch.zeros(size, size, device=device, dtype=torch.bool) + + for i in range(size): + if left_chunk_size <= 0: + start = 0 + else: + start = max((i // chunk_size - left_chunk_size) * chunk_size, 0) + + end = min((i // chunk_size + 1) * chunk_size, size) + mask[i, start:end] = True + + return ~mask + + +def make_source_mask(lengths: torch.Tensor) -> torch.Tensor: + """Create source mask for given lengths. + + Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py + + Args: + lengths: Sequence lengths. (B,) + + Returns: + : Mask for the sequence lengths. (B, max_len) + + """ + max_len = lengths.max() + batch_size = lengths.size(0) + + expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths) + + return expanded_lengths >= lengths.unsqueeze(1) + + +def get_transducer_task_io( + labels: torch.Tensor, + encoder_out_lens: torch.Tensor, + ignore_id: int = -1, + blank_id: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Get Transducer loss I/O. + + Args: + labels: Label ID sequences. (B, L) + encoder_out_lens: Encoder output lengths. (B,) + ignore_id: Padding symbol ID. + blank_id: Blank symbol ID. + + Returns: + decoder_in: Decoder inputs. (B, U) + target: Target label ID sequences. (B, U) + t_len: Time lengths. (B,) + u_len: Label lengths. (B,) + + """ + + def pad_list(labels: List[torch.Tensor], padding_value: int = 0): + """Create padded batch of labels from a list of labels sequences. + + Args: + labels: Labels sequences. [B x (?)] + padding_value: Padding value. + + Returns: + labels: Batch of padded labels sequences. (B,) + + """ + batch_size = len(labels) + + padded = ( + labels[0] + .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:]) + .fill_(padding_value) + ) + + for i in range(batch_size): + padded[i, : labels[i].size(0)] = labels[i] + + return padded + + device = labels.device + + labels_unpad = [y[y != ignore_id] for y in labels] + blank = labels[0].new([blank_id]) + + decoder_in = pad_list( + [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id + ).to(device) + + target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device) + + encoder_out_lens = list(map(int, encoder_out_lens)) + t_len = torch.IntTensor(encoder_out_lens).to(device) + + u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device) + + return decoder_in, target, t_len, u_len + +def pad_to_len(t: torch.Tensor, pad_len: int, dim: int): + """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros.""" + if t.size(dim) == pad_len: + return t + else: + pad_size = list(t.shape) + pad_size[dim] = pad_len - t.size(dim) + return torch.cat( + [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim + ) diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py new file mode 100644 index 000000000..3c7a78261 --- /dev/null +++ b/funasr/tasks/asr_transducer.py @@ -0,0 +1,487 @@ +"""ASR Transducer Task.""" + +import argparse +import logging +from typing import Callable, Collection, Dict, List, Optional, Tuple + +import numpy as np +import torch +from typeguard import check_argument_types, check_return_type + +from funasr.models.frontend.abs_frontend import AbsFrontend +from funasr.models.frontend.default import DefaultFrontend +from funasr.models.frontend.windowing import SlidingWindow +from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.models.specaug.specaug import SpecAug +from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder +from funasr.models.decoder.transformer_decoder import ( + DynamicConvolution2DTransformerDecoder, + DynamicConvolutionTransformerDecoder, + LightweightConvolution2DTransformerDecoder, + LightweightConvolutionTransformerDecoder, + TransformerDecoder, +) +from funasr.models_transducer.decoder.abs_decoder import AbsDecoder +from funasr.models_transducer.decoder.rnn_decoder import RNNDecoder +from funasr.models_transducer.decoder.stateless_decoder import StatelessDecoder +from funasr.models_transducer.encoder.encoder import Encoder +from funasr.models_transducer.encoder.sanm_encoder import SANMEncoderChunkOpt +from funasr.models_transducer.espnet_transducer_model import ESPnetASRTransducerModel +from funasr.models_transducer.espnet_transducer_model_unified import ESPnetASRUnifiedTransducerModel +from funasr.models_transducer.espnet_transducer_model_uni_asr import UniASRTransducerModel +from funasr.models_transducer.joint_network import JointNetwork +from funasr.layers.abs_normalize import AbsNormalize +from funasr.layers.global_mvn import GlobalMVN +from funasr.layers.utterance_mvn import UtteranceMVN +from funasr.tasks.abs_task import AbsTask +from funasr.text.phoneme_tokenizer import g2p_choices +from funasr.train.class_choices import ClassChoices +from funasr.datasets.collate_fn import CommonCollateFn +from funasr.datasets.preprocessor import CommonPreprocessor +from funasr.train.trainer import Trainer +from funasr.utils.get_default_kwargs import get_default_kwargs +from funasr.utils.nested_dict_action import NestedDictAction +from funasr.utils.types import float_or_none, int_or_none, str2bool, str_or_none + +frontend_choices = ClassChoices( + name="frontend", + classes=dict( + default=DefaultFrontend, + sliding_window=SlidingWindow, + ), + type_check=AbsFrontend, + default="default", +) +specaug_choices = ClassChoices( + "specaug", + classes=dict( + specaug=SpecAug, + ), + type_check=AbsSpecAug, + default=None, + optional=True, +) +normalize_choices = ClassChoices( + "normalize", + classes=dict( + global_mvn=GlobalMVN, + utterance_mvn=UtteranceMVN, + ), + type_check=AbsNormalize, + default="utterance_mvn", + optional=True, +) +encoder_choices = ClassChoices( + "encoder", + classes=dict( + encoder=Encoder, + sanm_chunk_opt=SANMEncoderChunkOpt, + ), + default="encoder", +) + +decoder_choices = ClassChoices( + "decoder", + classes=dict( + rnn=RNNDecoder, + stateless=StatelessDecoder, + ), + type_check=AbsDecoder, + default="rnn", +) + +att_decoder_choices = ClassChoices( + "att_decoder", + classes=dict( + transformer=TransformerDecoder, + lightweight_conv=LightweightConvolutionTransformerDecoder, + lightweight_conv2d=LightweightConvolution2DTransformerDecoder, + dynamic_conv=DynamicConvolutionTransformerDecoder, + dynamic_conv2d=DynamicConvolution2DTransformerDecoder, + ), + type_check=AbsAttDecoder, + default=None, + optional=True, +) +class ASRTransducerTask(AbsTask): + """ASR Transducer Task definition.""" + + num_optimizers: int = 1 + + class_choices_list = [ + frontend_choices, + specaug_choices, + normalize_choices, + encoder_choices, + decoder_choices, + att_decoder_choices, + ] + + trainer = Trainer + + @classmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + """Add Transducer task arguments. + Args: + cls: ASRTransducerTask object. + parser: Transducer arguments parser. + """ + group = parser.add_argument_group(description="Task related.") + + # required = parser.get_default("required") + # required += ["token_list"] + + group.add_argument( + "--token_list", + type=str_or_none, + default=None, + help="Integer-string mapper for tokens.", + ) + group.add_argument( + "--input_size", + type=int_or_none, + default=None, + help="The number of dimensions for input features.", + ) + group.add_argument( + "--init", + type=str_or_none, + default=None, + help="Type of model initialization to use.", + ) + group.add_argument( + "--model_conf", + action=NestedDictAction, + default=get_default_kwargs(ESPnetASRTransducerModel), + help="The keyword arguments for the model class.", + ) + # group.add_argument( + # "--encoder_conf", + # action=NestedDictAction, + # default={}, + # help="The keyword arguments for the encoder class.", + # ) + group.add_argument( + "--joint_network_conf", + action=NestedDictAction, + default={}, + help="The keyword arguments for the joint network class.", + ) + group = parser.add_argument_group(description="Preprocess related.") + group.add_argument( + "--use_preprocessor", + type=str2bool, + default=True, + help="Whether to apply preprocessing to input data.", + ) + group.add_argument( + "--token_type", + type=str, + default="bpe", + choices=["bpe", "char", "word", "phn"], + help="The type of tokens to use during tokenization.", + ) + group.add_argument( + "--bpemodel", + type=str_or_none, + default=None, + help="The path of the sentencepiece model.", + ) + parser.add_argument( + "--non_linguistic_symbols", + type=str_or_none, + help="The 'non_linguistic_symbols' file path.", + ) + parser.add_argument( + "--cleaner", + type=str_or_none, + choices=[None, "tacotron", "jaconv", "vietnamese"], + default=None, + help="Text cleaner to use.", + ) + parser.add_argument( + "--g2p", + type=str_or_none, + choices=g2p_choices, + default=None, + help="g2p method to use if --token_type=phn.", + ) + parser.add_argument( + "--speech_volume_normalize", + type=float_or_none, + default=None, + help="Normalization value for maximum amplitude scaling.", + ) + parser.add_argument( + "--rir_scp", + type=str_or_none, + default=None, + help="The RIR SCP file path.", + ) + parser.add_argument( + "--rir_apply_prob", + type=float, + default=1.0, + help="The probability of the applied RIR convolution.", + ) + parser.add_argument( + "--noise_scp", + type=str_or_none, + default=None, + help="The path of noise SCP file.", + ) + parser.add_argument( + "--noise_apply_prob", + type=float, + default=1.0, + help="The probability of the applied noise addition.", + ) + parser.add_argument( + "--noise_db_range", + type=str, + default="13_15", + help="The range of the noise decibel level.", + ) + for class_choices in cls.class_choices_list: + # Append -- and --_conf. + # e.g. --decoder and --decoder_conf + class_choices.add_arguments(group) + + @classmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[ + [Collection[Tuple[str, Dict[str, np.ndarray]]]], + Tuple[List[str], Dict[str, torch.Tensor]], + ]: + """Build collate function. + Args: + cls: ASRTransducerTask object. + args: Task arguments. + train: Training mode. + Return: + : Callable collate function. + """ + assert check_argument_types() + + return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) + + @classmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + """Build pre-processing function. + Args: + cls: ASRTransducerTask object. + args: Task arguments. + train: Training mode. + Return: + : Callable pre-processing function. + """ + assert check_argument_types() + + if args.use_preprocessor: + retval = CommonPreprocessor( + train=train, + token_type=args.token_type, + token_list=args.token_list, + bpemodel=args.bpemodel, + non_linguistic_symbols=args.non_linguistic_symbols, + text_cleaner=args.cleaner, + g2p_type=args.g2p, + rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, + rir_apply_prob=args.rir_apply_prob + if hasattr(args, "rir_apply_prob") + else 1.0, + noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None, + noise_apply_prob=args.noise_apply_prob + if hasattr(args, "noise_apply_prob") + else 1.0, + noise_db_range=args.noise_db_range + if hasattr(args, "noise_db_range") + else "13_15", + speech_volume_normalize=args.speech_volume_normalize + if hasattr(args, "rir_scp") + else None, + ) + else: + retval = None + + assert check_return_type(retval) + return retval + + @classmethod + def required_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + """Required data depending on task mode. + Args: + cls: ASRTransducerTask object. + train: Training mode. + inference: Inference mode. + Return: + retval: Required task data. + """ + if not inference: + retval = ("speech", "text") + else: + retval = ("speech",) + + return retval + + @classmethod + def optional_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + """Optional data depending on task mode. + Args: + cls: ASRTransducerTask object. + train: Training mode. + inference: Inference mode. + Return: + retval: Optional task data. + """ + retval = () + assert check_return_type(retval) + + return retval + + @classmethod + def build_model(cls, args: argparse.Namespace) -> ESPnetASRTransducerModel: + """Required data depending on task mode. + Args: + cls: ASRTransducerTask object. + args: Task arguments. + Return: + model: ASR Transducer model. + """ + assert check_argument_types() + + if isinstance(args.token_list, str): + with open(args.token_list, encoding="utf-8") as f: + token_list = [line.rstrip() for line in f] + + # Overwriting token_list to keep it as "portable". + args.token_list = list(token_list) + elif isinstance(args.token_list, (tuple, list)): + token_list = list(args.token_list) + else: + raise RuntimeError("token_list must be str or list") + vocab_size = len(token_list) + logging.info(f"Vocabulary size: {vocab_size }") + + # 1. frontend + if args.input_size is None: + # Extract features in the model + frontend_class = frontend_choices.get_class(args.frontend) + frontend = frontend_class(**args.frontend_conf) + input_size = frontend.output_size() + else: + # Give features from data-loader + frontend = None + input_size = args.input_size + + # 2. Data augmentation for spectrogram + if args.specaug is not None: + specaug_class = specaug_choices.get_class(args.specaug) + specaug = specaug_class(**args.specaug_conf) + else: + specaug = None + + # 3. Normalization layer + if args.normalize is not None: + normalize_class = normalize_choices.get_class(args.normalize) + normalize = normalize_class(**args.normalize_conf) + else: + normalize = None + + # 4. Encoder + + if getattr(args, "encoder", None) is not None: + encoder_class = encoder_choices.get_class(args.encoder) + encoder = encoder_class(input_size, **args.encoder_conf) + else: + encoder = Encoder(input_size, **args.encoder_conf) + encoder_output_size = encoder.output_size + + # 5. Decoder + decoder_class = decoder_choices.get_class(args.decoder) + decoder = decoder_class( + vocab_size, + **args.decoder_conf, + ) + decoder_output_size = decoder.output_size + + if getattr(args, "att_decoder", None) is not None: + att_decoder_class = att_decoder_choices.get_class(args.att_decoder) + + att_decoder = att_decoder_class( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + **args.att_decoder_conf, + ) + else: + att_decoder = None + + # 6. Joint Network + joint_network = JointNetwork( + vocab_size, + encoder_output_size, + decoder_output_size, + **args.joint_network_conf, + ) + + # 7. Build model + + if getattr(args, "encoder", None) is not None and args.encoder == 'sanm_chunk_opt': + model = UniASRTransducerModel( + vocab_size=vocab_size, + token_list=token_list, + frontend=frontend, + specaug=specaug, + normalize=normalize, + encoder=encoder, + decoder=decoder, + att_decoder=att_decoder, + joint_network=joint_network, + **args.model_conf, + ) + + elif encoder.unified_model_training: + model = ESPnetASRUnifiedTransducerModel( + vocab_size=vocab_size, + token_list=token_list, + frontend=frontend, + specaug=specaug, + normalize=normalize, + encoder=encoder, + decoder=decoder, + att_decoder=att_decoder, + joint_network=joint_network, + **args.model_conf, + ) + + else: + model = ESPnetASRTransducerModel( + vocab_size=vocab_size, + token_list=token_list, + frontend=frontend, + specaug=specaug, + normalize=normalize, + encoder=encoder, + decoder=decoder, + att_decoder=att_decoder, + joint_network=joint_network, + **args.model_conf, + ) + + # 8. Initialize model + if args.init is not None: + raise NotImplementedError( + "Currently not supported.", + "Initialization part will be reworked in a short future.", + ) + + #assert check_return_type(model) + + return model From 96bae0153cb04c82d6e7ca7cb9654d55eb987567 Mon Sep 17 00:00:00 2001 From: aky15 Date: Wed, 15 Mar 2023 17:34:34 +0800 Subject: [PATCH 02/14] rnnt bug fix --- funasr/bin/asr_inference_rnnt.py | 145 ++---------------- .../encoder/blocks/conv_input.py | 9 +- funasr/tasks/abs_task.py | 2 +- 3 files changed, 20 insertions(+), 136 deletions(-) diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py index f651f118d..c8a2916c2 100644 --- a/funasr/bin/asr_inference_rnnt.py +++ b/funasr/bin/asr_inference_rnnt.py @@ -31,7 +31,7 @@ from funasr.torch_utils.set_all_random_seed import set_all_random_seed from funasr.utils import config_argparse from funasr.utils.types import str2bool, str2triple_str, str_or_none from funasr.utils.cli_utils import get_commandline_args - +from funasr.models.frontend.wav_frontend import WavFrontend class Speech2Text: """Speech2Text class for Transducer models. @@ -62,6 +62,7 @@ class Speech2Text: self, asr_train_config: Union[Path, str] = None, asr_model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, beam_search_config: Dict[str, Any] = None, lm_train_config: Union[Path, str] = None, lm_file: Union[Path, str] = None, @@ -86,11 +87,14 @@ class Speech2Text: super().__init__() assert check_argument_types() - asr_model, asr_train_args = ASRTransducerTask.build_model_from_file( - asr_train_config, asr_model_file, device + asr_train_config, asr_model_file, cmvn_file, device ) + frontend = None + if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: + frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) + if quantize_asr_model: if quantize_modules is not None: if not all([q in ["LSTM", "Linear"] for q in quantize_modules]): @@ -156,7 +160,7 @@ class Speech2Text: tokenizer = build_tokenizer(token_type=token_type) converter = TokenIDConverter(token_list=token_list) logging.info(f"Text tokenizer: {tokenizer}") - + self.asr_model = asr_model self.asr_train_args = asr_train_args self.device = device @@ -181,23 +185,13 @@ class Speech2Text: self.simu_streaming = False self.asr_model.encoder.dynamic_chunk_training = False - self.n_fft = asr_train_args.frontend_conf.get("n_fft", 512) - self.hop_length = asr_train_args.frontend_conf.get("hop_length", 128) - - if asr_train_args.frontend_conf.get("win_length", None) is not None: - self.frontend_window_size = asr_train_args.frontend_conf["win_length"] - else: - self.frontend_window_size = self.n_fft - + self.frontend = frontend self.window_size = self.chunk_size + self.right_context - self._raw_ctx = self.asr_model.encoder.get_encoder_input_raw_size( - self.window_size, self.hop_length - ) + self._ctx = self.asr_model.encoder.get_encoder_input_size( self.window_size ) - #self.last_chunk_length = ( # self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 #) * self.hop_length @@ -218,112 +212,6 @@ class Speech2Text: self.num_processed_frames = torch.tensor([[0]], device=self.device) - def apply_frontend( - self, speech: torch.Tensor, is_final: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Forward frontend. - Args: - speech: Speech data. (S) - is_final: Whether speech corresponds to the final (or only) chunk of data. - Returns: - feats: Features sequence. (1, T_in, F) - feats_lengths: Features sequence length. (1, T_in, F) - """ - if self.frontend_cache is not None: - speech = torch.cat([self.frontend_cache["waveform_buffer"], speech], dim=0) - - if is_final: - if self.streaming and speech.size(0) < self.last_chunk_length: - pad = torch.zeros( - self.last_chunk_length - speech.size(0), dtype=speech.dtype - ) - speech = torch.cat([speech, pad], dim=0) - - speech_to_process = speech - waveform_buffer = None - else: - n_frames = ( - speech.size(0) - (self.frontend_window_size - self.hop_length) - ) // self.hop_length - - n_residual = ( - speech.size(0) - (self.frontend_window_size - self.hop_length) - ) % self.hop_length - - speech_to_process = speech.narrow( - 0, - 0, - (self.frontend_window_size - self.hop_length) - + n_frames * self.hop_length, - ) - - waveform_buffer = speech.narrow( - 0, - speech.size(0) - - (self.frontend_window_size - self.hop_length) - - n_residual, - (self.frontend_window_size - self.hop_length) + n_residual, - ).clone() - - speech_to_process = speech_to_process.unsqueeze(0).to( - getattr(torch, self.dtype) - ) - lengths = speech_to_process.new_full( - [1], dtype=torch.long, fill_value=speech_to_process.size(1) - ) - batch = {"speech": speech_to_process, "speech_lengths": lengths} - batch = to_device(batch, device=self.device) - - feats, feats_lengths = self.asr_model._extract_feats(**batch) - if self.asr_model.normalize is not None: - feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) - - if is_final: - if self.frontend_cache is None: - pass - else: - feats = feats.narrow( - 1, - math.ceil( - math.ceil(self.frontend_window_size / self.hop_length) / 2 - ), - feats.size(1) - - math.ceil( - math.ceil(self.frontend_window_size / self.hop_length) / 2 - ), - ) - else: - if self.frontend_cache is None: - feats = feats.narrow( - 1, - 0, - feats.size(1) - - math.ceil( - math.ceil(self.frontend_window_size / self.hop_length) / 2 - ), - ) - else: - feats = feats.narrow( - 1, - math.ceil( - math.ceil(self.frontend_window_size / self.hop_length) / 2 - ), - feats.size(1) - - 2 - * math.ceil( - math.ceil(self.frontend_window_size / self.hop_length) / 2 - ), - ) - - feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) - - if is_final: - self.frontend_cache = None - else: - self.frontend_cache = {"waveform_buffer": waveform_buffer} - - return feats, feats_lengths - @torch.no_grad() def streaming_decode( self, @@ -410,14 +298,9 @@ class Speech2Text: if isinstance(speech, np.ndarray): speech = torch.tensor(speech) - # lengths: (1,) - # feats, feats_length = self.apply_frontend(speech) feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) - # lengths: (1,) feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) - # print(feats.shape) - # print(feats_lengths) if self.asr_model.normalize is not None: feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) @@ -495,6 +378,7 @@ def inference( data_path_and_name_and_type: Sequence[Tuple[str, str, str]], asr_train_config: Optional[str], asr_model_file: Optional[str], + cmvn_file: Optional[str], beam_search_config: Optional[dict], lm_train_config: Optional[str], lm_file: Optional[str], @@ -562,7 +446,6 @@ def inference( device = "cuda" else: device = "cpu" - # 1. Set random-seed set_all_random_seed(seed) @@ -570,6 +453,7 @@ def inference( speech2text_kwargs = dict( asr_train_config=asr_train_config, asr_model_file=asr_model_file, + cmvn_file=cmvn_file, beam_search_config=beam_search_config, lm_train_config=lm_train_config, lm_file=lm_file, @@ -719,6 +603,11 @@ def get_parser(): type=str, help="ASR model parameter file", ) + group.add_argument( + "--cmvn_file", + type=str, + help="Global cmvn file", + ) group.add_argument( "--lm_train_config", type=str, diff --git a/funasr/models_transducer/encoder/blocks/conv_input.py b/funasr/models_transducer/encoder/blocks/conv_input.py index 931d0f0eb..c68c73b3d 100644 --- a/funasr/models_transducer/encoder/blocks/conv_input.py +++ b/funasr/models_transducer/encoder/blocks/conv_input.py @@ -120,7 +120,7 @@ class ConvInput(torch.nn.Module): self.create_new_mask = self.create_new_conv2d_mask self.vgg_like = vgg_like - self.min_frame_length = 2 + self.min_frame_length = 7 if output_size is not None: self.output = torch.nn.Linear(output_proj, output_size) @@ -218,9 +218,4 @@ class ConvInput(torch.nn.Module): : Number of frames before subsampling. """ - if self.subsampling_factor > 1: - if self.vgg_like: - return ((size * 2) * self.stride_1) + 1 - - return ((size + 2) * 2) + (self.kernel_2 - 1) * self.stride_2 - return size + return size * self.subsampling_factor diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py index e0884cef6..cc5b70886 100644 --- a/funasr/tasks/abs_task.py +++ b/funasr/tasks/abs_task.py @@ -1576,7 +1576,7 @@ class AbsTask(ABC): preprocess=iter_options.preprocess_fn, max_cache_size=iter_options.max_cache_size, max_cache_fd=iter_options.max_cache_fd, - dest_sample_rate=args.frontend_conf["fs"], + dest_sample_rate=args.frontend_conf["fs"] if args.frontend_conf else 16000, ) cls.check_task_requirements( dataset, args.allow_variable_data_keys, train=iter_options.train From 3e333c0abf31825e84d9673faf5e77601ced1112 Mon Sep 17 00:00:00 2001 From: aky15 Date: Thu, 16 Mar 2023 16:49:03 +0800 Subject: [PATCH 03/14] space between tokens --- funasr/models_transducer/error_calculator.py | 1 - .../models_transducer/espnet_transducer_model_unified.py | 4 ++-- funasr/tasks/asr_transducer.py | 7 +++++++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/funasr/models_transducer/error_calculator.py b/funasr/models_transducer/error_calculator.py index 17dbf362f..34b1dc74e 100644 --- a/funasr/models_transducer/error_calculator.py +++ b/funasr/models_transducer/error_calculator.py @@ -137,7 +137,6 @@ class ErrorCalculator: for i, char_pred_i in enumerate(char_pred): pred = char_pred_i.replace(" ", "") target = char_target[i].replace(" ", "") - distances.append(editdistance.eval(pred, target)) lens.append(len(target)) diff --git a/funasr/models_transducer/espnet_transducer_model_unified.py b/funasr/models_transducer/espnet_transducer_model_unified.py index efe3f4eb1..6df86f892 100644 --- a/funasr/models_transducer/espnet_transducer_model_unified.py +++ b/funasr/models_transducer/espnet_transducer_model_unified.py @@ -455,7 +455,8 @@ class ESPnetASRUnifiedTransducerModel(AbsESPnetModel): gather=True, ) - if not self.training and (self.report_cer or self.report_wer): + #if not self.training and (self.report_cer or self.report_wer): + if self.report_cer or self.report_wer: if self.error_calculator is None: self.error_calculator = ErrorCalculator( self.decoder, @@ -468,7 +469,6 @@ class ESPnetASRUnifiedTransducerModel(AbsESPnetModel): ) cer_transducer, wer_transducer = self.error_calculator(encoder_out, target) - return loss_transducer, cer_transducer, wer_transducer return loss_transducer, None, None diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py index 3c7a78261..be1445590 100644 --- a/funasr/tasks/asr_transducer.py +++ b/funasr/tasks/asr_transducer.py @@ -137,6 +137,12 @@ class ASRTransducerTask(AbsTask): default=None, help="Integer-string mapper for tokens.", ) + group.add_argument( + "--split_with_space", + type=str2bool, + default=True, + help="whether to split text using ", + ) group.add_argument( "--input_size", type=int_or_none, @@ -289,6 +295,7 @@ class ASRTransducerTask(AbsTask): non_linguistic_symbols=args.non_linguistic_symbols, text_cleaner=args.cleaner, g2p_type=args.g2p, + split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False, rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, rir_apply_prob=args.rir_apply_prob if hasattr(args, "rir_apply_prob") From fc9595625855be5b63f86a38ac785e49c142c0ae Mon Sep 17 00:00:00 2001 From: aky15 Date: Tue, 21 Mar 2023 14:10:03 +0800 Subject: [PATCH 04/14] embed debug --- .../encoder/blocks/conv_input.py | 15 ++++++++------- funasr/models_transducer/encoder/encoder.py | 12 +++++++----- .../espnet_transducer_model_unified.py | 3 +-- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/funasr/models_transducer/encoder/blocks/conv_input.py b/funasr/models_transducer/encoder/blocks/conv_input.py index c68c73b3d..ffec93e5e 100644 --- a/funasr/models_transducer/encoder/blocks/conv_input.py +++ b/funasr/models_transducer/encoder/blocks/conv_input.py @@ -146,30 +146,31 @@ class ConvInput(torch.nn.Module): if mask is not None: mask = self.create_new_mask(mask) olens = max(mask.eq(0).sum(1)) - - b, t_input, f = x.size() + + b, t, f = x.size() x = x.unsqueeze(1) # (b. 1. t. f) + if chunk_size is not None: max_input_length = int( - chunk_size * self.subsampling_factor * (math.ceil(float(t_input) / (chunk_size * self.subsampling_factor) )) + chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) )) ) x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x) x = list(x) x = torch.stack(x, dim=0) N_chunks = max_input_length // ( chunk_size * self.subsampling_factor) x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f) + x = self.conv(x) - _, c, t, f = x.size() - + _, c, _, f = x.size() if chunk_size is not None: x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:] else: - x = x.transpose(1, 2).contiguous().view(b, t, c * f) + x = x.transpose(1, 2).contiguous().view(b, -1, c * f) if self.output is not None: x = self.output(x) - + return x, mask[:,:olens][:,:x.size(1)] def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor: diff --git a/funasr/models_transducer/encoder/encoder.py b/funasr/models_transducer/encoder/encoder.py index 45c99c1de..b486a113f 100644 --- a/funasr/models_transducer/encoder/encoder.py +++ b/funasr/models_transducer/encoder/encoder.py @@ -134,14 +134,11 @@ class Encoder(torch.nn.Module): ) mask = make_source_mask(x_len) - if self.unified_model_training: - x, mask = self.embed(x, mask, self.default_chunk_size) - else: - x, mask = self.embed(x, mask) - pos_enc = self.pos_enc(x) if self.unified_model_training: chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item() + x, mask = self.embed(x, mask, chunk_size) + pos_enc = self.pos_enc(x) chunk_mask = make_chunk_mask( x.size(1), chunk_size, @@ -178,6 +175,9 @@ class Encoder(torch.nn.Module): else: chunk_size = (chunk_size % self.short_chunk_size) + 1 + x, mask = self.embed(x, mask, chunk_size) + pos_enc = self.pos_enc(x) + chunk_mask = make_chunk_mask( x.size(1), chunk_size, @@ -185,6 +185,8 @@ class Encoder(torch.nn.Module): device=x.device, ) else: + x, mask = self.embed(x, mask, None) + pos_enc = self.pos_enc(x) chunk_mask = None x = self.encoders( x, diff --git a/funasr/models_transducer/espnet_transducer_model_unified.py b/funasr/models_transducer/espnet_transducer_model_unified.py index 6df86f892..be61e8381 100644 --- a/funasr/models_transducer/espnet_transducer_model_unified.py +++ b/funasr/models_transducer/espnet_transducer_model_unified.py @@ -455,8 +455,7 @@ class ESPnetASRUnifiedTransducerModel(AbsESPnetModel): gather=True, ) - #if not self.training and (self.report_cer or self.report_wer): - if self.report_cer or self.report_wer: + if not self.training and (self.report_cer or self.report_wer): if self.error_calculator is None: self.error_calculator = ErrorCalculator( self.decoder, From 8a100b731efba8c18f7e7b6cb1cb04ded94248b3 Mon Sep 17 00:00:00 2001 From: aky15 Date: Tue, 21 Mar 2023 14:52:15 +0800 Subject: [PATCH 05/14] add aishell-1 rnnt egs --- egs/aishell/rnnt/README.md | 17 ++ .../conf/decode_rnnt_conformer_streaming.yaml | 8 + .../decode_rnnt_conformer_streaming_simu.yaml | 5 + .../conf/train_conformer_rnnt_unified.yaml | 84 ++++++ egs/aishell/rnnt/local/aishell_data_prep.sh | 66 +++++ egs/aishell/rnnt/path.sh | 5 + egs/aishell/rnnt/run.sh | 247 ++++++++++++++++++ egs/aishell/rnnt/utils | 1 + 8 files changed, 433 insertions(+) create mode 100644 egs/aishell/rnnt/README.md create mode 100644 egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming.yaml create mode 100644 egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming_simu.yaml create mode 100644 egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml create mode 100755 egs/aishell/rnnt/local/aishell_data_prep.sh create mode 100644 egs/aishell/rnnt/path.sh create mode 100755 egs/aishell/rnnt/run.sh create mode 120000 egs/aishell/rnnt/utils diff --git a/egs/aishell/rnnt/README.md b/egs/aishell/rnnt/README.md new file mode 100644 index 000000000..4d6ac9de3 --- /dev/null +++ b/egs/aishell/rnnt/README.md @@ -0,0 +1,17 @@ + +# Streaming RNN-T Result + +## Training Config +- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment +- Train config: conf/train_conformer_rnnt_unified +- chunk config: chunk size 16, 1 left chunk +- LM config: LM was not used +- Model size: 90M + +## Results (CER) +- Decode config: conf/train_conformer_rnnt_unified.yaml + +| testset | CER(%) | +|:-----------:|:-------:| +| dev | 5.89 | +| test | 6.76 | diff --git a/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming.yaml b/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming.yaml new file mode 100644 index 000000000..26e43c64d --- /dev/null +++ b/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming.yaml @@ -0,0 +1,8 @@ +# The conformer transducer decoding configuration from @jeon30c +beam_size: 10 +simu_streaming: false +streaming: true +chunk_size: 16 +left_context: 16 +right_context: 0 + diff --git a/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming_simu.yaml b/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming_simu.yaml new file mode 100644 index 000000000..dc3eff2a5 --- /dev/null +++ b/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming_simu.yaml @@ -0,0 +1,5 @@ +# The conformer transducer decoding configuration from @jeon30c +beam_size: 10 +simu_streaming: true +streaming: false +chunk_size: 16 diff --git a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml new file mode 100644 index 000000000..ef37b97eb --- /dev/null +++ b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml @@ -0,0 +1,84 @@ +encoder_conf: + main_conf: + pos_wise_act_type: swish + pos_enc_dropout_rate: 0.3 + conv_mod_act_type: swish + time_reduction_factor: 2 + unified_model_training: true + default_chunk_size: 16 + jitter_range: 4 + left_chunk_size: 1 + input_conf: + block_type: conv2d + conv_size: 512 + subsampling_factor: 4 + num_frame: 1 + body_conf: + - block_type: conformer + linear_size: 2048 + hidden_size: 512 + heads: 8 + dropout_rate: 0.3 + pos_wise_dropout_rate: 0.3 + att_dropout_rate: 0.3 + conv_mod_kernel_size: 15 + num_blocks: 12 + +# decoder related +decoder: rnn +decoder_conf: + embed_size: 512 + hidden_size: 512 + embed_dropout_rate: 0.2 + dropout_rate: 0.1 + +joint_network_conf: + joint_space_size: 512 + +# Auxiliary CTC +model_conf: + auxiliary_ctc_weight: 0.0 + +# minibatch related +use_amp: true +batch_type: numel +batch_bins: 1600000 +num_workers: 16 + +# optimization related +accum_grad: 1 +grad_clip: 5 +max_epoch: 80 +val_scheduler_criterion: + - valid + - loss +best_model_criterion: +- - valid + - cer_transducer_chunk + - min +keep_nbest_models: 5 + +optim: adam +optim_conf: + lr: 0.0003 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 25000 + +normalize: None + +specaug: specaug +specaug_conf: + apply_time_warp: true + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 30 + num_freq_mask: 2 + apply_time_mask: true + time_mask_width_range: + - 0 + - 40 + num_time_mask: 2 diff --git a/egs/aishell/rnnt/local/aishell_data_prep.sh b/egs/aishell/rnnt/local/aishell_data_prep.sh new file mode 100755 index 000000000..83f489b3c --- /dev/null +++ b/egs/aishell/rnnt/local/aishell_data_prep.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +# Copyright 2017 Xingyu Na +# Apache 2.0 + +#. ./path.sh || exit 1; + +if [ $# != 3 ]; then + echo "Usage: $0 " + echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data" + exit 1; +fi + +aishell_audio_dir=$1 +aishell_text=$2/aishell_transcript_v0.8.txt +output_dir=$3 + +train_dir=$output_dir/data/local/train +dev_dir=$output_dir/data/local/dev +test_dir=$output_dir/data/local/test +tmp_dir=$output_dir/data/local/tmp + +mkdir -p $train_dir +mkdir -p $dev_dir +mkdir -p $test_dir +mkdir -p $tmp_dir + +# data directory check +if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then + echo "Error: $0 requires two directory arguments" + exit 1; +fi + +# find wav audio file for train, dev and test resp. +find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist +n=`cat $tmp_dir/wav.flist | wc -l` +[ $n -ne 141925 ] && \ + echo Warning: expected 141925 data data files, found $n + +grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1; +grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1; +grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1; + +rm -r $tmp_dir + +# Transcriptions preparation +for dir in $train_dir $dev_dir $test_dir; do + echo Preparing $dir transcriptions + sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list + paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all + utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt + awk '{print $1}' $dir/transcripts.txt > $dir/utt.list + utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp + sort -u $dir/transcripts.txt > $dir/text +done + +mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test + +for f in wav.scp text; do + cp $train_dir/$f $output_dir/data/train/$f || exit 1; + cp $dev_dir/$f $output_dir/data/dev/$f || exit 1; + cp $test_dir/$f $output_dir/data/test/$f || exit 1; +done + +echo "$0: AISHELL data preparation succeeded" +exit 0; diff --git a/egs/aishell/rnnt/path.sh b/egs/aishell/rnnt/path.sh new file mode 100644 index 000000000..7972642d0 --- /dev/null +++ b/egs/aishell/rnnt/path.sh @@ -0,0 +1,5 @@ +export FUNASR_DIR=$PWD/../../.. + +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PATH=$FUNASR_DIR/funasr/bin:$PATH diff --git a/egs/aishell/rnnt/run.sh b/egs/aishell/rnnt/run.sh new file mode 100755 index 000000000..bcd4a8b9f --- /dev/null +++ b/egs/aishell/rnnt/run.sh @@ -0,0 +1,247 @@ +#!/usr/bin/env bash + +. ./path.sh || exit 1; + +# machines configuration +CUDA_VISIBLE_DEVICES="0,1,2,3" +gpu_num=4 +count=1 +gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding +# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob +njob=5 +train_cmd=utils/run.pl +infer_cmd=utils/run.pl + +# general configuration +feats_dir= #feature output dictionary +exp_dir= +lang=zh +dumpdir=dump/fbank +feats_type=fbank +token_type=char +scp=feats.scp +type=kaldi_ark +stage=0 +stop_stage=4 + +# feature configuration +feats_dim=80 +sample_frequency=16000 +nj=32 +speed_perturb="0.9,1.0,1.1" + +# data +data_aishell= + +# exp tag +tag="exp1" + +. utils/parse_options.sh || exit 1; + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +train_set=train +valid_set=dev +test_sets="dev test" + +asr_config=conf/train_conformer_rnnt_unified.yaml +model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}" + +inference_config=conf/decode_rnnt_conformer_streaming.yaml +inference_asr_model=valid.cer_transducer_chunk.ave_5best.pth + +# you can set gpu num for decoding here +gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default +ngpu=$(echo $gpuid_list | awk -F "," '{print NF}') + +if ${gpu_inference}; then + inference_nj=$[${ngpu}*${njob}] + _ngpu=1 +else + inference_nj=$njob + _ngpu=0 +fi + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "stage 0: Data preparation" + # Data preparation + local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir} + for x in train dev test; do + cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org + paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \ + > ${feats_dir}/data/${x}/text + utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org + mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text + done +fi + +feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir} +feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir} +feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir} +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "stage 1: Feature Generation" + # compute fbank features + fbankdir=${feats_dir}/fbank + utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \ + ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train + utils/fix_data_feat.sh ${fbankdir}/train + utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \ + ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev + utils/fix_data_feat.sh ${fbankdir}/dev + utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \ + ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test + utils/fix_data_feat.sh ${fbankdir}/test + + # compute global cmvn + utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \ + ${fbankdir}/train ${exp_dir}/exp/make_fbank/train + + # apply cmvn + utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \ + ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir} + utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \ + ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir} + utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \ + ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir} + + cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir} + cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir} + cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir} + + utils/fix_data_feat.sh ${feat_train_dir} + utils/fix_data_feat.sh ${feat_dev_dir} + utils/fix_data_feat.sh ${feat_test_dir} + + #generate ark list + utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir} + utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir} +fi + +token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt +echo "dictionary: ${token_list}" +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "stage 2: Dictionary Preparation" + mkdir -p ${feats_dir}/data/${lang}_token_list/char/ + + echo "make a dictionary" + echo "" > ${token_list} + echo "" >> ${token_list} + echo "" >> ${token_list} + utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \ + | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list} + num_token=$(cat ${token_list} | wc -l) + echo "" >> ${token_list} + vocab_size=$(cat ${token_list} | wc -l) + awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char + awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char + mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train + mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev + cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train + cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev +fi + +# Training Stage +world_size=$gpu_num # run on one machine +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "stage 3: Training" + mkdir -p ${exp_dir}/exp/${model_dir} + mkdir -p ${exp_dir}/exp/${model_dir}/log + INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $gpu_num; ++i)); do + { + rank=$i + local_rank=$i + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + asr_train_transducer.py \ + --gpu_id $gpu_id \ + --use_preprocessor true \ + --token_type char \ + --token_list $token_list \ + --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \ + --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \ + --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \ + --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \ + --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \ + --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \ + --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \ + --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \ + --resume true \ + --output_dir ${exp_dir}/exp/${model_dir} \ + --config $asr_config \ + --input_size $feats_dim \ + --ngpu $gpu_num \ + --num_worker_count $count \ + --multiprocessing_distributed true \ + --dist_init_method $init_method \ + --dist_world_size $world_size \ + --dist_rank $rank \ + --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1 + } & + done + wait +fi + +# Testing Stage +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + echo "stage 4: Inference" + for dset in ${test_sets}; do + asr_exp=${exp_dir}/exp/${model_dir} + inference_tag="$(basename "${inference_config}" .yaml)" + _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}" + _logdir="${_dir}/logdir" + if [ -d ${_dir} ]; then + echo "${_dir} is already exists. if you want to decode again, please delete this dir first." + exit 0 + fi + mkdir -p "${_logdir}" + _data="${feats_dir}/${dumpdir}/${dset}" + key_file=${_data}/${scp} + num_scp_file="$(<${key_file} wc -l)" + _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file") + split_scps= + for n in $(seq "${_nj}"); do + split_scps+=" ${_logdir}/keys.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + _opts= + if [ -n "${inference_config}" ]; then + _opts+="--config ${inference_config} " + fi + ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ + python -m funasr.bin.asr_inference_launch \ + --batch_size 1 \ + --ngpu "${_ngpu}" \ + --njob ${njob} \ + --gpuid_list ${gpuid_list} \ + --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \ + --key_file "${_logdir}"/keys.JOB.scp \ + --asr_train_config "${asr_exp}"/config.yaml \ + --asr_model_file "${asr_exp}"/"${inference_asr_model}" \ + --output_dir "${_logdir}"/output.JOB \ + --mode rnnt \ + ${_opts} + + for f in token token_int score text; do + if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then + for i in $(seq "${_nj}"); do + cat "${_logdir}/output.${i}/1best_recog/${f}" + done | sort -k1 >"${_dir}/${f}" + fi + done + python utils/proce_text.py ${_dir}/text ${_dir}/text.proc + python utils/proce_text.py ${_data}/text ${_data}/text.proc + python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer + tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt + cat ${_dir}/text.cer.txt + done +fi diff --git a/egs/aishell/rnnt/utils b/egs/aishell/rnnt/utils new file mode 120000 index 000000000..4072eacc1 --- /dev/null +++ b/egs/aishell/rnnt/utils @@ -0,0 +1 @@ +../transformer/utils \ No newline at end of file From e01f742a859dfccbd4fb1208f08f16f70abda45a Mon Sep 17 00:00:00 2001 From: aky15 Date: Mon, 10 Apr 2023 15:54:12 +0800 Subject: [PATCH 06/14] rnnt infer --- funasr/bin/asr_inference_rnnt.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py index c8a2916c2..f65bd07eb 100644 --- a/funasr/bin/asr_inference_rnnt.py +++ b/funasr/bin/asr_inference_rnnt.py @@ -301,9 +301,6 @@ class Speech2Text: feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) - if self.asr_model.normalize is not None: - feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) - feats = to_device(feats, device=self.device) feats_lengths = to_device(feats_lengths, device=self.device) From 7d1efe158eda74dc847c397db906f6cb77ac0f84 Mon Sep 17 00:00:00 2001 From: aky15 Date: Wed, 12 Apr 2023 16:49:56 +0800 Subject: [PATCH 07/14] rnnt reorg --- .../conf/train_conformer_rnnt_unified.yaml | 32 +- funasr/bin/asr_inference_rnnt.py | 58 +- .../e2e_transducer.py} | 10 +- .../e2e_transducer_unified.py} | 13 +- .../encoder/chunk_encoder.py} | 26 +- .../encoder/chunk_encoder_blocks}/__init__.py | 0 .../chunk_encoder_blocks}/branchformer.py | 0 .../chunk_encoder_blocks}/conformer.py | 0 .../encoder/chunk_encoder_blocks}/conv1d.py | 0 .../chunk_encoder_blocks}/conv_input.py | 2 +- .../chunk_encoder_blocks}/linear_input.py | 0 .../chunk_encoder_modules}/__init__.py | 0 .../chunk_encoder_modules}/attention.py | 0 .../chunk_encoder_modules}/convolution.py | 0 .../chunk_encoder_modules}/multi_blocks.py | 0 .../chunk_encoder_modules}/normalization.py | 0 .../positional_encoding.py | 0 .../encoder/chunk_encoder_utils}/building.py | 22 +- .../chunk_encoder_utils}/validation.py | 2 +- .../joint_network.py | 2 +- .../rnnt_decoder}/__init__.py | 0 .../rnnt_decoder}/abs_decoder.py | 0 .../rnnt_decoder}/rnn_decoder.py | 4 +- .../rnnt_decoder}/stateless_decoder.py | 16 +- .../encoder/blocks/__init__.py | 0 .../encoder/modules/__init__.py | 0 .../models_transducer/encoder/sanm_encoder.py | 835 ------------------ funasr/models_transducer/error_calculator.py | 169 ---- .../espnet_transducer_model_uni_asr.py | 485 ---------- funasr/models_transducer/utils.py | 200 ----- .../activation.py | 0 .../beam_search}/beam_search_transducer.py | 4 +- funasr/modules/e2e_asr_common.py | 151 ++++ funasr/modules/nets_utils.py | 195 +++- funasr/tasks/asr_transducer.py | 41 +- 35 files changed, 418 insertions(+), 1849 deletions(-) rename funasr/{models_transducer/espnet_transducer_model.py => models/e2e_transducer.py} (98%) rename funasr/{models_transducer/espnet_transducer_model_unified.py => models/e2e_transducer_unified.py} (98%) rename funasr/{models_transducer/encoder/encoder.py => models/encoder/chunk_encoder.py} (96%) rename funasr/{models_transducer => models/encoder/chunk_encoder_blocks}/__init__.py (100%) rename funasr/{models_transducer/encoder/blocks => models/encoder/chunk_encoder_blocks}/branchformer.py (100%) rename funasr/{models_transducer/encoder/blocks => models/encoder/chunk_encoder_blocks}/conformer.py (100%) rename funasr/{models_transducer/encoder/blocks => models/encoder/chunk_encoder_blocks}/conv1d.py (100%) rename funasr/{models_transducer/encoder/blocks => models/encoder/chunk_encoder_blocks}/conv_input.py (98%) rename funasr/{models_transducer/encoder/blocks => models/encoder/chunk_encoder_blocks}/linear_input.py (100%) rename funasr/{models_transducer/decoder => models/encoder/chunk_encoder_modules}/__init__.py (100%) rename funasr/{models_transducer/encoder/modules => models/encoder/chunk_encoder_modules}/attention.py (100%) rename funasr/{models_transducer/encoder/modules => models/encoder/chunk_encoder_modules}/convolution.py (100%) rename funasr/{models_transducer/encoder/modules => models/encoder/chunk_encoder_modules}/multi_blocks.py (100%) rename funasr/{models_transducer/encoder/modules => models/encoder/chunk_encoder_modules}/normalization.py (100%) rename funasr/{models_transducer/encoder/modules => models/encoder/chunk_encoder_modules}/positional_encoding.py (100%) rename funasr/{models_transducer/encoder => models/encoder/chunk_encoder_utils}/building.py (92%) rename funasr/{models_transducer/encoder => models/encoder/chunk_encoder_utils}/validation.py (98%) rename funasr/{models_transducer => models}/joint_network.py (96%) rename funasr/{models_transducer/encoder => models/rnnt_decoder}/__init__.py (100%) rename funasr/{models_transducer/decoder => models/rnnt_decoder}/abs_decoder.py (100%) rename funasr/{models_transducer/decoder => models/rnnt_decoder}/rnn_decoder.py (98%) rename funasr/{models_transducer/decoder => models/rnnt_decoder}/stateless_decoder.py (86%) delete mode 100644 funasr/models_transducer/encoder/blocks/__init__.py delete mode 100644 funasr/models_transducer/encoder/modules/__init__.py delete mode 100644 funasr/models_transducer/encoder/sanm_encoder.py delete mode 100644 funasr/models_transducer/error_calculator.py delete mode 100644 funasr/models_transducer/espnet_transducer_model_uni_asr.py delete mode 100644 funasr/models_transducer/utils.py rename funasr/{models_transducer => modules}/activation.py (100%) rename funasr/{models_transducer => modules/beam_search}/beam_search_transducer.py (99%) diff --git a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml index ef37b97eb..60f796c75 100644 --- a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml +++ b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml @@ -1,13 +1,13 @@ encoder_conf: main_conf: pos_wise_act_type: swish - pos_enc_dropout_rate: 0.3 + pos_enc_dropout_rate: 0.5 conv_mod_act_type: swish time_reduction_factor: 2 unified_model_training: true default_chunk_size: 16 jitter_range: 4 - left_chunk_size: 1 + left_chunk_size: 0 input_conf: block_type: conv2d conv_size: 512 @@ -18,9 +18,9 @@ encoder_conf: linear_size: 2048 hidden_size: 512 heads: 8 - dropout_rate: 0.3 - pos_wise_dropout_rate: 0.3 - att_dropout_rate: 0.3 + dropout_rate: 0.5 + pos_wise_dropout_rate: 0.5 + att_dropout_rate: 0.5 conv_mod_kernel_size: 15 num_blocks: 12 @@ -29,8 +29,8 @@ decoder: rnn decoder_conf: embed_size: 512 hidden_size: 512 - embed_dropout_rate: 0.2 - dropout_rate: 0.1 + embed_dropout_rate: 0.5 + dropout_rate: 0.5 joint_network_conf: joint_space_size: 512 @@ -41,14 +41,14 @@ model_conf: # minibatch related use_amp: true -batch_type: numel -batch_bins: 1600000 +batch_type: unsorted +batch_size: 16 num_workers: 16 # optimization related accum_grad: 1 grad_clip: 5 -max_epoch: 80 +max_epoch: 200 val_scheduler_criterion: - valid - loss @@ -56,11 +56,11 @@ best_model_criterion: - - valid - cer_transducer_chunk - min -keep_nbest_models: 5 +keep_nbest_models: 10 optim: adam optim_conf: - lr: 0.0003 + lr: 0.001 scheduler: warmuplr scheduler_conf: warmup_steps: 25000 @@ -75,10 +75,12 @@ specaug_conf: apply_freq_mask: true freq_mask_width_range: - 0 - - 30 + - 40 num_freq_mask: 2 apply_time_mask: true time_mask_width_range: - 0 - - 40 - num_time_mask: 2 + - 50 + num_time_mask: 5 + +log_interval: 50 diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py index 768bf7215..465f88254 100644 --- a/funasr/bin/asr_inference_rnnt.py +++ b/funasr/bin/asr_inference_rnnt.py @@ -16,11 +16,11 @@ import torch from packaging.version import parse as V from typeguard import check_argument_types, check_return_type -from funasr.models_transducer.beam_search_transducer import ( +from funasr.modules.beam_search.beam_search_transducer import ( BeamSearchTransducer, Hypothesis, ) -from funasr.models_transducer.utils import TooShortUttError +from funasr.modules.nets_utils import TooShortUttError from funasr.fileio.datadir_writer import DatadirWriter from funasr.tasks.asr_transducer import ASRTransducerTask from funasr.tasks.lm import LMTask @@ -500,7 +500,6 @@ def inference( _bs = len(next(iter(batch.values()))) assert len(keys) == _bs, f"{len(keys)} != {_bs}" -<<<<<<< HEAD batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} assert len(batch.keys()) == 1 @@ -541,59 +540,6 @@ def inference( if text is not None: ibest_writer["text"][key] = text -======= - # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")} - - logging.info("decoding, utt_id: {}".format(keys)) - # N-best list of (text, token, token_int, hyp_object) - - time_beg = time.time() - results = speech2text(cache=cache, **batch) - if len(results) < 1: - hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) - results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest - time_end = time.time() - forward_time = time_end - time_beg - lfr_factor = results[0][-1] - length = results[0][-2] - forward_time_total += forward_time - length_total += length - rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time, 100 * forward_time / (length * lfr_factor)) - logging.info(rtf_cur) - - for batch_id in range(_bs): - result = [results[batch_id][:-2]] - - key = keys[batch_id] - for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), result): - # Create a directory: outdir/{n}best_recog - if writer is not None: - ibest_writer = writer[f"{n}best_recog"] - - # Write the result to each file - ibest_writer["token"][key] = " ".join(token) - # ibest_writer["token_int"][key] = " ".join(map(str, token_int)) - ibest_writer["score"][key] = str(hyp.score) - ibest_writer["rtf"][key] = rtf_cur - - if text is not None: - text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token) - item = {'key': key, 'value': text_postprocessed} - asr_result_list.append(item) - finish_count += 1 - # asr_utils.print_progress(finish_count / file_count) - if writer is not None: - ibest_writer["text"][key] = " ".join(word_lists) - - logging.info("decoding, utt: {}, predictions: {}".format(key, text)) - rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor)) - logging.info(rtf_avg) - if writer is not None: - ibest_writer["rtf"]["rtf_avf"] = rtf_avg - return asr_result_list - - return _forward ->>>>>>> main def get_parser(): diff --git a/funasr/models_transducer/espnet_transducer_model.py b/funasr/models/e2e_transducer.py similarity index 98% rename from funasr/models_transducer/espnet_transducer_model.py rename to funasr/models/e2e_transducer.py index e32f6e350..b669c9d3e 100644 --- a/funasr/models_transducer/espnet_transducer_model.py +++ b/funasr/models/e2e_transducer.py @@ -10,11 +10,11 @@ from typeguard import check_argument_types from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.specaug.abs_specaug import AbsSpecAug -from funasr.models_transducer.decoder.abs_decoder import AbsDecoder +from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder -from funasr.models_transducer.encoder.encoder import Encoder -from funasr.models_transducer.joint_network import JointNetwork -from funasr.models_transducer.utils import get_transducer_task_io +from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder +from funasr.models.joint_network import JointNetwork +from funasr.modules.nets_utils import get_transducer_task_io from funasr.layers.abs_normalize import AbsNormalize from funasr.torch_utils.device_funcs import force_gatherable from funasr.train.abs_espnet_model import AbsESPnetModel @@ -28,7 +28,7 @@ else: yield -class ESPnetASRTransducerModel(AbsESPnetModel): +class TransducerModel(AbsESPnetModel): """ESPnet2ASRTransducerModel module definition. Args: diff --git a/funasr/models_transducer/espnet_transducer_model_unified.py b/funasr/models/e2e_transducer_unified.py similarity index 98% rename from funasr/models_transducer/espnet_transducer_model_unified.py rename to funasr/models/e2e_transducer_unified.py index be61e8381..600354216 100644 --- a/funasr/models_transducer/espnet_transducer_model_unified.py +++ b/funasr/models/e2e_transducer_unified.py @@ -10,10 +10,10 @@ from typeguard import check_argument_types from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.specaug.abs_specaug import AbsSpecAug -from funasr.models_transducer.decoder.abs_decoder import AbsDecoder -from funasr.models_transducer.encoder.encoder import Encoder -from funasr.models_transducer.joint_network import JointNetwork -from funasr.models_transducer.utils import get_transducer_task_io +from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder +from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder +from funasr.models.joint_network import JointNetwork +from funasr.modules.nets_utils import get_transducer_task_io from funasr.layers.abs_normalize import AbsNormalize from funasr.torch_utils.device_funcs import force_gatherable from funasr.train.abs_espnet_model import AbsESPnetModel @@ -23,7 +23,7 @@ from funasr.modules.nets_utils import th_accuracy from funasr.losses.label_smoothing_loss import ( # noqa: H301 LabelSmoothingLoss, ) -from funasr.models_transducer.error_calculator import ErrorCalculator +from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator if V(torch.__version__) >= V("1.6.0"): from torch.cuda.amp import autocast else: @@ -33,7 +33,7 @@ else: yield -class ESPnetASRUnifiedTransducerModel(AbsESPnetModel): +class UnifiedTransducerModel(AbsESPnetModel): """ESPnet2ASRTransducerModel module definition. Args: @@ -289,7 +289,6 @@ class ESPnetASRUnifiedTransducerModel(AbsESPnetModel): # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) - return loss, stats, weight def collect_feats( diff --git a/funasr/models_transducer/encoder/encoder.py b/funasr/models/encoder/chunk_encoder.py similarity index 96% rename from funasr/models_transducer/encoder/encoder.py rename to funasr/models/encoder/chunk_encoder.py index b486a113f..c6fc292e0 100644 --- a/funasr/models_transducer/encoder/encoder.py +++ b/funasr/models/encoder/chunk_encoder.py @@ -1,26 +1,23 @@ -"""Encoder for Transducer model.""" - from typing import Any, Dict, List, Tuple import torch from typeguard import check_argument_types -from funasr.models_transducer.encoder.building import ( +from funasr.models.encoder.chunk_encoder_utils.building import ( build_body_blocks, build_input_block, build_main_parameters, build_positional_encoding, ) -from funasr.models_transducer.encoder.validation import validate_architecture -from funasr.models_transducer.utils import ( +from funasr.models.encoder.chunk_encoder_utils.validation import validate_architecture +from funasr.modules.nets_utils import ( TooShortUttError, check_short_utt, make_chunk_mask, make_source_mask, ) - -class Encoder(torch.nn.Module): +class ChunkEncoder(torch.nn.Module): """Encoder module definition. Args: @@ -61,10 +58,9 @@ class Encoder(torch.nn.Module): self.unified_model_training = main_params["unified_model_training"] self.default_chunk_size = main_params["default_chunk_size"] - self.jitter_range = main_params["jitter_range"] - - self.time_reduction_factor = main_params["time_reduction_factor"] + self.jitter_range = main_params["jitter_range"] + self.time_reduction_factor = main_params["time_reduction_factor"] def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int: """Return the corresponding number of sample for a given chunk size, in frames. @@ -79,7 +75,7 @@ class Encoder(torch.nn.Module): """ return self.embed.get_size_before_subsampling(size) * hop_length - + def get_encoder_input_size(self, size: int) -> int: """Return the corresponding number of sample for a given chunk size, in frames. @@ -157,7 +153,7 @@ class Encoder(torch.nn.Module): mask, chunk_mask=chunk_mask, ) - + olens = mask.eq(0).sum(1) if self.time_reduction_factor > 1: x_utt = x_utt[:,::self.time_reduction_factor,:] @@ -194,14 +190,14 @@ class Encoder(torch.nn.Module): mask, chunk_mask=chunk_mask, ) - + olens = mask.eq(0).sum(1) if self.time_reduction_factor > 1: x = x[:,::self.time_reduction_factor,:] olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 return x, olens - + def simu_chunk_forward( self, x: torch.Tensor, @@ -290,7 +286,7 @@ class Encoder(torch.nn.Module): if right_context > 0: x = x[:, 0:-right_context, :] - + if self.time_reduction_factor > 1: x = x[:,::self.time_reduction_factor,:] return x diff --git a/funasr/models_transducer/__init__.py b/funasr/models/encoder/chunk_encoder_blocks/__init__.py similarity index 100% rename from funasr/models_transducer/__init__.py rename to funasr/models/encoder/chunk_encoder_blocks/__init__.py diff --git a/funasr/models_transducer/encoder/blocks/branchformer.py b/funasr/models/encoder/chunk_encoder_blocks/branchformer.py similarity index 100% rename from funasr/models_transducer/encoder/blocks/branchformer.py rename to funasr/models/encoder/chunk_encoder_blocks/branchformer.py diff --git a/funasr/models_transducer/encoder/blocks/conformer.py b/funasr/models/encoder/chunk_encoder_blocks/conformer.py similarity index 100% rename from funasr/models_transducer/encoder/blocks/conformer.py rename to funasr/models/encoder/chunk_encoder_blocks/conformer.py diff --git a/funasr/models_transducer/encoder/blocks/conv1d.py b/funasr/models/encoder/chunk_encoder_blocks/conv1d.py similarity index 100% rename from funasr/models_transducer/encoder/blocks/conv1d.py rename to funasr/models/encoder/chunk_encoder_blocks/conv1d.py diff --git a/funasr/models_transducer/encoder/blocks/conv_input.py b/funasr/models/encoder/chunk_encoder_blocks/conv_input.py similarity index 98% rename from funasr/models_transducer/encoder/blocks/conv_input.py rename to funasr/models/encoder/chunk_encoder_blocks/conv_input.py index ffec93e5e..b9bd2fdc2 100644 --- a/funasr/models_transducer/encoder/blocks/conv_input.py +++ b/funasr/models/encoder/chunk_encoder_blocks/conv_input.py @@ -5,7 +5,7 @@ from typing import Optional, Tuple, Union import torch import math -from funasr.models_transducer.utils import sub_factor_to_params, pad_to_len +from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len class ConvInput(torch.nn.Module): diff --git a/funasr/models_transducer/encoder/blocks/linear_input.py b/funasr/models/encoder/chunk_encoder_blocks/linear_input.py similarity index 100% rename from funasr/models_transducer/encoder/blocks/linear_input.py rename to funasr/models/encoder/chunk_encoder_blocks/linear_input.py diff --git a/funasr/models_transducer/decoder/__init__.py b/funasr/models/encoder/chunk_encoder_modules/__init__.py similarity index 100% rename from funasr/models_transducer/decoder/__init__.py rename to funasr/models/encoder/chunk_encoder_modules/__init__.py diff --git a/funasr/models_transducer/encoder/modules/attention.py b/funasr/models/encoder/chunk_encoder_modules/attention.py similarity index 100% rename from funasr/models_transducer/encoder/modules/attention.py rename to funasr/models/encoder/chunk_encoder_modules/attention.py diff --git a/funasr/models_transducer/encoder/modules/convolution.py b/funasr/models/encoder/chunk_encoder_modules/convolution.py similarity index 100% rename from funasr/models_transducer/encoder/modules/convolution.py rename to funasr/models/encoder/chunk_encoder_modules/convolution.py diff --git a/funasr/models_transducer/encoder/modules/multi_blocks.py b/funasr/models/encoder/chunk_encoder_modules/multi_blocks.py similarity index 100% rename from funasr/models_transducer/encoder/modules/multi_blocks.py rename to funasr/models/encoder/chunk_encoder_modules/multi_blocks.py diff --git a/funasr/models_transducer/encoder/modules/normalization.py b/funasr/models/encoder/chunk_encoder_modules/normalization.py similarity index 100% rename from funasr/models_transducer/encoder/modules/normalization.py rename to funasr/models/encoder/chunk_encoder_modules/normalization.py diff --git a/funasr/models_transducer/encoder/modules/positional_encoding.py b/funasr/models/encoder/chunk_encoder_modules/positional_encoding.py similarity index 100% rename from funasr/models_transducer/encoder/modules/positional_encoding.py rename to funasr/models/encoder/chunk_encoder_modules/positional_encoding.py diff --git a/funasr/models_transducer/encoder/building.py b/funasr/models/encoder/chunk_encoder_utils/building.py similarity index 92% rename from funasr/models_transducer/encoder/building.py rename to funasr/models/encoder/chunk_encoder_utils/building.py index a19943be7..21611aa19 100644 --- a/funasr/models_transducer/encoder/building.py +++ b/funasr/models/encoder/chunk_encoder_utils/building.py @@ -2,22 +2,22 @@ from typing import Any, Dict, List, Optional, Union -from funasr.models_transducer.activation import get_activation -from funasr.models_transducer.encoder.blocks.branchformer import Branchformer -from funasr.models_transducer.encoder.blocks.conformer import Conformer -from funasr.models_transducer.encoder.blocks.conv1d import Conv1d -from funasr.models_transducer.encoder.blocks.conv_input import ConvInput -from funasr.models_transducer.encoder.blocks.linear_input import LinearInput -from funasr.models_transducer.encoder.modules.attention import ( # noqa: H301 +from funasr.modules.activation import get_activation +from funasr.models.encoder.chunk_encoder_blocks.branchformer import Branchformer +from funasr.models.encoder.chunk_encoder_blocks.conformer import Conformer +from funasr.models.encoder.chunk_encoder_blocks.conv1d import Conv1d +from funasr.models.encoder.chunk_encoder_blocks.conv_input import ConvInput +from funasr.models.encoder.chunk_encoder_blocks.linear_input import LinearInput +from funasr.models.encoder.chunk_encoder_modules.attention import ( # noqa: H301 RelPositionMultiHeadedAttention, ) -from funasr.models_transducer.encoder.modules.convolution import ( # noqa: H301 +from funasr.models.encoder.chunk_encoder_modules.convolution import ( # noqa: H301 ConformerConvolution, ConvolutionalSpatialGatingUnit, ) -from funasr.models_transducer.encoder.modules.multi_blocks import MultiBlocks -from funasr.models_transducer.encoder.modules.normalization import get_normalization -from funasr.models_transducer.encoder.modules.positional_encoding import ( # noqa: H301 +from funasr.models.encoder.chunk_encoder_modules.multi_blocks import MultiBlocks +from funasr.models.encoder.chunk_encoder_modules.normalization import get_normalization +from funasr.models.encoder.chunk_encoder_modules.positional_encoding import ( # noqa: H301 RelPositionalEncoding, ) from funasr.modules.positionwise_feed_forward import ( diff --git a/funasr/models_transducer/encoder/validation.py b/funasr/models/encoder/chunk_encoder_utils/validation.py similarity index 98% rename from funasr/models_transducer/encoder/validation.py rename to funasr/models/encoder/chunk_encoder_utils/validation.py index 00035363a..1103cb93f 100644 --- a/funasr/models_transducer/encoder/validation.py +++ b/funasr/models/encoder/chunk_encoder_utils/validation.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Tuple -from funasr.models_transducer.utils import sub_factor_to_params +from funasr.modules.nets_utils import sub_factor_to_params def validate_block_arguments( diff --git a/funasr/models_transducer/joint_network.py b/funasr/models/joint_network.py similarity index 96% rename from funasr/models_transducer/joint_network.py rename to funasr/models/joint_network.py index 119dd84a5..5cabdb4f7 100644 --- a/funasr/models_transducer/joint_network.py +++ b/funasr/models/joint_network.py @@ -2,7 +2,7 @@ import torch -from funasr.models_transducer.activation import get_activation +from funasr.modules.activation import get_activation class JointNetwork(torch.nn.Module): diff --git a/funasr/models_transducer/encoder/__init__.py b/funasr/models/rnnt_decoder/__init__.py similarity index 100% rename from funasr/models_transducer/encoder/__init__.py rename to funasr/models/rnnt_decoder/__init__.py diff --git a/funasr/models_transducer/decoder/abs_decoder.py b/funasr/models/rnnt_decoder/abs_decoder.py similarity index 100% rename from funasr/models_transducer/decoder/abs_decoder.py rename to funasr/models/rnnt_decoder/abs_decoder.py diff --git a/funasr/models_transducer/decoder/rnn_decoder.py b/funasr/models/rnnt_decoder/rnn_decoder.py similarity index 98% rename from funasr/models_transducer/decoder/rnn_decoder.py rename to funasr/models/rnnt_decoder/rnn_decoder.py index 04c32287a..c4e79511c 100644 --- a/funasr/models_transducer/decoder/rnn_decoder.py +++ b/funasr/models/rnnt_decoder/rnn_decoder.py @@ -5,8 +5,8 @@ from typing import List, Optional, Tuple import torch from typeguard import check_argument_types -from funasr.models_transducer.beam_search_transducer import Hypothesis -from funasr.models_transducer.decoder.abs_decoder import AbsDecoder +from funasr.modules.beam_search.beam_search_transducer import Hypothesis +from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder from funasr.models.specaug.specaug import SpecAug class RNNDecoder(AbsDecoder): diff --git a/funasr/models_transducer/decoder/stateless_decoder.py b/funasr/models/rnnt_decoder/stateless_decoder.py similarity index 86% rename from funasr/models_transducer/decoder/stateless_decoder.py rename to funasr/models/rnnt_decoder/stateless_decoder.py index 07c8f519b..a2e1fc14b 100644 --- a/funasr/models_transducer/decoder/stateless_decoder.py +++ b/funasr/models/rnnt_decoder/stateless_decoder.py @@ -5,8 +5,8 @@ from typing import List, Optional, Tuple import torch from typeguard import check_argument_types -from funasr.models_transducer.beam_search_transducer import Hypothesis -from funasr.models_transducer.decoder.abs_decoder import AbsDecoder +from funasr.modules.beam_search.beam_search_transducer import Hypothesis +from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder from funasr.models.specaug.specaug import SpecAug class StatelessDecoder(AbsDecoder): @@ -26,7 +26,6 @@ class StatelessDecoder(AbsDecoder): embed_size: int = 256, embed_dropout_rate: float = 0.0, embed_pad: int = 0, - use_embed_mask: bool = False, ) -> None: """Construct a StatelessDecoder object.""" super().__init__() @@ -42,14 +41,6 @@ class StatelessDecoder(AbsDecoder): self.device = next(self.parameters()).device self.score_cache = {} - self.use_embed_mask = use_embed_mask - if self.use_embed_mask: - self._embed_mask = SpecAug( - time_mask_width_range=3, - num_time_mask=1, - apply_freq_mask=False, - apply_time_warp=False - ) def forward( @@ -69,9 +60,6 @@ class StatelessDecoder(AbsDecoder): """ dec_embed = self.embed_dropout_rate(self.embed(labels)) - if self.use_embed_mask and self.training: - dec_embed = self._embed_mask(dec_embed, label_lens)[0] - return dec_embed def score( diff --git a/funasr/models_transducer/encoder/blocks/__init__.py b/funasr/models_transducer/encoder/blocks/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/funasr/models_transducer/encoder/modules/__init__.py b/funasr/models_transducer/encoder/modules/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/funasr/models_transducer/encoder/sanm_encoder.py b/funasr/models_transducer/encoder/sanm_encoder.py deleted file mode 100644 index 9e74bdfeb..000000000 --- a/funasr/models_transducer/encoder/sanm_encoder.py +++ /dev/null @@ -1,835 +0,0 @@ -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union -import logging -import torch -import torch.nn as nn -from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk -from typeguard import check_argument_types -import numpy as np -from funasr.modules.nets_utils import make_pad_mask -from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM -from funasr.modules.embedding import SinusoidalPositionEncoder -from funasr.modules.layer_norm import LayerNorm -from funasr.modules.multi_layer_conv import Conv1dLinear -from funasr.modules.multi_layer_conv import MultiLayeredConv1d -from funasr.modules.positionwise_feed_forward import ( - PositionwiseFeedForward, # noqa: H301 -) -from funasr.modules.repeat import repeat -from funasr.modules.subsampling import Conv2dSubsampling -from funasr.modules.subsampling import Conv2dSubsampling2 -from funasr.modules.subsampling import Conv2dSubsampling6 -from funasr.modules.subsampling import Conv2dSubsampling8 -from funasr.modules.subsampling import TooShortUttError -from funasr.modules.subsampling import check_short_utt -from funasr.models.ctc import CTC -from funasr.models.encoder.abs_encoder import AbsEncoder - - -class EncoderLayerSANM(nn.Module): - def __init__( - self, - in_size, - size, - self_attn, - feed_forward, - dropout_rate, - normalize_before=True, - concat_after=False, - stochastic_depth_rate=0.0, - ): - """Construct an EncoderLayer object.""" - super(EncoderLayerSANM, self).__init__() - self.self_attn = self_attn - self.feed_forward = feed_forward - self.norm1 = LayerNorm(in_size) - self.norm2 = LayerNorm(size) - self.dropout = nn.Dropout(dropout_rate) - self.in_size = in_size - self.size = size - self.normalize_before = normalize_before - self.concat_after = concat_after - if self.concat_after: - self.concat_linear = nn.Linear(size + size, size) - self.stochastic_depth_rate = stochastic_depth_rate - self.dropout_rate = dropout_rate - - def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None): - """Compute encoded features. - Args: - x_input (torch.Tensor): Input tensor (#batch, time, size). - mask (torch.Tensor): Mask tensor for the input (#batch, time). - cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). - Returns: - torch.Tensor: Output tensor (#batch, time, size). - torch.Tensor: Mask tensor (#batch, time). - """ - skip_layer = False - # with stochastic depth, residual connection `x + f(x)` becomes - # `x <- x + 1 / (1 - p) * f(x)` at training time. - stoch_layer_coeff = 1.0 - if self.training and self.stochastic_depth_rate > 0: - skip_layer = torch.rand(1).item() < self.stochastic_depth_rate - stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) - - if skip_layer: - if cache is not None: - x = torch.cat([cache, x], dim=1) - return x, mask - - residual = x - if self.normalize_before: - x = self.norm1(x) - - if self.concat_after: - x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1) - if self.in_size == self.size: - x = residual + stoch_layer_coeff * self.concat_linear(x_concat) - else: - x = stoch_layer_coeff * self.concat_linear(x_concat) - else: - if self.in_size == self.size: - x = residual + stoch_layer_coeff * self.dropout( - self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder) - ) - else: - x = stoch_layer_coeff * self.dropout( - self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder) - ) - if not self.normalize_before: - x = self.norm1(x) - - residual = x - if self.normalize_before: - x = self.norm2(x) - x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) - if not self.normalize_before: - x = self.norm2(x) - - - return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder - -class SANMEncoder(AbsEncoder): - """ - author: Speech Lab, Alibaba Group, China - San-m: Memory equipped self-attention for end-to-end speech recognition - https://arxiv.org/abs/2006.01713 - """ - - def __init__( - self, - input_size: int, - output_size: int = 256, - attention_heads: int = 4, - linear_units: int = 2048, - num_blocks: int = 6, - dropout_rate: float = 0.1, - positional_dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.0, - pos_enc_class=SinusoidalPositionEncoder, - normalize_before: bool = True, - concat_after: bool = False, - positionwise_layer_type: str = "linear", - positionwise_conv_kernel_size: int = 1, - padding_idx: int = -1, - interctc_layer_idx: List[int] = [], - interctc_use_conditioning: bool = False, - kernel_size : int = 11, - sanm_shfit : int = 0, - tf2torch_tensor_name_prefix_torch: str = "encoder", - tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder", - ): - assert check_argument_types() - super().__init__() - - self.embed = SinusoidalPositionEncoder() - self.normalize_before = normalize_before - if positionwise_layer_type == "linear": - positionwise_layer = PositionwiseFeedForward - positionwise_layer_args = ( - output_size, - linear_units, - dropout_rate, - ) - elif positionwise_layer_type == "conv1d": - positionwise_layer = MultiLayeredConv1d - positionwise_layer_args = ( - output_size, - linear_units, - positionwise_conv_kernel_size, - dropout_rate, - ) - elif positionwise_layer_type == "conv1d-linear": - positionwise_layer = Conv1dLinear - positionwise_layer_args = ( - output_size, - linear_units, - positionwise_conv_kernel_size, - dropout_rate, - ) - else: - raise NotImplementedError("Support only linear or conv1d.") - - encoder_selfattn_layer = MultiHeadedAttentionSANM - encoder_selfattn_layer_args0 = ( - attention_heads, - input_size, - output_size, - attention_dropout_rate, - kernel_size, - sanm_shfit, - ) - - encoder_selfattn_layer_args = ( - attention_heads, - output_size, - output_size, - attention_dropout_rate, - kernel_size, - sanm_shfit, - ) - self.encoders0 = repeat( - 1, - lambda lnum: EncoderLayerSANM( - input_size, - output_size, - encoder_selfattn_layer(*encoder_selfattn_layer_args0), - positionwise_layer(*positionwise_layer_args), - dropout_rate, - normalize_before, - concat_after, - ), - ) - - self.encoders = repeat( - num_blocks-1, - lambda lnum: EncoderLayerSANM( - output_size, - output_size, - encoder_selfattn_layer(*encoder_selfattn_layer_args), - positionwise_layer(*positionwise_layer_args), - dropout_rate, - normalize_before, - concat_after, - ), - ) - if self.normalize_before: - self.after_norm = LayerNorm(output_size) - - self.interctc_layer_idx = interctc_layer_idx - if len(interctc_layer_idx) > 0: - assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks - self.interctc_use_conditioning = interctc_use_conditioning - self.conditioning_layer = None - self.dropout = nn.Dropout(dropout_rate) - self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch - self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf - - def forward( - self, - xs_pad: torch.Tensor, - ilens: torch.Tensor, - prev_states: torch.Tensor = None, - ctc: CTC = None, - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """Embed positions in tensor. - Args: - xs_pad: input tensor (B, L, D) - ilens: input length (B) - prev_states: Not to be used now. - Returns: - position embedded tensor and mask - """ - masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) - xs_pad = xs_pad * self.output_size**0.5 - if self.embed is None: - xs_pad = xs_pad - elif ( - isinstance(self.embed, Conv2dSubsampling) - or isinstance(self.embed, Conv2dSubsampling2) - or isinstance(self.embed, Conv2dSubsampling6) - or isinstance(self.embed, Conv2dSubsampling8) - ): - short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) - if short_status: - raise TooShortUttError( - f"has {xs_pad.size(1)} frames and is too short for subsampling " - + f"(it needs more than {limit_size} frames), return empty results", - xs_pad.size(1), - limit_size, - ) - xs_pad, masks = self.embed(xs_pad, masks) - else: - xs_pad = self.embed(xs_pad) - - # xs_pad = self.dropout(xs_pad) - encoder_outs = self.encoders0(xs_pad, masks) - xs_pad, masks = encoder_outs[0], encoder_outs[1] - intermediate_outs = [] - if len(self.interctc_layer_idx) == 0: - encoder_outs = self.encoders(xs_pad, masks) - xs_pad, masks = encoder_outs[0], encoder_outs[1] - else: - for layer_idx, encoder_layer in enumerate(self.encoders): - encoder_outs = encoder_layer(xs_pad, masks) - xs_pad, masks = encoder_outs[0], encoder_outs[1] - - if layer_idx + 1 in self.interctc_layer_idx: - encoder_out = xs_pad - - # intermediate outputs are also normalized - if self.normalize_before: - encoder_out = self.after_norm(encoder_out) - - intermediate_outs.append((layer_idx + 1, encoder_out)) - - if self.interctc_use_conditioning: - ctc_out = ctc.softmax(encoder_out) - xs_pad = xs_pad + self.conditioning_layer(ctc_out) - - if self.normalize_before: - xs_pad = self.after_norm(xs_pad) - - olens = masks.squeeze(1).sum(1) - if len(intermediate_outs) > 0: - return (xs_pad, intermediate_outs), olens, None - return xs_pad, olens - - def gen_tf2torch_map_dict(self): - tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch - tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf - map_dict_local = { - ## encoder - # cicd - "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 0), - }, # (768,256),(1,256,768) - "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (768,),(768,) - "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 2, 0), - }, # (256,1,31),(1,31,256,1) - "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 0), - }, # (256,256),(1,256,256) - "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - # ffn - "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 0), - }, # (1024,256),(1,256,1024) - "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (1024,),(1024,) - "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 0), - }, # (256,1024),(1,1024,256) - "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - # out norm - "{}.after_norm.weight".format(tensor_name_prefix_torch): - {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.after_norm.bias".format(tensor_name_prefix_torch): - {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - - } - - return map_dict_local - - def convert_tf2torch(self, - var_dict_tf, - var_dict_torch, - ): - - map_dict = self.gen_tf2torch_map_dict() - - var_dict_torch_update = dict() - for name in sorted(var_dict_torch.keys(), reverse=False): - names = name.split('.') - if names[0] == self.tf2torch_tensor_name_prefix_torch: - if names[1] == "encoders0": - layeridx = int(names[2]) - name_q = name.replace(".{}.".format(layeridx), ".layeridx.") - - name_q = name_q.replace("encoders0", "encoders") - layeridx_bias = 0 - layeridx += layeridx_bias - if name_q in map_dict.keys(): - name_v = map_dict[name_q]["name"] - name_tf = name_v.replace("layeridx", "{}".format(layeridx)) - data_tf = var_dict_tf[name_tf] - if map_dict[name_q]["squeeze"] is not None: - data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) - if map_dict[name_q]["transpose"] is not None: - data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) - data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") - assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, - var_dict_torch[ - name].size(), - data_tf.size()) - var_dict_torch_update[name] = data_tf - logging.info( - "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, - var_dict_tf[name_tf].shape)) - elif names[1] == "encoders": - layeridx = int(names[2]) - name_q = name.replace(".{}.".format(layeridx), ".layeridx.") - layeridx_bias = 1 - layeridx += layeridx_bias - if name_q in map_dict.keys(): - name_v = map_dict[name_q]["name"] - name_tf = name_v.replace("layeridx", "{}".format(layeridx)) - data_tf = var_dict_tf[name_tf] - if map_dict[name_q]["squeeze"] is not None: - data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) - if map_dict[name_q]["transpose"] is not None: - data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) - data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") - assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, - var_dict_torch[ - name].size(), - data_tf.size()) - var_dict_torch_update[name] = data_tf - logging.info( - "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, - var_dict_tf[name_tf].shape)) - - elif names[1] == "after_norm": - name_tf = map_dict[name]["name"] - data_tf = var_dict_tf[name_tf] - data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") - var_dict_torch_update[name] = data_tf - logging.info( - "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, - var_dict_tf[name_tf].shape)) - - return var_dict_torch_update - - -class SANMEncoderChunkOpt(AbsEncoder): - """ - author: Speech Lab, Alibaba Group, China - SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition - https://arxiv.org/abs/2006.01713 - """ - - def __init__( - self, - input_size: int, - output_size: int = 256, - attention_heads: int = 4, - linear_units: int = 2048, - num_blocks: int = 6, - dropout_rate: float = 0.1, - positional_dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.0, - pos_enc_class=SinusoidalPositionEncoder, - normalize_before: bool = True, - concat_after: bool = False, - positionwise_layer_type: str = "linear", - positionwise_conv_kernel_size: int = 1, - padding_idx: int = -1, - interctc_layer_idx: List[int] = [], - interctc_use_conditioning: bool = False, - kernel_size: int = 11, - sanm_shfit: int = 0, - chunk_size: Union[int, Sequence[int]] = (16,), - stride: Union[int, Sequence[int]] = (10,), - pad_left: Union[int, Sequence[int]] = (0,), - time_reduction_factor: int = 1, - encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,), - decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,), - tf2torch_tensor_name_prefix_torch: str = "encoder", - tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder", - ): - assert check_argument_types() - super().__init__() - self.output_size = output_size - - self.embed = SinusoidalPositionEncoder() - - self.normalize_before = normalize_before - if positionwise_layer_type == "linear": - positionwise_layer = PositionwiseFeedForward - positionwise_layer_args = ( - output_size, - linear_units, - dropout_rate, - ) - elif positionwise_layer_type == "conv1d": - positionwise_layer = MultiLayeredConv1d - positionwise_layer_args = ( - output_size, - linear_units, - positionwise_conv_kernel_size, - dropout_rate, - ) - elif positionwise_layer_type == "conv1d-linear": - positionwise_layer = Conv1dLinear - positionwise_layer_args = ( - output_size, - linear_units, - positionwise_conv_kernel_size, - dropout_rate, - ) - else: - raise NotImplementedError("Support only linear or conv1d.") - - encoder_selfattn_layer = MultiHeadedAttentionSANM - encoder_selfattn_layer_args0 = ( - attention_heads, - input_size, - output_size, - attention_dropout_rate, - kernel_size, - sanm_shfit, - ) - - encoder_selfattn_layer_args = ( - attention_heads, - output_size, - output_size, - attention_dropout_rate, - kernel_size, - sanm_shfit, - ) - self.encoders0 = repeat( - 1, - lambda lnum: EncoderLayerSANM( - input_size, - output_size, - encoder_selfattn_layer(*encoder_selfattn_layer_args0), - positionwise_layer(*positionwise_layer_args), - dropout_rate, - normalize_before, - concat_after, - ), - ) - - self.encoders = repeat( - num_blocks - 1, - lambda lnum: EncoderLayerSANM( - output_size, - output_size, - encoder_selfattn_layer(*encoder_selfattn_layer_args), - positionwise_layer(*positionwise_layer_args), - dropout_rate, - normalize_before, - concat_after, - ), - ) - if self.normalize_before: - self.after_norm = LayerNorm(output_size) - - self.interctc_layer_idx = interctc_layer_idx - if len(interctc_layer_idx) > 0: - assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks - self.interctc_use_conditioning = interctc_use_conditioning - self.conditioning_layer = None - shfit_fsmn = (kernel_size - 1) // 2 - self.overlap_chunk_cls = overlap_chunk( - chunk_size=chunk_size, - stride=stride, - pad_left=pad_left, - shfit_fsmn=shfit_fsmn, - encoder_att_look_back_factor=encoder_att_look_back_factor, - decoder_att_look_back_factor=decoder_att_look_back_factor, - ) - self.time_reduction_factor = time_reduction_factor - self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch - self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf - - def forward( - self, - xs_pad: torch.Tensor, - ilens: torch.Tensor, - prev_states: torch.Tensor = None, - ctc: CTC = None, - ind: int = 0, - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """Embed positions in tensor. - Args: - xs_pad: input tensor (B, L, D) - ilens: input length (B) - prev_states: Not to be used now. - Returns: - position embedded tensor and mask - """ - masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) - xs_pad *= self.output_size ** 0.5 - if self.embed is None: - xs_pad = xs_pad - elif ( - isinstance(self.embed, Conv2dSubsampling) - or isinstance(self.embed, Conv2dSubsampling2) - or isinstance(self.embed, Conv2dSubsampling6) - or isinstance(self.embed, Conv2dSubsampling8) - ): - short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) - if short_status: - raise TooShortUttError( - f"has {xs_pad.size(1)} frames and is too short for subsampling " - + f"(it needs more than {limit_size} frames), return empty results", - xs_pad.size(1), - limit_size, - ) - xs_pad, masks = self.embed(xs_pad, masks) - else: - xs_pad = self.embed(xs_pad) - - mask_shfit_chunk, mask_att_chunk_encoder = None, None - if self.overlap_chunk_cls is not None: - ilens = masks.squeeze(1).sum(1) - chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind) - xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs) - masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) - mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0), - dtype=xs_pad.dtype) - mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device, - xs_pad.size(0), - dtype=xs_pad.dtype) - - encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder) - xs_pad, masks = encoder_outs[0], encoder_outs[1] - intermediate_outs = [] - if len(self.interctc_layer_idx) == 0: - encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder) - xs_pad, masks = encoder_outs[0], encoder_outs[1] - else: - for layer_idx, encoder_layer in enumerate(self.encoders): - encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder) - xs_pad, masks = encoder_outs[0], encoder_outs[1] - if layer_idx + 1 in self.interctc_layer_idx: - encoder_out = xs_pad - - # intermediate outputs are also normalized - if self.normalize_before: - encoder_out = self.after_norm(encoder_out) - - intermediate_outs.append((layer_idx + 1, encoder_out)) - - if self.interctc_use_conditioning: - ctc_out = ctc.softmax(encoder_out) - xs_pad = xs_pad + self.conditioning_layer(ctc_out) - - if self.normalize_before: - xs_pad = self.after_norm(xs_pad) - - olens = masks.squeeze(1).sum(1) - - xs_pad, olens = self.overlap_chunk_cls.remove_chunk(xs_pad, olens, chunk_outs=None) - - if self.time_reduction_factor > 1: - xs_pad = xs_pad[:,::self.time_reduction_factor,:] - olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 - - if len(intermediate_outs) > 0: - return (xs_pad, intermediate_outs), olens, None - return xs_pad, olens - - def gen_tf2torch_map_dict(self): - tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch - tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf - map_dict_local = { - ## encoder - # cicd - "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 0), - }, # (768,256),(1,256,768) - "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (768,),(768,) - "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 2, 0), - }, # (256,1,31),(1,31,256,1) - "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 0), - }, # (256,256),(1,256,256) - "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - # ffn - "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 0), - }, # (1024,256),(1,256,1024) - "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (1024,),(1024,) - "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 0), - }, # (256,1024),(1,1024,256) - "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch): - {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - # out norm - "{}.after_norm.weight".format(tensor_name_prefix_torch): - {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.after_norm.bias".format(tensor_name_prefix_torch): - {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - - } - - return map_dict_local - - def convert_tf2torch(self, - var_dict_tf, - var_dict_torch, - ): - - map_dict = self.gen_tf2torch_map_dict() - - var_dict_torch_update = dict() - for name in sorted(var_dict_torch.keys(), reverse=False): - names = name.split('.') - if names[0] == self.tf2torch_tensor_name_prefix_torch: - if names[1] == "encoders0": - layeridx = int(names[2]) - name_q = name.replace(".{}.".format(layeridx), ".layeridx.") - - name_q = name_q.replace("encoders0", "encoders") - layeridx_bias = 0 - layeridx += layeridx_bias - if name_q in map_dict.keys(): - name_v = map_dict[name_q]["name"] - name_tf = name_v.replace("layeridx", "{}".format(layeridx)) - data_tf = var_dict_tf[name_tf] - if map_dict[name_q]["squeeze"] is not None: - data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) - if map_dict[name_q]["transpose"] is not None: - data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) - data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") - assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, - var_dict_torch[ - name].size(), - data_tf.size()) - var_dict_torch_update[name] = data_tf - logging.info( - "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, - var_dict_tf[name_tf].shape)) - elif names[1] == "encoders": - layeridx = int(names[2]) - name_q = name.replace(".{}.".format(layeridx), ".layeridx.") - layeridx_bias = 1 - layeridx += layeridx_bias - if name_q in map_dict.keys(): - name_v = map_dict[name_q]["name"] - name_tf = name_v.replace("layeridx", "{}".format(layeridx)) - data_tf = var_dict_tf[name_tf] - if map_dict[name_q]["squeeze"] is not None: - data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) - if map_dict[name_q]["transpose"] is not None: - data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) - data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") - assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, - var_dict_torch[ - name].size(), - data_tf.size()) - var_dict_torch_update[name] = data_tf - logging.info( - "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, - var_dict_tf[name_tf].shape)) - - elif names[1] == "after_norm": - name_tf = map_dict[name]["name"] - data_tf = var_dict_tf[name_tf] - data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") - var_dict_torch_update[name] = data_tf - logging.info( - "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, - var_dict_tf[name_tf].shape)) - - return var_dict_torch_update diff --git a/funasr/models_transducer/error_calculator.py b/funasr/models_transducer/error_calculator.py deleted file mode 100644 index 34b1dc74e..000000000 --- a/funasr/models_transducer/error_calculator.py +++ /dev/null @@ -1,169 +0,0 @@ -"""Error Calculator module for Transducer.""" - -from typing import List, Optional, Tuple - -import torch - -from funasr.models_transducer.beam_search_transducer import BeamSearchTransducer -from funasr.models_transducer.decoder.abs_decoder import AbsDecoder -from funasr.models_transducer.joint_network import JointNetwork - - -class ErrorCalculator: - """Calculate CER and WER for transducer models. - - Args: - decoder: Decoder module. - joint_network: Joint Network module. - token_list: List of token units. - sym_space: Space symbol. - sym_blank: Blank symbol. - report_cer: Whether to compute CER. - report_wer: Whether to compute WER. - - """ - - def __init__( - self, - decoder: AbsDecoder, - joint_network: JointNetwork, - token_list: List[int], - sym_space: str, - sym_blank: str, - report_cer: bool = False, - report_wer: bool = False, - ) -> None: - """Construct an ErrorCalculatorTransducer object.""" - super().__init__() - - self.beam_search = BeamSearchTransducer( - decoder=decoder, - joint_network=joint_network, - beam_size=1, - search_type="default", - score_norm=False, - ) - - self.decoder = decoder - - self.token_list = token_list - self.space = sym_space - self.blank = sym_blank - - self.report_cer = report_cer - self.report_wer = report_wer - - def __call__( - self, encoder_out: torch.Tensor, target: torch.Tensor - ) -> Tuple[Optional[float], Optional[float]]: - """Calculate sentence-level WER or/and CER score for Transducer model. - - Args: - encoder_out: Encoder output sequences. (B, T, D_enc) - target: Target label ID sequences. (B, L) - - Returns: - : Sentence-level CER score. - : Sentence-level WER score. - - """ - cer, wer = None, None - - batchsize = int(encoder_out.size(0)) - - encoder_out = encoder_out.to(next(self.decoder.parameters()).device) - - batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)] - pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest] - - char_pred, char_target = self.convert_to_char(pred, target) - - if self.report_cer: - cer = self.calculate_cer(char_pred, char_target) - - if self.report_wer: - wer = self.calculate_wer(char_pred, char_target) - - return cer, wer - - def convert_to_char( - self, pred: torch.Tensor, target: torch.Tensor - ) -> Tuple[List, List]: - """Convert label ID sequences to character sequences. - - Args: - pred: Prediction label ID sequences. (B, U) - target: Target label ID sequences. (B, L) - - Returns: - char_pred: Prediction character sequences. (B, ?) - char_target: Target character sequences. (B, ?) - - """ - char_pred, char_target = [], [] - - for i, pred_i in enumerate(pred): - char_pred_i = [self.token_list[int(h)] for h in pred_i] - char_target_i = [self.token_list[int(r)] for r in target[i]] - - char_pred_i = "".join(char_pred_i).replace(self.space, " ") - char_pred_i = char_pred_i.replace(self.blank, "") - - char_target_i = "".join(char_target_i).replace(self.space, " ") - char_target_i = char_target_i.replace(self.blank, "") - - char_pred.append(char_pred_i) - char_target.append(char_target_i) - - return char_pred, char_target - - def calculate_cer( - self, char_pred: torch.Tensor, char_target: torch.Tensor - ) -> float: - """Calculate sentence-level CER score. - - Args: - char_pred: Prediction character sequences. (B, ?) - char_target: Target character sequences. (B, ?) - - Returns: - : Average sentence-level CER score. - - """ - import editdistance - - distances, lens = [], [] - - for i, char_pred_i in enumerate(char_pred): - pred = char_pred_i.replace(" ", "") - target = char_target[i].replace(" ", "") - distances.append(editdistance.eval(pred, target)) - lens.append(len(target)) - - return float(sum(distances)) / sum(lens) - - def calculate_wer( - self, char_pred: torch.Tensor, char_target: torch.Tensor - ) -> float: - """Calculate sentence-level WER score. - - Args: - char_pred: Prediction character sequences. (B, ?) - char_target: Target character sequences. (B, ?) - - Returns: - : Average sentence-level WER score - - """ - import editdistance - - distances, lens = [], [] - - for i, char_pred_i in enumerate(char_pred): - pred = char_pred_i.replace("▁", " ").split() - target = char_target[i].replace("▁", " ").split() - - distances.append(editdistance.eval(pred, target)) - lens.append(len(target)) - - return float(sum(distances)) / sum(lens) diff --git a/funasr/models_transducer/espnet_transducer_model_uni_asr.py b/funasr/models_transducer/espnet_transducer_model_uni_asr.py deleted file mode 100644 index 2add3fa78..000000000 --- a/funasr/models_transducer/espnet_transducer_model_uni_asr.py +++ /dev/null @@ -1,485 +0,0 @@ -"""ESPnet2 ASR Transducer model.""" - -import logging -from contextlib import contextmanager -from typing import Dict, List, Optional, Tuple, Union - -import torch -from packaging.version import parse as V -from typeguard import check_argument_types - -from funasr.models.frontend.abs_frontend import AbsFrontend -from funasr.models.specaug.abs_specaug import AbsSpecAug -from funasr.models_transducer.decoder.abs_decoder import AbsDecoder -from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder -from funasr.models_transducer.encoder.encoder import Encoder -from funasr.models_transducer.joint_network import JointNetwork -from funasr.models_transducer.utils import get_transducer_task_io -from funasr.layers.abs_normalize import AbsNormalize -from funasr.torch_utils.device_funcs import force_gatherable -from funasr.train.abs_espnet_model import AbsESPnetModel - -if V(torch.__version__) >= V("1.6.0"): - from torch.cuda.amp import autocast -else: - - @contextmanager - def autocast(enabled=True): - yield - - -class UniASRTransducerModel(AbsESPnetModel): - """ESPnet2ASRTransducerModel module definition. - - Args: - vocab_size: Size of complete vocabulary (w/ EOS and blank included). - token_list: List of token - frontend: Frontend module. - specaug: SpecAugment module. - normalize: Normalization module. - encoder: Encoder module. - decoder: Decoder module. - joint_network: Joint Network module. - transducer_weight: Weight of the Transducer loss. - fastemit_lambda: FastEmit lambda value. - auxiliary_ctc_weight: Weight of auxiliary CTC loss. - auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs. - auxiliary_lm_loss_weight: Weight of auxiliary LM loss. - auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing. - ignore_id: Initial padding ID. - sym_space: Space symbol. - sym_blank: Blank Symbol - report_cer: Whether to report Character Error Rate during validation. - report_wer: Whether to report Word Error Rate during validation. - extract_feats_in_collect_stats: Whether to use extract_feats stats collection. - - """ - - def __init__( - self, - vocab_size: int, - token_list: Union[Tuple[str, ...], List[str]], - frontend: Optional[AbsFrontend], - specaug: Optional[AbsSpecAug], - normalize: Optional[AbsNormalize], - encoder, - decoder: AbsDecoder, - att_decoder: Optional[AbsAttDecoder], - joint_network: JointNetwork, - transducer_weight: float = 1.0, - fastemit_lambda: float = 0.0, - auxiliary_ctc_weight: float = 0.0, - auxiliary_ctc_dropout_rate: float = 0.0, - auxiliary_lm_loss_weight: float = 0.0, - auxiliary_lm_loss_smoothing: float = 0.0, - ignore_id: int = -1, - sym_space: str = "", - sym_blank: str = "", - report_cer: bool = True, - report_wer: bool = True, - extract_feats_in_collect_stats: bool = True, - ) -> None: - """Construct an ESPnetASRTransducerModel object.""" - super().__init__() - - assert check_argument_types() - - # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos) - self.blank_id = 0 - self.vocab_size = vocab_size - self.ignore_id = ignore_id - self.token_list = token_list.copy() - - self.sym_space = sym_space - self.sym_blank = sym_blank - - self.frontend = frontend - self.specaug = specaug - self.normalize = normalize - - self.encoder = encoder - self.decoder = decoder - self.joint_network = joint_network - - self.criterion_transducer = None - self.error_calculator = None - - self.use_auxiliary_ctc = auxiliary_ctc_weight > 0 - self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0 - - if self.use_auxiliary_ctc: - self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size) - self.ctc_dropout_rate = auxiliary_ctc_dropout_rate - - if self.use_auxiliary_lm_loss: - self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size) - self.lm_loss_smoothing = auxiliary_lm_loss_smoothing - - self.transducer_weight = transducer_weight - self.fastemit_lambda = fastemit_lambda - - self.auxiliary_ctc_weight = auxiliary_ctc_weight - self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight - - self.report_cer = report_cer - self.report_wer = report_wer - - self.extract_feats_in_collect_stats = extract_feats_in_collect_stats - - def forward( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - text: torch.Tensor, - text_lengths: torch.Tensor, - decoding_ind: int = None, - **kwargs, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: - """Forward architecture and compute loss(es). - - Args: - speech: Speech sequences. (B, S) - speech_lengths: Speech sequences lengths. (B,) - text: Label ID sequences. (B, L) - text_lengths: Label ID sequences lengths. (B,) - kwargs: Contains "utts_id". - - Return: - loss: Main loss value. - stats: Task statistics. - weight: Task weights. - - """ - assert text_lengths.dim() == 1, text_lengths.shape - assert ( - speech.shape[0] - == speech_lengths.shape[0] - == text.shape[0] - == text_lengths.shape[0] - ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) - - batch_size = speech.shape[0] - text = text[:, : text_lengths.max()] - - # 1. Encoder - ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind) - encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind) - # 2. Transducer-related I/O preparation - decoder_in, target, t_len, u_len = get_transducer_task_io( - text, - encoder_out_lens, - ignore_id=self.ignore_id, - ) - - # 3. Decoder - self.decoder.set_device(encoder_out.device) - decoder_out = self.decoder(decoder_in, u_len) - - # 4. Joint Network - joint_out = self.joint_network( - encoder_out.unsqueeze(2), decoder_out.unsqueeze(1) - ) - - # 5. Losses - loss_trans, cer_trans, wer_trans = self._calc_transducer_loss( - encoder_out, - joint_out, - target, - t_len, - u_len, - ) - - loss_ctc, loss_lm = 0.0, 0.0 - - if self.use_auxiliary_ctc: - loss_ctc = self._calc_ctc_loss( - encoder_out, - target, - t_len, - u_len, - ) - - if self.use_auxiliary_lm_loss: - loss_lm = self._calc_lm_loss(decoder_out, target) - - loss = ( - self.transducer_weight * loss_trans - + self.auxiliary_ctc_weight * loss_ctc - + self.auxiliary_lm_loss_weight * loss_lm - ) - - stats = dict( - loss=loss.detach(), - loss_transducer=loss_trans.detach(), - aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None, - aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None, - cer_transducer=cer_trans, - wer_transducer=wer_trans, - ) - - # force_gatherable: to-device and to-tensor if scalar for DataParallel - loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) - - return loss, stats, weight - - def collect_feats( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - text: torch.Tensor, - text_lengths: torch.Tensor, - **kwargs, - ) -> Dict[str, torch.Tensor]: - """Collect features sequences and features lengths sequences. - - Args: - speech: Speech sequences. (B, S) - speech_lengths: Speech sequences lengths. (B,) - text: Label ID sequences. (B, L) - text_lengths: Label ID sequences lengths. (B,) - kwargs: Contains "utts_id". - - Return: - {}: "feats": Features sequences. (B, T, D_feats), - "feats_lengths": Features sequences lengths. (B,) - - """ - if self.extract_feats_in_collect_stats: - feats, feats_lengths = self._extract_feats(speech, speech_lengths) - else: - # Generate dummy stats if extract_feats_in_collect_stats is False - logging.warning( - "Generating dummy stats for feats and feats_lengths, " - "because encoder_conf.extract_feats_in_collect_stats is " - f"{self.extract_feats_in_collect_stats}" - ) - - feats, feats_lengths = speech, speech_lengths - - return {"feats": feats, "feats_lengths": feats_lengths} - - def encode( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - ind: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Encoder speech sequences. - - Args: - speech: Speech sequences. (B, S) - speech_lengths: Speech sequences lengths. (B,) - - Return: - encoder_out: Encoder outputs. (B, T, D_enc) - encoder_out_lens: Encoder outputs lengths. (B,) - - """ - with autocast(False): - # 1. Extract feats - feats, feats_lengths = self._extract_feats(speech, speech_lengths) - - # 2. Data augmentation - if self.specaug is not None and self.training: - feats, feats_lengths = self.specaug(feats, feats_lengths) - - # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN - if self.normalize is not None: - feats, feats_lengths = self.normalize(feats, feats_lengths) - - # 4. Forward encoder - encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths, ind=ind) - - assert encoder_out.size(0) == speech.size(0), ( - encoder_out.size(), - speech.size(0), - ) - assert encoder_out.size(1) <= encoder_out_lens.max(), ( - encoder_out.size(), - encoder_out_lens.max(), - ) - - return encoder_out, encoder_out_lens - - def _extract_feats( - self, speech: torch.Tensor, speech_lengths: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Extract features sequences and features sequences lengths. - - Args: - speech: Speech sequences. (B, S) - speech_lengths: Speech sequences lengths. (B,) - - Return: - feats: Features sequences. (B, T, D_feats) - feats_lengths: Features sequences lengths. (B,) - - """ - assert speech_lengths.dim() == 1, speech_lengths.shape - - # for data-parallel - speech = speech[:, : speech_lengths.max()] - - if self.frontend is not None: - feats, feats_lengths = self.frontend(speech, speech_lengths) - else: - feats, feats_lengths = speech, speech_lengths - - return feats, feats_lengths - - def _calc_transducer_loss( - self, - encoder_out: torch.Tensor, - joint_out: torch.Tensor, - target: torch.Tensor, - t_len: torch.Tensor, - u_len: torch.Tensor, - ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]: - """Compute Transducer loss. - - Args: - encoder_out: Encoder output sequences. (B, T, D_enc) - joint_out: Joint Network output sequences (B, T, U, D_joint) - target: Target label ID sequences. (B, L) - t_len: Encoder output sequences lengths. (B,) - u_len: Target label ID sequences lengths. (B,) - - Return: - loss_transducer: Transducer loss value. - cer_transducer: Character error rate for Transducer. - wer_transducer: Word Error Rate for Transducer. - - """ - if self.criterion_transducer is None: - try: - # from warprnnt_pytorch import RNNTLoss - # self.criterion_transducer = RNNTLoss( - # reduction="mean", - # fastemit_lambda=self.fastemit_lambda, - # ) - from warp_rnnt import rnnt_loss as RNNTLoss - self.criterion_transducer = RNNTLoss - - except ImportError: - logging.error( - "warp-rnnt was not installed." - "Please consult the installation documentation." - ) - exit(1) - - # loss_transducer = self.criterion_transducer( - # joint_out, - # target, - # t_len, - # u_len, - # ) - log_probs = torch.log_softmax(joint_out, dim=-1) - - loss_transducer = self.criterion_transducer( - log_probs, - target, - t_len, - u_len, - reduction="mean", - blank=self.blank_id, - gather=True, - ) - - if not self.training and (self.report_cer or self.report_wer): - if self.error_calculator is None: - from espnet2.asr_transducer.error_calculator import ErrorCalculator - - self.error_calculator = ErrorCalculator( - self.decoder, - self.joint_network, - self.token_list, - self.sym_space, - self.sym_blank, - report_cer=self.report_cer, - report_wer=self.report_wer, - ) - - cer_transducer, wer_transducer = self.error_calculator(encoder_out, target) - - return loss_transducer, cer_transducer, wer_transducer - - return loss_transducer, None, None - - def _calc_ctc_loss( - self, - encoder_out: torch.Tensor, - target: torch.Tensor, - t_len: torch.Tensor, - u_len: torch.Tensor, - ) -> torch.Tensor: - """Compute CTC loss. - - Args: - encoder_out: Encoder output sequences. (B, T, D_enc) - target: Target label ID sequences. (B, L) - t_len: Encoder output sequences lengths. (B,) - u_len: Target label ID sequences lengths. (B,) - - Return: - loss_ctc: CTC loss value. - - """ - ctc_in = self.ctc_lin( - torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate) - ) - ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1) - - target_mask = target != 0 - ctc_target = target[target_mask].cpu() - - with torch.backends.cudnn.flags(deterministic=True): - loss_ctc = torch.nn.functional.ctc_loss( - ctc_in, - ctc_target, - t_len, - u_len, - zero_infinity=True, - reduction="sum", - ) - loss_ctc /= target.size(0) - - return loss_ctc - - def _calc_lm_loss( - self, - decoder_out: torch.Tensor, - target: torch.Tensor, - ) -> torch.Tensor: - """Compute LM loss. - - Args: - decoder_out: Decoder output sequences. (B, U, D_dec) - target: Target label ID sequences. (B, L) - - Return: - loss_lm: LM loss value. - - """ - lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size) - lm_target = target.view(-1).type(torch.int64) - - with torch.no_grad(): - true_dist = lm_loss_in.clone() - true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1)) - - # Ignore blank ID (0) - ignore = lm_target == 0 - lm_target = lm_target.masked_fill(ignore, 0) - - true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing)) - - loss_lm = torch.nn.functional.kl_div( - torch.log_softmax(lm_loss_in, dim=1), - true_dist, - reduction="none", - ) - loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size( - 0 - ) - - return loss_lm diff --git a/funasr/models_transducer/utils.py b/funasr/models_transducer/utils.py deleted file mode 100644 index fd3c531b4..000000000 --- a/funasr/models_transducer/utils.py +++ /dev/null @@ -1,200 +0,0 @@ -"""Utility functions for Transducer models.""" - -from typing import List, Tuple - -import torch - - -class TooShortUttError(Exception): - """Raised when the utt is too short for subsampling. - - Args: - message: Error message to display. - actual_size: The size that cannot pass the subsampling. - limit: The size limit for subsampling. - - """ - - def __init__(self, message: str, actual_size: int, limit: int) -> None: - """Construct a TooShortUttError module.""" - super().__init__(message) - - self.actual_size = actual_size - self.limit = limit - - -def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]: - """Check if the input is too short for subsampling. - - Args: - sub_factor: Subsampling factor for Conv2DSubsampling. - size: Input size. - - Returns: - : Whether an error should be sent. - : Size limit for specified subsampling factor. - - """ - if sub_factor == 2 and size < 3: - return True, 7 - elif sub_factor == 4 and size < 7: - return True, 7 - elif sub_factor == 6 and size < 11: - return True, 11 - - return False, -1 - - -def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]: - """Get conv2D second layer parameters for given subsampling factor. - - Args: - sub_factor: Subsampling factor (1/X). - input_size: Input size. - - Returns: - : Kernel size for second convolution. - : Stride for second convolution. - : Conv2DSubsampling output size. - - """ - if sub_factor == 2: - return 3, 1, (((input_size - 1) // 2 - 2)) - elif sub_factor == 4: - return 3, 2, (((input_size - 1) // 2 - 1) // 2) - elif sub_factor == 6: - return 5, 3, (((input_size - 1) // 2 - 2) // 3) - else: - raise ValueError( - "subsampling_factor parameter should be set to either 2, 4 or 6." - ) - - -def make_chunk_mask( - size: int, - chunk_size: int, - left_chunk_size: int = 0, - device: torch.device = None, -) -> torch.Tensor: - """Create chunk mask for the subsequent steps (size, size). - - Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py - - Args: - size: Size of the source mask. - chunk_size: Number of frames in chunk. - left_chunk_size: Size of the left context in chunks (0 means full context). - device: Device for the mask tensor. - - Returns: - mask: Chunk mask. (size, size) - - """ - mask = torch.zeros(size, size, device=device, dtype=torch.bool) - - for i in range(size): - if left_chunk_size <= 0: - start = 0 - else: - start = max((i // chunk_size - left_chunk_size) * chunk_size, 0) - - end = min((i // chunk_size + 1) * chunk_size, size) - mask[i, start:end] = True - - return ~mask - - -def make_source_mask(lengths: torch.Tensor) -> torch.Tensor: - """Create source mask for given lengths. - - Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py - - Args: - lengths: Sequence lengths. (B,) - - Returns: - : Mask for the sequence lengths. (B, max_len) - - """ - max_len = lengths.max() - batch_size = lengths.size(0) - - expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths) - - return expanded_lengths >= lengths.unsqueeze(1) - - -def get_transducer_task_io( - labels: torch.Tensor, - encoder_out_lens: torch.Tensor, - ignore_id: int = -1, - blank_id: int = 0, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Get Transducer loss I/O. - - Args: - labels: Label ID sequences. (B, L) - encoder_out_lens: Encoder output lengths. (B,) - ignore_id: Padding symbol ID. - blank_id: Blank symbol ID. - - Returns: - decoder_in: Decoder inputs. (B, U) - target: Target label ID sequences. (B, U) - t_len: Time lengths. (B,) - u_len: Label lengths. (B,) - - """ - - def pad_list(labels: List[torch.Tensor], padding_value: int = 0): - """Create padded batch of labels from a list of labels sequences. - - Args: - labels: Labels sequences. [B x (?)] - padding_value: Padding value. - - Returns: - labels: Batch of padded labels sequences. (B,) - - """ - batch_size = len(labels) - - padded = ( - labels[0] - .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:]) - .fill_(padding_value) - ) - - for i in range(batch_size): - padded[i, : labels[i].size(0)] = labels[i] - - return padded - - device = labels.device - - labels_unpad = [y[y != ignore_id] for y in labels] - blank = labels[0].new([blank_id]) - - decoder_in = pad_list( - [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id - ).to(device) - - target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device) - - encoder_out_lens = list(map(int, encoder_out_lens)) - t_len = torch.IntTensor(encoder_out_lens).to(device) - - u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device) - - return decoder_in, target, t_len, u_len - -def pad_to_len(t: torch.Tensor, pad_len: int, dim: int): - """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros.""" - if t.size(dim) == pad_len: - return t - else: - pad_size = list(t.shape) - pad_size[dim] = pad_len - t.size(dim) - return torch.cat( - [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim - ) diff --git a/funasr/models_transducer/activation.py b/funasr/modules/activation.py similarity index 100% rename from funasr/models_transducer/activation.py rename to funasr/modules/activation.py diff --git a/funasr/models_transducer/beam_search_transducer.py b/funasr/modules/beam_search/beam_search_transducer.py similarity index 99% rename from funasr/models_transducer/beam_search_transducer.py rename to funasr/modules/beam_search/beam_search_transducer.py index 8e234e45a..eaf5627f9 100644 --- a/funasr/models_transducer/beam_search_transducer.py +++ b/funasr/modules/beam_search/beam_search_transducer.py @@ -6,8 +6,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch -from funasr.models_transducer.decoder.abs_decoder import AbsDecoder -from funasr.models_transducer.joint_network import JointNetwork +from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder +from funasr.models.joint_network import JointNetwork @dataclass diff --git a/funasr/modules/e2e_asr_common.py b/funasr/modules/e2e_asr_common.py index 92f90796a..9b5039c91 100644 --- a/funasr/modules/e2e_asr_common.py +++ b/funasr/modules/e2e_asr_common.py @@ -6,6 +6,8 @@ """Common functions for ASR.""" +from typing import List, Optional, Tuple + import json import logging import sys @@ -13,7 +15,11 @@ import sys from itertools import groupby import numpy as np import six +import torch +from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer +from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder +from funasr.models.joint_network import JointNetwork def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): """End detection. @@ -247,3 +253,148 @@ class ErrorCalculator(object): word_eds.append(editdistance.eval(hyp_words, ref_words)) word_ref_lens.append(len(ref_words)) return float(sum(word_eds)) / sum(word_ref_lens) + +class ErrorCalculatorTransducer: + """Calculate CER and WER for transducer models. + Args: + decoder: Decoder module. + joint_network: Joint Network module. + token_list: List of token units. + sym_space: Space symbol. + sym_blank: Blank symbol. + report_cer: Whether to compute CER. + report_wer: Whether to compute WER. + """ + + def __init__( + self, + decoder: AbsDecoder, + joint_network: JointNetwork, + token_list: List[int], + sym_space: str, + sym_blank: str, + report_cer: bool = False, + report_wer: bool = False, + ) -> None: + """Construct an ErrorCalculatorTransducer object.""" + super().__init__() + + self.beam_search = BeamSearchTransducer( + decoder=decoder, + joint_network=joint_network, + beam_size=1, + search_type="default", + score_norm=False, + ) + + self.decoder = decoder + + self.token_list = token_list + self.space = sym_space + self.blank = sym_blank + + self.report_cer = report_cer + self.report_wer = report_wer + + def __call__( + self, encoder_out: torch.Tensor, target: torch.Tensor + ) -> Tuple[Optional[float], Optional[float]]: + """Calculate sentence-level WER or/and CER score for Transducer model. + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + target: Target label ID sequences. (B, L) + Returns: + : Sentence-level CER score. + : Sentence-level WER score. + """ + cer, wer = None, None + + batchsize = int(encoder_out.size(0)) + + encoder_out = encoder_out.to(next(self.decoder.parameters()).device) + + batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)] + pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest] + + char_pred, char_target = self.convert_to_char(pred, target) + + if self.report_cer: + cer = self.calculate_cer(char_pred, char_target) + + if self.report_wer: + wer = self.calculate_wer(char_pred, char_target) + + return cer, wer + + def convert_to_char( + self, pred: torch.Tensor, target: torch.Tensor + ) -> Tuple[List, List]: + """Convert label ID sequences to character sequences. + Args: + pred: Prediction label ID sequences. (B, U) + target: Target label ID sequences. (B, L) + Returns: + char_pred: Prediction character sequences. (B, ?) + char_target: Target character sequences. (B, ?) + """ + char_pred, char_target = [], [] + + for i, pred_i in enumerate(pred): + char_pred_i = [self.token_list[int(h)] for h in pred_i] + char_target_i = [self.token_list[int(r)] for r in target[i]] + + char_pred_i = "".join(char_pred_i).replace(self.space, " ") + char_pred_i = char_pred_i.replace(self.blank, "") + + char_target_i = "".join(char_target_i).replace(self.space, " ") + char_target_i = char_target_i.replace(self.blank, "") + + char_pred.append(char_pred_i) + char_target.append(char_target_i) + + return char_pred, char_target + + def calculate_cer( + self, char_pred: torch.Tensor, char_target: torch.Tensor + ) -> float: + """Calculate sentence-level CER score. + Args: + char_pred: Prediction character sequences. (B, ?) + char_target: Target character sequences. (B, ?) + Returns: + : Average sentence-level CER score. + """ + import editdistance + + distances, lens = [], [] + + for i, char_pred_i in enumerate(char_pred): + pred = char_pred_i.replace(" ", "") + target = char_target[i].replace(" ", "") + distances.append(editdistance.eval(pred, target)) + lens.append(len(target)) + + return float(sum(distances)) / sum(lens) + + def calculate_wer( + self, char_pred: torch.Tensor, char_target: torch.Tensor + ) -> float: + """Calculate sentence-level WER score. + Args: + char_pred: Prediction character sequences. (B, ?) + char_target: Target character sequences. (B, ?) + Returns: + : Average sentence-level WER score + """ + import editdistance + + distances, lens = [], [] + + for i, char_pred_i in enumerate(char_pred): + pred = char_pred_i.replace("▁", " ").split() + target = char_target[i].replace("▁", " ").split() + + distances.append(editdistance.eval(pred, target)) + lens.append(len(target)) + + return float(sum(distances)) / sum(lens) diff --git a/funasr/modules/nets_utils.py b/funasr/modules/nets_utils.py index 6d77d69a6..5d4fe1c85 100644 --- a/funasr/modules/nets_utils.py +++ b/funasr/modules/nets_utils.py @@ -3,7 +3,7 @@ """Network related utility tools.""" import logging -from typing import Dict +from typing import Dict, List, Tuple import numpy as np import torch @@ -506,3 +506,196 @@ def get_activation(act): } return activation_funcs[act]() + +class TooShortUttError(Exception): + """Raised when the utt is too short for subsampling. + + Args: + message: Error message to display. + actual_size: The size that cannot pass the subsampling. + limit: The size limit for subsampling. + + """ + + def __init__(self, message: str, actual_size: int, limit: int) -> None: + """Construct a TooShortUttError module.""" + super().__init__(message) + + self.actual_size = actual_size + self.limit = limit + + +def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]: + """Check if the input is too short for subsampling. + + Args: + sub_factor: Subsampling factor for Conv2DSubsampling. + size: Input size. + + Returns: + : Whether an error should be sent. + : Size limit for specified subsampling factor. + + """ + if sub_factor == 2 and size < 3: + return True, 7 + elif sub_factor == 4 and size < 7: + return True, 7 + elif sub_factor == 6 and size < 11: + return True, 11 + + return False, -1 + + +def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]: + """Get conv2D second layer parameters for given subsampling factor. + + Args: + sub_factor: Subsampling factor (1/X). + input_size: Input size. + + Returns: + : Kernel size for second convolution. + : Stride for second convolution. + : Conv2DSubsampling output size. + + """ + if sub_factor == 2: + return 3, 1, (((input_size - 1) // 2 - 2)) + elif sub_factor == 4: + return 3, 2, (((input_size - 1) // 2 - 1) // 2) + elif sub_factor == 6: + return 5, 3, (((input_size - 1) // 2 - 2) // 3) + else: + raise ValueError( + "subsampling_factor parameter should be set to either 2, 4 or 6." + ) + + +def make_chunk_mask( + size: int, + chunk_size: int, + left_chunk_size: int = 0, + device: torch.device = None, +) -> torch.Tensor: + """Create chunk mask for the subsequent steps (size, size). + + Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py + + Args: + size: Size of the source mask. + chunk_size: Number of frames in chunk. + left_chunk_size: Size of the left context in chunks (0 means full context). + device: Device for the mask tensor. + + Returns: + mask: Chunk mask. (size, size) + + """ + mask = torch.zeros(size, size, device=device, dtype=torch.bool) + + for i in range(size): + if left_chunk_size <= 0: + start = 0 + else: + start = max((i // chunk_size - left_chunk_size) * chunk_size, 0) + + end = min((i // chunk_size + 1) * chunk_size, size) + mask[i, start:end] = True + + return ~mask + +def make_source_mask(lengths: torch.Tensor) -> torch.Tensor: + """Create source mask for given lengths. + + Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py + + Args: + lengths: Sequence lengths. (B,) + + Returns: + : Mask for the sequence lengths. (B, max_len) + + """ + max_len = lengths.max() + batch_size = lengths.size(0) + + expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths) + + return expanded_lengths >= lengths.unsqueeze(1) + + +def get_transducer_task_io( + labels: torch.Tensor, + encoder_out_lens: torch.Tensor, + ignore_id: int = -1, + blank_id: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Get Transducer loss I/O. + + Args: + labels: Label ID sequences. (B, L) + encoder_out_lens: Encoder output lengths. (B,) + ignore_id: Padding symbol ID. + blank_id: Blank symbol ID. + + Returns: + decoder_in: Decoder inputs. (B, U) + target: Target label ID sequences. (B, U) + t_len: Time lengths. (B,) + u_len: Label lengths. (B,) + + """ + + def pad_list(labels: List[torch.Tensor], padding_value: int = 0): + """Create padded batch of labels from a list of labels sequences. + + Args: + labels: Labels sequences. [B x (?)] + padding_value: Padding value. + + Returns: + labels: Batch of padded labels sequences. (B,) + + """ + batch_size = len(labels) + + padded = ( + labels[0] + .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:]) + .fill_(padding_value) + ) + + for i in range(batch_size): + padded[i, : labels[i].size(0)] = labels[i] + + return padded + + device = labels.device + + labels_unpad = [y[y != ignore_id] for y in labels] + blank = labels[0].new([blank_id]) + + decoder_in = pad_list( + [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id + ).to(device) + + target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device) + + encoder_out_lens = list(map(int, encoder_out_lens)) + t_len = torch.IntTensor(encoder_out_lens).to(device) + + u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device) + + return decoder_in, target, t_len, u_len + +def pad_to_len(t: torch.Tensor, pad_len: int, dim: int): + """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros.""" + if t.size(dim) == pad_len: + return t + else: + pad_size = list(t.shape) + pad_size[dim] = pad_len - t.size(dim) + return torch.cat( + [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim + ) diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py index be1445590..cae18c169 100644 --- a/funasr/tasks/asr_transducer.py +++ b/funasr/tasks/asr_transducer.py @@ -21,15 +21,13 @@ from funasr.models.decoder.transformer_decoder import ( LightweightConvolutionTransformerDecoder, TransformerDecoder, ) -from funasr.models_transducer.decoder.abs_decoder import AbsDecoder -from funasr.models_transducer.decoder.rnn_decoder import RNNDecoder -from funasr.models_transducer.decoder.stateless_decoder import StatelessDecoder -from funasr.models_transducer.encoder.encoder import Encoder -from funasr.models_transducer.encoder.sanm_encoder import SANMEncoderChunkOpt -from funasr.models_transducer.espnet_transducer_model import ESPnetASRTransducerModel -from funasr.models_transducer.espnet_transducer_model_unified import ESPnetASRUnifiedTransducerModel -from funasr.models_transducer.espnet_transducer_model_uni_asr import UniASRTransducerModel -from funasr.models_transducer.joint_network import JointNetwork +from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder +from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder +from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder +from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder +from funasr.models.e2e_transducer import TransducerModel +from funasr.models.e2e_transducer_unified import UnifiedTransducerModel +from funasr.models.joint_network import JointNetwork from funasr.layers.abs_normalize import AbsNormalize from funasr.layers.global_mvn import GlobalMVN from funasr.layers.utterance_mvn import UtteranceMVN @@ -75,7 +73,6 @@ encoder_choices = ClassChoices( "encoder", classes=dict( encoder=Encoder, - sanm_chunk_opt=SANMEncoderChunkOpt, ), default="encoder", ) @@ -158,7 +155,7 @@ class ASRTransducerTask(AbsTask): group.add_argument( "--model_conf", action=NestedDictAction, - default=get_default_kwargs(ESPnetASRTransducerModel), + default=get_default_kwargs(TransducerModel), help="The keyword arguments for the model class.", ) # group.add_argument( @@ -354,7 +351,7 @@ class ASRTransducerTask(AbsTask): return retval @classmethod - def build_model(cls, args: argparse.Namespace) -> ESPnetASRTransducerModel: + def build_model(cls, args: argparse.Namespace) -> TransducerModel: """Required data depending on task mode. Args: cls: ASRTransducerTask object. @@ -440,22 +437,8 @@ class ASRTransducerTask(AbsTask): # 7. Build model - if getattr(args, "encoder", None) is not None and args.encoder == 'sanm_chunk_opt': - model = UniASRTransducerModel( - vocab_size=vocab_size, - token_list=token_list, - frontend=frontend, - specaug=specaug, - normalize=normalize, - encoder=encoder, - decoder=decoder, - att_decoder=att_decoder, - joint_network=joint_network, - **args.model_conf, - ) - - elif encoder.unified_model_training: - model = ESPnetASRUnifiedTransducerModel( + if encoder.unified_model_training: + model = UnifiedTransducerModel( vocab_size=vocab_size, token_list=token_list, frontend=frontend, @@ -469,7 +452,7 @@ class ASRTransducerTask(AbsTask): ) else: - model = ESPnetASRTransducerModel( + model = TransducerModel( vocab_size=vocab_size, token_list=token_list, frontend=frontend, From 0bde4aefbd060d311b531ae711fa400d2ce2b84a Mon Sep 17 00:00:00 2001 From: aky15 Date: Wed, 12 Apr 2023 17:42:28 +0800 Subject: [PATCH 08/14] update README --- egs/aishell/rnnt/README.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/egs/aishell/rnnt/README.md b/egs/aishell/rnnt/README.md index 4d6ac9de3..45f1f3f98 100644 --- a/egs/aishell/rnnt/README.md +++ b/egs/aishell/rnnt/README.md @@ -2,9 +2,10 @@ # Streaming RNN-T Result ## Training Config +- 8 gpu(Tesla V100) - Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment -- Train config: conf/train_conformer_rnnt_unified -- chunk config: chunk size 16, 1 left chunk +- Train config: conf/train_conformer_rnnt_unified.yaml +- chunk config: chunk size 16, full left chunk - LM config: LM was not used - Model size: 90M @@ -13,5 +14,5 @@ | testset | CER(%) | |:-----------:|:-------:| -| dev | 5.89 | -| test | 6.76 | +| dev | 5.53 | +| test | 6.24 | From d1c9782515df60bf7578075c649c45b224106cfe Mon Sep 17 00:00:00 2001 From: aky15 Date: Wed, 12 Apr 2023 17:51:21 +0800 Subject: [PATCH 09/14] Delete abs_task.py resolve abs_task.py conflict --- funasr/tasks/abs_task.py | 1958 -------------------------------------- 1 file changed, 1958 deletions(-) delete mode 100644 funasr/tasks/abs_task.py diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py deleted file mode 100644 index d4b8a72ff..000000000 --- a/funasr/tasks/abs_task.py +++ /dev/null @@ -1,1958 +0,0 @@ -# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""Abstract task module.""" -import argparse -import functools -import logging -import os -import sys -from abc import ABC -from abc import abstractmethod -from dataclasses import dataclass -from distutils.version import LooseVersion -from io import BytesIO -from pathlib import Path -from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union - -import humanfriendly -import numpy as np -import torch -import torch.distributed as dist -import torch.multiprocessing -import torch.nn -import torch.optim -import yaml -from torch.utils.data import DataLoader -from typeguard import check_argument_types -from typeguard import check_return_type - -from funasr import __version__ -from funasr.datasets.dataset import AbsDataset -from funasr.datasets.dataset import DATA_TYPES -from funasr.datasets.dataset import ESPnetDataset -from funasr.datasets.iterable_dataset import IterableESPnetDataset -from funasr.iterators.abs_iter_factory import AbsIterFactory -from funasr.iterators.chunk_iter_factory import ChunkIterFactory -from funasr.iterators.multiple_iter_factory import MultipleIterFactory -from funasr.iterators.sequence_iter_factory import SequenceIterFactory -from funasr.main_funcs.collect_stats import collect_stats -from funasr.optimizers.sgd import SGD -from funasr.optimizers.fairseq_adam import FairseqAdam -from funasr.samplers.build_batch_sampler import BATCH_TYPES -from funasr.samplers.build_batch_sampler import build_batch_sampler -from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler -from funasr.schedulers.noam_lr import NoamLR -from funasr.schedulers.warmup_lr import WarmupLR -from funasr.schedulers.tri_stage_scheduler import TriStageLR -from funasr.torch_utils.load_pretrained_model import load_pretrained_model -from funasr.torch_utils.model_summary import model_summary -from funasr.torch_utils.pytorch_version import pytorch_cudnn_version -from funasr.torch_utils.set_all_random_seed import set_all_random_seed -from funasr.train.abs_espnet_model import AbsESPnetModel -from funasr.train.class_choices import ClassChoices -from funasr.train.distributed_utils import DistributedOption -from funasr.train.trainer import Trainer -from funasr.utils import config_argparse -from funasr.utils.build_dataclass import build_dataclass -from funasr.utils.cli_utils import get_commandline_args -from funasr.utils.get_default_kwargs import get_default_kwargs -from funasr.utils.nested_dict_action import NestedDictAction -from funasr.utils.types import humanfriendly_parse_size_or_none -from funasr.utils.types import int_or_none -from funasr.utils.types import str2bool -from funasr.utils.types import str2triple_str -from funasr.utils.types import str_or_int -from funasr.utils.types import str_or_none -from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text -from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump - -try: - import wandb -except Exception: - wandb = None - -if LooseVersion(torch.__version__) >= LooseVersion("1.5.0"): - pass -else: - pass - -optim_classes = dict( - adam=torch.optim.Adam, - fairseq_adam=FairseqAdam, - adamw=torch.optim.AdamW, - sgd=SGD, - adadelta=torch.optim.Adadelta, - adagrad=torch.optim.Adagrad, - adamax=torch.optim.Adamax, - asgd=torch.optim.ASGD, - lbfgs=torch.optim.LBFGS, - rmsprop=torch.optim.RMSprop, - rprop=torch.optim.Rprop, -) -if LooseVersion(torch.__version__) >= LooseVersion("1.10.0"): - # From 1.10.0, RAdam is officially supported - optim_classes.update( - radam=torch.optim.RAdam, - ) -try: - import torch_optimizer - - optim_classes.update( - accagd=torch_optimizer.AccSGD, - adabound=torch_optimizer.AdaBound, - adamod=torch_optimizer.AdaMod, - diffgrad=torch_optimizer.DiffGrad, - lamb=torch_optimizer.Lamb, - novograd=torch_optimizer.NovoGrad, - pid=torch_optimizer.PID, - # torch_optimizer<=0.0.1a10 doesn't support - # qhadam=torch_optimizer.QHAdam, - qhm=torch_optimizer.QHM, - sgdw=torch_optimizer.SGDW, - yogi=torch_optimizer.Yogi, - ) - if LooseVersion(torch_optimizer.__version__) < LooseVersion("0.2.0"): - # From 0.2.0, RAdam is dropped - optim_classes.update( - radam=torch_optimizer.RAdam, - ) - del torch_optimizer -except ImportError: - pass -try: - import apex - - optim_classes.update( - fusedadam=apex.optimizers.FusedAdam, - fusedlamb=apex.optimizers.FusedLAMB, - fusednovograd=apex.optimizers.FusedNovoGrad, - fusedsgd=apex.optimizers.FusedSGD, - ) - del apex -except ImportError: - pass -try: - import fairscale -except ImportError: - fairscale = None - -scheduler_classes = dict( - ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau, - lambdalr=torch.optim.lr_scheduler.LambdaLR, - steplr=torch.optim.lr_scheduler.StepLR, - multisteplr=torch.optim.lr_scheduler.MultiStepLR, - exponentiallr=torch.optim.lr_scheduler.ExponentialLR, - CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR, - noamlr=NoamLR, - warmuplr=WarmupLR, - tri_stage=TriStageLR, - cycliclr=torch.optim.lr_scheduler.CyclicLR, - onecyclelr=torch.optim.lr_scheduler.OneCycleLR, - CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, -) -# To lower keys -optim_classes = {k.lower(): v for k, v in optim_classes.items()} -scheduler_classes = {k.lower(): v for k, v in scheduler_classes.items()} - - -@dataclass -class IteratorOptions: - preprocess_fn: callable - collate_fn: callable - data_path_and_name_and_type: list - shape_files: list - batch_size: int - batch_bins: int - batch_type: str - max_cache_size: float - max_cache_fd: int - distributed: bool - num_batches: Optional[int] - num_iters_per_epoch: Optional[int] - train: bool - - -class AbsTask(ABC): - # Use @staticmethod, or @classmethod, - # instead of instance method to avoid God classes - - # If you need more than one optimizers, change this value in inheritance - num_optimizers: int = 1 - trainer = Trainer - class_choices_list: List[ClassChoices] = [] - finetune_args: None - - def __init__(self): - raise RuntimeError("This class can't be instantiated.") - - @classmethod - @abstractmethod - def add_task_arguments(cls, parser: argparse.ArgumentParser): - pass - - @classmethod - @abstractmethod - def build_collate_fn( - cls, args: argparse.Namespace, train: bool - ) -> Callable[[Sequence[Dict[str, np.ndarray]]], Dict[str, torch.Tensor]]: - """Return "collate_fn", which is a callable object and given to DataLoader. - - >>> from torch.utils.data import DataLoader - >>> loader = DataLoader(collate_fn=cls.build_collate_fn(args, train=True), ...) - - In many cases, you can use our common collate_fn. - """ - raise NotImplementedError - - @classmethod - @abstractmethod - def build_preprocess_fn( - cls, args: argparse.Namespace, train: bool - ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: - raise NotImplementedError - - @classmethod - @abstractmethod - def required_data_names( - cls, train: bool = True, inference: bool = False - ) -> Tuple[str, ...]: - """Define the required names by Task - - This function is used by - >>> cls.check_task_requirements() - If your model is defined as following, - - >>> from funasr.train.abs_espnet_model import AbsESPnetModel - >>> class Model(AbsESPnetModel): - ... def forward(self, input, output, opt=None): pass - - then "required_data_names" should be as - - >>> required_data_names = ('input', 'output') - """ - raise NotImplementedError - - @classmethod - @abstractmethod - def optional_data_names( - cls, train: bool = True, inference: bool = False - ) -> Tuple[str, ...]: - """Define the optional names by Task - - This function is used by - >>> cls.check_task_requirements() - If your model is defined as follows, - - >>> from funasr.train.abs_espnet_model import AbsESPnetModel - >>> class Model(AbsESPnetModel): - ... def forward(self, input, output, opt=None): pass - - then "optional_data_names" should be as - - >>> optional_data_names = ('opt',) - """ - raise NotImplementedError - - @classmethod - @abstractmethod - def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel: - raise NotImplementedError - - @classmethod - def get_parser(cls) -> config_argparse.ArgumentParser: - assert check_argument_types() - - class ArgumentDefaultsRawTextHelpFormatter( - argparse.RawTextHelpFormatter, - argparse.ArgumentDefaultsHelpFormatter, - ): - pass - - parser = config_argparse.ArgumentParser( - description="base parser", - formatter_class=ArgumentDefaultsRawTextHelpFormatter, - ) - - # NOTE(kamo): Use '_' instead of '-' to avoid confusion. - # I think '-' looks really confusing if it's written in yaml. - - # NOTE(kamo): add_arguments(..., required=True) can't be used - # to provide --print_config mode. Instead of it, do as - # parser.set_defaults(required=["output_dir"]) - - group = parser.add_argument_group("Common configuration") - - group.add_argument( - "--print_config", - action="store_true", - help="Print the config file and exit", - ) - group.add_argument( - "--log_level", - type=lambda x: x.upper(), - default="INFO", - choices=("ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), - help="The verbose level of logging", - ) - group.add_argument( - "--dry_run", - type=str2bool, - default=False, - help="Perform process without training", - ) - group.add_argument( - "--iterator_type", - type=str, - choices=["sequence", "chunk", "task", "none"], - default="sequence", - help="Specify iterator type", - ) - - group.add_argument("--output_dir", type=str_or_none, default=None) - group.add_argument( - "--ngpu", - type=int, - default=0, - help="The number of gpus. 0 indicates CPU mode", - ) - group.add_argument("--seed", type=int, default=0, help="Random seed") - group.add_argument( - "--num_workers", - type=int, - default=1, - help="The number of workers used for DataLoader", - ) - group.add_argument( - "--num_att_plot", - type=int, - default=3, - help="The number images to plot the outputs from attention. " - "This option makes sense only when attention-based model. " - "We can also disable the attention plot by setting it 0", - ) - - group = parser.add_argument_group("distributed training related") - group.add_argument( - "--dist_backend", - default="nccl", - type=str, - help="distributed backend", - ) - group.add_argument( - "--dist_init_method", - type=str, - default="env://", - help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", ' - '"WORLD_SIZE", and "RANK" are referred.', - ) - group.add_argument( - "--dist_world_size", - default=None, - type=int_or_none, - help="number of nodes for distributed training", - ) - group.add_argument( - "--dist_rank", - type=int_or_none, - default=None, - help="node rank for distributed training", - ) - group.add_argument( - # Not starting with "dist_" for compatibility to launch.py - "--local_rank", - type=int_or_none, - default=None, - help="local rank for distributed training. This option is used if " - "--multiprocessing_distributed=false", - ) - group.add_argument( - "--dist_master_addr", - default=None, - type=str_or_none, - help="The master address for distributed training. " - "This value is used when dist_init_method == 'env://'", - ) - group.add_argument( - "--dist_master_port", - default=None, - type=int_or_none, - help="The master port for distributed training" - "This value is used when dist_init_method == 'env://'", - ) - group.add_argument( - "--dist_launcher", - default=None, - type=str_or_none, - choices=["slurm", "mpi", None], - help="The launcher type for distributed training", - ) - group.add_argument( - "--multiprocessing_distributed", - default=False, - type=str2bool, - help="Use multi-processing distributed training to launch " - "N processes per node, which has N GPUs. This is the " - "fastest way to use PyTorch for either single node or " - "multi node data parallel training", - ) - group.add_argument( - "--unused_parameters", - type=str2bool, - default=False, - help="Whether to use the find_unused_parameters in " - "torch.nn.parallel.DistributedDataParallel ", - ) - group.add_argument( - "--sharded_ddp", - default=False, - type=str2bool, - help="Enable sharded training provided by fairscale", - ) - - group = parser.add_argument_group("cudnn mode related") - group.add_argument( - "--cudnn_enabled", - type=str2bool, - default=torch.backends.cudnn.enabled, - help="Enable CUDNN", - ) - group.add_argument( - "--cudnn_benchmark", - type=str2bool, - default=torch.backends.cudnn.benchmark, - help="Enable cudnn-benchmark mode", - ) - group.add_argument( - "--cudnn_deterministic", - type=str2bool, - default=True, - help="Enable cudnn-deterministic mode", - ) - - group = parser.add_argument_group("collect stats mode related") - group.add_argument( - "--collect_stats", - type=str2bool, - default=False, - help='Perform on "collect stats" mode', - ) - group.add_argument( - "--write_collected_feats", - type=str2bool, - default=False, - help='Write the output features from the model when "collect stats" mode', - ) - - group = parser.add_argument_group("Trainer related") - group.add_argument( - "--max_epoch", - type=int, - default=40, - help="The maximum number epoch to train", - ) - group.add_argument( - "--max_update", - type=int, - default=sys.maxsize, - help="The maximum number update step to train", - ) - parser.add_argument( - "--batch_interval", - type=int, - default=10000, - help="The batch interval for saving model.", - ) - group.add_argument( - "--patience", - type=int_or_none, - default=None, - help="Number of epochs to wait without improvement " - "before stopping the training", - ) - group.add_argument( - "--val_scheduler_criterion", - type=str, - nargs=2, - default=("valid", "loss"), - help="The criterion used for the value given to the lr scheduler. " - 'Give a pair referring the phase, "train" or "valid",' - 'and the criterion name. The mode specifying "min" or "max" can ' - "be changed by --scheduler_conf", - ) - group.add_argument( - "--early_stopping_criterion", - type=str, - nargs=3, - default=("valid", "loss", "min"), - help="The criterion used for judging of early stopping. " - 'Give a pair referring the phase, "train" or "valid",' - 'the criterion name and the mode, "min" or "max", e.g. "acc,max".', - ) - group.add_argument( - "--best_model_criterion", - type=str2triple_str, - nargs="+", - default=[ - ("train", "loss", "min"), - ("valid", "loss", "min"), - ("train", "acc", "max"), - ("valid", "acc", "max"), - ], - help="The criterion used for judging of the best model. " - 'Give a pair referring the phase, "train" or "valid",' - 'the criterion name, and the mode, "min" or "max", e.g. "acc,max".', - ) - group.add_argument( - "--keep_nbest_models", - type=int, - nargs="+", - default=[10], - help="Remove previous snapshots excluding the n-best scored epochs", - ) - group.add_argument( - "--nbest_averaging_interval", - type=int, - default=0, - help="The epoch interval to apply model averaging and save nbest models", - ) - group.add_argument( - "--grad_clip", - type=float, - default=5.0, - help="Gradient norm threshold to clip", - ) - group.add_argument( - "--grad_clip_type", - type=float, - default=2.0, - help="The type of the used p-norm for gradient clip. Can be inf", - ) - group.add_argument( - "--grad_noise", - type=str2bool, - default=False, - help="The flag to switch to use noise injection to " - "gradients during training", - ) - group.add_argument( - "--accum_grad", - type=int, - default=1, - help="The number of gradient accumulation", - ) - group.add_argument( - "--no_forward_run", - type=str2bool, - default=False, - help="Just only iterating data loading without " - "model forwarding and training", - ) - group.add_argument( - "--resume", - type=str2bool, - default=False, - help="Enable resuming if checkpoint is existing", - ) - group.add_argument( - "--train_dtype", - default="float32", - choices=["float16", "float32", "float64"], - help="Data type for training.", - ) - group.add_argument( - "--use_amp", - type=str2bool, - default=False, - help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6", - ) - group.add_argument( - "--log_interval", - type=int_or_none, - default=None, - help="Show the logs every the number iterations in each epochs at the " - "training phase. If None is given, it is decided according the number " - "of training samples automatically .", - ) - group.add_argument( - "--use_tensorboard", - type=str2bool, - default=True, - help="Enable tensorboard logging", - ) - group.add_argument( - "--use_wandb", - type=str2bool, - default=False, - help="Enable wandb logging", - ) - group.add_argument( - "--wandb_project", - type=str, - default=None, - help="Specify wandb project", - ) - group.add_argument( - "--wandb_id", - type=str, - default=None, - help="Specify wandb id", - ) - group.add_argument( - "--wandb_entity", - type=str, - default=None, - help="Specify wandb entity", - ) - group.add_argument( - "--wandb_name", - type=str, - default=None, - help="Specify wandb run name", - ) - group.add_argument( - "--wandb_model_log_interval", - type=int, - default=-1, - help="Set the model log period", - ) - group.add_argument( - "--detect_anomaly", - type=str2bool, - default=False, - help="Set torch.autograd.set_detect_anomaly", - ) - - group = parser.add_argument_group("Pretraining model related") - group.add_argument("--pretrain_path", help="This option is obsoleted") - group.add_argument( - "--init_param", - type=str, - default=[], - nargs="*", - help="Specify the file path used for initialization of parameters. " - "The format is ':::', " - "where file_path is the model file path, " - "src_key specifies the key of model states to be used in the model file, " - "dst_key specifies the attribute of the model to be initialized, " - "and exclude_keys excludes keys of model states for the initialization." - "e.g.\n" - " # Load all parameters" - " --init_param some/where/model.pb\n" - " # Load only decoder parameters" - " --init_param some/where/model.pb:decoder:decoder\n" - " # Load only decoder parameters excluding decoder.embed" - " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n" - " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n", - ) - group.add_argument( - "--ignore_init_mismatch", - type=str2bool, - default=False, - help="Ignore size mismatch when loading pre-trained model", - ) - group.add_argument( - "--freeze_param", - type=str, - default=[], - nargs="*", - help="Freeze parameters", - ) - - group = parser.add_argument_group("BatchSampler related") - group.add_argument( - "--num_iters_per_epoch", - type=int_or_none, - default=None, - help="Restrict the number of iterations for training per epoch", - ) - group.add_argument( - "--batch_size", - type=int, - default=20, - help="The mini-batch size used for training. Used if batch_type='unsorted'," - " 'sorted', or 'folded'.", - ) - group.add_argument( - "--valid_batch_size", - type=int_or_none, - default=None, - help="If not given, the value of --batch_size is used", - ) - group.add_argument( - "--batch_bins", - type=int, - default=1000000, - help="The number of batch bins. Used if batch_type='length' or 'numel'", - ) - group.add_argument( - "--valid_batch_bins", - type=int_or_none, - default=None, - help="If not given, the value of --batch_bins is used", - ) - - group.add_argument("--train_shape_file", type=str, action="append", default=[]) - group.add_argument("--valid_shape_file", type=str, action="append", default=[]) - - group = parser.add_argument_group("Sequence iterator related") - _batch_type_help = "" - for key, value in BATCH_TYPES.items(): - _batch_type_help += f'"{key}":\n{value}\n' - group.add_argument( - "--batch_type", - type=str, - default="length", - choices=list(BATCH_TYPES), - help=_batch_type_help, - ) - group.add_argument( - "--valid_batch_type", - type=str_or_none, - default=None, - choices=list(BATCH_TYPES) + [None], - help="If not given, the value of --batch_type is used", - ) - group.add_argument( - "--speech_length_min", - type=int, - default=-1, - help="speech length min", - ) - group.add_argument( - "--speech_length_max", - type=int, - default=-1, - help="speech length max", - ) - group.add_argument("--fold_length", type=int, action="append", default=[]) - group.add_argument( - "--sort_in_batch", - type=str, - default="descending", - choices=["descending", "ascending"], - help="Sort the samples in each mini-batches by the sample " - 'lengths. To enable this, "shape_file" must have the length information.', - ) - group.add_argument( - "--sort_batch", - type=str, - default="descending", - choices=["descending", "ascending"], - help="Sort mini-batches by the sample lengths", - ) - group.add_argument( - "--multiple_iterator", - type=str2bool, - default=False, - help="Use multiple iterator mode", - ) - - group = parser.add_argument_group("Chunk iterator related") - group.add_argument( - "--chunk_length", - type=str_or_int, - default=500, - help="Specify chunk length. e.g. '300', '300,400,500', or '300-400'." - "If multiple numbers separated by command are given, " - "one of them is selected randomly for each samples. " - "If two numbers are given with '-', it indicates the range of the choices. " - "Note that if the sequence length is shorter than the all chunk_lengths, " - "the sample is discarded. ", - ) - group.add_argument( - "--chunk_shift_ratio", - type=float, - default=0.5, - help="Specify the shift width of chunks. If it's less than 1, " - "allows the overlapping and if bigger than 1, there are some gaps " - "between each chunk.", - ) - group.add_argument( - "--num_cache_chunks", - type=int, - default=1024, - help="Shuffle in the specified number of chunks and generate mini-batches " - "More larger this value, more randomness can be obtained.", - ) - - group = parser.add_argument_group("Dataset related") - _data_path_and_name_and_type_help = ( - "Give three words splitted by comma. It's used for the training data. " - "e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. " - "The first value, some/path/a.scp, indicates the file path, " - "and the second, foo, is the key name used for the mini-batch data, " - "and the last, sound, decides the file type. " - "This option is repeatable, so you can input any number of features " - "for your task. Supported file types are as follows:\n\n" - ) - for key, dic in DATA_TYPES.items(): - _data_path_and_name_and_type_help += f'"{key}":\n{dic["help"]}\n\n' - - # for large dataset - group.add_argument( - "--dataset_type", - type=str, - default="small", - help="whether to use dataloader for large dataset", - ) - parser.add_argument( - "--dataset_conf", - action=NestedDictAction, - default=dict(), - help=f"The keyword arguments for dataset", - ) - group.add_argument( - "--train_data_file", - type=str, - default=None, - help="train_list for large dataset", - ) - group.add_argument( - "--valid_data_file", - type=str, - default=None, - help="valid_list for large dataset", - ) - - group.add_argument( - "--train_data_path_and_name_and_type", - type=str2triple_str, - action="append", - default=[], - help=_data_path_and_name_and_type_help, - ) - group.add_argument( - "--valid_data_path_and_name_and_type", - type=str2triple_str, - action="append", - default=[], - ) - group.add_argument( - "--allow_variable_data_keys", - type=str2bool, - default=False, - help="Allow the arbitrary keys for mini-batch with ignoring " - "the task requirements", - ) - group.add_argument( - "--max_cache_size", - type=humanfriendly.parse_size, - default=0.0, - help="The maximum cache size for data loader. e.g. 10MB, 20GB.", - ) - group.add_argument( - "--max_cache_fd", - type=int, - default=32, - help="The maximum number of file descriptors to be kept " - "as opened for ark files. " - "This feature is only valid when data type is 'kaldi_ark'.", - ) - group.add_argument( - "--valid_max_cache_size", - type=humanfriendly_parse_size_or_none, - default=None, - help="The maximum cache size for validation data loader. e.g. 10MB, 20GB. " - "If None, the 5 percent size of --max_cache_size", - ) - - group = parser.add_argument_group("Optimizer related") - for i in range(1, cls.num_optimizers + 1): - suf = "" if i == 1 else str(i) - group.add_argument( - f"--optim{suf}", - type=lambda x: x.lower(), - default="adadelta", - choices=list(optim_classes), - help="The optimizer type", - ) - group.add_argument( - f"--optim{suf}_conf", - action=NestedDictAction, - default=dict(), - help="The keyword arguments for optimizer", - ) - group.add_argument( - f"--scheduler{suf}", - type=lambda x: str_or_none(x.lower()), - default=None, - choices=list(scheduler_classes) + [None], - help="The lr scheduler type", - ) - group.add_argument( - f"--scheduler{suf}_conf", - action=NestedDictAction, - default=dict(), - help="The keyword arguments for lr scheduler", - ) - - # for training on PAI - group = parser.add_argument_group("PAI training related") - group.add_argument( - "--use_pai", - type=str2bool, - default=False, - help="flag to indicate whether training on PAI", - ) - group.add_argument( - "--simple_ddp", - type=str2bool, - default=False, - ) - group.add_argument( - "--num_worker_count", - type=int, - default=1, - help="The number of machines on PAI.", - ) - group.add_argument( - "--access_key_id", - type=str, - default=None, - help="The username for oss.", - ) - group.add_argument( - "--access_key_secret", - type=str, - default=None, - help="The password for oss.", - ) - group.add_argument( - "--endpoint", - type=str, - default=None, - help="The endpoint for oss.", - ) - group.add_argument( - "--bucket_name", - type=str, - default=None, - help="The bucket name for oss.", - ) - group.add_argument( - "--oss_bucket", - default=None, - help="oss bucket.", - ) - - cls.trainer.add_arguments(parser) - cls.add_task_arguments(parser) - - assert check_return_type(parser) - return parser - - @classmethod - def build_optimizers( - cls, - args: argparse.Namespace, - model: torch.nn.Module, - ) -> List[torch.optim.Optimizer]: - if cls.num_optimizers != 1: - raise RuntimeError( - "build_optimizers() must be overridden if num_optimizers != 1" - ) - - optim_class = optim_classes.get(args.optim) - if optim_class is None: - raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}") - if args.sharded_ddp: - if fairscale is None: - raise RuntimeError("Requiring fairscale. Do 'pip install fairscale'") - optim = fairscale.optim.oss.OSS( - params=model.parameters(), optim=optim_class, **args.optim_conf - ) - else: - optim = optim_class(model.parameters(), **args.optim_conf) - - optimizers = [optim] - return optimizers - - @classmethod - def exclude_opts(cls) -> Tuple[str, ...]: - """The options not to be shown by --print_config""" - return "required", "print_config", "config", "ngpu" - - @classmethod - def get_default_config(cls) -> Dict[str, Any]: - """Return the configuration as dict. - - This method is used by print_config() - """ - - def get_class_type(name: str, classes: dict): - _cls = classes.get(name) - if _cls is None: - raise ValueError(f"must be one of {list(classes)}: {name}") - return _cls - - # This method is used only for --print_config - assert check_argument_types() - parser = cls.get_parser() - args, _ = parser.parse_known_args() - config = vars(args) - # Excludes the options not to be shown - for k in AbsTask.exclude_opts(): - config.pop(k) - - for i in range(1, cls.num_optimizers + 1): - suf = "" if i == 1 else str(i) - name = config[f"optim{suf}"] - optim_class = get_class_type(name, optim_classes) - conf = get_default_kwargs(optim_class) - # Overwrite the default by the arguments, - conf.update(config[f"optim{suf}_conf"]) - # and set it again - config[f"optim{suf}_conf"] = conf - - name = config[f"scheduler{suf}"] - if name is not None: - scheduler_class = get_class_type(name, scheduler_classes) - conf = get_default_kwargs(scheduler_class) - # Overwrite the default by the arguments, - conf.update(config[f"scheduler{suf}_conf"]) - # and set it again - config[f"scheduler{suf}_conf"] = conf - - for class_choices in cls.class_choices_list: - if getattr(args, class_choices.name) is not None: - class_obj = class_choices.get_class(getattr(args, class_choices.name)) - conf = get_default_kwargs(class_obj) - name = class_choices.name - # Overwrite the default by the arguments, - conf.update(config[f"{name}_conf"]) - # and set it again - config[f"{name}_conf"] = conf - return config - - @classmethod - def check_required_command_args(cls, args: argparse.Namespace): - assert check_argument_types() - if hasattr(args, "required"): - for k in vars(args): - if "-" in k: - raise RuntimeError(f'Use "_" instead of "-": parser.get_parser("{k}")') - - required = ", ".join( - f"--{a}" for a in args.required if getattr(args, a) is None - ) - - if len(required) != 0: - parser = cls.get_parser() - parser.print_help(file=sys.stderr) - p = Path(sys.argv[0]).name - print(file=sys.stderr) - print( - f"{p}: error: the following arguments are required: " f"{required}", - file=sys.stderr, - ) - sys.exit(2) - - @classmethod - def check_task_requirements( - cls, - dataset: Union[AbsDataset, IterableESPnetDataset], - allow_variable_data_keys: bool, - train: bool, - inference: bool = False, - ) -> None: - """Check if the dataset satisfy the requirement of current Task""" - assert check_argument_types() - mes = ( - f"If you intend to use an additional input, modify " - f'"{cls.__name__}.required_data_names()" or ' - f'"{cls.__name__}.optional_data_names()". ' - f"Otherwise you need to set --allow_variable_data_keys true " - ) - - for k in cls.required_data_names(train, inference): - if not dataset.has_name(k): - raise RuntimeError( - f'"{cls.required_data_names(train, inference)}" are required for' - f' {cls.__name__}. but "{dataset.names()}" are input.\n{mes}' - ) - if not allow_variable_data_keys: - task_keys = cls.required_data_names( - train, inference - ) + cls.optional_data_names(train, inference) - for k in dataset.names(): - if k not in task_keys: - raise RuntimeError( - f"The data-name must be one of {task_keys} " - f'for {cls.__name__}: "{k}" is not allowed.\n{mes}' - ) - - @classmethod - def print_config(cls, file=sys.stdout) -> None: - assert check_argument_types() - # Shows the config: e.g. python train.py asr --print_config - config = cls.get_default_config() - file.write(yaml_no_alias_safe_dump(config, indent=4, sort_keys=False)) - - @classmethod - def main(cls, args: argparse.Namespace = None, cmd: Sequence[str] = None): - assert check_argument_types() - print(get_commandline_args(), file=sys.stderr) - if args is None: - parser = cls.get_parser() - args = parser.parse_args(cmd) - args.version = __version__ - if args.pretrain_path is not None: - raise RuntimeError("--pretrain_path is deprecated. Use --init_param") - if args.print_config: - cls.print_config() - sys.exit(0) - cls.check_required_command_args(args) - - if not args.distributed or not args.multiprocessing_distributed: - cls.main_worker(args) - else: - assert args.ngpu > 1 - cls.main_worker(args) - - @classmethod - def run(cls): - assert hasattr(cls, "finetune_args") - args = cls.finetune_args - args.train_shape_file = None - if args.distributed: - args.simple_ddp = True - else: - args.simple_ddp = False - args.ngpu = 1 - args.use_pai = False - args.batch_type = "length" - args.oss_bucket = None - args.input_size = None - cls.main_worker(args) - - @classmethod - def main_worker(cls, args: argparse.Namespace): - assert check_argument_types() - - # 0. Init distributed process - distributed_option = build_dataclass(DistributedOption, args) - # Setting distributed_option.dist_rank, etc. - if args.use_pai: - distributed_option.init_options_pai() - elif not args.simple_ddp: - distributed_option.init_options() - - # Invoking torch.distributed.init_process_group - if args.use_pai: - distributed_option.init_torch_distributed_pai(args) - elif not args.simple_ddp: - distributed_option.init_torch_distributed(args) - elif args.distributed and args.simple_ddp: - distributed_option.init_torch_distributed_pai(args) - args.ngpu = dist.get_world_size() - if args.dataset_type == "small": - if args.batch_size is not None: - args.batch_size = args.batch_size * args.ngpu - if args.batch_bins is not None: - args.batch_bins = args.batch_bins * args.ngpu - - # filter samples if wav.scp and text are mismatch - if (args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large": - if not args.simple_ddp or distributed_option.dist_rank == 0: - filter_wav_text(args.data_dir, args.train_set) - filter_wav_text(args.data_dir, args.dev_set) - if args.simple_ddp: - dist.barrier() - - if args.train_shape_file is None and args.dataset_type == "small": - if not args.simple_ddp or distributed_option.dist_rank == 0: - calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max) - calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min, args.speech_length_max) - if args.simple_ddp: - dist.barrier() - args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")] - args.valid_shape_file = [os.path.join(args.data_dir, args.dev_set, "speech_shape")] - - if args.train_data_file is None and args.dataset_type == "large": - if not args.simple_ddp or distributed_option.dist_rank == 0: - generate_data_list(args.data_dir, args.train_set) - generate_data_list(args.data_dir, args.dev_set) - if args.simple_ddp: - dist.barrier() - args.train_data_file = os.path.join(args.data_dir, args.train_set, "data.list") - args.valid_data_file = os.path.join(args.data_dir, args.dev_set, "data.list") - - # NOTE(kamo): Don't use logging before invoking logging.basicConfig() - if not distributed_option.distributed or distributed_option.dist_rank == 0: - if not distributed_option.distributed: - _rank = "" - else: - _rank = ( - f":{distributed_option.dist_rank}/" - f"{distributed_option.dist_world_size}" - ) - - # NOTE(kamo): - # logging.basicConfig() is invoked in main_worker() instead of main() - # because it can be invoked only once in a process. - # FIXME(kamo): Should we use logging.getLogger()? - # BUGFIX: Remove previous handlers and reset log level - for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) - logging.basicConfig( - level=args.log_level, - format=f"[{os.uname()[1].split('.')[0]}]" - f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - else: - # BUGFIX: Remove previous handlers and reset log level - for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) - # Suppress logging if RANK != 0 - logging.basicConfig( - level="ERROR", - format=f"[{os.uname()[1].split('.')[0]}]" - f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size, - distributed_option.dist_rank, - distributed_option.local_rank)) - - # 1. Set random-seed - set_all_random_seed(args.seed) - torch.backends.cudnn.enabled = args.cudnn_enabled - torch.backends.cudnn.benchmark = args.cudnn_benchmark - torch.backends.cudnn.deterministic = args.cudnn_deterministic - if args.detect_anomaly: - logging.info("Invoking torch.autograd.set_detect_anomaly(True)") - torch.autograd.set_detect_anomaly(args.detect_anomaly) - - # 2. Build model - model = cls.build_model(args=args) - if not isinstance(model, AbsESPnetModel): - raise RuntimeError( - f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" - ) - model = model.to( - dtype=getattr(torch, args.train_dtype), - device="cuda" if args.ngpu > 0 else "cpu", - ) - for t in args.freeze_param: - for k, p in model.named_parameters(): - if k.startswith(t + ".") or k == t: - logging.info(f"Setting {k}.requires_grad = False") - p.requires_grad = False - - # 3. Build optimizer - optimizers = cls.build_optimizers(args, model=model) - - # 4. Build schedulers - schedulers = [] - for i, optim in enumerate(optimizers, 1): - suf = "" if i == 1 else str(i) - name = getattr(args, f"scheduler{suf}") - conf = getattr(args, f"scheduler{suf}_conf") - if name is not None: - cls_ = scheduler_classes.get(name) - if cls_ is None: - raise ValueError( - f"must be one of {list(scheduler_classes)}: {name}" - ) - scheduler = cls_(optim, **conf) - else: - scheduler = None - - schedulers.append(scheduler) - - logging.info(pytorch_cudnn_version()) - logging.info(model_summary(model)) - for i, (o, s) in enumerate(zip(optimizers, schedulers), 1): - suf = "" if i == 1 else str(i) - logging.info(f"Optimizer{suf}:\n{o}") - logging.info(f"Scheduler{suf}: {s}") - - # 5. Dump "args" to config.yaml - # NOTE(kamo): "args" should be saved after object-buildings are done - # because they are allowed to modify "args". - output_dir = Path(args.output_dir) - if not distributed_option.distributed or distributed_option.dist_rank == 0: - output_dir.mkdir(parents=True, exist_ok=True) - with (output_dir / "config.yaml").open("w", encoding="utf-8") as f: - logging.info( - f'Saving the configuration in {output_dir / "config.yaml"}' - ) - if args.use_pai: - buffer = BytesIO() - torch.save({"config": vars(args)}, buffer) - args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue()) - else: - yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False) - - if args.dry_run: - pass - elif args.collect_stats: - # Perform on collect_stats mode. This mode has two roles - # - Derive the length and dimension of all input data - # - Accumulate feats, square values, and the length for whitening - - if args.valid_batch_size is None: - args.valid_batch_size = args.batch_size - - if len(args.train_shape_file) != 0: - train_key_file = args.train_shape_file[0] - else: - train_key_file = None - if len(args.valid_shape_file) != 0: - valid_key_file = args.valid_shape_file[0] - else: - valid_key_file = None - - collect_stats( - model=model, - train_iter=cls.build_streaming_iterator( - data_path_and_name_and_type=args.train_data_path_and_name_and_type, - key_file=train_key_file, - batch_size=args.batch_size, - dtype=args.train_dtype, - num_workers=args.num_workers, - allow_variable_data_keys=args.allow_variable_data_keys, - ngpu=args.ngpu, - preprocess_fn=cls.build_preprocess_fn(args, train=False), - collate_fn=cls.build_collate_fn(args, train=False), - ), - valid_iter=cls.build_streaming_iterator( - data_path_and_name_and_type=args.valid_data_path_and_name_and_type, - key_file=valid_key_file, - batch_size=args.valid_batch_size, - dtype=args.train_dtype, - num_workers=args.num_workers, - allow_variable_data_keys=args.allow_variable_data_keys, - ngpu=args.ngpu, - preprocess_fn=cls.build_preprocess_fn(args, train=False), - collate_fn=cls.build_collate_fn(args, train=False), - ), - output_dir=output_dir, - ngpu=args.ngpu, - log_interval=args.log_interval, - write_collected_feats=args.write_collected_feats, - ) - else: - logging.info("Training args: {}".format(args)) - # 6. Loads pre-trained model - for p in args.init_param: - logging.info(f"Loading pretrained params from {p}") - load_pretrained_model( - model=model, - init_param=p, - ignore_init_mismatch=args.ignore_init_mismatch, - # NOTE(kamo): "cuda" for torch.load always indicates cuda:0 - # in PyTorch<=1.4 - map_location=f"cuda:{torch.cuda.current_device()}" - if args.ngpu > 0 - else "cpu", - oss_bucket=args.oss_bucket, - ) - - # 7. Build iterator factories - if args.dataset_type == "large": - from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader - train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf, - frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None, - seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None, - punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None, - bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None, - mode="train") - valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf, - frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None, - seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None, - punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None, - bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None, - mode="eval") - elif args.dataset_type == "small": - train_iter_factory = cls.build_iter_factory( - args=args, - distributed_option=distributed_option, - mode="train", - ) - valid_iter_factory = cls.build_iter_factory( - args=args, - distributed_option=distributed_option, - mode="valid", - ) - else: - raise ValueError(f"Not supported dataset_type={args.dataset_type}") - - if args.scheduler == "tri_stage": - for scheduler in schedulers: - scheduler.init_tri_stage_scheudler(max_update=args.max_update) - - # 8. Start training - if args.use_wandb: - if wandb is None: - raise RuntimeError("Please install wandb") - - try: - wandb.login() - except wandb.errors.UsageError: - logging.info("wandb not configured! run `wandb login` to enable") - args.use_wandb = False - - if args.use_wandb: - if ( - not distributed_option.distributed - or distributed_option.dist_rank == 0 - ): - if args.wandb_project is None: - project = "FunASR_" + cls.__name__ - else: - project = args.wandb_project - - if args.wandb_name is None: - name = str(Path(".").resolve()).replace("/", "_") - else: - name = args.wandb_name - - wandb.init( - entity=args.wandb_entity, - project=project, - name=name, - dir=output_dir, - id=args.wandb_id, - resume="allow", - ) - wandb.config.update(args) - else: - # wandb also supports grouping for distributed training, - # but we only logs aggregated data, - # so it's enough to perform on rank0 node. - args.use_wandb = False - - # Don't give args to trainer.run() directly!!! - # Instead of it, define "Options" object and build here. - trainer_options = cls.trainer.build_options(args) - cls.trainer.run( - model=model, - optimizers=optimizers, - schedulers=schedulers, - train_iter_factory=train_iter_factory, - valid_iter_factory=valid_iter_factory, - trainer_options=trainer_options, - distributed_option=distributed_option, - ) - - if args.use_wandb and wandb.run: - wandb.finish() - - @classmethod - def build_iter_options( - cls, - args: argparse.Namespace, - distributed_option: DistributedOption, - mode: str, - ): - if mode == "train": - preprocess_fn = cls.build_preprocess_fn(args, train=True) - collate_fn = cls.build_collate_fn(args, train=True) - data_path_and_name_and_type = args.train_data_path_and_name_and_type - shape_files = args.train_shape_file - batch_size = args.batch_size - batch_bins = args.batch_bins - batch_type = args.batch_type - max_cache_size = args.max_cache_size - max_cache_fd = args.max_cache_fd - distributed = distributed_option.distributed - num_batches = None - num_iters_per_epoch = args.num_iters_per_epoch - train = True - - elif mode == "valid": - preprocess_fn = cls.build_preprocess_fn(args, train=False) - collate_fn = cls.build_collate_fn(args, train=False) - data_path_and_name_and_type = args.valid_data_path_and_name_and_type - shape_files = args.valid_shape_file - - if args.valid_batch_type is None: - batch_type = args.batch_type - else: - batch_type = args.valid_batch_type - if args.valid_batch_size is None: - batch_size = args.batch_size - else: - batch_size = args.valid_batch_size - if args.valid_batch_bins is None: - batch_bins = args.batch_bins - else: - batch_bins = args.valid_batch_bins - if args.valid_max_cache_size is None: - # Cache 5% of maximum size for validation loader - max_cache_size = 0.05 * args.max_cache_size - else: - max_cache_size = args.valid_max_cache_size - max_cache_fd = args.max_cache_fd - distributed = distributed_option.distributed - num_batches = None - num_iters_per_epoch = None - train = False - else: - raise NotImplementedError(f"mode={mode}") - - return IteratorOptions( - preprocess_fn=preprocess_fn, - collate_fn=collate_fn, - data_path_and_name_and_type=data_path_and_name_and_type, - shape_files=shape_files, - batch_type=batch_type, - batch_size=batch_size, - batch_bins=batch_bins, - num_batches=num_batches, - max_cache_size=max_cache_size, - max_cache_fd=max_cache_fd, - distributed=distributed, - num_iters_per_epoch=num_iters_per_epoch, - train=train, - ) - - @classmethod - def build_iter_factory( - cls, - args: argparse.Namespace, - distributed_option: DistributedOption, - mode: str, - kwargs: dict = None, - ) -> AbsIterFactory: - """Build a factory object of mini-batch iterator. - - This object is invoked at every epochs to build the iterator for each epoch - as following: - - >>> iter_factory = cls.build_iter_factory(...) - >>> for epoch in range(1, max_epoch): - ... for keys, batch in iter_fatory.build_iter(epoch): - ... model(**batch) - - The mini-batches for each epochs are fully controlled by this class. - Note that the random seed used for shuffling is decided as "seed + epoch" and - the generated mini-batches can be reproduces when resuming. - - Note that the definition of "epoch" doesn't always indicate - to run out of the whole training corpus. - "--num_iters_per_epoch" option restricts the number of iterations for each epoch - and the rest of samples for the originally epoch are left for the next epoch. - e.g. If The number of mini-batches equals to 4, the following two are same: - - - 1 epoch without "--num_iters_per_epoch" - - 4 epoch with "--num_iters_per_epoch" == 4 - - """ - assert check_argument_types() - iter_options = cls.build_iter_options(args, distributed_option, mode) - - # Overwrite iter_options if any kwargs is given - if kwargs is not None: - for k, v in kwargs.items(): - setattr(iter_options, k, v) - - if args.iterator_type == "sequence": - return cls.build_sequence_iter_factory( - args=args, - iter_options=iter_options, - mode=mode, - ) - elif args.iterator_type == "chunk": - return cls.build_chunk_iter_factory( - args=args, - iter_options=iter_options, - mode=mode, - ) - elif args.iterator_type == "task": - return cls.build_task_iter_factory( - args=args, - iter_options=iter_options, - mode=mode, - ) - else: - raise RuntimeError(f"Not supported: iterator_type={args.iterator_type}") - - @classmethod - def build_sequence_iter_factory( - cls, args: argparse.Namespace, iter_options: IteratorOptions, mode: str - ) -> AbsIterFactory: - assert check_argument_types() - - if args.frontend_conf is not None and "fs" in args.frontend_conf: - dest_sample_rate = args.frontend_conf["fs"] - else: - dest_sample_rate = 16000 - - dataset = ESPnetDataset( - iter_options.data_path_and_name_and_type, - float_dtype=args.train_dtype, - preprocess=iter_options.preprocess_fn, - max_cache_size=iter_options.max_cache_size, - max_cache_fd=iter_options.max_cache_fd, -<<<<<<< HEAD - dest_sample_rate=args.frontend_conf["fs"] if args.frontend_conf else 16000, -======= - dest_sample_rate=dest_sample_rate, ->>>>>>> main - ) - cls.check_task_requirements( - dataset, args.allow_variable_data_keys, train=iter_options.train - ) - - if Path( - Path(iter_options.data_path_and_name_and_type[0][0]).parent, "utt2category" - ).exists(): - utt2category_file = str( - Path( - Path(iter_options.data_path_and_name_and_type[0][0]).parent, - "utt2category", - ) - ) - else: - utt2category_file = None - batch_sampler = build_batch_sampler( - type=iter_options.batch_type, - shape_files=iter_options.shape_files, - fold_lengths=args.fold_length, - batch_size=iter_options.batch_size, - batch_bins=iter_options.batch_bins, - sort_in_batch=args.sort_in_batch, - sort_batch=args.sort_batch, - drop_last=False, - min_batch_size=torch.distributed.get_world_size() - if iter_options.distributed - else 1, - utt2category_file=utt2category_file, - ) - - batches = list(batch_sampler) - if iter_options.num_batches is not None: - batches = batches[: iter_options.num_batches] - - bs_list = [len(batch) for batch in batches] - - logging.info(f"[{mode}] dataset:\n{dataset}") - logging.info(f"[{mode}] Batch sampler: {batch_sampler}") - logging.info( - f"[{mode}] mini-batch sizes summary: N-batch={len(bs_list)}, " - f"mean={np.mean(bs_list):.1f}, min={np.min(bs_list)}, max={np.max(bs_list)}" - ) - - if args.scheduler == "tri_stage" and mode == "train": - args.max_update = len(bs_list) * args.max_epoch - logging.info("Max update: {}".format(args.max_update)) - - if iter_options.distributed: - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - for batch in batches: - if len(batch) < world_size: - raise RuntimeError( - f"The batch-size must be equal or more than world_size: " - f"{len(batch)} < {world_size}" - ) - batches = [batch[rank::world_size] for batch in batches] - - return SequenceIterFactory( - dataset=dataset, - batches=batches, - seed=args.seed, - num_iters_per_epoch=iter_options.num_iters_per_epoch, - shuffle=iter_options.train, - num_workers=args.num_workers, - collate_fn=iter_options.collate_fn, - pin_memory=args.ngpu > 0, - ) - - @classmethod - def build_chunk_iter_factory( - cls, - args: argparse.Namespace, - iter_options: IteratorOptions, - mode: str, - ) -> AbsIterFactory: - assert check_argument_types() - - dataset = ESPnetDataset( - iter_options.data_path_and_name_and_type, - float_dtype=args.train_dtype, - preprocess=iter_options.preprocess_fn, - max_cache_size=iter_options.max_cache_size, - max_cache_fd=iter_options.max_cache_fd, - ) - cls.check_task_requirements( - dataset, args.allow_variable_data_keys, train=iter_options.train - ) - - if len(iter_options.shape_files) == 0: - key_file = iter_options.data_path_and_name_and_type[0][0] - else: - key_file = iter_options.shape_files[0] - - batch_sampler = UnsortedBatchSampler(batch_size=1, key_file=key_file) - batches = list(batch_sampler) - if iter_options.num_batches is not None: - batches = batches[: iter_options.num_batches] - logging.info(f"[{mode}] dataset:\n{dataset}") - - if iter_options.distributed: - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - if len(batches) < world_size: - raise RuntimeError("Number of samples is smaller than world_size") - if iter_options.batch_size < world_size: - raise RuntimeError("batch_size must be equal or more than world_size") - - if rank < iter_options.batch_size % world_size: - batch_size = iter_options.batch_size // world_size + 1 - else: - batch_size = iter_options.batch_size // world_size - num_cache_chunks = args.num_cache_chunks // world_size - # NOTE(kamo): Split whole corpus by sample numbers without considering - # each of the lengths, therefore the number of iteration counts are not - # always equal to each other and the iterations are limitted - # by the fewest iterations. - # i.e. the samples over the counts are discarded. - batches = batches[rank::world_size] - else: - batch_size = iter_options.batch_size - num_cache_chunks = args.num_cache_chunks - - return ChunkIterFactory( - dataset=dataset, - batches=batches, - seed=args.seed, - batch_size=batch_size, - # For chunk iterator, - # --num_iters_per_epoch doesn't indicate the number of iterations, - # but indicates the number of samples. - num_samples_per_epoch=iter_options.num_iters_per_epoch, - shuffle=iter_options.train, - num_workers=args.num_workers, - collate_fn=iter_options.collate_fn, - pin_memory=args.ngpu > 0, - chunk_length=args.chunk_length, - chunk_shift_ratio=args.chunk_shift_ratio, - num_cache_chunks=num_cache_chunks, - ) - - # NOTE(kamo): Not abstract class - @classmethod - def build_task_iter_factory( - cls, - args: argparse.Namespace, - iter_options: IteratorOptions, - mode: str, - ) -> AbsIterFactory: - """Build task specific iterator factory - - Example: - - >>> class YourTask(AbsTask): - ... @classmethod - ... def add_task_arguments(cls, parser: argparse.ArgumentParser): - ... parser.set_defaults(iterator_type="task") - ... - ... @classmethod - ... def build_task_iter_factory( - ... cls, - ... args: argparse.Namespace, - ... iter_options: IteratorOptions, - ... mode: str, - ... ): - ... return FooIterFactory(...) - ... - ... @classmethod - ... def build_iter_options( - .... args: argparse.Namespace, - ... distributed_option: DistributedOption, - ... mode: str - ... ): - ... # if you need to customize options object - """ - raise NotImplementedError - - @classmethod - def build_multiple_iter_factory( - cls, args: argparse.Namespace, distributed_option: DistributedOption, mode: str - ): - assert check_argument_types() - iter_options = cls.build_iter_options(args, distributed_option, mode) - assert len(iter_options.data_path_and_name_and_type) > 0, len( - iter_options.data_path_and_name_and_type - ) - - # 1. Sanity check - num_splits = None - for path in [ - path for path, _, _ in iter_options.data_path_and_name_and_type - ] + list(iter_options.shape_files): - if not Path(path).is_dir(): - raise RuntimeError(f"{path} is not a directory") - p = Path(path) / "num_splits" - if not p.exists(): - raise FileNotFoundError(f"{p} is not found") - with p.open() as f: - _num_splits = int(f.read()) - if num_splits is not None and num_splits != _num_splits: - raise RuntimeError( - f"Number of splits are mismathed: " - f"{iter_options.data_path_and_name_and_type[0][0]} and {path}" - ) - num_splits = _num_splits - - for i in range(num_splits): - p = Path(path) / f"split.{i}" - if not p.exists(): - raise FileNotFoundError(f"{p} is not found") - - # 2. Create functions to build an iter factory for each splits - data_path_and_name_and_type_list = [ - [ - (str(Path(p) / f"split.{i}"), n, t) - for p, n, t in iter_options.data_path_and_name_and_type - ] - for i in range(num_splits) - ] - shape_files_list = [ - [str(Path(s) / f"split.{i}") for s in iter_options.shape_files] - for i in range(num_splits) - ] - num_iters_per_epoch_list = [ - (iter_options.num_iters_per_epoch + i) // num_splits - if iter_options.num_iters_per_epoch is not None - else None - for i in range(num_splits) - ] - max_cache_size = iter_options.max_cache_size / num_splits - - # Note that iter-factories are built for each epoch at runtime lazily. - build_funcs = [ - functools.partial( - cls.build_iter_factory, - args, - distributed_option, - mode, - kwargs=dict( - data_path_and_name_and_type=_data_path_and_name_and_type, - shape_files=_shape_files, - num_iters_per_epoch=_num_iters_per_epoch, - max_cache_size=max_cache_size, - ), - ) - for ( - _data_path_and_name_and_type, - _shape_files, - _num_iters_per_epoch, - ) in zip( - data_path_and_name_and_type_list, - shape_files_list, - num_iters_per_epoch_list, - ) - ] - - # 3. Build MultipleIterFactory - return MultipleIterFactory( - build_funcs=build_funcs, shuffle=iter_options.train, seed=args.seed - ) - - @classmethod - def build_streaming_iterator( - cls, - data_path_and_name_and_type, - preprocess_fn, - collate_fn, - key_file: str = None, - batch_size: int = 1, - fs: dict = None, - mc: bool = False, - dtype: str = np.float32, - num_workers: int = 1, - allow_variable_data_keys: bool = False, - ngpu: int = 0, - inference: bool = False, - ) -> DataLoader: - """Build DataLoader using iterable dataset""" - assert check_argument_types() - # For backward compatibility for pytorch DataLoader - if collate_fn is not None: - kwargs = dict(collate_fn=collate_fn) - else: - kwargs = {} - - dataset = IterableESPnetDataset( - data_path_and_name_and_type, - float_dtype=dtype, - fs=fs, - mc=mc, - preprocess=preprocess_fn, - key_file=key_file, - ) - if dataset.apply_utt2category: - kwargs.update(batch_size=1) - else: - kwargs.update(batch_size=batch_size) - - cls.check_task_requirements( - dataset, allow_variable_data_keys, train=False, inference=inference - ) - - return DataLoader( - dataset=dataset, - pin_memory=ngpu > 0, - num_workers=num_workers, - **kwargs, - ) - - # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~ - @classmethod - def build_model_from_file( - cls, - config_file: Union[Path, str] = None, - model_file: Union[Path, str] = None, - cmvn_file: Union[Path, str] = None, - device: str = "cpu", - ) -> Tuple[AbsESPnetModel, argparse.Namespace]: - """Build model from the files. - - This method is used for inference or fine-tuning. - - Args: - config_file: The yaml file saved when training. - model_file: The model file saved when training. - device: Device type, "cpu", "cuda", or "cuda:N". - - """ - assert check_argument_types() - if config_file is None: - assert model_file is not None, ( - "The argument 'model_file' must be provided " - "if the argument 'config_file' is not specified." - ) - config_file = Path(model_file).parent / "config.yaml" - else: - config_file = Path(config_file) - - with config_file.open("r", encoding="utf-8") as f: - args = yaml.safe_load(f) - if cmvn_file is not None: - args["cmvn_file"] = cmvn_file - args = argparse.Namespace(**args) - model = cls.build_model(args) - if not isinstance(model, AbsESPnetModel): - raise RuntimeError( - f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" - ) - model.to(device) - if model_file is not None: - if device == "cuda": - # NOTE(kamo): "cuda" for torch.load always indicates cuda:0 - # in PyTorch<=1.4 - device = f"cuda:{torch.cuda.current_device()}" - model.load_state_dict(torch.load(model_file, map_location=device)) - model.to(device) - return model, args From dccbcc48a583eb38b4bc5b459d417087826846ac Mon Sep 17 00:00:00 2001 From: aky15 Date: Wed, 12 Apr 2023 18:06:29 +0800 Subject: [PATCH 10/14] resolve conflict --- funasr/tasks/abs_task.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py index e70b0623b..8d63b27d9 100644 --- a/funasr/tasks/abs_task.py +++ b/funasr/tasks/abs_task.py @@ -1594,11 +1594,7 @@ class AbsTask(ABC): preprocess=iter_options.preprocess_fn, max_cache_size=iter_options.max_cache_size, max_cache_fd=iter_options.max_cache_fd, -<<<<<<< HEAD - dest_sample_rate=args.frontend_conf["fs"] if args.frontend_conf else 16000, -======= dest_sample_rate=dest_sample_rate, ->>>>>>> main ) cls.check_task_requirements( dataset, args.allow_variable_data_keys, train=iter_options.train From 256035b6c1fa6115b6f33972ed243eb43f3e4299 Mon Sep 17 00:00:00 2001 From: aky15 Date: Fri, 14 Apr 2023 11:38:00 +0800 Subject: [PATCH 11/14] rnnt reorg --- funasr/models/e2e_transducer.py | 2 +- funasr/models/e2e_transducer_unified.py | 2 +- funasr/models/encoder/chunk_encoder.py | 292 -------- .../encoder/chunk_encoder_blocks/__init__.py | 0 .../chunk_encoder_blocks/branchformer.py | 178 ----- .../encoder/chunk_encoder_blocks/conformer.py | 198 ------ .../encoder/chunk_encoder_blocks/conv1d.py | 221 ------ .../chunk_encoder_blocks/conv_input.py | 222 ------ .../chunk_encoder_blocks/linear_input.py | 52 -- .../encoder/chunk_encoder_modules/__init__.py | 0 .../chunk_encoder_modules/attention.py | 246 ------- .../chunk_encoder_modules/convolution.py | 196 ------ .../chunk_encoder_modules/multi_blocks.py | 105 --- .../positional_encoding.py | 91 --- .../encoder/chunk_encoder_utils/building.py | 352 ---------- .../encoder/chunk_encoder_utils/validation.py | 171 ----- funasr/models/encoder/conformer_encoder.py | 640 +++++++++++++++++- funasr/modules/attention.py | 220 +++++- funasr/modules/embedding.py | 77 ++- .../normalization.py | 0 funasr/modules/repeat.py | 92 +++ funasr/modules/subsampling.py | 202 ++++++ funasr/tasks/asr_transducer.py | 6 +- 23 files changed, 1233 insertions(+), 2332 deletions(-) delete mode 100644 funasr/models/encoder/chunk_encoder.py delete mode 100644 funasr/models/encoder/chunk_encoder_blocks/__init__.py delete mode 100644 funasr/models/encoder/chunk_encoder_blocks/branchformer.py delete mode 100644 funasr/models/encoder/chunk_encoder_blocks/conformer.py delete mode 100644 funasr/models/encoder/chunk_encoder_blocks/conv1d.py delete mode 100644 funasr/models/encoder/chunk_encoder_blocks/conv_input.py delete mode 100644 funasr/models/encoder/chunk_encoder_blocks/linear_input.py delete mode 100644 funasr/models/encoder/chunk_encoder_modules/__init__.py delete mode 100644 funasr/models/encoder/chunk_encoder_modules/attention.py delete mode 100644 funasr/models/encoder/chunk_encoder_modules/convolution.py delete mode 100644 funasr/models/encoder/chunk_encoder_modules/multi_blocks.py delete mode 100644 funasr/models/encoder/chunk_encoder_modules/positional_encoding.py delete mode 100644 funasr/models/encoder/chunk_encoder_utils/building.py delete mode 100644 funasr/models/encoder/chunk_encoder_utils/validation.py rename funasr/{models/encoder/chunk_encoder_modules => modules}/normalization.py (100%) diff --git a/funasr/models/e2e_transducer.py b/funasr/models/e2e_transducer.py index b669c9d3e..8630aec40 100644 --- a/funasr/models/e2e_transducer.py +++ b/funasr/models/e2e_transducer.py @@ -12,7 +12,7 @@ from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder -from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder +from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder from funasr.models.joint_network import JointNetwork from funasr.modules.nets_utils import get_transducer_task_io from funasr.layers.abs_normalize import AbsNormalize diff --git a/funasr/models/e2e_transducer_unified.py b/funasr/models/e2e_transducer_unified.py index 600354216..124bc0938 100644 --- a/funasr/models/e2e_transducer_unified.py +++ b/funasr/models/e2e_transducer_unified.py @@ -11,7 +11,7 @@ from typeguard import check_argument_types from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder -from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder +from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder from funasr.models.joint_network import JointNetwork from funasr.modules.nets_utils import get_transducer_task_io from funasr.layers.abs_normalize import AbsNormalize diff --git a/funasr/models/encoder/chunk_encoder.py b/funasr/models/encoder/chunk_encoder.py deleted file mode 100644 index c6fc292e0..000000000 --- a/funasr/models/encoder/chunk_encoder.py +++ /dev/null @@ -1,292 +0,0 @@ -from typing import Any, Dict, List, Tuple - -import torch -from typeguard import check_argument_types - -from funasr.models.encoder.chunk_encoder_utils.building import ( - build_body_blocks, - build_input_block, - build_main_parameters, - build_positional_encoding, -) -from funasr.models.encoder.chunk_encoder_utils.validation import validate_architecture -from funasr.modules.nets_utils import ( - TooShortUttError, - check_short_utt, - make_chunk_mask, - make_source_mask, -) - -class ChunkEncoder(torch.nn.Module): - """Encoder module definition. - - Args: - input_size: Input size. - body_conf: Encoder body configuration. - input_conf: Encoder input configuration. - main_conf: Encoder main configuration. - - """ - - def __init__( - self, - input_size: int, - body_conf: List[Dict[str, Any]], - input_conf: Dict[str, Any] = {}, - main_conf: Dict[str, Any] = {}, - ) -> None: - """Construct an Encoder object.""" - super().__init__() - - assert check_argument_types() - - embed_size, output_size = validate_architecture( - input_conf, body_conf, input_size - ) - main_params = build_main_parameters(**main_conf) - - self.embed = build_input_block(input_size, input_conf) - self.pos_enc = build_positional_encoding(embed_size, main_params) - self.encoders = build_body_blocks(body_conf, main_params, output_size) - - self.output_size = output_size - - self.dynamic_chunk_training = main_params["dynamic_chunk_training"] - self.short_chunk_threshold = main_params["short_chunk_threshold"] - self.short_chunk_size = main_params["short_chunk_size"] - self.left_chunk_size = main_params["left_chunk_size"] - - self.unified_model_training = main_params["unified_model_training"] - self.default_chunk_size = main_params["default_chunk_size"] - self.jitter_range = main_params["jitter_range"] - - self.time_reduction_factor = main_params["time_reduction_factor"] - def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int: - """Return the corresponding number of sample for a given chunk size, in frames. - - Where size is the number of features frames after applying subsampling. - - Args: - size: Number of frames after subsampling. - hop_length: Frontend's hop length - - Returns: - : Number of raw samples - - """ - return self.embed.get_size_before_subsampling(size) * hop_length - - def get_encoder_input_size(self, size: int) -> int: - """Return the corresponding number of sample for a given chunk size, in frames. - - Where size is the number of features frames after applying subsampling. - - Args: - size: Number of frames after subsampling. - - Returns: - : Number of raw samples - - """ - return self.embed.get_size_before_subsampling(size) - - - def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: - """Initialize/Reset encoder streaming cache. - - Args: - left_context: Number of frames in left context. - device: Device ID. - - """ - return self.encoders.reset_streaming_cache(left_context, device) - - def forward( - self, - x: torch.Tensor, - x_len: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Encode input sequences. - - Args: - x: Encoder input features. (B, T_in, F) - x_len: Encoder input features lengths. (B,) - - Returns: - x: Encoder outputs. (B, T_out, D_enc) - x_len: Encoder outputs lenghts. (B,) - - """ - short_status, limit_size = check_short_utt( - self.embed.subsampling_factor, x.size(1) - ) - - if short_status: - raise TooShortUttError( - f"has {x.size(1)} frames and is too short for subsampling " - + f"(it needs more than {limit_size} frames), return empty results", - x.size(1), - limit_size, - ) - - mask = make_source_mask(x_len) - - if self.unified_model_training: - chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item() - x, mask = self.embed(x, mask, chunk_size) - pos_enc = self.pos_enc(x) - chunk_mask = make_chunk_mask( - x.size(1), - chunk_size, - left_chunk_size=self.left_chunk_size, - device=x.device, - ) - x_utt = self.encoders( - x, - pos_enc, - mask, - chunk_mask=None, - ) - x_chunk = self.encoders( - x, - pos_enc, - mask, - chunk_mask=chunk_mask, - ) - - olens = mask.eq(0).sum(1) - if self.time_reduction_factor > 1: - x_utt = x_utt[:,::self.time_reduction_factor,:] - x_chunk = x_chunk[:,::self.time_reduction_factor,:] - olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 - - return x_utt, x_chunk, olens - - elif self.dynamic_chunk_training: - max_len = x.size(1) - chunk_size = torch.randint(1, max_len, (1,)).item() - - if chunk_size > (max_len * self.short_chunk_threshold): - chunk_size = max_len - else: - chunk_size = (chunk_size % self.short_chunk_size) + 1 - - x, mask = self.embed(x, mask, chunk_size) - pos_enc = self.pos_enc(x) - - chunk_mask = make_chunk_mask( - x.size(1), - chunk_size, - left_chunk_size=self.left_chunk_size, - device=x.device, - ) - else: - x, mask = self.embed(x, mask, None) - pos_enc = self.pos_enc(x) - chunk_mask = None - x = self.encoders( - x, - pos_enc, - mask, - chunk_mask=chunk_mask, - ) - - olens = mask.eq(0).sum(1) - if self.time_reduction_factor > 1: - x = x[:,::self.time_reduction_factor,:] - olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 - - return x, olens - - def simu_chunk_forward( - self, - x: torch.Tensor, - x_len: torch.Tensor, - chunk_size: int = 16, - left_context: int = 32, - right_context: int = 0, - ) -> torch.Tensor: - short_status, limit_size = check_short_utt( - self.embed.subsampling_factor, x.size(1) - ) - - if short_status: - raise TooShortUttError( - f"has {x.size(1)} frames and is too short for subsampling " - + f"(it needs more than {limit_size} frames), return empty results", - x.size(1), - limit_size, - ) - - mask = make_source_mask(x_len) - - x, mask = self.embed(x, mask, chunk_size) - pos_enc = self.pos_enc(x) - chunk_mask = make_chunk_mask( - x.size(1), - chunk_size, - left_chunk_size=self.left_chunk_size, - device=x.device, - ) - - x = self.encoders( - x, - pos_enc, - mask, - chunk_mask=chunk_mask, - ) - olens = mask.eq(0).sum(1) - if self.time_reduction_factor > 1: - x = x[:,::self.time_reduction_factor,:] - - return x - - def chunk_forward( - self, - x: torch.Tensor, - x_len: torch.Tensor, - processed_frames: torch.tensor, - chunk_size: int = 16, - left_context: int = 32, - right_context: int = 0, - ) -> torch.Tensor: - """Encode input sequences as chunks. - - Args: - x: Encoder input features. (1, T_in, F) - x_len: Encoder input features lengths. (1,) - processed_frames: Number of frames already seen. - left_context: Number of frames in left context. - right_context: Number of frames in right context. - - Returns: - x: Encoder outputs. (B, T_out, D_enc) - - """ - mask = make_source_mask(x_len) - x, mask = self.embed(x, mask, None) - - if left_context > 0: - processed_mask = ( - torch.arange(left_context, device=x.device) - .view(1, left_context) - .flip(1) - ) - processed_mask = processed_mask >= processed_frames - mask = torch.cat([processed_mask, mask], dim=1) - pos_enc = self.pos_enc(x, left_context=left_context) - x = self.encoders.chunk_forward( - x, - pos_enc, - mask, - chunk_size=chunk_size, - left_context=left_context, - right_context=right_context, - ) - - if right_context > 0: - x = x[:, 0:-right_context, :] - - if self.time_reduction_factor > 1: - x = x[:,::self.time_reduction_factor,:] - return x diff --git a/funasr/models/encoder/chunk_encoder_blocks/__init__.py b/funasr/models/encoder/chunk_encoder_blocks/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/funasr/models/encoder/chunk_encoder_blocks/branchformer.py b/funasr/models/encoder/chunk_encoder_blocks/branchformer.py deleted file mode 100644 index ba0b25d83..000000000 --- a/funasr/models/encoder/chunk_encoder_blocks/branchformer.py +++ /dev/null @@ -1,178 +0,0 @@ -"""Branchformer block for Transducer encoder.""" - -from typing import Dict, Optional, Tuple - -import torch - - -class Branchformer(torch.nn.Module): - """Branchformer module definition. - - Reference: https://arxiv.org/pdf/2207.02971.pdf - - Args: - block_size: Input/output size. - linear_size: Linear layers' hidden size. - self_att: Self-attention module instance. - conv_mod: Convolution module instance. - norm_class: Normalization class. - norm_args: Normalization module arguments. - dropout_rate: Dropout rate. - - """ - - def __init__( - self, - block_size: int, - linear_size: int, - self_att: torch.nn.Module, - conv_mod: torch.nn.Module, - norm_class: torch.nn.Module = torch.nn.LayerNorm, - norm_args: Dict = {}, - dropout_rate: float = 0.0, - ) -> None: - """Construct a Branchformer object.""" - super().__init__() - - self.self_att = self_att - self.conv_mod = conv_mod - - self.channel_proj1 = torch.nn.Sequential( - torch.nn.Linear(block_size, linear_size), torch.nn.GELU() - ) - self.channel_proj2 = torch.nn.Linear(linear_size // 2, block_size) - - self.merge_proj = torch.nn.Linear(block_size + block_size, block_size) - - self.norm_self_att = norm_class(block_size, **norm_args) - self.norm_mlp = norm_class(block_size, **norm_args) - self.norm_final = norm_class(block_size, **norm_args) - - self.dropout = torch.nn.Dropout(dropout_rate) - - self.block_size = block_size - self.linear_size = linear_size - self.cache = None - - def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: - """Initialize/Reset self-attention and convolution modules cache for streaming. - - Args: - left_context: Number of left frames during chunk-by-chunk inference. - device: Device to use for cache tensor. - - """ - self.cache = [ - torch.zeros( - (1, left_context, self.block_size), - device=device, - ), - torch.zeros( - ( - 1, - self.linear_size // 2, - self.conv_mod.kernel_size - 1, - ), - device=device, - ), - ] - - def forward( - self, - x: torch.Tensor, - pos_enc: torch.Tensor, - mask: torch.Tensor, - chunk_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Encode input sequences. - - Args: - x: Branchformer input sequences. (B, T, D_block) - pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) - mask: Source mask. (B, T) - chunk_mask: Chunk mask. (T_2, T_2) - - Returns: - x: Branchformer output sequences. (B, T, D_block) - mask: Source mask. (B, T) - pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) - - """ - x1 = x - x2 = x - - x1 = self.norm_self_att(x1) - - x1 = self.dropout( - self.self_att(x1, x1, x1, pos_enc, mask=mask, chunk_mask=chunk_mask) - ) - - x2 = self.norm_mlp(x2) - - x2 = self.channel_proj1(x2) - x2, _ = self.conv_mod(x2) - x2 = self.channel_proj2(x2) - - x2 = self.dropout(x2) - - x = x + self.dropout(self.merge_proj(torch.cat([x1, x2], dim=-1))) - - x = self.norm_final(x) - - return x, mask, pos_enc - - def chunk_forward( - self, - x: torch.Tensor, - pos_enc: torch.Tensor, - mask: torch.Tensor, - left_context: int = 0, - right_context: int = 0, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Encode chunk of input sequence. - - Args: - x: Branchformer input sequences. (B, T, D_block) - pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) - mask: Source mask. (B, T_2) - left_context: Number of frames in left context. - right_context: Number of frames in right context. - - Returns: - x: Branchformer output sequences. (B, T, D_block) - pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) - - """ - x1 = x - x2 = x - - x1 = self.norm_self_att(x1) - - if left_context > 0: - key = torch.cat([self.cache[0], x1], dim=1) - else: - key = x1 - val = key - - if right_context > 0: - att_cache = key[:, -(left_context + right_context) : -right_context, :] - else: - att_cache = key[:, -left_context:, :] - - x1 = self.self_att(x1, key, val, pos_enc, mask=mask, left_context=left_context) - - x2 = self.norm_mlp(x2) - x2 = self.channel_proj1(x2) - - x2, conv_cache = self.conv_mod( - x2, cache=self.cache[1], right_context=right_context - ) - - x2 = self.channel_proj2(x2) - - x = x + self.merge_proj(torch.cat([x1, x2], dim=-1)) - - x = self.norm_final(x) - self.cache = [att_cache, conv_cache] - - return x, pos_enc diff --git a/funasr/models/encoder/chunk_encoder_blocks/conformer.py b/funasr/models/encoder/chunk_encoder_blocks/conformer.py deleted file mode 100644 index 0b9bbbf12..000000000 --- a/funasr/models/encoder/chunk_encoder_blocks/conformer.py +++ /dev/null @@ -1,198 +0,0 @@ -"""Conformer block for Transducer encoder.""" - -from typing import Dict, Optional, Tuple - -import torch - - -class Conformer(torch.nn.Module): - """Conformer module definition. - - Args: - block_size: Input/output size. - self_att: Self-attention module instance. - feed_forward: Feed-forward module instance. - feed_forward_macaron: Feed-forward module instance for macaron network. - conv_mod: Convolution module instance. - norm_class: Normalization module class. - norm_args: Normalization module arguments. - dropout_rate: Dropout rate. - - """ - - def __init__( - self, - block_size: int, - self_att: torch.nn.Module, - feed_forward: torch.nn.Module, - feed_forward_macaron: torch.nn.Module, - conv_mod: torch.nn.Module, - norm_class: torch.nn.Module = torch.nn.LayerNorm, - norm_args: Dict = {}, - dropout_rate: float = 0.0, - ) -> None: - """Construct a Conformer object.""" - super().__init__() - - self.self_att = self_att - - self.feed_forward = feed_forward - self.feed_forward_macaron = feed_forward_macaron - self.feed_forward_scale = 0.5 - - self.conv_mod = conv_mod - - self.norm_feed_forward = norm_class(block_size, **norm_args) - self.norm_self_att = norm_class(block_size, **norm_args) - - self.norm_macaron = norm_class(block_size, **norm_args) - self.norm_conv = norm_class(block_size, **norm_args) - self.norm_final = norm_class(block_size, **norm_args) - - self.dropout = torch.nn.Dropout(dropout_rate) - - self.block_size = block_size - self.cache = None - - def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: - """Initialize/Reset self-attention and convolution modules cache for streaming. - - Args: - left_context: Number of left frames during chunk-by-chunk inference. - device: Device to use for cache tensor. - - """ - self.cache = [ - torch.zeros( - (1, left_context, self.block_size), - device=device, - ), - torch.zeros( - ( - 1, - self.block_size, - self.conv_mod.kernel_size - 1, - ), - device=device, - ), - ] - - def forward( - self, - x: torch.Tensor, - pos_enc: torch.Tensor, - mask: torch.Tensor, - chunk_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Encode input sequences. - - Args: - x: Conformer input sequences. (B, T, D_block) - pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) - mask: Source mask. (B, T) - chunk_mask: Chunk mask. (T_2, T_2) - - Returns: - x: Conformer output sequences. (B, T, D_block) - mask: Source mask. (B, T) - pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) - - """ - residual = x - - x = self.norm_macaron(x) - x = residual + self.feed_forward_scale * self.dropout( - self.feed_forward_macaron(x) - ) - - residual = x - x = self.norm_self_att(x) - x_q = x - x = residual + self.dropout( - self.self_att( - x_q, - x, - x, - pos_enc, - mask, - chunk_mask=chunk_mask, - ) - ) - - residual = x - - x = self.norm_conv(x) - x, _ = self.conv_mod(x) - x = residual + self.dropout(x) - residual = x - - x = self.norm_feed_forward(x) - x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x)) - - x = self.norm_final(x) - return x, mask, pos_enc - - def chunk_forward( - self, - x: torch.Tensor, - pos_enc: torch.Tensor, - mask: torch.Tensor, - chunk_size: int = 16, - left_context: int = 0, - right_context: int = 0, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Encode chunk of input sequence. - - Args: - x: Conformer input sequences. (B, T, D_block) - pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) - mask: Source mask. (B, T_2) - left_context: Number of frames in left context. - right_context: Number of frames in right context. - - Returns: - x: Conformer output sequences. (B, T, D_block) - pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) - - """ - residual = x - - x = self.norm_macaron(x) - x = residual + self.feed_forward_scale * self.feed_forward_macaron(x) - - residual = x - x = self.norm_self_att(x) - if left_context > 0: - key = torch.cat([self.cache[0], x], dim=1) - else: - key = x - val = key - - if right_context > 0: - att_cache = key[:, -(left_context + right_context) : -right_context, :] - else: - att_cache = key[:, -left_context:, :] - x = residual + self.self_att( - x, - key, - val, - pos_enc, - mask, - left_context=left_context, - ) - - residual = x - x = self.norm_conv(x) - x, conv_cache = self.conv_mod( - x, cache=self.cache[1], right_context=right_context - ) - x = residual + x - residual = x - - x = self.norm_feed_forward(x) - x = residual + self.feed_forward_scale * self.feed_forward(x) - - x = self.norm_final(x) - self.cache = [att_cache, conv_cache] - - return x, pos_enc diff --git a/funasr/models/encoder/chunk_encoder_blocks/conv1d.py b/funasr/models/encoder/chunk_encoder_blocks/conv1d.py deleted file mode 100644 index f79cc37b4..000000000 --- a/funasr/models/encoder/chunk_encoder_blocks/conv1d.py +++ /dev/null @@ -1,221 +0,0 @@ -"""Conv1d block for Transducer encoder.""" - -from typing import Optional, Tuple, Union - -import torch - - -class Conv1d(torch.nn.Module): - """Conv1d module definition. - - Args: - input_size: Input dimension. - output_size: Output dimension. - kernel_size: Size of the convolving kernel. - stride: Stride of the convolution. - dilation: Spacing between the kernel points. - groups: Number of blocked connections from input channels to output channels. - bias: Whether to add a learnable bias to the output. - batch_norm: Whether to use batch normalization after convolution. - relu: Whether to use a ReLU activation after convolution. - causal: Whether to use causal convolution (set to True if streaming). - dropout_rate: Dropout rate. - - """ - - def __init__( - self, - input_size: int, - output_size: int, - kernel_size: Union[int, Tuple], - stride: Union[int, Tuple] = 1, - dilation: Union[int, Tuple] = 1, - groups: Union[int, Tuple] = 1, - bias: bool = True, - batch_norm: bool = False, - relu: bool = True, - causal: bool = False, - dropout_rate: float = 0.0, - ) -> None: - """Construct a Conv1d object.""" - super().__init__() - - if causal: - self.lorder = kernel_size - 1 - stride = 1 - else: - self.lorder = 0 - stride = stride - - self.conv = torch.nn.Conv1d( - input_size, - output_size, - kernel_size, - stride=stride, - dilation=dilation, - groups=groups, - bias=bias, - ) - - self.dropout = torch.nn.Dropout(p=dropout_rate) - - if relu: - self.relu_func = torch.nn.ReLU() - - if batch_norm: - self.bn = torch.nn.BatchNorm1d(output_size) - - self.out_pos = torch.nn.Linear(input_size, output_size) - - self.input_size = input_size - self.output_size = output_size - - self.relu = relu - self.batch_norm = batch_norm - self.causal = causal - - self.kernel_size = kernel_size - self.padding = dilation * (kernel_size - 1) - self.stride = stride - - self.cache = None - - def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: - """Initialize/Reset Conv1d cache for streaming. - - Args: - left_context: Number of left frames during chunk-by-chunk inference. - device: Device to use for cache tensor. - - """ - self.cache = torch.zeros( - (1, self.input_size, self.kernel_size - 1), device=device - ) - - def forward( - self, - x: torch.Tensor, - pos_enc: torch.Tensor, - mask: Optional[torch.Tensor] = None, - chunk_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Encode input sequences. - - Args: - x: Conv1d input sequences. (B, T, D_in) - pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_in) - mask: Source mask. (B, T) - chunk_mask: Chunk mask. (T_2, T_2) - - Returns: - x: Conv1d output sequences. (B, sub(T), D_out) - mask: Source mask. (B, T) or (B, sub(T)) - pos_enc: Positional embedding sequences. - (B, 2 * (T - 1), D_att) or (B, 2 * (sub(T) - 1), D_out) - - """ - x = x.transpose(1, 2) - - if self.lorder > 0: - x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) - else: - mask = self.create_new_mask(mask) - pos_enc = self.create_new_pos_enc(pos_enc) - - x = self.conv(x) - - if self.batch_norm: - x = self.bn(x) - - x = self.dropout(x) - - if self.relu: - x = self.relu_func(x) - - x = x.transpose(1, 2) - - return x, mask, self.out_pos(pos_enc) - - def chunk_forward( - self, - x: torch.Tensor, - pos_enc: torch.Tensor, - mask: torch.Tensor, - left_context: int = 0, - right_context: int = 0, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Encode chunk of input sequence. - - Args: - x: Conv1d input sequences. (B, T, D_in) - pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_in) - mask: Source mask. (B, T) - left_context: Number of frames in left context. - right_context: Number of frames in right context. - - Returns: - x: Conv1d output sequences. (B, T, D_out) - pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_out) - - """ - x = torch.cat([self.cache, x.transpose(1, 2)], dim=2) - - if right_context > 0: - self.cache = x[:, :, -(self.lorder + right_context) : -right_context] - else: - self.cache = x[:, :, -self.lorder :] - - x = self.conv(x) - - if self.batch_norm: - x = self.bn(x) - - x = self.dropout(x) - - if self.relu: - x = self.relu_func(x) - - x = x.transpose(1, 2) - - return x, self.out_pos(pos_enc) - - def create_new_mask(self, mask: torch.Tensor) -> torch.Tensor: - """Create new mask for output sequences. - - Args: - mask: Mask of input sequences. (B, T) - - Returns: - mask: Mask of output sequences. (B, sub(T)) - - """ - if self.padding != 0: - mask = mask[:, : -self.padding] - - return mask[:, :: self.stride] - - def create_new_pos_enc(self, pos_enc: torch.Tensor) -> torch.Tensor: - """Create new positional embedding vector. - - Args: - pos_enc: Input sequences positional embedding. - (B, 2 * (T - 1), D_in) - - Returns: - pos_enc: Output sequences positional embedding. - (B, 2 * (sub(T) - 1), D_in) - - """ - pos_enc_positive = pos_enc[:, : pos_enc.size(1) // 2 + 1, :] - pos_enc_negative = pos_enc[:, pos_enc.size(1) // 2 :, :] - - if self.padding != 0: - pos_enc_positive = pos_enc_positive[:, : -self.padding, :] - pos_enc_negative = pos_enc_negative[:, : -self.padding, :] - - pos_enc_positive = pos_enc_positive[:, :: self.stride, :] - pos_enc_negative = pos_enc_negative[:, :: self.stride, :] - - pos_enc = torch.cat([pos_enc_positive, pos_enc_negative[:, 1:, :]], dim=1) - - return pos_enc diff --git a/funasr/models/encoder/chunk_encoder_blocks/conv_input.py b/funasr/models/encoder/chunk_encoder_blocks/conv_input.py deleted file mode 100644 index b9bd2fdc2..000000000 --- a/funasr/models/encoder/chunk_encoder_blocks/conv_input.py +++ /dev/null @@ -1,222 +0,0 @@ -"""ConvInput block for Transducer encoder.""" - -from typing import Optional, Tuple, Union - -import torch -import math - -from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len - - -class ConvInput(torch.nn.Module): - """ConvInput module definition. - - Args: - input_size: Input size. - conv_size: Convolution size. - subsampling_factor: Subsampling factor. - vgg_like: Whether to use a VGG-like network. - output_size: Block output dimension. - - """ - - def __init__( - self, - input_size: int, - conv_size: Union[int, Tuple], - subsampling_factor: int = 4, - vgg_like: bool = True, - output_size: Optional[int] = None, - ) -> None: - """Construct a ConvInput object.""" - super().__init__() - if vgg_like: - if subsampling_factor == 1: - conv_size1, conv_size2 = conv_size - - self.conv = torch.nn.Sequential( - torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1), - torch.nn.ReLU(), - torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1), - torch.nn.ReLU(), - torch.nn.MaxPool2d((1, 2)), - torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1), - torch.nn.ReLU(), - torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1), - torch.nn.ReLU(), - torch.nn.MaxPool2d((1, 2)), - ) - - output_proj = conv_size2 * ((input_size // 2) // 2) - - self.subsampling_factor = 1 - - self.stride_1 = 1 - - self.create_new_mask = self.create_new_vgg_mask - - else: - conv_size1, conv_size2 = conv_size - - kernel_1 = int(subsampling_factor / 2) - - self.conv = torch.nn.Sequential( - torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1), - torch.nn.ReLU(), - torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1), - torch.nn.ReLU(), - torch.nn.MaxPool2d((kernel_1, 2)), - torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1), - torch.nn.ReLU(), - torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1), - torch.nn.ReLU(), - torch.nn.MaxPool2d((2, 2)), - ) - - output_proj = conv_size2 * ((input_size // 2) // 2) - - self.subsampling_factor = subsampling_factor - - self.create_new_mask = self.create_new_vgg_mask - - self.stride_1 = kernel_1 - - else: - if subsampling_factor == 1: - self.conv = torch.nn.Sequential( - torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]), - torch.nn.ReLU(), - torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]), - torch.nn.ReLU(), - ) - - output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2) - - self.subsampling_factor = subsampling_factor - self.kernel_2 = 3 - self.stride_2 = 1 - - self.create_new_mask = self.create_new_conv2d_mask - - else: - kernel_2, stride_2, conv_2_output_size = sub_factor_to_params( - subsampling_factor, - input_size, - ) - - self.conv = torch.nn.Sequential( - torch.nn.Conv2d(1, conv_size, 3, 2), - torch.nn.ReLU(), - torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2), - torch.nn.ReLU(), - ) - - output_proj = conv_size * conv_2_output_size - - self.subsampling_factor = subsampling_factor - self.kernel_2 = kernel_2 - self.stride_2 = stride_2 - - self.create_new_mask = self.create_new_conv2d_mask - - self.vgg_like = vgg_like - self.min_frame_length = 7 - - if output_size is not None: - self.output = torch.nn.Linear(output_proj, output_size) - self.output_size = output_size - else: - self.output = None - self.output_size = output_proj - - def forward( - self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Encode input sequences. - - Args: - x: ConvInput input sequences. (B, T, D_feats) - mask: Mask of input sequences. (B, 1, T) - - Returns: - x: ConvInput output sequences. (B, sub(T), D_out) - mask: Mask of output sequences. (B, 1, sub(T)) - - """ - if mask is not None: - mask = self.create_new_mask(mask) - olens = max(mask.eq(0).sum(1)) - - b, t, f = x.size() - x = x.unsqueeze(1) # (b. 1. t. f) - - if chunk_size is not None: - max_input_length = int( - chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) )) - ) - x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x) - x = list(x) - x = torch.stack(x, dim=0) - N_chunks = max_input_length // ( chunk_size * self.subsampling_factor) - x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f) - - x = self.conv(x) - - _, c, _, f = x.size() - if chunk_size is not None: - x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:] - else: - x = x.transpose(1, 2).contiguous().view(b, -1, c * f) - - if self.output is not None: - x = self.output(x) - - return x, mask[:,:olens][:,:x.size(1)] - - def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor: - """Create a new mask for VGG output sequences. - - Args: - mask: Mask of input sequences. (B, T) - - Returns: - mask: Mask of output sequences. (B, sub(T)) - - """ - if self.subsampling_factor > 1: - vgg1_t_len = mask.size(1) - (mask.size(1) % (self.subsampling_factor // 2 )) - mask = mask[:, :vgg1_t_len][:, ::self.subsampling_factor // 2] - - vgg2_t_len = mask.size(1) - (mask.size(1) % 2) - mask = mask[:, :vgg2_t_len][:, ::2] - else: - mask = mask - - return mask - - def create_new_conv2d_mask(self, mask: torch.Tensor) -> torch.Tensor: - """Create new conformer mask for Conv2d output sequences. - - Args: - mask: Mask of input sequences. (B, T) - - Returns: - mask: Mask of output sequences. (B, sub(T)) - - """ - if self.subsampling_factor > 1: - return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2] - else: - return mask - - def get_size_before_subsampling(self, size: int) -> int: - """Return the original size before subsampling for a given size. - - Args: - size: Number of frames after subsampling. - - Returns: - : Number of frames before subsampling. - - """ - return size * self.subsampling_factor diff --git a/funasr/models/encoder/chunk_encoder_blocks/linear_input.py b/funasr/models/encoder/chunk_encoder_blocks/linear_input.py deleted file mode 100644 index 9bb9698a7..000000000 --- a/funasr/models/encoder/chunk_encoder_blocks/linear_input.py +++ /dev/null @@ -1,52 +0,0 @@ -"""LinearInput block for Transducer encoder.""" - -from typing import Optional, Tuple, Union - -import torch - -class LinearInput(torch.nn.Module): - """ConvInput module definition. - - Args: - input_size: Input size. - conv_size: Convolution size. - subsampling_factor: Subsampling factor. - vgg_like: Whether to use a VGG-like network. - output_size: Block output dimension. - - """ - - def __init__( - self, - input_size: int, - output_size: Optional[int] = None, - subsampling_factor: int = 1, - ) -> None: - """Construct a ConvInput object.""" - super().__init__() - self.embed = torch.nn.Sequential( - torch.nn.Linear(input_size, output_size), - torch.nn.LayerNorm(output_size), - torch.nn.Dropout(0.1), - ) - self.subsampling_factor = subsampling_factor - self.min_frame_length = 1 - - def forward( - self, x: torch.Tensor, mask: Optional[torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor]: - - x = self.embed(x) - return x, mask - - def get_size_before_subsampling(self, size: int) -> int: - """Return the original size before subsampling for a given size. - - Args: - size: Number of frames after subsampling. - - Returns: - : Number of frames before subsampling. - - """ - return size diff --git a/funasr/models/encoder/chunk_encoder_modules/__init__.py b/funasr/models/encoder/chunk_encoder_modules/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/funasr/models/encoder/chunk_encoder_modules/attention.py b/funasr/models/encoder/chunk_encoder_modules/attention.py deleted file mode 100644 index 53e708750..000000000 --- a/funasr/models/encoder/chunk_encoder_modules/attention.py +++ /dev/null @@ -1,246 +0,0 @@ -"""Multi-Head attention layers with relative positional encoding.""" - -import math -from typing import Optional, Tuple - -import torch - - -class RelPositionMultiHeadedAttention(torch.nn.Module): - """RelPositionMultiHeadedAttention definition. - - Args: - num_heads: Number of attention heads. - embed_size: Embedding size. - dropout_rate: Dropout rate. - - """ - - def __init__( - self, - num_heads: int, - embed_size: int, - dropout_rate: float = 0.0, - simplified_attention_score: bool = False, - ) -> None: - """Construct an MultiHeadedAttention object.""" - super().__init__() - - self.d_k = embed_size // num_heads - self.num_heads = num_heads - - assert self.d_k * num_heads == embed_size, ( - "embed_size (%d) must be divisible by num_heads (%d)", - (embed_size, num_heads), - ) - - self.linear_q = torch.nn.Linear(embed_size, embed_size) - self.linear_k = torch.nn.Linear(embed_size, embed_size) - self.linear_v = torch.nn.Linear(embed_size, embed_size) - - self.linear_out = torch.nn.Linear(embed_size, embed_size) - - if simplified_attention_score: - self.linear_pos = torch.nn.Linear(embed_size, num_heads) - - self.compute_att_score = self.compute_simplified_attention_score - else: - self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False) - - self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) - self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) - torch.nn.init.xavier_uniform_(self.pos_bias_u) - torch.nn.init.xavier_uniform_(self.pos_bias_v) - - self.compute_att_score = self.compute_attention_score - - self.dropout = torch.nn.Dropout(p=dropout_rate) - self.attn = None - - def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor: - """Compute relative positional encoding. - - Args: - x: Input sequence. (B, H, T_1, 2 * T_1 - 1) - left_context: Number of frames in left context. - - Returns: - x: Output sequence. (B, H, T_1, T_2) - - """ - batch_size, n_heads, time1, n = x.shape - time2 = time1 + left_context - - batch_stride, n_heads_stride, time1_stride, n_stride = x.stride() - - return x.as_strided( - (batch_size, n_heads, time1, time2), - (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride), - storage_offset=(n_stride * (time1 - 1)), - ) - - def compute_simplified_attention_score( - self, - query: torch.Tensor, - key: torch.Tensor, - pos_enc: torch.Tensor, - left_context: int = 0, - ) -> torch.Tensor: - """Simplified attention score computation. - - Reference: https://github.com/k2-fsa/icefall/pull/458 - - Args: - query: Transformed query tensor. (B, H, T_1, d_k) - key: Transformed key tensor. (B, H, T_2, d_k) - pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) - left_context: Number of frames in left context. - - Returns: - : Attention score. (B, H, T_1, T_2) - - """ - pos_enc = self.linear_pos(pos_enc) - - matrix_ac = torch.matmul(query, key.transpose(2, 3)) - - matrix_bd = self.rel_shift( - pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1), - left_context=left_context, - ) - - return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) - - def compute_attention_score( - self, - query: torch.Tensor, - key: torch.Tensor, - pos_enc: torch.Tensor, - left_context: int = 0, - ) -> torch.Tensor: - """Attention score computation. - - Args: - query: Transformed query tensor. (B, H, T_1, d_k) - key: Transformed key tensor. (B, H, T_2, d_k) - pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) - left_context: Number of frames in left context. - - Returns: - : Attention score. (B, H, T_1, T_2) - - """ - p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k) - - query = query.transpose(1, 2) - q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) - q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) - - matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1)) - - matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1)) - matrix_bd = self.rel_shift(matrix_bd, left_context=left_context) - - return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) - - def forward_qkv( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Transform query, key and value. - - Args: - query: Query tensor. (B, T_1, size) - key: Key tensor. (B, T_2, size) - v: Value tensor. (B, T_2, size) - - Returns: - q: Transformed query tensor. (B, H, T_1, d_k) - k: Transformed key tensor. (B, H, T_2, d_k) - v: Transformed value tensor. (B, H, T_2, d_k) - - """ - n_batch = query.size(0) - - q = ( - self.linear_q(query) - .view(n_batch, -1, self.num_heads, self.d_k) - .transpose(1, 2) - ) - k = ( - self.linear_k(key) - .view(n_batch, -1, self.num_heads, self.d_k) - .transpose(1, 2) - ) - v = ( - self.linear_v(value) - .view(n_batch, -1, self.num_heads, self.d_k) - .transpose(1, 2) - ) - - return q, k, v - - def forward_attention( - self, - value: torch.Tensor, - scores: torch.Tensor, - mask: torch.Tensor, - chunk_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Compute attention context vector. - - Args: - value: Transformed value. (B, H, T_2, d_k) - scores: Attention score. (B, H, T_1, T_2) - mask: Source mask. (B, T_2) - chunk_mask: Chunk mask. (T_1, T_1) - - Returns: - attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k) - - """ - batch_size = scores.size(0) - mask = mask.unsqueeze(1).unsqueeze(2) - if chunk_mask is not None: - mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask - scores = scores.masked_fill(mask, float("-inf")) - self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) - - attn_output = self.dropout(self.attn) - attn_output = torch.matmul(attn_output, value) - - attn_output = self.linear_out( - attn_output.transpose(1, 2) - .contiguous() - .view(batch_size, -1, self.num_heads * self.d_k) - ) - - return attn_output - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - pos_enc: torch.Tensor, - mask: torch.Tensor, - chunk_mask: Optional[torch.Tensor] = None, - left_context: int = 0, - ) -> torch.Tensor: - """Compute scaled dot product attention with rel. positional encoding. - - Args: - query: Query tensor. (B, T_1, size) - key: Key tensor. (B, T_2, size) - value: Value tensor. (B, T_2, size) - pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) - mask: Source mask. (B, T_2) - chunk_mask: Chunk mask. (T_1, T_1) - left_context: Number of frames in left context. - - Returns: - : Output tensor. (B, T_1, H * d_k) - - """ - q, k, v = self.forward_qkv(query, key, value) - scores = self.compute_att_score(q, k, pos_enc, left_context=left_context) - return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask) diff --git a/funasr/models/encoder/chunk_encoder_modules/convolution.py b/funasr/models/encoder/chunk_encoder_modules/convolution.py deleted file mode 100644 index 012538a7d..000000000 --- a/funasr/models/encoder/chunk_encoder_modules/convolution.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Convolution modules for X-former blocks.""" - -from typing import Dict, Optional, Tuple - -import torch - - -class ConformerConvolution(torch.nn.Module): - """ConformerConvolution module definition. - - Args: - channels: The number of channels. - kernel_size: Size of the convolving kernel. - activation: Type of activation function. - norm_args: Normalization module arguments. - causal: Whether to use causal convolution (set to True if streaming). - - """ - - def __init__( - self, - channels: int, - kernel_size: int, - activation: torch.nn.Module = torch.nn.ReLU(), - norm_args: Dict = {}, - causal: bool = False, - ) -> None: - """Construct an ConformerConvolution object.""" - super().__init__() - - assert (kernel_size - 1) % 2 == 0 - - self.kernel_size = kernel_size - - self.pointwise_conv1 = torch.nn.Conv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - ) - - if causal: - self.lorder = kernel_size - 1 - padding = 0 - else: - self.lorder = 0 - padding = (kernel_size - 1) // 2 - - self.depthwise_conv = torch.nn.Conv1d( - channels, - channels, - kernel_size, - stride=1, - padding=padding, - groups=channels, - ) - self.norm = torch.nn.BatchNorm1d(channels, **norm_args) - self.pointwise_conv2 = torch.nn.Conv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - ) - - self.activation = activation - - def forward( - self, - x: torch.Tensor, - cache: Optional[torch.Tensor] = None, - right_context: int = 0, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute convolution module. - - Args: - x: ConformerConvolution input sequences. (B, T, D_hidden) - cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden) - right_context: Number of frames in right context. - - Returns: - x: ConformerConvolution output sequences. (B, T, D_hidden) - cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden) - - """ - x = self.pointwise_conv1(x.transpose(1, 2)) - x = torch.nn.functional.glu(x, dim=1) - - if self.lorder > 0: - if cache is None: - x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) - else: - x = torch.cat([cache, x], dim=2) - - if right_context > 0: - cache = x[:, :, -(self.lorder + right_context) : -right_context] - else: - cache = x[:, :, -self.lorder :] - - x = self.depthwise_conv(x) - x = self.activation(self.norm(x)) - - x = self.pointwise_conv2(x).transpose(1, 2) - - return x, cache - - -class ConvolutionalSpatialGatingUnit(torch.nn.Module): - """Convolutional Spatial Gating Unit module definition. - - Args: - size: Initial size to determine the number of channels. - kernel_size: Size of the convolving kernel. - norm_class: Normalization module class. - norm_args: Normalization module arguments. - dropout_rate: Dropout rate. - causal: Whether to use causal convolution (set to True if streaming). - - """ - - def __init__( - self, - size: int, - kernel_size: int, - norm_class: torch.nn.Module = torch.nn.LayerNorm, - norm_args: Dict = {}, - dropout_rate: float = 0.0, - causal: bool = False, - ) -> None: - """Construct a ConvolutionalSpatialGatingUnit object.""" - super().__init__() - - channels = size // 2 - - self.kernel_size = kernel_size - - if causal: - self.lorder = kernel_size - 1 - padding = 0 - else: - self.lorder = 0 - padding = (kernel_size - 1) // 2 - - self.conv = torch.nn.Conv1d( - channels, - channels, - kernel_size, - stride=1, - padding=padding, - groups=channels, - ) - - self.norm = norm_class(channels, **norm_args) - self.activation = torch.nn.Identity() - - self.dropout = torch.nn.Dropout(dropout_rate) - - def forward( - self, - x: torch.Tensor, - cache: Optional[torch.Tensor] = None, - right_context: int = 0, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute convolution module. - - Args: - x: ConvolutionalSpatialGatingUnit input sequences. (B, T, D_hidden) - cache: ConvolutionalSpationGatingUnit input cache. - (1, conv_kernel, D_hidden) - right_context: Number of frames in right context. - - Returns: - x: ConvolutionalSpatialGatingUnit output sequences. (B, T, D_hidden // 2) - - """ - x_r, x_g = x.chunk(2, dim=-1) - - x_g = self.norm(x_g).transpose(1, 2) - - if self.lorder > 0: - if cache is None: - x_g = torch.nn.functional.pad(x_g, (self.lorder, 0), "constant", 0.0) - else: - x_g = torch.cat([cache, x_g], dim=2) - - if right_context > 0: - cache = x_g[:, :, -(self.lorder + right_context) : -right_context] - else: - cache = x_g[:, :, -self.lorder :] - - x_g = self.conv(x_g).transpose(1, 2) - - x = self.dropout(x_r * self.activation(x_g)) - - return x, cache diff --git a/funasr/models/encoder/chunk_encoder_modules/multi_blocks.py b/funasr/models/encoder/chunk_encoder_modules/multi_blocks.py deleted file mode 100644 index 14aca8b6d..000000000 --- a/funasr/models/encoder/chunk_encoder_modules/multi_blocks.py +++ /dev/null @@ -1,105 +0,0 @@ -"""MultiBlocks for encoder architecture.""" - -from typing import Dict, List, Optional - -import torch - - -class MultiBlocks(torch.nn.Module): - """MultiBlocks definition. - - Args: - block_list: Individual blocks of the encoder architecture. - output_size: Architecture output size. - norm_class: Normalization module class. - norm_args: Normalization module arguments. - - """ - - def __init__( - self, - block_list: List[torch.nn.Module], - output_size: int, - norm_class: torch.nn.Module = torch.nn.LayerNorm, - norm_args: Optional[Dict] = None, - ) -> None: - """Construct a MultiBlocks object.""" - super().__init__() - - self.blocks = torch.nn.ModuleList(block_list) - self.norm_blocks = norm_class(output_size, **norm_args) - - self.num_blocks = len(block_list) - - def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: - """Initialize/Reset encoder streaming cache. - - Args: - left_context: Number of left frames during chunk-by-chunk inference. - device: Device to use for cache tensor. - - """ - for idx in range(self.num_blocks): - self.blocks[idx].reset_streaming_cache(left_context, device) - - def forward( - self, - x: torch.Tensor, - pos_enc: torch.Tensor, - mask: torch.Tensor, - chunk_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward each block of the encoder architecture. - - Args: - x: MultiBlocks input sequences. (B, T, D_block_1) - pos_enc: Positional embedding sequences. - mask: Source mask. (B, T) - chunk_mask: Chunk mask. (T_2, T_2) - - Returns: - x: Output sequences. (B, T, D_block_N) - - """ - for block_index, block in enumerate(self.blocks): - x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask) - - x = self.norm_blocks(x) - - return x - - def chunk_forward( - self, - x: torch.Tensor, - pos_enc: torch.Tensor, - mask: torch.Tensor, - chunk_size: int = 0, - left_context: int = 0, - right_context: int = 0, - ) -> torch.Tensor: - """Forward each block of the encoder architecture. - - Args: - x: MultiBlocks input sequences. (B, T, D_block_1) - pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att) - mask: Source mask. (B, T_2) - left_context: Number of frames in left context. - right_context: Number of frames in right context. - - Returns: - x: MultiBlocks output sequences. (B, T, D_block_N) - - """ - for block_idx, block in enumerate(self.blocks): - x, pos_enc = block.chunk_forward( - x, - pos_enc, - mask, - chunk_size=chunk_size, - left_context=left_context, - right_context=right_context, - ) - - x = self.norm_blocks(x) - - return x diff --git a/funasr/models/encoder/chunk_encoder_modules/positional_encoding.py b/funasr/models/encoder/chunk_encoder_modules/positional_encoding.py deleted file mode 100644 index 5b56e2671..000000000 --- a/funasr/models/encoder/chunk_encoder_modules/positional_encoding.py +++ /dev/null @@ -1,91 +0,0 @@ -"""Positional encoding modules.""" - -import math - -import torch - -from funasr.modules.embedding import _pre_hook - - -class RelPositionalEncoding(torch.nn.Module): - """Relative positional encoding. - - Args: - size: Module size. - max_len: Maximum input length. - dropout_rate: Dropout rate. - - """ - - def __init__( - self, size: int, dropout_rate: float = 0.0, max_len: int = 5000 - ) -> None: - """Construct a RelativePositionalEncoding object.""" - super().__init__() - - self.size = size - - self.pe = None - self.dropout = torch.nn.Dropout(p=dropout_rate) - - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - self._register_load_state_dict_pre_hook(_pre_hook) - - def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None: - """Reset positional encoding. - - Args: - x: Input sequences. (B, T, ?) - left_context: Number of frames in left context. - - """ - time1 = x.size(1) + left_context - - if self.pe is not None: - if self.pe.size(1) >= time1 * 2 - 1: - if self.pe.dtype != x.dtype or self.pe.device != x.device: - self.pe = self.pe.to(device=x.device, dtype=x.dtype) - return - - pe_positive = torch.zeros(time1, self.size) - pe_negative = torch.zeros(time1, self.size) - - position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, self.size, 2, dtype=torch.float32) - * -(math.log(10000.0) / self.size) - ) - - pe_positive[:, 0::2] = torch.sin(position * div_term) - pe_positive[:, 1::2] = torch.cos(position * div_term) - pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) - - pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) - pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) - pe_negative = pe_negative[1:].unsqueeze(0) - - self.pe = torch.cat([pe_positive, pe_negative], dim=1).to( - dtype=x.dtype, device=x.device - ) - - def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor: - """Compute positional encoding. - - Args: - x: Input sequences. (B, T, ?) - left_context: Number of frames in left context. - - Returns: - pos_enc: Positional embedding sequences. (B, 2 * (T - 1), ?) - - """ - self.extend_pe(x, left_context=left_context) - - time1 = x.size(1) + left_context - - pos_enc = self.pe[ - :, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1) - ] - pos_enc = self.dropout(pos_enc) - - return pos_enc diff --git a/funasr/models/encoder/chunk_encoder_utils/building.py b/funasr/models/encoder/chunk_encoder_utils/building.py deleted file mode 100644 index 21611aa19..000000000 --- a/funasr/models/encoder/chunk_encoder_utils/building.py +++ /dev/null @@ -1,352 +0,0 @@ -"""Set of methods to build Transducer encoder architecture.""" - -from typing import Any, Dict, List, Optional, Union - -from funasr.modules.activation import get_activation -from funasr.models.encoder.chunk_encoder_blocks.branchformer import Branchformer -from funasr.models.encoder.chunk_encoder_blocks.conformer import Conformer -from funasr.models.encoder.chunk_encoder_blocks.conv1d import Conv1d -from funasr.models.encoder.chunk_encoder_blocks.conv_input import ConvInput -from funasr.models.encoder.chunk_encoder_blocks.linear_input import LinearInput -from funasr.models.encoder.chunk_encoder_modules.attention import ( # noqa: H301 - RelPositionMultiHeadedAttention, -) -from funasr.models.encoder.chunk_encoder_modules.convolution import ( # noqa: H301 - ConformerConvolution, - ConvolutionalSpatialGatingUnit, -) -from funasr.models.encoder.chunk_encoder_modules.multi_blocks import MultiBlocks -from funasr.models.encoder.chunk_encoder_modules.normalization import get_normalization -from funasr.models.encoder.chunk_encoder_modules.positional_encoding import ( # noqa: H301 - RelPositionalEncoding, -) -from funasr.modules.positionwise_feed_forward import ( - PositionwiseFeedForward, -) - - -def build_main_parameters( - pos_wise_act_type: str = "swish", - conv_mod_act_type: str = "swish", - pos_enc_dropout_rate: float = 0.0, - pos_enc_max_len: int = 5000, - simplified_att_score: bool = False, - norm_type: str = "layer_norm", - conv_mod_norm_type: str = "layer_norm", - after_norm_eps: Optional[float] = None, - after_norm_partial: Optional[float] = None, - dynamic_chunk_training: bool = False, - short_chunk_threshold: float = 0.75, - short_chunk_size: int = 25, - left_chunk_size: int = 0, - time_reduction_factor: int = 1, - unified_model_training: bool = False, - default_chunk_size: int = 16, - jitter_range: int =4, - **activation_parameters, -) -> Dict[str, Any]: - """Build encoder main parameters. - - Args: - pos_wise_act_type: Conformer position-wise feed-forward activation type. - conv_mod_act_type: Conformer convolution module activation type. - pos_enc_dropout_rate: Positional encoding dropout rate. - pos_enc_max_len: Positional encoding maximum length. - simplified_att_score: Whether to use simplified attention score computation. - norm_type: X-former normalization module type. - conv_mod_norm_type: Conformer convolution module normalization type. - after_norm_eps: Epsilon value for the final normalization. - after_norm_partial: Value for the final normalization with RMSNorm. - dynamic_chunk_training: Whether to use dynamic chunk training. - short_chunk_threshold: Threshold for dynamic chunk selection. - short_chunk_size: Minimum number of frames during dynamic chunk training. - left_chunk_size: Number of frames in left context. - **activations_parameters: Parameters of the activation functions. - (See espnet2/asr_transducer/activation.py) - - Returns: - : Main encoder parameters - - """ - main_params = {} - - main_params["pos_wise_act"] = get_activation( - pos_wise_act_type, **activation_parameters - ) - - main_params["conv_mod_act"] = get_activation( - conv_mod_act_type, **activation_parameters - ) - - main_params["pos_enc_dropout_rate"] = pos_enc_dropout_rate - main_params["pos_enc_max_len"] = pos_enc_max_len - - main_params["simplified_att_score"] = simplified_att_score - - main_params["norm_type"] = norm_type - main_params["conv_mod_norm_type"] = conv_mod_norm_type - - ( - main_params["after_norm_class"], - main_params["after_norm_args"], - ) = get_normalization(norm_type, eps=after_norm_eps, partial=after_norm_partial) - - main_params["dynamic_chunk_training"] = dynamic_chunk_training - main_params["short_chunk_threshold"] = max(0, short_chunk_threshold) - main_params["short_chunk_size"] = max(0, short_chunk_size) - main_params["left_chunk_size"] = max(0, left_chunk_size) - - main_params["unified_model_training"] = unified_model_training - main_params["default_chunk_size"] = max(0, default_chunk_size) - main_params["jitter_range"] = max(0, jitter_range) - - main_params["time_reduction_factor"] = time_reduction_factor - - return main_params - - -def build_positional_encoding( - block_size: int, configuration: Dict[str, Any] -) -> RelPositionalEncoding: - """Build positional encoding block. - - Args: - block_size: Input/output size. - configuration: Positional encoding configuration. - - Returns: - : Positional encoding module. - - """ - return RelPositionalEncoding( - block_size, - configuration.get("pos_enc_dropout_rate", 0.0), - max_len=configuration.get("pos_enc_max_len", 5000), - ) - - -def build_input_block( - input_size: int, - configuration: Dict[str, Union[str, int]], -) -> ConvInput: - """Build encoder input block. - - Args: - input_size: Input size. - configuration: Input block configuration. - - Returns: - : ConvInput block function. - - """ - if configuration["linear"]: - return LinearInput( - input_size, - configuration["output_size"], - configuration["subsampling_factor"], - ) - else: - return ConvInput( - input_size, - configuration["conv_size"], - configuration["subsampling_factor"], - vgg_like=configuration["vgg_like"], - output_size=configuration["output_size"], - ) - - -def build_branchformer_block( - configuration: List[Dict[str, Any]], - main_params: Dict[str, Any], -) -> Conformer: - """Build Branchformer block. - - Args: - configuration: Branchformer block configuration. - main_params: Encoder main parameters. - - Returns: - : Branchformer block function. - - """ - hidden_size = configuration["hidden_size"] - linear_size = configuration["linear_size"] - - dropout_rate = configuration.get("dropout_rate", 0.0) - - conv_mod_norm_class, conv_mod_norm_args = get_normalization( - main_params["conv_mod_norm_type"], - eps=configuration.get("conv_mod_norm_eps"), - partial=configuration.get("conv_mod_norm_partial"), - ) - - conv_mod_args = ( - linear_size, - configuration["conv_mod_kernel_size"], - conv_mod_norm_class, - conv_mod_norm_args, - dropout_rate, - main_params["dynamic_chunk_training"], - ) - - mult_att_args = ( - configuration.get("heads", 4), - hidden_size, - configuration.get("att_dropout_rate", 0.0), - main_params["simplified_att_score"], - ) - - norm_class, norm_args = get_normalization( - main_params["norm_type"], - eps=configuration.get("norm_eps"), - partial=configuration.get("norm_partial"), - ) - - return lambda: Branchformer( - hidden_size, - linear_size, - RelPositionMultiHeadedAttention(*mult_att_args), - ConvolutionalSpatialGatingUnit(*conv_mod_args), - norm_class=norm_class, - norm_args=norm_args, - dropout_rate=dropout_rate, - ) - - -def build_conformer_block( - configuration: List[Dict[str, Any]], - main_params: Dict[str, Any], -) -> Conformer: - """Build Conformer block. - - Args: - configuration: Conformer block configuration. - main_params: Encoder main parameters. - - Returns: - : Conformer block function. - - """ - hidden_size = configuration["hidden_size"] - linear_size = configuration["linear_size"] - - pos_wise_args = ( - hidden_size, - linear_size, - configuration.get("pos_wise_dropout_rate", 0.0), - main_params["pos_wise_act"], - ) - - conv_mod_norm_args = { - "eps": configuration.get("conv_mod_norm_eps", 1e-05), - "momentum": configuration.get("conv_mod_norm_momentum", 0.1), - } - - conv_mod_args = ( - hidden_size, - configuration["conv_mod_kernel_size"], - main_params["conv_mod_act"], - conv_mod_norm_args, - main_params["dynamic_chunk_training"] or main_params["unified_model_training"], - ) - - mult_att_args = ( - configuration.get("heads", 4), - hidden_size, - configuration.get("att_dropout_rate", 0.0), - main_params["simplified_att_score"], - ) - - norm_class, norm_args = get_normalization( - main_params["norm_type"], - eps=configuration.get("norm_eps"), - partial=configuration.get("norm_partial"), - ) - - return lambda: Conformer( - hidden_size, - RelPositionMultiHeadedAttention(*mult_att_args), - PositionwiseFeedForward(*pos_wise_args), - PositionwiseFeedForward(*pos_wise_args), - ConformerConvolution(*conv_mod_args), - norm_class=norm_class, - norm_args=norm_args, - dropout_rate=configuration.get("dropout_rate", 0.0), - ) - - -def build_conv1d_block( - configuration: List[Dict[str, Any]], - causal: bool, -) -> Conv1d: - """Build Conv1d block. - - Args: - configuration: Conv1d block configuration. - - Returns: - : Conv1d block function. - - """ - return lambda: Conv1d( - configuration["input_size"], - configuration["output_size"], - configuration["kernel_size"], - stride=configuration.get("stride", 1), - dilation=configuration.get("dilation", 1), - groups=configuration.get("groups", 1), - bias=configuration.get("bias", True), - relu=configuration.get("relu", True), - batch_norm=configuration.get("batch_norm", False), - causal=causal, - dropout_rate=configuration.get("dropout_rate", 0.0), - ) - - -def build_body_blocks( - configuration: List[Dict[str, Any]], - main_params: Dict[str, Any], - output_size: int, -) -> MultiBlocks: - """Build encoder body blocks. - - Args: - configuration: Body blocks configuration. - main_params: Encoder main parameters. - output_size: Architecture output size. - - Returns: - MultiBlocks function encapsulation all encoder blocks. - - """ - fn_modules = [] - extended_conf = [] - - for c in configuration: - if c.get("num_blocks") is not None: - extended_conf += c["num_blocks"] * [ - {c_i: c[c_i] for c_i in c if c_i != "num_blocks"} - ] - else: - extended_conf += [c] - - for i, c in enumerate(extended_conf): - block_type = c["block_type"] - - if block_type == "branchformer": - module = build_branchformer_block(c, main_params) - elif block_type == "conformer": - module = build_conformer_block(c, main_params) - elif block_type == "conv1d": - module = build_conv1d_block(c, main_params["dynamic_chunk_training"]) - else: - raise NotImplementedError - - fn_modules.append(module) - - return MultiBlocks( - [fn() for fn in fn_modules], - output_size, - norm_class=main_params["after_norm_class"], - norm_args=main_params["after_norm_args"], - ) diff --git a/funasr/models/encoder/chunk_encoder_utils/validation.py b/funasr/models/encoder/chunk_encoder_utils/validation.py deleted file mode 100644 index 1103cb93f..000000000 --- a/funasr/models/encoder/chunk_encoder_utils/validation.py +++ /dev/null @@ -1,171 +0,0 @@ -"""Set of methods to validate encoder architecture.""" - -from typing import Any, Dict, List, Tuple - -from funasr.modules.nets_utils import sub_factor_to_params - - -def validate_block_arguments( - configuration: Dict[str, Any], - block_id: int, - previous_block_output: int, -) -> Tuple[int, int]: - """Validate block arguments. - - Args: - configuration: Architecture configuration. - block_id: Block ID. - previous_block_output: Previous block output size. - - Returns: - input_size: Block input size. - output_size: Block output size. - - """ - block_type = configuration.get("block_type") - - if block_type is None: - raise ValueError( - "Block %d in encoder doesn't have a type assigned. " % block_id - ) - - if block_type in ["branchformer", "conformer"]: - if configuration.get("linear_size") is None: - raise ValueError( - "Missing 'linear_size' argument for X-former block (ID: %d)" % block_id - ) - - if configuration.get("conv_mod_kernel_size") is None: - raise ValueError( - "Missing 'conv_mod_kernel_size' argument for X-former block (ID: %d)" - % block_id - ) - - input_size = configuration.get("hidden_size") - output_size = configuration.get("hidden_size") - - elif block_type == "conv1d": - output_size = configuration.get("output_size") - - if output_size is None: - raise ValueError( - "Missing 'output_size' argument for Conv1d block (ID: %d)" % block_id - ) - - if configuration.get("kernel_size") is None: - raise ValueError( - "Missing 'kernel_size' argument for Conv1d block (ID: %d)" % block_id - ) - - input_size = configuration["input_size"] = previous_block_output - else: - raise ValueError("Block type: %s is not supported." % block_type) - - return input_size, output_size - - -def validate_input_block( - configuration: Dict[str, Any], body_first_conf: Dict[str, Any], input_size: int -) -> int: - """Validate input block. - - Args: - configuration: Encoder input block configuration. - body_first_conf: Encoder first body block configuration. - input_size: Encoder input block input size. - - Return: - output_size: Encoder input block output size. - - """ - vgg_like = configuration.get("vgg_like", False) - linear = configuration.get("linear", False) - next_block_type = body_first_conf.get("block_type") - allowed_next_block_type = ["branchformer", "conformer", "conv1d"] - - if next_block_type is None or (next_block_type not in allowed_next_block_type): - return -1 - - if configuration.get("subsampling_factor") is None: - configuration["subsampling_factor"] = 4 - - if vgg_like: - conv_size = configuration.get("conv_size", (64, 128)) - - if isinstance(conv_size, int): - conv_size = (conv_size, conv_size) - else: - conv_size = configuration.get("conv_size", None) - - if isinstance(conv_size, tuple): - conv_size = conv_size[0] - - if next_block_type == "conv1d": - if vgg_like: - output_size = conv_size[1] * ((input_size // 2) // 2) - else: - if conv_size is None: - conv_size = body_first_conf.get("output_size", 64) - - sub_factor = configuration["subsampling_factor"] - - _, _, conv_osize = sub_factor_to_params(sub_factor, input_size) - assert ( - conv_osize > 0 - ), "Conv2D output size is <1 with input size %d and subsampling %d" % ( - input_size, - sub_factor, - ) - - output_size = conv_osize * conv_size - - configuration["output_size"] = None - else: - output_size = body_first_conf.get("hidden_size") - - if conv_size is None: - conv_size = output_size - - configuration["output_size"] = output_size - - configuration["conv_size"] = conv_size - configuration["vgg_like"] = vgg_like - configuration["linear"] = linear - - return output_size - - -def validate_architecture( - input_conf: Dict[str, Any], body_conf: List[Dict[str, Any]], input_size: int -) -> Tuple[int, int]: - """Validate specified architecture is valid. - - Args: - input_conf: Encoder input block configuration. - body_conf: Encoder body blocks configuration. - input_size: Encoder input size. - - Returns: - input_block_osize: Encoder input block output size. - : Encoder body block output size. - - """ - input_block_osize = validate_input_block(input_conf, body_conf[0], input_size) - - cmp_io = [] - - for i, b in enumerate(body_conf): - _io = validate_block_arguments( - b, (i + 1), input_block_osize if i == 0 else cmp_io[i - 1][1] - ) - - cmp_io.append(_io) - - for i in range(1, len(cmp_io)): - if cmp_io[(i - 1)][1] != cmp_io[i][0]: - raise ValueError( - "Output/Input mismatch between blocks %d and %d" - " in the encoder body." % ((i - 1), i) - ) - - return input_block_osize, cmp_io[-1][1] diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py index 7c7f66142..c837cf533 100644 --- a/funasr/models/encoder/conformer_encoder.py +++ b/funasr/models/encoder/conformer_encoder.py @@ -8,6 +8,7 @@ from typing import List from typing import Optional from typing import Tuple from typing import Union +from typing import Dict import torch from torch import nn @@ -18,6 +19,7 @@ from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.modules.attention import ( MultiHeadedAttention, # noqa: H301 RelPositionMultiHeadedAttention, # noqa: H301 + RelPositionMultiHeadedAttentionChunk, LegacyRelPositionMultiHeadedAttention, # noqa: H301 ) from funasr.modules.embedding import ( @@ -25,16 +27,24 @@ from funasr.modules.embedding import ( ScaledPositionalEncoding, # noqa: H301 RelPositionalEncoding, # noqa: H301 LegacyRelPositionalEncoding, # noqa: H301 + StreamingRelPositionalEncoding, ) from funasr.modules.layer_norm import LayerNorm +from funasr.modules.normalization import get_normalization from funasr.modules.multi_layer_conv import Conv1dLinear from funasr.modules.multi_layer_conv import MultiLayeredConv1d from funasr.modules.nets_utils import get_activation from funasr.modules.nets_utils import make_pad_mask +from funasr.modules.nets_utils import ( + TooShortUttError, + check_short_utt, + make_chunk_mask, + make_source_mask, +) from funasr.modules.positionwise_feed_forward import ( PositionwiseFeedForward, # noqa: H301 ) -from funasr.modules.repeat import repeat +from funasr.modules.repeat import repeat, MultiBlocks from funasr.modules.subsampling import Conv2dSubsampling from funasr.modules.subsampling import Conv2dSubsampling2 from funasr.modules.subsampling import Conv2dSubsampling6 @@ -42,6 +52,8 @@ from funasr.modules.subsampling import Conv2dSubsampling8 from funasr.modules.subsampling import TooShortUttError from funasr.modules.subsampling import check_short_utt from funasr.modules.subsampling import Conv2dSubsamplingPad +from funasr.modules.subsampling import StreamingConvInput + class ConvolutionModule(nn.Module): """ConvolutionModule in Conformer model. @@ -276,6 +288,188 @@ class EncoderLayer(nn.Module): return x, mask +class ChunkEncoderLayer(torch.nn.Module): + """Chunk Conformer module definition. + Args: + block_size: Input/output size. + self_att: Self-attention module instance. + feed_forward: Feed-forward module instance. + feed_forward_macaron: Feed-forward module instance for macaron network. + conv_mod: Convolution module instance. + norm_class: Normalization module class. + norm_args: Normalization module arguments. + dropout_rate: Dropout rate. + """ + + def __init__( + self, + block_size: int, + self_att: torch.nn.Module, + feed_forward: torch.nn.Module, + feed_forward_macaron: torch.nn.Module, + conv_mod: torch.nn.Module, + norm_class: torch.nn.Module = torch.nn.LayerNorm, + norm_args: Dict = {}, + dropout_rate: float = 0.0, + ) -> None: + """Construct a Conformer object.""" + super().__init__() + + self.self_att = self_att + + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.feed_forward_scale = 0.5 + + self.conv_mod = conv_mod + + self.norm_feed_forward = norm_class(block_size, **norm_args) + self.norm_self_att = norm_class(block_size, **norm_args) + + self.norm_macaron = norm_class(block_size, **norm_args) + self.norm_conv = norm_class(block_size, **norm_args) + self.norm_final = norm_class(block_size, **norm_args) + + self.dropout = torch.nn.Dropout(dropout_rate) + + self.block_size = block_size + self.cache = None + + def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: + """Initialize/Reset self-attention and convolution modules cache for streaming. + Args: + left_context: Number of left frames during chunk-by-chunk inference. + device: Device to use for cache tensor. + """ + self.cache = [ + torch.zeros( + (1, left_context, self.block_size), + device=device, + ), + torch.zeros( + ( + 1, + self.block_size, + self.conv_mod.kernel_size - 1, + ), + device=device, + ), + ] + + def forward( + self, + x: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + chunk_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Encode input sequences. + Args: + x: Conformer input sequences. (B, T, D_block) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + mask: Source mask. (B, T) + chunk_mask: Chunk mask. (T_2, T_2) + Returns: + x: Conformer output sequences. (B, T, D_block) + mask: Source mask. (B, T) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + """ + residual = x + + x = self.norm_macaron(x) + x = residual + self.feed_forward_scale * self.dropout( + self.feed_forward_macaron(x) + ) + + residual = x + x = self.norm_self_att(x) + x_q = x + x = residual + self.dropout( + self.self_att( + x_q, + x, + x, + pos_enc, + mask, + chunk_mask=chunk_mask, + ) + ) + + residual = x + + x = self.norm_conv(x) + x, _ = self.conv_mod(x) + x = residual + self.dropout(x) + residual = x + + x = self.norm_feed_forward(x) + x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x)) + + x = self.norm_final(x) + return x, mask, pos_enc + + def chunk_forward( + self, + x: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + chunk_size: int = 16, + left_context: int = 0, + right_context: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode chunk of input sequence. + Args: + x: Conformer input sequences. (B, T, D_block) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + mask: Source mask. (B, T_2) + left_context: Number of frames in left context. + right_context: Number of frames in right context. + Returns: + x: Conformer output sequences. (B, T, D_block) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + """ + residual = x + + x = self.norm_macaron(x) + x = residual + self.feed_forward_scale * self.feed_forward_macaron(x) + + residual = x + x = self.norm_self_att(x) + if left_context > 0: + key = torch.cat([self.cache[0], x], dim=1) + else: + key = x + val = key + + if right_context > 0: + att_cache = key[:, -(left_context + right_context) : -right_context, :] + else: + att_cache = key[:, -left_context:, :] + x = residual + self.self_att( + x, + key, + val, + pos_enc, + mask, + left_context=left_context, + ) + + residual = x + x = self.norm_conv(x) + x, conv_cache = self.conv_mod( + x, cache=self.cache[1], right_context=right_context + ) + x = residual + x + residual = x + + x = self.norm_feed_forward(x) + x = residual + self.feed_forward_scale * self.feed_forward(x) + + x = self.norm_final(x) + self.cache = [att_cache, conv_cache] + + return x, pos_enc + class ConformerEncoder(AbsEncoder): """Conformer encoder module. @@ -604,3 +798,447 @@ class ConformerEncoder(AbsEncoder): if len(intermediate_outs) > 0: return (xs_pad, intermediate_outs), olens, None return xs_pad, olens, None + + +class CausalConvolution(torch.nn.Module): + """ConformerConvolution module definition. + Args: + channels: The number of channels. + kernel_size: Size of the convolving kernel. + activation: Type of activation function. + norm_args: Normalization module arguments. + causal: Whether to use causal convolution (set to True if streaming). + """ + + def __init__( + self, + channels: int, + kernel_size: int, + activation: torch.nn.Module = torch.nn.ReLU(), + norm_args: Dict = {}, + causal: bool = False, + ) -> None: + """Construct an ConformerConvolution object.""" + super().__init__() + + assert (kernel_size - 1) % 2 == 0 + + self.kernel_size = kernel_size + + self.pointwise_conv1 = torch.nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + ) + + if causal: + self.lorder = kernel_size - 1 + padding = 0 + else: + self.lorder = 0 + padding = (kernel_size - 1) // 2 + + self.depthwise_conv = torch.nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + ) + self.norm = torch.nn.BatchNorm1d(channels, **norm_args) + self.pointwise_conv2 = torch.nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + ) + + self.activation = activation + + def forward( + self, + x: torch.Tensor, + cache: Optional[torch.Tensor] = None, + right_context: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute convolution module. + Args: + x: ConformerConvolution input sequences. (B, T, D_hidden) + cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden) + right_context: Number of frames in right context. + Returns: + x: ConformerConvolution output sequences. (B, T, D_hidden) + cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden) + """ + x = self.pointwise_conv1(x.transpose(1, 2)) + x = torch.nn.functional.glu(x, dim=1) + + if self.lorder > 0: + if cache is None: + x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + else: + x = torch.cat([cache, x], dim=2) + + if right_context > 0: + cache = x[:, :, -(self.lorder + right_context) : -right_context] + else: + cache = x[:, :, -self.lorder :] + + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x).transpose(1, 2) + + return x, cache + +class ConformerChunkEncoder(torch.nn.Module): + """Encoder module definition. + Args: + input_size: Input size. + body_conf: Encoder body configuration. + input_conf: Encoder input configuration. + main_conf: Encoder main configuration. + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + embed_vgg_like: bool = False, + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = "linear", + positionwise_conv_kernel_size: int = 3, + macaron_style: bool = False, + rel_pos_type: str = "legacy", + pos_enc_layer_type: str = "rel_pos", + selfattention_layer_type: str = "rel_selfattn", + activation_type: str = "swish", + use_cnn_module: bool = True, + zero_triu: bool = False, + norm_type: str = "layer_norm", + cnn_module_kernel: int = 31, + conv_mod_norm_eps: float = 0.00001, + conv_mod_norm_momentum: float = 0.1, + simplified_att_score: bool = False, + dynamic_chunk_training: bool = False, + short_chunk_threshold: float = 0.75, + short_chunk_size: int = 25, + left_chunk_size: int = 0, + time_reduction_factor: int = 1, + unified_model_training: bool = False, + default_chunk_size: int = 16, + jitter_range: int = 4, + subsampling_factor: int = 1, + **activation_parameters, + ) -> None: + """Construct an Encoder object.""" + super().__init__() + + assert check_argument_types() + + self.embed = StreamingConvInput( + input_size, + output_size, + subsampling_factor, + vgg_like=embed_vgg_like, + output_size=output_size, + ) + + self.pos_enc = StreamingRelPositionalEncoding( + output_size, + positional_dropout_rate, + ) + + activation = get_activation( + activation_type, **activation_parameters + ) + + pos_wise_args = ( + output_size, + linear_units, + positional_dropout_rate, + activation, + ) + + conv_mod_norm_args = { + "eps": conv_mod_norm_eps, + "momentum": conv_mod_norm_momentum, + } + + conv_mod_args = ( + output_size, + cnn_module_kernel, + activation, + conv_mod_norm_args, + dynamic_chunk_training or unified_model_training, + ) + + mult_att_args = ( + attention_heads, + output_size, + attention_dropout_rate, + simplified_att_score, + ) + + norm_class, norm_args = get_normalization( + norm_type, + ) + + fn_modules = [] + for _ in range(num_blocks): + module = lambda: ChunkEncoderLayer( + output_size, + RelPositionMultiHeadedAttentionChunk(*mult_att_args), + PositionwiseFeedForward(*pos_wise_args), + PositionwiseFeedForward(*pos_wise_args), + CausalConvolution(*conv_mod_args), + norm_class=norm_class, + norm_args=norm_args, + dropout_rate=dropout_rate, + ) + fn_modules.append(module) + + self.encoders = MultiBlocks( + [fn() for fn in fn_modules], + output_size, + norm_class=norm_class, + norm_args=norm_args, + ) + + self.output_size = output_size + + self.dynamic_chunk_training = dynamic_chunk_training + self.short_chunk_threshold = short_chunk_threshold + self.short_chunk_size = short_chunk_size + self.left_chunk_size = left_chunk_size + + self.unified_model_training = unified_model_training + self.default_chunk_size = default_chunk_size + self.jitter_range = jitter_range + + self.time_reduction_factor = time_reduction_factor + + def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int: + """Return the corresponding number of sample for a given chunk size, in frames. + Where size is the number of features frames after applying subsampling. + Args: + size: Number of frames after subsampling. + hop_length: Frontend's hop length + Returns: + : Number of raw samples + """ + return self.embed.get_size_before_subsampling(size) * hop_length + + def get_encoder_input_size(self, size: int) -> int: + """Return the corresponding number of sample for a given chunk size, in frames. + Where size is the number of features frames after applying subsampling. + Args: + size: Number of frames after subsampling. + Returns: + : Number of raw samples + """ + return self.embed.get_size_before_subsampling(size) + + + def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: + """Initialize/Reset encoder streaming cache. + Args: + left_context: Number of frames in left context. + device: Device ID. + """ + return self.encoders.reset_streaming_cache(left_context, device) + + def forward( + self, + x: torch.Tensor, + x_len: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode input sequences. + Args: + x: Encoder input features. (B, T_in, F) + x_len: Encoder input features lengths. (B,) + Returns: + x: Encoder outputs. (B, T_out, D_enc) + x_len: Encoder outputs lenghts. (B,) + """ + short_status, limit_size = check_short_utt( + self.embed.subsampling_factor, x.size(1) + ) + + if short_status: + raise TooShortUttError( + f"has {x.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + x.size(1), + limit_size, + ) + + mask = make_source_mask(x_len) + + if self.unified_model_training: + chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item() + x, mask = self.embed(x, mask, chunk_size) + pos_enc = self.pos_enc(x) + chunk_mask = make_chunk_mask( + x.size(1), + chunk_size, + left_chunk_size=self.left_chunk_size, + device=x.device, + ) + x_utt = self.encoders( + x, + pos_enc, + mask, + chunk_mask=None, + ) + x_chunk = self.encoders( + x, + pos_enc, + mask, + chunk_mask=chunk_mask, + ) + + olens = mask.eq(0).sum(1) + if self.time_reduction_factor > 1: + x_utt = x_utt[:,::self.time_reduction_factor,:] + x_chunk = x_chunk[:,::self.time_reduction_factor,:] + olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 + + return x_utt, x_chunk, olens + + elif self.dynamic_chunk_training: + max_len = x.size(1) + chunk_size = torch.randint(1, max_len, (1,)).item() + + if chunk_size > (max_len * self.short_chunk_threshold): + chunk_size = max_len + else: + chunk_size = (chunk_size % self.short_chunk_size) + 1 + + x, mask = self.embed(x, mask, chunk_size) + pos_enc = self.pos_enc(x) + + chunk_mask = make_chunk_mask( + x.size(1), + chunk_size, + left_chunk_size=self.left_chunk_size, + device=x.device, + ) + else: + x, mask = self.embed(x, mask, None) + pos_enc = self.pos_enc(x) + chunk_mask = None + x = self.encoders( + x, + pos_enc, + mask, + chunk_mask=chunk_mask, + ) + + olens = mask.eq(0).sum(1) + if self.time_reduction_factor > 1: + x = x[:,::self.time_reduction_factor,:] + olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 + + return x, olens + + def simu_chunk_forward( + self, + x: torch.Tensor, + x_len: torch.Tensor, + chunk_size: int = 16, + left_context: int = 32, + right_context: int = 0, + ) -> torch.Tensor: + short_status, limit_size = check_short_utt( + self.embed.subsampling_factor, x.size(1) + ) + + if short_status: + raise TooShortUttError( + f"has {x.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + x.size(1), + limit_size, + ) + + mask = make_source_mask(x_len) + + x, mask = self.embed(x, mask, chunk_size) + pos_enc = self.pos_enc(x) + chunk_mask = make_chunk_mask( + x.size(1), + chunk_size, + left_chunk_size=self.left_chunk_size, + device=x.device, + ) + + x = self.encoders( + x, + pos_enc, + mask, + chunk_mask=chunk_mask, + ) + olens = mask.eq(0).sum(1) + if self.time_reduction_factor > 1: + x = x[:,::self.time_reduction_factor,:] + + return x + + def chunk_forward( + self, + x: torch.Tensor, + x_len: torch.Tensor, + processed_frames: torch.tensor, + chunk_size: int = 16, + left_context: int = 32, + right_context: int = 0, + ) -> torch.Tensor: + """Encode input sequences as chunks. + Args: + x: Encoder input features. (1, T_in, F) + x_len: Encoder input features lengths. (1,) + processed_frames: Number of frames already seen. + left_context: Number of frames in left context. + right_context: Number of frames in right context. + Returns: + x: Encoder outputs. (B, T_out, D_enc) + """ + mask = make_source_mask(x_len) + x, mask = self.embed(x, mask, None) + + if left_context > 0: + processed_mask = ( + torch.arange(left_context, device=x.device) + .view(1, left_context) + .flip(1) + ) + processed_mask = processed_mask >= processed_frames + mask = torch.cat([processed_mask, mask], dim=1) + pos_enc = self.pos_enc(x, left_context=left_context) + x = self.encoders.chunk_forward( + x, + pos_enc, + mask, + chunk_size=chunk_size, + left_context=left_context, + right_context=right_context, + ) + + if right_context > 0: + x = x[:, 0:-right_context, :] + + if self.time_reduction_factor > 1: + x = x[:,::self.time_reduction_factor,:] + return x diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py index 31d5a8775..62020796e 100644 --- a/funasr/modules/attention.py +++ b/funasr/modules/attention.py @@ -11,7 +11,7 @@ import math import numpy import torch from torch import nn - +from typing import Optional, Tuple class MultiHeadedAttention(nn.Module): """Multi-Head Attention layer. @@ -741,3 +741,221 @@ class MultiHeadSelfAttention(nn.Module): scores = torch.matmul(q_h, k_h.transpose(-2, -1)) att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) return att_outs + +class RelPositionMultiHeadedAttentionChunk(torch.nn.Module): + """RelPositionMultiHeadedAttention definition. + Args: + num_heads: Number of attention heads. + embed_size: Embedding size. + dropout_rate: Dropout rate. + """ + + def __init__( + self, + num_heads: int, + embed_size: int, + dropout_rate: float = 0.0, + simplified_attention_score: bool = False, + ) -> None: + """Construct an MultiHeadedAttention object.""" + super().__init__() + + self.d_k = embed_size // num_heads + self.num_heads = num_heads + + assert self.d_k * num_heads == embed_size, ( + "embed_size (%d) must be divisible by num_heads (%d)", + (embed_size, num_heads), + ) + + self.linear_q = torch.nn.Linear(embed_size, embed_size) + self.linear_k = torch.nn.Linear(embed_size, embed_size) + self.linear_v = torch.nn.Linear(embed_size, embed_size) + + self.linear_out = torch.nn.Linear(embed_size, embed_size) + + if simplified_attention_score: + self.linear_pos = torch.nn.Linear(embed_size, num_heads) + + self.compute_att_score = self.compute_simplified_attention_score + else: + self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False) + + self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) + self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + self.compute_att_score = self.compute_attention_score + + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.attn = None + + def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor: + """Compute relative positional encoding. + Args: + x: Input sequence. (B, H, T_1, 2 * T_1 - 1) + left_context: Number of frames in left context. + Returns: + x: Output sequence. (B, H, T_1, T_2) + """ + batch_size, n_heads, time1, n = x.shape + time2 = time1 + left_context + + batch_stride, n_heads_stride, time1_stride, n_stride = x.stride() + + return x.as_strided( + (batch_size, n_heads, time1, time2), + (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride), + storage_offset=(n_stride * (time1 - 1)), + ) + + def compute_simplified_attention_score( + self, + query: torch.Tensor, + key: torch.Tensor, + pos_enc: torch.Tensor, + left_context: int = 0, + ) -> torch.Tensor: + """Simplified attention score computation. + Reference: https://github.com/k2-fsa/icefall/pull/458 + Args: + query: Transformed query tensor. (B, H, T_1, d_k) + key: Transformed key tensor. (B, H, T_2, d_k) + pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) + left_context: Number of frames in left context. + Returns: + : Attention score. (B, H, T_1, T_2) + """ + pos_enc = self.linear_pos(pos_enc) + + matrix_ac = torch.matmul(query, key.transpose(2, 3)) + + matrix_bd = self.rel_shift( + pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1), + left_context=left_context, + ) + + return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) + + def compute_attention_score( + self, + query: torch.Tensor, + key: torch.Tensor, + pos_enc: torch.Tensor, + left_context: int = 0, + ) -> torch.Tensor: + """Attention score computation. + Args: + query: Transformed query tensor. (B, H, T_1, d_k) + key: Transformed key tensor. (B, H, T_2, d_k) + pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) + left_context: Number of frames in left context. + Returns: + : Attention score. (B, H, T_1, T_2) + """ + p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k) + + query = query.transpose(1, 2) + q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) + + matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1)) + + matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1)) + matrix_bd = self.rel_shift(matrix_bd, left_context=left_context) + + return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) + + def forward_qkv( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Transform query, key and value. + Args: + query: Query tensor. (B, T_1, size) + key: Key tensor. (B, T_2, size) + v: Value tensor. (B, T_2, size) + Returns: + q: Transformed query tensor. (B, H, T_1, d_k) + k: Transformed key tensor. (B, H, T_2, d_k) + v: Transformed value tensor. (B, H, T_2, d_k) + """ + n_batch = query.size(0) + + q = ( + self.linear_q(query) + .view(n_batch, -1, self.num_heads, self.d_k) + .transpose(1, 2) + ) + k = ( + self.linear_k(key) + .view(n_batch, -1, self.num_heads, self.d_k) + .transpose(1, 2) + ) + v = ( + self.linear_v(value) + .view(n_batch, -1, self.num_heads, self.d_k) + .transpose(1, 2) + ) + + return q, k, v + + def forward_attention( + self, + value: torch.Tensor, + scores: torch.Tensor, + mask: torch.Tensor, + chunk_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Compute attention context vector. + Args: + value: Transformed value. (B, H, T_2, d_k) + scores: Attention score. (B, H, T_1, T_2) + mask: Source mask. (B, T_2) + chunk_mask: Chunk mask. (T_1, T_1) + Returns: + attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k) + """ + batch_size = scores.size(0) + mask = mask.unsqueeze(1).unsqueeze(2) + if chunk_mask is not None: + mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask + scores = scores.masked_fill(mask, float("-inf")) + self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) + + attn_output = self.dropout(self.attn) + attn_output = torch.matmul(attn_output, value) + + attn_output = self.linear_out( + attn_output.transpose(1, 2) + .contiguous() + .view(batch_size, -1, self.num_heads * self.d_k) + ) + + return attn_output + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + chunk_mask: Optional[torch.Tensor] = None, + left_context: int = 0, + ) -> torch.Tensor: + """Compute scaled dot product attention with rel. positional encoding. + Args: + query: Query tensor. (B, T_1, size) + key: Key tensor. (B, T_2, size) + value: Value tensor. (B, T_2, size) + pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) + mask: Source mask. (B, T_2) + chunk_mask: Chunk mask. (T_1, T_1) + left_context: Number of frames in left context. + Returns: + : Output tensor. (B, T_1, H * d_k) + """ + q, k, v = self.forward_qkv(query, key, value) + scores = self.compute_att_score(q, k, pos_enc, left_context=left_context) + return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask) diff --git a/funasr/modules/embedding.py b/funasr/modules/embedding.py index 79ca0b2f8..e0070dedb 100644 --- a/funasr/modules/embedding.py +++ b/funasr/modules/embedding.py @@ -423,4 +423,79 @@ class SinusoidalPositionEncoder(torch.nn.Module): outputs = F.pad(outputs, (pad_left, pad_right)) outputs = outputs.transpose(1,2) return outputs - + +class StreamingRelPositionalEncoding(torch.nn.Module): + """Relative positional encoding. + Args: + size: Module size. + max_len: Maximum input length. + dropout_rate: Dropout rate. + """ + + def __init__( + self, size: int, dropout_rate: float = 0.0, max_len: int = 5000 + ) -> None: + """Construct a RelativePositionalEncoding object.""" + super().__init__() + + self.size = size + + self.pe = None + self.dropout = torch.nn.Dropout(p=dropout_rate) + + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + self._register_load_state_dict_pre_hook(_pre_hook) + + def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None: + """Reset positional encoding. + Args: + x: Input sequences. (B, T, ?) + left_context: Number of frames in left context. + """ + time1 = x.size(1) + left_context + + if self.pe is not None: + if self.pe.size(1) >= time1 * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(device=x.device, dtype=x.dtype) + return + + pe_positive = torch.zeros(time1, self.size) + pe_negative = torch.zeros(time1, self.size) + + position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.size, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.size) + ) + + pe_positive[:, 0::2] = torch.sin(position * div_term) + pe_positive[:, 1::2] = torch.cos(position * div_term) + pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) + + pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) + pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) + pe_negative = pe_negative[1:].unsqueeze(0) + + self.pe = torch.cat([pe_positive, pe_negative], dim=1).to( + dtype=x.dtype, device=x.device + ) + + def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor: + """Compute positional encoding. + Args: + x: Input sequences. (B, T, ?) + left_context: Number of frames in left context. + Returns: + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), ?) + """ + self.extend_pe(x, left_context=left_context) + + time1 = x.size(1) + left_context + + pos_enc = self.pe[ + :, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1) + ] + pos_enc = self.dropout(pos_enc) + + return pos_enc diff --git a/funasr/models/encoder/chunk_encoder_modules/normalization.py b/funasr/modules/normalization.py similarity index 100% rename from funasr/models/encoder/chunk_encoder_modules/normalization.py rename to funasr/modules/normalization.py diff --git a/funasr/modules/repeat.py b/funasr/modules/repeat.py index a3d2676a8..7241dd96b 100644 --- a/funasr/modules/repeat.py +++ b/funasr/modules/repeat.py @@ -6,6 +6,8 @@ """Repeat the same layer definition.""" +from typing import Dict, List, Optional + import torch @@ -31,3 +33,93 @@ def repeat(N, fn): """ return MultiSequential(*[fn(n) for n in range(N)]) + + +class MultiBlocks(torch.nn.Module): + """MultiBlocks definition. + Args: + block_list: Individual blocks of the encoder architecture. + output_size: Architecture output size. + norm_class: Normalization module class. + norm_args: Normalization module arguments. + """ + + def __init__( + self, + block_list: List[torch.nn.Module], + output_size: int, + norm_class: torch.nn.Module = torch.nn.LayerNorm, + norm_args: Optional[Dict] = None, + ) -> None: + """Construct a MultiBlocks object.""" + super().__init__() + + self.blocks = torch.nn.ModuleList(block_list) + self.norm_blocks = norm_class(output_size, **norm_args) + + self.num_blocks = len(block_list) + + def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: + """Initialize/Reset encoder streaming cache. + Args: + left_context: Number of left frames during chunk-by-chunk inference. + device: Device to use for cache tensor. + """ + for idx in range(self.num_blocks): + self.blocks[idx].reset_streaming_cache(left_context, device) + + def forward( + self, + x: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + chunk_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward each block of the encoder architecture. + Args: + x: MultiBlocks input sequences. (B, T, D_block_1) + pos_enc: Positional embedding sequences. + mask: Source mask. (B, T) + chunk_mask: Chunk mask. (T_2, T_2) + Returns: + x: Output sequences. (B, T, D_block_N) + """ + for block_index, block in enumerate(self.blocks): + x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask) + + x = self.norm_blocks(x) + + return x + + def chunk_forward( + self, + x: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + chunk_size: int = 0, + left_context: int = 0, + right_context: int = 0, + ) -> torch.Tensor: + """Forward each block of the encoder architecture. + Args: + x: MultiBlocks input sequences. (B, T, D_block_1) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att) + mask: Source mask. (B, T_2) + left_context: Number of frames in left context. + right_context: Number of frames in right context. + Returns: + x: MultiBlocks output sequences. (B, T, D_block_N) + """ + for block_idx, block in enumerate(self.blocks): + x, pos_enc = block.chunk_forward( + x, + pos_enc, + mask, + chunk_size=chunk_size, + left_context=left_context, + right_context=right_context, + ) + + x = self.norm_blocks(x) + + return x diff --git a/funasr/modules/subsampling.py b/funasr/modules/subsampling.py index d492ccf61..623be65bc 100644 --- a/funasr/modules/subsampling.py +++ b/funasr/modules/subsampling.py @@ -11,6 +11,10 @@ import torch.nn.functional as F from funasr.modules.embedding import PositionalEncoding import logging from funasr.modules.streaming_utils.utils import sequence_mask +from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len +from typing import Optional, Tuple, Union +import math + class TooShortUttError(Exception): """Raised when the utt is too short for subsampling. @@ -407,3 +411,201 @@ class Conv1dSubsampling(torch.nn.Module): var_dict_tf[name_tf].shape)) return var_dict_torch_update +class StreamingConvInput(torch.nn.Module): + """Streaming ConvInput module definition. + Args: + input_size: Input size. + conv_size: Convolution size. + subsampling_factor: Subsampling factor. + vgg_like: Whether to use a VGG-like network. + output_size: Block output dimension. + """ + + def __init__( + self, + input_size: int, + conv_size: Union[int, Tuple], + subsampling_factor: int = 4, + vgg_like: bool = True, + output_size: Optional[int] = None, + ) -> None: + """Construct a ConvInput object.""" + super().__init__() + if vgg_like: + if subsampling_factor == 1: + conv_size1, conv_size2 = conv_size + + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((1, 2)), + torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((1, 2)), + ) + + output_proj = conv_size2 * ((input_size // 2) // 2) + + self.subsampling_factor = 1 + + self.stride_1 = 1 + + self.create_new_mask = self.create_new_vgg_mask + + else: + conv_size1, conv_size2 = conv_size + + kernel_1 = int(subsampling_factor / 2) + + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((kernel_1, 2)), + torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((2, 2)), + ) + + output_proj = conv_size2 * ((input_size // 2) // 2) + + self.subsampling_factor = subsampling_factor + + self.create_new_mask = self.create_new_vgg_mask + + self.stride_1 = kernel_1 + + else: + if subsampling_factor == 1: + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]), + torch.nn.ReLU(), + ) + + output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2) + + self.subsampling_factor = subsampling_factor + self.kernel_2 = 3 + self.stride_2 = 1 + + self.create_new_mask = self.create_new_conv2d_mask + + else: + kernel_2, stride_2, conv_2_output_size = sub_factor_to_params( + subsampling_factor, + input_size, + ) + + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, conv_size, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2), + torch.nn.ReLU(), + ) + + output_proj = conv_size * conv_2_output_size + + self.subsampling_factor = subsampling_factor + self.kernel_2 = kernel_2 + self.stride_2 = stride_2 + + self.create_new_mask = self.create_new_conv2d_mask + + self.vgg_like = vgg_like + self.min_frame_length = 7 + + if output_size is not None: + self.output = torch.nn.Linear(output_proj, output_size) + self.output_size = output_size + else: + self.output = None + self.output_size = output_proj + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode input sequences. + Args: + x: ConvInput input sequences. (B, T, D_feats) + mask: Mask of input sequences. (B, 1, T) + Returns: + x: ConvInput output sequences. (B, sub(T), D_out) + mask: Mask of output sequences. (B, 1, sub(T)) + """ + if mask is not None: + mask = self.create_new_mask(mask) + olens = max(mask.eq(0).sum(1)) + + b, t, f = x.size() + x = x.unsqueeze(1) # (b. 1. t. f) + + if chunk_size is not None: + max_input_length = int( + chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) )) + ) + x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x) + x = list(x) + x = torch.stack(x, dim=0) + N_chunks = max_input_length // ( chunk_size * self.subsampling_factor) + x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f) + + x = self.conv(x) + + _, c, _, f = x.size() + if chunk_size is not None: + x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:] + else: + x = x.transpose(1, 2).contiguous().view(b, -1, c * f) + + if self.output is not None: + x = self.output(x) + + return x, mask[:,:olens][:,:x.size(1)] + + def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor: + """Create a new mask for VGG output sequences. + Args: + mask: Mask of input sequences. (B, T) + Returns: + mask: Mask of output sequences. (B, sub(T)) + """ + if self.subsampling_factor > 1: + vgg1_t_len = mask.size(1) - (mask.size(1) % (self.subsampling_factor // 2 )) + mask = mask[:, :vgg1_t_len][:, ::self.subsampling_factor // 2] + + vgg2_t_len = mask.size(1) - (mask.size(1) % 2) + mask = mask[:, :vgg2_t_len][:, ::2] + else: + mask = mask + + return mask + + def create_new_conv2d_mask(self, mask: torch.Tensor) -> torch.Tensor: + """Create new conformer mask for Conv2d output sequences. + Args: + mask: Mask of input sequences. (B, T) + Returns: + mask: Mask of output sequences. (B, sub(T)) + """ + if self.subsampling_factor > 1: + return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2] + else: + return mask + + def get_size_before_subsampling(self, size: int) -> int: + """Return the original size before subsampling for a given size. + Args: + size: Number of frames after subsampling. + Returns: + : Number of frames before subsampling. + """ + return size * self.subsampling_factor diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py index cae18c169..bb1f99643 100644 --- a/funasr/tasks/asr_transducer.py +++ b/funasr/tasks/asr_transducer.py @@ -24,7 +24,7 @@ from funasr.models.decoder.transformer_decoder import ( from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder -from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder +from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder from funasr.models.e2e_transducer import TransducerModel from funasr.models.e2e_transducer_unified import UnifiedTransducerModel from funasr.models.joint_network import JointNetwork @@ -72,9 +72,9 @@ normalize_choices = ClassChoices( encoder_choices = ClassChoices( "encoder", classes=dict( - encoder=Encoder, + chunk_conformer=ConformerChunkEncoder, ), - default="encoder", + default="chunk_conformer", ) decoder_choices = ClassChoices( From fa25b637b0d257186a8399eb1c530a91f4252702 Mon Sep 17 00:00:00 2001 From: aky15 Date: Fri, 14 Apr 2023 15:44:50 +0800 Subject: [PATCH 12/14] remove some functions --- funasr/models/e2e_transducer.py | 2 +- funasr/models/e2e_transducer_unified.py | 2 +- funasr/models/encoder/conformer_encoder.py | 11 +- funasr/models/joint_network.py | 5 +- .../__init__.py | 0 .../abs_decoder.py | 0 .../rnn_decoder.py | 2 +- .../stateless_decoder.py | 2 +- funasr/modules/activation.py | 213 ------------------ .../beam_search/beam_search_transducer.py | 2 +- funasr/modules/e2e_asr_common.py | 2 +- funasr/modules/normalization.py | 170 -------------- funasr/modules/repeat.py | 3 +- funasr/tasks/asr_transducer.py | 6 +- 14 files changed, 13 insertions(+), 407 deletions(-) rename funasr/models/{rnnt_decoder => rnnt_predictor}/__init__.py (100%) rename funasr/models/{rnnt_decoder => rnnt_predictor}/abs_decoder.py (100%) rename funasr/models/{rnnt_decoder => rnnt_predictor}/rnn_decoder.py (99%) rename funasr/models/{rnnt_decoder => rnnt_predictor}/stateless_decoder.py (98%) delete mode 100644 funasr/modules/activation.py delete mode 100644 funasr/modules/normalization.py diff --git a/funasr/models/e2e_transducer.py b/funasr/models/e2e_transducer.py index 8630aec40..460a6d796 100644 --- a/funasr/models/e2e_transducer.py +++ b/funasr/models/e2e_transducer.py @@ -10,7 +10,7 @@ from typeguard import check_argument_types from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.specaug.abs_specaug import AbsSpecAug -from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder +from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder from funasr.models.joint_network import JointNetwork diff --git a/funasr/models/e2e_transducer_unified.py b/funasr/models/e2e_transducer_unified.py index 124bc0938..f79ba57c4 100644 --- a/funasr/models/e2e_transducer_unified.py +++ b/funasr/models/e2e_transducer_unified.py @@ -10,7 +10,7 @@ from typeguard import check_argument_types from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.specaug.abs_specaug import AbsSpecAug -from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder +from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder from funasr.models.joint_network import JointNetwork from funasr.modules.nets_utils import get_transducer_task_io diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py index c837cf533..b7b552ce6 100644 --- a/funasr/models/encoder/conformer_encoder.py +++ b/funasr/models/encoder/conformer_encoder.py @@ -30,7 +30,6 @@ from funasr.modules.embedding import ( StreamingRelPositionalEncoding, ) from funasr.modules.layer_norm import LayerNorm -from funasr.modules.normalization import get_normalization from funasr.modules.multi_layer_conv import Conv1dLinear from funasr.modules.multi_layer_conv import MultiLayeredConv1d from funasr.modules.nets_utils import get_activation @@ -940,7 +939,6 @@ class ConformerChunkEncoder(torch.nn.Module): default_chunk_size: int = 16, jitter_range: int = 4, subsampling_factor: int = 1, - **activation_parameters, ) -> None: """Construct an Encoder object.""" super().__init__() @@ -961,7 +959,7 @@ class ConformerChunkEncoder(torch.nn.Module): ) activation = get_activation( - activation_type, **activation_parameters + activation_type ) pos_wise_args = ( @@ -991,9 +989,6 @@ class ConformerChunkEncoder(torch.nn.Module): simplified_att_score, ) - norm_class, norm_args = get_normalization( - norm_type, - ) fn_modules = [] for _ in range(num_blocks): @@ -1003,8 +998,6 @@ class ConformerChunkEncoder(torch.nn.Module): PositionwiseFeedForward(*pos_wise_args), PositionwiseFeedForward(*pos_wise_args), CausalConvolution(*conv_mod_args), - norm_class=norm_class, - norm_args=norm_args, dropout_rate=dropout_rate, ) fn_modules.append(module) @@ -1012,8 +1005,6 @@ class ConformerChunkEncoder(torch.nn.Module): self.encoders = MultiBlocks( [fn() for fn in fn_modules], output_size, - norm_class=norm_class, - norm_args=norm_args, ) self.output_size = output_size diff --git a/funasr/models/joint_network.py b/funasr/models/joint_network.py index 5cabdb4f7..ed827c420 100644 --- a/funasr/models/joint_network.py +++ b/funasr/models/joint_network.py @@ -2,7 +2,7 @@ import torch -from funasr.modules.activation import get_activation +from funasr.modules.nets_utils import get_activation class JointNetwork(torch.nn.Module): @@ -25,7 +25,6 @@ class JointNetwork(torch.nn.Module): decoder_size: int, joint_space_size: int = 256, joint_activation_type: str = "tanh", - **activation_parameters, ) -> None: """Construct a JointNetwork object.""" super().__init__() @@ -36,7 +35,7 @@ class JointNetwork(torch.nn.Module): self.lin_out = torch.nn.Linear(joint_space_size, output_size) self.joint_activation = get_activation( - joint_activation_type, **activation_parameters + joint_activation_type ) def forward( diff --git a/funasr/models/rnnt_decoder/__init__.py b/funasr/models/rnnt_predictor/__init__.py similarity index 100% rename from funasr/models/rnnt_decoder/__init__.py rename to funasr/models/rnnt_predictor/__init__.py diff --git a/funasr/models/rnnt_decoder/abs_decoder.py b/funasr/models/rnnt_predictor/abs_decoder.py similarity index 100% rename from funasr/models/rnnt_decoder/abs_decoder.py rename to funasr/models/rnnt_predictor/abs_decoder.py diff --git a/funasr/models/rnnt_decoder/rnn_decoder.py b/funasr/models/rnnt_predictor/rnn_decoder.py similarity index 99% rename from funasr/models/rnnt_decoder/rnn_decoder.py rename to funasr/models/rnnt_predictor/rnn_decoder.py index c4e79511c..0df6fc750 100644 --- a/funasr/models/rnnt_decoder/rnn_decoder.py +++ b/funasr/models/rnnt_predictor/rnn_decoder.py @@ -6,7 +6,7 @@ import torch from typeguard import check_argument_types from funasr.modules.beam_search.beam_search_transducer import Hypothesis -from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder +from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder from funasr.models.specaug.specaug import SpecAug class RNNDecoder(AbsDecoder): diff --git a/funasr/models/rnnt_decoder/stateless_decoder.py b/funasr/models/rnnt_predictor/stateless_decoder.py similarity index 98% rename from funasr/models/rnnt_decoder/stateless_decoder.py rename to funasr/models/rnnt_predictor/stateless_decoder.py index a2e1fc14b..70cd877f2 100644 --- a/funasr/models/rnnt_decoder/stateless_decoder.py +++ b/funasr/models/rnnt_predictor/stateless_decoder.py @@ -6,7 +6,7 @@ import torch from typeguard import check_argument_types from funasr.modules.beam_search.beam_search_transducer import Hypothesis -from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder +from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder from funasr.models.specaug.specaug import SpecAug class StatelessDecoder(AbsDecoder): diff --git a/funasr/modules/activation.py b/funasr/modules/activation.py deleted file mode 100644 index 82cda1251..000000000 --- a/funasr/modules/activation.py +++ /dev/null @@ -1,213 +0,0 @@ -"""Activation functions for Transducer.""" - -import torch -from packaging.version import parse as V - - -def get_activation( - activation_type: str, - ftswish_threshold: float = -0.2, - ftswish_mean_shift: float = 0.0, - hardtanh_min_val: int = -1.0, - hardtanh_max_val: int = 1.0, - leakyrelu_neg_slope: float = 0.01, - smish_alpha: float = 1.0, - smish_beta: float = 1.0, - softplus_beta: float = 1.0, - softplus_threshold: int = 20, - swish_beta: float = 1.0, -) -> torch.nn.Module: - """Return activation function. - - Args: - activation_type: Activation function type. - ftswish_threshold: Threshold value for FTSwish activation formulation. - ftswish_mean_shift: Mean shifting value for FTSwish activation formulation. - hardtanh_min_val: Minimum value of the linear region range for HardTanh. - hardtanh_max_val: Maximum value of the linear region range for HardTanh. - leakyrelu_neg_slope: Negative slope value for LeakyReLU activation formulation. - smish_alpha: Alpha value for Smish activation fomulation. - smish_beta: Beta value for Smish activation formulation. - softplus_beta: Beta value for softplus activation formulation in Mish. - softplus_threshold: Values above this revert to a linear function in Mish. - swish_beta: Beta value for Swish variant formulation. - - Returns: - : Activation function. - - """ - torch_version = V(torch.__version__) - - activations = { - "ftswish": ( - FTSwish, - {"threshold": ftswish_threshold, "mean_shift": ftswish_mean_shift}, - ), - "hardtanh": ( - torch.nn.Hardtanh, - {"min_val": hardtanh_min_val, "max_val": hardtanh_max_val}, - ), - "leaky_relu": (torch.nn.LeakyReLU, {"negative_slope": leakyrelu_neg_slope}), - "mish": ( - Mish, - { - "softplus_beta": softplus_beta, - "softplus_threshold": softplus_threshold, - "use_builtin": torch_version >= V("1.9"), - }, - ), - "relu": (torch.nn.ReLU, {}), - "selu": (torch.nn.SELU, {}), - "smish": (Smish, {"alpha": smish_alpha, "beta": smish_beta}), - "swish": ( - Swish, - {"beta": swish_beta, "use_builtin": torch_version >= V("1.8")}, - ), - "tanh": (torch.nn.Tanh, {}), - "identity": (torch.nn.Identity, {}), - } - - act_func, act_args = activations[activation_type] - - return act_func(**act_args) - - -class FTSwish(torch.nn.Module): - """Flatten-T Swish activation definition. - - FTSwish(x) = x * sigmoid(x) + threshold - where FTSwish(x) < 0 = threshold - - Reference: https://arxiv.org/abs/1812.06247 - - Args: - threshold: Threshold value for FTSwish activation formulation. (threshold < 0) - mean_shift: Mean shifting value for FTSwish activation formulation. - (applied only if != 0, disabled by default) - - """ - - def __init__(self, threshold: float = -0.2, mean_shift: float = 0) -> None: - super().__init__() - - assert threshold < 0, "FTSwish threshold parameter should be < 0." - - self.threshold = threshold - self.mean_shift = mean_shift - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward computation.""" - x = (x * torch.sigmoid(x)) + self.threshold - x = torch.where(x >= 0, x, torch.tensor([self.threshold], device=x.device)) - - if self.mean_shift != 0: - x.sub_(self.mean_shift) - - return x - - -class Mish(torch.nn.Module): - """Mish activation definition. - - Mish(x) = x * tanh(softplus(x)) - - Reference: https://arxiv.org/abs/1908.08681. - - Args: - softplus_beta: Beta value for softplus activation formulation. - (Usually 0 > softplus_beta >= 2) - softplus_threshold: Values above this revert to a linear function. - (Usually 10 > softplus_threshold >= 20) - use_builtin: Whether to use PyTorch activation function if available. - - """ - - def __init__( - self, - softplus_beta: float = 1.0, - softplus_threshold: int = 20, - use_builtin: bool = False, - ) -> None: - super().__init__() - - if use_builtin: - self.mish = torch.nn.Mish() - else: - self.tanh = torch.nn.Tanh() - self.softplus = torch.nn.Softplus( - beta=softplus_beta, threshold=softplus_threshold - ) - - self.mish = lambda x: x * self.tanh(self.softplus(x)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward computation.""" - return self.mish(x) - - -class Smish(torch.nn.Module): - """Smish activation definition. - - Smish(x) = (alpha * x) * tanh(log(1 + sigmoid(beta * x))) - where alpha > 0 and beta > 0 - - Reference: https://www.mdpi.com/2079-9292/11/4/540/htm. - - Args: - alpha: Alpha value for Smish activation fomulation. - (Usually, alpha = 1. If alpha <= 0, set value to 1). - beta: Beta value for Smish activation formulation. - (Usually, beta = 1. If beta <= 0, set value to 1). - - """ - - def __init__(self, alpha: float = 1.0, beta: float = 1.0) -> None: - super().__init__() - - self.tanh = torch.nn.Tanh() - - self.alpha = alpha if alpha > 0 else 1 - self.beta = beta if beta > 0 else 1 - - self.smish = lambda x: (self.alpha * x) * self.tanh( - torch.log(1 + torch.sigmoid((self.beta * x))) - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward computation.""" - return self.smish(x) - - -class Swish(torch.nn.Module): - """Swish activation definition. - - Swish(x) = (beta * x) * sigmoid(x) - where beta = 1 defines standard Swish activation. - - References: - https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1. - E-swish variant: https://arxiv.org/abs/1801.07145. - - Args: - beta: Beta parameter for E-Swish. - (beta >= 1. If beta < 1, use standard Swish). - use_builtin: Whether to use PyTorch function if available. - - """ - - def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None: - super().__init__() - - self.beta = beta - - if beta > 1: - self.swish = lambda x: (self.beta * x) * torch.sigmoid(x) - else: - if use_builtin: - self.swish = torch.nn.SiLU() - else: - self.swish = lambda x: x * torch.sigmoid(x) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward computation.""" - return self.swish(x) diff --git a/funasr/modules/beam_search/beam_search_transducer.py b/funasr/modules/beam_search/beam_search_transducer.py index eaf5627f9..49cce92a1 100644 --- a/funasr/modules/beam_search/beam_search_transducer.py +++ b/funasr/modules/beam_search/beam_search_transducer.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch -from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder +from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder from funasr.models.joint_network import JointNetwork diff --git a/funasr/modules/e2e_asr_common.py b/funasr/modules/e2e_asr_common.py index 9b5039c91..3746036ba 100644 --- a/funasr/modules/e2e_asr_common.py +++ b/funasr/modules/e2e_asr_common.py @@ -18,7 +18,7 @@ import six import torch from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer -from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder +from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder from funasr.models.joint_network import JointNetwork def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): diff --git a/funasr/modules/normalization.py b/funasr/modules/normalization.py deleted file mode 100644 index ae35fd43f..000000000 --- a/funasr/modules/normalization.py +++ /dev/null @@ -1,170 +0,0 @@ -"""Normalization modules for X-former blocks.""" - -from typing import Dict, Optional, Tuple - -import torch - - -def get_normalization( - normalization_type: str, - eps: Optional[float] = None, - partial: Optional[float] = None, -) -> Tuple[torch.nn.Module, Dict]: - """Get normalization module and arguments given parameters. - - Args: - normalization_type: Normalization module type. - eps: Value added to the denominator. - partial: Value defining the part of the input used for RMS stats (RMSNorm). - - Return: - : Normalization module class - : Normalization module arguments - - """ - norm = { - "basic_norm": ( - BasicNorm, - {"eps": eps if eps is not None else 0.25}, - ), - "layer_norm": (torch.nn.LayerNorm, {"eps": eps if eps is not None else 1e-12}), - "rms_norm": ( - RMSNorm, - { - "eps": eps if eps is not None else 1e-05, - "partial": partial if partial is not None else -1.0, - }, - ), - "scale_norm": ( - ScaleNorm, - {"eps": eps if eps is not None else 1e-05}, - ), - } - - return norm[normalization_type] - - -class BasicNorm(torch.nn.Module): - """BasicNorm module definition. - - Reference: https://github.com/k2-fsa/icefall/pull/288 - - Args: - normalized_shape: Expected size. - eps: Value added to the denominator for numerical stability. - - """ - - def __init__( - self, - normalized_shape: int, - eps: float = 0.25, - ) -> None: - """Construct a BasicNorm object.""" - super().__init__() - - self.eps = torch.nn.Parameter(torch.tensor(eps).log().detach()) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Compute basic normalization. - - Args: - x: Input sequences. (B, T, D_hidden) - - Returns: - : Output sequences. (B, T, D_hidden) - - """ - scales = (torch.mean(x.pow(2), dim=-1, keepdim=True) + self.eps.exp()) ** -0.5 - - return x * scales - - -class RMSNorm(torch.nn.Module): - """RMSNorm module definition. - - Reference: https://arxiv.org/pdf/1910.07467.pdf - - Args: - normalized_shape: Expected size. - eps: Value added to the denominator for numerical stability. - partial: Value defining the part of the input used for RMS stats. - - """ - - def __init__( - self, - normalized_shape: int, - eps: float = 1e-5, - partial: float = 0.0, - ) -> None: - """Construct a RMSNorm object.""" - super().__init__() - - self.normalized_shape = normalized_shape - - self.partial = True if 0 < partial < 1 else False - self.p = partial - self.eps = eps - - self.scale = torch.nn.Parameter(torch.ones(normalized_shape)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Compute RMS normalization. - - Args: - x: Input sequences. (B, T, D_hidden) - - Returns: - x: Output sequences. (B, T, D_hidden) - - """ - if self.partial: - partial_size = int(self.normalized_shape * self.p) - partial_x, _ = torch.split( - x, [partial_size, self.normalized_shape - partial_size], dim=-1 - ) - - norm_x = partial_x.norm(2, dim=-1, keepdim=True) - d_x = partial_size - else: - norm_x = x.norm(2, dim=-1, keepdim=True) - d_x = self.normalized_shape - - rms_x = norm_x * d_x ** (-1.0 / 2) - x = self.scale * (x / (rms_x + self.eps)) - - return x - - -class ScaleNorm(torch.nn.Module): - """ScaleNorm module definition. - - Reference: https://arxiv.org/pdf/1910.05895.pdf - - Args: - normalized_shape: Expected size. - eps: Value added to the denominator for numerical stability. - - """ - - def __init__(self, normalized_shape: int, eps: float = 1e-5) -> None: - """Construct a ScaleNorm object.""" - super().__init__() - - self.eps = eps - self.scale = torch.nn.Parameter(torch.tensor(normalized_shape**0.5)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Compute scale normalization. - - Args: - x: Input sequences. (B, T, D_hidden) - - Returns: - : Output sequences. (B, T, D_hidden) - - """ - norm = self.scale / torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps) - - return x * norm diff --git a/funasr/modules/repeat.py b/funasr/modules/repeat.py index 7241dd96b..2b2dac8f3 100644 --- a/funasr/modules/repeat.py +++ b/funasr/modules/repeat.py @@ -49,13 +49,12 @@ class MultiBlocks(torch.nn.Module): block_list: List[torch.nn.Module], output_size: int, norm_class: torch.nn.Module = torch.nn.LayerNorm, - norm_args: Optional[Dict] = None, ) -> None: """Construct a MultiBlocks object.""" super().__init__() self.blocks = torch.nn.ModuleList(block_list) - self.norm_blocks = norm_class(output_size, **norm_args) + self.norm_blocks = norm_class(output_size) self.num_blocks = len(block_list) diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py index bb1f99643..99b3d0c2e 100644 --- a/funasr/tasks/asr_transducer.py +++ b/funasr/tasks/asr_transducer.py @@ -21,9 +21,9 @@ from funasr.models.decoder.transformer_decoder import ( LightweightConvolutionTransformerDecoder, TransformerDecoder, ) -from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder -from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder -from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder +from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder +from funasr.models.rnnt_predictor.rnn_decoder import RNNDecoder +from funasr.models.rnnt_predictor.stateless_decoder import StatelessDecoder from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder from funasr.models.e2e_transducer import TransducerModel from funasr.models.e2e_transducer_unified import UnifiedTransducerModel From b3b4c1bc5bb068c9aff99740e6257d12c6676ff7 Mon Sep 17 00:00:00 2001 From: aky15 Date: Mon, 17 Apr 2023 11:19:14 +0800 Subject: [PATCH 13/14] rename some functions --- funasr/models/{e2e_transducer.py => e2e_asr_transducer.py} | 2 +- ..._transducer_unified.py => e2e_asr_transducer_unified.py} | 2 +- funasr/models/{ => joint_net}/joint_network.py | 0 funasr/modules/beam_search/beam_search_transducer.py | 2 +- funasr/modules/e2e_asr_common.py | 2 +- funasr/tasks/asr_transducer.py | 6 +++--- 6 files changed, 7 insertions(+), 7 deletions(-) rename funasr/models/{e2e_transducer.py => e2e_asr_transducer.py} (99%) rename funasr/models/{e2e_transducer_unified.py => e2e_asr_transducer_unified.py} (99%) rename funasr/models/{ => joint_net}/joint_network.py (100%) diff --git a/funasr/models/e2e_transducer.py b/funasr/models/e2e_asr_transducer.py similarity index 99% rename from funasr/models/e2e_transducer.py rename to funasr/models/e2e_asr_transducer.py index 460a6d796..6eb002320 100644 --- a/funasr/models/e2e_transducer.py +++ b/funasr/models/e2e_asr_transducer.py @@ -13,7 +13,7 @@ from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder -from funasr.models.joint_network import JointNetwork +from funasr.models.joint_net.joint_network import JointNetwork from funasr.modules.nets_utils import get_transducer_task_io from funasr.layers.abs_normalize import AbsNormalize from funasr.torch_utils.device_funcs import force_gatherable diff --git a/funasr/models/e2e_transducer_unified.py b/funasr/models/e2e_asr_transducer_unified.py similarity index 99% rename from funasr/models/e2e_transducer_unified.py rename to funasr/models/e2e_asr_transducer_unified.py index f79ba57c4..ad61d12c0 100644 --- a/funasr/models/e2e_transducer_unified.py +++ b/funasr/models/e2e_asr_transducer_unified.py @@ -12,7 +12,7 @@ from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder -from funasr.models.joint_network import JointNetwork +from funasr.models.joint_net.joint_network import JointNetwork from funasr.modules.nets_utils import get_transducer_task_io from funasr.layers.abs_normalize import AbsNormalize from funasr.torch_utils.device_funcs import force_gatherable diff --git a/funasr/models/joint_network.py b/funasr/models/joint_net/joint_network.py similarity index 100% rename from funasr/models/joint_network.py rename to funasr/models/joint_net/joint_network.py diff --git a/funasr/modules/beam_search/beam_search_transducer.py b/funasr/modules/beam_search/beam_search_transducer.py index 49cce92a1..8b7e613fe 100644 --- a/funasr/modules/beam_search/beam_search_transducer.py +++ b/funasr/modules/beam_search/beam_search_transducer.py @@ -7,7 +7,7 @@ import numpy as np import torch from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder -from funasr.models.joint_network import JointNetwork +from funasr.models.joint_net.joint_network import JointNetwork @dataclass diff --git a/funasr/modules/e2e_asr_common.py b/funasr/modules/e2e_asr_common.py index 3746036ba..a01cd5ef1 100644 --- a/funasr/modules/e2e_asr_common.py +++ b/funasr/modules/e2e_asr_common.py @@ -19,7 +19,7 @@ import torch from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder -from funasr.models.joint_network import JointNetwork +from funasr.models.joint_net.joint_network import JointNetwork def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): """End detection. diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py index 99b3d0c2e..d4136d068 100644 --- a/funasr/tasks/asr_transducer.py +++ b/funasr/tasks/asr_transducer.py @@ -25,9 +25,9 @@ from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder from funasr.models.rnnt_predictor.rnn_decoder import RNNDecoder from funasr.models.rnnt_predictor.stateless_decoder import StatelessDecoder from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder -from funasr.models.e2e_transducer import TransducerModel -from funasr.models.e2e_transducer_unified import UnifiedTransducerModel -from funasr.models.joint_network import JointNetwork +from funasr.models.e2e_asr_transducer import TransducerModel +from funasr.models.e2e_asr_transducer_unified import UnifiedTransducerModel +from funasr.models.joint_net.joint_network import JointNetwork from funasr.layers.abs_normalize import AbsNormalize from funasr.layers.global_mvn import GlobalMVN from funasr.layers.utterance_mvn import UtteranceMVN From 8672352ecde80a86609fe01195b398ebe77f0ed1 Mon Sep 17 00:00:00 2001 From: aky15 Date: Mon, 17 Apr 2023 16:09:23 +0800 Subject: [PATCH 14/14] merge many functions --- .../conf/train_conformer_rnnt_unified.yaml | 30 +- funasr/bin/asr_inference_rnnt.py | 2 +- funasr/bin/asr_train_transducer.py | 2 +- .../rnnt_decoder.py} | 3 +- funasr/models/e2e_asr_transducer.py | 535 +++++++++++++++- funasr/models/e2e_asr_transducer_unified.py | 586 ------------------ funasr/models/encoder/conformer_encoder.py | 7 +- funasr/models/rnnt_predictor/__init__.py | 0 funasr/models/rnnt_predictor/abs_decoder.py | 110 ---- .../rnnt_predictor/stateless_decoder.py | 145 ----- .../beam_search/beam_search_transducer.py | 3 +- funasr/modules/e2e_asr_common.py | 3 +- funasr/tasks/asr.py | 391 +++++++++++- funasr/tasks/asr_transducer.py | 477 -------------- 14 files changed, 944 insertions(+), 1350 deletions(-) rename funasr/models/{rnnt_predictor/rnn_decoder.py => decoder/rnnt_decoder.py} (98%) delete mode 100644 funasr/models/e2e_asr_transducer_unified.py delete mode 100644 funasr/models/rnnt_predictor/__init__.py delete mode 100644 funasr/models/rnnt_predictor/abs_decoder.py delete mode 100644 funasr/models/rnnt_predictor/stateless_decoder.py delete mode 100644 funasr/tasks/asr_transducer.py diff --git a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml index 60f796c75..8a1c40cac 100644 --- a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml +++ b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml @@ -1,32 +1,26 @@ +encoder: chunk_conformer encoder_conf: - main_conf: - pos_wise_act_type: swish - pos_enc_dropout_rate: 0.5 - conv_mod_act_type: swish + activation_type: swish + positional_dropout_rate: 0.5 time_reduction_factor: 2 unified_model_training: true default_chunk_size: 16 jitter_range: 4 left_chunk_size: 0 - input_conf: - block_type: conv2d - conv_size: 512 + embed_vgg_like: false subsampling_factor: 4 - num_frame: 1 - body_conf: - - block_type: conformer - linear_size: 2048 - hidden_size: 512 - heads: 8 + linear_units: 2048 + output_size: 512 + attention_heads: 8 dropout_rate: 0.5 - pos_wise_dropout_rate: 0.5 - att_dropout_rate: 0.5 - conv_mod_kernel_size: 15 + positional_dropout_rate: 0.5 + attention_dropout_rate: 0.5 + cnn_module_kernel: 15 num_blocks: 12 # decoder related -decoder: rnn -decoder_conf: +rnnt_decoder: rnnt +rnnt_decoder_conf: embed_size: 512 hidden_size: 512 embed_dropout_rate: 0.5 diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py index 465f88254..bff87022e 100644 --- a/funasr/bin/asr_inference_rnnt.py +++ b/funasr/bin/asr_inference_rnnt.py @@ -22,7 +22,7 @@ from funasr.modules.beam_search.beam_search_transducer import ( ) from funasr.modules.nets_utils import TooShortUttError from funasr.fileio.datadir_writer import DatadirWriter -from funasr.tasks.asr_transducer import ASRTransducerTask +from funasr.tasks.asr import ASRTransducerTask from funasr.tasks.lm import LMTask from funasr.text.build_tokenizer import build_tokenizer from funasr.text.token_id_converter import TokenIDConverter diff --git a/funasr/bin/asr_train_transducer.py b/funasr/bin/asr_train_transducer.py index 9b6d287dd..fe418dbc9 100755 --- a/funasr/bin/asr_train_transducer.py +++ b/funasr/bin/asr_train_transducer.py @@ -2,7 +2,7 @@ import os -from funasr.tasks.asr_transducer import ASRTransducerTask +from funasr.tasks.asr import ASRTransducerTask # for ASR Training diff --git a/funasr/models/rnnt_predictor/rnn_decoder.py b/funasr/models/decoder/rnnt_decoder.py similarity index 98% rename from funasr/models/rnnt_predictor/rnn_decoder.py rename to funasr/models/decoder/rnnt_decoder.py index 0df6fc750..5401ab20c 100644 --- a/funasr/models/rnnt_predictor/rnn_decoder.py +++ b/funasr/models/decoder/rnnt_decoder.py @@ -6,10 +6,9 @@ import torch from typeguard import check_argument_types from funasr.modules.beam_search.beam_search_transducer import Hypothesis -from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder from funasr.models.specaug.specaug import SpecAug -class RNNDecoder(AbsDecoder): +class RNNTDecoder(torch.nn.Module): """RNN decoder module. Args: diff --git a/funasr/models/e2e_asr_transducer.py b/funasr/models/e2e_asr_transducer.py index 6eb002320..0cae30605 100644 --- a/funasr/models/e2e_asr_transducer.py +++ b/funasr/models/e2e_asr_transducer.py @@ -10,7 +10,7 @@ from typeguard import check_argument_types from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.specaug.abs_specaug import AbsSpecAug -from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder +from funasr.models.decoder.rnnt_decoder import RNNTDecoder from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder from funasr.models.joint_net.joint_network import JointNetwork @@ -63,9 +63,9 @@ class TransducerModel(AbsESPnetModel): specaug: Optional[AbsSpecAug], normalize: Optional[AbsNormalize], encoder: Encoder, - decoder: AbsDecoder, - att_decoder: Optional[AbsAttDecoder], + decoder: RNNTDecoder, joint_network: JointNetwork, + att_decoder: Optional[AbsAttDecoder] = None, transducer_weight: float = 1.0, fastemit_lambda: float = 0.0, auxiliary_ctc_weight: float = 0.0, @@ -482,3 +482,532 @@ class TransducerModel(AbsESPnetModel): ) return loss_lm + +class UnifiedTransducerModel(AbsESPnetModel): + """ESPnet2ASRTransducerModel module definition. + Args: + vocab_size: Size of complete vocabulary (w/ EOS and blank included). + token_list: List of token + frontend: Frontend module. + specaug: SpecAugment module. + normalize: Normalization module. + encoder: Encoder module. + decoder: Decoder module. + joint_network: Joint Network module. + transducer_weight: Weight of the Transducer loss. + fastemit_lambda: FastEmit lambda value. + auxiliary_ctc_weight: Weight of auxiliary CTC loss. + auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs. + auxiliary_lm_loss_weight: Weight of auxiliary LM loss. + auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing. + ignore_id: Initial padding ID. + sym_space: Space symbol. + sym_blank: Blank Symbol + report_cer: Whether to report Character Error Rate during validation. + report_wer: Whether to report Word Error Rate during validation. + extract_feats_in_collect_stats: Whether to use extract_feats stats collection. + """ + + def __init__( + self, + vocab_size: int, + token_list: Union[Tuple[str, ...], List[str]], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + encoder: Encoder, + decoder: RNNTDecoder, + joint_network: JointNetwork, + att_decoder: Optional[AbsAttDecoder] = None, + transducer_weight: float = 1.0, + fastemit_lambda: float = 0.0, + auxiliary_ctc_weight: float = 0.0, + auxiliary_att_weight: float = 0.0, + auxiliary_ctc_dropout_rate: float = 0.0, + auxiliary_lm_loss_weight: float = 0.0, + auxiliary_lm_loss_smoothing: float = 0.0, + ignore_id: int = -1, + sym_space: str = "", + sym_blank: str = "", + report_cer: bool = True, + report_wer: bool = True, + sym_sos: str = "", + sym_eos: str = "", + extract_feats_in_collect_stats: bool = True, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + ) -> None: + """Construct an ESPnetASRTransducerModel object.""" + super().__init__() + + assert check_argument_types() + + # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos) + self.blank_id = 0 + + if sym_sos in token_list: + self.sos = token_list.index(sym_sos) + else: + self.sos = vocab_size - 1 + if sym_eos in token_list: + self.eos = token_list.index(sym_eos) + else: + self.eos = vocab_size - 1 + + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.token_list = token_list.copy() + + self.sym_space = sym_space + self.sym_blank = sym_blank + + self.frontend = frontend + self.specaug = specaug + self.normalize = normalize + + self.encoder = encoder + self.decoder = decoder + self.joint_network = joint_network + + self.criterion_transducer = None + self.error_calculator = None + + self.use_auxiliary_ctc = auxiliary_ctc_weight > 0 + self.use_auxiliary_att = auxiliary_att_weight > 0 + self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0 + + if self.use_auxiliary_ctc: + self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size) + self.ctc_dropout_rate = auxiliary_ctc_dropout_rate + + if self.use_auxiliary_att: + self.att_decoder = att_decoder + + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + if self.use_auxiliary_lm_loss: + self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size) + self.lm_loss_smoothing = auxiliary_lm_loss_smoothing + + self.transducer_weight = transducer_weight + self.fastemit_lambda = fastemit_lambda + + self.auxiliary_ctc_weight = auxiliary_ctc_weight + self.auxiliary_att_weight = auxiliary_att_weight + self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight + + self.report_cer = report_cer + self.report_wer = report_wer + + self.extract_feats_in_collect_stats = extract_feats_in_collect_stats + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Forward architecture and compute loss(es). + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + text: Label ID sequences. (B, L) + text_lengths: Label ID sequences lengths. (B,) + kwargs: Contains "utts_id". + Return: + loss: Main loss value. + stats: Task statistics. + weight: Task weights. + """ + assert text_lengths.dim() == 1, text_lengths.shape + assert ( + speech.shape[0] + == speech_lengths.shape[0] + == text.shape[0] + == text_lengths.shape[0] + ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + + batch_size = speech.shape[0] + text = text[:, : text_lengths.max()] + #print(speech.shape) + # 1. Encoder + encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths) + + loss_att, loss_att_chunk = 0.0, 0.0 + + if self.use_auxiliary_att: + loss_att, _ = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + loss_att_chunk, _ = self._calc_att_loss( + encoder_out_chunk, encoder_out_lens, text, text_lengths + ) + + # 2. Transducer-related I/O preparation + decoder_in, target, t_len, u_len = get_transducer_task_io( + text, + encoder_out_lens, + ignore_id=self.ignore_id, + ) + + # 3. Decoder + self.decoder.set_device(encoder_out.device) + decoder_out = self.decoder(decoder_in, u_len) + + # 4. Joint Network + joint_out = self.joint_network( + encoder_out.unsqueeze(2), decoder_out.unsqueeze(1) + ) + + joint_out_chunk = self.joint_network( + encoder_out_chunk.unsqueeze(2), decoder_out.unsqueeze(1) + ) + + # 5. Losses + loss_trans_utt, cer_trans, wer_trans = self._calc_transducer_loss( + encoder_out, + joint_out, + target, + t_len, + u_len, + ) + + loss_trans_chunk, cer_trans_chunk, wer_trans_chunk = self._calc_transducer_loss( + encoder_out_chunk, + joint_out_chunk, + target, + t_len, + u_len, + ) + + loss_ctc, loss_ctc_chunk, loss_lm = 0.0, 0.0, 0.0 + + if self.use_auxiliary_ctc: + loss_ctc = self._calc_ctc_loss( + encoder_out, + target, + t_len, + u_len, + ) + loss_ctc_chunk = self._calc_ctc_loss( + encoder_out_chunk, + target, + t_len, + u_len, + ) + + if self.use_auxiliary_lm_loss: + loss_lm = self._calc_lm_loss(decoder_out, target) + + loss_trans = loss_trans_utt + loss_trans_chunk + loss_ctc = loss_ctc + loss_ctc_chunk + loss_ctc = loss_att + loss_att_chunk + + loss = ( + self.transducer_weight * loss_trans + + self.auxiliary_ctc_weight * loss_ctc + + self.auxiliary_att_weight * loss_att + + self.auxiliary_lm_loss_weight * loss_lm + ) + + stats = dict( + loss=loss.detach(), + loss_transducer=loss_trans_utt.detach(), + loss_transducer_chunk=loss_trans_chunk.detach(), + aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None, + aux_ctc_loss_chunk=loss_ctc_chunk.detach() if loss_ctc_chunk > 0.0 else None, + aux_att_loss=loss_att.detach() if loss_att > 0.0 else None, + aux_att_loss_chunk=loss_att_chunk.detach() if loss_att_chunk > 0.0 else None, + aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None, + cer_transducer=cer_trans, + wer_transducer=wer_trans, + cer_transducer_chunk=cer_trans_chunk, + wer_transducer_chunk=wer_trans_chunk, + ) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def collect_feats( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Dict[str, torch.Tensor]: + """Collect features sequences and features lengths sequences. + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + text: Label ID sequences. (B, L) + text_lengths: Label ID sequences lengths. (B,) + kwargs: Contains "utts_id". + Return: + {}: "feats": Features sequences. (B, T, D_feats), + "feats_lengths": Features sequences lengths. (B,) + """ + if self.extract_feats_in_collect_stats: + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + else: + # Generate dummy stats if extract_feats_in_collect_stats is False + logging.warning( + "Generating dummy stats for feats and feats_lengths, " + "because encoder_conf.extract_feats_in_collect_stats is " + f"{self.extract_feats_in_collect_stats}" + ) + + feats, feats_lengths = speech, speech_lengths + + return {"feats": feats, "feats_lengths": feats_lengths} + + def encode( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encoder speech sequences. + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + Return: + encoder_out: Encoder outputs. (B, T, D_enc) + encoder_out_lens: Encoder outputs lengths. (B,) + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # 4. Forward encoder + encoder_out, encoder_out_chunk, encoder_out_lens = self.encoder(feats, feats_lengths) + + assert encoder_out.size(0) == speech.size(0), ( + encoder_out.size(), + speech.size(0), + ) + assert encoder_out.size(1) <= encoder_out_lens.max(), ( + encoder_out.size(), + encoder_out_lens.max(), + ) + + return encoder_out, encoder_out_chunk, encoder_out_lens + + def _extract_feats( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Extract features sequences and features sequences lengths. + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + Return: + feats: Features sequences. (B, T, D_feats) + feats_lengths: Features sequences lengths. (B,) + """ + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, : speech_lengths.max()] + + if self.frontend is not None: + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + feats, feats_lengths = speech, speech_lengths + + return feats, feats_lengths + + def _calc_transducer_loss( + self, + encoder_out: torch.Tensor, + joint_out: torch.Tensor, + target: torch.Tensor, + t_len: torch.Tensor, + u_len: torch.Tensor, + ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]: + """Compute Transducer loss. + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + joint_out: Joint Network output sequences (B, T, U, D_joint) + target: Target label ID sequences. (B, L) + t_len: Encoder output sequences lengths. (B,) + u_len: Target label ID sequences lengths. (B,) + Return: + loss_transducer: Transducer loss value. + cer_transducer: Character error rate for Transducer. + wer_transducer: Word Error Rate for Transducer. + """ + if self.criterion_transducer is None: + try: + # from warprnnt_pytorch import RNNTLoss + # self.criterion_transducer = RNNTLoss( + # reduction="mean", + # fastemit_lambda=self.fastemit_lambda, + # ) + from warp_rnnt import rnnt_loss as RNNTLoss + self.criterion_transducer = RNNTLoss + + except ImportError: + logging.error( + "warp-rnnt was not installed." + "Please consult the installation documentation." + ) + exit(1) + + # loss_transducer = self.criterion_transducer( + # joint_out, + # target, + # t_len, + # u_len, + # ) + log_probs = torch.log_softmax(joint_out, dim=-1) + + loss_transducer = self.criterion_transducer( + log_probs, + target, + t_len, + u_len, + reduction="mean", + blank=self.blank_id, + fastemit_lambda=self.fastemit_lambda, + gather=True, + ) + + if not self.training and (self.report_cer or self.report_wer): + if self.error_calculator is None: + self.error_calculator = ErrorCalculator( + self.decoder, + self.joint_network, + self.token_list, + self.sym_space, + self.sym_blank, + report_cer=self.report_cer, + report_wer=self.report_wer, + ) + + cer_transducer, wer_transducer = self.error_calculator(encoder_out, target) + return loss_transducer, cer_transducer, wer_transducer + + return loss_transducer, None, None + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + target: torch.Tensor, + t_len: torch.Tensor, + u_len: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + target: Target label ID sequences. (B, L) + t_len: Encoder output sequences lengths. (B,) + u_len: Target label ID sequences lengths. (B,) + Return: + loss_ctc: CTC loss value. + """ + ctc_in = self.ctc_lin( + torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate) + ) + ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1) + + target_mask = target != 0 + ctc_target = target[target_mask].cpu() + + with torch.backends.cudnn.flags(deterministic=True): + loss_ctc = torch.nn.functional.ctc_loss( + ctc_in, + ctc_target, + t_len, + u_len, + zero_infinity=True, + reduction="sum", + ) + loss_ctc /= target.size(0) + + return loss_ctc + + def _calc_lm_loss( + self, + decoder_out: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + """Compute LM loss. + Args: + decoder_out: Decoder output sequences. (B, U, D_dec) + target: Target label ID sequences. (B, L) + Return: + loss_lm: LM loss value. + """ + lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size) + lm_target = target.view(-1).type(torch.int64) + + with torch.no_grad(): + true_dist = lm_loss_in.clone() + true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1)) + + # Ignore blank ID (0) + ignore = lm_target == 0 + lm_target = lm_target.masked_fill(ignore, 0) + + true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing)) + + loss_lm = torch.nn.functional.kl_div( + torch.log_softmax(lm_loss_in, dim=1), + true_dist, + reduction="none", + ) + loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size( + 0 + ) + + return loss_lm + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + if hasattr(self, "lang_token_id") and self.lang_token_id is not None: + ys_pad = torch.cat( + [ + self.lang_token_id.repeat(ys_pad.size(0), 1).to(ys_pad.device), + ys_pad, + ], + dim=1, + ) + ys_pad_lens += 1 + + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.att_decoder( + encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens + ) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, + ) + + return loss_att, acc_att diff --git a/funasr/models/e2e_asr_transducer_unified.py b/funasr/models/e2e_asr_transducer_unified.py deleted file mode 100644 index ad61d12c0..000000000 --- a/funasr/models/e2e_asr_transducer_unified.py +++ /dev/null @@ -1,586 +0,0 @@ -"""ESPnet2 ASR Transducer model.""" - -import logging -from contextlib import contextmanager -from typing import Dict, List, Optional, Tuple, Union - -import torch -from packaging.version import parse as V -from typeguard import check_argument_types - -from funasr.models.frontend.abs_frontend import AbsFrontend -from funasr.models.specaug.abs_specaug import AbsSpecAug -from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder -from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder -from funasr.models.joint_net.joint_network import JointNetwork -from funasr.modules.nets_utils import get_transducer_task_io -from funasr.layers.abs_normalize import AbsNormalize -from funasr.torch_utils.device_funcs import force_gatherable -from funasr.train.abs_espnet_model import AbsESPnetModel -from funasr.modules.add_sos_eos import add_sos_eos -from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder -from funasr.modules.nets_utils import th_accuracy -from funasr.losses.label_smoothing_loss import ( # noqa: H301 - LabelSmoothingLoss, -) -from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator -if V(torch.__version__) >= V("1.6.0"): - from torch.cuda.amp import autocast -else: - - @contextmanager - def autocast(enabled=True): - yield - - -class UnifiedTransducerModel(AbsESPnetModel): - """ESPnet2ASRTransducerModel module definition. - - Args: - vocab_size: Size of complete vocabulary (w/ EOS and blank included). - token_list: List of token - frontend: Frontend module. - specaug: SpecAugment module. - normalize: Normalization module. - encoder: Encoder module. - decoder: Decoder module. - joint_network: Joint Network module. - transducer_weight: Weight of the Transducer loss. - fastemit_lambda: FastEmit lambda value. - auxiliary_ctc_weight: Weight of auxiliary CTC loss. - auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs. - auxiliary_lm_loss_weight: Weight of auxiliary LM loss. - auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing. - ignore_id: Initial padding ID. - sym_space: Space symbol. - sym_blank: Blank Symbol - report_cer: Whether to report Character Error Rate during validation. - report_wer: Whether to report Word Error Rate during validation. - extract_feats_in_collect_stats: Whether to use extract_feats stats collection. - - """ - - def __init__( - self, - vocab_size: int, - token_list: Union[Tuple[str, ...], List[str]], - frontend: Optional[AbsFrontend], - specaug: Optional[AbsSpecAug], - normalize: Optional[AbsNormalize], - encoder: Encoder, - decoder: AbsDecoder, - att_decoder: Optional[AbsAttDecoder], - joint_network: JointNetwork, - transducer_weight: float = 1.0, - fastemit_lambda: float = 0.0, - auxiliary_ctc_weight: float = 0.0, - auxiliary_att_weight: float = 0.0, - auxiliary_ctc_dropout_rate: float = 0.0, - auxiliary_lm_loss_weight: float = 0.0, - auxiliary_lm_loss_smoothing: float = 0.0, - ignore_id: int = -1, - sym_space: str = "", - sym_blank: str = "", - report_cer: bool = True, - report_wer: bool = True, - sym_sos: str = "", - sym_eos: str = "", - extract_feats_in_collect_stats: bool = True, - lsm_weight: float = 0.0, - length_normalized_loss: bool = False, - ) -> None: - """Construct an ESPnetASRTransducerModel object.""" - super().__init__() - - assert check_argument_types() - - # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos) - self.blank_id = 0 - - if sym_sos in token_list: - self.sos = token_list.index(sym_sos) - else: - self.sos = vocab_size - 1 - if sym_eos in token_list: - self.eos = token_list.index(sym_eos) - else: - self.eos = vocab_size - 1 - - self.vocab_size = vocab_size - self.ignore_id = ignore_id - self.token_list = token_list.copy() - - self.sym_space = sym_space - self.sym_blank = sym_blank - - self.frontend = frontend - self.specaug = specaug - self.normalize = normalize - - self.encoder = encoder - self.decoder = decoder - self.joint_network = joint_network - - self.criterion_transducer = None - self.error_calculator = None - - self.use_auxiliary_ctc = auxiliary_ctc_weight > 0 - self.use_auxiliary_att = auxiliary_att_weight > 0 - self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0 - - if self.use_auxiliary_ctc: - self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size) - self.ctc_dropout_rate = auxiliary_ctc_dropout_rate - - if self.use_auxiliary_att: - self.att_decoder = att_decoder - - self.criterion_att = LabelSmoothingLoss( - size=vocab_size, - padding_idx=ignore_id, - smoothing=lsm_weight, - normalize_length=length_normalized_loss, - ) - - if self.use_auxiliary_lm_loss: - self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size) - self.lm_loss_smoothing = auxiliary_lm_loss_smoothing - - self.transducer_weight = transducer_weight - self.fastemit_lambda = fastemit_lambda - - self.auxiliary_ctc_weight = auxiliary_ctc_weight - self.auxiliary_att_weight = auxiliary_att_weight - self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight - - self.report_cer = report_cer - self.report_wer = report_wer - - self.extract_feats_in_collect_stats = extract_feats_in_collect_stats - - def forward( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - text: torch.Tensor, - text_lengths: torch.Tensor, - **kwargs, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: - """Forward architecture and compute loss(es). - - Args: - speech: Speech sequences. (B, S) - speech_lengths: Speech sequences lengths. (B,) - text: Label ID sequences. (B, L) - text_lengths: Label ID sequences lengths. (B,) - kwargs: Contains "utts_id". - - Return: - loss: Main loss value. - stats: Task statistics. - weight: Task weights. - - """ - assert text_lengths.dim() == 1, text_lengths.shape - assert ( - speech.shape[0] - == speech_lengths.shape[0] - == text.shape[0] - == text_lengths.shape[0] - ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) - - batch_size = speech.shape[0] - text = text[:, : text_lengths.max()] - #print(speech.shape) - # 1. Encoder - encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths) - - loss_att, loss_att_chunk = 0.0, 0.0 - - if self.use_auxiliary_att: - loss_att, _ = self._calc_att_loss( - encoder_out, encoder_out_lens, text, text_lengths - ) - loss_att_chunk, _ = self._calc_att_loss( - encoder_out_chunk, encoder_out_lens, text, text_lengths - ) - - # 2. Transducer-related I/O preparation - decoder_in, target, t_len, u_len = get_transducer_task_io( - text, - encoder_out_lens, - ignore_id=self.ignore_id, - ) - - # 3. Decoder - self.decoder.set_device(encoder_out.device) - decoder_out = self.decoder(decoder_in, u_len) - - # 4. Joint Network - joint_out = self.joint_network( - encoder_out.unsqueeze(2), decoder_out.unsqueeze(1) - ) - - joint_out_chunk = self.joint_network( - encoder_out_chunk.unsqueeze(2), decoder_out.unsqueeze(1) - ) - - # 5. Losses - loss_trans_utt, cer_trans, wer_trans = self._calc_transducer_loss( - encoder_out, - joint_out, - target, - t_len, - u_len, - ) - - loss_trans_chunk, cer_trans_chunk, wer_trans_chunk = self._calc_transducer_loss( - encoder_out_chunk, - joint_out_chunk, - target, - t_len, - u_len, - ) - - loss_ctc, loss_ctc_chunk, loss_lm = 0.0, 0.0, 0.0 - - if self.use_auxiliary_ctc: - loss_ctc = self._calc_ctc_loss( - encoder_out, - target, - t_len, - u_len, - ) - loss_ctc_chunk = self._calc_ctc_loss( - encoder_out_chunk, - target, - t_len, - u_len, - ) - - if self.use_auxiliary_lm_loss: - loss_lm = self._calc_lm_loss(decoder_out, target) - - loss_trans = loss_trans_utt + loss_trans_chunk - loss_ctc = loss_ctc + loss_ctc_chunk - loss_ctc = loss_att + loss_att_chunk - - loss = ( - self.transducer_weight * loss_trans - + self.auxiliary_ctc_weight * loss_ctc - + self.auxiliary_att_weight * loss_att - + self.auxiliary_lm_loss_weight * loss_lm - ) - - stats = dict( - loss=loss.detach(), - loss_transducer=loss_trans_utt.detach(), - loss_transducer_chunk=loss_trans_chunk.detach(), - aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None, - aux_ctc_loss_chunk=loss_ctc_chunk.detach() if loss_ctc_chunk > 0.0 else None, - aux_att_loss=loss_att.detach() if loss_att > 0.0 else None, - aux_att_loss_chunk=loss_att_chunk.detach() if loss_att_chunk > 0.0 else None, - aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None, - cer_transducer=cer_trans, - wer_transducer=wer_trans, - cer_transducer_chunk=cer_trans_chunk, - wer_transducer_chunk=wer_trans_chunk, - ) - - # force_gatherable: to-device and to-tensor if scalar for DataParallel - loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) - return loss, stats, weight - - def collect_feats( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - text: torch.Tensor, - text_lengths: torch.Tensor, - **kwargs, - ) -> Dict[str, torch.Tensor]: - """Collect features sequences and features lengths sequences. - - Args: - speech: Speech sequences. (B, S) - speech_lengths: Speech sequences lengths. (B,) - text: Label ID sequences. (B, L) - text_lengths: Label ID sequences lengths. (B,) - kwargs: Contains "utts_id". - - Return: - {}: "feats": Features sequences. (B, T, D_feats), - "feats_lengths": Features sequences lengths. (B,) - - """ - if self.extract_feats_in_collect_stats: - feats, feats_lengths = self._extract_feats(speech, speech_lengths) - else: - # Generate dummy stats if extract_feats_in_collect_stats is False - logging.warning( - "Generating dummy stats for feats and feats_lengths, " - "because encoder_conf.extract_feats_in_collect_stats is " - f"{self.extract_feats_in_collect_stats}" - ) - - feats, feats_lengths = speech, speech_lengths - - return {"feats": feats, "feats_lengths": feats_lengths} - - def encode( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Encoder speech sequences. - - Args: - speech: Speech sequences. (B, S) - speech_lengths: Speech sequences lengths. (B,) - - Return: - encoder_out: Encoder outputs. (B, T, D_enc) - encoder_out_lens: Encoder outputs lengths. (B,) - - """ - with autocast(False): - # 1. Extract feats - feats, feats_lengths = self._extract_feats(speech, speech_lengths) - - # 2. Data augmentation - if self.specaug is not None and self.training: - feats, feats_lengths = self.specaug(feats, feats_lengths) - - # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN - if self.normalize is not None: - feats, feats_lengths = self.normalize(feats, feats_lengths) - - # 4. Forward encoder - encoder_out, encoder_out_chunk, encoder_out_lens = self.encoder(feats, feats_lengths) - - assert encoder_out.size(0) == speech.size(0), ( - encoder_out.size(), - speech.size(0), - ) - assert encoder_out.size(1) <= encoder_out_lens.max(), ( - encoder_out.size(), - encoder_out_lens.max(), - ) - - return encoder_out, encoder_out_chunk, encoder_out_lens - - def _extract_feats( - self, speech: torch.Tensor, speech_lengths: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Extract features sequences and features sequences lengths. - - Args: - speech: Speech sequences. (B, S) - speech_lengths: Speech sequences lengths. (B,) - - Return: - feats: Features sequences. (B, T, D_feats) - feats_lengths: Features sequences lengths. (B,) - - """ - assert speech_lengths.dim() == 1, speech_lengths.shape - - # for data-parallel - speech = speech[:, : speech_lengths.max()] - - if self.frontend is not None: - feats, feats_lengths = self.frontend(speech, speech_lengths) - else: - feats, feats_lengths = speech, speech_lengths - - return feats, feats_lengths - - def _calc_transducer_loss( - self, - encoder_out: torch.Tensor, - joint_out: torch.Tensor, - target: torch.Tensor, - t_len: torch.Tensor, - u_len: torch.Tensor, - ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]: - """Compute Transducer loss. - - Args: - encoder_out: Encoder output sequences. (B, T, D_enc) - joint_out: Joint Network output sequences (B, T, U, D_joint) - target: Target label ID sequences. (B, L) - t_len: Encoder output sequences lengths. (B,) - u_len: Target label ID sequences lengths. (B,) - - Return: - loss_transducer: Transducer loss value. - cer_transducer: Character error rate for Transducer. - wer_transducer: Word Error Rate for Transducer. - - """ - if self.criterion_transducer is None: - try: - # from warprnnt_pytorch import RNNTLoss - # self.criterion_transducer = RNNTLoss( - # reduction="mean", - # fastemit_lambda=self.fastemit_lambda, - # ) - from warp_rnnt import rnnt_loss as RNNTLoss - self.criterion_transducer = RNNTLoss - - except ImportError: - logging.error( - "warp-rnnt was not installed." - "Please consult the installation documentation." - ) - exit(1) - - # loss_transducer = self.criterion_transducer( - # joint_out, - # target, - # t_len, - # u_len, - # ) - log_probs = torch.log_softmax(joint_out, dim=-1) - - loss_transducer = self.criterion_transducer( - log_probs, - target, - t_len, - u_len, - reduction="mean", - blank=self.blank_id, - fastemit_lambda=self.fastemit_lambda, - gather=True, - ) - - if not self.training and (self.report_cer or self.report_wer): - if self.error_calculator is None: - self.error_calculator = ErrorCalculator( - self.decoder, - self.joint_network, - self.token_list, - self.sym_space, - self.sym_blank, - report_cer=self.report_cer, - report_wer=self.report_wer, - ) - - cer_transducer, wer_transducer = self.error_calculator(encoder_out, target) - return loss_transducer, cer_transducer, wer_transducer - - return loss_transducer, None, None - - def _calc_ctc_loss( - self, - encoder_out: torch.Tensor, - target: torch.Tensor, - t_len: torch.Tensor, - u_len: torch.Tensor, - ) -> torch.Tensor: - """Compute CTC loss. - - Args: - encoder_out: Encoder output sequences. (B, T, D_enc) - target: Target label ID sequences. (B, L) - t_len: Encoder output sequences lengths. (B,) - u_len: Target label ID sequences lengths. (B,) - - Return: - loss_ctc: CTC loss value. - - """ - ctc_in = self.ctc_lin( - torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate) - ) - ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1) - - target_mask = target != 0 - ctc_target = target[target_mask].cpu() - - with torch.backends.cudnn.flags(deterministic=True): - loss_ctc = torch.nn.functional.ctc_loss( - ctc_in, - ctc_target, - t_len, - u_len, - zero_infinity=True, - reduction="sum", - ) - loss_ctc /= target.size(0) - - return loss_ctc - - def _calc_lm_loss( - self, - decoder_out: torch.Tensor, - target: torch.Tensor, - ) -> torch.Tensor: - """Compute LM loss. - - Args: - decoder_out: Decoder output sequences. (B, U, D_dec) - target: Target label ID sequences. (B, L) - - Return: - loss_lm: LM loss value. - - """ - lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size) - lm_target = target.view(-1).type(torch.int64) - - with torch.no_grad(): - true_dist = lm_loss_in.clone() - true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1)) - - # Ignore blank ID (0) - ignore = lm_target == 0 - lm_target = lm_target.masked_fill(ignore, 0) - - true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing)) - - loss_lm = torch.nn.functional.kl_div( - torch.log_softmax(lm_loss_in, dim=1), - true_dist, - reduction="none", - ) - loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size( - 0 - ) - - return loss_lm - - def _calc_att_loss( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ys_pad: torch.Tensor, - ys_pad_lens: torch.Tensor, - ): - if hasattr(self, "lang_token_id") and self.lang_token_id is not None: - ys_pad = torch.cat( - [ - self.lang_token_id.repeat(ys_pad.size(0), 1).to(ys_pad.device), - ys_pad, - ], - dim=1, - ) - ys_pad_lens += 1 - - ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) - ys_in_lens = ys_pad_lens + 1 - - # 1. Forward decoder - decoder_out, _ = self.att_decoder( - encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens - ) - - # 2. Compute attention loss - loss_att = self.criterion_att(decoder_out, ys_out_pad) - acc_att = th_accuracy( - decoder_out.view(-1, self.vocab_size), - ys_out_pad, - ignore_label=self.ignore_id, - ) - - return loss_att, acc_att diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py index b7b552ce6..9777ceed6 100644 --- a/funasr/models/encoder/conformer_encoder.py +++ b/funasr/models/encoder/conformer_encoder.py @@ -894,7 +894,7 @@ class CausalConvolution(torch.nn.Module): return x, cache -class ConformerChunkEncoder(torch.nn.Module): +class ConformerChunkEncoder(AbsEncoder): """Encoder module definition. Args: input_size: Input size. @@ -1007,7 +1007,7 @@ class ConformerChunkEncoder(torch.nn.Module): output_size, ) - self.output_size = output_size + self._output_size = output_size self.dynamic_chunk_training = dynamic_chunk_training self.short_chunk_threshold = short_chunk_threshold @@ -1020,6 +1020,9 @@ class ConformerChunkEncoder(torch.nn.Module): self.time_reduction_factor = time_reduction_factor + def output_size(self) -> int: + return self._output_size + def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int: """Return the corresponding number of sample for a given chunk size, in frames. Where size is the number of features frames after applying subsampling. diff --git a/funasr/models/rnnt_predictor/__init__.py b/funasr/models/rnnt_predictor/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/funasr/models/rnnt_predictor/abs_decoder.py b/funasr/models/rnnt_predictor/abs_decoder.py deleted file mode 100644 index 5b4a335be..000000000 --- a/funasr/models/rnnt_predictor/abs_decoder.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Abstract decoder definition for Transducer models.""" - -from abc import ABC, abstractmethod -from typing import Any, List, Optional, Tuple - -import torch - - -class AbsDecoder(torch.nn.Module, ABC): - """Abstract decoder module.""" - - @abstractmethod - def forward(self, labels: torch.Tensor) -> torch.Tensor: - """Encode source label sequences. - - Args: - labels: Label ID sequences. (B, L) - - Returns: - dec_out: Decoder output sequences. (B, T, D_dec) - - """ - raise NotImplementedError - - @abstractmethod - def score( - self, - label: torch.Tensor, - label_sequence: List[int], - dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]], - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]]: - """One-step forward hypothesis. - - Args: - label: Previous label. (1, 1) - label_sequence: Current label sequence. - dec_state: Previous decoder hidden states. - ((N, 1, D_dec), (N, 1, D_dec) or None) or None - - Returns: - dec_out: Decoder output sequence. (1, D_dec) or (1, D_emb) - dec_state: Decoder hidden states. - ((N, 1, D_dec), (N, 1, D_dec) or None) or None - - """ - raise NotImplementedError - - @abstractmethod - def batch_score( - self, - hyps: List[Any], - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]]: - """One-step forward hypotheses. - - Args: - hyps: Hypotheses. - - Returns: - dec_out: Decoder output sequences. (B, D_dec) or (B, D_emb) - states: Decoder hidden states. - ((N, B, D_dec), (N, B, D_dec) or None) or None - - """ - raise NotImplementedError - - @abstractmethod - def set_device(self, device: torch.Tensor) -> None: - """Set GPU device to use. - - Args: - device: Device ID. - - """ - raise NotImplementedError - - @abstractmethod - def init_state( - self, batch_size: int - ) -> Optional[Tuple[torch.Tensor, Optional[torch.tensor]]]: - """Initialize decoder states. - - Args: - batch_size: Batch size. - - Returns: - : Initial decoder hidden states. - ((N, B, D_dec), (N, B, D_dec) or None) or None - - """ - raise NotImplementedError - - @abstractmethod - def select_state( - self, - states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None, - idx: int = 0, - ) -> Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]: - """Get specified ID state from batch of states, if provided. - - Args: - states: Decoder hidden states. - ((N, B, D_dec), (N, B, D_dec) or None) or None - idx: State ID to extract. - - Returns: - : Decoder hidden state for given ID. - ((N, 1, D_dec), (N, 1, D_dec) or None) or None - - """ - raise NotImplementedError diff --git a/funasr/models/rnnt_predictor/stateless_decoder.py b/funasr/models/rnnt_predictor/stateless_decoder.py deleted file mode 100644 index 70cd877f2..000000000 --- a/funasr/models/rnnt_predictor/stateless_decoder.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Stateless decoder definition for Transducer models.""" - -from typing import List, Optional, Tuple - -import torch -from typeguard import check_argument_types - -from funasr.modules.beam_search.beam_search_transducer import Hypothesis -from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder -from funasr.models.specaug.specaug import SpecAug - -class StatelessDecoder(AbsDecoder): - """Stateless Transducer decoder module. - - Args: - vocab_size: Output size. - embed_size: Embedding size. - embed_dropout_rate: Dropout rate for embedding layer. - embed_pad: Embed/Blank symbol ID. - - """ - - def __init__( - self, - vocab_size: int, - embed_size: int = 256, - embed_dropout_rate: float = 0.0, - embed_pad: int = 0, - ) -> None: - """Construct a StatelessDecoder object.""" - super().__init__() - - assert check_argument_types() - - self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad) - self.embed_dropout_rate = torch.nn.Dropout(p=embed_dropout_rate) - - self.output_size = embed_size - self.vocab_size = vocab_size - - self.device = next(self.parameters()).device - self.score_cache = {} - - - - def forward( - self, - labels: torch.Tensor, - label_lens: torch.Tensor, - states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None, - ) -> torch.Tensor: - """Encode source label sequences. - - Args: - labels: Label ID sequences. (B, L) - states: Decoder hidden states. None - - Returns: - dec_embed: Decoder output sequences. (B, U, D_emb) - - """ - dec_embed = self.embed_dropout_rate(self.embed(labels)) - return dec_embed - - def score( - self, - label: torch.Tensor, - label_sequence: List[int], - state: None, - ) -> Tuple[torch.Tensor, None]: - """One-step forward hypothesis. - - Args: - label: Previous label. (1, 1) - label_sequence: Current label sequence. - state: Previous decoder hidden states. None - - Returns: - dec_out: Decoder output sequence. (1, D_emb) - state: Decoder hidden states. None - - """ - str_labels = "_".join(map(str, label_sequence)) - - if str_labels in self.score_cache: - dec_embed = self.score_cache[str_labels] - else: - dec_embed = self.embed(label) - - self.score_cache[str_labels] = dec_embed - - return dec_embed[0], None - - def batch_score( - self, - hyps: List[Hypothesis], - ) -> Tuple[torch.Tensor, None]: - """One-step forward hypotheses. - - Args: - hyps: Hypotheses. - - Returns: - dec_out: Decoder output sequences. (B, D_dec) - states: Decoder hidden states. None - - """ - labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device) - dec_embed = self.embed(labels) - - return dec_embed.squeeze(1), None - - def set_device(self, device: torch.device) -> None: - """Set GPU device to use. - - Args: - device: Device ID. - - """ - self.device = device - - def init_state(self, batch_size: int) -> None: - """Initialize decoder states. - - Args: - batch_size: Batch size. - - Returns: - : Initial decoder hidden states. None - - """ - return None - - def select_state(self, states: Optional[torch.Tensor], idx: int) -> None: - """Get specified ID state from decoder hidden states. - - Args: - states: Decoder hidden states. None - idx: State ID to extract. - - Returns: - : Decoder hidden state for given ID. None - - """ - return None diff --git a/funasr/modules/beam_search/beam_search_transducer.py b/funasr/modules/beam_search/beam_search_transducer.py index 8b7e613fe..3eb8e08d0 100644 --- a/funasr/modules/beam_search/beam_search_transducer.py +++ b/funasr/modules/beam_search/beam_search_transducer.py @@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch -from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder from funasr.models.joint_net.joint_network import JointNetwork @@ -68,7 +67,7 @@ class BeamSearchTransducer: def __init__( self, - decoder: AbsDecoder, + decoder, joint_network: JointNetwork, beam_size: int, lm: Optional[torch.nn.Module] = None, diff --git a/funasr/modules/e2e_asr_common.py b/funasr/modules/e2e_asr_common.py index a01cd5ef1..f430fcb43 100644 --- a/funasr/modules/e2e_asr_common.py +++ b/funasr/modules/e2e_asr_common.py @@ -18,7 +18,6 @@ import six import torch from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer -from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder from funasr.models.joint_net.joint_network import JointNetwork def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): @@ -268,7 +267,7 @@ class ErrorCalculatorTransducer: def __init__( self, - decoder: AbsDecoder, + decoder, joint_network: JointNetwork, token_list: List[int], sym_space: str, diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py index e15147332..87db05c67 100644 --- a/funasr/tasks/asr.py +++ b/funasr/tasks/asr.py @@ -38,13 +38,16 @@ from funasr.models.decoder.transformer_decoder import ( from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN from funasr.models.decoder.transformer_decoder import TransformerDecoder from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder +from funasr.models.decoder.rnnt_decoder import RNNTDecoder +from funasr.models.joint_net.joint_network import JointNetwork from funasr.models.e2e_asr import ESPnetASRModel from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer from funasr.models.e2e_tp import TimestampPredictor from funasr.models.e2e_asr_mfcca import MFCCA from funasr.models.e2e_uni_asr import UniASR +from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel from funasr.models.encoder.abs_encoder import AbsEncoder -from funasr.models.encoder.conformer_encoder import ConformerEncoder +from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder from funasr.models.encoder.data2vec_encoder import Data2VecEncoder from funasr.models.encoder.rnn_encoder import RNNEncoder from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt @@ -150,6 +153,7 @@ encoder_choices = ClassChoices( sanm_chunk_opt=SANMEncoderChunkOpt, data2vec_encoder=Data2VecEncoder, mfcca_enc=MFCCAEncoder, + chunk_conformer=ConformerChunkEncoder, ), type_check=AbsEncoder, default="rnn", @@ -207,6 +211,16 @@ decoder_choices2 = ClassChoices( type_check=AbsDecoder, default="rnn", ) + +rnnt_decoder_choices = ClassChoices( + "rnnt_decoder", + classes=dict( + rnnt=RNNTDecoder, + ), + type_check=RNNTDecoder, + default="rnnt", +) + predictor_choices = ClassChoices( name="predictor", classes=dict( @@ -1331,3 +1345,378 @@ class ASRTaskAligner(ASRTaskParaformer): ) -> Tuple[str, ...]: retval = ("speech", "text") return retval + + +class ASRTransducerTask(AbsTask): + """ASR Transducer Task definition.""" + + num_optimizers: int = 1 + + class_choices_list = [ + frontend_choices, + specaug_choices, + normalize_choices, + encoder_choices, + rnnt_decoder_choices, + ] + + trainer = Trainer + + @classmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + """Add Transducer task arguments. + Args: + cls: ASRTransducerTask object. + parser: Transducer arguments parser. + """ + group = parser.add_argument_group(description="Task related.") + + # required = parser.get_default("required") + # required += ["token_list"] + + group.add_argument( + "--token_list", + type=str_or_none, + default=None, + help="Integer-string mapper for tokens.", + ) + group.add_argument( + "--split_with_space", + type=str2bool, + default=True, + help="whether to split text using ", + ) + group.add_argument( + "--input_size", + type=int_or_none, + default=None, + help="The number of dimensions for input features.", + ) + group.add_argument( + "--init", + type=str_or_none, + default=None, + help="Type of model initialization to use.", + ) + group.add_argument( + "--model_conf", + action=NestedDictAction, + default=get_default_kwargs(TransducerModel), + help="The keyword arguments for the model class.", + ) + # group.add_argument( + # "--encoder_conf", + # action=NestedDictAction, + # default={}, + # help="The keyword arguments for the encoder class.", + # ) + group.add_argument( + "--joint_network_conf", + action=NestedDictAction, + default={}, + help="The keyword arguments for the joint network class.", + ) + group = parser.add_argument_group(description="Preprocess related.") + group.add_argument( + "--use_preprocessor", + type=str2bool, + default=True, + help="Whether to apply preprocessing to input data.", + ) + group.add_argument( + "--token_type", + type=str, + default="bpe", + choices=["bpe", "char", "word", "phn"], + help="The type of tokens to use during tokenization.", + ) + group.add_argument( + "--bpemodel", + type=str_or_none, + default=None, + help="The path of the sentencepiece model.", + ) + parser.add_argument( + "--non_linguistic_symbols", + type=str_or_none, + help="The 'non_linguistic_symbols' file path.", + ) + parser.add_argument( + "--cleaner", + type=str_or_none, + choices=[None, "tacotron", "jaconv", "vietnamese"], + default=None, + help="Text cleaner to use.", + ) + parser.add_argument( + "--g2p", + type=str_or_none, + choices=g2p_choices, + default=None, + help="g2p method to use if --token_type=phn.", + ) + parser.add_argument( + "--speech_volume_normalize", + type=float_or_none, + default=None, + help="Normalization value for maximum amplitude scaling.", + ) + parser.add_argument( + "--rir_scp", + type=str_or_none, + default=None, + help="The RIR SCP file path.", + ) + parser.add_argument( + "--rir_apply_prob", + type=float, + default=1.0, + help="The probability of the applied RIR convolution.", + ) + parser.add_argument( + "--noise_scp", + type=str_or_none, + default=None, + help="The path of noise SCP file.", + ) + parser.add_argument( + "--noise_apply_prob", + type=float, + default=1.0, + help="The probability of the applied noise addition.", + ) + parser.add_argument( + "--noise_db_range", + type=str, + default="13_15", + help="The range of the noise decibel level.", + ) + for class_choices in cls.class_choices_list: + # Append -- and --_conf. + # e.g. --decoder and --decoder_conf + class_choices.add_arguments(group) + + @classmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[ + [Collection[Tuple[str, Dict[str, np.ndarray]]]], + Tuple[List[str], Dict[str, torch.Tensor]], + ]: + """Build collate function. + Args: + cls: ASRTransducerTask object. + args: Task arguments. + train: Training mode. + Return: + : Callable collate function. + """ + assert check_argument_types() + + return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) + + @classmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + """Build pre-processing function. + Args: + cls: ASRTransducerTask object. + args: Task arguments. + train: Training mode. + Return: + : Callable pre-processing function. + """ + assert check_argument_types() + + if args.use_preprocessor: + retval = CommonPreprocessor( + train=train, + token_type=args.token_type, + token_list=args.token_list, + bpemodel=args.bpemodel, + non_linguistic_symbols=args.non_linguistic_symbols, + text_cleaner=args.cleaner, + g2p_type=args.g2p, + split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False, + rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, + rir_apply_prob=args.rir_apply_prob + if hasattr(args, "rir_apply_prob") + else 1.0, + noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None, + noise_apply_prob=args.noise_apply_prob + if hasattr(args, "noise_apply_prob") + else 1.0, + noise_db_range=args.noise_db_range + if hasattr(args, "noise_db_range") + else "13_15", + speech_volume_normalize=args.speech_volume_normalize + if hasattr(args, "rir_scp") + else None, + ) + else: + retval = None + + assert check_return_type(retval) + return retval + + @classmethod + def required_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + """Required data depending on task mode. + Args: + cls: ASRTransducerTask object. + train: Training mode. + inference: Inference mode. + Return: + retval: Required task data. + """ + if not inference: + retval = ("speech", "text") + else: + retval = ("speech",) + + return retval + + @classmethod + def optional_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + """Optional data depending on task mode. + Args: + cls: ASRTransducerTask object. + train: Training mode. + inference: Inference mode. + Return: + retval: Optional task data. + """ + retval = () + assert check_return_type(retval) + + return retval + + @classmethod + def build_model(cls, args: argparse.Namespace) -> TransducerModel: + """Required data depending on task mode. + Args: + cls: ASRTransducerTask object. + args: Task arguments. + Return: + model: ASR Transducer model. + """ + assert check_argument_types() + + if isinstance(args.token_list, str): + with open(args.token_list, encoding="utf-8") as f: + token_list = [line.rstrip() for line in f] + + # Overwriting token_list to keep it as "portable". + args.token_list = list(token_list) + elif isinstance(args.token_list, (tuple, list)): + token_list = list(args.token_list) + else: + raise RuntimeError("token_list must be str or list") + vocab_size = len(token_list) + logging.info(f"Vocabulary size: {vocab_size }") + + # 1. frontend + if args.input_size is None: + # Extract features in the model + frontend_class = frontend_choices.get_class(args.frontend) + frontend = frontend_class(**args.frontend_conf) + input_size = frontend.output_size() + else: + # Give features from data-loader + frontend = None + input_size = args.input_size + + # 2. Data augmentation for spectrogram + if args.specaug is not None: + specaug_class = specaug_choices.get_class(args.specaug) + specaug = specaug_class(**args.specaug_conf) + else: + specaug = None + + # 3. Normalization layer + if args.normalize is not None: + normalize_class = normalize_choices.get_class(args.normalize) + normalize = normalize_class(**args.normalize_conf) + else: + normalize = None + + # 4. Encoder + + if getattr(args, "encoder", None) is not None: + encoder_class = encoder_choices.get_class(args.encoder) + encoder = encoder_class(input_size, **args.encoder_conf) + else: + encoder = Encoder(input_size, **args.encoder_conf) + encoder_output_size = encoder.output_size() + + # 5. Decoder + rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder) + decoder = rnnt_decoder_class( + vocab_size, + **args.rnnt_decoder_conf, + ) + decoder_output_size = decoder.output_size + + if getattr(args, "decoder", None) is not None: + att_decoder_class = decoder_choices.get_class(args.att_decoder) + + att_decoder = att_decoder_class( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + **args.decoder_conf, + ) + else: + att_decoder = None + # 6. Joint Network + joint_network = JointNetwork( + vocab_size, + encoder_output_size, + decoder_output_size, + **args.joint_network_conf, + ) + + # 7. Build model + + if encoder.unified_model_training: + model = UnifiedTransducerModel( + vocab_size=vocab_size, + token_list=token_list, + frontend=frontend, + specaug=specaug, + normalize=normalize, + encoder=encoder, + decoder=decoder, + att_decoder=att_decoder, + joint_network=joint_network, + **args.model_conf, + ) + + else: + model = TransducerModel( + vocab_size=vocab_size, + token_list=token_list, + frontend=frontend, + specaug=specaug, + normalize=normalize, + encoder=encoder, + decoder=decoder, + att_decoder=att_decoder, + joint_network=joint_network, + **args.model_conf, + ) + + # 8. Initialize model + if args.init is not None: + raise NotImplementedError( + "Currently not supported.", + "Initialization part will be reworked in a short future.", + ) + + #assert check_return_type(model) + + return model diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py deleted file mode 100644 index d4136d068..000000000 --- a/funasr/tasks/asr_transducer.py +++ /dev/null @@ -1,477 +0,0 @@ -"""ASR Transducer Task.""" - -import argparse -import logging -from typing import Callable, Collection, Dict, List, Optional, Tuple - -import numpy as np -import torch -from typeguard import check_argument_types, check_return_type - -from funasr.models.frontend.abs_frontend import AbsFrontend -from funasr.models.frontend.default import DefaultFrontend -from funasr.models.frontend.windowing import SlidingWindow -from funasr.models.specaug.abs_specaug import AbsSpecAug -from funasr.models.specaug.specaug import SpecAug -from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder -from funasr.models.decoder.transformer_decoder import ( - DynamicConvolution2DTransformerDecoder, - DynamicConvolutionTransformerDecoder, - LightweightConvolution2DTransformerDecoder, - LightweightConvolutionTransformerDecoder, - TransformerDecoder, -) -from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder -from funasr.models.rnnt_predictor.rnn_decoder import RNNDecoder -from funasr.models.rnnt_predictor.stateless_decoder import StatelessDecoder -from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder -from funasr.models.e2e_asr_transducer import TransducerModel -from funasr.models.e2e_asr_transducer_unified import UnifiedTransducerModel -from funasr.models.joint_net.joint_network import JointNetwork -from funasr.layers.abs_normalize import AbsNormalize -from funasr.layers.global_mvn import GlobalMVN -from funasr.layers.utterance_mvn import UtteranceMVN -from funasr.tasks.abs_task import AbsTask -from funasr.text.phoneme_tokenizer import g2p_choices -from funasr.train.class_choices import ClassChoices -from funasr.datasets.collate_fn import CommonCollateFn -from funasr.datasets.preprocessor import CommonPreprocessor -from funasr.train.trainer import Trainer -from funasr.utils.get_default_kwargs import get_default_kwargs -from funasr.utils.nested_dict_action import NestedDictAction -from funasr.utils.types import float_or_none, int_or_none, str2bool, str_or_none - -frontend_choices = ClassChoices( - name="frontend", - classes=dict( - default=DefaultFrontend, - sliding_window=SlidingWindow, - ), - type_check=AbsFrontend, - default="default", -) -specaug_choices = ClassChoices( - "specaug", - classes=dict( - specaug=SpecAug, - ), - type_check=AbsSpecAug, - default=None, - optional=True, -) -normalize_choices = ClassChoices( - "normalize", - classes=dict( - global_mvn=GlobalMVN, - utterance_mvn=UtteranceMVN, - ), - type_check=AbsNormalize, - default="utterance_mvn", - optional=True, -) -encoder_choices = ClassChoices( - "encoder", - classes=dict( - chunk_conformer=ConformerChunkEncoder, - ), - default="chunk_conformer", -) - -decoder_choices = ClassChoices( - "decoder", - classes=dict( - rnn=RNNDecoder, - stateless=StatelessDecoder, - ), - type_check=AbsDecoder, - default="rnn", -) - -att_decoder_choices = ClassChoices( - "att_decoder", - classes=dict( - transformer=TransformerDecoder, - lightweight_conv=LightweightConvolutionTransformerDecoder, - lightweight_conv2d=LightweightConvolution2DTransformerDecoder, - dynamic_conv=DynamicConvolutionTransformerDecoder, - dynamic_conv2d=DynamicConvolution2DTransformerDecoder, - ), - type_check=AbsAttDecoder, - default=None, - optional=True, -) -class ASRTransducerTask(AbsTask): - """ASR Transducer Task definition.""" - - num_optimizers: int = 1 - - class_choices_list = [ - frontend_choices, - specaug_choices, - normalize_choices, - encoder_choices, - decoder_choices, - att_decoder_choices, - ] - - trainer = Trainer - - @classmethod - def add_task_arguments(cls, parser: argparse.ArgumentParser): - """Add Transducer task arguments. - Args: - cls: ASRTransducerTask object. - parser: Transducer arguments parser. - """ - group = parser.add_argument_group(description="Task related.") - - # required = parser.get_default("required") - # required += ["token_list"] - - group.add_argument( - "--token_list", - type=str_or_none, - default=None, - help="Integer-string mapper for tokens.", - ) - group.add_argument( - "--split_with_space", - type=str2bool, - default=True, - help="whether to split text using ", - ) - group.add_argument( - "--input_size", - type=int_or_none, - default=None, - help="The number of dimensions for input features.", - ) - group.add_argument( - "--init", - type=str_or_none, - default=None, - help="Type of model initialization to use.", - ) - group.add_argument( - "--model_conf", - action=NestedDictAction, - default=get_default_kwargs(TransducerModel), - help="The keyword arguments for the model class.", - ) - # group.add_argument( - # "--encoder_conf", - # action=NestedDictAction, - # default={}, - # help="The keyword arguments for the encoder class.", - # ) - group.add_argument( - "--joint_network_conf", - action=NestedDictAction, - default={}, - help="The keyword arguments for the joint network class.", - ) - group = parser.add_argument_group(description="Preprocess related.") - group.add_argument( - "--use_preprocessor", - type=str2bool, - default=True, - help="Whether to apply preprocessing to input data.", - ) - group.add_argument( - "--token_type", - type=str, - default="bpe", - choices=["bpe", "char", "word", "phn"], - help="The type of tokens to use during tokenization.", - ) - group.add_argument( - "--bpemodel", - type=str_or_none, - default=None, - help="The path of the sentencepiece model.", - ) - parser.add_argument( - "--non_linguistic_symbols", - type=str_or_none, - help="The 'non_linguistic_symbols' file path.", - ) - parser.add_argument( - "--cleaner", - type=str_or_none, - choices=[None, "tacotron", "jaconv", "vietnamese"], - default=None, - help="Text cleaner to use.", - ) - parser.add_argument( - "--g2p", - type=str_or_none, - choices=g2p_choices, - default=None, - help="g2p method to use if --token_type=phn.", - ) - parser.add_argument( - "--speech_volume_normalize", - type=float_or_none, - default=None, - help="Normalization value for maximum amplitude scaling.", - ) - parser.add_argument( - "--rir_scp", - type=str_or_none, - default=None, - help="The RIR SCP file path.", - ) - parser.add_argument( - "--rir_apply_prob", - type=float, - default=1.0, - help="The probability of the applied RIR convolution.", - ) - parser.add_argument( - "--noise_scp", - type=str_or_none, - default=None, - help="The path of noise SCP file.", - ) - parser.add_argument( - "--noise_apply_prob", - type=float, - default=1.0, - help="The probability of the applied noise addition.", - ) - parser.add_argument( - "--noise_db_range", - type=str, - default="13_15", - help="The range of the noise decibel level.", - ) - for class_choices in cls.class_choices_list: - # Append -- and --_conf. - # e.g. --decoder and --decoder_conf - class_choices.add_arguments(group) - - @classmethod - def build_collate_fn( - cls, args: argparse.Namespace, train: bool - ) -> Callable[ - [Collection[Tuple[str, Dict[str, np.ndarray]]]], - Tuple[List[str], Dict[str, torch.Tensor]], - ]: - """Build collate function. - Args: - cls: ASRTransducerTask object. - args: Task arguments. - train: Training mode. - Return: - : Callable collate function. - """ - assert check_argument_types() - - return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) - - @classmethod - def build_preprocess_fn( - cls, args: argparse.Namespace, train: bool - ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: - """Build pre-processing function. - Args: - cls: ASRTransducerTask object. - args: Task arguments. - train: Training mode. - Return: - : Callable pre-processing function. - """ - assert check_argument_types() - - if args.use_preprocessor: - retval = CommonPreprocessor( - train=train, - token_type=args.token_type, - token_list=args.token_list, - bpemodel=args.bpemodel, - non_linguistic_symbols=args.non_linguistic_symbols, - text_cleaner=args.cleaner, - g2p_type=args.g2p, - split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False, - rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, - rir_apply_prob=args.rir_apply_prob - if hasattr(args, "rir_apply_prob") - else 1.0, - noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None, - noise_apply_prob=args.noise_apply_prob - if hasattr(args, "noise_apply_prob") - else 1.0, - noise_db_range=args.noise_db_range - if hasattr(args, "noise_db_range") - else "13_15", - speech_volume_normalize=args.speech_volume_normalize - if hasattr(args, "rir_scp") - else None, - ) - else: - retval = None - - assert check_return_type(retval) - return retval - - @classmethod - def required_data_names( - cls, train: bool = True, inference: bool = False - ) -> Tuple[str, ...]: - """Required data depending on task mode. - Args: - cls: ASRTransducerTask object. - train: Training mode. - inference: Inference mode. - Return: - retval: Required task data. - """ - if not inference: - retval = ("speech", "text") - else: - retval = ("speech",) - - return retval - - @classmethod - def optional_data_names( - cls, train: bool = True, inference: bool = False - ) -> Tuple[str, ...]: - """Optional data depending on task mode. - Args: - cls: ASRTransducerTask object. - train: Training mode. - inference: Inference mode. - Return: - retval: Optional task data. - """ - retval = () - assert check_return_type(retval) - - return retval - - @classmethod - def build_model(cls, args: argparse.Namespace) -> TransducerModel: - """Required data depending on task mode. - Args: - cls: ASRTransducerTask object. - args: Task arguments. - Return: - model: ASR Transducer model. - """ - assert check_argument_types() - - if isinstance(args.token_list, str): - with open(args.token_list, encoding="utf-8") as f: - token_list = [line.rstrip() for line in f] - - # Overwriting token_list to keep it as "portable". - args.token_list = list(token_list) - elif isinstance(args.token_list, (tuple, list)): - token_list = list(args.token_list) - else: - raise RuntimeError("token_list must be str or list") - vocab_size = len(token_list) - logging.info(f"Vocabulary size: {vocab_size }") - - # 1. frontend - if args.input_size is None: - # Extract features in the model - frontend_class = frontend_choices.get_class(args.frontend) - frontend = frontend_class(**args.frontend_conf) - input_size = frontend.output_size() - else: - # Give features from data-loader - frontend = None - input_size = args.input_size - - # 2. Data augmentation for spectrogram - if args.specaug is not None: - specaug_class = specaug_choices.get_class(args.specaug) - specaug = specaug_class(**args.specaug_conf) - else: - specaug = None - - # 3. Normalization layer - if args.normalize is not None: - normalize_class = normalize_choices.get_class(args.normalize) - normalize = normalize_class(**args.normalize_conf) - else: - normalize = None - - # 4. Encoder - - if getattr(args, "encoder", None) is not None: - encoder_class = encoder_choices.get_class(args.encoder) - encoder = encoder_class(input_size, **args.encoder_conf) - else: - encoder = Encoder(input_size, **args.encoder_conf) - encoder_output_size = encoder.output_size - - # 5. Decoder - decoder_class = decoder_choices.get_class(args.decoder) - decoder = decoder_class( - vocab_size, - **args.decoder_conf, - ) - decoder_output_size = decoder.output_size - - if getattr(args, "att_decoder", None) is not None: - att_decoder_class = att_decoder_choices.get_class(args.att_decoder) - - att_decoder = att_decoder_class( - vocab_size=vocab_size, - encoder_output_size=encoder_output_size, - **args.att_decoder_conf, - ) - else: - att_decoder = None - - # 6. Joint Network - joint_network = JointNetwork( - vocab_size, - encoder_output_size, - decoder_output_size, - **args.joint_network_conf, - ) - - # 7. Build model - - if encoder.unified_model_training: - model = UnifiedTransducerModel( - vocab_size=vocab_size, - token_list=token_list, - frontend=frontend, - specaug=specaug, - normalize=normalize, - encoder=encoder, - decoder=decoder, - att_decoder=att_decoder, - joint_network=joint_network, - **args.model_conf, - ) - - else: - model = TransducerModel( - vocab_size=vocab_size, - token_list=token_list, - frontend=frontend, - specaug=specaug, - normalize=normalize, - encoder=encoder, - decoder=decoder, - att_decoder=att_decoder, - joint_network=joint_network, - **args.model_conf, - ) - - # 8. Initialize model - if args.init is not None: - raise NotImplementedError( - "Currently not supported.", - "Initialization part will be reworked in a short future.", - ) - - #assert check_return_type(model) - - return model