mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #1053 from alibaba-damo-academy/dev_lzr_en
support paraformer-16k-en finetune
This commit is contained in:
commit
8c904ecadd
@ -0,0 +1,35 @@
|
|||||||
|
import os
|
||||||
|
from modelscope.metainfo import Trainers
|
||||||
|
from modelscope.trainers import build_trainer
|
||||||
|
from funasr.datasets.ms_dataset import MsDataset
|
||||||
|
|
||||||
|
|
||||||
|
def modelscope_finetune(params):
|
||||||
|
if not os.path.exists(params.output_dir):
|
||||||
|
os.makedirs(params.output_dir, exist_ok=True)
|
||||||
|
# dataset split ["train", "validation"]
|
||||||
|
ds_dict = MsDataset.load(params.data_path)
|
||||||
|
kwargs = dict(
|
||||||
|
model=params.model,
|
||||||
|
model_revision=params.model_revision,
|
||||||
|
data_dir=ds_dict,
|
||||||
|
dataset_type=params.dataset_type,
|
||||||
|
work_dir=params.output_dir,
|
||||||
|
batch_bins=params.batch_bins,
|
||||||
|
max_epoch=params.max_epoch,
|
||||||
|
lr=params.lr)
|
||||||
|
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from funasr.utils.modelscope_param import modelscope_args
|
||||||
|
params = modelscope_args(model="damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020")
|
||||||
|
params.output_dir = "./checkpoint" # m模型保存路径
|
||||||
|
params.data_path = "./example_data/" # 数据路径
|
||||||
|
params.dataset_type = "small" # 小数据量设置small,若数据量大于1000小时,请使用large
|
||||||
|
params.batch_bins = 1000 # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒,
|
||||||
|
params.max_epoch = 50 # 最大训练轮数
|
||||||
|
params.lr = 0.00005 # 设置学习率
|
||||||
|
params.model_revision = "v1.0.1"
|
||||||
|
modelscope_finetune(params)
|
||||||
@ -548,6 +548,7 @@ def build_trainer(modelscope_dict,
|
|||||||
init_param = modelscope_dict['init_model']
|
init_param = modelscope_dict['init_model']
|
||||||
cmvn_file = modelscope_dict['cmvn_file']
|
cmvn_file = modelscope_dict['cmvn_file']
|
||||||
seg_dict_file = modelscope_dict['seg_dict']
|
seg_dict_file = modelscope_dict['seg_dict']
|
||||||
|
bpemodel = modelscope_dict['bpemodel']
|
||||||
|
|
||||||
# overwrite parameters
|
# overwrite parameters
|
||||||
with open(config) as f:
|
with open(config) as f:
|
||||||
@ -581,6 +582,10 @@ def build_trainer(modelscope_dict,
|
|||||||
args.seg_dict_file = seg_dict_file
|
args.seg_dict_file = seg_dict_file
|
||||||
else:
|
else:
|
||||||
args.seg_dict_file = None
|
args.seg_dict_file = None
|
||||||
|
if os.path.exists(bpemodel):
|
||||||
|
args.bpemodel = bpemodel
|
||||||
|
else:
|
||||||
|
args.bpemodel = None
|
||||||
args.data_dir = data_dir
|
args.data_dir = data_dir
|
||||||
args.train_set = train_set
|
args.train_set = train_set
|
||||||
args.dev_set = dev_set
|
args.dev_set = dev_set
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user