trainer.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881
  1. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Trainer module."""
  4. import argparse
  5. from contextlib import contextmanager
  6. import dataclasses
  7. from dataclasses import is_dataclass
  8. from distutils.version import LooseVersion
  9. import logging
  10. from pathlib import Path
  11. import time
  12. from typing import Dict
  13. from typing import Iterable
  14. from typing import List
  15. from typing import Optional
  16. from typing import Sequence
  17. from typing import Tuple
  18. from typing import Union
  19. import humanfriendly
  20. import oss2
  21. from io import BytesIO
  22. import os
  23. import numpy as np
  24. import torch
  25. import torch.nn
  26. import torch.optim
  27. from funasr.iterators.abs_iter_factory import AbsIterFactory
  28. from funasr.main_funcs.average_nbest_models import average_nbest_models
  29. from funasr.main_funcs.calculate_all_attentions import calculate_all_attentions
  30. from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
  31. from funasr.schedulers.abs_scheduler import AbsEpochStepScheduler
  32. from funasr.schedulers.abs_scheduler import AbsScheduler
  33. from funasr.schedulers.abs_scheduler import AbsValEpochStepScheduler
  34. from funasr.torch_utils.add_gradient_noise import add_gradient_noise
  35. from funasr.torch_utils.device_funcs import to_device
  36. from funasr.torch_utils.recursive_op import recursive_average
  37. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  38. from funasr.models.base_model import FunASRModel
  39. from funasr.train.distributed_utils import DistributedOption
  40. from funasr.train.reporter import Reporter
  41. from funasr.train.reporter import SubReporter
  42. from funasr.utils.build_dataclass import build_dataclass
  43. from funasr.utils.kwargs2args import kwargs2args
  44. if torch.distributed.is_available():
  45. from torch.distributed import ReduceOp
  46. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  47. from torch.cuda.amp import autocast
  48. from torch.cuda.amp import GradScaler
  49. else:
  50. # Nothing to do if torch<1.6.0
  51. @contextmanager
  52. def autocast(enabled=True):
  53. yield
  54. GradScaler = None
  55. try:
  56. import fairscale
  57. except ImportError:
  58. fairscale = None
  59. @dataclasses.dataclass
  60. class TrainerOptions:
  61. ngpu: int
  62. resume: bool
  63. use_amp: bool
  64. train_dtype: str
  65. grad_noise: bool
  66. accum_grad: int
  67. grad_clip: float
  68. grad_clip_type: float
  69. log_interval: Optional[int]
  70. no_forward_run: bool
  71. use_tensorboard: bool
  72. use_wandb: bool
  73. output_dir: Union[Path, str]
  74. max_epoch: int
  75. max_update: int
  76. seed: int
  77. sharded_ddp: bool
  78. patience: Optional[int]
  79. keep_nbest_models: Union[int, List[int]]
  80. nbest_averaging_interval: int
  81. early_stopping_criterion: Sequence[str]
  82. best_model_criterion: Sequence[Sequence[str]]
  83. val_scheduler_criterion: Sequence[str]
  84. unused_parameters: bool
  85. wandb_model_log_interval: int
  86. use_pai: bool
  87. oss_bucket: Union[oss2.Bucket, None]
  88. batch_interval: int
  89. bias_grad_times: float
  90. class Trainer:
  91. """Trainer having a optimizer.
  92. If you'd like to use multiple optimizers, then inherit this class
  93. and override the methods if necessary - at least "train_one_epoch()"
  94. >>> class TwoOptimizerTrainer(Trainer):
  95. ... @classmethod
  96. ... def add_arguments(cls, parser):
  97. ... ...
  98. ...
  99. ... @classmethod
  100. ... def train_one_epoch(cls, model, optimizers, ...):
  101. ... loss1 = model.model1(...)
  102. ... loss1.backward()
  103. ... optimizers[0].step()
  104. ...
  105. ... loss2 = model.model2(...)
  106. ... loss2.backward()
  107. ... optimizers[1].step()
  108. """
  109. def __init__(self):
  110. raise RuntimeError("This class can't be instantiated.")
  111. @classmethod
  112. def build_options(cls, args: argparse.Namespace) -> TrainerOptions:
  113. """Build options consumed by train(), eval()"""
  114. return build_dataclass(TrainerOptions, args)
  115. @classmethod
  116. def add_arguments(cls, parser: argparse.ArgumentParser):
  117. """Reserved for future development of another Trainer"""
  118. pass
  119. @staticmethod
  120. def resume(
  121. checkpoint: Union[str, Path],
  122. model: torch.nn.Module,
  123. reporter: Reporter,
  124. optimizers: Sequence[torch.optim.Optimizer],
  125. schedulers: Sequence[Optional[AbsScheduler]],
  126. scaler: Optional[GradScaler],
  127. ngpu: int = 0,
  128. oss_bucket=None,
  129. ):
  130. if oss_bucket is None:
  131. if os.path.exists(checkpoint):
  132. states = torch.load(
  133. checkpoint,
  134. map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
  135. )
  136. else:
  137. return 0
  138. else:
  139. if oss_bucket.object_exists(checkpoint):
  140. buffer = BytesIO(oss_bucket.get_object(checkpoint).read())
  141. states = torch.load(buffer, map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",)
  142. else:
  143. return 0
  144. model.load_state_dict(states["model"])
  145. reporter.load_state_dict(states["reporter"])
  146. for optimizer, state in zip(optimizers, states["optimizers"]):
  147. optimizer.load_state_dict(state)
  148. for scheduler, state in zip(schedulers, states["schedulers"]):
  149. if scheduler is not None:
  150. scheduler.load_state_dict(state)
  151. if scaler is not None:
  152. if states["scaler"] is None:
  153. logging.warning("scaler state is not found")
  154. else:
  155. scaler.load_state_dict(states["scaler"])
  156. logging.info(f"The training was resumed using {checkpoint}")
  157. @classmethod
  158. def run(
  159. cls,
  160. model: FunASRModel,
  161. optimizers: Sequence[torch.optim.Optimizer],
  162. schedulers: Sequence[Optional[AbsScheduler]],
  163. train_iter_factory: AbsIterFactory,
  164. valid_iter_factory: AbsIterFactory,
  165. trainer_options,
  166. distributed_option: DistributedOption,
  167. ) -> None:
  168. """Perform training. This method performs the main process of training."""
  169. # NOTE(kamo): Don't check the type more strictly as far trainer_options
  170. assert is_dataclass(trainer_options), type(trainer_options)
  171. assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers))
  172. if isinstance(trainer_options.keep_nbest_models, int):
  173. keep_nbest_models = [trainer_options.keep_nbest_models]
  174. else:
  175. if len(trainer_options.keep_nbest_models) == 0:
  176. logging.warning("No keep_nbest_models is given. Change to [1]")
  177. trainer_options.keep_nbest_models = [1]
  178. keep_nbest_models = trainer_options.keep_nbest_models
  179. output_dir = Path(trainer_options.output_dir)
  180. reporter = Reporter()
  181. if trainer_options.use_amp:
  182. if LooseVersion(torch.__version__) < LooseVersion("1.6.0"):
  183. raise RuntimeError(
  184. "Require torch>=1.6.0 for Automatic Mixed Precision"
  185. )
  186. if trainer_options.sharded_ddp:
  187. if fairscale is None:
  188. raise RuntimeError(
  189. "Requiring fairscale. Do 'pip install fairscale'"
  190. )
  191. scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
  192. else:
  193. scaler = GradScaler()
  194. else:
  195. scaler = None
  196. if trainer_options.resume:
  197. cls.resume(
  198. checkpoint=os.path.join(trainer_options.output_dir, "checkpoint.pb") if trainer_options.use_pai else output_dir / "checkpoint.pb",
  199. model=model,
  200. optimizers=optimizers,
  201. schedulers=schedulers,
  202. reporter=reporter,
  203. scaler=scaler,
  204. ngpu=trainer_options.ngpu,
  205. oss_bucket=trainer_options.oss_bucket if trainer_options.use_pai else None,
  206. )
  207. start_epoch = reporter.get_epoch() + 1
  208. if start_epoch == trainer_options.max_epoch + 1:
  209. logging.warning(
  210. f"The training has already reached at max_epoch: {start_epoch}"
  211. )
  212. if distributed_option.distributed:
  213. if trainer_options.sharded_ddp:
  214. dp_model = fairscale.nn.data_parallel.ShardedDataParallel(
  215. module=model,
  216. sharded_optimizer=optimizers,
  217. )
  218. else:
  219. dp_model = torch.nn.parallel.DistributedDataParallel(
  220. model, find_unused_parameters=trainer_options.unused_parameters)
  221. elif distributed_option.ngpu > 1:
  222. dp_model = torch.nn.parallel.DataParallel(
  223. model,
  224. device_ids=list(range(distributed_option.ngpu)),
  225. )
  226. else:
  227. # NOTE(kamo): DataParallel also should work with ngpu=1,
  228. # but for debuggability it's better to keep this block.
  229. dp_model = model
  230. if trainer_options.use_tensorboard and (
  231. not distributed_option.distributed or distributed_option.dist_rank == 0
  232. ):
  233. from torch.utils.tensorboard import SummaryWriter
  234. if trainer_options.use_pai:
  235. train_summary_writer = SummaryWriter(
  236. os.path.join(trainer_options.output_dir, "tensorboard/train")
  237. )
  238. valid_summary_writer = SummaryWriter(
  239. os.path.join(trainer_options.output_dir, "tensorboard/valid")
  240. )
  241. else:
  242. train_summary_writer = SummaryWriter(
  243. str(output_dir / "tensorboard" / "train")
  244. )
  245. valid_summary_writer = SummaryWriter(
  246. str(output_dir / "tensorboard" / "valid")
  247. )
  248. else:
  249. train_summary_writer = None
  250. start_time = time.perf_counter()
  251. for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
  252. if iepoch != start_epoch:
  253. logging.info(
  254. "{}/{}epoch started. Estimated time to finish: {}".format(
  255. iepoch,
  256. trainer_options.max_epoch,
  257. humanfriendly.format_timespan(
  258. (time.perf_counter() - start_time)
  259. / (iepoch - start_epoch)
  260. * (trainer_options.max_epoch - iepoch + 1)
  261. ),
  262. )
  263. )
  264. else:
  265. logging.info(f"{iepoch}/{trainer_options.max_epoch}epoch started")
  266. set_all_random_seed(trainer_options.seed + iepoch)
  267. reporter.set_epoch(iepoch)
  268. # 1. Train and validation for one-epoch
  269. with reporter.observe("train") as sub_reporter:
  270. all_steps_are_invalid, max_update_stop = cls.train_one_epoch(
  271. model=dp_model,
  272. optimizers=optimizers,
  273. schedulers=schedulers,
  274. iterator=train_iter_factory.build_iter(iepoch),
  275. reporter=sub_reporter,
  276. scaler=scaler,
  277. summary_writer=train_summary_writer,
  278. options=trainer_options,
  279. distributed_option=distributed_option,
  280. )
  281. with reporter.observe("valid") as sub_reporter:
  282. cls.validate_one_epoch(
  283. model=dp_model,
  284. iterator=valid_iter_factory.build_iter(iepoch),
  285. reporter=sub_reporter,
  286. options=trainer_options,
  287. distributed_option=distributed_option,
  288. )
  289. # 2. LR Scheduler step
  290. for scheduler in schedulers:
  291. if isinstance(scheduler, AbsValEpochStepScheduler):
  292. scheduler.step(
  293. reporter.get_value(*trainer_options.val_scheduler_criterion)
  294. )
  295. elif isinstance(scheduler, AbsEpochStepScheduler):
  296. scheduler.step()
  297. if trainer_options.sharded_ddp:
  298. for optimizer in optimizers:
  299. if isinstance(optimizer, fairscale.optim.oss.OSS):
  300. optimizer.consolidate_state_dict()
  301. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  302. # 3. Report the results
  303. logging.info(reporter.log_message())
  304. if train_summary_writer is not None:
  305. reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
  306. reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
  307. if trainer_options.use_wandb:
  308. reporter.wandb_log()
  309. # save tensorboard on oss
  310. if trainer_options.use_pai and train_summary_writer is not None:
  311. def write_tensorboard_summary(summary_writer_path, oss_bucket):
  312. file_list = []
  313. for root, dirs, files in os.walk(summary_writer_path, topdown=False):
  314. for name in files:
  315. file_full_path = os.path.join(root, name)
  316. file_list.append(file_full_path)
  317. for file_full_path in file_list:
  318. with open(file_full_path, "rb") as f:
  319. oss_bucket.put_object(file_full_path, f)
  320. write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/train"), trainer_options.oss_bucket)
  321. write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/valid"), trainer_options.oss_bucket)
  322. # 4. Save/Update the checkpoint
  323. if trainer_options.use_pai:
  324. buffer = BytesIO()
  325. torch.save(
  326. {
  327. "model": model.state_dict(),
  328. "reporter": reporter.state_dict(),
  329. "optimizers": [o.state_dict() for o in optimizers],
  330. "schedulers": [
  331. s.state_dict() if s is not None else None
  332. for s in schedulers
  333. ],
  334. "scaler": scaler.state_dict() if scaler is not None else None,
  335. "ema_model": model.encoder.ema.model.state_dict()
  336. if hasattr(model.encoder, "ema") and model.encoder.ema is not None else None,
  337. },
  338. buffer,
  339. )
  340. trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pb"), buffer.getvalue())
  341. else:
  342. torch.save(
  343. {
  344. "model": model.state_dict(),
  345. "reporter": reporter.state_dict(),
  346. "optimizers": [o.state_dict() for o in optimizers],
  347. "schedulers": [
  348. s.state_dict() if s is not None else None
  349. for s in schedulers
  350. ],
  351. "scaler": scaler.state_dict() if scaler is not None else None,
  352. },
  353. output_dir / "checkpoint.pb",
  354. )
  355. # 5. Save and log the model and update the link to the best model
  356. if trainer_options.use_pai:
  357. buffer = BytesIO()
  358. torch.save(model.state_dict(), buffer)
  359. trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir,
  360. f"{iepoch}epoch.pb"),buffer.getvalue())
  361. else:
  362. torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pb")
  363. # Creates a sym link latest.pb -> {iepoch}epoch.pb
  364. if trainer_options.use_pai:
  365. p = os.path.join(trainer_options.output_dir, "latest.pb")
  366. if trainer_options.oss_bucket.object_exists(p):
  367. trainer_options.oss_bucket.delete_object(p)
  368. trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
  369. os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pb"), p)
  370. else:
  371. p = output_dir / "latest.pb"
  372. if p.is_symlink() or p.exists():
  373. p.unlink()
  374. p.symlink_to(f"{iepoch}epoch.pb")
  375. _improved = []
  376. for _phase, k, _mode in trainer_options.best_model_criterion:
  377. # e.g. _phase, k, _mode = "train", "loss", "min"
  378. if reporter.has(_phase, k):
  379. best_epoch = reporter.get_best_epoch(_phase, k, _mode)
  380. # Creates sym links if it's the best result
  381. if best_epoch == iepoch:
  382. if trainer_options.use_pai:
  383. p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pb")
  384. if trainer_options.oss_bucket.object_exists(p):
  385. trainer_options.oss_bucket.delete_object(p)
  386. trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
  387. os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pb"),p)
  388. else:
  389. p = output_dir / f"{_phase}.{k}.best.pb"
  390. if p.is_symlink() or p.exists():
  391. p.unlink()
  392. p.symlink_to(f"{iepoch}epoch.pb")
  393. _improved.append(f"{_phase}.{k}")
  394. if len(_improved) == 0:
  395. logging.info("There are no improvements in this epoch")
  396. else:
  397. logging.info(
  398. "The best model has been updated: " + ", ".join(_improved)
  399. )
  400. log_model = (
  401. trainer_options.wandb_model_log_interval > 0
  402. and iepoch % trainer_options.wandb_model_log_interval == 0
  403. )
  404. if log_model and trainer_options.use_wandb:
  405. import wandb
  406. logging.info("Logging Model on this epoch :::::")
  407. artifact = wandb.Artifact(
  408. name=f"model_{wandb.run.id}",
  409. type="model",
  410. metadata={"improved": _improved},
  411. )
  412. artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
  413. aliases = [
  414. f"epoch-{iepoch}",
  415. "best" if best_epoch == iepoch else "",
  416. ]
  417. wandb.log_artifact(artifact, aliases=aliases)
  418. # 6. Remove the model files excluding n-best epoch and latest epoch
  419. _removed = []
  420. # Get the union set of the n-best among multiple criterion
  421. nbests = set().union(
  422. *[
  423. set(reporter.sort_epochs(ph, k, m)[: max(keep_nbest_models)])
  424. for ph, k, m in trainer_options.best_model_criterion
  425. if reporter.has(ph, k)
  426. ]
  427. )
  428. # Generated n-best averaged model
  429. if (
  430. trainer_options.nbest_averaging_interval > 0
  431. and iepoch % trainer_options.nbest_averaging_interval == 0
  432. ):
  433. average_nbest_models(
  434. reporter=reporter,
  435. output_dir=output_dir,
  436. best_model_criterion=trainer_options.best_model_criterion,
  437. nbest=keep_nbest_models,
  438. suffix=f"till{iepoch}epoch",
  439. oss_bucket=trainer_options.oss_bucket,
  440. pai_output_dir=trainer_options.output_dir,
  441. )
  442. for e in range(1, iepoch):
  443. if trainer_options.use_pai:
  444. p = os.path.join(trainer_options.output_dir, f"{e}epoch.pb")
  445. if trainer_options.oss_bucket.object_exists(p) and e not in nbests:
  446. trainer_options.oss_bucket.delete_object(p)
  447. _removed.append(str(p))
  448. else:
  449. p = output_dir / f"{e}epoch.pb"
  450. if p.exists() and e not in nbests:
  451. p.unlink()
  452. _removed.append(str(p))
  453. if len(_removed) != 0:
  454. logging.info("The model files were removed: " + ", ".join(_removed))
  455. # 7. If any updating haven't happened, stops the training
  456. if all_steps_are_invalid:
  457. logging.warning(
  458. f"The gradients at all steps are invalid in this epoch. "
  459. f"Something seems wrong. This training was stopped at {iepoch}epoch"
  460. )
  461. break
  462. if max_update_stop:
  463. logging.info(
  464. f"Stopping training due to "
  465. f"num_updates: {trainer_options.num_updates} >= max_update: {trainer_options.max_update}"
  466. )
  467. break
  468. # 8. Check early stopping
  469. if trainer_options.patience is not None:
  470. if reporter.check_early_stopping(
  471. trainer_options.patience, *trainer_options.early_stopping_criterion
  472. ):
  473. break
  474. else:
  475. logging.info(
  476. f"The training was finished at {trainer_options.max_epoch} epochs "
  477. )
  478. # Generated n-best averaged model
  479. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  480. average_nbest_models(
  481. reporter=reporter,
  482. output_dir=output_dir,
  483. best_model_criterion=trainer_options.best_model_criterion,
  484. nbest=keep_nbest_models,
  485. oss_bucket=trainer_options.oss_bucket,
  486. pai_output_dir=trainer_options.output_dir,
  487. )
  488. @classmethod
  489. def train_one_epoch(
  490. cls,
  491. model: torch.nn.Module,
  492. iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
  493. optimizers: Sequence[torch.optim.Optimizer],
  494. schedulers: Sequence[Optional[AbsScheduler]],
  495. scaler: Optional[GradScaler],
  496. reporter: SubReporter,
  497. summary_writer,
  498. options: TrainerOptions,
  499. distributed_option: DistributedOption,
  500. ) -> Tuple[bool, bool]:
  501. grad_noise = options.grad_noise
  502. accum_grad = options.accum_grad
  503. grad_clip = options.grad_clip
  504. grad_clip_type = options.grad_clip_type
  505. log_interval = options.log_interval
  506. no_forward_run = options.no_forward_run
  507. ngpu = options.ngpu
  508. use_wandb = options.use_wandb
  509. bias_grad_times = options.bias_grad_times
  510. distributed = distributed_option.distributed
  511. if bias_grad_times != 1.0:
  512. logging.warning("Using bias_grad_times: {} for gradient scaling".format(bias_grad_times))
  513. if log_interval is None:
  514. try:
  515. log_interval = max(len(iterator) // 20, 10)
  516. except TypeError:
  517. log_interval = 100
  518. model.train()
  519. all_steps_are_invalid = True
  520. max_update_stop = False
  521. # [For distributed] Because iteration counts are not always equals between
  522. # processes, send stop-flag to the other processes if iterator is finished
  523. iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
  524. #get the rank
  525. rank = distributed_option.dist_rank
  526. #get the num batch updates
  527. num_batch_updates = 0
  528. #ouput dir
  529. output_dir = Path(options.output_dir)
  530. #batch interval
  531. batch_interval = options.batch_interval
  532. start_time = time.perf_counter()
  533. for iiter, (_, batch) in enumerate(
  534. reporter.measure_iter_time(iterator, "iter_time"), 1
  535. ):
  536. assert isinstance(batch, dict), type(batch)
  537. if batch_interval > 0 and (not distributed_option.distributed or rank == 0):
  538. if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")):
  539. num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates()
  540. if num_batch_updates % batch_interval == 0:
  541. if options.use_pai and options.oss_bucket is not None:
  542. buffer = BytesIO()
  543. if hasattr(model, "module"):
  544. torch.save(model.module.state_dict(), buffer)
  545. else:
  546. torch.save(model.state_dict(), buffer)
  547. options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue())
  548. else:
  549. if hasattr(model, "module"):
  550. torch.save(model.module.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
  551. else:
  552. torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
  553. if distributed:
  554. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
  555. if iterator_stop > 0:
  556. break
  557. batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
  558. if no_forward_run:
  559. all_steps_are_invalid = False
  560. continue
  561. if iiter == 1 and summary_writer is not None:
  562. try:
  563. args = kwargs2args(model.forward, batch)
  564. except (ValueError, TypeError):
  565. logging.warning(
  566. "inpect.signature() is failed for the model. "
  567. "The graph can't be added for tensorboard."
  568. )
  569. else:
  570. try:
  571. summary_writer.add_graph(model, args, use_strict_trace=False)
  572. except Exception:
  573. logging.warning(
  574. "summary_writer.add_graph() is failed for the model. "
  575. "The graph can't be added for tensorboard."
  576. )
  577. del args
  578. with autocast(scaler is not None):
  579. with reporter.measure_time("forward_time"):
  580. retval = model(**batch)
  581. # Note(kamo):
  582. # Supporting two patterns for the returned value from the model
  583. # a. dict type
  584. if isinstance(retval, dict):
  585. loss = retval["loss"]
  586. stats = retval["stats"]
  587. weight = retval["weight"]
  588. optim_idx = retval.get("optim_idx")
  589. if optim_idx is not None and not isinstance(optim_idx, int):
  590. if not isinstance(optim_idx, torch.Tensor):
  591. raise RuntimeError(
  592. "optim_idx must be int or 1dim torch.Tensor, "
  593. f"but got {type(optim_idx)}"
  594. )
  595. if optim_idx.dim() >= 2:
  596. raise RuntimeError(
  597. "optim_idx must be int or 1dim torch.Tensor, "
  598. f"but got {optim_idx.dim()}dim tensor"
  599. )
  600. if optim_idx.dim() == 1:
  601. for v in optim_idx:
  602. if v != optim_idx[0]:
  603. raise RuntimeError(
  604. "optim_idx must be 1dim tensor "
  605. "having same values for all entries"
  606. )
  607. optim_idx = optim_idx[0].item()
  608. else:
  609. optim_idx = optim_idx.item()
  610. # b. tuple or list type
  611. else:
  612. loss, stats, weight = retval
  613. optim_idx = None
  614. stats = {k: v for k, v in stats.items() if v is not None}
  615. if ngpu > 1 or distributed:
  616. # Apply weighted averaging for loss and stats
  617. loss = (loss * weight.type(loss.dtype)).sum()
  618. # if distributed, this method can also apply all_reduce()
  619. stats, weight = recursive_average(stats, weight, distributed)
  620. # Now weight is summation over all workers
  621. loss /= weight
  622. if distributed:
  623. # NOTE(kamo): Multiply world_size because DistributedDataParallel
  624. # automatically normalizes the gradient by world_size.
  625. loss *= torch.distributed.get_world_size()
  626. loss /= accum_grad
  627. reporter.register(stats, weight)
  628. with reporter.measure_time("backward_time"):
  629. if scaler is not None:
  630. # Scales loss. Calls backward() on scaled loss
  631. # to create scaled gradients.
  632. # Backward passes under autocast are not recommended.
  633. # Backward ops run in the same dtype autocast chose
  634. # for corresponding forward ops.
  635. scaler.scale(loss).backward()
  636. else:
  637. loss.backward()
  638. if iiter % accum_grad == 0:
  639. if scaler is not None:
  640. # Unscales the gradients of optimizer's assigned params in-place
  641. for iopt, optimizer in enumerate(optimizers):
  642. if optim_idx is not None and iopt != optim_idx:
  643. continue
  644. scaler.unscale_(optimizer)
  645. # gradient noise injection
  646. if grad_noise:
  647. add_gradient_noise(
  648. model,
  649. reporter.get_total_count(),
  650. duration=100,
  651. eta=1.0,
  652. scale_factor=0.55,
  653. )
  654. # for contextual training
  655. if bias_grad_times != 1.0:
  656. # contextual related parameter names
  657. cr_pnames = ["bias_encoder", "bias_embed", "decoder.bias_decoder", "decoder.bias_output"]
  658. for name, param in model.named_parameters():
  659. for cr_pname in cr_pnames:
  660. if cr_pname in name:
  661. param.grad *= bias_grad_times
  662. continue
  663. # compute the gradient norm to check if it is normal or not
  664. grad_norm = torch.nn.utils.clip_grad_norm_(
  665. model.parameters(),
  666. max_norm=grad_clip,
  667. norm_type=grad_clip_type,
  668. )
  669. # PyTorch<=1.4, clip_grad_norm_ returns float value
  670. if not isinstance(grad_norm, torch.Tensor):
  671. grad_norm = torch.tensor(grad_norm)
  672. if not torch.isfinite(grad_norm):
  673. logging.warning(
  674. f"The grad norm is {grad_norm}. Skipping updating the model."
  675. )
  676. # Must invoke scaler.update() if unscale_() is used in the iteration
  677. # to avoid the following error:
  678. # RuntimeError: unscale_() has already been called
  679. # on this optimizer since the last update().
  680. # Note that if the gradient has inf/nan values,
  681. # scaler.step skips optimizer.step().
  682. if scaler is not None:
  683. for iopt, optimizer in enumerate(optimizers):
  684. if optim_idx is not None and iopt != optim_idx:
  685. continue
  686. scaler.step(optimizer)
  687. scaler.update()
  688. else:
  689. all_steps_are_invalid = False
  690. with reporter.measure_time("optim_step_time"):
  691. for iopt, (optimizer, scheduler) in enumerate(
  692. zip(optimizers, schedulers)
  693. ):
  694. if optim_idx is not None and iopt != optim_idx:
  695. continue
  696. if scaler is not None:
  697. # scaler.step() first unscales the gradients of
  698. # the optimizer's assigned params.
  699. scaler.step(optimizer)
  700. # Updates the scale for next iteration.
  701. scaler.update()
  702. else:
  703. optimizer.step()
  704. if isinstance(scheduler, AbsBatchStepScheduler):
  705. scheduler.step()
  706. for iopt, optimizer in enumerate(optimizers):
  707. if optim_idx is not None and iopt != optim_idx:
  708. continue
  709. optimizer.zero_grad()
  710. # Register lr and train/load time[sec/step],
  711. # where step refers to accum_grad * mini-batch
  712. reporter.register(
  713. dict(
  714. {
  715. f"optim{i}_lr{j}": pg["lr"]
  716. for i, optimizer in enumerate(optimizers)
  717. for j, pg in enumerate(optimizer.param_groups)
  718. if "lr" in pg
  719. },
  720. train_time=time.perf_counter() - start_time,
  721. ),
  722. )
  723. start_time = time.perf_counter()
  724. # update num_updates
  725. if distributed:
  726. if hasattr(model.module, "num_updates"):
  727. model.module.set_num_updates(model.module.get_num_updates() + 1)
  728. options.num_updates = model.module.get_num_updates()
  729. if model.module.get_num_updates() >= options.max_update:
  730. max_update_stop = True
  731. else:
  732. if hasattr(model, "num_updates"):
  733. model.set_num_updates(model.get_num_updates() + 1)
  734. options.num_updates = model.get_num_updates()
  735. if model.get_num_updates() >= options.max_update:
  736. max_update_stop = True
  737. # NOTE(kamo): Call log_message() after next()
  738. reporter.next()
  739. if iiter % log_interval == 0:
  740. num_updates = options.num_updates if hasattr(options, "num_updates") else None
  741. logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
  742. if summary_writer is not None:
  743. reporter.tensorboard_add_scalar(summary_writer, -log_interval)
  744. if use_wandb:
  745. reporter.wandb_log()
  746. if max_update_stop:
  747. break
  748. else:
  749. if distributed:
  750. iterator_stop.fill_(1)
  751. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
  752. return all_steps_are_invalid, max_update_stop
  753. @classmethod
  754. @torch.no_grad()
  755. def validate_one_epoch(
  756. cls,
  757. model: torch.nn.Module,
  758. iterator: Iterable[Dict[str, torch.Tensor]],
  759. reporter: SubReporter,
  760. options: TrainerOptions,
  761. distributed_option: DistributedOption,
  762. ) -> None:
  763. ngpu = options.ngpu
  764. no_forward_run = options.no_forward_run
  765. distributed = distributed_option.distributed
  766. model.eval()
  767. # [For distributed] Because iteration counts are not always equals between
  768. # processes, send stop-flag to the other processes if iterator is finished
  769. iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
  770. for (_, batch) in iterator:
  771. assert isinstance(batch, dict), type(batch)
  772. if distributed:
  773. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
  774. if iterator_stop > 0:
  775. break
  776. batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
  777. if no_forward_run:
  778. continue
  779. retval = model(**batch)
  780. if isinstance(retval, dict):
  781. stats = retval["stats"]
  782. weight = retval["weight"]
  783. else:
  784. _, stats, weight = retval
  785. if ngpu > 1 or distributed:
  786. # Apply weighted averaging for stats.
  787. # if distributed, this method can also apply all_reduce()
  788. stats, weight = recursive_average(stats, weight, distributed)
  789. reporter.register(stats, weight)
  790. reporter.next()
  791. else:
  792. if distributed:
  793. iterator_stop.fill_(1)
  794. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)