update repo

This commit is contained in:
嘉渊 2023-05-24 11:52:19 +08:00
parent 53a753755b
commit a5b1f9911c

View File

@ -272,8 +272,8 @@ def get_parser():
parser.add_argument(
"--init_param",
type=str,
action="append",
default=[],
nargs="*",
help="Specify the file path used for initialization of parameters. "
"The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
"where file_path is the model file path, "
@ -519,6 +519,12 @@ if __name__ == '__main__':
dtype=getattr(torch, args.train_dtype),
device="cuda" if args.ngpu > 0 else "cpu",
)
for t in args.freeze_param:
for k, p in model.named_parameters():
if k.startswith(t + ".") or k == t:
logging.info(f"Setting {k}.requires_grad = False")
p.requires_grad = False
optimizers = build_optimizer(args, model=model)
schedulers = build_scheduler(args, optimizers)