compute_cmvn.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import argparse
  2. import json
  3. import os
  4. import numpy as np
  5. import torchaudio
  6. import torchaudio.compliance.kaldi as kaldi
  7. import yaml
  8. def get_parser():
  9. parser = argparse.ArgumentParser(
  10. description="computer global cmvn",
  11. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  12. )
  13. parser.add_argument(
  14. "--dim",
  15. default=80,
  16. type=int,
  17. help="feature dimension",
  18. )
  19. parser.add_argument(
  20. "--wav_path",
  21. default=False,
  22. required=True,
  23. type=str,
  24. help="the path of wav scps",
  25. )
  26. parser.add_argument(
  27. "--config_file",
  28. type=str,
  29. help="the config file for computing cmvn",
  30. )
  31. parser.add_argument(
  32. "--idx",
  33. default=1,
  34. required=True,
  35. type=int,
  36. help="index",
  37. )
  38. return parser
  39. def compute_fbank(wav_file,
  40. num_mel_bins=80,
  41. frame_length=25,
  42. frame_shift=10,
  43. dither=0.0,
  44. resample_rate=16000,
  45. speed=1.0,
  46. window_type="hamming"):
  47. waveform, sample_rate = torchaudio.load(wav_file)
  48. if resample_rate != sample_rate:
  49. waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
  50. new_freq=resample_rate)(waveform)
  51. if speed != 1.0:
  52. waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
  53. waveform, resample_rate,
  54. [['speed', str(speed)], ['rate', str(resample_rate)]]
  55. )
  56. waveform = waveform * (1 << 15)
  57. mat = kaldi.fbank(waveform,
  58. num_mel_bins=num_mel_bins,
  59. frame_length=frame_length,
  60. frame_shift=frame_shift,
  61. dither=dither,
  62. energy_floor=0.0,
  63. window_type=window_type,
  64. sample_frequency=resample_rate)
  65. return mat.numpy()
  66. def main():
  67. parser = get_parser()
  68. args = parser.parse_args()
  69. wav_scp_file = os.path.join(args.wav_path, "wav.{}.scp".format(args.idx))
  70. cmvn_file = os.path.join(args.wav_path, "cmvn.{}.json".format(args.idx))
  71. mean_stats = np.zeros(args.dim)
  72. var_stats = np.zeros(args.dim)
  73. total_frames = 0
  74. # with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
  75. # for key, mat in ark_reader:
  76. # mean_stats += np.sum(mat, axis=0)
  77. # var_stats += np.sum(np.square(mat), axis=0)
  78. # total_frames += mat.shape[0]
  79. with open(args.config_file) as f:
  80. configs = yaml.safe_load(f)
  81. frontend_configs = configs.get("frontend_conf", {})
  82. num_mel_bins = frontend_configs.get("n_mels", 80)
  83. frame_length = frontend_configs.get("frame_length", 25)
  84. frame_shift = frontend_configs.get("frame_shift", 10)
  85. window_type = frontend_configs.get("window", "hamming")
  86. resample_rate = frontend_configs.get("fs", 16000)
  87. assert num_mel_bins == args.dim
  88. with open(wav_scp_file) as f:
  89. lines = f.readlines()
  90. for line in lines:
  91. _, wav_file = line.strip().split()
  92. fbank = compute_fbank(wav_file,
  93. num_mel_bins=args.dim,
  94. frame_length=frame_length,
  95. frame_shift=frame_shift,
  96. resample_rate=resample_rate,
  97. window_type=window_type)
  98. mean_stats += np.sum(fbank, axis=0)
  99. var_stats += np.sum(np.square(fbank), axis=0)
  100. total_frames += fbank.shape[0]
  101. cmvn_info = {
  102. 'mean_stats': list(mean_stats.tolist()),
  103. 'var_stats': list(var_stats.tolist()),
  104. 'total_frames': total_frames
  105. }
  106. with open(cmvn_file, 'w') as fout:
  107. fout.write(json.dumps(cmvn_info))
  108. if __name__ == '__main__':
  109. main()