compute_fbank.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. from kaldiio import WriteHelper
  2. import argparse
  3. import numpy as np
  4. import json
  5. import torch
  6. import torchaudio
  7. import torchaudio.compliance.kaldi as kaldi
  8. def compute_fbank(wav_file,
  9. num_mel_bins=80,
  10. frame_length=25,
  11. frame_shift=10,
  12. dither=0.0,
  13. resample_rate=16000,
  14. speed=1.0,
  15. window_type="hamming"):
  16. waveform, sample_rate = torchaudio.load(wav_file)
  17. if resample_rate != sample_rate:
  18. waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
  19. new_freq=resample_rate)(waveform)
  20. if speed != 1.0:
  21. waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
  22. waveform, resample_rate,
  23. [['speed', str(speed)], ['rate', str(resample_rate)]]
  24. )
  25. waveform = waveform * (1 << 15)
  26. mat = kaldi.fbank(waveform,
  27. num_mel_bins=num_mel_bins,
  28. frame_length=frame_length,
  29. frame_shift=frame_shift,
  30. dither=dither,
  31. energy_floor=0.0,
  32. window_type=window_type,
  33. sample_frequency=resample_rate)
  34. return mat.numpy()
  35. def get_parser():
  36. parser = argparse.ArgumentParser(
  37. description="computer features",
  38. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  39. )
  40. parser.add_argument(
  41. "--wav-lists",
  42. "-w",
  43. default=False,
  44. required=True,
  45. type=str,
  46. help="input wav lists",
  47. )
  48. parser.add_argument(
  49. "--text-files",
  50. "-t",
  51. default=False,
  52. required=True,
  53. type=str,
  54. help="input text files",
  55. )
  56. parser.add_argument(
  57. "--dims",
  58. "-d",
  59. default=80,
  60. type=int,
  61. help="feature dims",
  62. )
  63. parser.add_argument(
  64. "--max-lengths",
  65. "-m",
  66. default=1500,
  67. type=int,
  68. help="max frame numbers",
  69. )
  70. parser.add_argument(
  71. "--sample-frequency",
  72. "-s",
  73. default=16000,
  74. type=int,
  75. help="sample frequency",
  76. )
  77. parser.add_argument(
  78. "--speed-perturb",
  79. "-p",
  80. default="1.0",
  81. type=str,
  82. help="speed perturb",
  83. )
  84. parser.add_argument(
  85. "--ark-index",
  86. "-a",
  87. default=1,
  88. required=True,
  89. type=int,
  90. help="ark index",
  91. )
  92. parser.add_argument(
  93. "--output-dir",
  94. "-o",
  95. default=False,
  96. required=True,
  97. type=str,
  98. help="output dir",
  99. )
  100. parser.add_argument(
  101. "--window-type",
  102. default="hamming",
  103. required=False,
  104. type=str,
  105. help="window type"
  106. )
  107. return parser
  108. def main():
  109. parser = get_parser()
  110. args = parser.parse_args()
  111. ark_file = args.output_dir + "/ark/feats." + str(args.ark_index) + ".ark"
  112. scp_file = args.output_dir + "/ark/feats." + str(args.ark_index) + ".scp"
  113. text_file = args.output_dir + "/txt/text." + str(args.ark_index) + ".txt"
  114. feats_shape_file = args.output_dir + "/ark/len." + str(args.ark_index)
  115. text_shape_file = args.output_dir + "/txt/len." + str(args.ark_index)
  116. ark_writer = WriteHelper('ark,scp:{},{}'.format(ark_file, scp_file))
  117. text_writer = open(text_file, 'w')
  118. feats_shape_writer = open(feats_shape_file, 'w')
  119. text_shape_writer = open(text_shape_file, 'w')
  120. speed_perturb_list = args.speed_perturb.split(',')
  121. for speed in speed_perturb_list:
  122. with open(args.wav_lists, 'r', encoding='utf-8') as wavfile:
  123. with open(args.text_files, 'r', encoding='utf-8') as textfile:
  124. for wav, text in zip(wavfile, textfile):
  125. s_w = wav.strip().split()
  126. wav_id = s_w[0]
  127. wav_file = s_w[1]
  128. s_t = text.strip().split()
  129. text_id = s_t[0]
  130. txt = s_t[1:]
  131. fbank = compute_fbank(wav_file,
  132. num_mel_bins=args.dims,
  133. resample_rate=args.sample_frequency,
  134. speed=float(speed),
  135. window_type=args.window_type
  136. )
  137. feats_dims = fbank.shape[1]
  138. feats_lens = fbank.shape[0]
  139. if feats_lens >= args.max_lengths:
  140. continue
  141. txt_lens = len(txt)
  142. if speed == "1.0":
  143. wav_id_sp = wav_id
  144. else:
  145. wav_id_sp = wav_id + "_sp" + speed
  146. feats_shape_writer.write(wav_id_sp + " " + str(feats_lens) + "," + str(feats_dims) + '\n')
  147. text_shape_writer.write(wav_id_sp + " " + str(txt_lens) + '\n')
  148. text_writer.write(wav_id_sp + " " + " ".join(txt) + '\n')
  149. ark_writer(wav_id_sp, fbank)
  150. if __name__ == '__main__':
  151. main()