trainer.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886
  1. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Trainer module."""
  4. import argparse
  5. from contextlib import contextmanager
  6. import dataclasses
  7. from dataclasses import is_dataclass
  8. from distutils.version import LooseVersion
  9. import logging
  10. from pathlib import Path
  11. import time
  12. from typing import Dict
  13. from typing import Iterable
  14. from typing import List
  15. from typing import Optional
  16. from typing import Sequence
  17. from typing import Tuple
  18. from typing import Union
  19. import humanfriendly
  20. import oss2
  21. from io import BytesIO
  22. import os
  23. import numpy as np
  24. import torch
  25. import torch.nn
  26. import torch.optim
  27. from typeguard import check_argument_types
  28. from funasr.iterators.abs_iter_factory import AbsIterFactory
  29. from funasr.main_funcs.average_nbest_models import average_nbest_models
  30. from funasr.main_funcs.calculate_all_attentions import calculate_all_attentions
  31. from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
  32. from funasr.schedulers.abs_scheduler import AbsEpochStepScheduler
  33. from funasr.schedulers.abs_scheduler import AbsScheduler
  34. from funasr.schedulers.abs_scheduler import AbsValEpochStepScheduler
  35. from funasr.torch_utils.add_gradient_noise import add_gradient_noise
  36. from funasr.torch_utils.device_funcs import to_device
  37. from funasr.torch_utils.recursive_op import recursive_average
  38. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  39. from funasr.models.base_model import FunASRModel
  40. from funasr.train.distributed_utils import DistributedOption
  41. from funasr.train.reporter import Reporter
  42. from funasr.train.reporter import SubReporter
  43. from funasr.utils.build_dataclass import build_dataclass
  44. from funasr.utils.kwargs2args import kwargs2args
  45. if torch.distributed.is_available():
  46. from torch.distributed import ReduceOp
  47. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  48. from torch.cuda.amp import autocast
  49. from torch.cuda.amp import GradScaler
  50. else:
  51. # Nothing to do if torch<1.6.0
  52. @contextmanager
  53. def autocast(enabled=True):
  54. yield
  55. GradScaler = None
  56. try:
  57. import fairscale
  58. except ImportError:
  59. fairscale = None
  60. @dataclasses.dataclass
  61. class TrainerOptions:
  62. ngpu: int
  63. resume: bool
  64. use_amp: bool
  65. train_dtype: str
  66. grad_noise: bool
  67. accum_grad: int
  68. grad_clip: float
  69. grad_clip_type: float
  70. log_interval: Optional[int]
  71. no_forward_run: bool
  72. use_tensorboard: bool
  73. use_wandb: bool
  74. output_dir: Union[Path, str]
  75. max_epoch: int
  76. max_update: int
  77. seed: int
  78. sharded_ddp: bool
  79. patience: Optional[int]
  80. keep_nbest_models: Union[int, List[int]]
  81. nbest_averaging_interval: int
  82. early_stopping_criterion: Sequence[str]
  83. best_model_criterion: Sequence[Sequence[str]]
  84. val_scheduler_criterion: Sequence[str]
  85. unused_parameters: bool
  86. wandb_model_log_interval: int
  87. use_pai: bool
  88. oss_bucket: Union[oss2.Bucket, None]
  89. batch_interval: int
  90. bias_grad_times: float
  91. class Trainer:
  92. """Trainer having a optimizer.
  93. If you'd like to use multiple optimizers, then inherit this class
  94. and override the methods if necessary - at least "train_one_epoch()"
  95. >>> class TwoOptimizerTrainer(Trainer):
  96. ... @classmethod
  97. ... def add_arguments(cls, parser):
  98. ... ...
  99. ...
  100. ... @classmethod
  101. ... def train_one_epoch(cls, model, optimizers, ...):
  102. ... loss1 = model.model1(...)
  103. ... loss1.backward()
  104. ... optimizers[0].step()
  105. ...
  106. ... loss2 = model.model2(...)
  107. ... loss2.backward()
  108. ... optimizers[1].step()
  109. """
  110. def __init__(self):
  111. raise RuntimeError("This class can't be instantiated.")
  112. @classmethod
  113. def build_options(cls, args: argparse.Namespace) -> TrainerOptions:
  114. """Build options consumed by train(), eval()"""
  115. assert check_argument_types()
  116. return build_dataclass(TrainerOptions, args)
  117. @classmethod
  118. def add_arguments(cls, parser: argparse.ArgumentParser):
  119. """Reserved for future development of another Trainer"""
  120. pass
  121. @staticmethod
  122. def resume(
  123. checkpoint: Union[str, Path],
  124. model: torch.nn.Module,
  125. reporter: Reporter,
  126. optimizers: Sequence[torch.optim.Optimizer],
  127. schedulers: Sequence[Optional[AbsScheduler]],
  128. scaler: Optional[GradScaler],
  129. ngpu: int = 0,
  130. oss_bucket=None,
  131. ):
  132. if oss_bucket is None:
  133. if os.path.exists(checkpoint):
  134. states = torch.load(
  135. checkpoint,
  136. map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
  137. )
  138. else:
  139. return 0
  140. else:
  141. if oss_bucket.object_exists(checkpoint):
  142. buffer = BytesIO(oss_bucket.get_object(checkpoint).read())
  143. states = torch.load(buffer, map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",)
  144. else:
  145. return 0
  146. model.load_state_dict(states["model"])
  147. reporter.load_state_dict(states["reporter"])
  148. for optimizer, state in zip(optimizers, states["optimizers"]):
  149. optimizer.load_state_dict(state)
  150. for scheduler, state in zip(schedulers, states["schedulers"]):
  151. if scheduler is not None:
  152. scheduler.load_state_dict(state)
  153. if scaler is not None:
  154. if states["scaler"] is None:
  155. logging.warning("scaler state is not found")
  156. else:
  157. scaler.load_state_dict(states["scaler"])
  158. logging.info(f"The training was resumed using {checkpoint}")
  159. @classmethod
  160. def run(
  161. cls,
  162. model: FunASRModel,
  163. optimizers: Sequence[torch.optim.Optimizer],
  164. schedulers: Sequence[Optional[AbsScheduler]],
  165. train_iter_factory: AbsIterFactory,
  166. valid_iter_factory: AbsIterFactory,
  167. trainer_options,
  168. distributed_option: DistributedOption,
  169. ) -> None:
  170. """Perform training. This method performs the main process of training."""
  171. assert check_argument_types()
  172. # NOTE(kamo): Don't check the type more strictly as far trainer_options
  173. assert is_dataclass(trainer_options), type(trainer_options)
  174. assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers))
  175. if isinstance(trainer_options.keep_nbest_models, int):
  176. keep_nbest_models = [trainer_options.keep_nbest_models]
  177. else:
  178. if len(trainer_options.keep_nbest_models) == 0:
  179. logging.warning("No keep_nbest_models is given. Change to [1]")
  180. trainer_options.keep_nbest_models = [1]
  181. keep_nbest_models = trainer_options.keep_nbest_models
  182. output_dir = Path(trainer_options.output_dir)
  183. reporter = Reporter()
  184. if trainer_options.use_amp:
  185. if LooseVersion(torch.__version__) < LooseVersion("1.6.0"):
  186. raise RuntimeError(
  187. "Require torch>=1.6.0 for Automatic Mixed Precision"
  188. )
  189. if trainer_options.sharded_ddp:
  190. if fairscale is None:
  191. raise RuntimeError(
  192. "Requiring fairscale. Do 'pip install fairscale'"
  193. )
  194. scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
  195. else:
  196. scaler = GradScaler()
  197. else:
  198. scaler = None
  199. if trainer_options.resume:
  200. cls.resume(
  201. checkpoint=os.path.join(trainer_options.output_dir, "checkpoint.pb") if trainer_options.use_pai else output_dir / "checkpoint.pb",
  202. model=model,
  203. optimizers=optimizers,
  204. schedulers=schedulers,
  205. reporter=reporter,
  206. scaler=scaler,
  207. ngpu=trainer_options.ngpu,
  208. oss_bucket=trainer_options.oss_bucket if trainer_options.use_pai else None,
  209. )
  210. start_epoch = reporter.get_epoch() + 1
  211. if start_epoch == trainer_options.max_epoch + 1:
  212. logging.warning(
  213. f"The training has already reached at max_epoch: {start_epoch}"
  214. )
  215. if distributed_option.distributed:
  216. if trainer_options.sharded_ddp:
  217. dp_model = fairscale.nn.data_parallel.ShardedDataParallel(
  218. module=model,
  219. sharded_optimizer=optimizers,
  220. )
  221. else:
  222. dp_model = torch.nn.parallel.DistributedDataParallel(
  223. model, find_unused_parameters=trainer_options.unused_parameters)
  224. elif distributed_option.ngpu > 1:
  225. dp_model = torch.nn.parallel.DataParallel(
  226. model,
  227. device_ids=list(range(distributed_option.ngpu)),
  228. )
  229. else:
  230. # NOTE(kamo): DataParallel also should work with ngpu=1,
  231. # but for debuggability it's better to keep this block.
  232. dp_model = model
  233. if trainer_options.use_tensorboard and (
  234. not distributed_option.distributed or distributed_option.dist_rank == 0
  235. ):
  236. from torch.utils.tensorboard import SummaryWriter
  237. if trainer_options.use_pai:
  238. train_summary_writer = SummaryWriter(
  239. os.path.join(trainer_options.output_dir, "tensorboard/train")
  240. )
  241. valid_summary_writer = SummaryWriter(
  242. os.path.join(trainer_options.output_dir, "tensorboard/valid")
  243. )
  244. else:
  245. train_summary_writer = SummaryWriter(
  246. str(output_dir / "tensorboard" / "train")
  247. )
  248. valid_summary_writer = SummaryWriter(
  249. str(output_dir / "tensorboard" / "valid")
  250. )
  251. else:
  252. train_summary_writer = None
  253. start_time = time.perf_counter()
  254. for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
  255. if iepoch != start_epoch:
  256. logging.info(
  257. "{}/{}epoch started. Estimated time to finish: {}".format(
  258. iepoch,
  259. trainer_options.max_epoch,
  260. humanfriendly.format_timespan(
  261. (time.perf_counter() - start_time)
  262. / (iepoch - start_epoch)
  263. * (trainer_options.max_epoch - iepoch + 1)
  264. ),
  265. )
  266. )
  267. else:
  268. logging.info(f"{iepoch}/{trainer_options.max_epoch}epoch started")
  269. set_all_random_seed(trainer_options.seed + iepoch)
  270. reporter.set_epoch(iepoch)
  271. # 1. Train and validation for one-epoch
  272. with reporter.observe("train") as sub_reporter:
  273. all_steps_are_invalid, max_update_stop = cls.train_one_epoch(
  274. model=dp_model,
  275. optimizers=optimizers,
  276. schedulers=schedulers,
  277. iterator=train_iter_factory.build_iter(iepoch),
  278. reporter=sub_reporter,
  279. scaler=scaler,
  280. summary_writer=train_summary_writer,
  281. options=trainer_options,
  282. distributed_option=distributed_option,
  283. )
  284. with reporter.observe("valid") as sub_reporter:
  285. cls.validate_one_epoch(
  286. model=dp_model,
  287. iterator=valid_iter_factory.build_iter(iepoch),
  288. reporter=sub_reporter,
  289. options=trainer_options,
  290. distributed_option=distributed_option,
  291. )
  292. # 2. LR Scheduler step
  293. for scheduler in schedulers:
  294. if isinstance(scheduler, AbsValEpochStepScheduler):
  295. scheduler.step(
  296. reporter.get_value(*trainer_options.val_scheduler_criterion)
  297. )
  298. elif isinstance(scheduler, AbsEpochStepScheduler):
  299. scheduler.step()
  300. if trainer_options.sharded_ddp:
  301. for optimizer in optimizers:
  302. if isinstance(optimizer, fairscale.optim.oss.OSS):
  303. optimizer.consolidate_state_dict()
  304. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  305. # 3. Report the results
  306. logging.info(reporter.log_message())
  307. if train_summary_writer is not None:
  308. reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
  309. reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
  310. if trainer_options.use_wandb:
  311. reporter.wandb_log()
  312. # save tensorboard on oss
  313. if trainer_options.use_pai and train_summary_writer is not None:
  314. def write_tensorboard_summary(summary_writer_path, oss_bucket):
  315. file_list = []
  316. for root, dirs, files in os.walk(summary_writer_path, topdown=False):
  317. for name in files:
  318. file_full_path = os.path.join(root, name)
  319. file_list.append(file_full_path)
  320. for file_full_path in file_list:
  321. with open(file_full_path, "rb") as f:
  322. oss_bucket.put_object(file_full_path, f)
  323. write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/train"), trainer_options.oss_bucket)
  324. write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/valid"), trainer_options.oss_bucket)
  325. # 4. Save/Update the checkpoint
  326. if trainer_options.use_pai:
  327. buffer = BytesIO()
  328. torch.save(
  329. {
  330. "model": model.state_dict(),
  331. "reporter": reporter.state_dict(),
  332. "optimizers": [o.state_dict() for o in optimizers],
  333. "schedulers": [
  334. s.state_dict() if s is not None else None
  335. for s in schedulers
  336. ],
  337. "scaler": scaler.state_dict() if scaler is not None else None,
  338. "ema_model": model.encoder.ema.model.state_dict()
  339. if hasattr(model.encoder, "ema") and model.encoder.ema is not None else None,
  340. },
  341. buffer,
  342. )
  343. trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pb"), buffer.getvalue())
  344. else:
  345. torch.save(
  346. {
  347. "model": model.state_dict(),
  348. "reporter": reporter.state_dict(),
  349. "optimizers": [o.state_dict() for o in optimizers],
  350. "schedulers": [
  351. s.state_dict() if s is not None else None
  352. for s in schedulers
  353. ],
  354. "scaler": scaler.state_dict() if scaler is not None else None,
  355. },
  356. output_dir / "checkpoint.pb",
  357. )
  358. # 5. Save and log the model and update the link to the best model
  359. if trainer_options.use_pai:
  360. buffer = BytesIO()
  361. torch.save(model.state_dict(), buffer)
  362. trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir,
  363. f"{iepoch}epoch.pb"),buffer.getvalue())
  364. else:
  365. torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pb")
  366. # Creates a sym link latest.pb -> {iepoch}epoch.pb
  367. if trainer_options.use_pai:
  368. p = os.path.join(trainer_options.output_dir, "latest.pb")
  369. if trainer_options.oss_bucket.object_exists(p):
  370. trainer_options.oss_bucket.delete_object(p)
  371. trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
  372. os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pb"), p)
  373. else:
  374. p = output_dir / "latest.pb"
  375. if p.is_symlink() or p.exists():
  376. p.unlink()
  377. p.symlink_to(f"{iepoch}epoch.pb")
  378. _improved = []
  379. for _phase, k, _mode in trainer_options.best_model_criterion:
  380. # e.g. _phase, k, _mode = "train", "loss", "min"
  381. if reporter.has(_phase, k):
  382. best_epoch = reporter.get_best_epoch(_phase, k, _mode)
  383. # Creates sym links if it's the best result
  384. if best_epoch == iepoch:
  385. if trainer_options.use_pai:
  386. p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pb")
  387. if trainer_options.oss_bucket.object_exists(p):
  388. trainer_options.oss_bucket.delete_object(p)
  389. trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
  390. os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pb"),p)
  391. else:
  392. p = output_dir / f"{_phase}.{k}.best.pb"
  393. if p.is_symlink() or p.exists():
  394. p.unlink()
  395. p.symlink_to(f"{iepoch}epoch.pb")
  396. _improved.append(f"{_phase}.{k}")
  397. if len(_improved) == 0:
  398. logging.info("There are no improvements in this epoch")
  399. else:
  400. logging.info(
  401. "The best model has been updated: " + ", ".join(_improved)
  402. )
  403. log_model = (
  404. trainer_options.wandb_model_log_interval > 0
  405. and iepoch % trainer_options.wandb_model_log_interval == 0
  406. )
  407. if log_model and trainer_options.use_wandb:
  408. import wandb
  409. logging.info("Logging Model on this epoch :::::")
  410. artifact = wandb.Artifact(
  411. name=f"model_{wandb.run.id}",
  412. type="model",
  413. metadata={"improved": _improved},
  414. )
  415. artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
  416. aliases = [
  417. f"epoch-{iepoch}",
  418. "best" if best_epoch == iepoch else "",
  419. ]
  420. wandb.log_artifact(artifact, aliases=aliases)
  421. # 6. Remove the model files excluding n-best epoch and latest epoch
  422. _removed = []
  423. # Get the union set of the n-best among multiple criterion
  424. nbests = set().union(
  425. *[
  426. set(reporter.sort_epochs(ph, k, m)[: max(keep_nbest_models)])
  427. for ph, k, m in trainer_options.best_model_criterion
  428. if reporter.has(ph, k)
  429. ]
  430. )
  431. # Generated n-best averaged model
  432. if (
  433. trainer_options.nbest_averaging_interval > 0
  434. and iepoch % trainer_options.nbest_averaging_interval == 0
  435. ):
  436. average_nbest_models(
  437. reporter=reporter,
  438. output_dir=output_dir,
  439. best_model_criterion=trainer_options.best_model_criterion,
  440. nbest=keep_nbest_models,
  441. suffix=f"till{iepoch}epoch",
  442. oss_bucket=trainer_options.oss_bucket,
  443. pai_output_dir=trainer_options.output_dir,
  444. )
  445. for e in range(1, iepoch):
  446. if trainer_options.use_pai:
  447. p = os.path.join(trainer_options.output_dir, f"{e}epoch.pb")
  448. if trainer_options.oss_bucket.object_exists(p) and e not in nbests:
  449. trainer_options.oss_bucket.delete_object(p)
  450. _removed.append(str(p))
  451. else:
  452. p = output_dir / f"{e}epoch.pb"
  453. if p.exists() and e not in nbests:
  454. p.unlink()
  455. _removed.append(str(p))
  456. if len(_removed) != 0:
  457. logging.info("The model files were removed: " + ", ".join(_removed))
  458. # 7. If any updating haven't happened, stops the training
  459. if all_steps_are_invalid:
  460. logging.warning(
  461. f"The gradients at all steps are invalid in this epoch. "
  462. f"Something seems wrong. This training was stopped at {iepoch}epoch"
  463. )
  464. break
  465. if max_update_stop:
  466. logging.info(
  467. f"Stopping training due to "
  468. f"num_updates: {trainer_options.num_updates} >= max_update: {trainer_options.max_update}"
  469. )
  470. break
  471. # 8. Check early stopping
  472. if trainer_options.patience is not None:
  473. if reporter.check_early_stopping(
  474. trainer_options.patience, *trainer_options.early_stopping_criterion
  475. ):
  476. break
  477. else:
  478. logging.info(
  479. f"The training was finished at {trainer_options.max_epoch} epochs "
  480. )
  481. # Generated n-best averaged model
  482. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  483. average_nbest_models(
  484. reporter=reporter,
  485. output_dir=output_dir,
  486. best_model_criterion=trainer_options.best_model_criterion,
  487. nbest=keep_nbest_models,
  488. oss_bucket=trainer_options.oss_bucket,
  489. pai_output_dir=trainer_options.output_dir,
  490. )
  491. @classmethod
  492. def train_one_epoch(
  493. cls,
  494. model: torch.nn.Module,
  495. iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
  496. optimizers: Sequence[torch.optim.Optimizer],
  497. schedulers: Sequence[Optional[AbsScheduler]],
  498. scaler: Optional[GradScaler],
  499. reporter: SubReporter,
  500. summary_writer,
  501. options: TrainerOptions,
  502. distributed_option: DistributedOption,
  503. ) -> Tuple[bool, bool]:
  504. assert check_argument_types()
  505. grad_noise = options.grad_noise
  506. accum_grad = options.accum_grad
  507. grad_clip = options.grad_clip
  508. grad_clip_type = options.grad_clip_type
  509. log_interval = options.log_interval
  510. no_forward_run = options.no_forward_run
  511. ngpu = options.ngpu
  512. use_wandb = options.use_wandb
  513. bias_grad_times = options.bias_grad_times
  514. distributed = distributed_option.distributed
  515. if bias_grad_times != 1.0:
  516. logging.warning("Using bias_grad_times: {} for gradient scaling".format(bias_grad_times))
  517. if log_interval is None:
  518. try:
  519. log_interval = max(len(iterator) // 20, 10)
  520. except TypeError:
  521. log_interval = 100
  522. model.train()
  523. all_steps_are_invalid = True
  524. max_update_stop = False
  525. # [For distributed] Because iteration counts are not always equals between
  526. # processes, send stop-flag to the other processes if iterator is finished
  527. iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
  528. #get the rank
  529. rank = distributed_option.dist_rank
  530. #get the num batch updates
  531. num_batch_updates = 0
  532. #ouput dir
  533. output_dir = Path(options.output_dir)
  534. #batch interval
  535. batch_interval = options.batch_interval
  536. start_time = time.perf_counter()
  537. for iiter, (_, batch) in enumerate(
  538. reporter.measure_iter_time(iterator, "iter_time"), 1
  539. ):
  540. assert isinstance(batch, dict), type(batch)
  541. if batch_interval > 0 and (not distributed_option.distributed or rank == 0):
  542. if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")):
  543. num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates()
  544. if num_batch_updates % batch_interval == 0:
  545. if options.use_pai and options.oss_bucket is not None:
  546. buffer = BytesIO()
  547. if hasattr(model, "module"):
  548. torch.save(model.module.state_dict(), buffer)
  549. else:
  550. torch.save(model.state_dict(), buffer)
  551. options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue())
  552. else:
  553. if hasattr(model, "module"):
  554. torch.save(model.module.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
  555. else:
  556. torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
  557. if distributed:
  558. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
  559. if iterator_stop > 0:
  560. break
  561. batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
  562. if no_forward_run:
  563. all_steps_are_invalid = False
  564. continue
  565. if iiter == 1 and summary_writer is not None:
  566. try:
  567. args = kwargs2args(model.forward, batch)
  568. except (ValueError, TypeError):
  569. logging.warning(
  570. "inpect.signature() is failed for the model. "
  571. "The graph can't be added for tensorboard."
  572. )
  573. else:
  574. try:
  575. summary_writer.add_graph(model, args, use_strict_trace=False)
  576. except Exception:
  577. logging.warning(
  578. "summary_writer.add_graph() is failed for the model. "
  579. "The graph can't be added for tensorboard."
  580. )
  581. del args
  582. with autocast(scaler is not None):
  583. with reporter.measure_time("forward_time"):
  584. retval = model(**batch)
  585. # Note(kamo):
  586. # Supporting two patterns for the returned value from the model
  587. # a. dict type
  588. if isinstance(retval, dict):
  589. loss = retval["loss"]
  590. stats = retval["stats"]
  591. weight = retval["weight"]
  592. optim_idx = retval.get("optim_idx")
  593. if optim_idx is not None and not isinstance(optim_idx, int):
  594. if not isinstance(optim_idx, torch.Tensor):
  595. raise RuntimeError(
  596. "optim_idx must be int or 1dim torch.Tensor, "
  597. f"but got {type(optim_idx)}"
  598. )
  599. if optim_idx.dim() >= 2:
  600. raise RuntimeError(
  601. "optim_idx must be int or 1dim torch.Tensor, "
  602. f"but got {optim_idx.dim()}dim tensor"
  603. )
  604. if optim_idx.dim() == 1:
  605. for v in optim_idx:
  606. if v != optim_idx[0]:
  607. raise RuntimeError(
  608. "optim_idx must be 1dim tensor "
  609. "having same values for all entries"
  610. )
  611. optim_idx = optim_idx[0].item()
  612. else:
  613. optim_idx = optim_idx.item()
  614. # b. tuple or list type
  615. else:
  616. loss, stats, weight = retval
  617. optim_idx = None
  618. stats = {k: v for k, v in stats.items() if v is not None}
  619. if ngpu > 1 or distributed:
  620. # Apply weighted averaging for loss and stats
  621. loss = (loss * weight.type(loss.dtype)).sum()
  622. # if distributed, this method can also apply all_reduce()
  623. stats, weight = recursive_average(stats, weight, distributed)
  624. # Now weight is summation over all workers
  625. loss /= weight
  626. if distributed:
  627. # NOTE(kamo): Multiply world_size because DistributedDataParallel
  628. # automatically normalizes the gradient by world_size.
  629. loss *= torch.distributed.get_world_size()
  630. loss /= accum_grad
  631. reporter.register(stats, weight)
  632. with reporter.measure_time("backward_time"):
  633. if scaler is not None:
  634. # Scales loss. Calls backward() on scaled loss
  635. # to create scaled gradients.
  636. # Backward passes under autocast are not recommended.
  637. # Backward ops run in the same dtype autocast chose
  638. # for corresponding forward ops.
  639. scaler.scale(loss).backward()
  640. else:
  641. loss.backward()
  642. if iiter % accum_grad == 0:
  643. if scaler is not None:
  644. # Unscales the gradients of optimizer's assigned params in-place
  645. for iopt, optimizer in enumerate(optimizers):
  646. if optim_idx is not None and iopt != optim_idx:
  647. continue
  648. scaler.unscale_(optimizer)
  649. # gradient noise injection
  650. if grad_noise:
  651. add_gradient_noise(
  652. model,
  653. reporter.get_total_count(),
  654. duration=100,
  655. eta=1.0,
  656. scale_factor=0.55,
  657. )
  658. # for contextual training
  659. if bias_grad_times != 1.0:
  660. # contextual related parameter names
  661. cr_pnames = ["bias_encoder", "bias_embed", "decoder.bias_decoder", "decoder.bias_output"]
  662. for name, param in model.named_parameters():
  663. for cr_pname in cr_pnames:
  664. if cr_pname in name:
  665. param.grad *= bias_grad_times
  666. continue
  667. # compute the gradient norm to check if it is normal or not
  668. grad_norm = torch.nn.utils.clip_grad_norm_(
  669. model.parameters(),
  670. max_norm=grad_clip,
  671. norm_type=grad_clip_type,
  672. )
  673. # PyTorch<=1.4, clip_grad_norm_ returns float value
  674. if not isinstance(grad_norm, torch.Tensor):
  675. grad_norm = torch.tensor(grad_norm)
  676. if not torch.isfinite(grad_norm):
  677. logging.warning(
  678. f"The grad norm is {grad_norm}. Skipping updating the model."
  679. )
  680. # Must invoke scaler.update() if unscale_() is used in the iteration
  681. # to avoid the following error:
  682. # RuntimeError: unscale_() has already been called
  683. # on this optimizer since the last update().
  684. # Note that if the gradient has inf/nan values,
  685. # scaler.step skips optimizer.step().
  686. if scaler is not None:
  687. for iopt, optimizer in enumerate(optimizers):
  688. if optim_idx is not None and iopt != optim_idx:
  689. continue
  690. scaler.step(optimizer)
  691. scaler.update()
  692. else:
  693. all_steps_are_invalid = False
  694. with reporter.measure_time("optim_step_time"):
  695. for iopt, (optimizer, scheduler) in enumerate(
  696. zip(optimizers, schedulers)
  697. ):
  698. if optim_idx is not None and iopt != optim_idx:
  699. continue
  700. if scaler is not None:
  701. # scaler.step() first unscales the gradients of
  702. # the optimizer's assigned params.
  703. scaler.step(optimizer)
  704. # Updates the scale for next iteration.
  705. scaler.update()
  706. else:
  707. optimizer.step()
  708. if isinstance(scheduler, AbsBatchStepScheduler):
  709. scheduler.step()
  710. for iopt, optimizer in enumerate(optimizers):
  711. if optim_idx is not None and iopt != optim_idx:
  712. continue
  713. optimizer.zero_grad()
  714. # Register lr and train/load time[sec/step],
  715. # where step refers to accum_grad * mini-batch
  716. reporter.register(
  717. dict(
  718. {
  719. f"optim{i}_lr{j}": pg["lr"]
  720. for i, optimizer in enumerate(optimizers)
  721. for j, pg in enumerate(optimizer.param_groups)
  722. if "lr" in pg
  723. },
  724. train_time=time.perf_counter() - start_time,
  725. ),
  726. )
  727. start_time = time.perf_counter()
  728. # update num_updates
  729. if distributed:
  730. if hasattr(model.module, "num_updates"):
  731. model.module.set_num_updates(model.module.get_num_updates() + 1)
  732. options.num_updates = model.module.get_num_updates()
  733. if model.module.get_num_updates() >= options.max_update:
  734. max_update_stop = True
  735. else:
  736. if hasattr(model, "num_updates"):
  737. model.set_num_updates(model.get_num_updates() + 1)
  738. options.num_updates = model.get_num_updates()
  739. if model.get_num_updates() >= options.max_update:
  740. max_update_stop = True
  741. # NOTE(kamo): Call log_message() after next()
  742. reporter.next()
  743. if iiter % log_interval == 0:
  744. num_updates = options.num_updates if hasattr(options, "num_updates") else None
  745. logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
  746. if summary_writer is not None:
  747. reporter.tensorboard_add_scalar(summary_writer, -log_interval)
  748. if use_wandb:
  749. reporter.wandb_log()
  750. if max_update_stop:
  751. break
  752. else:
  753. if distributed:
  754. iterator_stop.fill_(1)
  755. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
  756. return all_steps_are_invalid, max_update_stop
  757. @classmethod
  758. @torch.no_grad()
  759. def validate_one_epoch(
  760. cls,
  761. model: torch.nn.Module,
  762. iterator: Iterable[Dict[str, torch.Tensor]],
  763. reporter: SubReporter,
  764. options: TrainerOptions,
  765. distributed_option: DistributedOption,
  766. ) -> None:
  767. assert check_argument_types()
  768. ngpu = options.ngpu
  769. no_forward_run = options.no_forward_run
  770. distributed = distributed_option.distributed
  771. model.eval()
  772. # [For distributed] Because iteration counts are not always equals between
  773. # processes, send stop-flag to the other processes if iterator is finished
  774. iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
  775. for (_, batch) in iterator:
  776. assert isinstance(batch, dict), type(batch)
  777. if distributed:
  778. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
  779. if iterator_stop > 0:
  780. break
  781. batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
  782. if no_forward_run:
  783. continue
  784. retval = model(**batch)
  785. if isinstance(retval, dict):
  786. stats = retval["stats"]
  787. weight = retval["weight"]
  788. else:
  789. _, stats, weight = retval
  790. if ngpu > 1 or distributed:
  791. # Apply weighted averaging for stats.
  792. # if distributed, this method can also apply all_reduce()
  793. stats, weight = recursive_average(stats, weight, distributed)
  794. reporter.register(stats, weight)
  795. reporter.next()
  796. else:
  797. if distributed:
  798. iterator_stop.fill_(1)
  799. torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)