wav_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import math
  3. import os
  4. import shutil
  5. from multiprocessing import Pool
  6. from typing import Any, Dict, Union
  7. import kaldiio
  8. import librosa
  9. import numpy as np
  10. import torch
  11. import torchaudio
  12. import librosa
  13. import torchaudio.compliance.kaldi as kaldi
  14. def ndarray_resample(audio_in: np.ndarray,
  15. fs_in: int = 16000,
  16. fs_out: int = 16000) -> np.ndarray:
  17. audio_out = audio_in
  18. if fs_in != fs_out:
  19. audio_out = librosa.resample(audio_in, orig_sr=fs_in, target_sr=fs_out)
  20. return audio_out
  21. def torch_resample(audio_in: torch.Tensor,
  22. fs_in: int = 16000,
  23. fs_out: int = 16000) -> torch.Tensor:
  24. audio_out = audio_in
  25. if fs_in != fs_out:
  26. audio_out = torchaudio.transforms.Resample(orig_freq=fs_in,
  27. new_freq=fs_out)(audio_in)
  28. return audio_out
  29. def extract_CMVN_featrures(mvn_file):
  30. """
  31. extract CMVN from cmvn.ark
  32. """
  33. if not os.path.exists(mvn_file):
  34. return None
  35. try:
  36. cmvn = kaldiio.load_mat(mvn_file)
  37. means = []
  38. variance = []
  39. for i in range(cmvn.shape[1] - 1):
  40. means.append(float(cmvn[0][i]))
  41. count = float(cmvn[0][-1])
  42. for i in range(cmvn.shape[1] - 1):
  43. variance.append(float(cmvn[1][i]))
  44. for i in range(len(means)):
  45. means[i] /= count
  46. variance[i] = variance[i] / count - means[i] * means[i]
  47. if variance[i] < 1.0e-20:
  48. variance[i] = 1.0e-20
  49. variance[i] = 1.0 / math.sqrt(variance[i])
  50. cmvn = np.array([means, variance])
  51. return cmvn
  52. except Exception:
  53. cmvn = extract_CMVN_features_txt(mvn_file)
  54. return cmvn
  55. def extract_CMVN_features_txt(mvn_file): # noqa
  56. with open(mvn_file, 'r', encoding='utf-8') as f:
  57. lines = f.readlines()
  58. add_shift_list = []
  59. rescale_list = []
  60. for i in range(len(lines)):
  61. line_item = lines[i].split()
  62. if line_item[0] == '<AddShift>':
  63. line_item = lines[i + 1].split()
  64. if line_item[0] == '<LearnRateCoef>':
  65. add_shift_line = line_item[3:(len(line_item) - 1)]
  66. add_shift_list = list(add_shift_line)
  67. continue
  68. elif line_item[0] == '<Rescale>':
  69. line_item = lines[i + 1].split()
  70. if line_item[0] == '<LearnRateCoef>':
  71. rescale_line = line_item[3:(len(line_item) - 1)]
  72. rescale_list = list(rescale_line)
  73. continue
  74. add_shift_list_f = [float(s) for s in add_shift_list]
  75. rescale_list_f = [float(s) for s in rescale_list]
  76. cmvn = np.array([add_shift_list_f, rescale_list_f])
  77. return cmvn
  78. def build_LFR_features(inputs, m=7, n=6): # noqa
  79. """
  80. Actually, this implements stacking frames and skipping frames.
  81. if m = 1 and n = 1, just return the origin features.
  82. if m = 1 and n > 1, it works like skipping.
  83. if m > 1 and n = 1, it works like stacking but only support right frames.
  84. if m > 1 and n > 1, it works like LFR.
  85. Args:
  86. inputs_batch: inputs is T x D np.ndarray
  87. m: number of frames to stack
  88. n: number of frames to skip
  89. """
  90. # LFR_inputs_batch = []
  91. # for inputs in inputs_batch:
  92. LFR_inputs = []
  93. T = inputs.shape[0]
  94. T_lfr = int(np.ceil(T / n))
  95. left_padding = np.tile(inputs[0], ((m - 1) // 2, 1))
  96. inputs = np.vstack((left_padding, inputs))
  97. T = T + (m - 1) // 2
  98. for i in range(T_lfr):
  99. if m <= T - i * n:
  100. LFR_inputs.append(np.hstack(inputs[i * n:i * n + m]))
  101. else: # process last LFR frame
  102. num_padding = m - (T - i * n)
  103. frame = np.hstack(inputs[i * n:])
  104. for _ in range(num_padding):
  105. frame = np.hstack((frame, inputs[-1]))
  106. LFR_inputs.append(frame)
  107. return np.vstack(LFR_inputs)
  108. def compute_fbank(wav_file,
  109. num_mel_bins=80,
  110. frame_length=25,
  111. frame_shift=10,
  112. dither=0.0,
  113. is_pcm=False,
  114. fs: Union[int, Dict[Any, int]] = 16000):
  115. audio_sr: int = 16000
  116. model_sr: int = 16000
  117. if isinstance(fs, int):
  118. model_sr = fs
  119. audio_sr = fs
  120. else:
  121. model_sr = fs['model_fs']
  122. audio_sr = fs['audio_fs']
  123. if is_pcm is True:
  124. # byte(PCM16) to float32, and resample
  125. value = wav_file
  126. middle_data = np.frombuffer(value, dtype=np.int16)
  127. middle_data = np.asarray(middle_data)
  128. if middle_data.dtype.kind not in 'iu':
  129. raise TypeError("'middle_data' must be an array of integers")
  130. dtype = np.dtype('float32')
  131. if dtype.kind != 'f':
  132. raise TypeError("'dtype' must be a floating point type")
  133. i = np.iinfo(middle_data.dtype)
  134. abs_max = 2 ** (i.bits - 1)
  135. offset = i.min + abs_max
  136. waveform = np.frombuffer(
  137. (middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
  138. waveform = ndarray_resample(waveform, audio_sr, model_sr)
  139. waveform = torch.from_numpy(waveform.reshape(1, -1))
  140. else:
  141. # load pcm from wav, and resample
  142. try:
  143. waveform, audio_sr = torchaudio.load(wav_file)
  144. except:
  145. waveform, audio_sr = librosa.load(wav_file, dtype='float32')
  146. if waveform.ndim == 2:
  147. waveform = waveform[:, 0]
  148. waveform = torch.tensor(np.expand_dims(waveform, axis=0))
  149. waveform = waveform * (1 << 15)
  150. waveform = torch_resample(waveform, audio_sr, model_sr)
  151. mat = kaldi.fbank(waveform,
  152. num_mel_bins=num_mel_bins,
  153. frame_length=frame_length,
  154. frame_shift=frame_shift,
  155. dither=dither,
  156. energy_floor=0.0,
  157. window_type='hamming',
  158. sample_frequency=model_sr)
  159. input_feats = mat
  160. return input_feats
  161. def wav2num_frame(wav_path, frontend_conf):
  162. try:
  163. waveform, sampling_rate = torchaudio.load(wav_path)
  164. except:
  165. waveform, sampling_rate = librosa.load(wav_path)
  166. waveform = torch.tensor(np.expand_dims(waveform, axis=0))
  167. speech_length = (waveform.shape[1] / sampling_rate) * 1000.
  168. n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
  169. feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
  170. return n_frames, feature_dim, speech_length
  171. def calc_shape_core(root_path, frontend_conf, speech_length_min, speech_length_max, idx):
  172. wav_scp_file = os.path.join(root_path, "wav.scp.{}".format(idx))
  173. shape_file = os.path.join(root_path, "speech_shape.{}".format(idx))
  174. with open(wav_scp_file) as f:
  175. lines = f.readlines()
  176. with open(shape_file, "w") as f:
  177. for line in lines:
  178. sample_name, wav_path = line.strip().split()
  179. n_frames, feature_dim, speech_length = wav2num_frame(wav_path, frontend_conf)
  180. write_flag = True
  181. if speech_length_min > 0 and speech_length < speech_length_min:
  182. write_flag = False
  183. if speech_length_max > 0 and speech_length > speech_length_max:
  184. write_flag = False
  185. if write_flag:
  186. f.write("{} {},{}\n".format(sample_name, str(int(np.ceil(n_frames))), str(int(feature_dim))))
  187. f.flush()
  188. def calc_shape(data_dir, dataset, frontend_conf, speech_length_min=-1, speech_length_max=-1, nj=32):
  189. shape_path = os.path.join(data_dir, dataset, "shape_files")
  190. if os.path.exists(shape_path):
  191. assert os.path.exists(os.path.join(data_dir, dataset, "speech_shape"))
  192. print('Shape file for small dataset already exists.')
  193. return
  194. os.makedirs(shape_path, exist_ok=True)
  195. # split
  196. wav_scp_file = os.path.join(data_dir, dataset, "wav.scp")
  197. with open(wav_scp_file) as f:
  198. lines = f.readlines()
  199. num_lines = len(lines)
  200. num_job_lines = num_lines // nj
  201. start = 0
  202. for i in range(nj):
  203. end = start + num_job_lines
  204. file = os.path.join(shape_path, "wav.scp.{}".format(str(i + 1)))
  205. with open(file, "w") as f:
  206. if i == nj - 1:
  207. f.writelines(lines[start:])
  208. else:
  209. f.writelines(lines[start:end])
  210. start = end
  211. p = Pool(nj)
  212. for i in range(nj):
  213. p.apply_async(calc_shape_core,
  214. args=(shape_path, frontend_conf, speech_length_min, speech_length_max, str(i + 1)))
  215. print('Generating shape files, please wait a few minutes...')
  216. p.close()
  217. p.join()
  218. # combine
  219. file = os.path.join(data_dir, dataset, "speech_shape")
  220. with open(file, "w") as f:
  221. for i in range(nj):
  222. job_file = os.path.join(shape_path, "speech_shape.{}".format(str(i + 1)))
  223. with open(job_file) as job_f:
  224. lines = job_f.readlines()
  225. f.writelines(lines)
  226. print('Generating shape files done.')
  227. def generate_data_list(data_dir, dataset, nj=100):
  228. split_dir = os.path.join(data_dir, dataset, "split")
  229. if os.path.exists(split_dir):
  230. assert os.path.exists(os.path.join(data_dir, dataset, "data.list"))
  231. print('Data list for large dataset already exists.')
  232. return
  233. os.makedirs(split_dir, exist_ok=True)
  234. with open(os.path.join(data_dir, dataset, "wav.scp")) as f_wav:
  235. wav_lines = f_wav.readlines()
  236. with open(os.path.join(data_dir, dataset, "text")) as f_text:
  237. text_lines = f_text.readlines()
  238. total_num_lines = len(wav_lines)
  239. num_lines = total_num_lines // nj
  240. start_num = 0
  241. for i in range(nj):
  242. end_num = start_num + num_lines
  243. split_dir_nj = os.path.join(split_dir, str(i + 1))
  244. os.mkdir(split_dir_nj)
  245. wav_file = os.path.join(split_dir_nj, 'wav.scp')
  246. text_file = os.path.join(split_dir_nj, "text")
  247. with open(wav_file, "w") as fw, open(text_file, "w") as ft:
  248. if i == nj - 1:
  249. fw.writelines(wav_lines[start_num:])
  250. ft.writelines(text_lines[start_num:])
  251. else:
  252. fw.writelines(wav_lines[start_num:end_num])
  253. ft.writelines(text_lines[start_num:end_num])
  254. start_num = end_num
  255. data_list_file = os.path.join(data_dir, dataset, "data.list")
  256. with open(data_list_file, "w") as f_data:
  257. for i in range(nj):
  258. wav_path = os.path.join(split_dir, str(i + 1), "wav.scp")
  259. text_path = os.path.join(split_dir, str(i + 1), "text")
  260. f_data.write(wav_path + " " + text_path + "\n")
  261. def filter_wav_text(data_dir, dataset):
  262. wav_file = os.path.join(data_dir,dataset,"wav.scp")
  263. text_file = os.path.join(data_dir, dataset, "text")
  264. with open(wav_file) as f_wav, open(text_file) as f_text:
  265. wav_lines = f_wav.readlines()
  266. text_lines = f_text.readlines()
  267. os.rename(wav_file, "{}.bak".format(wav_file))
  268. os.rename(text_file, "{}.bak".format(text_file))
  269. wav_dict = {}
  270. for line in wav_lines:
  271. parts = line.strip().split()
  272. if len(parts) != 2:
  273. continue
  274. sample_name, wav_path = parts
  275. wav_dict[sample_name] = wav_path
  276. text_dict = {}
  277. for line in text_lines:
  278. parts = line.strip().split()
  279. if len(parts) < 2:
  280. continue
  281. sample_name = parts[0]
  282. text_dict[sample_name] = " ".join(parts[1:]).lower()
  283. filter_count = 0
  284. with open(wav_file, "w") as f_wav, open(text_file, "w") as f_text:
  285. for sample_name, wav_path in wav_dict.items():
  286. if sample_name in text_dict.keys():
  287. f_wav.write(sample_name + " " + wav_path + "\n")
  288. f_text.write(sample_name + " " + text_dict[sample_name] + "\n")
  289. else:
  290. filter_count += 1
  291. print("{}/{} samples in {} are filtered because of the mismatch between wav.scp and text".format(len(wav_lines), filter_count, dataset))