|
|
@@ -16,6 +16,7 @@ from funasr.build_utils.build_optimizer import build_optimizer
|
|
|
from funasr.build_utils.build_scheduler import build_scheduler
|
|
|
from funasr.build_utils.build_trainer import build_trainer
|
|
|
from funasr.text.phoneme_tokenizer import g2p_choices
|
|
|
+from funasr.torch_utils.load_pretrained_model import load_pretrained_model
|
|
|
from funasr.torch_utils.model_summary import model_summary
|
|
|
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
|
|
|
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
|
|
@@ -530,6 +531,18 @@ if __name__ == '__main__':
|
|
|
else:
|
|
|
yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
|
|
|
|
|
|
+ for p in args.init_param:
|
|
|
+ logging.info(f"Loading pretrained params from {p}")
|
|
|
+ load_pretrained_model(
|
|
|
+ model=model,
|
|
|
+ init_param=p,
|
|
|
+ ignore_init_mismatch=args.ignore_init_mismatch,
|
|
|
+ map_location=f"cuda:{torch.cuda.current_device()}"
|
|
|
+ if args.ngpu > 0
|
|
|
+ else "cpu",
|
|
|
+ oss_bucket=args.oss_bucket,
|
|
|
+ )
|
|
|
+
|
|
|
# dataloader for training/validation
|
|
|
train_dataloader, valid_dataloader = build_dataloader(args)
|
|
|
|