reporter.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. """Reporter module."""
  2. import dataclasses
  3. import datetime
  4. import logging
  5. import time
  6. import warnings
  7. from collections import defaultdict
  8. from contextlib import contextmanager
  9. from distutils.version import LooseVersion
  10. from typing import ContextManager
  11. from typing import Dict
  12. from typing import List
  13. from typing import Optional
  14. from typing import Sequence
  15. from typing import Tuple
  16. from typing import Union
  17. import humanfriendly
  18. import numpy as np
  19. import torch
  20. from typeguard import check_argument_types
  21. from typeguard import check_return_type
  22. Num = Union[float, int, complex, torch.Tensor, np.ndarray]
  23. _reserved = {"time", "total_count"}
  24. def to_reported_value(v: Num, weight: Num = None) -> "ReportedValue":
  25. assert check_argument_types()
  26. if isinstance(v, (torch.Tensor, np.ndarray)):
  27. if np.prod(v.shape) != 1:
  28. raise ValueError(f"v must be 0 or 1 dimension: {len(v.shape)}")
  29. v = v.item()
  30. if isinstance(weight, (torch.Tensor, np.ndarray)):
  31. if np.prod(weight.shape) != 1:
  32. raise ValueError(f"weight must be 0 or 1 dimension: {len(weight.shape)}")
  33. weight = weight.item()
  34. if weight is not None:
  35. retval = WeightedAverage(v, weight)
  36. else:
  37. retval = Average(v)
  38. assert check_return_type(retval)
  39. return retval
  40. def aggregate(values: Sequence["ReportedValue"]) -> Num:
  41. assert check_argument_types()
  42. for v in values:
  43. if not isinstance(v, type(values[0])):
  44. raise ValueError(
  45. f"Can't use different Reported type together: "
  46. f"{type(v)} != {type(values[0])}"
  47. )
  48. if len(values) == 0:
  49. warnings.warn("No stats found")
  50. retval = np.nan
  51. elif isinstance(values[0], Average):
  52. retval = np.nanmean([v.value for v in values])
  53. elif isinstance(values[0], WeightedAverage):
  54. # Excludes non finite values
  55. invalid_indices = set()
  56. for i, v in enumerate(values):
  57. if not np.isfinite(v.value) or not np.isfinite(v.weight):
  58. invalid_indices.add(i)
  59. values = [v for i, v in enumerate(values) if i not in invalid_indices]
  60. if len(values) != 0:
  61. # Calc weighed average. Weights are changed to sum-to-1.
  62. sum_weights = sum(v.weight for i, v in enumerate(values))
  63. sum_value = sum(v.value * v.weight for i, v in enumerate(values))
  64. if sum_weights == 0:
  65. warnings.warn("weight is zero")
  66. retval = np.nan
  67. else:
  68. retval = sum_value / sum_weights
  69. else:
  70. warnings.warn("No valid stats found")
  71. retval = np.nan
  72. else:
  73. raise NotImplementedError(f"type={type(values[0])}")
  74. assert check_return_type(retval)
  75. return retval
  76. def wandb_get_prefix(key: str):
  77. if key.startswith("valid"):
  78. return "valid/"
  79. if key.startswith("train"):
  80. return "train/"
  81. if key.startswith("attn"):
  82. return "attn/"
  83. return "metrics/"
  84. class ReportedValue:
  85. pass
  86. @dataclasses.dataclass(frozen=True)
  87. class Average(ReportedValue):
  88. value: Num
  89. @dataclasses.dataclass(frozen=True)
  90. class WeightedAverage(ReportedValue):
  91. value: Tuple[Num, Num]
  92. weight: Num
  93. class SubReporter:
  94. """This class is used in Reporter.
  95. See the docstring of Reporter for the usage.
  96. """
  97. def __init__(self, key: str, epoch: int, total_count: int):
  98. assert check_argument_types()
  99. self.key = key
  100. self.epoch = epoch
  101. self.start_time = time.perf_counter()
  102. self.stats = defaultdict(list)
  103. self._finished = False
  104. self.total_count = total_count
  105. self.count = 0
  106. self._seen_keys_in_the_step = set()
  107. def get_total_count(self) -> int:
  108. """Returns the number of iterations over all epochs."""
  109. return self.total_count
  110. def get_epoch(self) -> int:
  111. return self.epoch
  112. def next(self):
  113. """Close up this step and reset state for the next step"""
  114. for key, stats_list in self.stats.items():
  115. if key not in self._seen_keys_in_the_step:
  116. # Fill nan value if the key is not registered in this step
  117. if isinstance(stats_list[0], WeightedAverage):
  118. stats_list.append(to_reported_value(np.nan, 0))
  119. elif isinstance(stats_list[0], Average):
  120. stats_list.append(to_reported_value(np.nan))
  121. else:
  122. raise NotImplementedError(f"type={type(stats_list[0])}")
  123. assert len(stats_list) == self.count, (len(stats_list), self.count)
  124. self._seen_keys_in_the_step = set()
  125. def register(
  126. self,
  127. stats: Dict[str, Optional[Union[Num, Dict[str, Num]]]],
  128. weight: Num = None,
  129. ) -> None:
  130. assert check_argument_types()
  131. if self._finished:
  132. raise RuntimeError("Already finished")
  133. if len(self._seen_keys_in_the_step) == 0:
  134. # Increment count as the first register in this step
  135. self.total_count += 1
  136. self.count += 1
  137. for key2, v in stats.items():
  138. if key2 in _reserved:
  139. raise RuntimeError(f"{key2} is reserved.")
  140. if key2 in self._seen_keys_in_the_step:
  141. raise RuntimeError(f"{key2} is registered twice.")
  142. if v is None:
  143. v = np.nan
  144. r = to_reported_value(v, weight)
  145. if key2 not in self.stats:
  146. # If it's the first time to register the key,
  147. # append nan values in front of the the value
  148. # to make it same length to the other stats
  149. # e.g.
  150. # stat A: [0.4, 0.3, 0.5]
  151. # stat B: [nan, nan, 0.2]
  152. nan = to_reported_value(np.nan, None if weight is None else 0)
  153. self.stats[key2].extend(
  154. r if i == self.count - 1 else nan for i in range(self.count)
  155. )
  156. else:
  157. self.stats[key2].append(r)
  158. self._seen_keys_in_the_step.add(key2)
  159. def log_message(self, start: int = None, end: int = None, num_updates: int = None) -> str:
  160. if self._finished:
  161. raise RuntimeError("Already finished")
  162. if start is None:
  163. start = 0
  164. if start < 0:
  165. start = self.count + start
  166. if end is None:
  167. end = self.count
  168. if self.count == 0 or start == end:
  169. return ""
  170. message = f"{self.epoch}epoch:{self.key}:" f"{start + 1}-{end}batch:"
  171. if num_updates is not None:
  172. message += f"{num_updates}num_updates: "
  173. for idx, (key2, stats_list) in enumerate(self.stats.items()):
  174. assert len(stats_list) == self.count, (len(stats_list), self.count)
  175. # values: List[ReportValue]
  176. values = stats_list[start:end]
  177. if idx != 0 and idx != len(stats_list):
  178. message += ", "
  179. v = aggregate(values)
  180. if abs(v) > 1.0e3:
  181. message += f"{key2}={v:.3e}"
  182. elif abs(v) > 1.0e-3:
  183. message += f"{key2}={v:.3f}"
  184. else:
  185. message += f"{key2}={v:.3e}"
  186. return message
  187. def tensorboard_add_scalar(self, summary_writer, start: int = None):
  188. if start is None:
  189. start = 0
  190. if start < 0:
  191. start = self.count + start
  192. for key2, stats_list in self.stats.items():
  193. assert len(stats_list) == self.count, (len(stats_list), self.count)
  194. # values: List[ReportValue]
  195. values = stats_list[start:]
  196. v = aggregate(values)
  197. summary_writer.add_scalar(f"{key2}", v, self.total_count)
  198. def wandb_log(self, start: int = None):
  199. import wandb
  200. if start is None:
  201. start = 0
  202. if start < 0:
  203. start = self.count + start
  204. d = {}
  205. for key2, stats_list in self.stats.items():
  206. assert len(stats_list) == self.count, (len(stats_list), self.count)
  207. # values: List[ReportValue]
  208. values = stats_list[start:]
  209. v = aggregate(values)
  210. d[wandb_get_prefix(key2) + key2] = v
  211. d["iteration"] = self.total_count
  212. wandb.log(d)
  213. def finished(self) -> None:
  214. self._finished = True
  215. @contextmanager
  216. def measure_time(self, name: str):
  217. start = time.perf_counter()
  218. yield start
  219. t = time.perf_counter() - start
  220. self.register({name: t})
  221. def measure_iter_time(self, iterable, name: str):
  222. iterator = iter(iterable)
  223. while True:
  224. try:
  225. start = time.perf_counter()
  226. retval = next(iterator)
  227. t = time.perf_counter() - start
  228. self.register({name: t})
  229. yield retval
  230. except StopIteration:
  231. break
  232. class Reporter:
  233. """Reporter class.
  234. Examples:
  235. >>> reporter = Reporter()
  236. >>> with reporter.observe('train') as sub_reporter:
  237. ... for batch in iterator:
  238. ... stats = dict(loss=0.2)
  239. ... sub_reporter.register(stats)
  240. """
  241. def __init__(self, epoch: int = 0):
  242. assert check_argument_types()
  243. if epoch < 0:
  244. raise ValueError(f"epoch must be 0 or more: {epoch}")
  245. self.epoch = epoch
  246. # stats: Dict[int, Dict[str, Dict[str, float]]]
  247. # e.g. self.stats[epoch]['train']['loss']
  248. self.stats = {}
  249. def get_epoch(self) -> int:
  250. return self.epoch
  251. def set_epoch(self, epoch: int) -> None:
  252. if epoch < 0:
  253. raise ValueError(f"epoch must be 0 or more: {epoch}")
  254. self.epoch = epoch
  255. @contextmanager
  256. def observe(self, key: str, epoch: int = None) -> ContextManager[SubReporter]:
  257. sub_reporter = self.start_epoch(key, epoch)
  258. yield sub_reporter
  259. # Receive the stats from sub_reporter
  260. self.finish_epoch(sub_reporter)
  261. def start_epoch(self, key: str, epoch: int = None) -> SubReporter:
  262. if epoch is not None:
  263. if epoch < 0:
  264. raise ValueError(f"epoch must be 0 or more: {epoch}")
  265. self.epoch = epoch
  266. if self.epoch - 1 not in self.stats or key not in self.stats[self.epoch - 1]:
  267. # If the previous epoch doesn't exist for some reason,
  268. # maybe due to bug, this case also indicates 0-count.
  269. if self.epoch - 1 != 0:
  270. warnings.warn(
  271. f"The stats of the previous epoch={self.epoch - 1}"
  272. f"doesn't exist."
  273. )
  274. total_count = 0
  275. else:
  276. total_count = self.stats[self.epoch - 1][key]["total_count"]
  277. sub_reporter = SubReporter(key, self.epoch, total_count)
  278. # Clear the stats for the next epoch if it exists
  279. self.stats.pop(epoch, None)
  280. return sub_reporter
  281. def finish_epoch(self, sub_reporter: SubReporter) -> None:
  282. if self.epoch != sub_reporter.epoch:
  283. raise RuntimeError(
  284. f"Don't change epoch during observation: "
  285. f"{self.epoch} != {sub_reporter.epoch}"
  286. )
  287. # Calc mean of current stats and set it as previous epochs stats
  288. stats = {}
  289. for key2, values in sub_reporter.stats.items():
  290. v = aggregate(values)
  291. stats[key2] = v
  292. stats["time"] = datetime.timedelta(
  293. seconds=time.perf_counter() - sub_reporter.start_time
  294. )
  295. stats["total_count"] = sub_reporter.total_count
  296. if LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
  297. if torch.cuda.is_initialized():
  298. stats["gpu_max_cached_mem_GB"] = (
  299. torch.cuda.max_memory_reserved() / 2 ** 30
  300. )
  301. else:
  302. if torch.cuda.is_available() and torch.cuda.max_memory_cached() > 0:
  303. stats["gpu_cached_mem_GB"] = torch.cuda.max_memory_cached() / 2 ** 30
  304. self.stats.setdefault(self.epoch, {})[sub_reporter.key] = stats
  305. sub_reporter.finished()
  306. def sort_epochs_and_values(
  307. self, key: str, key2: str, mode: str
  308. ) -> List[Tuple[int, float]]:
  309. """Return the epoch which resulted the best value.
  310. Example:
  311. >>> val = reporter.sort_epochs_and_values('eval', 'loss', 'min')
  312. >>> e_1best, v_1best = val[0]
  313. >>> e_2best, v_2best = val[1]
  314. """
  315. if mode not in ("min", "max"):
  316. raise ValueError(f"mode must min or max: {mode}")
  317. if not self.has(key, key2):
  318. raise KeyError(f"{key}.{key2} is not found: {self.get_all_keys()}")
  319. # iterate from the last epoch
  320. values = [(e, self.stats[e][key][key2]) for e in self.stats]
  321. if mode == "min":
  322. values = sorted(values, key=lambda x: x[1])
  323. else:
  324. values = sorted(values, key=lambda x: -x[1])
  325. return values
  326. def sort_epochs(self, key: str, key2: str, mode: str) -> List[int]:
  327. return [e for e, v in self.sort_epochs_and_values(key, key2, mode)]
  328. def sort_values(self, key: str, key2: str, mode: str) -> List[float]:
  329. return [v for e, v in self.sort_epochs_and_values(key, key2, mode)]
  330. def get_best_epoch(self, key: str, key2: str, mode: str, nbest: int = 0) -> int:
  331. return self.sort_epochs(key, key2, mode)[nbest]
  332. def check_early_stopping(
  333. self,
  334. patience: int,
  335. key1: str,
  336. key2: str,
  337. mode: str,
  338. epoch: int = None,
  339. logger=None,
  340. ) -> bool:
  341. if logger is None:
  342. logger = logging
  343. if epoch is None:
  344. epoch = self.get_epoch()
  345. best_epoch = self.get_best_epoch(key1, key2, mode)
  346. if epoch - best_epoch > patience:
  347. logger.info(
  348. f"[Early stopping] {key1}.{key2} has not been "
  349. f"improved {epoch - best_epoch} epochs continuously. "
  350. f"The training was stopped at {epoch}epoch"
  351. )
  352. return True
  353. else:
  354. return False
  355. def has(self, key: str, key2: str, epoch: int = None) -> bool:
  356. if epoch is None:
  357. epoch = self.get_epoch()
  358. return (
  359. epoch in self.stats
  360. and key in self.stats[epoch]
  361. and key2 in self.stats[epoch][key]
  362. )
  363. def log_message(self, epoch: int = None) -> str:
  364. if epoch is None:
  365. epoch = self.get_epoch()
  366. message = ""
  367. for key, d in self.stats[epoch].items():
  368. _message = ""
  369. for key2, v in d.items():
  370. if v is not None:
  371. if len(_message) != 0:
  372. _message += ", "
  373. if isinstance(v, float):
  374. if abs(v) > 1.0e3:
  375. _message += f"{key2}={v:.3e}"
  376. elif abs(v) > 1.0e-3:
  377. _message += f"{key2}={v:.3f}"
  378. else:
  379. _message += f"{key2}={v:.3e}"
  380. elif isinstance(v, datetime.timedelta):
  381. _v = humanfriendly.format_timespan(v)
  382. _message += f"{key2}={_v}"
  383. else:
  384. _message += f"{key2}={v}"
  385. if len(_message) != 0:
  386. if len(message) == 0:
  387. message += f"{epoch}epoch results: "
  388. else:
  389. message += ", "
  390. message += f"[{key}] {_message}"
  391. return message
  392. def get_value(self, key: str, key2: str, epoch: int = None):
  393. if not self.has(key, key2):
  394. raise KeyError(f"{key}.{key2} is not found in stats: {self.get_all_keys()}")
  395. if epoch is None:
  396. epoch = self.get_epoch()
  397. return self.stats[epoch][key][key2]
  398. def get_keys(self, epoch: int = None) -> Tuple[str, ...]:
  399. """Returns keys1 e.g. train,eval."""
  400. if epoch is None:
  401. epoch = self.get_epoch()
  402. return tuple(self.stats[epoch])
  403. def get_keys2(self, key: str, epoch: int = None) -> Tuple[str, ...]:
  404. """Returns keys2 e.g. loss,acc."""
  405. if epoch is None:
  406. epoch = self.get_epoch()
  407. d = self.stats[epoch][key]
  408. keys2 = tuple(k for k in d if k not in ("time", "total_count"))
  409. return keys2
  410. def get_all_keys(self, epoch: int = None) -> Tuple[Tuple[str, str], ...]:
  411. if epoch is None:
  412. epoch = self.get_epoch()
  413. all_keys = []
  414. for key in self.stats[epoch]:
  415. for key2 in self.stats[epoch][key]:
  416. all_keys.append((key, key2))
  417. return tuple(all_keys)
  418. def tensorboard_add_scalar(
  419. self, summary_writer, epoch: int = None, key1: str = None
  420. ):
  421. if epoch is None:
  422. epoch = self.get_epoch()
  423. total_count = self.stats[epoch]["train"]["total_count"]
  424. if key1 == "train":
  425. summary_writer.add_scalar("iter_epoch", epoch, total_count)
  426. if key1 is not None:
  427. key1_iterator = tuple([key1])
  428. else:
  429. key1_iterator = self.get_keys(epoch)
  430. for key1 in key1_iterator:
  431. for key2 in self.get_keys2(key1):
  432. summary_writer.add_scalar(
  433. f"{key2}", self.stats[epoch][key1][key2], total_count
  434. )
  435. def wandb_log(self, epoch: int = None):
  436. import wandb
  437. if epoch is None:
  438. epoch = self.get_epoch()
  439. d = {}
  440. for key1 in self.get_keys(epoch):
  441. for key2 in self.stats[epoch][key1]:
  442. if key2 in ("time", "total_count"):
  443. continue
  444. key = f"{key1}_{key2}_epoch"
  445. d[wandb_get_prefix(key) + key] = self.stats[epoch][key1][key2]
  446. d["epoch"] = epoch
  447. wandb.log(d)
  448. def state_dict(self):
  449. return {"stats": self.stats, "epoch": self.epoch}
  450. def load_state_dict(self, state_dict: dict):
  451. self.epoch = state_dict["epoch"]
  452. self.stats = state_dict["stats"]