aggregate_stats_dirs.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. #!/usr/bin/env python3
  2. import argparse
  3. import logging
  4. import sys
  5. from pathlib import Path
  6. from typing import Iterable
  7. from typing import Union
  8. import numpy as np
  9. from funasr.utils.cli_utils import get_commandline_args
  10. def aggregate_stats_dirs(
  11. input_dir: Iterable[Union[str, Path]],
  12. output_dir: Union[str, Path],
  13. log_level: str,
  14. skip_sum_stats: bool,
  15. ):
  16. logging.basicConfig(
  17. level=log_level,
  18. format="%(asctime)s (%(module)s:%(lineno)d) (levelname)s: %(message)s",
  19. )
  20. input_dirs = [Path(p) for p in input_dir]
  21. output_dir = Path(output_dir)
  22. for mode in ["train", "valid"]:
  23. with (input_dirs[0] / mode / "batch_keys").open("r", encoding="utf-8") as f:
  24. batch_keys = [line.strip() for line in f if line.strip() != ""]
  25. with (input_dirs[0] / mode / "stats_keys").open("r", encoding="utf-8") as f:
  26. stats_keys = [line.strip() for line in f if line.strip() != ""]
  27. (output_dir / mode).mkdir(parents=True, exist_ok=True)
  28. for key in batch_keys:
  29. with (output_dir / mode / f"{key}_shape").open(
  30. "w", encoding="utf-8"
  31. ) as fout:
  32. for idir in input_dirs:
  33. with (idir / mode / f"{key}_shape").open(
  34. "r", encoding="utf-8"
  35. ) as fin:
  36. # Read to the last in order to sort keys
  37. # because the order can be changed if num_workers>=1
  38. lines = fin.readlines()
  39. lines = sorted(lines, key=lambda x: x.split()[0])
  40. for line in lines:
  41. fout.write(line)
  42. for key in stats_keys:
  43. if not skip_sum_stats:
  44. sum_stats = None
  45. for idir in input_dirs:
  46. stats = np.load(idir / mode / f"{key}_stats.npz")
  47. if sum_stats is None:
  48. sum_stats = dict(**stats)
  49. else:
  50. for k in stats:
  51. sum_stats[k] += stats[k]
  52. np.savez(output_dir / mode / f"{key}_stats.npz", **sum_stats)
  53. # if --write_collected_feats=true
  54. p = Path(mode) / "collect_feats" / f"{key}.scp"
  55. scp = input_dirs[0] / p
  56. if scp.exists():
  57. (output_dir / p).parent.mkdir(parents=True, exist_ok=True)
  58. with (output_dir / p).open("w", encoding="utf-8") as fout:
  59. for idir in input_dirs:
  60. with (idir / p).open("r", encoding="utf-8") as fin:
  61. for line in fin:
  62. fout.write(line)
  63. def get_parser() -> argparse.ArgumentParser:
  64. parser = argparse.ArgumentParser(
  65. description="Aggregate statistics directories into one directory",
  66. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  67. )
  68. parser.add_argument(
  69. "--log_level",
  70. type=lambda x: x.upper(),
  71. default="INFO",
  72. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  73. help="The verbose level of logging",
  74. )
  75. parser.add_argument(
  76. "--skip_sum_stats",
  77. default=False,
  78. action="store_true",
  79. help="Skip computing the sum of statistics.",
  80. )
  81. parser.add_argument("--input_dir", action="append", help="Input directories")
  82. parser.add_argument("--output_dir", required=True, help="Output directory")
  83. return parser
  84. def main(cmd=None):
  85. print(get_commandline_args(), file=sys.stderr)
  86. parser = get_parser()
  87. args = parser.parse_args(cmd)
  88. kwargs = vars(args)
  89. aggregate_stats_dirs(**kwargs)
  90. if __name__ == "__main__":
  91. main()