From 78c78c39a90c62b7c552019043a970e9f85bf378 Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Tue, 10 Oct 2023 17:11:15 +0800 Subject: [PATCH] big fix for speaker pipeline --- .../finetune.py | 35 +++++++++++++++++++ .../infer.py | 27 ++++++++++++++ funasr/bin/asr_inference_launch.py | 8 +++-- funasr/utils/modelscope_utils.py | 26 +++++++++++++- 4 files changed, 93 insertions(+), 3 deletions(-) create mode 100644 egs_modelscope/asr_vad_spk/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/finetune.py create mode 100644 egs_modelscope/asr_vad_spk/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/infer.py diff --git a/egs_modelscope/asr_vad_spk/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/finetune.py b/egs_modelscope/asr_vad_spk/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/finetune.py new file mode 100644 index 000000000..52d2b9c41 --- /dev/null +++ b/egs_modelscope/asr_vad_spk/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/finetune.py @@ -0,0 +1,35 @@ +import os +from modelscope.metainfo import Trainers +from modelscope.trainers import build_trainer +from funasr.datasets.ms_dataset import MsDataset + + +def modelscope_finetune(params): + if not os.path.exists(params.output_dir): + os.makedirs(params.output_dir, exist_ok=True) + # dataset split ["train", "validation"] + ds_dict = MsDataset.load(params.data_path) + kwargs = dict( + model=params.model, + model_revision=params.model_revision, + data_dir=ds_dict, + dataset_type=params.dataset_type, + work_dir=params.output_dir, + batch_bins=params.batch_bins, + max_epoch=params.max_epoch, + lr=params.lr) + trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs) + trainer.train() + + +if __name__ == '__main__': + from funasr.utils.modelscope_param import modelscope_args + params = modelscope_args(model="damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn", data_path="./data") + params.output_dir = "./checkpoint" # m模型保存路径 + params.data_path = "./example_data/" # 数据路径 + params.dataset_type = "small" # 小数据量设置small,若数据量大于1000小时,请使用large + params.batch_bins = 2000 # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒, + params.max_epoch = 50 # 最大训练轮数 + params.lr = 0.00005 # 设置学习率 + params.model_revision = "v1.2.1" + modelscope_finetune(params) diff --git a/egs_modelscope/asr_vad_spk/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/infer.py b/egs_modelscope/asr_vad_spk/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/infer.py new file mode 100644 index 000000000..4783ec1b0 --- /dev/null +++ b/egs_modelscope/asr_vad_spk/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/infer.py @@ -0,0 +1,27 @@ +import os +import shutil +import argparse +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks + +def modelscope_infer(args): + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpuid) + inference_pipeline = pipeline( + task=Tasks.auto_speech_recognition, + model=args.model, + output_dir=args.output_dir, + param_dict={"decoding_model": args.decoding_mode, "hotword": args.hotword_txt} + ) + inference_pipeline(audio_in=args.audio_in, batch_size_token=args.batch_size_token) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--model', type=str, default="damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn") + parser.add_argument('--audio_in', type=str, default="./data/test/wav.scp") + parser.add_argument('--output_dir', type=str, default="./results/") + parser.add_argument('--decoding_mode', type=str, default="normal") + parser.add_argument('--hotword_txt', type=str, default=None) + parser.add_argument('--batch_size_token', type=int, default=5000) + parser.add_argument('--gpuid', type=str, default="0") + args = parser.parse_args() + modelscope_infer(args) diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index 128877737..c728d7212 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -55,6 +55,7 @@ from funasr.utils.speaker_utils import (check_audio_list, distribute_spk) from funasr.build_utils.build_model_from_file import build_model_from_file from funasr.utils.cluster_backend import ClusterBackend +from funasr.utils.modelscope_utils import get_cache_dir from tqdm import tqdm def inference_asr( @@ -791,7 +792,7 @@ def inference_paraformer_vad_speaker( time_stamp_writer: bool = True, punc_infer_config: Optional[str] = None, punc_model_file: Optional[str] = None, - sv_model_file: Optional[str] = "~/.cache/modelscope/hub/damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/campplus_cn_common.bin", + sv_model_file: Optional[str] = None, streaming: bool = False, embedding_node: str = "resnet1_dense", sv_threshold: float = 0.9465, @@ -813,6 +814,9 @@ def inference_paraformer_vad_speaker( format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) + if sv_model_file is None: + sv_model_file = "{}/damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/campplus_cn_common.bin".format(get_cache_dir(None)) + if param_dict is not None: hotword_list_or_file = param_dict.get('hotword') else: @@ -937,7 +941,7 @@ def inference_paraformer_vad_speaker( ##### speaker_verification ##### ################################## # load sv model - sv_model_dict = torch.load(sv_model_file.replace("~", os.environ['HOME']), map_location=torch.device('cpu')) + sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu')) sv_model = CAMPPlus() sv_model.load_state_dict(sv_model_dict) sv_model.eval() diff --git a/funasr/utils/modelscope_utils.py b/funasr/utils/modelscope_utils.py index 9712e09e6..417988525 100644 --- a/funasr/utils/modelscope_utils.py +++ b/funasr/utils/modelscope_utils.py @@ -1,5 +1,6 @@ import os from modelscope.hub.snapshot_download import snapshot_download +from pathlib import Path def check_model_dir(model_dir, model_name: str = "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"): @@ -13,4 +14,27 @@ def check_model_dir(model_dir, model_name: str = "damo/speech_fsmn_vad_zh-cn-16k if not os.path.exists(dst): os.symlink(model_dir, dst) - model_dir = snapshot_download(model_name, cache_dir=dst_dir_root) \ No newline at end of file + model_dir = snapshot_download(model_name, cache_dir=dst_dir_root) + +def get_default_cache_dir(): + """ + default base dir: '~/.cache/modelscope' + """ + default_cache_dir = Path.home().joinpath('.cache', 'modelscope') + return default_cache_dir + +def get_cache_dir(model_id): + """cache dir precedence: + function parameter > environment > ~/.cache/modelscope/hub + + Args: + model_id (str, optional): The model id. + + Returns: + str: the model_id dir if model_id not None, otherwise cache root dir. + """ + default_cache_dir = get_default_cache_dir() + base_path = os.getenv('MODELSCOPE_CACHE', + os.path.join(default_cache_dir, 'hub')) + return base_path if model_id is None else os.path.join( + base_path, model_id + '/') \ No newline at end of file