diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py index bd30a83c2..0ab2013de 100644 --- a/funasr/bin/build_trainer.py +++ b/funasr/bin/build_trainer.py @@ -55,7 +55,7 @@ def build_trainer(modelscope_dict, scheduler_conf=None, specaug=None, specaug_conf=None, - param_dict=None, + meta_dict=None, **kwargs): mode = modelscope_dict['mode'] args, ASRTask = parse_args(mode=mode) @@ -144,8 +144,9 @@ def build_trainer(modelscope_dict, args.patience = None args.local_rank = local_rank args.distributed = distributed - for key, value in kwargs.items(): - args.key = value + if meta_dict is not None: + for key, value in meta_dict.items(): + args.key = value ASRTask.finetune_args = args return ASRTask diff --git a/funasr/bin/train.py b/funasr/bin/train.py index c9c0b0204..f5d10c4ac 100755 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -489,7 +489,7 @@ def get_parser(): "--lora_bias", type=str, default="none", - help="oss bucket.", + help="lora bias.", ) return parser diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py index 91d33c5e3..f7f13d289 100644 --- a/funasr/tasks/abs_task.py +++ b/funasr/tasks/abs_task.py @@ -71,6 +71,7 @@ from funasr.utils.types import str_or_int from funasr.utils.types import str_or_none from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump +from funasr.modules.lora.utils import mark_only_lora_as_trainable try: import wandb @@ -952,6 +953,18 @@ class AbsTask(ABC): default=None, help="oss bucket.", ) + group.add_argument( + "--enable_lora", + type=str2bool, + default=False, + help="Apply lora for finetuning.", + ) + group.add_argument( + "--lora_bias", + type=str, + default="none", + help="lora bias.", + ) cls.trainer.add_arguments(parser) cls.add_task_arguments(parser) @@ -1246,6 +1259,8 @@ class AbsTask(ABC): dtype=getattr(torch, args.train_dtype), device="cuda" if args.ngpu > 0 else "cpu", ) + if args.enable_lora: + mark_only_lora_as_trainable(model, args.lora_bias) for t in args.freeze_param: for k, p in model.named_parameters(): if k.startswith(t + ".") or k == t: