From 96bae0153cb04c82d6e7ca7cb9654d55eb987567 Mon Sep 17 00:00:00 2001 From: aky15 Date: Wed, 15 Mar 2023 17:34:34 +0800 Subject: [PATCH] rnnt bug fix --- funasr/bin/asr_inference_rnnt.py | 145 ++---------------- .../encoder/blocks/conv_input.py | 9 +- funasr/tasks/abs_task.py | 2 +- 3 files changed, 20 insertions(+), 136 deletions(-) diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py index f651f118d..c8a2916c2 100644 --- a/funasr/bin/asr_inference_rnnt.py +++ b/funasr/bin/asr_inference_rnnt.py @@ -31,7 +31,7 @@ from funasr.torch_utils.set_all_random_seed import set_all_random_seed from funasr.utils import config_argparse from funasr.utils.types import str2bool, str2triple_str, str_or_none from funasr.utils.cli_utils import get_commandline_args - +from funasr.models.frontend.wav_frontend import WavFrontend class Speech2Text: """Speech2Text class for Transducer models. @@ -62,6 +62,7 @@ class Speech2Text: self, asr_train_config: Union[Path, str] = None, asr_model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, beam_search_config: Dict[str, Any] = None, lm_train_config: Union[Path, str] = None, lm_file: Union[Path, str] = None, @@ -86,11 +87,14 @@ class Speech2Text: super().__init__() assert check_argument_types() - asr_model, asr_train_args = ASRTransducerTask.build_model_from_file( - asr_train_config, asr_model_file, device + asr_train_config, asr_model_file, cmvn_file, device ) + frontend = None + if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: + frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) + if quantize_asr_model: if quantize_modules is not None: if not all([q in ["LSTM", "Linear"] for q in quantize_modules]): @@ -156,7 +160,7 @@ class Speech2Text: tokenizer = build_tokenizer(token_type=token_type) converter = TokenIDConverter(token_list=token_list) logging.info(f"Text tokenizer: {tokenizer}") - + self.asr_model = asr_model self.asr_train_args = asr_train_args self.device = device @@ -181,23 +185,13 @@ class Speech2Text: self.simu_streaming = False self.asr_model.encoder.dynamic_chunk_training = False - self.n_fft = asr_train_args.frontend_conf.get("n_fft", 512) - self.hop_length = asr_train_args.frontend_conf.get("hop_length", 128) - - if asr_train_args.frontend_conf.get("win_length", None) is not None: - self.frontend_window_size = asr_train_args.frontend_conf["win_length"] - else: - self.frontend_window_size = self.n_fft - + self.frontend = frontend self.window_size = self.chunk_size + self.right_context - self._raw_ctx = self.asr_model.encoder.get_encoder_input_raw_size( - self.window_size, self.hop_length - ) + self._ctx = self.asr_model.encoder.get_encoder_input_size( self.window_size ) - #self.last_chunk_length = ( # self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 #) * self.hop_length @@ -218,112 +212,6 @@ class Speech2Text: self.num_processed_frames = torch.tensor([[0]], device=self.device) - def apply_frontend( - self, speech: torch.Tensor, is_final: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Forward frontend. - Args: - speech: Speech data. (S) - is_final: Whether speech corresponds to the final (or only) chunk of data. - Returns: - feats: Features sequence. (1, T_in, F) - feats_lengths: Features sequence length. (1, T_in, F) - """ - if self.frontend_cache is not None: - speech = torch.cat([self.frontend_cache["waveform_buffer"], speech], dim=0) - - if is_final: - if self.streaming and speech.size(0) < self.last_chunk_length: - pad = torch.zeros( - self.last_chunk_length - speech.size(0), dtype=speech.dtype - ) - speech = torch.cat([speech, pad], dim=0) - - speech_to_process = speech - waveform_buffer = None - else: - n_frames = ( - speech.size(0) - (self.frontend_window_size - self.hop_length) - ) // self.hop_length - - n_residual = ( - speech.size(0) - (self.frontend_window_size - self.hop_length) - ) % self.hop_length - - speech_to_process = speech.narrow( - 0, - 0, - (self.frontend_window_size - self.hop_length) - + n_frames * self.hop_length, - ) - - waveform_buffer = speech.narrow( - 0, - speech.size(0) - - (self.frontend_window_size - self.hop_length) - - n_residual, - (self.frontend_window_size - self.hop_length) + n_residual, - ).clone() - - speech_to_process = speech_to_process.unsqueeze(0).to( - getattr(torch, self.dtype) - ) - lengths = speech_to_process.new_full( - [1], dtype=torch.long, fill_value=speech_to_process.size(1) - ) - batch = {"speech": speech_to_process, "speech_lengths": lengths} - batch = to_device(batch, device=self.device) - - feats, feats_lengths = self.asr_model._extract_feats(**batch) - if self.asr_model.normalize is not None: - feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) - - if is_final: - if self.frontend_cache is None: - pass - else: - feats = feats.narrow( - 1, - math.ceil( - math.ceil(self.frontend_window_size / self.hop_length) / 2 - ), - feats.size(1) - - math.ceil( - math.ceil(self.frontend_window_size / self.hop_length) / 2 - ), - ) - else: - if self.frontend_cache is None: - feats = feats.narrow( - 1, - 0, - feats.size(1) - - math.ceil( - math.ceil(self.frontend_window_size / self.hop_length) / 2 - ), - ) - else: - feats = feats.narrow( - 1, - math.ceil( - math.ceil(self.frontend_window_size / self.hop_length) / 2 - ), - feats.size(1) - - 2 - * math.ceil( - math.ceil(self.frontend_window_size / self.hop_length) / 2 - ), - ) - - feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) - - if is_final: - self.frontend_cache = None - else: - self.frontend_cache = {"waveform_buffer": waveform_buffer} - - return feats, feats_lengths - @torch.no_grad() def streaming_decode( self, @@ -410,14 +298,9 @@ class Speech2Text: if isinstance(speech, np.ndarray): speech = torch.tensor(speech) - # lengths: (1,) - # feats, feats_length = self.apply_frontend(speech) feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) - # lengths: (1,) feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) - # print(feats.shape) - # print(feats_lengths) if self.asr_model.normalize is not None: feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) @@ -495,6 +378,7 @@ def inference( data_path_and_name_and_type: Sequence[Tuple[str, str, str]], asr_train_config: Optional[str], asr_model_file: Optional[str], + cmvn_file: Optional[str], beam_search_config: Optional[dict], lm_train_config: Optional[str], lm_file: Optional[str], @@ -562,7 +446,6 @@ def inference( device = "cuda" else: device = "cpu" - # 1. Set random-seed set_all_random_seed(seed) @@ -570,6 +453,7 @@ def inference( speech2text_kwargs = dict( asr_train_config=asr_train_config, asr_model_file=asr_model_file, + cmvn_file=cmvn_file, beam_search_config=beam_search_config, lm_train_config=lm_train_config, lm_file=lm_file, @@ -719,6 +603,11 @@ def get_parser(): type=str, help="ASR model parameter file", ) + group.add_argument( + "--cmvn_file", + type=str, + help="Global cmvn file", + ) group.add_argument( "--lm_train_config", type=str, diff --git a/funasr/models_transducer/encoder/blocks/conv_input.py b/funasr/models_transducer/encoder/blocks/conv_input.py index 931d0f0eb..c68c73b3d 100644 --- a/funasr/models_transducer/encoder/blocks/conv_input.py +++ b/funasr/models_transducer/encoder/blocks/conv_input.py @@ -120,7 +120,7 @@ class ConvInput(torch.nn.Module): self.create_new_mask = self.create_new_conv2d_mask self.vgg_like = vgg_like - self.min_frame_length = 2 + self.min_frame_length = 7 if output_size is not None: self.output = torch.nn.Linear(output_proj, output_size) @@ -218,9 +218,4 @@ class ConvInput(torch.nn.Module): : Number of frames before subsampling. """ - if self.subsampling_factor > 1: - if self.vgg_like: - return ((size * 2) * self.stride_1) + 1 - - return ((size + 2) * 2) + (self.kernel_2 - 1) * self.stride_2 - return size + return size * self.subsampling_factor diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py index e0884cef6..cc5b70886 100644 --- a/funasr/tasks/abs_task.py +++ b/funasr/tasks/abs_task.py @@ -1576,7 +1576,7 @@ class AbsTask(ABC): preprocess=iter_options.preprocess_fn, max_cache_size=iter_options.max_cache_size, max_cache_fd=iter_options.max_cache_fd, - dest_sample_rate=args.frontend_conf["fs"], + dest_sample_rate=args.frontend_conf["fs"] if args.frontend_conf else 16000, ) cls.check_task_requirements( dataset, args.allow_variable_data_keys, train=iter_options.train