mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update repo
This commit is contained in:
parent
53a753755b
commit
a5b1f9911c
@ -272,8 +272,8 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--init_param",
|
||||
type=str,
|
||||
action="append",
|
||||
default=[],
|
||||
nargs="*",
|
||||
help="Specify the file path used for initialization of parameters. "
|
||||
"The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
|
||||
"where file_path is the model file path, "
|
||||
@ -519,6 +519,12 @@ if __name__ == '__main__':
|
||||
dtype=getattr(torch, args.train_dtype),
|
||||
device="cuda" if args.ngpu > 0 else "cpu",
|
||||
)
|
||||
for t in args.freeze_param:
|
||||
for k, p in model.named_parameters():
|
||||
if k.startswith(t + ".") or k == t:
|
||||
logging.info(f"Setting {k}.requires_grad = False")
|
||||
p.requires_grad = False
|
||||
|
||||
optimizers = build_optimizer(args, model=model)
|
||||
schedulers = build_scheduler(args, optimizers)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user