big fix for speaker pipeline

This commit is contained in:
shixian.shi 2023-10-10 17:11:15 +08:00
parent f974935484
commit 78c78c39a9
4 changed files with 93 additions and 3 deletions

View File

@ -0,0 +1,35 @@
import os
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
from funasr.datasets.ms_dataset import MsDataset
def modelscope_finetune(params):
if not os.path.exists(params.output_dir):
os.makedirs(params.output_dir, exist_ok=True)
# dataset split ["train", "validation"]
ds_dict = MsDataset.load(params.data_path)
kwargs = dict(
model=params.model,
model_revision=params.model_revision,
data_dir=ds_dict,
dataset_type=params.dataset_type,
work_dir=params.output_dir,
batch_bins=params.batch_bins,
max_epoch=params.max_epoch,
lr=params.lr)
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
trainer.train()
if __name__ == '__main__':
from funasr.utils.modelscope_param import modelscope_args
params = modelscope_args(model="damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn", data_path="./data")
params.output_dir = "./checkpoint" # m模型保存路径
params.data_path = "./example_data/" # 数据路径
params.dataset_type = "small" # 小数据量设置small若数据量大于1000小时请使用large
params.batch_bins = 2000 # batch size如果dataset_type="small"batch_bins单位为fbank特征帧数如果dataset_type="large"batch_bins单位为毫秒
params.max_epoch = 50 # 最大训练轮数
params.lr = 0.00005 # 设置学习率
params.model_revision = "v1.2.1"
modelscope_finetune(params)

View File

@ -0,0 +1,27 @@
import os
import shutil
import argparse
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
def modelscope_infer(args):
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpuid)
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model=args.model,
output_dir=args.output_dir,
param_dict={"decoding_model": args.decoding_mode, "hotword": args.hotword_txt}
)
inference_pipeline(audio_in=args.audio_in, batch_size_token=args.batch_size_token)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default="damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn")
parser.add_argument('--audio_in', type=str, default="./data/test/wav.scp")
parser.add_argument('--output_dir', type=str, default="./results/")
parser.add_argument('--decoding_mode', type=str, default="normal")
parser.add_argument('--hotword_txt', type=str, default=None)
parser.add_argument('--batch_size_token', type=int, default=5000)
parser.add_argument('--gpuid', type=str, default="0")
args = parser.parse_args()
modelscope_infer(args)

View File

@ -55,6 +55,7 @@ from funasr.utils.speaker_utils import (check_audio_list,
distribute_spk)
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.utils.cluster_backend import ClusterBackend
from funasr.utils.modelscope_utils import get_cache_dir
from tqdm import tqdm
def inference_asr(
@ -791,7 +792,7 @@ def inference_paraformer_vad_speaker(
time_stamp_writer: bool = True,
punc_infer_config: Optional[str] = None,
punc_model_file: Optional[str] = None,
sv_model_file: Optional[str] = "~/.cache/modelscope/hub/damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/campplus_cn_common.bin",
sv_model_file: Optional[str] = None,
streaming: bool = False,
embedding_node: str = "resnet1_dense",
sv_threshold: float = 0.9465,
@ -813,6 +814,9 @@ def inference_paraformer_vad_speaker(
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if sv_model_file is None:
sv_model_file = "{}/damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/campplus_cn_common.bin".format(get_cache_dir(None))
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
else:
@ -937,7 +941,7 @@ def inference_paraformer_vad_speaker(
##### speaker_verification #####
##################################
# load sv model
sv_model_dict = torch.load(sv_model_file.replace("~", os.environ['HOME']), map_location=torch.device('cpu'))
sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu'))
sv_model = CAMPPlus()
sv_model.load_state_dict(sv_model_dict)
sv_model.eval()

View File

@ -1,5 +1,6 @@
import os
from modelscope.hub.snapshot_download import snapshot_download
from pathlib import Path
def check_model_dir(model_dir, model_name: str = "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"):
@ -13,4 +14,27 @@ def check_model_dir(model_dir, model_name: str = "damo/speech_fsmn_vad_zh-cn-16k
if not os.path.exists(dst):
os.symlink(model_dir, dst)
model_dir = snapshot_download(model_name, cache_dir=dst_dir_root)
model_dir = snapshot_download(model_name, cache_dir=dst_dir_root)
def get_default_cache_dir():
"""
default base dir: '~/.cache/modelscope'
"""
default_cache_dir = Path.home().joinpath('.cache', 'modelscope')
return default_cache_dir
def get_cache_dir(model_id):
"""cache dir precedence:
function parameter > environment > ~/.cache/modelscope/hub
Args:
model_id (str, optional): The model id.
Returns:
str: the model_id dir if model_id not None, otherwise cache root dir.
"""
default_cache_dir = get_default_cache_dir()
base_path = os.getenv('MODELSCOPE_CACHE',
os.path.join(default_cache_dir, 'hub'))
return base_path if model_id is None else os.path.join(
base_path, model_id + '/')