build_trainer.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815
  1. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Trainer module."""
  4. import argparse
  5. import dataclasses
  6. import logging
  7. import os
  8. import time
  9. from contextlib import contextmanager
  10. from dataclasses import is_dataclass
  11. from distutils.version import LooseVersion
  12. from io import BytesIO
  13. from pathlib import Path
  14. from typing import Dict
  15. from typing import Iterable
  16. from typing import List
  17. from typing import Optional
  18. from typing import Sequence
  19. from typing import Tuple
  20. from typing import Union
  21. import humanfriendly
  22. import oss2
  23. import torch
  24. import torch.nn
  25. import torch.optim
  26. from funasr.iterators.abs_iter_factory import AbsIterFactory
  27. from funasr.main_funcs.average_nbest_models import average_nbest_models
  28. from funasr.models.base_model import FunASRModel
  29. from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
  30. from funasr.schedulers.abs_scheduler import AbsEpochStepScheduler
  31. from funasr.schedulers.abs_scheduler import AbsScheduler
  32. from funasr.schedulers.abs_scheduler import AbsValEpochStepScheduler
  33. from funasr.torch_utils.add_gradient_noise import add_gradient_noise
  34. from funasr.torch_utils.device_funcs import to_device
  35. from funasr.torch_utils.recursive_op import recursive_average
  36. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  37. from funasr.train.distributed_utils import DistributedOption
  38. from funasr.train.reporter import Reporter
  39. from funasr.train.reporter import SubReporter
  40. from funasr.utils.build_dataclass import build_dataclass
  41. if torch.distributed.is_available():
  42. from torch.distributed import ReduceOp
  43. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  44. from torch.cuda.amp import autocast
  45. from torch.cuda.amp import GradScaler
  46. else:
  47. # Nothing to do if torch<1.6.0
  48. @contextmanager
  49. def autocast(enabled=True):
  50. yield
  51. GradScaler = None
  52. try:
  53. import fairscale
  54. except ImportError:
  55. fairscale = None
  56. @dataclasses.dataclass
  57. class TrainerOptions:
  58. ngpu: int
  59. resume: bool
  60. use_amp: bool
  61. train_dtype: str
  62. grad_noise: bool
  63. accum_grad: int
  64. grad_clip: float
  65. grad_clip_type: float
  66. log_interval: Optional[int]
  67. # no_forward_run: bool
  68. use_tensorboard: bool
  69. # use_wandb: bool
  70. output_dir: Union[Path, str]
  71. max_epoch: int
  72. max_update: int
  73. seed: int
  74. # sharded_ddp: bool
  75. patience: Optional[int]
  76. keep_nbest_models: Union[int, List[int]]
  77. nbest_averaging_interval: int
  78. early_stopping_criterion: Sequence[str]
  79. best_model_criterion: Sequence[Sequence[str]]
  80. val_scheduler_criterion: Sequence[str]
  81. unused_parameters: bool
  82. # wandb_model_log_interval: int
  83. use_pai: bool
  84. oss_bucket: Union[oss2.Bucket, None]
  85. class Trainer:
  86. """Trainer
  87. """
  88. def __init__(self,
  89. args,
  90. model: FunASRModel,
  91. optimizers: Sequence[torch.optim.Optimizer],
  92. schedulers: Sequence[Optional[AbsScheduler]],
  93. train_dataloader: AbsIterFactory,
  94. valid_dataloader: AbsIterFactory,
  95. distributed_option: DistributedOption):
  96. self.trainer_options = self.build_options(args)
  97. self.model = model
  98. self.optimizers = optimizers
  99. self.schedulers = schedulers
  100. self.train_dataloader = train_dataloader
  101. self.valid_dataloader = valid_dataloader
  102. self.distributed_option = distributed_option
  103. def build_options(self, args: argparse.Namespace) -> TrainerOptions:
  104. """Build options consumed by train(), eval()"""
  105. return build_dataclass(TrainerOptions, args)
  106. @classmethod
  107. def add_arguments(cls, parser: argparse.ArgumentParser):
  108. """Reserved for future development of another Trainer"""
  109. pass
  110. def resume(self,
  111. checkpoint: Union[str, Path],
  112. model: torch.nn.Module,
  113. reporter: Reporter,
  114. optimizers: Sequence[torch.optim.Optimizer],
  115. schedulers: Sequence[Optional[AbsScheduler]],
  116. scaler: Optional[GradScaler],
  117. ngpu: int = 0,
  118. ):
  119. states = torch.load(
  120. checkpoint,
  121. map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
  122. )
  123. model.load_state_dict(states["model"])
  124. reporter.load_state_dict(states["reporter"])
  125. for optimizer, state in zip(optimizers, states["optimizers"]):
  126. optimizer.load_state_dict(state)
  127. for scheduler, state in zip(schedulers, states["schedulers"]):
  128. if scheduler is not None:
  129. scheduler.load_state_dict(state)
  130. if scaler is not None:
  131. if states["scaler"] is None:
  132. logging.warning("scaler state is not found")
  133. else:
  134. scaler.load_state_dict(states["scaler"])
  135. logging.info(f"The training was resumed using {checkpoint}")
  136. def run(self) -> None:
  137. """Perform training. This method performs the main process of training."""
  138. # NOTE(kamo): Don't check the type more strictly as far trainer_options
  139. model = self.model
  140. optimizers = self.optimizers
  141. schedulers = self.schedulers
  142. train_dataloader = self.train_dataloader
  143. valid_dataloader = self.valid_dataloader
  144. trainer_options = self.trainer_options
  145. distributed_option = self.distributed_option
  146. assert is_dataclass(trainer_options), type(trainer_options)
  147. assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers))
  148. if isinstance(trainer_options.keep_nbest_models, int):
  149. keep_nbest_models = [trainer_options.keep_nbest_models]
  150. else:
  151. if len(trainer_options.keep_nbest_models) == 0:
  152. logging.warning("No keep_nbest_models is given. Change to [1]")
  153. trainer_options.keep_nbest_models = [1]
  154. keep_nbest_models = trainer_options.keep_nbest_models
  155. output_dir = Path(trainer_options.output_dir)
  156. reporter = Reporter()
  157. if trainer_options.use_amp:
  158. if LooseVersion(torch.__version__) < LooseVersion("1.6.0"):
  159. raise RuntimeError(
  160. "Require torch>=1.6.0 for Automatic Mixed Precision"
  161. )
  162. # if trainer_options.sharded_ddp:
  163. # if fairscale is None:
  164. # raise RuntimeError(
  165. # "Requiring fairscale. Do 'pip install fairscale'"
  166. # )
  167. # scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
  168. # else:
  169. scaler = GradScaler()
  170. else:
  171. scaler = None
  172. if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
  173. self.resume(
  174. checkpoint=output_dir / "checkpoint.pb",
  175. model=model,
  176. optimizers=optimizers,
  177. schedulers=schedulers,
  178. reporter=reporter,
  179. scaler=scaler,
  180. ngpu=trainer_options.ngpu,
  181. )
  182. start_epoch = reporter.get_epoch() + 1
  183. if start_epoch == trainer_options.max_epoch + 1:
  184. logging.warning(
  185. f"The training has already reached at max_epoch: {start_epoch}"
  186. )
  187. if distributed_option.distributed:
  188. dp_model = torch.nn.parallel.DistributedDataParallel(
  189. model, find_unused_parameters=trainer_options.unused_parameters)
  190. elif distributed_option.ngpu > 1:
  191. dp_model = torch.nn.parallel.DataParallel(
  192. model,
  193. device_ids=list(range(distributed_option.ngpu)),
  194. )
  195. else:
  196. # NOTE(kamo): DataParallel also should work with ngpu=1,
  197. # but for debuggability it's better to keep this block.
  198. dp_model = model
  199. if trainer_options.use_tensorboard and (
  200. not distributed_option.distributed or distributed_option.dist_rank == 0
  201. ):
  202. from torch.utils.tensorboard import SummaryWriter
  203. if trainer_options.use_pai:
  204. train_summary_writer = SummaryWriter(
  205. os.path.join(trainer_options.output_dir, "tensorboard/train")
  206. )
  207. valid_summary_writer = SummaryWriter(
  208. os.path.join(trainer_options.output_dir, "tensorboard/valid")
  209. )
  210. else:
  211. train_summary_writer = SummaryWriter(
  212. str(output_dir / "tensorboard" / "train")
  213. )
  214. valid_summary_writer = SummaryWriter(
  215. str(output_dir / "tensorboard" / "valid")
  216. )
  217. else:
  218. train_summary_writer = None
  219. start_time = time.perf_counter()
  220. for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
  221. if iepoch != start_epoch:
  222. logging.info(
  223. "{}/{}epoch started. Estimated time to finish: {}".format(
  224. iepoch,
  225. trainer_options.max_epoch,
  226. humanfriendly.format_timespan(
  227. (time.perf_counter() - start_time)
  228. / (iepoch - start_epoch)
  229. * (trainer_options.max_epoch - iepoch + 1)
  230. ),
  231. )
  232. )
  233. else:
  234. logging.info(f"{iepoch}/{trainer_options.max_epoch}epoch started")
  235. set_all_random_seed(trainer_options.seed + iepoch)
  236. reporter.set_epoch(iepoch)
  237. # 1. Train and validation for one-epoch
  238. with reporter.observe("train") as sub_reporter:
  239. all_steps_are_invalid, max_update_stop = self.train_one_epoch(
  240. model=dp_model,
  241. optimizers=optimizers,
  242. schedulers=schedulers,
  243. iterator=train_dataloader.build_iter(iepoch),
  244. reporter=sub_reporter,
  245. scaler=scaler,
  246. summary_writer=train_summary_writer,
  247. options=trainer_options,
  248. distributed_option=distributed_option,
  249. )
  250. with reporter.observe("valid") as sub_reporter:
  251. self.validate_one_epoch(
  252. model=dp_model,
  253. iterator=valid_dataloader.build_iter(iepoch),
  254. reporter=sub_reporter,
  255. options=trainer_options,
  256. distributed_option=distributed_option,
  257. )
  258. # 2. LR Scheduler step
  259. for scheduler in schedulers:
  260. if isinstance(scheduler, AbsValEpochStepScheduler):
  261. scheduler.step(
  262. reporter.get_value(*trainer_options.val_scheduler_criterion)
  263. )
  264. elif isinstance(scheduler, AbsEpochStepScheduler):
  265. scheduler.step()
  266. # if trainer_options.sharded_ddp:
  267. # for optimizer in optimizers:
  268. # if isinstance(optimizer, fairscale.optim.oss.OSS):
  269. # optimizer.consolidate_state_dict()
  270. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  271. # 3. Report the results
  272. logging.info(reporter.log_message())
  273. if train_summary_writer is not None:
  274. reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
  275. reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
  276. # if trainer_options.use_wandb:
  277. # reporter.wandb_log()
  278. # save tensorboard on oss
  279. if trainer_options.use_pai and train_summary_writer is not None:
  280. def write_tensorboard_summary(summary_writer_path, oss_bucket):
  281. file_list = []
  282. for root, dirs, files in os.walk(summary_writer_path, topdown=False):
  283. for name in files:
  284. file_full_path = os.path.join(root, name)
  285. file_list.append(file_full_path)
  286. for file_full_path in file_list:
  287. with open(file_full_path, "rb") as f:
  288. oss_bucket.put_object(file_full_path, f)
  289. write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/train"),
  290. trainer_options.oss_bucket)
  291. write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/valid"),
  292. trainer_options.oss_bucket)
  293. # 4. Save/Update the checkpoint
  294. if trainer_options.use_pai:
  295. buffer = BytesIO()
  296. torch.save(
  297. {
  298. "model": model.state_dict(),
  299. "reporter": reporter.state_dict(),
  300. "optimizers": [o.state_dict() for o in optimizers],
  301. "schedulers": [
  302. s.state_dict() if s is not None else None
  303. for s in schedulers
  304. ],
  305. "scaler": scaler.state_dict() if scaler is not None else None,
  306. "ema_model": model.encoder.ema.model.state_dict()
  307. if hasattr(model.encoder, "ema") and model.encoder.ema is not None else None,
  308. },
  309. buffer,
  310. )
  311. trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pb"),
  312. buffer.getvalue())
  313. else:
  314. torch.save(
  315. {
  316. "model": model.state_dict(),
  317. "reporter": reporter.state_dict(),
  318. "optimizers": [o.state_dict() for o in optimizers],
  319. "schedulers": [
  320. s.state_dict() if s is not None else None
  321. for s in schedulers
  322. ],
  323. "scaler": scaler.state_dict() if scaler is not None else None,
  324. },
  325. output_dir / "checkpoint.pb",
  326. )
  327. # 5. Save and log the model and update the link to the best model
  328. if trainer_options.use_pai:
  329. buffer = BytesIO()
  330. torch.save(model.state_dict(), buffer)
  331. trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir,
  332. f"{iepoch}epoch.pb"), buffer.getvalue())
  333. else:
  334. torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pb")
  335. # Creates a sym link latest.pb -> {iepoch}epoch.pb
  336. if trainer_options.use_pai:
  337. p = os.path.join(trainer_options.output_dir, "latest.pb")
  338. if trainer_options.oss_bucket.object_exists(p):
  339. trainer_options.oss_bucket.delete_object(p)
  340. trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
  341. os.path.join(trainer_options.output_dir,
  342. f"{iepoch}epoch.pb"), p)
  343. else:
  344. p = output_dir / "latest.pb"
  345. if p.is_symlink() or p.exists():
  346. p.unlink()
  347. p.symlink_to(f"{iepoch}epoch.pb")
  348. _improved = []
  349. for _phase, k, _mode in trainer_options.best_model_criterion:
  350. # e.g. _phase, k, _mode = "train", "loss", "min"
  351. if reporter.has(_phase, k):
  352. best_epoch = reporter.get_best_epoch(_phase, k, _mode)
  353. # Creates sym links if it's the best result
  354. if best_epoch == iepoch:
  355. if trainer_options.use_pai:
  356. p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pb")
  357. if trainer_options.oss_bucket.object_exists(p):
  358. trainer_options.oss_bucket.delete_object(p)
  359. trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
  360. os.path.join(trainer_options.output_dir,
  361. f"{iepoch}epoch.pb"), p)
  362. else:
  363. p = output_dir / f"{_phase}.{k}.best.pb"
  364. if p.is_symlink() or p.exists():
  365. p.unlink()
  366. p.symlink_to(f"{iepoch}epoch.pb")
  367. _improved.append(f"{_phase}.{k}")
  368. if len(_improved) == 0:
  369. logging.info("There are no improvements in this epoch")
  370. else:
  371. logging.info(
  372. "The best model has been updated: " + ", ".join(_improved)
  373. )
  374. # log_model = (
  375. # trainer_options.wandb_model_log_interval > 0
  376. # and iepoch % trainer_options.wandb_model_log_interval == 0
  377. # )
  378. # if log_model and trainer_options.use_wandb:
  379. # import wandb
  380. #
  381. # logging.info("Logging Model on this epoch :::::")
  382. # artifact = wandb.Artifact(
  383. # name=f"model_{wandb.run.id}",
  384. # type="model",
  385. # metadata={"improved": _improved},
  386. # )
  387. # artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
  388. # aliases = [
  389. # f"epoch-{iepoch}",
  390. # "best" if best_epoch == iepoch else "",
  391. # ]
  392. # wandb.log_artifact(artifact, aliases=aliases)
  393. # 6. Remove the model files excluding n-best epoch and latest epoch
  394. _removed = []
  395. # Get the union set of the n-best among multiple criterion
  396. nbests = set().union(
  397. *[
  398. set(reporter.sort_epochs(ph, k, m)[: max(keep_nbest_models)])
  399. for ph, k, m in trainer_options.best_model_criterion
  400. if reporter.has(ph, k)
  401. ]
  402. )
  403. # Generated n-best averaged model
  404. if (
  405. trainer_options.nbest_averaging_interval > 0
  406. and iepoch % trainer_options.nbest_averaging_interval == 0
  407. ):
  408. average_nbest_models(
  409. reporter=reporter,
  410. output_dir=output_dir,
  411. best_model_criterion=trainer_options.best_model_criterion,
  412. nbest=keep_nbest_models,
  413. suffix=f"till{iepoch}epoch",
  414. oss_bucket=trainer_options.oss_bucket,
  415. pai_output_dir=trainer_options.output_dir,
  416. )
  417. for e in range(1, iepoch):
  418. if trainer_options.use_pai:
  419. p = os.path.join(trainer_options.output_dir, f"{e}epoch.pb")
  420. if trainer_options.oss_bucket.object_exists(p) and e not in nbests:
  421. trainer_options.oss_bucket.delete_object(p)
  422. _removed.append(str(p))
  423. else:
  424. p = output_dir / f"{e}epoch.pb"
  425. if p.exists() and e not in nbests:
  426. p.unlink()
  427. _removed.append(str(p))
  428. if len(_removed) != 0:
  429. logging.info("The model files were removed: " + ", ".join(_removed))
  430. # 7. If any updating haven't happened, stops the training
  431. if all_steps_are_invalid:
  432. logging.warning(
  433. f"The gradients at all steps are invalid in this epoch. "
  434. f"Something seems wrong. This training was stopped at {iepoch}epoch"
  435. )
  436. break
  437. if max_update_stop:
  438. logging.info(
  439. f"Stopping training due to "
  440. f"num_updates: {trainer_options.num_updates} >= max_update: {trainer_options.max_update}"
  441. )
  442. break
  443. # 8. Check early stopping
  444. if trainer_options.patience is not None:
  445. if reporter.check_early_stopping(
  446. trainer_options.patience, *trainer_options.early_stopping_criterion
  447. ):
  448. break
  449. else:
  450. logging.info(
  451. f"The training was finished at {trainer_options.max_epoch} epochs "
  452. )
  453. # Generated n-best averaged model
  454. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  455. average_nbest_models(
  456. reporter=reporter,
  457. output_dir=output_dir,
  458. best_model_criterion=trainer_options.best_model_criterion,
  459. nbest=keep_nbest_models,
  460. oss_bucket=trainer_options.oss_bucket,
  461. pai_output_dir=trainer_options.output_dir,
  462. )
  463. def train_one_epoch(
  464. self,
  465. model: torch.nn.Module,
  466. iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
  467. optimizers: Sequence[torch.optim.Optimizer],
  468. schedulers: Sequence[Optional[AbsScheduler]],
  469. scaler: Optional[GradScaler],
  470. reporter: SubReporter,
  471. summary_writer,
  472. options: TrainerOptions,
  473. distributed_option: DistributedOption,
  474. ) -> Tuple[bool, bool]:
  475. grad_noise = options.grad_noise
  476. accum_grad = options.accum_grad
  477. grad_clip = options.grad_clip
  478. grad_clip_type = options.grad_clip_type
  479. log_interval = options.log_interval
  480. # no_forward_run = options.no_forward_run
  481. ngpu = options.ngpu
  482. # use_wandb = options.use_wandb
  483. distributed = distributed_option.distributed
  484. if log_interval is None:
  485. try:
  486. log_interval = max(len(iterator) // 20, 10)
  487. except TypeError:
  488. log_interval = 100
  489. model.train()
  490. all_steps_are_invalid = True
  491. max_update_stop = False
  492. # [For distributed] Because iteration counts are not always equals between
  493. # processes, send stop-flag to the other processes if iterator is finished
  494. iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
  495. start_time = time.perf_counter()
  496. for iiter, (_, batch) in enumerate(
  497. reporter.measure_iter_time(iterator, "iter_time"), 1
  498. ):
  499. assert isinstance(batch, dict), type(batch)
  500. if distributed:
  501. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
  502. if iterator_stop > 0:
  503. break
  504. batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
  505. # if no_forward_run:
  506. # all_steps_are_invalid = False
  507. # continue
  508. with autocast(scaler is not None):
  509. with reporter.measure_time("forward_time"):
  510. retval = model(**batch)
  511. # Note(kamo):
  512. # Supporting two patterns for the returned value from the model
  513. # a. dict type
  514. if isinstance(retval, dict):
  515. loss = retval["loss"]
  516. stats = retval["stats"]
  517. weight = retval["weight"]
  518. optim_idx = retval.get("optim_idx")
  519. if optim_idx is not None and not isinstance(optim_idx, int):
  520. if not isinstance(optim_idx, torch.Tensor):
  521. raise RuntimeError(
  522. "optim_idx must be int or 1dim torch.Tensor, "
  523. f"but got {type(optim_idx)}"
  524. )
  525. if optim_idx.dim() >= 2:
  526. raise RuntimeError(
  527. "optim_idx must be int or 1dim torch.Tensor, "
  528. f"but got {optim_idx.dim()}dim tensor"
  529. )
  530. if optim_idx.dim() == 1:
  531. for v in optim_idx:
  532. if v != optim_idx[0]:
  533. raise RuntimeError(
  534. "optim_idx must be 1dim tensor "
  535. "having same values for all entries"
  536. )
  537. optim_idx = optim_idx[0].item()
  538. else:
  539. optim_idx = optim_idx.item()
  540. # b. tuple or list type
  541. else:
  542. loss, stats, weight = retval
  543. optim_idx = None
  544. stats = {k: v for k, v in stats.items() if v is not None}
  545. if ngpu > 1 or distributed:
  546. # Apply weighted averaging for loss and stats
  547. loss = (loss * weight.type(loss.dtype)).sum()
  548. # if distributed, this method can also apply all_reduce()
  549. stats, weight = recursive_average(stats, weight, distributed)
  550. # Now weight is summation over all workers
  551. loss /= weight
  552. if distributed:
  553. # NOTE(kamo): Multiply world_size because DistributedDataParallel
  554. # automatically normalizes the gradient by world_size.
  555. loss *= torch.distributed.get_world_size()
  556. loss /= accum_grad
  557. reporter.register(stats, weight)
  558. with reporter.measure_time("backward_time"):
  559. if scaler is not None:
  560. # Scales loss. Calls backward() on scaled loss
  561. # to create scaled gradients.
  562. # Backward passes under autocast are not recommended.
  563. # Backward ops run in the same dtype autocast chose
  564. # for corresponding forward ops.
  565. scaler.scale(loss).backward()
  566. else:
  567. loss.backward()
  568. if iiter % accum_grad == 0:
  569. if scaler is not None:
  570. # Unscales the gradients of optimizer's assigned params in-place
  571. for iopt, optimizer in enumerate(optimizers):
  572. if optim_idx is not None and iopt != optim_idx:
  573. continue
  574. scaler.unscale_(optimizer)
  575. # gradient noise injection
  576. if grad_noise:
  577. add_gradient_noise(
  578. model,
  579. reporter.get_total_count(),
  580. duration=100,
  581. eta=1.0,
  582. scale_factor=0.55,
  583. )
  584. # compute the gradient norm to check if it is normal or not
  585. grad_norm = torch.nn.utils.clip_grad_norm_(
  586. model.parameters(),
  587. max_norm=grad_clip,
  588. norm_type=grad_clip_type,
  589. )
  590. # PyTorch<=1.4, clip_grad_norm_ returns float value
  591. if not isinstance(grad_norm, torch.Tensor):
  592. grad_norm = torch.tensor(grad_norm)
  593. if not torch.isfinite(grad_norm):
  594. logging.warning(
  595. f"The grad norm is {grad_norm}. Skipping updating the model."
  596. )
  597. # Must invoke scaler.update() if unscale_() is used in the iteration
  598. # to avoid the following error:
  599. # RuntimeError: unscale_() has already been called
  600. # on this optimizer since the last update().
  601. # Note that if the gradient has inf/nan values,
  602. # scaler.step skips optimizer.step().
  603. if scaler is not None:
  604. for iopt, optimizer in enumerate(optimizers):
  605. if optim_idx is not None and iopt != optim_idx:
  606. continue
  607. scaler.step(optimizer)
  608. scaler.update()
  609. else:
  610. all_steps_are_invalid = False
  611. with reporter.measure_time("optim_step_time"):
  612. for iopt, (optimizer, scheduler) in enumerate(
  613. zip(optimizers, schedulers)
  614. ):
  615. if optim_idx is not None and iopt != optim_idx:
  616. continue
  617. if scaler is not None:
  618. # scaler.step() first unscales the gradients of
  619. # the optimizer's assigned params.
  620. scaler.step(optimizer)
  621. # Updates the scale for next iteration.
  622. scaler.update()
  623. else:
  624. optimizer.step()
  625. if isinstance(scheduler, AbsBatchStepScheduler):
  626. scheduler.step()
  627. for iopt, optimizer in enumerate(optimizers):
  628. if optim_idx is not None and iopt != optim_idx:
  629. continue
  630. optimizer.zero_grad()
  631. # Register lr and train/load time[sec/step],
  632. # where step refers to accum_grad * mini-batch
  633. reporter.register(
  634. dict(
  635. {
  636. f"optim{i}_lr{j}": pg["lr"]
  637. for i, optimizer in enumerate(optimizers)
  638. for j, pg in enumerate(optimizer.param_groups)
  639. if "lr" in pg
  640. },
  641. train_time=time.perf_counter() - start_time,
  642. ),
  643. )
  644. start_time = time.perf_counter()
  645. # update num_updates
  646. if distributed:
  647. if hasattr(model.module, "num_updates"):
  648. model.module.set_num_updates(model.module.get_num_updates() + 1)
  649. options.num_updates = model.module.get_num_updates()
  650. if model.module.get_num_updates() >= options.max_update:
  651. max_update_stop = True
  652. else:
  653. if hasattr(model, "num_updates"):
  654. model.set_num_updates(model.get_num_updates() + 1)
  655. options.num_updates = model.get_num_updates()
  656. if model.get_num_updates() >= options.max_update:
  657. max_update_stop = True
  658. # NOTE(kamo): Call log_message() after next()
  659. reporter.next()
  660. if iiter % log_interval == 0:
  661. num_updates = options.num_updates if hasattr(options, "num_updates") else None
  662. logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
  663. if summary_writer is not None:
  664. reporter.tensorboard_add_scalar(summary_writer, -log_interval)
  665. # if use_wandb:
  666. # reporter.wandb_log()
  667. if max_update_stop:
  668. break
  669. else:
  670. if distributed:
  671. iterator_stop.fill_(1)
  672. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
  673. return all_steps_are_invalid, max_update_stop
  674. @torch.no_grad()
  675. def validate_one_epoch(
  676. self,
  677. model: torch.nn.Module,
  678. iterator: Iterable[Dict[str, torch.Tensor]],
  679. reporter: SubReporter,
  680. options: TrainerOptions,
  681. distributed_option: DistributedOption,
  682. ) -> None:
  683. ngpu = options.ngpu
  684. # no_forward_run = options.no_forward_run
  685. distributed = distributed_option.distributed
  686. model.eval()
  687. # [For distributed] Because iteration counts are not always equals between
  688. # processes, send stop-flag to the other processes if iterator is finished
  689. iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
  690. for (_, batch) in iterator:
  691. assert isinstance(batch, dict), type(batch)
  692. if distributed:
  693. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
  694. if iterator_stop > 0:
  695. break
  696. batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
  697. # if no_forward_run:
  698. # continue
  699. retval = model(**batch)
  700. if isinstance(retval, dict):
  701. stats = retval["stats"]
  702. weight = retval["weight"]
  703. else:
  704. _, stats, weight = retval
  705. if ngpu > 1 or distributed:
  706. # Apply weighted averaging for stats.
  707. # if distributed, this method can also apply all_reduce()
  708. stats, weight = recursive_average(stats, weight, distributed)
  709. reporter.register(stats, weight)
  710. reporter.next()
  711. else:
  712. if distributed:
  713. iterator_stop.fill_(1)
  714. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
  715. def build_trainer(
  716. args,
  717. model: FunASRModel,
  718. optimizers: Sequence[torch.optim.Optimizer],
  719. schedulers: Sequence[Optional[AbsScheduler]],
  720. train_dataloader: AbsIterFactory,
  721. valid_dataloader: AbsIterFactory,
  722. distributed_option: DistributedOption
  723. ):
  724. trainer = Trainer(
  725. args=args,
  726. model=model,
  727. optimizers=optimizers,
  728. schedulers=schedulers,
  729. train_dataloader=train_dataloader,
  730. valid_dataloader=valid_dataloader,
  731. distributed_option=distributed_option
  732. )
  733. return trainer