reporter.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. """Reporter module."""
  2. import dataclasses
  3. import datetime
  4. import logging
  5. import time
  6. import warnings
  7. from collections import defaultdict
  8. from contextlib import contextmanager
  9. from distutils.version import LooseVersion
  10. from typing import ContextManager
  11. from typing import Dict
  12. from typing import List
  13. from typing import Optional
  14. from typing import Sequence
  15. from typing import Tuple
  16. from typing import Union
  17. import humanfriendly
  18. import numpy as np
  19. import torch
  20. Num = Union[float, int, complex, torch.Tensor, np.ndarray]
  21. _reserved = {"time", "total_count"}
  22. def to_reported_value(v: Num, weight: Num = None) -> "ReportedValue":
  23. if isinstance(v, (torch.Tensor, np.ndarray)):
  24. if np.prod(v.shape) != 1:
  25. raise ValueError(f"v must be 0 or 1 dimension: {len(v.shape)}")
  26. v = v.item()
  27. if isinstance(weight, (torch.Tensor, np.ndarray)):
  28. if np.prod(weight.shape) != 1:
  29. raise ValueError(f"weight must be 0 or 1 dimension: {len(weight.shape)}")
  30. weight = weight.item()
  31. if weight is not None:
  32. retval = WeightedAverage(v, weight)
  33. else:
  34. retval = Average(v)
  35. return retval
  36. def aggregate(values: Sequence["ReportedValue"]) -> Num:
  37. for v in values:
  38. if not isinstance(v, type(values[0])):
  39. raise ValueError(
  40. f"Can't use different Reported type together: "
  41. f"{type(v)} != {type(values[0])}"
  42. )
  43. if len(values) == 0:
  44. warnings.warn("No stats found")
  45. retval = np.nan
  46. elif isinstance(values[0], Average):
  47. retval = np.nanmean([v.value for v in values])
  48. elif isinstance(values[0], WeightedAverage):
  49. # Excludes non finite values
  50. invalid_indices = set()
  51. for i, v in enumerate(values):
  52. if not np.isfinite(v.value) or not np.isfinite(v.weight):
  53. invalid_indices.add(i)
  54. values = [v for i, v in enumerate(values) if i not in invalid_indices]
  55. if len(values) != 0:
  56. # Calc weighed average. Weights are changed to sum-to-1.
  57. sum_weights = sum(v.weight for i, v in enumerate(values))
  58. sum_value = sum(v.value * v.weight for i, v in enumerate(values))
  59. if sum_weights == 0:
  60. warnings.warn("weight is zero")
  61. retval = np.nan
  62. else:
  63. retval = sum_value / sum_weights
  64. else:
  65. warnings.warn("No valid stats found")
  66. retval = np.nan
  67. else:
  68. raise NotImplementedError(f"type={type(values[0])}")
  69. return retval
  70. def wandb_get_prefix(key: str):
  71. if key.startswith("valid"):
  72. return "valid/"
  73. if key.startswith("train"):
  74. return "train/"
  75. if key.startswith("attn"):
  76. return "attn/"
  77. return "metrics/"
  78. class ReportedValue:
  79. pass
  80. @dataclasses.dataclass(frozen=True)
  81. class Average(ReportedValue):
  82. value: Num
  83. @dataclasses.dataclass(frozen=True)
  84. class WeightedAverage(ReportedValue):
  85. value: Tuple[Num, Num]
  86. weight: Num
  87. class SubReporter:
  88. """This class is used in Reporter.
  89. See the docstring of Reporter for the usage.
  90. """
  91. def __init__(self, key: str, epoch: int, total_count: int):
  92. self.key = key
  93. self.epoch = epoch
  94. self.start_time = time.perf_counter()
  95. self.stats = defaultdict(list)
  96. self._finished = False
  97. self.total_count = total_count
  98. self.count = 0
  99. self._seen_keys_in_the_step = set()
  100. def get_total_count(self) -> int:
  101. """Returns the number of iterations over all epochs."""
  102. return self.total_count
  103. def get_epoch(self) -> int:
  104. return self.epoch
  105. def next(self):
  106. """Close up this step and reset state for the next step"""
  107. for key, stats_list in self.stats.items():
  108. if key not in self._seen_keys_in_the_step:
  109. # Fill nan value if the key is not registered in this step
  110. if isinstance(stats_list[0], WeightedAverage):
  111. stats_list.append(to_reported_value(np.nan, 0))
  112. elif isinstance(stats_list[0], Average):
  113. stats_list.append(to_reported_value(np.nan))
  114. else:
  115. raise NotImplementedError(f"type={type(stats_list[0])}")
  116. assert len(stats_list) == self.count, (len(stats_list), self.count)
  117. self._seen_keys_in_the_step = set()
  118. def register(
  119. self,
  120. stats: Dict[str, Optional[Union[Num, Dict[str, Num]]]],
  121. weight: Num = None,
  122. ) -> None:
  123. if self._finished:
  124. raise RuntimeError("Already finished")
  125. if len(self._seen_keys_in_the_step) == 0:
  126. # Increment count as the first register in this step
  127. self.total_count += 1
  128. self.count += 1
  129. for key2, v in stats.items():
  130. if key2 in _reserved:
  131. raise RuntimeError(f"{key2} is reserved.")
  132. if key2 in self._seen_keys_in_the_step:
  133. raise RuntimeError(f"{key2} is registered twice.")
  134. if v is None:
  135. v = np.nan
  136. r = to_reported_value(v, weight)
  137. if key2 not in self.stats:
  138. # If it's the first time to register the key,
  139. # append nan values in front of the the value
  140. # to make it same length to the other stats
  141. # e.g.
  142. # stat A: [0.4, 0.3, 0.5]
  143. # stat B: [nan, nan, 0.2]
  144. nan = to_reported_value(np.nan, None if weight is None else 0)
  145. self.stats[key2].extend(
  146. r if i == self.count - 1 else nan for i in range(self.count)
  147. )
  148. else:
  149. self.stats[key2].append(r)
  150. self._seen_keys_in_the_step.add(key2)
  151. def log_message(self, start: int = None, end: int = None, num_updates: int = None) -> str:
  152. if self._finished:
  153. raise RuntimeError("Already finished")
  154. if start is None:
  155. start = 0
  156. if start < 0:
  157. start = self.count + start
  158. if end is None:
  159. end = self.count
  160. if self.count == 0 or start == end:
  161. return ""
  162. message = f"{self.epoch}epoch:{self.key}:" f"{start + 1}-{end}batch:"
  163. if num_updates is not None:
  164. message += f"{num_updates}num_updates: "
  165. for idx, (key2, stats_list) in enumerate(self.stats.items()):
  166. assert len(stats_list) == self.count, (len(stats_list), self.count)
  167. # values: List[ReportValue]
  168. values = stats_list[start:end]
  169. if idx != 0 and idx != len(stats_list):
  170. message += ", "
  171. v = aggregate(values)
  172. if abs(v) > 1.0e3:
  173. message += f"{key2}={v:.3e}"
  174. elif abs(v) > 1.0e-3:
  175. message += f"{key2}={v:.3f}"
  176. else:
  177. message += f"{key2}={v:.3e}"
  178. return message
  179. def tensorboard_add_scalar(self, summary_writer, start: int = None):
  180. if start is None:
  181. start = 0
  182. if start < 0:
  183. start = self.count + start
  184. for key2, stats_list in self.stats.items():
  185. assert len(stats_list) == self.count, (len(stats_list), self.count)
  186. # values: List[ReportValue]
  187. values = stats_list[start:]
  188. v = aggregate(values)
  189. summary_writer.add_scalar(f"{key2}", v, self.total_count)
  190. def wandb_log(self, start: int = None):
  191. import wandb
  192. if start is None:
  193. start = 0
  194. if start < 0:
  195. start = self.count + start
  196. d = {}
  197. for key2, stats_list in self.stats.items():
  198. assert len(stats_list) == self.count, (len(stats_list), self.count)
  199. # values: List[ReportValue]
  200. values = stats_list[start:]
  201. v = aggregate(values)
  202. d[wandb_get_prefix(key2) + key2] = v
  203. d["iteration"] = self.total_count
  204. wandb.log(d)
  205. def finished(self) -> None:
  206. self._finished = True
  207. @contextmanager
  208. def measure_time(self, name: str):
  209. start = time.perf_counter()
  210. yield start
  211. t = time.perf_counter() - start
  212. self.register({name: t})
  213. def measure_iter_time(self, iterable, name: str):
  214. iterator = iter(iterable)
  215. while True:
  216. try:
  217. start = time.perf_counter()
  218. retval = next(iterator)
  219. t = time.perf_counter() - start
  220. self.register({name: t})
  221. yield retval
  222. except StopIteration:
  223. break
  224. class Reporter:
  225. """Reporter class.
  226. Examples:
  227. >>> reporter = Reporter()
  228. >>> with reporter.observe('train') as sub_reporter:
  229. ... for batch in iterator:
  230. ... stats = dict(loss=0.2)
  231. ... sub_reporter.register(stats)
  232. """
  233. def __init__(self, epoch: int = 0):
  234. if epoch < 0:
  235. raise ValueError(f"epoch must be 0 or more: {epoch}")
  236. self.epoch = epoch
  237. # stats: Dict[int, Dict[str, Dict[str, float]]]
  238. # e.g. self.stats[epoch]['train']['loss']
  239. self.stats = {}
  240. def get_epoch(self) -> int:
  241. return self.epoch
  242. def set_epoch(self, epoch: int) -> None:
  243. if epoch < 0:
  244. raise ValueError(f"epoch must be 0 or more: {epoch}")
  245. self.epoch = epoch
  246. @contextmanager
  247. def observe(self, key: str, epoch: int = None) -> ContextManager[SubReporter]:
  248. sub_reporter = self.start_epoch(key, epoch)
  249. yield sub_reporter
  250. # Receive the stats from sub_reporter
  251. self.finish_epoch(sub_reporter)
  252. def start_epoch(self, key: str, epoch: int = None) -> SubReporter:
  253. if epoch is not None:
  254. if epoch < 0:
  255. raise ValueError(f"epoch must be 0 or more: {epoch}")
  256. self.epoch = epoch
  257. if self.epoch - 1 not in self.stats or key not in self.stats[self.epoch - 1]:
  258. # If the previous epoch doesn't exist for some reason,
  259. # maybe due to bug, this case also indicates 0-count.
  260. if self.epoch - 1 != 0:
  261. warnings.warn(
  262. f"The stats of the previous epoch={self.epoch - 1}"
  263. f"doesn't exist."
  264. )
  265. total_count = 0
  266. else:
  267. total_count = self.stats[self.epoch - 1][key]["total_count"]
  268. sub_reporter = SubReporter(key, self.epoch, total_count)
  269. # Clear the stats for the next epoch if it exists
  270. self.stats.pop(epoch, None)
  271. return sub_reporter
  272. def finish_epoch(self, sub_reporter: SubReporter) -> None:
  273. if self.epoch != sub_reporter.epoch:
  274. raise RuntimeError(
  275. f"Don't change epoch during observation: "
  276. f"{self.epoch} != {sub_reporter.epoch}"
  277. )
  278. # Calc mean of current stats and set it as previous epochs stats
  279. stats = {}
  280. for key2, values in sub_reporter.stats.items():
  281. v = aggregate(values)
  282. stats[key2] = v
  283. stats["time"] = datetime.timedelta(
  284. seconds=time.perf_counter() - sub_reporter.start_time
  285. )
  286. stats["total_count"] = sub_reporter.total_count
  287. if LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
  288. if torch.cuda.is_initialized():
  289. stats["gpu_max_cached_mem_GB"] = (
  290. torch.cuda.max_memory_reserved() / 2 ** 30
  291. )
  292. else:
  293. if torch.cuda.is_available() and torch.cuda.max_memory_cached() > 0:
  294. stats["gpu_cached_mem_GB"] = torch.cuda.max_memory_cached() / 2 ** 30
  295. self.stats.setdefault(self.epoch, {})[sub_reporter.key] = stats
  296. sub_reporter.finished()
  297. def sort_epochs_and_values(
  298. self, key: str, key2: str, mode: str
  299. ) -> List[Tuple[int, float]]:
  300. """Return the epoch which resulted the best value.
  301. Example:
  302. >>> val = reporter.sort_epochs_and_values('eval', 'loss', 'min')
  303. >>> e_1best, v_1best = val[0]
  304. >>> e_2best, v_2best = val[1]
  305. """
  306. if mode not in ("min", "max"):
  307. raise ValueError(f"mode must min or max: {mode}")
  308. if not self.has(key, key2):
  309. raise KeyError(f"{key}.{key2} is not found: {self.get_all_keys()}")
  310. # iterate from the last epoch
  311. values = [(e, self.stats[e][key][key2]) for e in self.stats]
  312. if mode == "min":
  313. values = sorted(values, key=lambda x: x[1])
  314. else:
  315. values = sorted(values, key=lambda x: -x[1])
  316. return values
  317. def sort_epochs(self, key: str, key2: str, mode: str) -> List[int]:
  318. return [e for e, v in self.sort_epochs_and_values(key, key2, mode)]
  319. def sort_values(self, key: str, key2: str, mode: str) -> List[float]:
  320. return [v for e, v in self.sort_epochs_and_values(key, key2, mode)]
  321. def get_best_epoch(self, key: str, key2: str, mode: str, nbest: int = 0) -> int:
  322. return self.sort_epochs(key, key2, mode)[nbest]
  323. def check_early_stopping(
  324. self,
  325. patience: int,
  326. key1: str,
  327. key2: str,
  328. mode: str,
  329. epoch: int = None,
  330. logger=None,
  331. ) -> bool:
  332. if logger is None:
  333. logger = logging
  334. if epoch is None:
  335. epoch = self.get_epoch()
  336. best_epoch = self.get_best_epoch(key1, key2, mode)
  337. if epoch - best_epoch > patience:
  338. logger.info(
  339. f"[Early stopping] {key1}.{key2} has not been "
  340. f"improved {epoch - best_epoch} epochs continuously. "
  341. f"The training was stopped at {epoch}epoch"
  342. )
  343. return True
  344. else:
  345. return False
  346. def has(self, key: str, key2: str, epoch: int = None) -> bool:
  347. if epoch is None:
  348. epoch = self.get_epoch()
  349. return (
  350. epoch in self.stats
  351. and key in self.stats[epoch]
  352. and key2 in self.stats[epoch][key]
  353. )
  354. def log_message(self, epoch: int = None) -> str:
  355. if epoch is None:
  356. epoch = self.get_epoch()
  357. message = ""
  358. for key, d in self.stats[epoch].items():
  359. _message = ""
  360. for key2, v in d.items():
  361. if v is not None:
  362. if len(_message) != 0:
  363. _message += ", "
  364. if isinstance(v, float):
  365. if abs(v) > 1.0e3:
  366. _message += f"{key2}={v:.3e}"
  367. elif abs(v) > 1.0e-3:
  368. _message += f"{key2}={v:.3f}"
  369. else:
  370. _message += f"{key2}={v:.3e}"
  371. elif isinstance(v, datetime.timedelta):
  372. _v = humanfriendly.format_timespan(v)
  373. _message += f"{key2}={_v}"
  374. else:
  375. _message += f"{key2}={v}"
  376. if len(_message) != 0:
  377. if len(message) == 0:
  378. message += f"{epoch}epoch results: "
  379. else:
  380. message += ", "
  381. message += f"[{key}] {_message}"
  382. return message
  383. def get_value(self, key: str, key2: str, epoch: int = None):
  384. if not self.has(key, key2):
  385. raise KeyError(f"{key}.{key2} is not found in stats: {self.get_all_keys()}")
  386. if epoch is None:
  387. epoch = self.get_epoch()
  388. return self.stats[epoch][key][key2]
  389. def get_keys(self, epoch: int = None) -> Tuple[str, ...]:
  390. """Returns keys1 e.g. train,eval."""
  391. if epoch is None:
  392. epoch = self.get_epoch()
  393. return tuple(self.stats[epoch])
  394. def get_keys2(self, key: str, epoch: int = None) -> Tuple[str, ...]:
  395. """Returns keys2 e.g. loss,acc."""
  396. if epoch is None:
  397. epoch = self.get_epoch()
  398. d = self.stats[epoch][key]
  399. keys2 = tuple(k for k in d if k not in ("time", "total_count"))
  400. return keys2
  401. def get_all_keys(self, epoch: int = None) -> Tuple[Tuple[str, str], ...]:
  402. if epoch is None:
  403. epoch = self.get_epoch()
  404. all_keys = []
  405. for key in self.stats[epoch]:
  406. for key2 in self.stats[epoch][key]:
  407. all_keys.append((key, key2))
  408. return tuple(all_keys)
  409. def tensorboard_add_scalar(
  410. self, summary_writer, epoch: int = None, key1: str = None
  411. ):
  412. if epoch is None:
  413. epoch = self.get_epoch()
  414. total_count = self.stats[epoch]["train"]["total_count"]
  415. if key1 == "train":
  416. summary_writer.add_scalar("iter_epoch", epoch, total_count)
  417. if key1 is not None:
  418. key1_iterator = tuple([key1])
  419. else:
  420. key1_iterator = self.get_keys(epoch)
  421. for key1 in key1_iterator:
  422. for key2 in self.get_keys2(key1):
  423. summary_writer.add_scalar(
  424. f"{key2}", self.stats[epoch][key1][key2], total_count
  425. )
  426. def wandb_log(self, epoch: int = None):
  427. import wandb
  428. if epoch is None:
  429. epoch = self.get_epoch()
  430. d = {}
  431. for key1 in self.get_keys(epoch):
  432. for key2 in self.stats[epoch][key1]:
  433. if key2 in ("time", "total_count"):
  434. continue
  435. key = f"{key1}_{key2}_epoch"
  436. d[wandb_get_prefix(key) + key] = self.stats[epoch][key1][key2]
  437. d["epoch"] = epoch
  438. wandb.log(d)
  439. def state_dict(self):
  440. return {"stats": self.stats, "epoch": self.epoch}
  441. def load_state_dict(self, state_dict: dict):
  442. self.epoch = state_dict["epoch"]
  443. self.stats = state_dict["stats"]