mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add lora finetune code
This commit is contained in:
parent
7ac54b3c97
commit
7a19c52602
@ -55,7 +55,7 @@ def build_trainer(modelscope_dict,
|
||||
scheduler_conf=None,
|
||||
specaug=None,
|
||||
specaug_conf=None,
|
||||
param_dict=None,
|
||||
meta_dict=None,
|
||||
**kwargs):
|
||||
mode = modelscope_dict['mode']
|
||||
args, ASRTask = parse_args(mode=mode)
|
||||
@ -144,8 +144,9 @@ def build_trainer(modelscope_dict,
|
||||
args.patience = None
|
||||
args.local_rank = local_rank
|
||||
args.distributed = distributed
|
||||
for key, value in kwargs.items():
|
||||
args.key = value
|
||||
if meta_dict is not None:
|
||||
for key, value in meta_dict.items():
|
||||
args.key = value
|
||||
ASRTask.finetune_args = args
|
||||
|
||||
return ASRTask
|
||||
|
||||
@ -489,7 +489,7 @@ def get_parser():
|
||||
"--lora_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="oss bucket.",
|
||||
help="lora bias.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@ -71,6 +71,7 @@ from funasr.utils.types import str_or_int
|
||||
from funasr.utils.types import str_or_none
|
||||
from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text
|
||||
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
|
||||
from funasr.modules.lora.utils import mark_only_lora_as_trainable
|
||||
|
||||
try:
|
||||
import wandb
|
||||
@ -952,6 +953,18 @@ class AbsTask(ABC):
|
||||
default=None,
|
||||
help="oss bucket.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--enable_lora",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Apply lora for finetuning.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--lora_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="lora bias.",
|
||||
)
|
||||
|
||||
cls.trainer.add_arguments(parser)
|
||||
cls.add_task_arguments(parser)
|
||||
@ -1246,6 +1259,8 @@ class AbsTask(ABC):
|
||||
dtype=getattr(torch, args.train_dtype),
|
||||
device="cuda" if args.ngpu > 0 else "cpu",
|
||||
)
|
||||
if args.enable_lora:
|
||||
mark_only_lora_as_trainable(model, args.lora_bias)
|
||||
for t in args.freeze_param:
|
||||
for k, p in model.named_parameters():
|
||||
if k.startswith(t + ".") or k == t:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user