collect_stats.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from collections import defaultdict
  2. import logging
  3. from pathlib import Path
  4. from typing import Dict
  5. from typing import Iterable
  6. from typing import List
  7. from typing import Optional
  8. from typing import Tuple
  9. import numpy as np
  10. import torch
  11. from torch.nn.parallel import data_parallel
  12. from torch.utils.data import DataLoader
  13. from typeguard import check_argument_types
  14. from funasr.fileio.datadir_writer import DatadirWriter
  15. from funasr.fileio.npy_scp import NpyScpWriter
  16. from funasr.torch_utils.device_funcs import to_device
  17. from funasr.torch_utils.forward_adaptor import ForwardAdaptor
  18. from funasr.train.abs_espnet_model import AbsESPnetModel
  19. @torch.no_grad()
  20. def collect_stats(
  21. model: AbsESPnetModel,
  22. train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
  23. valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
  24. output_dir: Path,
  25. ngpu: Optional[int],
  26. log_interval: Optional[int],
  27. write_collected_feats: bool,
  28. ) -> None:
  29. """Perform on collect_stats mode.
  30. Running for deriving the shape information from data
  31. and gathering statistics.
  32. This method is used before executing train().
  33. """
  34. assert check_argument_types()
  35. npy_scp_writers = {}
  36. for itr, mode in zip([train_iter, valid_iter], ["train", "valid"]):
  37. if log_interval is None:
  38. try:
  39. log_interval = max(len(itr) // 20, 10)
  40. except TypeError:
  41. log_interval = 100
  42. sum_dict = defaultdict(lambda: 0)
  43. sq_dict = defaultdict(lambda: 0)
  44. count_dict = defaultdict(lambda: 0)
  45. with DatadirWriter(output_dir / mode) as datadir_writer:
  46. for iiter, (keys, batch) in enumerate(itr, 1):
  47. batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
  48. # 1. Write shape file
  49. for name in batch:
  50. if name.endswith("_lengths"):
  51. continue
  52. for i, (key, data) in enumerate(zip(keys, batch[name])):
  53. if f"{name}_lengths" in batch:
  54. lg = int(batch[f"{name}_lengths"][i])
  55. data = data[:lg]
  56. datadir_writer[f"{name}_shape"][key] = ",".join(
  57. map(str, data.shape)
  58. )
  59. # 2. Extract feats
  60. if ngpu <= 1:
  61. data = model.collect_feats(**batch)
  62. else:
  63. # Note that data_parallel can parallelize only "forward()"
  64. data = data_parallel(
  65. ForwardAdaptor(model, "collect_feats"),
  66. (),
  67. range(ngpu),
  68. module_kwargs=batch,
  69. )
  70. # 3. Calculate sum and square sum
  71. for key, v in data.items():
  72. for i, (uttid, seq) in enumerate(zip(keys, v.cpu().numpy())):
  73. # Truncate zero-padding region
  74. if f"{key}_lengths" in data:
  75. length = data[f"{key}_lengths"][i]
  76. # seq: (Length, Dim, ...)
  77. seq = seq[:length]
  78. else:
  79. # seq: (Dim, ...) -> (1, Dim, ...)
  80. seq = seq[None]
  81. # Accumulate value, its square, and count
  82. sum_dict[key] += seq.sum(0)
  83. sq_dict[key] += (seq**2).sum(0)
  84. count_dict[key] += len(seq)
  85. # 4. [Option] Write derived features as npy format file.
  86. if write_collected_feats:
  87. # Instantiate NpyScpWriter for the first iteration
  88. if (key, mode) not in npy_scp_writers:
  89. p = output_dir / mode / "collect_feats"
  90. npy_scp_writers[(key, mode)] = NpyScpWriter(
  91. p / f"data_{key}", p / f"{key}.scp"
  92. )
  93. # Save array as npy file
  94. npy_scp_writers[(key, mode)][uttid] = seq
  95. if iiter % log_interval == 0:
  96. logging.info(f"Niter: {iiter}")
  97. for key in sum_dict:
  98. np.savez(
  99. output_dir / mode / f"{key}_stats.npz",
  100. count=count_dict[key],
  101. sum=sum_dict[key],
  102. sum_square=sq_dict[key],
  103. )
  104. # batch_keys and stats_keys are used by aggregate_stats_dirs.py
  105. with (output_dir / mode / "batch_keys").open("w", encoding="utf-8") as f:
  106. f.write(
  107. "\n".join(filter(lambda x: not x.endswith("_lengths"), batch)) + "\n"
  108. )
  109. with (output_dir / mode / "stats_keys").open("w", encoding="utf-8") as f:
  110. f.write("\n".join(sum_dict) + "\n")