| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531 |
- """Reporter module."""
- import dataclasses
- import datetime
- import logging
- import time
- import warnings
- from collections import defaultdict
- from contextlib import contextmanager
- from distutils.version import LooseVersion
- from typing import ContextManager
- from typing import Dict
- from typing import List
- from typing import Optional
- from typing import Sequence
- from typing import Tuple
- from typing import Union
- import humanfriendly
- import numpy as np
- import torch
- Num = Union[float, int, complex, torch.Tensor, np.ndarray]
- _reserved = {"time", "total_count"}
- def to_reported_value(v: Num, weight: Num = None) -> "ReportedValue":
- if isinstance(v, (torch.Tensor, np.ndarray)):
- if np.prod(v.shape) != 1:
- raise ValueError(f"v must be 0 or 1 dimension: {len(v.shape)}")
- v = v.item()
- if isinstance(weight, (torch.Tensor, np.ndarray)):
- if np.prod(weight.shape) != 1:
- raise ValueError(f"weight must be 0 or 1 dimension: {len(weight.shape)}")
- weight = weight.item()
- if weight is not None:
- retval = WeightedAverage(v, weight)
- else:
- retval = Average(v)
- return retval
- def aggregate(values: Sequence["ReportedValue"]) -> Num:
- for v in values:
- if not isinstance(v, type(values[0])):
- raise ValueError(
- f"Can't use different Reported type together: "
- f"{type(v)} != {type(values[0])}"
- )
- if len(values) == 0:
- warnings.warn("No stats found")
- retval = np.nan
- elif isinstance(values[0], Average):
- retval = np.nanmean([v.value for v in values])
- elif isinstance(values[0], WeightedAverage):
- # Excludes non finite values
- invalid_indices = set()
- for i, v in enumerate(values):
- if not np.isfinite(v.value) or not np.isfinite(v.weight):
- invalid_indices.add(i)
- values = [v for i, v in enumerate(values) if i not in invalid_indices]
- if len(values) != 0:
- # Calc weighed average. Weights are changed to sum-to-1.
- sum_weights = sum(v.weight for i, v in enumerate(values))
- sum_value = sum(v.value * v.weight for i, v in enumerate(values))
- if sum_weights == 0:
- warnings.warn("weight is zero")
- retval = np.nan
- else:
- retval = sum_value / sum_weights
- else:
- warnings.warn("No valid stats found")
- retval = np.nan
- else:
- raise NotImplementedError(f"type={type(values[0])}")
- return retval
- def wandb_get_prefix(key: str):
- if key.startswith("valid"):
- return "valid/"
- if key.startswith("train"):
- return "train/"
- if key.startswith("attn"):
- return "attn/"
- return "metrics/"
- class ReportedValue:
- pass
- @dataclasses.dataclass(frozen=True)
- class Average(ReportedValue):
- value: Num
- @dataclasses.dataclass(frozen=True)
- class WeightedAverage(ReportedValue):
- value: Tuple[Num, Num]
- weight: Num
- class SubReporter:
- """This class is used in Reporter.
- See the docstring of Reporter for the usage.
- """
- def __init__(self, key: str, epoch: int, total_count: int):
- self.key = key
- self.epoch = epoch
- self.start_time = time.perf_counter()
- self.stats = defaultdict(list)
- self._finished = False
- self.total_count = total_count
- self.count = 0
- self._seen_keys_in_the_step = set()
- def get_total_count(self) -> int:
- """Returns the number of iterations over all epochs."""
- return self.total_count
- def get_epoch(self) -> int:
- return self.epoch
- def next(self):
- """Close up this step and reset state for the next step"""
- for key, stats_list in self.stats.items():
- if key not in self._seen_keys_in_the_step:
- # Fill nan value if the key is not registered in this step
- if isinstance(stats_list[0], WeightedAverage):
- stats_list.append(to_reported_value(np.nan, 0))
- elif isinstance(stats_list[0], Average):
- stats_list.append(to_reported_value(np.nan))
- else:
- raise NotImplementedError(f"type={type(stats_list[0])}")
- assert len(stats_list) == self.count, (len(stats_list), self.count)
- self._seen_keys_in_the_step = set()
- def register(
- self,
- stats: Dict[str, Optional[Union[Num, Dict[str, Num]]]],
- weight: Num = None,
- ) -> None:
- if self._finished:
- raise RuntimeError("Already finished")
- if len(self._seen_keys_in_the_step) == 0:
- # Increment count as the first register in this step
- self.total_count += 1
- self.count += 1
- for key2, v in stats.items():
- if key2 in _reserved:
- raise RuntimeError(f"{key2} is reserved.")
- if key2 in self._seen_keys_in_the_step:
- raise RuntimeError(f"{key2} is registered twice.")
- if v is None:
- v = np.nan
- r = to_reported_value(v, weight)
- if key2 not in self.stats:
- # If it's the first time to register the key,
- # append nan values in front of the the value
- # to make it same length to the other stats
- # e.g.
- # stat A: [0.4, 0.3, 0.5]
- # stat B: [nan, nan, 0.2]
- nan = to_reported_value(np.nan, None if weight is None else 0)
- self.stats[key2].extend(
- r if i == self.count - 1 else nan for i in range(self.count)
- )
- else:
- self.stats[key2].append(r)
- self._seen_keys_in_the_step.add(key2)
- def log_message(self, start: int = None, end: int = None, num_updates: int = None) -> str:
- if self._finished:
- raise RuntimeError("Already finished")
- if start is None:
- start = 0
- if start < 0:
- start = self.count + start
- if end is None:
- end = self.count
- if self.count == 0 or start == end:
- return ""
- message = f"{self.epoch}epoch:{self.key}:" f"{start + 1}-{end}batch:"
- if num_updates is not None:
- message += f"{num_updates}num_updates: "
- for idx, (key2, stats_list) in enumerate(self.stats.items()):
- assert len(stats_list) == self.count, (len(stats_list), self.count)
- # values: List[ReportValue]
- values = stats_list[start:end]
- if idx != 0 and idx != len(stats_list):
- message += ", "
- v = aggregate(values)
- if abs(v) > 1.0e3:
- message += f"{key2}={v:.3e}"
- elif abs(v) > 1.0e-3:
- message += f"{key2}={v:.3f}"
- else:
- message += f"{key2}={v:.3e}"
- return message
- def tensorboard_add_scalar(self, summary_writer, start: int = None):
- if start is None:
- start = 0
- if start < 0:
- start = self.count + start
- for key2, stats_list in self.stats.items():
- assert len(stats_list) == self.count, (len(stats_list), self.count)
- # values: List[ReportValue]
- values = stats_list[start:]
- v = aggregate(values)
- summary_writer.add_scalar(f"{key2}", v, self.total_count)
- def wandb_log(self, start: int = None):
- import wandb
- if start is None:
- start = 0
- if start < 0:
- start = self.count + start
- d = {}
- for key2, stats_list in self.stats.items():
- assert len(stats_list) == self.count, (len(stats_list), self.count)
- # values: List[ReportValue]
- values = stats_list[start:]
- v = aggregate(values)
- d[wandb_get_prefix(key2) + key2] = v
- d["iteration"] = self.total_count
- wandb.log(d)
- def finished(self) -> None:
- self._finished = True
- @contextmanager
- def measure_time(self, name: str):
- start = time.perf_counter()
- yield start
- t = time.perf_counter() - start
- self.register({name: t})
- def measure_iter_time(self, iterable, name: str):
- iterator = iter(iterable)
- while True:
- try:
- start = time.perf_counter()
- retval = next(iterator)
- t = time.perf_counter() - start
- self.register({name: t})
- yield retval
- except StopIteration:
- break
- class Reporter:
- """Reporter class.
- Examples:
- >>> reporter = Reporter()
- >>> with reporter.observe('train') as sub_reporter:
- ... for batch in iterator:
- ... stats = dict(loss=0.2)
- ... sub_reporter.register(stats)
- """
- def __init__(self, epoch: int = 0):
- if epoch < 0:
- raise ValueError(f"epoch must be 0 or more: {epoch}")
- self.epoch = epoch
- # stats: Dict[int, Dict[str, Dict[str, float]]]
- # e.g. self.stats[epoch]['train']['loss']
- self.stats = {}
- def get_epoch(self) -> int:
- return self.epoch
- def set_epoch(self, epoch: int) -> None:
- if epoch < 0:
- raise ValueError(f"epoch must be 0 or more: {epoch}")
- self.epoch = epoch
- @contextmanager
- def observe(self, key: str, epoch: int = None) -> ContextManager[SubReporter]:
- sub_reporter = self.start_epoch(key, epoch)
- yield sub_reporter
- # Receive the stats from sub_reporter
- self.finish_epoch(sub_reporter)
- def start_epoch(self, key: str, epoch: int = None) -> SubReporter:
- if epoch is not None:
- if epoch < 0:
- raise ValueError(f"epoch must be 0 or more: {epoch}")
- self.epoch = epoch
- if self.epoch - 1 not in self.stats or key not in self.stats[self.epoch - 1]:
- # If the previous epoch doesn't exist for some reason,
- # maybe due to bug, this case also indicates 0-count.
- if self.epoch - 1 != 0:
- warnings.warn(
- f"The stats of the previous epoch={self.epoch - 1}"
- f"doesn't exist."
- )
- total_count = 0
- else:
- total_count = self.stats[self.epoch - 1][key]["total_count"]
- sub_reporter = SubReporter(key, self.epoch, total_count)
- # Clear the stats for the next epoch if it exists
- self.stats.pop(epoch, None)
- return sub_reporter
- def finish_epoch(self, sub_reporter: SubReporter) -> None:
- if self.epoch != sub_reporter.epoch:
- raise RuntimeError(
- f"Don't change epoch during observation: "
- f"{self.epoch} != {sub_reporter.epoch}"
- )
- # Calc mean of current stats and set it as previous epochs stats
- stats = {}
- for key2, values in sub_reporter.stats.items():
- v = aggregate(values)
- stats[key2] = v
- stats["time"] = datetime.timedelta(
- seconds=time.perf_counter() - sub_reporter.start_time
- )
- stats["total_count"] = sub_reporter.total_count
- if LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
- if torch.cuda.is_initialized():
- stats["gpu_max_cached_mem_GB"] = (
- torch.cuda.max_memory_reserved() / 2 ** 30
- )
- else:
- if torch.cuda.is_available() and torch.cuda.max_memory_cached() > 0:
- stats["gpu_cached_mem_GB"] = torch.cuda.max_memory_cached() / 2 ** 30
- self.stats.setdefault(self.epoch, {})[sub_reporter.key] = stats
- sub_reporter.finished()
- def sort_epochs_and_values(
- self, key: str, key2: str, mode: str
- ) -> List[Tuple[int, float]]:
- """Return the epoch which resulted the best value.
- Example:
- >>> val = reporter.sort_epochs_and_values('eval', 'loss', 'min')
- >>> e_1best, v_1best = val[0]
- >>> e_2best, v_2best = val[1]
- """
- if mode not in ("min", "max"):
- raise ValueError(f"mode must min or max: {mode}")
- if not self.has(key, key2):
- raise KeyError(f"{key}.{key2} is not found: {self.get_all_keys()}")
- # iterate from the last epoch
- values = [(e, self.stats[e][key][key2]) for e in self.stats]
- if mode == "min":
- values = sorted(values, key=lambda x: x[1])
- else:
- values = sorted(values, key=lambda x: -x[1])
- return values
- def sort_epochs(self, key: str, key2: str, mode: str) -> List[int]:
- return [e for e, v in self.sort_epochs_and_values(key, key2, mode)]
- def sort_values(self, key: str, key2: str, mode: str) -> List[float]:
- return [v for e, v in self.sort_epochs_and_values(key, key2, mode)]
- def get_best_epoch(self, key: str, key2: str, mode: str, nbest: int = 0) -> int:
- return self.sort_epochs(key, key2, mode)[nbest]
- def check_early_stopping(
- self,
- patience: int,
- key1: str,
- key2: str,
- mode: str,
- epoch: int = None,
- logger=None,
- ) -> bool:
- if logger is None:
- logger = logging
- if epoch is None:
- epoch = self.get_epoch()
- best_epoch = self.get_best_epoch(key1, key2, mode)
- if epoch - best_epoch > patience:
- logger.info(
- f"[Early stopping] {key1}.{key2} has not been "
- f"improved {epoch - best_epoch} epochs continuously. "
- f"The training was stopped at {epoch}epoch"
- )
- return True
- else:
- return False
- def has(self, key: str, key2: str, epoch: int = None) -> bool:
- if epoch is None:
- epoch = self.get_epoch()
- return (
- epoch in self.stats
- and key in self.stats[epoch]
- and key2 in self.stats[epoch][key]
- )
- def log_message(self, epoch: int = None) -> str:
- if epoch is None:
- epoch = self.get_epoch()
- message = ""
- for key, d in self.stats[epoch].items():
- _message = ""
- for key2, v in d.items():
- if v is not None:
- if len(_message) != 0:
- _message += ", "
- if isinstance(v, float):
- if abs(v) > 1.0e3:
- _message += f"{key2}={v:.3e}"
- elif abs(v) > 1.0e-3:
- _message += f"{key2}={v:.3f}"
- else:
- _message += f"{key2}={v:.3e}"
- elif isinstance(v, datetime.timedelta):
- _v = humanfriendly.format_timespan(v)
- _message += f"{key2}={_v}"
- else:
- _message += f"{key2}={v}"
- if len(_message) != 0:
- if len(message) == 0:
- message += f"{epoch}epoch results: "
- else:
- message += ", "
- message += f"[{key}] {_message}"
- return message
- def get_value(self, key: str, key2: str, epoch: int = None):
- if not self.has(key, key2):
- raise KeyError(f"{key}.{key2} is not found in stats: {self.get_all_keys()}")
- if epoch is None:
- epoch = self.get_epoch()
- return self.stats[epoch][key][key2]
- def get_keys(self, epoch: int = None) -> Tuple[str, ...]:
- """Returns keys1 e.g. train,eval."""
- if epoch is None:
- epoch = self.get_epoch()
- return tuple(self.stats[epoch])
- def get_keys2(self, key: str, epoch: int = None) -> Tuple[str, ...]:
- """Returns keys2 e.g. loss,acc."""
- if epoch is None:
- epoch = self.get_epoch()
- d = self.stats[epoch][key]
- keys2 = tuple(k for k in d if k not in ("time", "total_count"))
- return keys2
- def get_all_keys(self, epoch: int = None) -> Tuple[Tuple[str, str], ...]:
- if epoch is None:
- epoch = self.get_epoch()
- all_keys = []
- for key in self.stats[epoch]:
- for key2 in self.stats[epoch][key]:
- all_keys.append((key, key2))
- return tuple(all_keys)
- def tensorboard_add_scalar(
- self, summary_writer, epoch: int = None, key1: str = None
- ):
- if epoch is None:
- epoch = self.get_epoch()
- total_count = self.stats[epoch]["train"]["total_count"]
- if key1 == "train":
- summary_writer.add_scalar("iter_epoch", epoch, total_count)
- if key1 is not None:
- key1_iterator = tuple([key1])
- else:
- key1_iterator = self.get_keys(epoch)
- for key1 in key1_iterator:
- for key2 in self.get_keys2(key1):
- summary_writer.add_scalar(
- f"{key2}", self.stats[epoch][key1][key2], total_count
- )
- def wandb_log(self, epoch: int = None):
- import wandb
- if epoch is None:
- epoch = self.get_epoch()
- d = {}
- for key1 in self.get_keys(epoch):
- for key2 in self.stats[epoch][key1]:
- if key2 in ("time", "total_count"):
- continue
- key = f"{key1}_{key2}_epoch"
- d[wandb_get_prefix(key) + key] = self.stats[epoch][key1][key2]
- d["epoch"] = epoch
- wandb.log(d)
- def state_dict(self):
- return {"stats": self.stats, "epoch": self.epoch}
- def load_state_dict(self, state_dict: dict):
- self.epoch = state_dict["epoch"]
- self.stats = state_dict["stats"]
|