| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- import logging
- from pathlib import Path
- from typing import Optional
- from typing import Sequence
- from typing import Union
- import warnings
- import os
- from io import BytesIO
- import torch
- from typing import Collection
- import os
- import torch
- import re
- from collections import OrderedDict
- from functools import cmp_to_key
- # @torch.no_grad()
- # def average_nbest_models(
- # output_dir: Path,
- # best_model_criterion: Sequence[Sequence[str]],
- # nbest: Union[Collection[int], int],
- # suffix: Optional[str] = None,
- # oss_bucket=None,
- # pai_output_dir=None,
- # ) -> None:
- # """Generate averaged model from n-best models
- #
- # Args:
- # output_dir: The directory contains the model file for each epoch
- # reporter: Reporter instance
- # best_model_criterion: Give criterions to decide the best model.
- # e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
- # nbest: Number of best model files to be averaged
- # suffix: A suffix added to the averaged model file name
- # """
- # if isinstance(nbest, int):
- # nbests = [nbest]
- # else:
- # nbests = list(nbest)
- # if len(nbests) == 0:
- # warnings.warn("At least 1 nbest values are required")
- # nbests = [1]
- # if suffix is not None:
- # suffix = suffix + "."
- # else:
- # suffix = ""
- #
- # # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
- # nbest_epochs = [
- # (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
- # for ph, k, m in best_model_criterion
- # if reporter.has(ph, k)
- # ]
- #
- # _loaded = {}
- # for ph, cr, epoch_and_values in nbest_epochs:
- # _nbests = [i for i in nbests if i <= len(epoch_and_values)]
- # if len(_nbests) == 0:
- # _nbests = [1]
- #
- # for n in _nbests:
- # if n == 0:
- # continue
- # elif n == 1:
- # # The averaged model is same as the best model
- # e, _ = epoch_and_values[0]
- # op = output_dir / f"{e}epoch.pb"
- # sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
- # if sym_op.is_symlink() or sym_op.exists():
- # sym_op.unlink()
- # sym_op.symlink_to(op.name)
- # else:
- # op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
- # logging.info(
- # f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
- # )
- #
- # avg = None
- # # 2.a. Averaging model
- # for e, _ in epoch_and_values[:n]:
- # if e not in _loaded:
- # if oss_bucket is None:
- # _loaded[e] = torch.load(
- # output_dir / f"{e}epoch.pb",
- # map_location="cpu",
- # )
- # else:
- # buffer = BytesIO(
- # oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
- # _loaded[e] = torch.load(buffer)
- # states = _loaded[e]
- #
- # if avg is None:
- # avg = states
- # else:
- # # Accumulated
- # for k in avg:
- # avg[k] = avg[k] + states[k]
- # for k in avg:
- # if str(avg[k].dtype).startswith("torch.int"):
- # # For int type, not averaged, but only accumulated.
- # # e.g. BatchNorm.num_batches_tracked
- # # (If there are any cases that requires averaging
- # # or the other reducing method, e.g. max/min, for integer type,
- # # please report.)
- # pass
- # else:
- # avg[k] = avg[k] / n
- #
- # # 2.b. Save the ave model and create a symlink
- # if oss_bucket is None:
- # torch.save(avg, op)
- # else:
- # buffer = BytesIO()
- # torch.save(avg, buffer)
- # oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
- # buffer.getvalue())
- #
- # # 3. *.*.ave.pb is a symlink to the max ave model
- # if oss_bucket is None:
- # op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
- # sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
- # if sym_op.is_symlink() or sym_op.exists():
- # sym_op.unlink()
- # sym_op.symlink_to(op.name)
- def _get_checkpoint_paths(output_dir: str, last_n: int=5):
- """
- Get the paths of the last 'last_n' checkpoints by parsing filenames
- in the output directory.
- """
- # List all files in the output directory
- files = os.listdir(output_dir)
- # Filter out checkpoint files and extract epoch numbers
- checkpoint_files = [f for f in files if f.startswith("model.pt.e")]
- # Sort files by epoch number in descending order
- checkpoint_files.sort(key=lambda x: int(re.search(r'(\d+)', x).group()), reverse=True)
- # Get the last 'last_n' checkpoint paths
- checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]]
- return checkpoint_paths
- @torch.no_grad()
- def average_checkpoints(output_dir: str, last_n: int=5):
- """
- Average the last 'last_n' checkpoints' model state_dicts.
- If a tensor is of type torch.int, perform sum instead of average.
- """
- checkpoint_paths = _get_checkpoint_paths(output_dir, last_n)
- state_dicts = []
- # Load state_dicts from checkpoints
- for path in checkpoint_paths:
- if os.path.isfile(path):
- state_dicts.append(torch.load(path, map_location='cpu')['state_dict'])
- else:
- print(f"Checkpoint file {path} not found.")
- continue
- # Check if we have any state_dicts to average
- if not state_dicts:
- raise RuntimeError("No checkpoints found for averaging.")
- # Average or sum weights
- avg_state_dict = OrderedDict()
- for key in state_dicts[0].keys():
- tensors = [state_dict[key].cpu() for state_dict in state_dicts]
- # Check the type of the tensor
- if str(tensors[0].dtype).startswith("torch.int"):
- # Perform sum for integer tensors
- summed_tensor = sum(tensors)
- avg_state_dict[key] = summed_tensor
- else:
- # Perform average for other types of tensors
- stacked_tensors = torch.stack(tensors)
- avg_state_dict[key] = torch.mean(stacked_tensors, dim=0)
-
- torch.save({'state_dict': avg_state_dict}, os.path.join(output_dir, f"model.pt.avg{last_n}"))
- return avg_state_dict
|