build_trainer.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725
  1. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  2. # MIT License (https://opensource.org/licenses/MIT)
  3. import argparse
  4. import logging
  5. import os
  6. import sys
  7. from io import BytesIO
  8. import torch
  9. import yaml
  10. from funasr.build_utils.build_args import build_args
  11. from funasr.build_utils.build_dataloader import build_dataloader
  12. from funasr.build_utils.build_distributed import build_distributed
  13. from funasr.build_utils.build_model import build_model
  14. from funasr.build_utils.build_optimizer import build_optimizer
  15. from funasr.build_utils.build_scheduler import build_scheduler
  16. from funasr.build_utils.build_trainer import build_trainer as build_trainer_modelscope
  17. from funasr.modules.lora.utils import mark_only_lora_as_trainable
  18. from funasr.tokenizer.phoneme_tokenizer import g2p_choices
  19. from funasr.torch_utils.load_pretrained_model import load_pretrained_model
  20. from funasr.torch_utils.model_summary import model_summary
  21. from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
  22. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  23. from funasr.utils.nested_dict_action import NestedDictAction
  24. from funasr.utils.prepare_data import prepare_data
  25. from funasr.utils.types import int_or_none
  26. from funasr.utils.types import str2bool
  27. from funasr.utils.types import str_or_none
  28. from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
  29. def update_dct(fin_configs, root):
  30. if root == {}:
  31. return {}
  32. for root_key, root_value in root.items():
  33. if not isinstance(root[root_key], dict):
  34. fin_configs[root_key] = root[root_key]
  35. else:
  36. if root_key in fin_configs.keys():
  37. result = update_dct(fin_configs[root_key], root[root_key])
  38. fin_configs[root_key] = result
  39. else:
  40. fin_configs[root_key] = root[root_key]
  41. return fin_configs
  42. def get_parser():
  43. parser = argparse.ArgumentParser(
  44. description="FunASR Common Training Parser",
  45. )
  46. # common configuration
  47. parser.add_argument("--output_dir", help="model save path")
  48. parser.add_argument(
  49. "--ngpu",
  50. type=int,
  51. default=0,
  52. help="The number of gpus. 0 indicates CPU mode",
  53. )
  54. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  55. parser.add_argument("--task_name", type=str, default="asr", help="Name for different tasks")
  56. # ddp related
  57. parser.add_argument(
  58. "--dist_backend",
  59. default="nccl",
  60. type=str,
  61. help="distributed backend",
  62. )
  63. parser.add_argument(
  64. "--dist_init_method",
  65. type=str,
  66. default="env://",
  67. help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
  68. '"WORLD_SIZE", and "RANK" are referred.',
  69. )
  70. parser.add_argument(
  71. "--dist_world_size",
  72. type=int,
  73. default=1,
  74. help="number of nodes for distributed training",
  75. )
  76. parser.add_argument(
  77. "--dist_rank",
  78. type=int,
  79. default=None,
  80. help="node rank for distributed training",
  81. )
  82. parser.add_argument(
  83. "--local_rank",
  84. type=int,
  85. default=None,
  86. help="local rank for distributed training",
  87. )
  88. parser.add_argument(
  89. "--dist_master_addr",
  90. default=None,
  91. type=str_or_none,
  92. help="The master address for distributed training. "
  93. "This value is used when dist_init_method == 'env://'",
  94. )
  95. parser.add_argument(
  96. "--dist_master_port",
  97. default=None,
  98. type=int_or_none,
  99. help="The master port for distributed training"
  100. "This value is used when dist_init_method == 'env://'",
  101. )
  102. parser.add_argument(
  103. "--dist_launcher",
  104. default=None,
  105. type=str_or_none,
  106. choices=["slurm", "mpi", None],
  107. help="The launcher type for distributed training",
  108. )
  109. parser.add_argument(
  110. "--multiprocessing_distributed",
  111. default=True,
  112. type=str2bool,
  113. help="Use multi-processing distributed training to launch "
  114. "N processes per node, which has N GPUs. This is the "
  115. "fastest way to use PyTorch for either single node or "
  116. "multi node data parallel training",
  117. )
  118. parser.add_argument(
  119. "--unused_parameters",
  120. type=str2bool,
  121. default=False,
  122. help="Whether to use the find_unused_parameters in "
  123. "torch.nn.parallel.DistributedDataParallel ",
  124. )
  125. parser.add_argument(
  126. "--gpu_id",
  127. type=int,
  128. default=0,
  129. help="local gpu id.",
  130. )
  131. # cudnn related
  132. parser.add_argument(
  133. "--cudnn_enabled",
  134. type=str2bool,
  135. default=torch.backends.cudnn.enabled,
  136. help="Enable CUDNN",
  137. )
  138. parser.add_argument(
  139. "--cudnn_benchmark",
  140. type=str2bool,
  141. default=torch.backends.cudnn.benchmark,
  142. help="Enable cudnn-benchmark mode",
  143. )
  144. parser.add_argument(
  145. "--cudnn_deterministic",
  146. type=str2bool,
  147. default=True,
  148. help="Enable cudnn-deterministic mode",
  149. )
  150. # trainer related
  151. parser.add_argument(
  152. "--max_epoch",
  153. type=int,
  154. default=40,
  155. help="The maximum number epoch to train",
  156. )
  157. parser.add_argument(
  158. "--max_update",
  159. type=int,
  160. default=sys.maxsize,
  161. help="The maximum number update step to train",
  162. )
  163. parser.add_argument(
  164. "--batch_interval",
  165. type=int,
  166. default=10000,
  167. help="The batch interval for saving model.",
  168. )
  169. parser.add_argument(
  170. "--patience",
  171. type=int_or_none,
  172. default=None,
  173. help="Number of epochs to wait without improvement "
  174. "before stopping the training",
  175. )
  176. parser.add_argument(
  177. "--val_scheduler_criterion",
  178. type=str,
  179. nargs=2,
  180. default=("valid", "loss"),
  181. help="The criterion used for the value given to the lr scheduler. "
  182. 'Give a pair referring the phase, "train" or "valid",'
  183. 'and the criterion name. The mode specifying "min" or "max" can '
  184. "be changed by --scheduler_conf",
  185. )
  186. parser.add_argument(
  187. "--early_stopping_criterion",
  188. type=str,
  189. nargs=3,
  190. default=("valid", "loss", "min"),
  191. help="The criterion used for judging of early stopping. "
  192. 'Give a pair referring the phase, "train" or "valid",'
  193. 'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
  194. )
  195. parser.add_argument(
  196. "--best_model_criterion",
  197. nargs="+",
  198. default=[
  199. ("train", "loss", "min"),
  200. ("valid", "loss", "min"),
  201. ("train", "acc", "max"),
  202. ("valid", "acc", "max"),
  203. ],
  204. help="The criterion used for judging of the best model. "
  205. 'Give a pair referring the phase, "train" or "valid",'
  206. 'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
  207. )
  208. parser.add_argument(
  209. "--keep_nbest_models",
  210. type=int,
  211. nargs="+",
  212. default=[10],
  213. help="Remove previous snapshots excluding the n-best scored epochs",
  214. )
  215. parser.add_argument(
  216. "--nbest_averaging_interval",
  217. type=int,
  218. default=0,
  219. help="The epoch interval to apply model averaging and save nbest models",
  220. )
  221. parser.add_argument(
  222. "--grad_clip",
  223. type=float,
  224. default=5.0,
  225. help="Gradient norm threshold to clip",
  226. )
  227. parser.add_argument(
  228. "--grad_clip_type",
  229. type=float,
  230. default=2.0,
  231. help="The type of the used p-norm for gradient clip. Can be inf",
  232. )
  233. parser.add_argument(
  234. "--grad_noise",
  235. type=str2bool,
  236. default=False,
  237. help="The flag to switch to use noise injection to "
  238. "gradients during training",
  239. )
  240. parser.add_argument(
  241. "--accum_grad",
  242. type=int,
  243. default=1,
  244. help="The number of gradient accumulation",
  245. )
  246. parser.add_argument(
  247. "--resume",
  248. type=str2bool,
  249. default=False,
  250. help="Enable resuming if checkpoint is existing",
  251. )
  252. parser.add_argument(
  253. "--train_dtype",
  254. default="float32",
  255. choices=["float16", "float32", "float64"],
  256. help="Data type for training.",
  257. )
  258. parser.add_argument(
  259. "--use_amp",
  260. type=str2bool,
  261. default=False,
  262. help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
  263. )
  264. parser.add_argument(
  265. "--log_interval",
  266. default=None,
  267. help="Show the logs every the number iterations in each epochs at the "
  268. "training phase. If None is given, it is decided according the number "
  269. "of training samples automatically .",
  270. )
  271. parser.add_argument(
  272. "--use_tensorboard",
  273. type=str2bool,
  274. default=True,
  275. help="Enable tensorboard logging",
  276. )
  277. # pretrained model related
  278. parser.add_argument(
  279. "--init_param",
  280. type=str,
  281. action="append",
  282. default=[],
  283. help="Specify the file path used for initialization of parameters. "
  284. "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
  285. "where file_path is the model file path, "
  286. "src_key specifies the key of model states to be used in the model file, "
  287. "dst_key specifies the attribute of the model to be initialized, "
  288. "and exclude_keys excludes keys of model states for the initialization."
  289. "e.g.\n"
  290. " # Load all parameters"
  291. " --init_param some/where/model.pb\n"
  292. " # Load only decoder parameters"
  293. " --init_param some/where/model.pb:decoder:decoder\n"
  294. " # Load only decoder parameters excluding decoder.embed"
  295. " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
  296. " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
  297. )
  298. parser.add_argument(
  299. "--ignore_init_mismatch",
  300. type=str2bool,
  301. default=False,
  302. help="Ignore size mismatch when loading pre-trained model",
  303. )
  304. parser.add_argument(
  305. "--freeze_param",
  306. type=str,
  307. default=[],
  308. action="append",
  309. help="Freeze parameters",
  310. )
  311. # dataset related
  312. parser.add_argument(
  313. "--dataset_type",
  314. type=str,
  315. default="small",
  316. help="whether to use dataloader for large dataset",
  317. )
  318. parser.add_argument(
  319. "--dataset_conf",
  320. action=NestedDictAction,
  321. default=dict(),
  322. help=f"The keyword arguments for dataset",
  323. )
  324. parser.add_argument(
  325. "--data_dir",
  326. type=str,
  327. default=None,
  328. help="root path of data",
  329. )
  330. parser.add_argument(
  331. "--train_set",
  332. type=str,
  333. default="train",
  334. help="train dataset",
  335. )
  336. parser.add_argument(
  337. "--valid_set",
  338. type=str,
  339. default="validation",
  340. help="dev dataset",
  341. )
  342. parser.add_argument(
  343. "--data_file_names",
  344. type=str,
  345. default="wav.scp,text",
  346. help="input data files",
  347. )
  348. parser.add_argument(
  349. "--speed_perturb",
  350. type=float,
  351. nargs="+",
  352. default=None,
  353. help="speed perturb",
  354. )
  355. parser.add_argument(
  356. "--use_preprocessor",
  357. type=str2bool,
  358. default=True,
  359. help="Apply preprocessing to data or not",
  360. )
  361. # optimization related
  362. parser.add_argument(
  363. "--optim",
  364. type=lambda x: x.lower(),
  365. default="adam",
  366. help="The optimizer type",
  367. )
  368. parser.add_argument(
  369. "--optim_conf",
  370. action=NestedDictAction,
  371. default=dict(),
  372. help="The keyword arguments for optimizer",
  373. )
  374. parser.add_argument(
  375. "--scheduler",
  376. type=lambda x: str_or_none(x.lower()),
  377. default=None,
  378. help="The lr scheduler type",
  379. )
  380. parser.add_argument(
  381. "--scheduler_conf",
  382. action=NestedDictAction,
  383. default=dict(),
  384. help="The keyword arguments for lr scheduler",
  385. )
  386. # most task related
  387. parser.add_argument(
  388. "--init",
  389. type=lambda x: str_or_none(x.lower()),
  390. default=None,
  391. help="The initialization method",
  392. choices=[
  393. "chainer",
  394. "xavier_uniform",
  395. "xavier_normal",
  396. "kaiming_uniform",
  397. "kaiming_normal",
  398. None,
  399. ],
  400. )
  401. parser.add_argument(
  402. "--token_list",
  403. type=str_or_none,
  404. default=None,
  405. help="A text mapping int-id to token",
  406. )
  407. parser.add_argument(
  408. "--token_type",
  409. type=str,
  410. default="bpe",
  411. choices=["bpe", "char", "word"],
  412. help="",
  413. )
  414. parser.add_argument(
  415. "--bpemodel",
  416. type=str_or_none,
  417. default=None,
  418. help="The model file fo sentencepiece",
  419. )
  420. parser.add_argument(
  421. "--cleaner",
  422. type=str_or_none,
  423. choices=[None, "tacotron", "jaconv", "vietnamese"],
  424. default=None,
  425. help="Apply text cleaning",
  426. )
  427. parser.add_argument(
  428. "--g2p",
  429. type=str_or_none,
  430. choices=g2p_choices,
  431. default=None,
  432. help="Specify g2p method if --token_type=phn",
  433. )
  434. # pai related
  435. parser.add_argument(
  436. "--use_pai",
  437. type=str2bool,
  438. default=False,
  439. help="flag to indicate whether training on PAI",
  440. )
  441. parser.add_argument(
  442. "--simple_ddp",
  443. type=str2bool,
  444. default=False,
  445. )
  446. parser.add_argument(
  447. "--num_worker_count",
  448. type=int,
  449. default=1,
  450. help="The number of machines on PAI.",
  451. )
  452. parser.add_argument(
  453. "--access_key_id",
  454. type=str,
  455. default=None,
  456. help="The username for oss.",
  457. )
  458. parser.add_argument(
  459. "--access_key_secret",
  460. type=str,
  461. default=None,
  462. help="The password for oss.",
  463. )
  464. parser.add_argument(
  465. "--endpoint",
  466. type=str,
  467. default=None,
  468. help="The endpoint for oss.",
  469. )
  470. parser.add_argument(
  471. "--bucket_name",
  472. type=str,
  473. default=None,
  474. help="The bucket name for oss.",
  475. )
  476. parser.add_argument(
  477. "--oss_bucket",
  478. default=None,
  479. help="oss bucket.",
  480. )
  481. parser.add_argument(
  482. "--enable_lora",
  483. type=str2bool,
  484. default=False,
  485. help="Apply lora for finetuning.",
  486. )
  487. parser.add_argument(
  488. "--lora_bias",
  489. type=str,
  490. default="none",
  491. help="lora bias.",
  492. )
  493. return parser
  494. def build_trainer(modelscope_dict,
  495. data_dir,
  496. output_dir,
  497. train_set="train",
  498. dev_set="validation",
  499. distributed=False,
  500. dataset_type="small",
  501. batch_bins=None,
  502. max_epoch=None,
  503. optim=None,
  504. lr=None,
  505. scheduler=None,
  506. scheduler_conf=None,
  507. specaug=None,
  508. specaug_conf=None,
  509. mate_params=None,
  510. **kwargs):
  511. parser = get_parser()
  512. args, extra_task_params = parser.parse_known_args()
  513. args = build_args(args, parser, extra_task_params)
  514. if args.local_rank is not None:
  515. distributed = True
  516. else:
  517. distributed = False
  518. args.local_rank = args.local_rank if args.local_rank is not None else 0
  519. local_rank = args.local_rank
  520. if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
  521. gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
  522. os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[args.local_rank])
  523. else:
  524. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank)
  525. config = modelscope_dict['am_model_config']
  526. finetune_config = modelscope_dict['finetune_config']
  527. init_param = modelscope_dict['init_model']
  528. cmvn_file = modelscope_dict['cmvn_file']
  529. seg_dict_file = modelscope_dict['seg_dict']
  530. if 'bpemodel' in modelscope_dict:
  531. bpemodel = modelscope_dict['bpemodel']
  532. else:
  533. bpemodel = None
  534. # overwrite parameters
  535. with open(config) as f:
  536. configs = yaml.safe_load(f)
  537. with open(finetune_config) as f:
  538. finetune_configs = yaml.safe_load(f)
  539. # set data_types
  540. if dataset_type == "large":
  541. # finetune_configs["dataset_conf"]["data_types"] = "sound,text"
  542. if 'data_types' not in finetune_configs['dataset_conf']:
  543. finetune_configs["dataset_conf"]["data_types"] = "sound,text"
  544. finetune_configs = update_dct(configs, finetune_configs)
  545. for key, value in finetune_configs.items():
  546. if hasattr(args, key):
  547. setattr(args, key, value)
  548. if mate_params is not None:
  549. for key, value in mate_params.items():
  550. if hasattr(args, key):
  551. setattr(args, key, value)
  552. if mate_params is not None and "lora_params" in mate_params:
  553. lora_params = mate_params['lora_params']
  554. configs['encoder_conf'].update(lora_params)
  555. configs['decoder_conf'].update(lora_params)
  556. args.dataset_type = dataset_type
  557. args.init_param = [init_param]
  558. if mate_params is not None and "init_param" in mate_params:
  559. if len(mate_params["init_param"]) != 0:
  560. args.init_param = mate_params["init_param"]
  561. args.cmvn_file = cmvn_file
  562. if os.path.exists(seg_dict_file):
  563. args.seg_dict_file = seg_dict_file
  564. else:
  565. args.seg_dict_file = None
  566. if bpemodel is not None and os.path.exists(bpemodel):
  567. args.bpemodel = bpemodel
  568. else:
  569. args.bpemodel = None
  570. args.data_dir = data_dir
  571. args.train_set = train_set
  572. args.dev_set = dev_set
  573. args.output_dir = output_dir
  574. args.gpu_id = args.local_rank
  575. args.config = finetune_config
  576. args.use_pai = False
  577. args.batch_type = "length"
  578. args.oss_bucket = None
  579. args.input_size = None
  580. if distributed:
  581. args.distributed = True
  582. args.simple_ddp = True
  583. else:
  584. args.distributed = False
  585. args.ngpu = 1
  586. if optim is not None:
  587. args.optim = optim
  588. if lr is not None:
  589. args.optim_conf["lr"] = lr
  590. if scheduler is not None:
  591. args.scheduler = scheduler
  592. if scheduler_conf is not None:
  593. args.scheduler_conf = scheduler_conf
  594. if specaug is not None:
  595. args.specaug = specaug
  596. if specaug_conf is not None:
  597. args.specaug_conf = specaug_conf
  598. if max_epoch is not None:
  599. args.max_epoch = max_epoch
  600. if batch_bins is not None:
  601. if args.dataset_type == "small":
  602. args.batch_bins = batch_bins
  603. args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
  604. elif args.dataset_type == "large":
  605. args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
  606. else:
  607. raise ValueError(f"Not supported dataset_type={args.dataset_type}")
  608. if args.normalize in ["null", "none", "None"]:
  609. args.normalize = None
  610. if args.patience in ["null", "none", "None"]:
  611. args.patience = None
  612. args.local_rank = local_rank
  613. # set random seed
  614. set_all_random_seed(args.seed)
  615. torch.backends.cudnn.enabled = args.cudnn_enabled
  616. torch.backends.cudnn.benchmark = args.cudnn_benchmark
  617. torch.backends.cudnn.deterministic = args.cudnn_deterministic
  618. # ddp init
  619. distributed_option = build_distributed(args)
  620. # for logging
  621. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  622. logging.basicConfig(
  623. level="INFO",
  624. format=f"[{os.uname()[1].split('.')[0]}]"
  625. f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  626. )
  627. else:
  628. logging.basicConfig(
  629. level="ERROR",
  630. format=f"[{os.uname()[1].split('.')[0]}]"
  631. f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  632. )
  633. # prepare files for dataloader
  634. prepare_data(args, distributed_option)
  635. model = build_model(args)
  636. model = model.to(
  637. dtype=getattr(torch, args.train_dtype),
  638. device="cuda" if args.ngpu > 0 else "cpu",
  639. )
  640. if args.enable_lora:
  641. mark_only_lora_as_trainable(model, args.lora_bias)
  642. for t in args.freeze_param:
  643. for k, p in model.named_parameters():
  644. if k.startswith(t + ".") or k == t:
  645. logging.info(f"Setting {k}.requires_grad = False")
  646. p.requires_grad = False
  647. optimizers = build_optimizer(args, model=model)
  648. schedulers = build_scheduler(args, optimizers)
  649. logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
  650. distributed_option.dist_rank,
  651. distributed_option.local_rank))
  652. logging.info(pytorch_cudnn_version())
  653. logging.info("Args: {}".format(args))
  654. logging.info(model_summary(model))
  655. logging.info("Optimizer: {}".format(optimizers))
  656. logging.info("Scheduler: {}".format(schedulers))
  657. # dump args to config.yaml
  658. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  659. os.makedirs(args.output_dir, exist_ok=True)
  660. with open(os.path.join(args.output_dir, "config.yaml"), "w") as f:
  661. logging.info("Saving the configuration in {}/{}".format(args.output_dir, "config.yaml"))
  662. if args.use_pai:
  663. buffer = BytesIO()
  664. torch.save({"config": vars(args)}, buffer)
  665. args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue())
  666. else:
  667. yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
  668. for p in args.init_param:
  669. logging.info(f"Loading pretrained params from {p}")
  670. load_pretrained_model(
  671. model=model,
  672. init_param=p,
  673. ignore_init_mismatch=args.ignore_init_mismatch,
  674. map_location=f"cuda:{torch.cuda.current_device()}"
  675. if args.ngpu > 0
  676. else "cpu",
  677. oss_bucket=args.oss_bucket,
  678. )
  679. # dataloader for training/validation
  680. train_dataloader, valid_dataloader = build_dataloader(args)
  681. # Trainer, including model, optimizers, etc.
  682. trainer = build_trainer_modelscope(
  683. args=args,
  684. model=model,
  685. optimizers=optimizers,
  686. schedulers=schedulers,
  687. train_dataloader=train_dataloader,
  688. valid_dataloader=valid_dataloader,
  689. distributed_option=distributed_option
  690. )
  691. return trainer