add lora finetune code

This commit is contained in:
haoneng.lhn 2023-07-19 16:41:39 +08:00
parent 7ac54b3c97
commit 7a19c52602
3 changed files with 20 additions and 4 deletions

View File

@ -55,7 +55,7 @@ def build_trainer(modelscope_dict,
scheduler_conf=None,
specaug=None,
specaug_conf=None,
param_dict=None,
meta_dict=None,
**kwargs):
mode = modelscope_dict['mode']
args, ASRTask = parse_args(mode=mode)
@ -144,8 +144,9 @@ def build_trainer(modelscope_dict,
args.patience = None
args.local_rank = local_rank
args.distributed = distributed
for key, value in kwargs.items():
args.key = value
if meta_dict is not None:
for key, value in meta_dict.items():
args.key = value
ASRTask.finetune_args = args
return ASRTask

View File

@ -489,7 +489,7 @@ def get_parser():
"--lora_bias",
type=str,
default="none",
help="oss bucket.",
help="lora bias.",
)
return parser

View File

@ -71,6 +71,7 @@ from funasr.utils.types import str_or_int
from funasr.utils.types import str_or_none
from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
from funasr.modules.lora.utils import mark_only_lora_as_trainable
try:
import wandb
@ -952,6 +953,18 @@ class AbsTask(ABC):
default=None,
help="oss bucket.",
)
group.add_argument(
"--enable_lora",
type=str2bool,
default=False,
help="Apply lora for finetuning.",
)
group.add_argument(
"--lora_bias",
type=str,
default="none",
help="lora bias.",
)
cls.trainer.add_arguments(parser)
cls.add_task_arguments(parser)
@ -1246,6 +1259,8 @@ class AbsTask(ABC):
dtype=getattr(torch, args.train_dtype),
device="cuda" if args.ngpu > 0 else "cpu",
)
if args.enable_lora:
mark_only_lora_as_trainable(model, args.lora_bias)
for t in args.freeze_param:
for k, p in model.named_parameters():
if k.startswith(t + ".") or k == t: