trainer.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840
  1. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Trainer module."""
  4. import argparse
  5. from contextlib import contextmanager
  6. import dataclasses
  7. from dataclasses import is_dataclass
  8. from distutils.version import LooseVersion
  9. import logging
  10. from pathlib import Path
  11. import time
  12. from typing import Dict
  13. from typing import Iterable
  14. from typing import List
  15. from typing import Optional
  16. from typing import Sequence
  17. from typing import Tuple
  18. from typing import Union
  19. import humanfriendly
  20. import oss2
  21. from io import BytesIO
  22. import os
  23. import numpy as np
  24. import torch
  25. import torch.nn
  26. import torch.optim
  27. from typeguard import check_argument_types
  28. from funasr.iterators.abs_iter_factory import AbsIterFactory
  29. from funasr.main_funcs.average_nbest_models import average_nbest_models
  30. from funasr.main_funcs.calculate_all_attentions import calculate_all_attentions
  31. from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
  32. from funasr.schedulers.abs_scheduler import AbsEpochStepScheduler
  33. from funasr.schedulers.abs_scheduler import AbsScheduler
  34. from funasr.schedulers.abs_scheduler import AbsValEpochStepScheduler
  35. from funasr.torch_utils.add_gradient_noise import add_gradient_noise
  36. from funasr.torch_utils.device_funcs import to_device
  37. from funasr.torch_utils.recursive_op import recursive_average
  38. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  39. from funasr.train.abs_espnet_model import AbsESPnetModel
  40. from funasr.train.distributed_utils import DistributedOption
  41. from funasr.train.reporter import Reporter
  42. from funasr.train.reporter import SubReporter
  43. from funasr.utils.build_dataclass import build_dataclass
  44. if torch.distributed.is_available():
  45. from torch.distributed import ReduceOp
  46. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  47. from torch.cuda.amp import autocast
  48. from torch.cuda.amp import GradScaler
  49. else:
  50. # Nothing to do if torch<1.6.0
  51. @contextmanager
  52. def autocast(enabled=True):
  53. yield
  54. GradScaler = None
  55. try:
  56. import fairscale
  57. except ImportError:
  58. fairscale = None
  59. @dataclasses.dataclass
  60. class TrainerOptions:
  61. ngpu: int
  62. resume: bool
  63. use_amp: bool
  64. train_dtype: str
  65. grad_noise: bool
  66. accum_grad: int
  67. grad_clip: float
  68. grad_clip_type: float
  69. log_interval: Optional[int]
  70. no_forward_run: bool
  71. use_tensorboard: bool
  72. use_wandb: bool
  73. output_dir: Union[Path, str]
  74. max_epoch: int
  75. max_update: int
  76. seed: int
  77. sharded_ddp: bool
  78. patience: Optional[int]
  79. keep_nbest_models: Union[int, List[int]]
  80. nbest_averaging_interval: int
  81. early_stopping_criterion: Sequence[str]
  82. best_model_criterion: Sequence[Sequence[str]]
  83. val_scheduler_criterion: Sequence[str]
  84. unused_parameters: bool
  85. wandb_model_log_interval: int
  86. use_pai: bool
  87. oss_bucket: Union[oss2.Bucket, None]
  88. batch_interval: int
  89. class Trainer:
  90. """Trainer having a optimizer.
  91. If you'd like to use multiple optimizers, then inherit this class
  92. and override the methods if necessary - at least "train_one_epoch()"
  93. >>> class TwoOptimizerTrainer(Trainer):
  94. ... @classmethod
  95. ... def add_arguments(cls, parser):
  96. ... ...
  97. ...
  98. ... @classmethod
  99. ... def train_one_epoch(cls, model, optimizers, ...):
  100. ... loss1 = model.model1(...)
  101. ... loss1.backward()
  102. ... optimizers[0].step()
  103. ...
  104. ... loss2 = model.model2(...)
  105. ... loss2.backward()
  106. ... optimizers[1].step()
  107. """
  108. def __init__(self):
  109. raise RuntimeError("This class can't be instantiated.")
  110. @classmethod
  111. def build_options(cls, args: argparse.Namespace) -> TrainerOptions:
  112. """Build options consumed by train(), eval()"""
  113. assert check_argument_types()
  114. return build_dataclass(TrainerOptions, args)
  115. @classmethod
  116. def add_arguments(cls, parser: argparse.ArgumentParser):
  117. """Reserved for future development of another Trainer"""
  118. pass
  119. @staticmethod
  120. def resume(
  121. checkpoint: Union[str, Path],
  122. model: torch.nn.Module,
  123. reporter: Reporter,
  124. optimizers: Sequence[torch.optim.Optimizer],
  125. schedulers: Sequence[Optional[AbsScheduler]],
  126. scaler: Optional[GradScaler],
  127. ngpu: int = 0,
  128. ):
  129. states = torch.load(
  130. checkpoint,
  131. map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
  132. )
  133. model.load_state_dict(states["model"])
  134. reporter.load_state_dict(states["reporter"])
  135. for optimizer, state in zip(optimizers, states["optimizers"]):
  136. optimizer.load_state_dict(state)
  137. for scheduler, state in zip(schedulers, states["schedulers"]):
  138. if scheduler is not None:
  139. scheduler.load_state_dict(state)
  140. if scaler is not None:
  141. if states["scaler"] is None:
  142. logging.warning("scaler state is not found")
  143. else:
  144. scaler.load_state_dict(states["scaler"])
  145. logging.info(f"The training was resumed using {checkpoint}")
  146. @classmethod
  147. def run(
  148. cls,
  149. model: AbsESPnetModel,
  150. optimizers: Sequence[torch.optim.Optimizer],
  151. schedulers: Sequence[Optional[AbsScheduler]],
  152. train_iter_factory: AbsIterFactory,
  153. valid_iter_factory: AbsIterFactory,
  154. trainer_options,
  155. distributed_option: DistributedOption,
  156. ) -> None:
  157. """Perform training. This method performs the main process of training."""
  158. assert check_argument_types()
  159. # NOTE(kamo): Don't check the type more strictly as far trainer_options
  160. assert is_dataclass(trainer_options), type(trainer_options)
  161. assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers))
  162. if isinstance(trainer_options.keep_nbest_models, int):
  163. keep_nbest_models = [trainer_options.keep_nbest_models]
  164. else:
  165. if len(trainer_options.keep_nbest_models) == 0:
  166. logging.warning("No keep_nbest_models is given. Change to [1]")
  167. trainer_options.keep_nbest_models = [1]
  168. keep_nbest_models = trainer_options.keep_nbest_models
  169. output_dir = Path(trainer_options.output_dir)
  170. reporter = Reporter()
  171. if trainer_options.use_amp:
  172. if LooseVersion(torch.__version__) < LooseVersion("1.6.0"):
  173. raise RuntimeError(
  174. "Require torch>=1.6.0 for Automatic Mixed Precision"
  175. )
  176. if trainer_options.sharded_ddp:
  177. if fairscale is None:
  178. raise RuntimeError(
  179. "Requiring fairscale. Do 'pip install fairscale'"
  180. )
  181. scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
  182. else:
  183. scaler = GradScaler()
  184. else:
  185. scaler = None
  186. if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
  187. cls.resume(
  188. checkpoint=output_dir / "checkpoint.pb",
  189. model=model,
  190. optimizers=optimizers,
  191. schedulers=schedulers,
  192. reporter=reporter,
  193. scaler=scaler,
  194. ngpu=trainer_options.ngpu,
  195. )
  196. start_epoch = reporter.get_epoch() + 1
  197. if start_epoch == trainer_options.max_epoch + 1:
  198. logging.warning(
  199. f"The training has already reached at max_epoch: {start_epoch}"
  200. )
  201. if distributed_option.distributed:
  202. if trainer_options.sharded_ddp:
  203. dp_model = fairscale.nn.data_parallel.ShardedDataParallel(
  204. module=model,
  205. sharded_optimizer=optimizers,
  206. )
  207. else:
  208. dp_model = torch.nn.parallel.DistributedDataParallel(
  209. model, find_unused_parameters=trainer_options.unused_parameters)
  210. elif distributed_option.ngpu > 1:
  211. dp_model = torch.nn.parallel.DataParallel(
  212. model,
  213. device_ids=list(range(distributed_option.ngpu)),
  214. )
  215. else:
  216. # NOTE(kamo): DataParallel also should work with ngpu=1,
  217. # but for debuggability it's better to keep this block.
  218. dp_model = model
  219. if trainer_options.use_tensorboard and (
  220. not distributed_option.distributed or distributed_option.dist_rank == 0
  221. ):
  222. from torch.utils.tensorboard import SummaryWriter
  223. if trainer_options.use_pai:
  224. train_summary_writer = SummaryWriter(
  225. os.path.join(trainer_options.output_dir, "tensorboard/train")
  226. )
  227. valid_summary_writer = SummaryWriter(
  228. os.path.join(trainer_options.output_dir, "tensorboard/valid")
  229. )
  230. else:
  231. train_summary_writer = SummaryWriter(
  232. str(output_dir / "tensorboard" / "train")
  233. )
  234. valid_summary_writer = SummaryWriter(
  235. str(output_dir / "tensorboard" / "valid")
  236. )
  237. else:
  238. train_summary_writer = None
  239. start_time = time.perf_counter()
  240. for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
  241. if iepoch != start_epoch:
  242. logging.info(
  243. "{}/{}epoch started. Estimated time to finish: {}".format(
  244. iepoch,
  245. trainer_options.max_epoch,
  246. humanfriendly.format_timespan(
  247. (time.perf_counter() - start_time)
  248. / (iepoch - start_epoch)
  249. * (trainer_options.max_epoch - iepoch + 1)
  250. ),
  251. )
  252. )
  253. else:
  254. logging.info(f"{iepoch}/{trainer_options.max_epoch}epoch started")
  255. set_all_random_seed(trainer_options.seed + iepoch)
  256. reporter.set_epoch(iepoch)
  257. # 1. Train and validation for one-epoch
  258. with reporter.observe("train") as sub_reporter:
  259. all_steps_are_invalid, max_update_stop = cls.train_one_epoch(
  260. model=dp_model,
  261. optimizers=optimizers,
  262. schedulers=schedulers,
  263. iterator=train_iter_factory.build_iter(iepoch),
  264. reporter=sub_reporter,
  265. scaler=scaler,
  266. summary_writer=train_summary_writer,
  267. options=trainer_options,
  268. distributed_option=distributed_option,
  269. )
  270. with reporter.observe("valid") as sub_reporter:
  271. cls.validate_one_epoch(
  272. model=dp_model,
  273. iterator=valid_iter_factory.build_iter(iepoch),
  274. reporter=sub_reporter,
  275. options=trainer_options,
  276. distributed_option=distributed_option,
  277. )
  278. # 2. LR Scheduler step
  279. for scheduler in schedulers:
  280. if isinstance(scheduler, AbsValEpochStepScheduler):
  281. scheduler.step(
  282. reporter.get_value(*trainer_options.val_scheduler_criterion)
  283. )
  284. elif isinstance(scheduler, AbsEpochStepScheduler):
  285. scheduler.step()
  286. if trainer_options.sharded_ddp:
  287. for optimizer in optimizers:
  288. if isinstance(optimizer, fairscale.optim.oss.OSS):
  289. optimizer.consolidate_state_dict()
  290. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  291. # 3. Report the results
  292. logging.info(reporter.log_message())
  293. if train_summary_writer is not None:
  294. reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
  295. reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
  296. if trainer_options.use_wandb:
  297. reporter.wandb_log()
  298. # save tensorboard on oss
  299. if trainer_options.use_pai and train_summary_writer is not None:
  300. def write_tensorboard_summary(summary_writer_path, oss_bucket):
  301. file_list = []
  302. for root, dirs, files in os.walk(summary_writer_path, topdown=False):
  303. for name in files:
  304. file_full_path = os.path.join(root, name)
  305. file_list.append(file_full_path)
  306. for file_full_path in file_list:
  307. with open(file_full_path, "rb") as f:
  308. oss_bucket.put_object(file_full_path, f)
  309. write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/train"), trainer_options.oss_bucket)
  310. write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/valid"), trainer_options.oss_bucket)
  311. # 4. Save/Update the checkpoint
  312. if trainer_options.use_pai:
  313. buffer = BytesIO()
  314. torch.save(
  315. {
  316. "model": model.state_dict(),
  317. "reporter": reporter.state_dict(),
  318. "optimizers": [o.state_dict() for o in optimizers],
  319. "schedulers": [
  320. s.state_dict() if s is not None else None
  321. for s in schedulers
  322. ],
  323. "scaler": scaler.state_dict() if scaler is not None else None,
  324. "ema_model": model.encoder.ema.model.state_dict()
  325. if hasattr(model.encoder, "ema") and model.encoder.ema is not None else None,
  326. },
  327. buffer,
  328. )
  329. trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pb"), buffer.getvalue())
  330. else:
  331. torch.save(
  332. {
  333. "model": model.state_dict(),
  334. "reporter": reporter.state_dict(),
  335. "optimizers": [o.state_dict() for o in optimizers],
  336. "schedulers": [
  337. s.state_dict() if s is not None else None
  338. for s in schedulers
  339. ],
  340. "scaler": scaler.state_dict() if scaler is not None else None,
  341. },
  342. output_dir / "checkpoint.pb",
  343. )
  344. # 5. Save and log the model and update the link to the best model
  345. if trainer_options.use_pai:
  346. buffer = BytesIO()
  347. torch.save(model.state_dict(), buffer)
  348. trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir,
  349. f"{iepoch}epoch.pb"),buffer.getvalue())
  350. else:
  351. torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pb")
  352. # Creates a sym link latest.pb -> {iepoch}epoch.pb
  353. if trainer_options.use_pai:
  354. p = os.path.join(trainer_options.output_dir, "latest.pb")
  355. if trainer_options.oss_bucket.object_exists(p):
  356. trainer_options.oss_bucket.delete_object(p)
  357. trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
  358. os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pb"), p)
  359. else:
  360. p = output_dir / "latest.pb"
  361. if p.is_symlink() or p.exists():
  362. p.unlink()
  363. p.symlink_to(f"{iepoch}epoch.pb")
  364. _improved = []
  365. for _phase, k, _mode in trainer_options.best_model_criterion:
  366. # e.g. _phase, k, _mode = "train", "loss", "min"
  367. if reporter.has(_phase, k):
  368. best_epoch = reporter.get_best_epoch(_phase, k, _mode)
  369. # Creates sym links if it's the best result
  370. if best_epoch == iepoch:
  371. if trainer_options.use_pai:
  372. p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pb")
  373. if trainer_options.oss_bucket.object_exists(p):
  374. trainer_options.oss_bucket.delete_object(p)
  375. trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
  376. os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pb"),p)
  377. else:
  378. p = output_dir / f"{_phase}.{k}.best.pb"
  379. if p.is_symlink() or p.exists():
  380. p.unlink()
  381. p.symlink_to(f"{iepoch}epoch.pb")
  382. _improved.append(f"{_phase}.{k}")
  383. if len(_improved) == 0:
  384. logging.info("There are no improvements in this epoch")
  385. else:
  386. logging.info(
  387. "The best model has been updated: " + ", ".join(_improved)
  388. )
  389. log_model = (
  390. trainer_options.wandb_model_log_interval > 0
  391. and iepoch % trainer_options.wandb_model_log_interval == 0
  392. )
  393. if log_model and trainer_options.use_wandb:
  394. import wandb
  395. logging.info("Logging Model on this epoch :::::")
  396. artifact = wandb.Artifact(
  397. name=f"model_{wandb.run.id}",
  398. type="model",
  399. metadata={"improved": _improved},
  400. )
  401. artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
  402. aliases = [
  403. f"epoch-{iepoch}",
  404. "best" if best_epoch == iepoch else "",
  405. ]
  406. wandb.log_artifact(artifact, aliases=aliases)
  407. # 6. Remove the model files excluding n-best epoch and latest epoch
  408. _removed = []
  409. # Get the union set of the n-best among multiple criterion
  410. nbests = set().union(
  411. *[
  412. set(reporter.sort_epochs(ph, k, m)[: max(keep_nbest_models)])
  413. for ph, k, m in trainer_options.best_model_criterion
  414. if reporter.has(ph, k)
  415. ]
  416. )
  417. # Generated n-best averaged model
  418. if (
  419. trainer_options.nbest_averaging_interval > 0
  420. and iepoch % trainer_options.nbest_averaging_interval == 0
  421. ):
  422. average_nbest_models(
  423. reporter=reporter,
  424. output_dir=output_dir,
  425. best_model_criterion=trainer_options.best_model_criterion,
  426. nbest=keep_nbest_models,
  427. suffix=f"till{iepoch}epoch",
  428. oss_bucket=trainer_options.oss_bucket,
  429. pai_output_dir=trainer_options.output_dir,
  430. )
  431. for e in range(1, iepoch):
  432. if trainer_options.use_pai:
  433. p = os.path.join(trainer_options.output_dir, f"{e}epoch.pb")
  434. if trainer_options.oss_bucket.object_exists(p) and e not in nbests:
  435. trainer_options.oss_bucket.delete_object(p)
  436. _removed.append(str(p))
  437. else:
  438. p = output_dir / f"{e}epoch.pb"
  439. if p.exists() and e not in nbests:
  440. p.unlink()
  441. _removed.append(str(p))
  442. if len(_removed) != 0:
  443. logging.info("The model files were removed: " + ", ".join(_removed))
  444. # 7. If any updating haven't happened, stops the training
  445. if all_steps_are_invalid:
  446. logging.warning(
  447. f"The gradients at all steps are invalid in this epoch. "
  448. f"Something seems wrong. This training was stopped at {iepoch}epoch"
  449. )
  450. break
  451. if max_update_stop:
  452. logging.info(
  453. f"Stopping training due to "
  454. f"num_updates: {trainer_options.num_updates} >= max_update: {trainer_options.max_update}"
  455. )
  456. break
  457. # 8. Check early stopping
  458. if trainer_options.patience is not None:
  459. if reporter.check_early_stopping(
  460. trainer_options.patience, *trainer_options.early_stopping_criterion
  461. ):
  462. break
  463. else:
  464. logging.info(
  465. f"The training was finished at {trainer_options.max_epoch} epochs "
  466. )
  467. # Generated n-best averaged model
  468. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  469. average_nbest_models(
  470. reporter=reporter,
  471. output_dir=output_dir,
  472. best_model_criterion=trainer_options.best_model_criterion,
  473. nbest=keep_nbest_models,
  474. oss_bucket=trainer_options.oss_bucket,
  475. pai_output_dir=trainer_options.output_dir,
  476. )
  477. @classmethod
  478. def train_one_epoch(
  479. cls,
  480. model: torch.nn.Module,
  481. iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
  482. optimizers: Sequence[torch.optim.Optimizer],
  483. schedulers: Sequence[Optional[AbsScheduler]],
  484. scaler: Optional[GradScaler],
  485. reporter: SubReporter,
  486. summary_writer,
  487. options: TrainerOptions,
  488. distributed_option: DistributedOption,
  489. ) -> Tuple[bool, bool]:
  490. assert check_argument_types()
  491. grad_noise = options.grad_noise
  492. accum_grad = options.accum_grad
  493. grad_clip = options.grad_clip
  494. grad_clip_type = options.grad_clip_type
  495. log_interval = options.log_interval
  496. no_forward_run = options.no_forward_run
  497. ngpu = options.ngpu
  498. use_wandb = options.use_wandb
  499. distributed = distributed_option.distributed
  500. if log_interval is None:
  501. try:
  502. log_interval = max(len(iterator) // 20, 10)
  503. except TypeError:
  504. log_interval = 100
  505. model.train()
  506. all_steps_are_invalid = True
  507. max_update_stop = False
  508. # [For distributed] Because iteration counts are not always equals between
  509. # processes, send stop-flag to the other processes if iterator is finished
  510. iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
  511. #get the rank
  512. rank = distributed_option.dist_rank
  513. #get the num batch updates
  514. num_batch_updates = 0
  515. #ouput dir
  516. output_dir = Path(options.output_dir)
  517. #batch interval
  518. batch_interval = options.batch_interval
  519. start_time = time.perf_counter()
  520. for iiter, (_, batch) in enumerate(
  521. reporter.measure_iter_time(iterator, "iter_time"), 1
  522. ):
  523. assert isinstance(batch, dict), type(batch)
  524. if batch_interval > 0 and (not distributed_option.distributed or rank == 0):
  525. if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")):
  526. num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates()
  527. if num_batch_updates % batch_interval == 0:
  528. if options.use_pai and options.oss_bucket is not None:
  529. buffer = BytesIO()
  530. if hasattr(model, "module"):
  531. torch.save(model.module.state_dict(), buffer)
  532. else:
  533. torch.save(model.state_dict(), buffer)
  534. options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue())
  535. else:
  536. if hasattr(model, "module"):
  537. torch.save(model.module.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
  538. else:
  539. torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
  540. if distributed:
  541. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
  542. if iterator_stop > 0:
  543. break
  544. batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
  545. if no_forward_run:
  546. all_steps_are_invalid = False
  547. continue
  548. with autocast(scaler is not None):
  549. with reporter.measure_time("forward_time"):
  550. retval = model(**batch)
  551. # Note(kamo):
  552. # Supporting two patterns for the returned value from the model
  553. # a. dict type
  554. if isinstance(retval, dict):
  555. loss = retval["loss"]
  556. stats = retval["stats"]
  557. weight = retval["weight"]
  558. optim_idx = retval.get("optim_idx")
  559. if optim_idx is not None and not isinstance(optim_idx, int):
  560. if not isinstance(optim_idx, torch.Tensor):
  561. raise RuntimeError(
  562. "optim_idx must be int or 1dim torch.Tensor, "
  563. f"but got {type(optim_idx)}"
  564. )
  565. if optim_idx.dim() >= 2:
  566. raise RuntimeError(
  567. "optim_idx must be int or 1dim torch.Tensor, "
  568. f"but got {optim_idx.dim()}dim tensor"
  569. )
  570. if optim_idx.dim() == 1:
  571. for v in optim_idx:
  572. if v != optim_idx[0]:
  573. raise RuntimeError(
  574. "optim_idx must be 1dim tensor "
  575. "having same values for all entries"
  576. )
  577. optim_idx = optim_idx[0].item()
  578. else:
  579. optim_idx = optim_idx.item()
  580. # b. tuple or list type
  581. else:
  582. loss, stats, weight = retval
  583. optim_idx = None
  584. stats = {k: v for k, v in stats.items() if v is not None}
  585. if ngpu > 1 or distributed:
  586. # Apply weighted averaging for loss and stats
  587. loss = (loss * weight.type(loss.dtype)).sum()
  588. # if distributed, this method can also apply all_reduce()
  589. stats, weight = recursive_average(stats, weight, distributed)
  590. # Now weight is summation over all workers
  591. loss /= weight
  592. if distributed:
  593. # NOTE(kamo): Multiply world_size because DistributedDataParallel
  594. # automatically normalizes the gradient by world_size.
  595. loss *= torch.distributed.get_world_size()
  596. loss /= accum_grad
  597. reporter.register(stats, weight)
  598. with reporter.measure_time("backward_time"):
  599. if scaler is not None:
  600. # Scales loss. Calls backward() on scaled loss
  601. # to create scaled gradients.
  602. # Backward passes under autocast are not recommended.
  603. # Backward ops run in the same dtype autocast chose
  604. # for corresponding forward ops.
  605. scaler.scale(loss).backward()
  606. else:
  607. loss.backward()
  608. if iiter % accum_grad == 0:
  609. if scaler is not None:
  610. # Unscales the gradients of optimizer's assigned params in-place
  611. for iopt, optimizer in enumerate(optimizers):
  612. if optim_idx is not None and iopt != optim_idx:
  613. continue
  614. scaler.unscale_(optimizer)
  615. # gradient noise injection
  616. if grad_noise:
  617. add_gradient_noise(
  618. model,
  619. reporter.get_total_count(),
  620. duration=100,
  621. eta=1.0,
  622. scale_factor=0.55,
  623. )
  624. # compute the gradient norm to check if it is normal or not
  625. grad_norm = torch.nn.utils.clip_grad_norm_(
  626. model.parameters(),
  627. max_norm=grad_clip,
  628. norm_type=grad_clip_type,
  629. )
  630. # PyTorch<=1.4, clip_grad_norm_ returns float value
  631. if not isinstance(grad_norm, torch.Tensor):
  632. grad_norm = torch.tensor(grad_norm)
  633. if not torch.isfinite(grad_norm):
  634. logging.warning(
  635. f"The grad norm is {grad_norm}. Skipping updating the model."
  636. )
  637. # Must invoke scaler.update() if unscale_() is used in the iteration
  638. # to avoid the following error:
  639. # RuntimeError: unscale_() has already been called
  640. # on this optimizer since the last update().
  641. # Note that if the gradient has inf/nan values,
  642. # scaler.step skips optimizer.step().
  643. if scaler is not None:
  644. for iopt, optimizer in enumerate(optimizers):
  645. if optim_idx is not None and iopt != optim_idx:
  646. continue
  647. scaler.step(optimizer)
  648. scaler.update()
  649. else:
  650. all_steps_are_invalid = False
  651. with reporter.measure_time("optim_step_time"):
  652. for iopt, (optimizer, scheduler) in enumerate(
  653. zip(optimizers, schedulers)
  654. ):
  655. if optim_idx is not None and iopt != optim_idx:
  656. continue
  657. if scaler is not None:
  658. # scaler.step() first unscales the gradients of
  659. # the optimizer's assigned params.
  660. scaler.step(optimizer)
  661. # Updates the scale for next iteration.
  662. scaler.update()
  663. else:
  664. optimizer.step()
  665. if isinstance(scheduler, AbsBatchStepScheduler):
  666. scheduler.step()
  667. for iopt, optimizer in enumerate(optimizers):
  668. if optim_idx is not None and iopt != optim_idx:
  669. continue
  670. optimizer.zero_grad()
  671. # Register lr and train/load time[sec/step],
  672. # where step refers to accum_grad * mini-batch
  673. reporter.register(
  674. dict(
  675. {
  676. f"optim{i}_lr{j}": pg["lr"]
  677. for i, optimizer in enumerate(optimizers)
  678. for j, pg in enumerate(optimizer.param_groups)
  679. if "lr" in pg
  680. },
  681. train_time=time.perf_counter() - start_time,
  682. ),
  683. )
  684. start_time = time.perf_counter()
  685. # update num_updates
  686. if distributed:
  687. if hasattr(model.module, "num_updates"):
  688. model.module.set_num_updates(model.module.get_num_updates() + 1)
  689. options.num_updates = model.module.get_num_updates()
  690. if model.module.get_num_updates() >= options.max_update:
  691. max_update_stop = True
  692. else:
  693. if hasattr(model, "num_updates"):
  694. model.set_num_updates(model.get_num_updates() + 1)
  695. options.num_updates = model.get_num_updates()
  696. if model.get_num_updates() >= options.max_update:
  697. max_update_stop = True
  698. # NOTE(kamo): Call log_message() after next()
  699. reporter.next()
  700. if iiter % log_interval == 0:
  701. num_updates = options.num_updates if hasattr(options, "num_updates") else None
  702. logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
  703. if summary_writer is not None:
  704. reporter.tensorboard_add_scalar(summary_writer, -log_interval)
  705. if use_wandb:
  706. reporter.wandb_log()
  707. if max_update_stop:
  708. break
  709. else:
  710. if distributed:
  711. iterator_stop.fill_(1)
  712. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
  713. return all_steps_are_invalid, max_update_stop
  714. @classmethod
  715. @torch.no_grad()
  716. def validate_one_epoch(
  717. cls,
  718. model: torch.nn.Module,
  719. iterator: Iterable[Dict[str, torch.Tensor]],
  720. reporter: SubReporter,
  721. options: TrainerOptions,
  722. distributed_option: DistributedOption,
  723. ) -> None:
  724. assert check_argument_types()
  725. ngpu = options.ngpu
  726. no_forward_run = options.no_forward_run
  727. distributed = distributed_option.distributed
  728. model.eval()
  729. # [For distributed] Because iteration counts are not always equals between
  730. # processes, send stop-flag to the other processes if iterator is finished
  731. iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
  732. for (_, batch) in iterator:
  733. assert isinstance(batch, dict), type(batch)
  734. if distributed:
  735. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
  736. if iterator_stop > 0:
  737. break
  738. batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
  739. if no_forward_run:
  740. continue
  741. retval = model(**batch)
  742. if isinstance(retval, dict):
  743. stats = retval["stats"]
  744. weight = retval["weight"]
  745. else:
  746. _, stats, weight = retval
  747. if ngpu > 1 or distributed:
  748. # Apply weighted averaging for stats.
  749. # if distributed, this method can also apply all_reduce()
  750. stats, weight = recursive_average(stats, weight, distributed)
  751. reporter.register(stats, weight)
  752. reporter.next()
  753. else:
  754. if distributed:
  755. iterator_stop.fill_(1)
  756. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)