wav_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import math
  3. import os
  4. import shutil
  5. from multiprocessing import Pool
  6. from typing import Any, Dict, Union
  7. import kaldiio
  8. import librosa
  9. import numpy as np
  10. import torch
  11. import torchaudio
  12. import torchaudio.compliance.kaldi as kaldi
  13. def ndarray_resample(audio_in: np.ndarray,
  14. fs_in: int = 16000,
  15. fs_out: int = 16000) -> np.ndarray:
  16. audio_out = audio_in
  17. if fs_in != fs_out:
  18. audio_out = librosa.resample(audio_in, orig_sr=fs_in, target_sr=fs_out)
  19. return audio_out
  20. def torch_resample(audio_in: torch.Tensor,
  21. fs_in: int = 16000,
  22. fs_out: int = 16000) -> torch.Tensor:
  23. audio_out = audio_in
  24. if fs_in != fs_out:
  25. audio_out = torchaudio.transforms.Resample(orig_freq=fs_in,
  26. new_freq=fs_out)(audio_in)
  27. return audio_out
  28. def extract_CMVN_featrures(mvn_file):
  29. """
  30. extract CMVN from cmvn.ark
  31. """
  32. if not os.path.exists(mvn_file):
  33. return None
  34. try:
  35. cmvn = kaldiio.load_mat(mvn_file)
  36. means = []
  37. variance = []
  38. for i in range(cmvn.shape[1] - 1):
  39. means.append(float(cmvn[0][i]))
  40. count = float(cmvn[0][-1])
  41. for i in range(cmvn.shape[1] - 1):
  42. variance.append(float(cmvn[1][i]))
  43. for i in range(len(means)):
  44. means[i] /= count
  45. variance[i] = variance[i] / count - means[i] * means[i]
  46. if variance[i] < 1.0e-20:
  47. variance[i] = 1.0e-20
  48. variance[i] = 1.0 / math.sqrt(variance[i])
  49. cmvn = np.array([means, variance])
  50. return cmvn
  51. except Exception:
  52. cmvn = extract_CMVN_features_txt(mvn_file)
  53. return cmvn
  54. def extract_CMVN_features_txt(mvn_file): # noqa
  55. with open(mvn_file, 'r', encoding='utf-8') as f:
  56. lines = f.readlines()
  57. add_shift_list = []
  58. rescale_list = []
  59. for i in range(len(lines)):
  60. line_item = lines[i].split()
  61. if line_item[0] == '<AddShift>':
  62. line_item = lines[i + 1].split()
  63. if line_item[0] == '<LearnRateCoef>':
  64. add_shift_line = line_item[3:(len(line_item) - 1)]
  65. add_shift_list = list(add_shift_line)
  66. continue
  67. elif line_item[0] == '<Rescale>':
  68. line_item = lines[i + 1].split()
  69. if line_item[0] == '<LearnRateCoef>':
  70. rescale_line = line_item[3:(len(line_item) - 1)]
  71. rescale_list = list(rescale_line)
  72. continue
  73. add_shift_list_f = [float(s) for s in add_shift_list]
  74. rescale_list_f = [float(s) for s in rescale_list]
  75. cmvn = np.array([add_shift_list_f, rescale_list_f])
  76. return cmvn
  77. def build_LFR_features(inputs, m=7, n=6): # noqa
  78. """
  79. Actually, this implements stacking frames and skipping frames.
  80. if m = 1 and n = 1, just return the origin features.
  81. if m = 1 and n > 1, it works like skipping.
  82. if m > 1 and n = 1, it works like stacking but only support right frames.
  83. if m > 1 and n > 1, it works like LFR.
  84. Args:
  85. inputs_batch: inputs is T x D np.ndarray
  86. m: number of frames to stack
  87. n: number of frames to skip
  88. """
  89. # LFR_inputs_batch = []
  90. # for inputs in inputs_batch:
  91. LFR_inputs = []
  92. T = inputs.shape[0]
  93. T_lfr = int(np.ceil(T / n))
  94. left_padding = np.tile(inputs[0], ((m - 1) // 2, 1))
  95. inputs = np.vstack((left_padding, inputs))
  96. T = T + (m - 1) // 2
  97. for i in range(T_lfr):
  98. if m <= T - i * n:
  99. LFR_inputs.append(np.hstack(inputs[i * n:i * n + m]))
  100. else: # process last LFR frame
  101. num_padding = m - (T - i * n)
  102. frame = np.hstack(inputs[i * n:])
  103. for _ in range(num_padding):
  104. frame = np.hstack((frame, inputs[-1]))
  105. LFR_inputs.append(frame)
  106. return np.vstack(LFR_inputs)
  107. def compute_fbank(wav_file,
  108. num_mel_bins=80,
  109. frame_length=25,
  110. frame_shift=10,
  111. dither=0.0,
  112. is_pcm=False,
  113. fs: Union[int, Dict[Any, int]] = 16000):
  114. audio_sr: int = 16000
  115. model_sr: int = 16000
  116. if isinstance(fs, int):
  117. model_sr = fs
  118. audio_sr = fs
  119. else:
  120. model_sr = fs['model_fs']
  121. audio_sr = fs['audio_fs']
  122. if is_pcm is True:
  123. # byte(PCM16) to float32, and resample
  124. value = wav_file
  125. middle_data = np.frombuffer(value, dtype=np.int16)
  126. middle_data = np.asarray(middle_data)
  127. if middle_data.dtype.kind not in 'iu':
  128. raise TypeError("'middle_data' must be an array of integers")
  129. dtype = np.dtype('float32')
  130. if dtype.kind != 'f':
  131. raise TypeError("'dtype' must be a floating point type")
  132. i = np.iinfo(middle_data.dtype)
  133. abs_max = 2 ** (i.bits - 1)
  134. offset = i.min + abs_max
  135. waveform = np.frombuffer(
  136. (middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
  137. waveform = ndarray_resample(waveform, audio_sr, model_sr)
  138. waveform = torch.from_numpy(waveform.reshape(1, -1))
  139. else:
  140. # load pcm from wav, and resample
  141. waveform, audio_sr = torchaudio.load(wav_file)
  142. waveform = waveform * (1 << 15)
  143. waveform = torch_resample(waveform, audio_sr, model_sr)
  144. mat = kaldi.fbank(waveform,
  145. num_mel_bins=num_mel_bins,
  146. frame_length=frame_length,
  147. frame_shift=frame_shift,
  148. dither=dither,
  149. energy_floor=0.0,
  150. window_type='hamming',
  151. sample_frequency=model_sr)
  152. input_feats = mat
  153. return input_feats
  154. def wav2num_frame(wav_path, frontend_conf):
  155. waveform, sampling_rate = torchaudio.load(wav_path)
  156. speech_length = (waveform.shape[1] / sampling_rate) * 1000.
  157. n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
  158. feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
  159. return n_frames, feature_dim, speech_length
  160. def calc_shape_core(root_path, frontend_conf, speech_length_min, speech_length_max, idx):
  161. wav_scp_file = os.path.join(root_path, "wav.scp.{}".format(idx))
  162. shape_file = os.path.join(root_path, "speech_shape.{}".format(idx))
  163. with open(wav_scp_file) as f:
  164. lines = f.readlines()
  165. with open(shape_file, "w") as f:
  166. for line in lines:
  167. sample_name, wav_path = line.strip().split()
  168. n_frames, feature_dim, speech_length = wav2num_frame(wav_path, frontend_conf)
  169. write_flag = True
  170. if speech_length_min > 0 and speech_length < speech_length_min:
  171. write_flag = False
  172. if speech_length_max > 0 and speech_length > speech_length_max:
  173. write_flag = False
  174. if write_flag:
  175. f.write("{} {},{}\n".format(sample_name, str(int(np.ceil(n_frames))), str(int(feature_dim))))
  176. f.flush()
  177. def calc_shape(data_dir, dataset, frontend_conf, speech_length_min=-1, speech_length_max=-1, nj=32):
  178. shape_path = os.path.join(data_dir, dataset, "shape_files")
  179. if os.path.exists(shape_path):
  180. assert os.path.exists(os.path.join(data_dir, dataset, "speech_shape"))
  181. print('Shape file for small dataset already exists.')
  182. return
  183. os.makedirs(shape_path, exist_ok=True)
  184. # split
  185. wav_scp_file = os.path.join(data_dir, dataset, "wav.scp")
  186. with open(wav_scp_file) as f:
  187. lines = f.readlines()
  188. num_lines = len(lines)
  189. num_job_lines = num_lines // nj
  190. start = 0
  191. for i in range(nj):
  192. end = start + num_job_lines
  193. file = os.path.join(shape_path, "wav.scp.{}".format(str(i + 1)))
  194. with open(file, "w") as f:
  195. if i == nj - 1:
  196. f.writelines(lines[start:])
  197. else:
  198. f.writelines(lines[start:end])
  199. start = end
  200. p = Pool(nj)
  201. for i in range(nj):
  202. p.apply_async(calc_shape_core,
  203. args=(shape_path, frontend_conf, speech_length_min, speech_length_max, str(i + 1)))
  204. print('Generating shape files, please wait a few minutes...')
  205. p.close()
  206. p.join()
  207. # combine
  208. file = os.path.join(data_dir, dataset, "speech_shape")
  209. with open(file, "w") as f:
  210. for i in range(nj):
  211. job_file = os.path.join(shape_path, "speech_shape.{}".format(str(i + 1)))
  212. with open(job_file) as job_f:
  213. lines = job_f.readlines()
  214. f.writelines(lines)
  215. print('Generating shape files done.')
  216. def generate_data_list(data_dir, dataset, nj=100):
  217. split_dir = os.path.join(data_dir, dataset, "split")
  218. if os.path.exists(split_dir):
  219. assert os.path.exists(os.path.join(data_dir, dataset, "data.list"))
  220. print('Data list for large dataset already exists.')
  221. return
  222. os.makedirs(split_dir, exist_ok=True)
  223. with open(os.path.join(data_dir, dataset, "wav.scp")) as f_wav:
  224. wav_lines = f_wav.readlines()
  225. with open(os.path.join(data_dir, dataset, "text")) as f_text:
  226. text_lines = f_text.readlines()
  227. total_num_lines = len(wav_lines)
  228. num_lines = total_num_lines // nj
  229. start_num = 0
  230. for i in range(nj):
  231. end_num = start_num + num_lines
  232. split_dir_nj = os.path.join(split_dir, str(i + 1))
  233. os.mkdir(split_dir_nj)
  234. wav_file = os.path.join(split_dir_nj, 'wav.scp')
  235. text_file = os.path.join(split_dir_nj, "text")
  236. with open(wav_file, "w") as fw, open(text_file, "w") as ft:
  237. if i == nj - 1:
  238. fw.writelines(wav_lines[start_num:])
  239. ft.writelines(text_lines[start_num:])
  240. else:
  241. fw.writelines(wav_lines[start_num:end_num])
  242. ft.writelines(text_lines[start_num:end_num])
  243. start_num = end_num
  244. data_list_file = os.path.join(data_dir, dataset, "data.list")
  245. with open(data_list_file, "w") as f_data:
  246. for i in range(nj):
  247. wav_path = os.path.join(split_dir, str(i + 1), "wav.scp")
  248. text_path = os.path.join(split_dir, str(i + 1), "text")
  249. f_data.write(wav_path + " " + text_path + "\n")
  250. def filter_wav_text(data_dir, dataset):
  251. wav_file = os.path.join(data_dir,dataset,"wav.scp")
  252. text_file = os.path.join(data_dir, dataset, "text")
  253. with open(wav_file) as f_wav, open(text_file) as f_text:
  254. wav_lines = f_wav.readlines()
  255. text_lines = f_text.readlines()
  256. os.rename(wav_file, "{}.bak".format(wav_file))
  257. os.rename(text_file, "{}.bak".format(text_file))
  258. wav_dict = {}
  259. for line in wav_lines:
  260. parts = line.strip().split()
  261. if len(parts) < 2:
  262. continue
  263. sample_name, wav_path = parts
  264. wav_dict[sample_name] = wav_path
  265. text_dict = {}
  266. for line in text_lines:
  267. parts = line.strip().split(" ", 1)
  268. if len(parts) < 2:
  269. continue
  270. sample_name, txt = parts
  271. text_dict[sample_name] = txt
  272. filter_count = 0
  273. with open(wav_file, "w") as f_wav, open(text_file, "w") as f_text:
  274. for sample_name, wav_path in wav_dict.items():
  275. if sample_name in text_dict.keys():
  276. f_wav.write(sample_name + " " + wav_path + "\n")
  277. f_text.write(sample_name + " " + text_dict[sample_name] + "\n")
  278. else:
  279. filter_count += 1
  280. print("{}/{} samples in {} are filtered because of the mismatch between wav.scp and text".format(len(wav_lines), filter_count, dataset))