train.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578
  1. #!/usr/bin/env python3
  2. import argparse
  3. import logging
  4. import os
  5. import sys
  6. from io import BytesIO
  7. import torch
  8. from funasr.build_utils.build_args import build_args
  9. from funasr.build_utils.build_dataloader import build_dataloader
  10. from funasr.build_utils.build_distributed import build_distributed
  11. from funasr.build_utils.build_model import build_model
  12. from funasr.build_utils.build_optimizer import build_optimizer
  13. from funasr.build_utils.build_scheduler import build_scheduler
  14. from funasr.build_utils.build_trainer import build_trainer
  15. from funasr.text.phoneme_tokenizer import g2p_choices
  16. from funasr.torch_utils.load_pretrained_model import load_pretrained_model
  17. from funasr.torch_utils.model_summary import model_summary
  18. from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
  19. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  20. from funasr.utils.nested_dict_action import NestedDictAction
  21. from funasr.utils.prepare_data import prepare_data
  22. from funasr.utils.types import int_or_none
  23. from funasr.utils.types import str2bool
  24. from funasr.utils.types import str_or_none
  25. from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
  26. def get_parser():
  27. parser = argparse.ArgumentParser(
  28. description="FunASR Common Training Parser",
  29. )
  30. # common configuration
  31. parser.add_argument("--output_dir", help="model save path")
  32. parser.add_argument(
  33. "--ngpu",
  34. type=int,
  35. default=0,
  36. help="The number of gpus. 0 indicates CPU mode",
  37. )
  38. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  39. parser.add_argument("--task_name", type=str, default="asr", help="Name for different tasks")
  40. # ddp related
  41. parser.add_argument(
  42. "--dist_backend",
  43. default="nccl",
  44. type=str,
  45. help="distributed backend",
  46. )
  47. parser.add_argument(
  48. "--dist_init_method",
  49. type=str,
  50. default="env://",
  51. help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
  52. '"WORLD_SIZE", and "RANK" are referred.',
  53. )
  54. parser.add_argument(
  55. "--dist_world_size",
  56. type=int,
  57. default=1,
  58. help="number of nodes for distributed training",
  59. )
  60. parser.add_argument(
  61. "--dist_rank",
  62. type=int,
  63. default=None,
  64. help="node rank for distributed training",
  65. )
  66. parser.add_argument(
  67. "--local_rank",
  68. type=int,
  69. default=None,
  70. help="local rank for distributed training",
  71. )
  72. parser.add_argument(
  73. "--dist_master_addr",
  74. default=None,
  75. type=str_or_none,
  76. help="The master address for distributed training. "
  77. "This value is used when dist_init_method == 'env://'",
  78. )
  79. parser.add_argument(
  80. "--dist_master_port",
  81. default=None,
  82. type=int_or_none,
  83. help="The master port for distributed training"
  84. "This value is used when dist_init_method == 'env://'",
  85. )
  86. parser.add_argument(
  87. "--dist_launcher",
  88. default=None,
  89. type=str_or_none,
  90. choices=["slurm", "mpi", None],
  91. help="The launcher type for distributed training",
  92. )
  93. parser.add_argument(
  94. "--multiprocessing_distributed",
  95. default=True,
  96. type=str2bool,
  97. help="Use multi-processing distributed training to launch "
  98. "N processes per node, which has N GPUs. This is the "
  99. "fastest way to use PyTorch for either single node or "
  100. "multi node data parallel training",
  101. )
  102. parser.add_argument(
  103. "--unused_parameters",
  104. type=str2bool,
  105. default=False,
  106. help="Whether to use the find_unused_parameters in "
  107. "torch.nn.parallel.DistributedDataParallel ",
  108. )
  109. parser.add_argument(
  110. "--gpu_id",
  111. type=int,
  112. default=0,
  113. help="local gpu id.",
  114. )
  115. # cudnn related
  116. parser.add_argument(
  117. "--cudnn_enabled",
  118. type=str2bool,
  119. default=torch.backends.cudnn.enabled,
  120. help="Enable CUDNN",
  121. )
  122. parser.add_argument(
  123. "--cudnn_benchmark",
  124. type=str2bool,
  125. default=torch.backends.cudnn.benchmark,
  126. help="Enable cudnn-benchmark mode",
  127. )
  128. parser.add_argument(
  129. "--cudnn_deterministic",
  130. type=str2bool,
  131. default=True,
  132. help="Enable cudnn-deterministic mode",
  133. )
  134. # trainer related
  135. parser.add_argument(
  136. "--max_epoch",
  137. type=int,
  138. default=40,
  139. help="The maximum number epoch to train",
  140. )
  141. parser.add_argument(
  142. "--max_update",
  143. type=int,
  144. default=sys.maxsize,
  145. help="The maximum number update step to train",
  146. )
  147. parser.add_argument(
  148. "--batch_interval",
  149. type=int,
  150. default=10000,
  151. help="The batch interval for saving model.",
  152. )
  153. parser.add_argument(
  154. "--patience",
  155. type=int_or_none,
  156. default=None,
  157. help="Number of epochs to wait without improvement "
  158. "before stopping the training",
  159. )
  160. parser.add_argument(
  161. "--val_scheduler_criterion",
  162. type=str,
  163. nargs=2,
  164. default=("valid", "loss"),
  165. help="The criterion used for the value given to the lr scheduler. "
  166. 'Give a pair referring the phase, "train" or "valid",'
  167. 'and the criterion name. The mode specifying "min" or "max" can '
  168. "be changed by --scheduler_conf",
  169. )
  170. parser.add_argument(
  171. "--early_stopping_criterion",
  172. type=str,
  173. nargs=3,
  174. default=("valid", "loss", "min"),
  175. help="The criterion used for judging of early stopping. "
  176. 'Give a pair referring the phase, "train" or "valid",'
  177. 'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
  178. )
  179. parser.add_argument(
  180. "--best_model_criterion",
  181. nargs="+",
  182. default=[
  183. ("train", "loss", "min"),
  184. ("valid", "loss", "min"),
  185. ("train", "acc", "max"),
  186. ("valid", "acc", "max"),
  187. ],
  188. help="The criterion used for judging of the best model. "
  189. 'Give a pair referring the phase, "train" or "valid",'
  190. 'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
  191. )
  192. parser.add_argument(
  193. "--keep_nbest_models",
  194. type=int,
  195. nargs="+",
  196. default=[10],
  197. help="Remove previous snapshots excluding the n-best scored epochs",
  198. )
  199. parser.add_argument(
  200. "--nbest_averaging_interval",
  201. type=int,
  202. default=0,
  203. help="The epoch interval to apply model averaging and save nbest models",
  204. )
  205. parser.add_argument(
  206. "--grad_clip",
  207. type=float,
  208. default=5.0,
  209. help="Gradient norm threshold to clip",
  210. )
  211. parser.add_argument(
  212. "--grad_clip_type",
  213. type=float,
  214. default=2.0,
  215. help="The type of the used p-norm for gradient clip. Can be inf",
  216. )
  217. parser.add_argument(
  218. "--grad_noise",
  219. type=str2bool,
  220. default=False,
  221. help="The flag to switch to use noise injection to "
  222. "gradients during training",
  223. )
  224. parser.add_argument(
  225. "--accum_grad",
  226. type=int,
  227. default=1,
  228. help="The number of gradient accumulation",
  229. )
  230. parser.add_argument(
  231. "--resume",
  232. type=str2bool,
  233. default=False,
  234. help="Enable resuming if checkpoint is existing",
  235. )
  236. parser.add_argument(
  237. "--train_dtype",
  238. default="float32",
  239. choices=["float16", "float32", "float64"],
  240. help="Data type for training.",
  241. )
  242. parser.add_argument(
  243. "--use_amp",
  244. type=str2bool,
  245. default=False,
  246. help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
  247. )
  248. parser.add_argument(
  249. "--log_interval",
  250. default=None,
  251. help="Show the logs every the number iterations in each epochs at the "
  252. "training phase. If None is given, it is decided according the number "
  253. "of training samples automatically .",
  254. )
  255. parser.add_argument(
  256. "--use_tensorboard",
  257. type=str2bool,
  258. default=True,
  259. help="Enable tensorboard logging",
  260. )
  261. # pretrained model related
  262. parser.add_argument(
  263. "--init_param",
  264. type=str,
  265. action="append",
  266. default=[],
  267. help="Specify the file path used for initialization of parameters. "
  268. "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
  269. "where file_path is the model file path, "
  270. "src_key specifies the key of model states to be used in the model file, "
  271. "dst_key specifies the attribute of the model to be initialized, "
  272. "and exclude_keys excludes keys of model states for the initialization."
  273. "e.g.\n"
  274. " # Load all parameters"
  275. " --init_param some/where/model.pb\n"
  276. " # Load only decoder parameters"
  277. " --init_param some/where/model.pb:decoder:decoder\n"
  278. " # Load only decoder parameters excluding decoder.embed"
  279. " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
  280. " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
  281. )
  282. parser.add_argument(
  283. "--ignore_init_mismatch",
  284. type=str2bool,
  285. default=False,
  286. help="Ignore size mismatch when loading pre-trained model",
  287. )
  288. parser.add_argument(
  289. "--freeze_param",
  290. type=str,
  291. default=[],
  292. nargs="*",
  293. help="Freeze parameters",
  294. )
  295. # dataset related
  296. parser.add_argument(
  297. "--dataset_type",
  298. type=str,
  299. default="small",
  300. help="whether to use dataloader for large dataset",
  301. )
  302. parser.add_argument(
  303. "--dataset_conf",
  304. action=NestedDictAction,
  305. default=dict(),
  306. help=f"The keyword arguments for dataset",
  307. )
  308. parser.add_argument(
  309. "--data_dir",
  310. type=str,
  311. default=None,
  312. help="root path of data",
  313. )
  314. parser.add_argument(
  315. "--train_set",
  316. type=str,
  317. default="train",
  318. help="train dataset",
  319. )
  320. parser.add_argument(
  321. "--valid_set",
  322. type=str,
  323. default="validation",
  324. help="dev dataset",
  325. )
  326. parser.add_argument(
  327. "--data_file_names",
  328. type=str,
  329. default="wav.scp,text",
  330. help="input data files",
  331. )
  332. parser.add_argument(
  333. "--speed_perturb",
  334. type=float,
  335. nargs="+",
  336. default=None,
  337. help="speed perturb",
  338. )
  339. parser.add_argument(
  340. "--use_preprocessor",
  341. type=str2bool,
  342. default=True,
  343. help="Apply preprocessing to data or not",
  344. )
  345. # optimization related
  346. parser.add_argument(
  347. "--optim",
  348. type=lambda x: x.lower(),
  349. default="adam",
  350. help="The optimizer type",
  351. )
  352. parser.add_argument(
  353. "--optim_conf",
  354. action=NestedDictAction,
  355. default=dict(),
  356. help="The keyword arguments for optimizer",
  357. )
  358. parser.add_argument(
  359. "--scheduler",
  360. type=lambda x: str_or_none(x.lower()),
  361. default=None,
  362. help="The lr scheduler type",
  363. )
  364. parser.add_argument(
  365. "--scheduler_conf",
  366. action=NestedDictAction,
  367. default=dict(),
  368. help="The keyword arguments for lr scheduler",
  369. )
  370. # most task related
  371. parser.add_argument(
  372. "--init",
  373. type=lambda x: str_or_none(x.lower()),
  374. default=None,
  375. help="The initialization method",
  376. choices=[
  377. "chainer",
  378. "xavier_uniform",
  379. "xavier_normal",
  380. "kaiming_uniform",
  381. "kaiming_normal",
  382. None,
  383. ],
  384. )
  385. parser.add_argument(
  386. "--token_list",
  387. type=str_or_none,
  388. default=None,
  389. help="A text mapping int-id to token",
  390. )
  391. parser.add_argument(
  392. "--token_type",
  393. type=str,
  394. default="bpe",
  395. choices=["bpe", "char", "word"],
  396. help="",
  397. )
  398. parser.add_argument(
  399. "--bpemodel",
  400. type=str_or_none,
  401. default=None,
  402. help="The model file fo sentencepiece",
  403. )
  404. parser.add_argument(
  405. "--cleaner",
  406. type=str_or_none,
  407. choices=[None, "tacotron", "jaconv", "vietnamese"],
  408. default=None,
  409. help="Apply text cleaning",
  410. )
  411. parser.add_argument(
  412. "--g2p",
  413. type=str_or_none,
  414. choices=g2p_choices,
  415. default=None,
  416. help="Specify g2p method if --token_type=phn",
  417. )
  418. # pai related
  419. parser.add_argument(
  420. "--use_pai",
  421. type=str2bool,
  422. default=False,
  423. help="flag to indicate whether training on PAI",
  424. )
  425. parser.add_argument(
  426. "--simple_ddp",
  427. type=str2bool,
  428. default=False,
  429. )
  430. parser.add_argument(
  431. "--num_worker_count",
  432. type=int,
  433. default=1,
  434. help="The number of machines on PAI.",
  435. )
  436. parser.add_argument(
  437. "--access_key_id",
  438. type=str,
  439. default=None,
  440. help="The username for oss.",
  441. )
  442. parser.add_argument(
  443. "--access_key_secret",
  444. type=str,
  445. default=None,
  446. help="The password for oss.",
  447. )
  448. parser.add_argument(
  449. "--endpoint",
  450. type=str,
  451. default=None,
  452. help="The endpoint for oss.",
  453. )
  454. parser.add_argument(
  455. "--bucket_name",
  456. type=str,
  457. default=None,
  458. help="The bucket name for oss.",
  459. )
  460. parser.add_argument(
  461. "--oss_bucket",
  462. default=None,
  463. help="oss bucket.",
  464. )
  465. return parser
  466. if __name__ == '__main__':
  467. parser = get_parser()
  468. args, extra_task_params = parser.parse_known_args()
  469. if extra_task_params:
  470. args = build_args(args, parser, extra_task_params)
  471. # set random seed
  472. set_all_random_seed(args.seed)
  473. torch.backends.cudnn.enabled = args.cudnn_enabled
  474. torch.backends.cudnn.benchmark = args.cudnn_benchmark
  475. torch.backends.cudnn.deterministic = args.cudnn_deterministic
  476. # ddp init
  477. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
  478. args.distributed = args.ngpu > 1 or args.dist_world_size > 1
  479. distributed_option = build_distributed(args)
  480. # for logging
  481. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  482. logging.basicConfig(
  483. level="INFO",
  484. format=f"[{os.uname()[1].split('.')[0]}]"
  485. f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  486. )
  487. else:
  488. logging.basicConfig(
  489. level="ERROR",
  490. format=f"[{os.uname()[1].split('.')[0]}]"
  491. f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  492. )
  493. # prepare files for dataloader
  494. prepare_data(args, distributed_option)
  495. model = build_model(args)
  496. model = model.to(
  497. dtype=getattr(torch, args.train_dtype),
  498. device="cuda" if args.ngpu > 0 else "cpu",
  499. )
  500. for t in args.freeze_param:
  501. for k, p in model.named_parameters():
  502. if k.startswith(t + ".") or k == t:
  503. logging.info(f"Setting {k}.requires_grad = False")
  504. p.requires_grad = False
  505. optimizers = build_optimizer(args, model=model)
  506. schedulers = build_scheduler(args, optimizers)
  507. logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
  508. distributed_option.dist_rank,
  509. distributed_option.local_rank))
  510. logging.info(pytorch_cudnn_version())
  511. logging.info("Args: {}".format(args))
  512. logging.info(model_summary(model))
  513. logging.info("Optimizer: {}".format(optimizers))
  514. logging.info("Scheduler: {}".format(schedulers))
  515. # dump args to config.yaml
  516. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  517. os.makedirs(args.output_dir, exist_ok=True)
  518. with open(os.path.join(args.output_dir, "config.yaml"), "w") as f:
  519. logging.info("Saving the configuration in {}/{}".format(args.output_dir, "config.yaml"))
  520. if args.use_pai:
  521. buffer = BytesIO()
  522. torch.save({"config": vars(args)}, buffer)
  523. args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue())
  524. else:
  525. yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
  526. for p in args.init_param:
  527. logging.info(f"Loading pretrained params from {p}")
  528. load_pretrained_model(
  529. model=model,
  530. init_param=p,
  531. ignore_init_mismatch=args.ignore_init_mismatch,
  532. map_location=f"cuda:{torch.cuda.current_device()}"
  533. if args.ngpu > 0
  534. else "cpu",
  535. oss_bucket=args.oss_bucket,
  536. )
  537. # dataloader for training/validation
  538. train_dataloader, valid_dataloader = build_dataloader(args)
  539. # Trainer, including model, optimizers, etc.
  540. trainer = build_trainer(
  541. args=args,
  542. model=model,
  543. optimizers=optimizers,
  544. schedulers=schedulers,
  545. train_dataloader=train_dataloader,
  546. valid_dataloader=valid_dataloader,
  547. distributed_option=distributed_option
  548. )
  549. trainer.run()