build_trainer.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820
  1. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Trainer module."""
  4. import argparse
  5. import dataclasses
  6. import logging
  7. import os
  8. import time
  9. from contextlib import contextmanager
  10. from dataclasses import is_dataclass
  11. from distutils.version import LooseVersion
  12. from io import BytesIO
  13. from pathlib import Path
  14. from typing import Dict
  15. from typing import Iterable
  16. from typing import List
  17. from typing import Optional
  18. from typing import Sequence
  19. from typing import Tuple
  20. from typing import Union
  21. import humanfriendly
  22. import oss2
  23. import torch
  24. import torch.nn
  25. import torch.optim
  26. from typeguard import check_argument_types
  27. from funasr.iterators.abs_iter_factory import AbsIterFactory
  28. from funasr.main_funcs.average_nbest_models import average_nbest_models
  29. from funasr.models.base_model import FunASRModel
  30. from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
  31. from funasr.schedulers.abs_scheduler import AbsEpochStepScheduler
  32. from funasr.schedulers.abs_scheduler import AbsScheduler
  33. from funasr.schedulers.abs_scheduler import AbsValEpochStepScheduler
  34. from funasr.torch_utils.add_gradient_noise import add_gradient_noise
  35. from funasr.torch_utils.device_funcs import to_device
  36. from funasr.torch_utils.recursive_op import recursive_average
  37. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  38. from funasr.train.distributed_utils import DistributedOption
  39. from funasr.train.reporter import Reporter
  40. from funasr.train.reporter import SubReporter
  41. from funasr.utils.build_dataclass import build_dataclass
  42. if torch.distributed.is_available():
  43. from torch.distributed import ReduceOp
  44. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  45. from torch.cuda.amp import autocast
  46. from torch.cuda.amp import GradScaler
  47. else:
  48. # Nothing to do if torch<1.6.0
  49. @contextmanager
  50. def autocast(enabled=True):
  51. yield
  52. GradScaler = None
  53. try:
  54. import fairscale
  55. except ImportError:
  56. fairscale = None
  57. @dataclasses.dataclass
  58. class TrainerOptions:
  59. ngpu: int
  60. resume: bool
  61. use_amp: bool
  62. train_dtype: str
  63. grad_noise: bool
  64. accum_grad: int
  65. grad_clip: float
  66. grad_clip_type: float
  67. log_interval: Optional[int]
  68. # no_forward_run: bool
  69. use_tensorboard: bool
  70. # use_wandb: bool
  71. output_dir: Union[Path, str]
  72. max_epoch: int
  73. max_update: int
  74. seed: int
  75. # sharded_ddp: bool
  76. patience: Optional[int]
  77. keep_nbest_models: Union[int, List[int]]
  78. nbest_averaging_interval: int
  79. early_stopping_criterion: Sequence[str]
  80. best_model_criterion: Sequence[Sequence[str]]
  81. val_scheduler_criterion: Sequence[str]
  82. unused_parameters: bool
  83. # wandb_model_log_interval: int
  84. use_pai: bool
  85. oss_bucket: Union[oss2.Bucket, None]
  86. class Trainer:
  87. """Trainer
  88. """
  89. def __init__(self,
  90. args,
  91. model: FunASRModel,
  92. optimizers: Sequence[torch.optim.Optimizer],
  93. schedulers: Sequence[Optional[AbsScheduler]],
  94. train_dataloader: AbsIterFactory,
  95. valid_dataloader: AbsIterFactory,
  96. distributed_option: DistributedOption):
  97. self.trainer_options = self.build_options(args)
  98. self.model = model
  99. self.optimizers = optimizers
  100. self.schedulers = schedulers
  101. self.train_dataloader = train_dataloader
  102. self.valid_dataloader = valid_dataloader
  103. self.distributed_option = distributed_option
  104. def build_options(self, args: argparse.Namespace) -> TrainerOptions:
  105. """Build options consumed by train(), eval()"""
  106. assert check_argument_types()
  107. return build_dataclass(TrainerOptions, args)
  108. @classmethod
  109. def add_arguments(cls, parser: argparse.ArgumentParser):
  110. """Reserved for future development of another Trainer"""
  111. pass
  112. def resume(self,
  113. checkpoint: Union[str, Path],
  114. model: torch.nn.Module,
  115. reporter: Reporter,
  116. optimizers: Sequence[torch.optim.Optimizer],
  117. schedulers: Sequence[Optional[AbsScheduler]],
  118. scaler: Optional[GradScaler],
  119. ngpu: int = 0,
  120. ):
  121. states = torch.load(
  122. checkpoint,
  123. map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
  124. )
  125. model.load_state_dict(states["model"])
  126. reporter.load_state_dict(states["reporter"])
  127. for optimizer, state in zip(optimizers, states["optimizers"]):
  128. optimizer.load_state_dict(state)
  129. for scheduler, state in zip(schedulers, states["schedulers"]):
  130. if scheduler is not None:
  131. scheduler.load_state_dict(state)
  132. if scaler is not None:
  133. if states["scaler"] is None:
  134. logging.warning("scaler state is not found")
  135. else:
  136. scaler.load_state_dict(states["scaler"])
  137. logging.info(f"The training was resumed using {checkpoint}")
  138. def run(self) -> None:
  139. """Perform training. This method performs the main process of training."""
  140. assert check_argument_types()
  141. # NOTE(kamo): Don't check the type more strictly as far trainer_options
  142. model = self.model
  143. optimizers = self.optimizers
  144. schedulers = self.schedulers
  145. train_dataloader = self.train_dataloader
  146. valid_dataloader = self.valid_dataloader
  147. trainer_options = self.trainer_options
  148. distributed_option = self.distributed_option
  149. assert is_dataclass(trainer_options), type(trainer_options)
  150. assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers))
  151. if isinstance(trainer_options.keep_nbest_models, int):
  152. keep_nbest_models = [trainer_options.keep_nbest_models]
  153. else:
  154. if len(trainer_options.keep_nbest_models) == 0:
  155. logging.warning("No keep_nbest_models is given. Change to [1]")
  156. trainer_options.keep_nbest_models = [1]
  157. keep_nbest_models = trainer_options.keep_nbest_models
  158. output_dir = Path(trainer_options.output_dir)
  159. reporter = Reporter()
  160. if trainer_options.use_amp:
  161. if LooseVersion(torch.__version__) < LooseVersion("1.6.0"):
  162. raise RuntimeError(
  163. "Require torch>=1.6.0 for Automatic Mixed Precision"
  164. )
  165. # if trainer_options.sharded_ddp:
  166. # if fairscale is None:
  167. # raise RuntimeError(
  168. # "Requiring fairscale. Do 'pip install fairscale'"
  169. # )
  170. # scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
  171. # else:
  172. scaler = GradScaler()
  173. else:
  174. scaler = None
  175. if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
  176. self.resume(
  177. checkpoint=output_dir / "checkpoint.pb",
  178. model=model,
  179. optimizers=optimizers,
  180. schedulers=schedulers,
  181. reporter=reporter,
  182. scaler=scaler,
  183. ngpu=trainer_options.ngpu,
  184. )
  185. start_epoch = reporter.get_epoch() + 1
  186. if start_epoch == trainer_options.max_epoch + 1:
  187. logging.warning(
  188. f"The training has already reached at max_epoch: {start_epoch}"
  189. )
  190. if distributed_option.distributed:
  191. dp_model = torch.nn.parallel.DistributedDataParallel(
  192. model, find_unused_parameters=trainer_options.unused_parameters)
  193. elif distributed_option.ngpu > 1:
  194. dp_model = torch.nn.parallel.DataParallel(
  195. model,
  196. device_ids=list(range(distributed_option.ngpu)),
  197. )
  198. else:
  199. # NOTE(kamo): DataParallel also should work with ngpu=1,
  200. # but for debuggability it's better to keep this block.
  201. dp_model = model
  202. if trainer_options.use_tensorboard and (
  203. not distributed_option.distributed or distributed_option.dist_rank == 0
  204. ):
  205. from torch.utils.tensorboard import SummaryWriter
  206. if trainer_options.use_pai:
  207. train_summary_writer = SummaryWriter(
  208. os.path.join(trainer_options.output_dir, "tensorboard/train")
  209. )
  210. valid_summary_writer = SummaryWriter(
  211. os.path.join(trainer_options.output_dir, "tensorboard/valid")
  212. )
  213. else:
  214. train_summary_writer = SummaryWriter(
  215. str(output_dir / "tensorboard" / "train")
  216. )
  217. valid_summary_writer = SummaryWriter(
  218. str(output_dir / "tensorboard" / "valid")
  219. )
  220. else:
  221. train_summary_writer = None
  222. start_time = time.perf_counter()
  223. for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
  224. if iepoch != start_epoch:
  225. logging.info(
  226. "{}/{}epoch started. Estimated time to finish: {}".format(
  227. iepoch,
  228. trainer_options.max_epoch,
  229. humanfriendly.format_timespan(
  230. (time.perf_counter() - start_time)
  231. / (iepoch - start_epoch)
  232. * (trainer_options.max_epoch - iepoch + 1)
  233. ),
  234. )
  235. )
  236. else:
  237. logging.info(f"{iepoch}/{trainer_options.max_epoch}epoch started")
  238. set_all_random_seed(trainer_options.seed + iepoch)
  239. reporter.set_epoch(iepoch)
  240. # 1. Train and validation for one-epoch
  241. with reporter.observe("train") as sub_reporter:
  242. all_steps_are_invalid, max_update_stop = self.train_one_epoch(
  243. model=dp_model,
  244. optimizers=optimizers,
  245. schedulers=schedulers,
  246. iterator=train_dataloader.build_iter(iepoch),
  247. reporter=sub_reporter,
  248. scaler=scaler,
  249. summary_writer=train_summary_writer,
  250. options=trainer_options,
  251. distributed_option=distributed_option,
  252. )
  253. with reporter.observe("valid") as sub_reporter:
  254. self.validate_one_epoch(
  255. model=dp_model,
  256. iterator=valid_dataloader.build_iter(iepoch),
  257. reporter=sub_reporter,
  258. options=trainer_options,
  259. distributed_option=distributed_option,
  260. )
  261. # 2. LR Scheduler step
  262. for scheduler in schedulers:
  263. if isinstance(scheduler, AbsValEpochStepScheduler):
  264. scheduler.step(
  265. reporter.get_value(*trainer_options.val_scheduler_criterion)
  266. )
  267. elif isinstance(scheduler, AbsEpochStepScheduler):
  268. scheduler.step()
  269. # if trainer_options.sharded_ddp:
  270. # for optimizer in optimizers:
  271. # if isinstance(optimizer, fairscale.optim.oss.OSS):
  272. # optimizer.consolidate_state_dict()
  273. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  274. # 3. Report the results
  275. logging.info(reporter.log_message())
  276. if train_summary_writer is not None:
  277. reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
  278. reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
  279. # if trainer_options.use_wandb:
  280. # reporter.wandb_log()
  281. # save tensorboard on oss
  282. if trainer_options.use_pai and train_summary_writer is not None:
  283. def write_tensorboard_summary(summary_writer_path, oss_bucket):
  284. file_list = []
  285. for root, dirs, files in os.walk(summary_writer_path, topdown=False):
  286. for name in files:
  287. file_full_path = os.path.join(root, name)
  288. file_list.append(file_full_path)
  289. for file_full_path in file_list:
  290. with open(file_full_path, "rb") as f:
  291. oss_bucket.put_object(file_full_path, f)
  292. write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/train"),
  293. trainer_options.oss_bucket)
  294. write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/valid"),
  295. trainer_options.oss_bucket)
  296. # 4. Save/Update the checkpoint
  297. if trainer_options.use_pai:
  298. buffer = BytesIO()
  299. torch.save(
  300. {
  301. "model": model.state_dict(),
  302. "reporter": reporter.state_dict(),
  303. "optimizers": [o.state_dict() for o in optimizers],
  304. "schedulers": [
  305. s.state_dict() if s is not None else None
  306. for s in schedulers
  307. ],
  308. "scaler": scaler.state_dict() if scaler is not None else None,
  309. "ema_model": model.encoder.ema.model.state_dict()
  310. if hasattr(model.encoder, "ema") and model.encoder.ema is not None else None,
  311. },
  312. buffer,
  313. )
  314. trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pb"),
  315. buffer.getvalue())
  316. else:
  317. torch.save(
  318. {
  319. "model": model.state_dict(),
  320. "reporter": reporter.state_dict(),
  321. "optimizers": [o.state_dict() for o in optimizers],
  322. "schedulers": [
  323. s.state_dict() if s is not None else None
  324. for s in schedulers
  325. ],
  326. "scaler": scaler.state_dict() if scaler is not None else None,
  327. },
  328. output_dir / "checkpoint.pb",
  329. )
  330. # 5. Save and log the model and update the link to the best model
  331. if trainer_options.use_pai:
  332. buffer = BytesIO()
  333. torch.save(model.state_dict(), buffer)
  334. trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir,
  335. f"{iepoch}epoch.pb"), buffer.getvalue())
  336. else:
  337. torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pb")
  338. # Creates a sym link latest.pb -> {iepoch}epoch.pb
  339. if trainer_options.use_pai:
  340. p = os.path.join(trainer_options.output_dir, "latest.pb")
  341. if trainer_options.oss_bucket.object_exists(p):
  342. trainer_options.oss_bucket.delete_object(p)
  343. trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
  344. os.path.join(trainer_options.output_dir,
  345. f"{iepoch}epoch.pb"), p)
  346. else:
  347. p = output_dir / "latest.pb"
  348. if p.is_symlink() or p.exists():
  349. p.unlink()
  350. p.symlink_to(f"{iepoch}epoch.pb")
  351. _improved = []
  352. for _phase, k, _mode in trainer_options.best_model_criterion:
  353. # e.g. _phase, k, _mode = "train", "loss", "min"
  354. if reporter.has(_phase, k):
  355. best_epoch = reporter.get_best_epoch(_phase, k, _mode)
  356. # Creates sym links if it's the best result
  357. if best_epoch == iepoch:
  358. if trainer_options.use_pai:
  359. p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pb")
  360. if trainer_options.oss_bucket.object_exists(p):
  361. trainer_options.oss_bucket.delete_object(p)
  362. trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
  363. os.path.join(trainer_options.output_dir,
  364. f"{iepoch}epoch.pb"), p)
  365. else:
  366. p = output_dir / f"{_phase}.{k}.best.pb"
  367. if p.is_symlink() or p.exists():
  368. p.unlink()
  369. p.symlink_to(f"{iepoch}epoch.pb")
  370. _improved.append(f"{_phase}.{k}")
  371. if len(_improved) == 0:
  372. logging.info("There are no improvements in this epoch")
  373. else:
  374. logging.info(
  375. "The best model has been updated: " + ", ".join(_improved)
  376. )
  377. # log_model = (
  378. # trainer_options.wandb_model_log_interval > 0
  379. # and iepoch % trainer_options.wandb_model_log_interval == 0
  380. # )
  381. # if log_model and trainer_options.use_wandb:
  382. # import wandb
  383. #
  384. # logging.info("Logging Model on this epoch :::::")
  385. # artifact = wandb.Artifact(
  386. # name=f"model_{wandb.run.id}",
  387. # type="model",
  388. # metadata={"improved": _improved},
  389. # )
  390. # artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
  391. # aliases = [
  392. # f"epoch-{iepoch}",
  393. # "best" if best_epoch == iepoch else "",
  394. # ]
  395. # wandb.log_artifact(artifact, aliases=aliases)
  396. # 6. Remove the model files excluding n-best epoch and latest epoch
  397. _removed = []
  398. # Get the union set of the n-best among multiple criterion
  399. nbests = set().union(
  400. *[
  401. set(reporter.sort_epochs(ph, k, m)[: max(keep_nbest_models)])
  402. for ph, k, m in trainer_options.best_model_criterion
  403. if reporter.has(ph, k)
  404. ]
  405. )
  406. # Generated n-best averaged model
  407. if (
  408. trainer_options.nbest_averaging_interval > 0
  409. and iepoch % trainer_options.nbest_averaging_interval == 0
  410. ):
  411. average_nbest_models(
  412. reporter=reporter,
  413. output_dir=output_dir,
  414. best_model_criterion=trainer_options.best_model_criterion,
  415. nbest=keep_nbest_models,
  416. suffix=f"till{iepoch}epoch",
  417. oss_bucket=trainer_options.oss_bucket,
  418. pai_output_dir=trainer_options.output_dir,
  419. )
  420. for e in range(1, iepoch):
  421. if trainer_options.use_pai:
  422. p = os.path.join(trainer_options.output_dir, f"{e}epoch.pb")
  423. if trainer_options.oss_bucket.object_exists(p) and e not in nbests:
  424. trainer_options.oss_bucket.delete_object(p)
  425. _removed.append(str(p))
  426. else:
  427. p = output_dir / f"{e}epoch.pb"
  428. if p.exists() and e not in nbests:
  429. p.unlink()
  430. _removed.append(str(p))
  431. if len(_removed) != 0:
  432. logging.info("The model files were removed: " + ", ".join(_removed))
  433. # 7. If any updating haven't happened, stops the training
  434. if all_steps_are_invalid:
  435. logging.warning(
  436. f"The gradients at all steps are invalid in this epoch. "
  437. f"Something seems wrong. This training was stopped at {iepoch}epoch"
  438. )
  439. break
  440. if max_update_stop:
  441. logging.info(
  442. f"Stopping training due to "
  443. f"num_updates: {trainer_options.num_updates} >= max_update: {trainer_options.max_update}"
  444. )
  445. break
  446. # 8. Check early stopping
  447. if trainer_options.patience is not None:
  448. if reporter.check_early_stopping(
  449. trainer_options.patience, *trainer_options.early_stopping_criterion
  450. ):
  451. break
  452. else:
  453. logging.info(
  454. f"The training was finished at {trainer_options.max_epoch} epochs "
  455. )
  456. # Generated n-best averaged model
  457. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  458. average_nbest_models(
  459. reporter=reporter,
  460. output_dir=output_dir,
  461. best_model_criterion=trainer_options.best_model_criterion,
  462. nbest=keep_nbest_models,
  463. oss_bucket=trainer_options.oss_bucket,
  464. pai_output_dir=trainer_options.output_dir,
  465. )
  466. def train_one_epoch(
  467. self,
  468. model: torch.nn.Module,
  469. iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
  470. optimizers: Sequence[torch.optim.Optimizer],
  471. schedulers: Sequence[Optional[AbsScheduler]],
  472. scaler: Optional[GradScaler],
  473. reporter: SubReporter,
  474. summary_writer,
  475. options: TrainerOptions,
  476. distributed_option: DistributedOption,
  477. ) -> Tuple[bool, bool]:
  478. assert check_argument_types()
  479. grad_noise = options.grad_noise
  480. accum_grad = options.accum_grad
  481. grad_clip = options.grad_clip
  482. grad_clip_type = options.grad_clip_type
  483. log_interval = options.log_interval
  484. # no_forward_run = options.no_forward_run
  485. ngpu = options.ngpu
  486. # use_wandb = options.use_wandb
  487. distributed = distributed_option.distributed
  488. if log_interval is None:
  489. try:
  490. log_interval = max(len(iterator) // 20, 10)
  491. except TypeError:
  492. log_interval = 100
  493. model.train()
  494. all_steps_are_invalid = True
  495. max_update_stop = False
  496. # [For distributed] Because iteration counts are not always equals between
  497. # processes, send stop-flag to the other processes if iterator is finished
  498. iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
  499. start_time = time.perf_counter()
  500. for iiter, (_, batch) in enumerate(
  501. reporter.measure_iter_time(iterator, "iter_time"), 1
  502. ):
  503. assert isinstance(batch, dict), type(batch)
  504. if distributed:
  505. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
  506. if iterator_stop > 0:
  507. break
  508. batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
  509. # if no_forward_run:
  510. # all_steps_are_invalid = False
  511. # continue
  512. with autocast(scaler is not None):
  513. with reporter.measure_time("forward_time"):
  514. retval = model(**batch)
  515. # Note(kamo):
  516. # Supporting two patterns for the returned value from the model
  517. # a. dict type
  518. if isinstance(retval, dict):
  519. loss = retval["loss"]
  520. stats = retval["stats"]
  521. weight = retval["weight"]
  522. optim_idx = retval.get("optim_idx")
  523. if optim_idx is not None and not isinstance(optim_idx, int):
  524. if not isinstance(optim_idx, torch.Tensor):
  525. raise RuntimeError(
  526. "optim_idx must be int or 1dim torch.Tensor, "
  527. f"but got {type(optim_idx)}"
  528. )
  529. if optim_idx.dim() >= 2:
  530. raise RuntimeError(
  531. "optim_idx must be int or 1dim torch.Tensor, "
  532. f"but got {optim_idx.dim()}dim tensor"
  533. )
  534. if optim_idx.dim() == 1:
  535. for v in optim_idx:
  536. if v != optim_idx[0]:
  537. raise RuntimeError(
  538. "optim_idx must be 1dim tensor "
  539. "having same values for all entries"
  540. )
  541. optim_idx = optim_idx[0].item()
  542. else:
  543. optim_idx = optim_idx.item()
  544. # b. tuple or list type
  545. else:
  546. loss, stats, weight = retval
  547. optim_idx = None
  548. stats = {k: v for k, v in stats.items() if v is not None}
  549. if ngpu > 1 or distributed:
  550. # Apply weighted averaging for loss and stats
  551. loss = (loss * weight.type(loss.dtype)).sum()
  552. # if distributed, this method can also apply all_reduce()
  553. stats, weight = recursive_average(stats, weight, distributed)
  554. # Now weight is summation over all workers
  555. loss /= weight
  556. if distributed:
  557. # NOTE(kamo): Multiply world_size because DistributedDataParallel
  558. # automatically normalizes the gradient by world_size.
  559. loss *= torch.distributed.get_world_size()
  560. loss /= accum_grad
  561. reporter.register(stats, weight)
  562. with reporter.measure_time("backward_time"):
  563. if scaler is not None:
  564. # Scales loss. Calls backward() on scaled loss
  565. # to create scaled gradients.
  566. # Backward passes under autocast are not recommended.
  567. # Backward ops run in the same dtype autocast chose
  568. # for corresponding forward ops.
  569. scaler.scale(loss).backward()
  570. else:
  571. loss.backward()
  572. if iiter % accum_grad == 0:
  573. if scaler is not None:
  574. # Unscales the gradients of optimizer's assigned params in-place
  575. for iopt, optimizer in enumerate(optimizers):
  576. if optim_idx is not None and iopt != optim_idx:
  577. continue
  578. scaler.unscale_(optimizer)
  579. # gradient noise injection
  580. if grad_noise:
  581. add_gradient_noise(
  582. model,
  583. reporter.get_total_count(),
  584. duration=100,
  585. eta=1.0,
  586. scale_factor=0.55,
  587. )
  588. # compute the gradient norm to check if it is normal or not
  589. grad_norm = torch.nn.utils.clip_grad_norm_(
  590. model.parameters(),
  591. max_norm=grad_clip,
  592. norm_type=grad_clip_type,
  593. )
  594. # PyTorch<=1.4, clip_grad_norm_ returns float value
  595. if not isinstance(grad_norm, torch.Tensor):
  596. grad_norm = torch.tensor(grad_norm)
  597. if not torch.isfinite(grad_norm):
  598. logging.warning(
  599. f"The grad norm is {grad_norm}. Skipping updating the model."
  600. )
  601. # Must invoke scaler.update() if unscale_() is used in the iteration
  602. # to avoid the following error:
  603. # RuntimeError: unscale_() has already been called
  604. # on this optimizer since the last update().
  605. # Note that if the gradient has inf/nan values,
  606. # scaler.step skips optimizer.step().
  607. if scaler is not None:
  608. for iopt, optimizer in enumerate(optimizers):
  609. if optim_idx is not None and iopt != optim_idx:
  610. continue
  611. scaler.step(optimizer)
  612. scaler.update()
  613. else:
  614. all_steps_are_invalid = False
  615. with reporter.measure_time("optim_step_time"):
  616. for iopt, (optimizer, scheduler) in enumerate(
  617. zip(optimizers, schedulers)
  618. ):
  619. if optim_idx is not None and iopt != optim_idx:
  620. continue
  621. if scaler is not None:
  622. # scaler.step() first unscales the gradients of
  623. # the optimizer's assigned params.
  624. scaler.step(optimizer)
  625. # Updates the scale for next iteration.
  626. scaler.update()
  627. else:
  628. optimizer.step()
  629. if isinstance(scheduler, AbsBatchStepScheduler):
  630. scheduler.step()
  631. for iopt, optimizer in enumerate(optimizers):
  632. if optim_idx is not None and iopt != optim_idx:
  633. continue
  634. optimizer.zero_grad()
  635. # Register lr and train/load time[sec/step],
  636. # where step refers to accum_grad * mini-batch
  637. reporter.register(
  638. dict(
  639. {
  640. f"optim{i}_lr{j}": pg["lr"]
  641. for i, optimizer in enumerate(optimizers)
  642. for j, pg in enumerate(optimizer.param_groups)
  643. if "lr" in pg
  644. },
  645. train_time=time.perf_counter() - start_time,
  646. ),
  647. )
  648. start_time = time.perf_counter()
  649. # update num_updates
  650. if distributed:
  651. if hasattr(model.module, "num_updates"):
  652. model.module.set_num_updates(model.module.get_num_updates() + 1)
  653. options.num_updates = model.module.get_num_updates()
  654. if model.module.get_num_updates() >= options.max_update:
  655. max_update_stop = True
  656. else:
  657. if hasattr(model, "num_updates"):
  658. model.set_num_updates(model.get_num_updates() + 1)
  659. options.num_updates = model.get_num_updates()
  660. if model.get_num_updates() >= options.max_update:
  661. max_update_stop = True
  662. # NOTE(kamo): Call log_message() after next()
  663. reporter.next()
  664. if iiter % log_interval == 0:
  665. num_updates = options.num_updates if hasattr(options, "num_updates") else None
  666. logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
  667. if summary_writer is not None:
  668. reporter.tensorboard_add_scalar(summary_writer, -log_interval)
  669. # if use_wandb:
  670. # reporter.wandb_log()
  671. if max_update_stop:
  672. break
  673. else:
  674. if distributed:
  675. iterator_stop.fill_(1)
  676. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
  677. return all_steps_are_invalid, max_update_stop
  678. @torch.no_grad()
  679. def validate_one_epoch(
  680. self,
  681. model: torch.nn.Module,
  682. iterator: Iterable[Dict[str, torch.Tensor]],
  683. reporter: SubReporter,
  684. options: TrainerOptions,
  685. distributed_option: DistributedOption,
  686. ) -> None:
  687. assert check_argument_types()
  688. ngpu = options.ngpu
  689. # no_forward_run = options.no_forward_run
  690. distributed = distributed_option.distributed
  691. model.eval()
  692. # [For distributed] Because iteration counts are not always equals between
  693. # processes, send stop-flag to the other processes if iterator is finished
  694. iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
  695. for (_, batch) in iterator:
  696. assert isinstance(batch, dict), type(batch)
  697. if distributed:
  698. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
  699. if iterator_stop > 0:
  700. break
  701. batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
  702. # if no_forward_run:
  703. # continue
  704. retval = model(**batch)
  705. if isinstance(retval, dict):
  706. stats = retval["stats"]
  707. weight = retval["weight"]
  708. else:
  709. _, stats, weight = retval
  710. if ngpu > 1 or distributed:
  711. # Apply weighted averaging for stats.
  712. # if distributed, this method can also apply all_reduce()
  713. stats, weight = recursive_average(stats, weight, distributed)
  714. reporter.register(stats, weight)
  715. reporter.next()
  716. else:
  717. if distributed:
  718. iterator_stop.fill_(1)
  719. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
  720. def build_trainer(
  721. args,
  722. model: FunASRModel,
  723. optimizers: Sequence[torch.optim.Optimizer],
  724. schedulers: Sequence[Optional[AbsScheduler]],
  725. train_dataloader: AbsIterFactory,
  726. valid_dataloader: AbsIterFactory,
  727. distributed_option: DistributedOption
  728. ):
  729. trainer = Trainer(
  730. args=args,
  731. model=model,
  732. optimizers=optimizers,
  733. schedulers=schedulers,
  734. train_dataloader=train_dataloader,
  735. valid_dataloader=valid_dataloader,
  736. distributed_option=distributed_option
  737. )
  738. return trainer