diff --git a/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/finetune.py b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/finetune.py new file mode 100644 index 000000000..7ca1bffa9 --- /dev/null +++ b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/finetune.py @@ -0,0 +1,35 @@ +import os +from modelscope.metainfo import Trainers +from modelscope.trainers import build_trainer +from funasr.datasets.ms_dataset import MsDataset + + +def modelscope_finetune(params): + if not os.path.exists(params.output_dir): + os.makedirs(params.output_dir, exist_ok=True) + # dataset split ["train", "validation"] + ds_dict = MsDataset.load(params.data_path) + kwargs = dict( + model=params.model, + model_revision=params.model_revision, + data_dir=ds_dict, + dataset_type=params.dataset_type, + work_dir=params.output_dir, + batch_bins=params.batch_bins, + max_epoch=params.max_epoch, + lr=params.lr) + trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs) + trainer.train() + + +if __name__ == '__main__': + from funasr.utils.modelscope_param import modelscope_args + params = modelscope_args(model="damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020") + params.output_dir = "./checkpoint" # m模型保存路径 + params.data_path = "./example_data/" # 数据路径 + params.dataset_type = "small" # 小数据量设置small,若数据量大于1000小时,请使用large + params.batch_bins = 1000 # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒, + params.max_epoch = 50 # 最大训练轮数 + params.lr = 0.00005 # 设置学习率 + params.model_revision = "v1.0.1" + modelscope_finetune(params) \ No newline at end of file diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py index 52aa5093e..61af7663d 100644 --- a/funasr/bin/build_trainer.py +++ b/funasr/bin/build_trainer.py @@ -548,6 +548,7 @@ def build_trainer(modelscope_dict, init_param = modelscope_dict['init_model'] cmvn_file = modelscope_dict['cmvn_file'] seg_dict_file = modelscope_dict['seg_dict'] + bpemodel = modelscope_dict['bpemodel'] # overwrite parameters with open(config) as f: @@ -581,6 +582,10 @@ def build_trainer(modelscope_dict, args.seg_dict_file = seg_dict_file else: args.seg_dict_file = None + if os.path.exists(bpemodel): + args.bpemodel = bpemodel + else: + args.bpemodel = None args.data_dir = data_dir args.train_set = train_set args.dev_set = dev_set