collect_stats.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from collections import defaultdict
  2. import logging
  3. from pathlib import Path
  4. from typing import Dict
  5. from typing import Iterable
  6. from typing import List
  7. from typing import Optional
  8. from typing import Tuple
  9. import numpy as np
  10. import torch
  11. from torch.nn.parallel import data_parallel
  12. from torch.utils.data import DataLoader
  13. from funasr.fileio.datadir_writer import DatadirWriter
  14. from funasr.fileio.npy_scp import NpyScpWriter
  15. from funasr.torch_utils.device_funcs import to_device
  16. from funasr.torch_utils.forward_adaptor import ForwardAdaptor
  17. from funasr.models.base_model import FunASRModel
  18. @torch.no_grad()
  19. def collect_stats(
  20. model: FunASRModel,
  21. train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
  22. valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
  23. output_dir: Path,
  24. ngpu: Optional[int],
  25. log_interval: Optional[int],
  26. write_collected_feats: bool,
  27. ) -> None:
  28. """Perform on collect_stats mode.
  29. Running for deriving the shape information from data
  30. and gathering statistics.
  31. This method is used before executing train().
  32. """
  33. npy_scp_writers = {}
  34. for itr, mode in zip([train_iter, valid_iter], ["train", "valid"]):
  35. if log_interval is None:
  36. try:
  37. log_interval = max(len(itr) // 20, 10)
  38. except TypeError:
  39. log_interval = 100
  40. sum_dict = defaultdict(lambda: 0)
  41. sq_dict = defaultdict(lambda: 0)
  42. count_dict = defaultdict(lambda: 0)
  43. with DatadirWriter(output_dir / mode) as datadir_writer:
  44. for iiter, (keys, batch) in enumerate(itr, 1):
  45. batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
  46. # 1. Write shape file
  47. for name in batch:
  48. if name.endswith("_lengths"):
  49. continue
  50. for i, (key, data) in enumerate(zip(keys, batch[name])):
  51. if f"{name}_lengths" in batch:
  52. lg = int(batch[f"{name}_lengths"][i])
  53. data = data[:lg]
  54. datadir_writer[f"{name}_shape"][key] = ",".join(
  55. map(str, data.shape)
  56. )
  57. # 2. Extract feats
  58. if ngpu <= 1:
  59. data = model.collect_feats(**batch)
  60. else:
  61. # Note that data_parallel can parallelize only "forward()"
  62. data = data_parallel(
  63. ForwardAdaptor(model, "collect_feats"),
  64. (),
  65. range(ngpu),
  66. module_kwargs=batch,
  67. )
  68. # 3. Calculate sum and square sum
  69. for key, v in data.items():
  70. for i, (uttid, seq) in enumerate(zip(keys, v.cpu().numpy())):
  71. # Truncate zero-padding region
  72. if f"{key}_lengths" in data:
  73. length = data[f"{key}_lengths"][i]
  74. # seq: (Length, Dim, ...)
  75. seq = seq[:length]
  76. else:
  77. # seq: (Dim, ...) -> (1, Dim, ...)
  78. seq = seq[None]
  79. # Accumulate value, its square, and count
  80. sum_dict[key] += seq.sum(0)
  81. sq_dict[key] += (seq**2).sum(0)
  82. count_dict[key] += len(seq)
  83. # 4. [Option] Write derived features as npy format file.
  84. if write_collected_feats:
  85. # Instantiate NpyScpWriter for the first iteration
  86. if (key, mode) not in npy_scp_writers:
  87. p = output_dir / mode / "collect_feats"
  88. npy_scp_writers[(key, mode)] = NpyScpWriter(
  89. p / f"data_{key}", p / f"{key}.scp"
  90. )
  91. # Save array as npy file
  92. npy_scp_writers[(key, mode)][uttid] = seq
  93. if iiter % log_interval == 0:
  94. logging.info(f"Niter: {iiter}")
  95. for key in sum_dict:
  96. np.savez(
  97. output_dir / mode / f"{key}_stats.npz",
  98. count=count_dict[key],
  99. sum=sum_dict[key],
  100. sum_square=sq_dict[key],
  101. )
  102. # batch_keys and stats_keys are used by aggregate_stats_dirs.py
  103. with (output_dir / mode / "batch_keys").open("w", encoding="utf-8") as f:
  104. f.write(
  105. "\n".join(filter(lambda x: not x.endswith("_lengths"), batch)) + "\n"
  106. )
  107. with (output_dir / mode / "stats_keys").open("w", encoding="utf-8") as f:
  108. f.write("\n".join(sum_dict) + "\n")