mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
big fix for speaker pipeline
This commit is contained in:
parent
f974935484
commit
78c78c39a9
@ -0,0 +1,35 @@
|
||||
import os
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.trainers import build_trainer
|
||||
from funasr.datasets.ms_dataset import MsDataset
|
||||
|
||||
|
||||
def modelscope_finetune(params):
|
||||
if not os.path.exists(params.output_dir):
|
||||
os.makedirs(params.output_dir, exist_ok=True)
|
||||
# dataset split ["train", "validation"]
|
||||
ds_dict = MsDataset.load(params.data_path)
|
||||
kwargs = dict(
|
||||
model=params.model,
|
||||
model_revision=params.model_revision,
|
||||
data_dir=ds_dict,
|
||||
dataset_type=params.dataset_type,
|
||||
work_dir=params.output_dir,
|
||||
batch_bins=params.batch_bins,
|
||||
max_epoch=params.max_epoch,
|
||||
lr=params.lr)
|
||||
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from funasr.utils.modelscope_param import modelscope_args
|
||||
params = modelscope_args(model="damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn", data_path="./data")
|
||||
params.output_dir = "./checkpoint" # m模型保存路径
|
||||
params.data_path = "./example_data/" # 数据路径
|
||||
params.dataset_type = "small" # 小数据量设置small,若数据量大于1000小时,请使用large
|
||||
params.batch_bins = 2000 # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒,
|
||||
params.max_epoch = 50 # 最大训练轮数
|
||||
params.lr = 0.00005 # 设置学习率
|
||||
params.model_revision = "v1.2.1"
|
||||
modelscope_finetune(params)
|
||||
@ -0,0 +1,27 @@
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
def modelscope_infer(args):
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpuid)
|
||||
inference_pipeline = pipeline(
|
||||
task=Tasks.auto_speech_recognition,
|
||||
model=args.model,
|
||||
output_dir=args.output_dir,
|
||||
param_dict={"decoding_model": args.decoding_mode, "hotword": args.hotword_txt}
|
||||
)
|
||||
inference_pipeline(audio_in=args.audio_in, batch_size_token=args.batch_size_token)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', type=str, default="damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn")
|
||||
parser.add_argument('--audio_in', type=str, default="./data/test/wav.scp")
|
||||
parser.add_argument('--output_dir', type=str, default="./results/")
|
||||
parser.add_argument('--decoding_mode', type=str, default="normal")
|
||||
parser.add_argument('--hotword_txt', type=str, default=None)
|
||||
parser.add_argument('--batch_size_token', type=int, default=5000)
|
||||
parser.add_argument('--gpuid', type=str, default="0")
|
||||
args = parser.parse_args()
|
||||
modelscope_infer(args)
|
||||
@ -55,6 +55,7 @@ from funasr.utils.speaker_utils import (check_audio_list,
|
||||
distribute_spk)
|
||||
from funasr.build_utils.build_model_from_file import build_model_from_file
|
||||
from funasr.utils.cluster_backend import ClusterBackend
|
||||
from funasr.utils.modelscope_utils import get_cache_dir
|
||||
from tqdm import tqdm
|
||||
|
||||
def inference_asr(
|
||||
@ -791,7 +792,7 @@ def inference_paraformer_vad_speaker(
|
||||
time_stamp_writer: bool = True,
|
||||
punc_infer_config: Optional[str] = None,
|
||||
punc_model_file: Optional[str] = None,
|
||||
sv_model_file: Optional[str] = "~/.cache/modelscope/hub/damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/campplus_cn_common.bin",
|
||||
sv_model_file: Optional[str] = None,
|
||||
streaming: bool = False,
|
||||
embedding_node: str = "resnet1_dense",
|
||||
sv_threshold: float = 0.9465,
|
||||
@ -813,6 +814,9 @@ def inference_paraformer_vad_speaker(
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
|
||||
if sv_model_file is None:
|
||||
sv_model_file = "{}/damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/campplus_cn_common.bin".format(get_cache_dir(None))
|
||||
|
||||
if param_dict is not None:
|
||||
hotword_list_or_file = param_dict.get('hotword')
|
||||
else:
|
||||
@ -937,7 +941,7 @@ def inference_paraformer_vad_speaker(
|
||||
##### speaker_verification #####
|
||||
##################################
|
||||
# load sv model
|
||||
sv_model_dict = torch.load(sv_model_file.replace("~", os.environ['HOME']), map_location=torch.device('cpu'))
|
||||
sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu'))
|
||||
sv_model = CAMPPlus()
|
||||
sv_model.load_state_dict(sv_model_dict)
|
||||
sv_model.eval()
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def check_model_dir(model_dir, model_name: str = "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"):
|
||||
@ -13,4 +14,27 @@ def check_model_dir(model_dir, model_name: str = "damo/speech_fsmn_vad_zh-cn-16k
|
||||
if not os.path.exists(dst):
|
||||
os.symlink(model_dir, dst)
|
||||
|
||||
model_dir = snapshot_download(model_name, cache_dir=dst_dir_root)
|
||||
model_dir = snapshot_download(model_name, cache_dir=dst_dir_root)
|
||||
|
||||
def get_default_cache_dir():
|
||||
"""
|
||||
default base dir: '~/.cache/modelscope'
|
||||
"""
|
||||
default_cache_dir = Path.home().joinpath('.cache', 'modelscope')
|
||||
return default_cache_dir
|
||||
|
||||
def get_cache_dir(model_id):
|
||||
"""cache dir precedence:
|
||||
function parameter > environment > ~/.cache/modelscope/hub
|
||||
|
||||
Args:
|
||||
model_id (str, optional): The model id.
|
||||
|
||||
Returns:
|
||||
str: the model_id dir if model_id not None, otherwise cache root dir.
|
||||
"""
|
||||
default_cache_dir = get_default_cache_dir()
|
||||
base_path = os.getenv('MODELSCOPE_CACHE',
|
||||
os.path.join(default_cache_dir, 'hub'))
|
||||
return base_path if model_id is None else os.path.join(
|
||||
base_path, model_id + '/')
|
||||
Loading…
Reference in New Issue
Block a user