|
|
@@ -272,8 +272,8 @@ def get_parser():
|
|
|
parser.add_argument(
|
|
|
"--init_param",
|
|
|
type=str,
|
|
|
+ action="append",
|
|
|
default=[],
|
|
|
- nargs="*",
|
|
|
help="Specify the file path used for initialization of parameters. "
|
|
|
"The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
|
|
|
"where file_path is the model file path, "
|
|
|
@@ -519,6 +519,12 @@ if __name__ == '__main__':
|
|
|
dtype=getattr(torch, args.train_dtype),
|
|
|
device="cuda" if args.ngpu > 0 else "cpu",
|
|
|
)
|
|
|
+ for t in args.freeze_param:
|
|
|
+ for k, p in model.named_parameters():
|
|
|
+ if k.startswith(t + ".") or k == t:
|
|
|
+ logging.info(f"Setting {k}.requires_grad = False")
|
|
|
+ p.requires_grad = False
|
|
|
+
|
|
|
optimizers = build_optimizer(args, model=model)
|
|
|
schedulers = build_scheduler(args, optimizers)
|
|
|
|