diff --git a/funasr/bin/vad_infer.py b/funasr/bin/vad_infer.py index e1698d03d..a511239c4 100644 --- a/funasr/bin/vad_infer.py +++ b/funasr/bin/vad_infer.py @@ -1,42 +1,23 @@ -# -*- encoding: utf-8 -*- #!/usr/bin/env python3 +# -*- encoding: utf-8 -*- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) -import argparse import logging -import os -import sys -import json +import math from pathlib import Path -from typing import Any +from typing import Dict from typing import List -from typing import Optional -from typing import Sequence from typing import Tuple from typing import Union -from typing import Dict -import math import numpy as np import torch from typeguard import check_argument_types -from typeguard import check_return_type -from funasr.fileio.datadir_writer import DatadirWriter -from funasr.modules.scorers.scorer_interface import BatchScorerInterface -from funasr.modules.subsampling import TooShortUttError -from funasr.tasks.vad import VADTask -from funasr.torch_utils.device_funcs import to_device -from funasr.torch_utils.set_all_random_seed import set_all_random_seed -from funasr.utils import config_argparse -from funasr.utils.cli_utils import get_commandline_args -from funasr.utils.types import str2bool -from funasr.utils.types import str2triple_str -from funasr.utils.types import str_or_none -from funasr.utils import asr_utils, wav_utils, postprocess_utils +from funasr.build_utils.build_model_from_file import build_model_from_file from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline - +from funasr.torch_utils.device_funcs import to_device class Speech2VadSegment: @@ -64,8 +45,8 @@ class Speech2VadSegment: assert check_argument_types() # 1. Build vad model - vad_model, vad_infer_args = VADTask.build_model_from_file( - vad_infer_config, vad_model_file, device + vad_model, vad_infer_args = build_model_from_file( + vad_infer_config, vad_model_file, "None", device, task_name="vad" ) frontend = None if vad_infer_args.frontend is not None: @@ -128,13 +109,14 @@ class Speech2VadSegment: "in_cache": in_cache } # a. To device - #batch = to_device(batch, device=self.device) + # batch = to_device(batch, device=self.device) segments_part, in_cache = self.vad_model(**batch) if segments_part: for batch_num in range(0, self.batch_size): segments[batch_num] += segments_part[batch_num] return fbanks, segments + class Speech2VadSegmentOnline(Speech2VadSegment): """Speech2VadSegmentOnline class @@ -146,6 +128,7 @@ class Speech2VadSegmentOnline(Speech2VadSegment): [[10, 230], [245, 450], ...] """ + def __init__(self, **kwargs): super(Speech2VadSegmentOnline, self).__init__(**kwargs) vad_cmvn_file = kwargs.get('vad_cmvn_file', None) @@ -153,7 +136,6 @@ class Speech2VadSegmentOnline(Speech2VadSegment): if self.vad_infer_args.frontend is not None: self.frontend = WavFrontendOnline(cmvn_file=vad_cmvn_file, **self.vad_infer_args.frontend_conf) - @torch.no_grad() def __call__( self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None, @@ -198,5 +180,3 @@ class Speech2VadSegmentOnline(Speech2VadSegment): # in_cache.update(batch['in_cache']) # in_cache = {key: value for key, value in batch['in_cache'].items()} return fbanks, segments, in_cache - - diff --git a/funasr/bin/vad_inference_launch.py b/funasr/bin/vad_inference_launch.py index b17d05863..829f157f4 100644 --- a/funasr/bin/vad_inference_launch.py +++ b/funasr/bin/vad_inference_launch.py @@ -1,59 +1,35 @@ -# -*- encoding: utf-8 -*- #!/usr/bin/env python3 +# -*- encoding: utf-8 -*- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) import torch + torch.set_num_threads(1) -import argparse -import logging -import os -import sys -from typing import Union, Dict, Any - -from funasr.utils import config_argparse -from funasr.utils.cli_utils import get_commandline_args -from funasr.utils.types import str2bool -from funasr.utils.types import str2triple_str -from funasr.utils.types import str_or_none - import argparse import logging import os import sys import json -from pathlib import Path -from typing import Any -from typing import List from typing import Optional -from typing import Sequence -from typing import Tuple from typing import Union -from typing import Dict -import math import numpy as np import torch from typeguard import check_argument_types -from typeguard import check_return_type - +from funasr.build_utils.build_streaming_iterator import build_streaming_iterator from funasr.fileio.datadir_writer import DatadirWriter -from funasr.modules.scorers.scorer_interface import BatchScorerInterface -from funasr.modules.subsampling import TooShortUttError -from funasr.tasks.vad import VADTask -from funasr.torch_utils.device_funcs import to_device from funasr.torch_utils.set_all_random_seed import set_all_random_seed from funasr.utils import config_argparse from funasr.utils.cli_utils import get_commandline_args from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none -from funasr.utils import asr_utils, wav_utils, postprocess_utils -from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline from funasr.bin.vad_infer import Speech2VadSegment, Speech2VadSegmentOnline + def inference_vad( batch_size: int, ngpu: int, @@ -75,7 +51,6 @@ def inference_vad( if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") - logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", @@ -112,16 +87,14 @@ def inference_vad( if isinstance(raw_inputs, torch.Tensor): raw_inputs = raw_inputs.numpy() data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] - loader = VADTask.build_streaming_iterator( - data_path_and_name_and_type, + loader = build_streaming_iterator( + task_name="vad", + preprocess_args=None, + data_path_and_name_and_type=data_path_and_name_and_type, dtype=dtype, batch_size=batch_size, key_file=key_file, num_workers=num_workers, - preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False), - collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False), - allow_variable_data_keys=allow_variable_data_keys, - inference=True, ) finish_count = 0 @@ -157,6 +130,7 @@ def inference_vad( return _forward + def inference_vad_online( batch_size: int, ngpu: int, @@ -176,7 +150,6 @@ def inference_vad_online( ): assert check_argument_types() - logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", @@ -214,16 +187,14 @@ def inference_vad_online( if isinstance(raw_inputs, torch.Tensor): raw_inputs = raw_inputs.numpy() data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] - loader = VADTask.build_streaming_iterator( - data_path_and_name_and_type, + loader = build_streaming_iterator( + task_name="vad", + preprocess_args=None, + data_path_and_name_and_type=data_path_and_name_and_type, dtype=dtype, batch_size=batch_size, key_file=key_file, num_workers=num_workers, - preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False), - collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False), - allow_variable_data_keys=allow_variable_data_keys, - inference=True, ) finish_count = 0 @@ -273,8 +244,6 @@ def inference_vad_online( return _forward - - def inference_launch(mode, **kwargs): if mode == "offline": return inference_vad(**kwargs) @@ -284,6 +253,7 @@ def inference_launch(mode, **kwargs): logging.info("Unknown decoding mode: {}".format(mode)) return None + def get_parser(): parser = config_argparse.ArgumentParser( description="VAD Decoding", @@ -405,5 +375,6 @@ def main(cmd=None): inference_pipeline = inference_launch(**kwargs) return inference_pipeline(kwargs["data_path_and_name_and_type"]) + if __name__ == "__main__": main() diff --git a/funasr/build_utils/build_streaming_iterator.py b/funasr/build_utils/build_streaming_iterator.py index 8c5f7fca7..1b16cf4ec 100644 --- a/funasr/build_utils/build_streaming_iterator.py +++ b/funasr/build_utils/build_streaming_iterator.py @@ -5,7 +5,7 @@ from typeguard import check_argument_types from funasr.datasets.iterable_dataset import IterableESPnetDataset from funasr.datasets.small_datasets.collate_fn import CommonCollateFn from funasr.datasets.small_datasets.preprocessor import build_preprocess -from funasr.build_utils.build_model_from_file import build_model_from_file + def build_streaming_iterator( task_name, @@ -20,7 +20,7 @@ def build_streaming_iterator( use_collate_fn: bool = True, preprocess_fn=None, ngpu: int = 0, - train: bool=False, + train: bool = False, ) -> DataLoader: """Build DataLoader using iterable dataset""" assert check_argument_types()