train.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. #!/usr/bin/env python3
  2. import argparse
  3. import logging
  4. import os
  5. import sys
  6. from io import BytesIO
  7. import torch
  8. from funasr.build_utils.build_args import build_args
  9. from funasr.build_utils.build_dataloader import build_dataloader
  10. from funasr.build_utils.build_distributed import build_distributed
  11. from funasr.build_utils.build_model import build_model
  12. from funasr.build_utils.build_optimizer import build_optimizer
  13. from funasr.build_utils.build_scheduler import build_scheduler
  14. from funasr.build_utils.build_trainer import build_trainer
  15. from funasr.text.phoneme_tokenizer import g2p_choices
  16. from funasr.torch_utils.model_summary import model_summary
  17. from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
  18. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  19. from funasr.utils.nested_dict_action import NestedDictAction
  20. from funasr.utils.prepare_data import prepare_data
  21. from funasr.utils.types import int_or_none
  22. from funasr.utils.types import str2bool
  23. from funasr.utils.types import str2triple_str
  24. from funasr.utils.types import str_or_none
  25. from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
  26. def get_parser():
  27. parser = argparse.ArgumentParser(
  28. description="FunASR Common Training Parser",
  29. )
  30. # common configuration
  31. parser.add_argument("--output_dir", help="model save path")
  32. parser.add_argument(
  33. "--ngpu",
  34. type=int,
  35. default=0,
  36. help="The number of gpus. 0 indicates CPU mode",
  37. )
  38. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  39. parser.add_argument("--task_name", type=str, default="asr", help="Name for different tasks")
  40. # ddp related
  41. parser.add_argument(
  42. "--dist_backend",
  43. default="nccl",
  44. type=str,
  45. help="distributed backend",
  46. )
  47. parser.add_argument(
  48. "--dist_init_method",
  49. type=str,
  50. default="env://",
  51. help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
  52. '"WORLD_SIZE", and "RANK" are referred.',
  53. )
  54. parser.add_argument(
  55. "--dist_world_size",
  56. type=int,
  57. default=1,
  58. help="number of nodes for distributed training",
  59. )
  60. parser.add_argument(
  61. "--dist_rank",
  62. type=int,
  63. default=None,
  64. help="node rank for distributed training",
  65. )
  66. parser.add_argument(
  67. "--local_rank",
  68. type=int,
  69. default=None,
  70. help="local rank for distributed training",
  71. )
  72. parser.add_argument(
  73. "--dist_master_addr",
  74. default=None,
  75. type=str_or_none,
  76. help="The master address for distributed training. "
  77. "This value is used when dist_init_method == 'env://'",
  78. )
  79. parser.add_argument(
  80. "--dist_master_port",
  81. default=None,
  82. type=int_or_none,
  83. help="The master port for distributed training"
  84. "This value is used when dist_init_method == 'env://'",
  85. )
  86. parser.add_argument(
  87. "--dist_launcher",
  88. default=None,
  89. type=str_or_none,
  90. choices=["slurm", "mpi", None],
  91. help="The launcher type for distributed training",
  92. )
  93. parser.add_argument(
  94. "--multiprocessing_distributed",
  95. default=True,
  96. type=str2bool,
  97. help="Use multi-processing distributed training to launch "
  98. "N processes per node, which has N GPUs. This is the "
  99. "fastest way to use PyTorch for either single node or "
  100. "multi node data parallel training",
  101. )
  102. parser.add_argument(
  103. "--unused_parameters",
  104. type=str2bool,
  105. default=False,
  106. help="Whether to use the find_unused_parameters in "
  107. "torch.nn.parallel.DistributedDataParallel ",
  108. )
  109. parser.add_argument(
  110. "--gpu_id",
  111. type=int,
  112. default=0,
  113. help="local gpu id.",
  114. )
  115. # cudnn related
  116. parser.add_argument(
  117. "--cudnn_enabled",
  118. type=str2bool,
  119. default=torch.backends.cudnn.enabled,
  120. help="Enable CUDNN",
  121. )
  122. parser.add_argument(
  123. "--cudnn_benchmark",
  124. type=str2bool,
  125. default=torch.backends.cudnn.benchmark,
  126. help="Enable cudnn-benchmark mode",
  127. )
  128. parser.add_argument(
  129. "--cudnn_deterministic",
  130. type=str2bool,
  131. default=True,
  132. help="Enable cudnn-deterministic mode",
  133. )
  134. # trainer related
  135. parser.add_argument(
  136. "--max_epoch",
  137. type=int,
  138. default=40,
  139. help="The maximum number epoch to train",
  140. )
  141. parser.add_argument(
  142. "--max_update",
  143. type=int,
  144. default=sys.maxsize,
  145. help="The maximum number update step to train",
  146. )
  147. parser.add_argument(
  148. "--batch_interval",
  149. type=int,
  150. default=10000,
  151. help="The batch interval for saving model.",
  152. )
  153. parser.add_argument(
  154. "--patience",
  155. default=None,
  156. help="Number of epochs to wait without improvement "
  157. "before stopping the training",
  158. )
  159. parser.add_argument(
  160. "--val_scheduler_criterion",
  161. type=str,
  162. nargs=2,
  163. default=("valid", "loss"),
  164. help="The criterion used for the value given to the lr scheduler. "
  165. 'Give a pair referring the phase, "train" or "valid",'
  166. 'and the criterion name. The mode specifying "min" or "max" can '
  167. "be changed by --scheduler_conf",
  168. )
  169. parser.add_argument(
  170. "--early_stopping_criterion",
  171. type=str,
  172. nargs=3,
  173. default=("valid", "loss", "min"),
  174. help="The criterion used for judging of early stopping. "
  175. 'Give a pair referring the phase, "train" or "valid",'
  176. 'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
  177. )
  178. parser.add_argument(
  179. "--best_model_criterion",
  180. nargs="+",
  181. default=[
  182. ("train", "loss", "min"),
  183. ("valid", "loss", "min"),
  184. ("train", "acc", "max"),
  185. ("valid", "acc", "max"),
  186. ],
  187. help="The criterion used for judging of the best model. "
  188. 'Give a pair referring the phase, "train" or "valid",'
  189. 'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
  190. )
  191. parser.add_argument(
  192. "--keep_nbest_models",
  193. type=int,
  194. nargs="+",
  195. default=[10],
  196. help="Remove previous snapshots excluding the n-best scored epochs",
  197. )
  198. parser.add_argument(
  199. "--nbest_averaging_interval",
  200. type=int,
  201. default=0,
  202. help="The epoch interval to apply model averaging and save nbest models",
  203. )
  204. parser.add_argument(
  205. "--grad_clip",
  206. type=float,
  207. default=5.0,
  208. help="Gradient norm threshold to clip",
  209. )
  210. parser.add_argument(
  211. "--grad_clip_type",
  212. type=float,
  213. default=2.0,
  214. help="The type of the used p-norm for gradient clip. Can be inf",
  215. )
  216. parser.add_argument(
  217. "--grad_noise",
  218. type=str2bool,
  219. default=False,
  220. help="The flag to switch to use noise injection to "
  221. "gradients during training",
  222. )
  223. parser.add_argument(
  224. "--accum_grad",
  225. type=int,
  226. default=1,
  227. help="The number of gradient accumulation",
  228. )
  229. parser.add_argument(
  230. "--resume",
  231. type=str2bool,
  232. default=False,
  233. help="Enable resuming if checkpoint is existing",
  234. )
  235. parser.add_argument(
  236. "--use_amp",
  237. type=str2bool,
  238. default=False,
  239. help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
  240. )
  241. parser.add_argument(
  242. "--log_interval",
  243. default=None,
  244. help="Show the logs every the number iterations in each epochs at the "
  245. "training phase. If None is given, it is decided according the number "
  246. "of training samples automatically .",
  247. )
  248. # pretrained model related
  249. parser.add_argument(
  250. "--init_param",
  251. type=str,
  252. default=[],
  253. nargs="*",
  254. help="Specify the file path used for initialization of parameters. "
  255. "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
  256. "where file_path is the model file path, "
  257. "src_key specifies the key of model states to be used in the model file, "
  258. "dst_key specifies the attribute of the model to be initialized, "
  259. "and exclude_keys excludes keys of model states for the initialization."
  260. "e.g.\n"
  261. " # Load all parameters"
  262. " --init_param some/where/model.pb\n"
  263. " # Load only decoder parameters"
  264. " --init_param some/where/model.pb:decoder:decoder\n"
  265. " # Load only decoder parameters excluding decoder.embed"
  266. " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
  267. " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
  268. )
  269. parser.add_argument(
  270. "--ignore_init_mismatch",
  271. type=str2bool,
  272. default=False,
  273. help="Ignore size mismatch when loading pre-trained model",
  274. )
  275. parser.add_argument(
  276. "--freeze_param",
  277. type=str,
  278. default=[],
  279. nargs="*",
  280. help="Freeze parameters",
  281. )
  282. # dataset related
  283. parser.add_argument(
  284. "--dataset_type",
  285. type=str,
  286. default="small",
  287. help="whether to use dataloader for large dataset",
  288. )
  289. parser.add_argument(
  290. "--dataset_conf",
  291. action=NestedDictAction,
  292. default=dict(),
  293. help=f"The keyword arguments for dataset",
  294. )
  295. parser.add_argument(
  296. "--train_data_file",
  297. type=str,
  298. default=None,
  299. help="train_list for large dataset",
  300. )
  301. parser.add_argument(
  302. "--valid_data_file",
  303. type=str,
  304. default=None,
  305. help="valid_list for large dataset",
  306. )
  307. parser.add_argument(
  308. "--train_data_path_and_name_and_type",
  309. type=str2triple_str,
  310. action="append",
  311. default=[],
  312. help="e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. ",
  313. )
  314. parser.add_argument(
  315. "--valid_data_path_and_name_and_type",
  316. type=str2triple_str,
  317. action="append",
  318. default=[],
  319. )
  320. parser.add_argument(
  321. "--train_shape_file",
  322. type=str,
  323. action="append",
  324. default=[],
  325. )
  326. parser.add_argument(
  327. "--valid_shape_file",
  328. type=str,
  329. action="append",
  330. default=[],
  331. )
  332. parser.add_argument(
  333. "--use_preprocessor",
  334. type=str2bool,
  335. default=True,
  336. help="Apply preprocessing to data or not",
  337. )
  338. # optimization related
  339. parser.add_argument(
  340. "--optim",
  341. type=lambda x: x.lower(),
  342. default="adam",
  343. help="The optimizer type",
  344. )
  345. parser.add_argument(
  346. "--optim_conf",
  347. action=NestedDictAction,
  348. default=dict(),
  349. help="The keyword arguments for optimizer",
  350. )
  351. parser.add_argument(
  352. "--scheduler",
  353. type=lambda x: str_or_none(x.lower()),
  354. default=None,
  355. help="The lr scheduler type",
  356. )
  357. parser.add_argument(
  358. "--scheduler_conf",
  359. action=NestedDictAction,
  360. default=dict(),
  361. help="The keyword arguments for lr scheduler",
  362. )
  363. # most task related
  364. parser.add_argument(
  365. "--init",
  366. type=lambda x: str_or_none(x.lower()),
  367. default=None,
  368. help="The initialization method",
  369. choices=[
  370. "chainer",
  371. "xavier_uniform",
  372. "xavier_normal",
  373. "kaiming_uniform",
  374. "kaiming_normal",
  375. None,
  376. ],
  377. )
  378. parser.add_argument(
  379. "--token_list",
  380. type=str_or_none,
  381. default=None,
  382. help="A text mapping int-id to token",
  383. )
  384. parser.add_argument(
  385. "--token_type",
  386. type=str,
  387. default="bpe",
  388. choices=["bpe", "char", "word"],
  389. help="",
  390. )
  391. parser.add_argument(
  392. "--bpemodel",
  393. type=str_or_none,
  394. default=None,
  395. help="The model file fo sentencepiece",
  396. )
  397. parser.add_argument(
  398. "--cleaner",
  399. type=str_or_none,
  400. choices=[None, "tacotron", "jaconv", "vietnamese"],
  401. default=None,
  402. help="Apply text cleaning",
  403. )
  404. parser.add_argument(
  405. "--g2p",
  406. type=str_or_none,
  407. choices=g2p_choices,
  408. default=None,
  409. help="Specify g2p method if --token_type=phn",
  410. )
  411. # pai related
  412. parser.add_argument(
  413. "--use_pai",
  414. type=str2bool,
  415. default=False,
  416. help="flag to indicate whether training on PAI",
  417. )
  418. parser.add_argument(
  419. "--simple_ddp",
  420. type=str2bool,
  421. default=False,
  422. )
  423. parser.add_argument(
  424. "--num_worker_count",
  425. type=int,
  426. default=1,
  427. help="The number of machines on PAI.",
  428. )
  429. parser.add_argument(
  430. "--access_key_id",
  431. type=str,
  432. default=None,
  433. help="The username for oss.",
  434. )
  435. parser.add_argument(
  436. "--access_key_secret",
  437. type=str,
  438. default=None,
  439. help="The password for oss.",
  440. )
  441. parser.add_argument(
  442. "--endpoint",
  443. type=str,
  444. default=None,
  445. help="The endpoint for oss.",
  446. )
  447. parser.add_argument(
  448. "--bucket_name",
  449. type=str,
  450. default=None,
  451. help="The bucket name for oss.",
  452. )
  453. parser.add_argument(
  454. "--oss_bucket",
  455. default=None,
  456. help="oss bucket.",
  457. )
  458. return parser
  459. if __name__ == '__main__':
  460. parser = get_parser()
  461. args, extra_task_params = parser.parse_known_args()
  462. if extra_task_params:
  463. args = build_args(args, parser, extra_task_params)
  464. # set random seed
  465. set_all_random_seed(args.seed)
  466. torch.backends.cudnn.enabled = args.cudnn_enabled
  467. torch.backends.cudnn.benchmark = args.cudnn_benchmark
  468. torch.backends.cudnn.deterministic = args.cudnn_deterministic
  469. # ddp init
  470. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
  471. args.distributed = args.ngpu > 1 or args.dist_world_size > 1
  472. distributed_option = build_distributed(args)
  473. # for logging
  474. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  475. logging.basicConfig(
  476. level="INFO",
  477. format=f"[{os.uname()[1].split('.')[0]}]"
  478. f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  479. )
  480. else:
  481. logging.basicConfig(
  482. level="ERROR",
  483. format=f"[{os.uname()[1].split('.')[0]}]"
  484. f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  485. )
  486. # prepare files for dataloader
  487. prepare_data(args, distributed_option)
  488. model = build_model(args)
  489. optimizers = build_optimizer(args, model=model)
  490. schedulers = build_scheduler(args, optimizers)
  491. logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
  492. distributed_option.dist_rank,
  493. distributed_option.local_rank))
  494. logging.info(pytorch_cudnn_version())
  495. logging.info(model_summary(model))
  496. logging.info("Optimizer: {}".format(optimizers))
  497. logging.info("Scheduler: {}".format(schedulers))
  498. # dump args to config.yaml
  499. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  500. os.makedirs(args.output_dir, exist_ok=True)
  501. with open(os.path.join(args.output_dir, "config.yaml"), "w") as f:
  502. logging.info("Saving the configuration in {}/{}".format(args.output_dir, "config.yaml"))
  503. if args.use_pai:
  504. buffer = BytesIO()
  505. torch.save({"config": vars(args)}, buffer)
  506. args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue())
  507. else:
  508. yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
  509. # dataloader for training/validation
  510. train_dataloader, valid_dataloader = build_dataloader(args)
  511. # Trainer, including model, optimizers, etc.
  512. trainer = build_trainer(
  513. args=args,
  514. model=model,
  515. optimizers=optimizers,
  516. schedulers=schedulers,
  517. train_dataloader=train_dataloader,
  518. valid_dataloader=valid_dataloader,
  519. distributed_option=distributed_option
  520. )
  521. trainer.run()