feature_transform.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. from typing import List
  2. from typing import Tuple
  3. from typing import Union
  4. import librosa
  5. import numpy as np
  6. import torch
  7. from torch_complex.tensor import ComplexTensor
  8. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  9. class FeatureTransform(torch.nn.Module):
  10. def __init__(
  11. self,
  12. # Mel options,
  13. fs: int = 16000,
  14. n_fft: int = 512,
  15. n_mels: int = 80,
  16. fmin: float = 0.0,
  17. fmax: float = None,
  18. # Normalization
  19. stats_file: str = None,
  20. apply_uttmvn: bool = True,
  21. uttmvn_norm_means: bool = True,
  22. uttmvn_norm_vars: bool = False,
  23. ):
  24. super().__init__()
  25. self.apply_uttmvn = apply_uttmvn
  26. self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
  27. self.stats_file = stats_file
  28. if stats_file is not None:
  29. self.global_mvn = GlobalMVN(stats_file)
  30. else:
  31. self.global_mvn = None
  32. if self.apply_uttmvn is not None:
  33. self.uttmvn = UtteranceMVN(
  34. norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars
  35. )
  36. else:
  37. self.uttmvn = None
  38. def forward(
  39. self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]]
  40. ) -> Tuple[torch.Tensor, torch.LongTensor]:
  41. # (B, T, F) or (B, T, C, F)
  42. if x.dim() not in (3, 4):
  43. raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
  44. if not torch.is_tensor(ilens):
  45. ilens = torch.from_numpy(np.asarray(ilens)).to(x.device)
  46. if x.dim() == 4:
  47. # h: (B, T, C, F) -> h: (B, T, F)
  48. if self.training:
  49. # Select 1ch randomly
  50. ch = np.random.randint(x.size(2))
  51. h = x[:, :, ch, :]
  52. else:
  53. # Use the first channel
  54. h = x[:, :, 0, :]
  55. else:
  56. h = x
  57. # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
  58. h = h.real**2 + h.imag**2
  59. h, _ = self.logmel(h, ilens)
  60. if self.stats_file is not None:
  61. h, _ = self.global_mvn(h, ilens)
  62. if self.apply_uttmvn:
  63. h, _ = self.uttmvn(h, ilens)
  64. return h, ilens
  65. class LogMel(torch.nn.Module):
  66. """Convert STFT to fbank feats
  67. The arguments is same as librosa.filters.mel
  68. Args:
  69. fs: number > 0 [scalar] sampling rate of the incoming signal
  70. n_fft: int > 0 [scalar] number of FFT components
  71. n_mels: int > 0 [scalar] number of Mel bands to generate
  72. fmin: float >= 0 [scalar] lowest frequency (in Hz)
  73. fmax: float >= 0 [scalar] highest frequency (in Hz).
  74. If `None`, use `fmax = fs / 2.0`
  75. htk: use HTK formula instead of Slaney
  76. norm: {None, 1, np.inf} [scalar]
  77. if 1, divide the triangular mel weights by the width of the mel band
  78. (area normalization). Otherwise, leave all the triangles aiming for
  79. a peak value of 1.0
  80. """
  81. def __init__(
  82. self,
  83. fs: int = 16000,
  84. n_fft: int = 512,
  85. n_mels: int = 80,
  86. fmin: float = 0.0,
  87. fmax: float = None,
  88. htk: bool = False,
  89. norm=1,
  90. ):
  91. super().__init__()
  92. _mel_options = dict(
  93. sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
  94. )
  95. self.mel_options = _mel_options
  96. # Note(kamo): The mel matrix of librosa is different from kaldi.
  97. melmat = librosa.filters.mel(**_mel_options)
  98. # melmat: (D2, D1) -> (D1, D2)
  99. self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
  100. def extra_repr(self):
  101. return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
  102. def forward(
  103. self, feat: torch.Tensor, ilens: torch.LongTensor
  104. ) -> Tuple[torch.Tensor, torch.LongTensor]:
  105. # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
  106. mel_feat = torch.matmul(feat, self.melmat)
  107. logmel_feat = (mel_feat + 1e-20).log()
  108. # Zero padding
  109. logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0)
  110. return logmel_feat, ilens
  111. class GlobalMVN(torch.nn.Module):
  112. """Apply global mean and variance normalization
  113. Args:
  114. stats_file(str): npy file of 1-dim array or text file.
  115. From the _first element to
  116. the {(len(array) - 1) / 2}th element are treated as
  117. the sum of features,
  118. and the rest excluding the last elements are
  119. treated as the sum of the square value of features,
  120. and the last elements eqauls to the number of samples.
  121. std_floor(float):
  122. """
  123. def __init__(
  124. self,
  125. stats_file: str,
  126. norm_means: bool = True,
  127. norm_vars: bool = True,
  128. eps: float = 1.0e-20,
  129. ):
  130. super().__init__()
  131. self.norm_means = norm_means
  132. self.norm_vars = norm_vars
  133. self.stats_file = stats_file
  134. stats = np.load(stats_file)
  135. stats = stats.astype(float)
  136. assert (len(stats) - 1) % 2 == 0, stats.shape
  137. count = stats.flatten()[-1]
  138. mean = stats[: (len(stats) - 1) // 2] / count
  139. var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean
  140. std = np.maximum(np.sqrt(var), eps)
  141. self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32)))
  142. self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32)))
  143. def extra_repr(self):
  144. return (
  145. f"stats_file={self.stats_file}, "
  146. f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
  147. )
  148. def forward(
  149. self, x: torch.Tensor, ilens: torch.LongTensor
  150. ) -> Tuple[torch.Tensor, torch.LongTensor]:
  151. # feat: (B, T, D)
  152. if self.norm_means:
  153. x += self.bias.type_as(x)
  154. x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
  155. if self.norm_vars:
  156. x *= self.scale.type_as(x)
  157. return x, ilens
  158. class UtteranceMVN(torch.nn.Module):
  159. def __init__(
  160. self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20
  161. ):
  162. super().__init__()
  163. self.norm_means = norm_means
  164. self.norm_vars = norm_vars
  165. self.eps = eps
  166. def extra_repr(self):
  167. return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
  168. def forward(
  169. self, x: torch.Tensor, ilens: torch.LongTensor
  170. ) -> Tuple[torch.Tensor, torch.LongTensor]:
  171. return utterance_mvn(
  172. x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps
  173. )
  174. def utterance_mvn(
  175. x: torch.Tensor,
  176. ilens: torch.LongTensor,
  177. norm_means: bool = True,
  178. norm_vars: bool = False,
  179. eps: float = 1.0e-20,
  180. ) -> Tuple[torch.Tensor, torch.LongTensor]:
  181. """Apply utterance mean and variance normalization
  182. Args:
  183. x: (B, T, D), assumed zero padded
  184. ilens: (B, T, D)
  185. norm_means:
  186. norm_vars:
  187. eps:
  188. """
  189. ilens_ = ilens.type_as(x)
  190. # mean: (B, D)
  191. mean = x.sum(dim=1) / ilens_[:, None]
  192. if norm_means:
  193. x -= mean[:, None, :]
  194. x_ = x
  195. else:
  196. x_ = x - mean[:, None, :]
  197. # Zero padding
  198. x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0)
  199. if norm_vars:
  200. var = x_.pow(2).sum(dim=1) / ilens_[:, None]
  201. var = torch.clamp(var, min=eps)
  202. x /= var.sqrt()[:, None, :]
  203. x_ = x
  204. return x_, ilens
  205. def feature_transform_for(args, n_fft):
  206. return FeatureTransform(
  207. # Mel options,
  208. fs=args.fbank_fs,
  209. n_fft=n_fft,
  210. n_mels=args.n_mels,
  211. fmin=args.fbank_fmin,
  212. fmax=args.fbank_fmax,
  213. # Normalization
  214. stats_file=args.stats_file,
  215. apply_uttmvn=args.apply_uttmvn,
  216. uttmvn_norm_means=args.uttmvn_norm_means,
  217. uttmvn_norm_vars=args.uttmvn_norm_vars,
  218. )