diff --git a/funasr/bin/train.py b/funasr/bin/train.py index b0d46e7e4..21e1943cd 100755 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -272,8 +272,8 @@ def get_parser(): parser.add_argument( "--init_param", type=str, + action="append", default=[], - nargs="*", help="Specify the file path used for initialization of parameters. " "The format is ':::', " "where file_path is the model file path, " @@ -519,6 +519,12 @@ if __name__ == '__main__': dtype=getattr(torch, args.train_dtype), device="cuda" if args.ngpu > 0 else "cpu", ) + for t in args.freeze_param: + for k, p in model.named_parameters(): + if k.startswith(t + ".") or k == t: + logging.info(f"Setting {k}.requires_grad = False") + p.requires_grad = False + optimizers = build_optimizer(args, model=model) schedulers = build_scheduler(args, optimizers)