wav_utils.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import math
  3. import os
  4. from typing import Any, Dict, Union
  5. import kaldiio
  6. import librosa
  7. import numpy as np
  8. import torch
  9. import torchaudio
  10. import torchaudio.compliance.kaldi as kaldi
  11. def ndarray_resample(audio_in: np.ndarray,
  12. fs_in: int = 16000,
  13. fs_out: int = 16000) -> np.ndarray:
  14. audio_out = audio_in
  15. if fs_in != fs_out:
  16. audio_out = librosa.resample(audio_in, orig_sr=fs_in, target_sr=fs_out)
  17. return audio_out
  18. def torch_resample(audio_in: torch.Tensor,
  19. fs_in: int = 16000,
  20. fs_out: int = 16000) -> torch.Tensor:
  21. audio_out = audio_in
  22. if fs_in != fs_out:
  23. audio_out = torchaudio.transforms.Resample(orig_freq=fs_in,
  24. new_freq=fs_out)(audio_in)
  25. return audio_out
  26. def extract_CMVN_featrures(mvn_file):
  27. """
  28. extract CMVN from cmvn.ark
  29. """
  30. if not os.path.exists(mvn_file):
  31. return None
  32. try:
  33. cmvn = kaldiio.load_mat(mvn_file)
  34. means = []
  35. variance = []
  36. for i in range(cmvn.shape[1] - 1):
  37. means.append(float(cmvn[0][i]))
  38. count = float(cmvn[0][-1])
  39. for i in range(cmvn.shape[1] - 1):
  40. variance.append(float(cmvn[1][i]))
  41. for i in range(len(means)):
  42. means[i] /= count
  43. variance[i] = variance[i] / count - means[i] * means[i]
  44. if variance[i] < 1.0e-20:
  45. variance[i] = 1.0e-20
  46. variance[i] = 1.0 / math.sqrt(variance[i])
  47. cmvn = np.array([means, variance])
  48. return cmvn
  49. except Exception:
  50. cmvn = extract_CMVN_features_txt(mvn_file)
  51. return cmvn
  52. def extract_CMVN_features_txt(mvn_file): # noqa
  53. with open(mvn_file, 'r', encoding='utf-8') as f:
  54. lines = f.readlines()
  55. add_shift_list = []
  56. rescale_list = []
  57. for i in range(len(lines)):
  58. line_item = lines[i].split()
  59. if line_item[0] == '<AddShift>':
  60. line_item = lines[i + 1].split()
  61. if line_item[0] == '<LearnRateCoef>':
  62. add_shift_line = line_item[3:(len(line_item) - 1)]
  63. add_shift_list = list(add_shift_line)
  64. continue
  65. elif line_item[0] == '<Rescale>':
  66. line_item = lines[i + 1].split()
  67. if line_item[0] == '<LearnRateCoef>':
  68. rescale_line = line_item[3:(len(line_item) - 1)]
  69. rescale_list = list(rescale_line)
  70. continue
  71. add_shift_list_f = [float(s) for s in add_shift_list]
  72. rescale_list_f = [float(s) for s in rescale_list]
  73. cmvn = np.array([add_shift_list_f, rescale_list_f])
  74. return cmvn
  75. def build_LFR_features(inputs, m=7, n=6): # noqa
  76. """
  77. Actually, this implements stacking frames and skipping frames.
  78. if m = 1 and n = 1, just return the origin features.
  79. if m = 1 and n > 1, it works like skipping.
  80. if m > 1 and n = 1, it works like stacking but only support right frames.
  81. if m > 1 and n > 1, it works like LFR.
  82. Args:
  83. inputs_batch: inputs is T x D np.ndarray
  84. m: number of frames to stack
  85. n: number of frames to skip
  86. """
  87. # LFR_inputs_batch = []
  88. # for inputs in inputs_batch:
  89. LFR_inputs = []
  90. T = inputs.shape[0]
  91. T_lfr = int(np.ceil(T / n))
  92. left_padding = np.tile(inputs[0], ((m - 1) // 2, 1))
  93. inputs = np.vstack((left_padding, inputs))
  94. T = T + (m - 1) // 2
  95. for i in range(T_lfr):
  96. if m <= T - i * n:
  97. LFR_inputs.append(np.hstack(inputs[i * n:i * n + m]))
  98. else: # process last LFR frame
  99. num_padding = m - (T - i * n)
  100. frame = np.hstack(inputs[i * n:])
  101. for _ in range(num_padding):
  102. frame = np.hstack((frame, inputs[-1]))
  103. LFR_inputs.append(frame)
  104. return np.vstack(LFR_inputs)
  105. def compute_fbank(wav_file,
  106. num_mel_bins=80,
  107. frame_length=25,
  108. frame_shift=10,
  109. dither=0.0,
  110. is_pcm=False,
  111. fs: Union[int, Dict[Any, int]] = 16000):
  112. audio_sr: int = 16000
  113. model_sr: int = 16000
  114. if isinstance(fs, int):
  115. model_sr = fs
  116. audio_sr = fs
  117. else:
  118. model_sr = fs['model_fs']
  119. audio_sr = fs['audio_fs']
  120. if is_pcm is True:
  121. # byte(PCM16) to float32, and resample
  122. value = wav_file
  123. middle_data = np.frombuffer(value, dtype=np.int16)
  124. middle_data = np.asarray(middle_data)
  125. if middle_data.dtype.kind not in 'iu':
  126. raise TypeError("'middle_data' must be an array of integers")
  127. dtype = np.dtype('float32')
  128. if dtype.kind != 'f':
  129. raise TypeError("'dtype' must be a floating point type")
  130. i = np.iinfo(middle_data.dtype)
  131. abs_max = 2**(i.bits - 1)
  132. offset = i.min + abs_max
  133. waveform = np.frombuffer(
  134. (middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
  135. waveform = ndarray_resample(waveform, audio_sr, model_sr)
  136. waveform = torch.from_numpy(waveform.reshape(1, -1))
  137. else:
  138. # load pcm from wav, and resample
  139. waveform, audio_sr = torchaudio.load(wav_file)
  140. waveform = waveform * (1 << 15)
  141. waveform = torch_resample(waveform, audio_sr, model_sr)
  142. mat = kaldi.fbank(waveform,
  143. num_mel_bins=num_mel_bins,
  144. frame_length=frame_length,
  145. frame_shift=frame_shift,
  146. dither=dither,
  147. energy_floor=0.0,
  148. window_type='hamming',
  149. sample_frequency=model_sr)
  150. input_feats = mat
  151. return input_feats