train.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580
  1. #!/usr/bin/env python3
  2. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  3. # MIT License (https://opensource.org/licenses/MIT)
  4. import argparse
  5. import logging
  6. import os
  7. import sys
  8. from io import BytesIO
  9. import torch
  10. from funasr.build_utils.build_args import build_args
  11. from funasr.build_utils.build_dataloader import build_dataloader
  12. from funasr.build_utils.build_distributed import build_distributed
  13. from funasr.build_utils.build_model import build_model
  14. from funasr.build_utils.build_optimizer import build_optimizer
  15. from funasr.build_utils.build_scheduler import build_scheduler
  16. from funasr.build_utils.build_trainer import build_trainer
  17. from funasr.text.phoneme_tokenizer import g2p_choices
  18. from funasr.torch_utils.load_pretrained_model import load_pretrained_model
  19. from funasr.torch_utils.model_summary import model_summary
  20. from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
  21. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  22. from funasr.utils.nested_dict_action import NestedDictAction
  23. from funasr.utils.prepare_data import prepare_data
  24. from funasr.utils.types import int_or_none
  25. from funasr.utils.types import str2bool
  26. from funasr.utils.types import str_or_none
  27. from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
  28. def get_parser():
  29. parser = argparse.ArgumentParser(
  30. description="FunASR Common Training Parser",
  31. )
  32. # common configuration
  33. parser.add_argument("--output_dir", help="model save path")
  34. parser.add_argument(
  35. "--ngpu",
  36. type=int,
  37. default=0,
  38. help="The number of gpus. 0 indicates CPU mode",
  39. )
  40. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  41. parser.add_argument("--task_name", type=str, default="asr", help="Name for different tasks")
  42. # ddp related
  43. parser.add_argument(
  44. "--dist_backend",
  45. default="nccl",
  46. type=str,
  47. help="distributed backend",
  48. )
  49. parser.add_argument(
  50. "--dist_init_method",
  51. type=str,
  52. default="env://",
  53. help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
  54. '"WORLD_SIZE", and "RANK" are referred.',
  55. )
  56. parser.add_argument(
  57. "--dist_world_size",
  58. type=int,
  59. default=1,
  60. help="number of nodes for distributed training",
  61. )
  62. parser.add_argument(
  63. "--dist_rank",
  64. type=int,
  65. default=None,
  66. help="node rank for distributed training",
  67. )
  68. parser.add_argument(
  69. "--local_rank",
  70. type=int,
  71. default=None,
  72. help="local rank for distributed training",
  73. )
  74. parser.add_argument(
  75. "--dist_master_addr",
  76. default=None,
  77. type=str_or_none,
  78. help="The master address for distributed training. "
  79. "This value is used when dist_init_method == 'env://'",
  80. )
  81. parser.add_argument(
  82. "--dist_master_port",
  83. default=None,
  84. type=int_or_none,
  85. help="The master port for distributed training"
  86. "This value is used when dist_init_method == 'env://'",
  87. )
  88. parser.add_argument(
  89. "--dist_launcher",
  90. default=None,
  91. type=str_or_none,
  92. choices=["slurm", "mpi", None],
  93. help="The launcher type for distributed training",
  94. )
  95. parser.add_argument(
  96. "--multiprocessing_distributed",
  97. default=True,
  98. type=str2bool,
  99. help="Use multi-processing distributed training to launch "
  100. "N processes per node, which has N GPUs. This is the "
  101. "fastest way to use PyTorch for either single node or "
  102. "multi node data parallel training",
  103. )
  104. parser.add_argument(
  105. "--unused_parameters",
  106. type=str2bool,
  107. default=False,
  108. help="Whether to use the find_unused_parameters in "
  109. "torch.nn.parallel.DistributedDataParallel ",
  110. )
  111. parser.add_argument(
  112. "--gpu_id",
  113. type=int,
  114. default=0,
  115. help="local gpu id.",
  116. )
  117. # cudnn related
  118. parser.add_argument(
  119. "--cudnn_enabled",
  120. type=str2bool,
  121. default=torch.backends.cudnn.enabled,
  122. help="Enable CUDNN",
  123. )
  124. parser.add_argument(
  125. "--cudnn_benchmark",
  126. type=str2bool,
  127. default=torch.backends.cudnn.benchmark,
  128. help="Enable cudnn-benchmark mode",
  129. )
  130. parser.add_argument(
  131. "--cudnn_deterministic",
  132. type=str2bool,
  133. default=True,
  134. help="Enable cudnn-deterministic mode",
  135. )
  136. # trainer related
  137. parser.add_argument(
  138. "--max_epoch",
  139. type=int,
  140. default=40,
  141. help="The maximum number epoch to train",
  142. )
  143. parser.add_argument(
  144. "--max_update",
  145. type=int,
  146. default=sys.maxsize,
  147. help="The maximum number update step to train",
  148. )
  149. parser.add_argument(
  150. "--batch_interval",
  151. type=int,
  152. default=10000,
  153. help="The batch interval for saving model.",
  154. )
  155. parser.add_argument(
  156. "--patience",
  157. type=int_or_none,
  158. default=None,
  159. help="Number of epochs to wait without improvement "
  160. "before stopping the training",
  161. )
  162. parser.add_argument(
  163. "--val_scheduler_criterion",
  164. type=str,
  165. nargs=2,
  166. default=("valid", "loss"),
  167. help="The criterion used for the value given to the lr scheduler. "
  168. 'Give a pair referring the phase, "train" or "valid",'
  169. 'and the criterion name. The mode specifying "min" or "max" can '
  170. "be changed by --scheduler_conf",
  171. )
  172. parser.add_argument(
  173. "--early_stopping_criterion",
  174. type=str,
  175. nargs=3,
  176. default=("valid", "loss", "min"),
  177. help="The criterion used for judging of early stopping. "
  178. 'Give a pair referring the phase, "train" or "valid",'
  179. 'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
  180. )
  181. parser.add_argument(
  182. "--best_model_criterion",
  183. nargs="+",
  184. default=[
  185. ("train", "loss", "min"),
  186. ("valid", "loss", "min"),
  187. ("train", "acc", "max"),
  188. ("valid", "acc", "max"),
  189. ],
  190. help="The criterion used for judging of the best model. "
  191. 'Give a pair referring the phase, "train" or "valid",'
  192. 'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
  193. )
  194. parser.add_argument(
  195. "--keep_nbest_models",
  196. type=int,
  197. nargs="+",
  198. default=[10],
  199. help="Remove previous snapshots excluding the n-best scored epochs",
  200. )
  201. parser.add_argument(
  202. "--nbest_averaging_interval",
  203. type=int,
  204. default=0,
  205. help="The epoch interval to apply model averaging and save nbest models",
  206. )
  207. parser.add_argument(
  208. "--grad_clip",
  209. type=float,
  210. default=5.0,
  211. help="Gradient norm threshold to clip",
  212. )
  213. parser.add_argument(
  214. "--grad_clip_type",
  215. type=float,
  216. default=2.0,
  217. help="The type of the used p-norm for gradient clip. Can be inf",
  218. )
  219. parser.add_argument(
  220. "--grad_noise",
  221. type=str2bool,
  222. default=False,
  223. help="The flag to switch to use noise injection to "
  224. "gradients during training",
  225. )
  226. parser.add_argument(
  227. "--accum_grad",
  228. type=int,
  229. default=1,
  230. help="The number of gradient accumulation",
  231. )
  232. parser.add_argument(
  233. "--resume",
  234. type=str2bool,
  235. default=False,
  236. help="Enable resuming if checkpoint is existing",
  237. )
  238. parser.add_argument(
  239. "--train_dtype",
  240. default="float32",
  241. choices=["float16", "float32", "float64"],
  242. help="Data type for training.",
  243. )
  244. parser.add_argument(
  245. "--use_amp",
  246. type=str2bool,
  247. default=False,
  248. help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
  249. )
  250. parser.add_argument(
  251. "--log_interval",
  252. default=None,
  253. help="Show the logs every the number iterations in each epochs at the "
  254. "training phase. If None is given, it is decided according the number "
  255. "of training samples automatically .",
  256. )
  257. parser.add_argument(
  258. "--use_tensorboard",
  259. type=str2bool,
  260. default=True,
  261. help="Enable tensorboard logging",
  262. )
  263. # pretrained model related
  264. parser.add_argument(
  265. "--init_param",
  266. type=str,
  267. action="append",
  268. default=[],
  269. help="Specify the file path used for initialization of parameters. "
  270. "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
  271. "where file_path is the model file path, "
  272. "src_key specifies the key of model states to be used in the model file, "
  273. "dst_key specifies the attribute of the model to be initialized, "
  274. "and exclude_keys excludes keys of model states for the initialization."
  275. "e.g.\n"
  276. " # Load all parameters"
  277. " --init_param some/where/model.pb\n"
  278. " # Load only decoder parameters"
  279. " --init_param some/where/model.pb:decoder:decoder\n"
  280. " # Load only decoder parameters excluding decoder.embed"
  281. " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
  282. " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
  283. )
  284. parser.add_argument(
  285. "--ignore_init_mismatch",
  286. type=str2bool,
  287. default=False,
  288. help="Ignore size mismatch when loading pre-trained model",
  289. )
  290. parser.add_argument(
  291. "--freeze_param",
  292. type=str,
  293. default=[],
  294. action="append",
  295. help="Freeze parameters",
  296. )
  297. # dataset related
  298. parser.add_argument(
  299. "--dataset_type",
  300. type=str,
  301. default="small",
  302. help="whether to use dataloader for large dataset",
  303. )
  304. parser.add_argument(
  305. "--dataset_conf",
  306. action=NestedDictAction,
  307. default=dict(),
  308. help=f"The keyword arguments for dataset",
  309. )
  310. parser.add_argument(
  311. "--data_dir",
  312. type=str,
  313. default=None,
  314. help="root path of data",
  315. )
  316. parser.add_argument(
  317. "--train_set",
  318. type=str,
  319. default="train",
  320. help="train dataset",
  321. )
  322. parser.add_argument(
  323. "--valid_set",
  324. type=str,
  325. default="validation",
  326. help="dev dataset",
  327. )
  328. parser.add_argument(
  329. "--data_file_names",
  330. type=str,
  331. default="wav.scp,text",
  332. help="input data files",
  333. )
  334. parser.add_argument(
  335. "--speed_perturb",
  336. type=float,
  337. nargs="+",
  338. default=None,
  339. help="speed perturb",
  340. )
  341. parser.add_argument(
  342. "--use_preprocessor",
  343. type=str2bool,
  344. default=True,
  345. help="Apply preprocessing to data or not",
  346. )
  347. # optimization related
  348. parser.add_argument(
  349. "--optim",
  350. type=lambda x: x.lower(),
  351. default="adam",
  352. help="The optimizer type",
  353. )
  354. parser.add_argument(
  355. "--optim_conf",
  356. action=NestedDictAction,
  357. default=dict(),
  358. help="The keyword arguments for optimizer",
  359. )
  360. parser.add_argument(
  361. "--scheduler",
  362. type=lambda x: str_or_none(x.lower()),
  363. default=None,
  364. help="The lr scheduler type",
  365. )
  366. parser.add_argument(
  367. "--scheduler_conf",
  368. action=NestedDictAction,
  369. default=dict(),
  370. help="The keyword arguments for lr scheduler",
  371. )
  372. # most task related
  373. parser.add_argument(
  374. "--init",
  375. type=lambda x: str_or_none(x.lower()),
  376. default=None,
  377. help="The initialization method",
  378. choices=[
  379. "chainer",
  380. "xavier_uniform",
  381. "xavier_normal",
  382. "kaiming_uniform",
  383. "kaiming_normal",
  384. None,
  385. ],
  386. )
  387. parser.add_argument(
  388. "--token_list",
  389. type=str_or_none,
  390. default=None,
  391. help="A text mapping int-id to token",
  392. )
  393. parser.add_argument(
  394. "--token_type",
  395. type=str,
  396. default="bpe",
  397. choices=["bpe", "char", "word"],
  398. help="",
  399. )
  400. parser.add_argument(
  401. "--bpemodel",
  402. type=str_or_none,
  403. default=None,
  404. help="The model file fo sentencepiece",
  405. )
  406. parser.add_argument(
  407. "--cleaner",
  408. type=str_or_none,
  409. choices=[None, "tacotron", "jaconv", "vietnamese"],
  410. default=None,
  411. help="Apply text cleaning",
  412. )
  413. parser.add_argument(
  414. "--g2p",
  415. type=str_or_none,
  416. choices=g2p_choices,
  417. default=None,
  418. help="Specify g2p method if --token_type=phn",
  419. )
  420. # pai related
  421. parser.add_argument(
  422. "--use_pai",
  423. type=str2bool,
  424. default=False,
  425. help="flag to indicate whether training on PAI",
  426. )
  427. parser.add_argument(
  428. "--simple_ddp",
  429. type=str2bool,
  430. default=False,
  431. )
  432. parser.add_argument(
  433. "--num_worker_count",
  434. type=int,
  435. default=1,
  436. help="The number of machines on PAI.",
  437. )
  438. parser.add_argument(
  439. "--access_key_id",
  440. type=str,
  441. default=None,
  442. help="The username for oss.",
  443. )
  444. parser.add_argument(
  445. "--access_key_secret",
  446. type=str,
  447. default=None,
  448. help="The password for oss.",
  449. )
  450. parser.add_argument(
  451. "--endpoint",
  452. type=str,
  453. default=None,
  454. help="The endpoint for oss.",
  455. )
  456. parser.add_argument(
  457. "--bucket_name",
  458. type=str,
  459. default=None,
  460. help="The bucket name for oss.",
  461. )
  462. parser.add_argument(
  463. "--oss_bucket",
  464. default=None,
  465. help="oss bucket.",
  466. )
  467. return parser
  468. if __name__ == '__main__':
  469. parser = get_parser()
  470. args, extra_task_params = parser.parse_known_args()
  471. if extra_task_params:
  472. args = build_args(args, parser, extra_task_params)
  473. # set random seed
  474. set_all_random_seed(args.seed)
  475. torch.backends.cudnn.enabled = args.cudnn_enabled
  476. torch.backends.cudnn.benchmark = args.cudnn_benchmark
  477. torch.backends.cudnn.deterministic = args.cudnn_deterministic
  478. # ddp init
  479. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
  480. args.distributed = args.ngpu > 1 or args.dist_world_size > 1
  481. distributed_option = build_distributed(args)
  482. # for logging
  483. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  484. logging.basicConfig(
  485. level="INFO",
  486. format=f"[{os.uname()[1].split('.')[0]}]"
  487. f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  488. )
  489. else:
  490. logging.basicConfig(
  491. level="ERROR",
  492. format=f"[{os.uname()[1].split('.')[0]}]"
  493. f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  494. )
  495. # prepare files for dataloader
  496. prepare_data(args, distributed_option)
  497. model = build_model(args)
  498. model = model.to(
  499. dtype=getattr(torch, args.train_dtype),
  500. device="cuda" if args.ngpu > 0 else "cpu",
  501. )
  502. for t in args.freeze_param:
  503. for k, p in model.named_parameters():
  504. if k.startswith(t + ".") or k == t:
  505. logging.info(f"Setting {k}.requires_grad = False")
  506. p.requires_grad = False
  507. optimizers = build_optimizer(args, model=model)
  508. schedulers = build_scheduler(args, optimizers)
  509. logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
  510. distributed_option.dist_rank,
  511. distributed_option.local_rank))
  512. logging.info(pytorch_cudnn_version())
  513. logging.info("Args: {}".format(args))
  514. logging.info(model_summary(model))
  515. logging.info("Optimizer: {}".format(optimizers))
  516. logging.info("Scheduler: {}".format(schedulers))
  517. # dump args to config.yaml
  518. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  519. os.makedirs(args.output_dir, exist_ok=True)
  520. with open(os.path.join(args.output_dir, "config.yaml"), "w") as f:
  521. logging.info("Saving the configuration in {}/{}".format(args.output_dir, "config.yaml"))
  522. if args.use_pai:
  523. buffer = BytesIO()
  524. torch.save({"config": vars(args)}, buffer)
  525. args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue())
  526. else:
  527. yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
  528. for p in args.init_param:
  529. logging.info(f"Loading pretrained params from {p}")
  530. load_pretrained_model(
  531. model=model,
  532. init_param=p,
  533. ignore_init_mismatch=args.ignore_init_mismatch,
  534. map_location=f"cuda:{torch.cuda.current_device()}"
  535. if args.ngpu > 0
  536. else "cpu",
  537. oss_bucket=args.oss_bucket,
  538. )
  539. # dataloader for training/validation
  540. train_dataloader, valid_dataloader = build_dataloader(args)
  541. # Trainer, including model, optimizers, etc.
  542. trainer = build_trainer(
  543. args=args,
  544. model=model,
  545. optimizers=optimizers,
  546. schedulers=schedulers,
  547. train_dataloader=train_dataloader,
  548. valid_dataloader=valid_dataloader,
  549. distributed_option=distributed_option
  550. )
  551. trainer.run()