mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update repo
This commit is contained in:
parent
04528cb292
commit
eee72548d7
@ -1,42 +1,23 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
from typing import Dict
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.fileio.datadir_writer import DatadirWriter
|
||||
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
|
||||
from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.tasks.vad import VADTask
|
||||
from funasr.torch_utils.device_funcs import to_device
|
||||
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.utils import config_argparse
|
||||
from funasr.utils.cli_utils import get_commandline_args
|
||||
from funasr.utils.types import str2bool
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_none
|
||||
from funasr.utils import asr_utils, wav_utils, postprocess_utils
|
||||
from funasr.build_utils.build_model_from_file import build_model_from_file
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
|
||||
|
||||
from funasr.torch_utils.device_funcs import to_device
|
||||
|
||||
|
||||
class Speech2VadSegment:
|
||||
@ -64,8 +45,8 @@ class Speech2VadSegment:
|
||||
assert check_argument_types()
|
||||
|
||||
# 1. Build vad model
|
||||
vad_model, vad_infer_args = VADTask.build_model_from_file(
|
||||
vad_infer_config, vad_model_file, device
|
||||
vad_model, vad_infer_args = build_model_from_file(
|
||||
vad_infer_config, vad_model_file, "None", device, task_name="vad"
|
||||
)
|
||||
frontend = None
|
||||
if vad_infer_args.frontend is not None:
|
||||
@ -128,13 +109,14 @@ class Speech2VadSegment:
|
||||
"in_cache": in_cache
|
||||
}
|
||||
# a. To device
|
||||
#batch = to_device(batch, device=self.device)
|
||||
# batch = to_device(batch, device=self.device)
|
||||
segments_part, in_cache = self.vad_model(**batch)
|
||||
if segments_part:
|
||||
for batch_num in range(0, self.batch_size):
|
||||
segments[batch_num] += segments_part[batch_num]
|
||||
return fbanks, segments
|
||||
|
||||
|
||||
class Speech2VadSegmentOnline(Speech2VadSegment):
|
||||
"""Speech2VadSegmentOnline class
|
||||
|
||||
@ -146,6 +128,7 @@ class Speech2VadSegmentOnline(Speech2VadSegment):
|
||||
[[10, 230], [245, 450], ...]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(Speech2VadSegmentOnline, self).__init__(**kwargs)
|
||||
vad_cmvn_file = kwargs.get('vad_cmvn_file', None)
|
||||
@ -153,7 +136,6 @@ class Speech2VadSegmentOnline(Speech2VadSegment):
|
||||
if self.vad_infer_args.frontend is not None:
|
||||
self.frontend = WavFrontendOnline(cmvn_file=vad_cmvn_file, **self.vad_infer_args.frontend_conf)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
|
||||
@ -198,5 +180,3 @@ class Speech2VadSegmentOnline(Speech2VadSegment):
|
||||
# in_cache.update(batch['in_cache'])
|
||||
# in_cache = {key: value for key, value in batch['in_cache'].items()}
|
||||
return fbanks, segments, in_cache
|
||||
|
||||
|
||||
|
||||
@ -1,59 +1,35 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
torch.set_num_threads(1)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Union, Dict, Any
|
||||
|
||||
from funasr.utils import config_argparse
|
||||
from funasr.utils.cli_utils import get_commandline_args
|
||||
from funasr.utils.types import str2bool
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_none
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
from typing import Dict
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
|
||||
from funasr.fileio.datadir_writer import DatadirWriter
|
||||
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
|
||||
from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.tasks.vad import VADTask
|
||||
from funasr.torch_utils.device_funcs import to_device
|
||||
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.utils import config_argparse
|
||||
from funasr.utils.cli_utils import get_commandline_args
|
||||
from funasr.utils.types import str2bool
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_none
|
||||
from funasr.utils import asr_utils, wav_utils, postprocess_utils
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
|
||||
from funasr.bin.vad_infer import Speech2VadSegment, Speech2VadSegmentOnline
|
||||
|
||||
|
||||
def inference_vad(
|
||||
batch_size: int,
|
||||
ngpu: int,
|
||||
@ -75,7 +51,6 @@ def inference_vad(
|
||||
if batch_size > 1:
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
@ -112,16 +87,14 @@ def inference_vad(
|
||||
if isinstance(raw_inputs, torch.Tensor):
|
||||
raw_inputs = raw_inputs.numpy()
|
||||
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
|
||||
loader = VADTask.build_streaming_iterator(
|
||||
data_path_and_name_and_type,
|
||||
loader = build_streaming_iterator(
|
||||
task_name="vad",
|
||||
preprocess_args=None,
|
||||
data_path_and_name_and_type=data_path_and_name_and_type,
|
||||
dtype=dtype,
|
||||
batch_size=batch_size,
|
||||
key_file=key_file,
|
||||
num_workers=num_workers,
|
||||
preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
|
||||
collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
|
||||
allow_variable_data_keys=allow_variable_data_keys,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
finish_count = 0
|
||||
@ -157,6 +130,7 @@ def inference_vad(
|
||||
|
||||
return _forward
|
||||
|
||||
|
||||
def inference_vad_online(
|
||||
batch_size: int,
|
||||
ngpu: int,
|
||||
@ -176,7 +150,6 @@ def inference_vad_online(
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
@ -214,16 +187,14 @@ def inference_vad_online(
|
||||
if isinstance(raw_inputs, torch.Tensor):
|
||||
raw_inputs = raw_inputs.numpy()
|
||||
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
|
||||
loader = VADTask.build_streaming_iterator(
|
||||
data_path_and_name_and_type,
|
||||
loader = build_streaming_iterator(
|
||||
task_name="vad",
|
||||
preprocess_args=None,
|
||||
data_path_and_name_and_type=data_path_and_name_and_type,
|
||||
dtype=dtype,
|
||||
batch_size=batch_size,
|
||||
key_file=key_file,
|
||||
num_workers=num_workers,
|
||||
preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
|
||||
collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
|
||||
allow_variable_data_keys=allow_variable_data_keys,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
finish_count = 0
|
||||
@ -273,8 +244,6 @@ def inference_vad_online(
|
||||
return _forward
|
||||
|
||||
|
||||
|
||||
|
||||
def inference_launch(mode, **kwargs):
|
||||
if mode == "offline":
|
||||
return inference_vad(**kwargs)
|
||||
@ -284,6 +253,7 @@ def inference_launch(mode, **kwargs):
|
||||
logging.info("Unknown decoding mode: {}".format(mode))
|
||||
return None
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = config_argparse.ArgumentParser(
|
||||
description="VAD Decoding",
|
||||
@ -405,5 +375,6 @@ def main(cmd=None):
|
||||
inference_pipeline = inference_launch(**kwargs)
|
||||
return inference_pipeline(kwargs["data_path_and_name_and_type"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -5,7 +5,7 @@ from typeguard import check_argument_types
|
||||
from funasr.datasets.iterable_dataset import IterableESPnetDataset
|
||||
from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
|
||||
from funasr.datasets.small_datasets.preprocessor import build_preprocess
|
||||
from funasr.build_utils.build_model_from_file import build_model_from_file
|
||||
|
||||
|
||||
def build_streaming_iterator(
|
||||
task_name,
|
||||
@ -20,7 +20,7 @@ def build_streaming_iterator(
|
||||
use_collate_fn: bool = True,
|
||||
preprocess_fn=None,
|
||||
ngpu: int = 0,
|
||||
train: bool=False,
|
||||
train: bool = False,
|
||||
) -> DataLoader:
|
||||
"""Build DataLoader using iterable dataset"""
|
||||
assert check_argument_types()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user