stft.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. from distutils.version import LooseVersion
  2. from typing import Optional
  3. from typing import Tuple
  4. from typing import Union
  5. import torch
  6. try:
  7. from torch_complex.tensor import ComplexTensor
  8. except:
  9. print("Please install torch_complex firstly")
  10. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  11. from funasr.frontends.utils.complex_utils import is_complex
  12. import librosa
  13. import numpy as np
  14. is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
  15. is_torch_1_7_plus = LooseVersion(torch.__version__) >= LooseVersion("1.7")
  16. class Stft(torch.nn.Module):
  17. def __init__(
  18. self,
  19. n_fft: int = 512,
  20. win_length: int = None,
  21. hop_length: int = 128,
  22. window: Optional[str] = "hann",
  23. center: bool = True,
  24. normalized: bool = False,
  25. onesided: bool = True,
  26. ):
  27. super().__init__()
  28. self.n_fft = n_fft
  29. if win_length is None:
  30. self.win_length = n_fft
  31. else:
  32. self.win_length = win_length
  33. self.hop_length = hop_length
  34. self.center = center
  35. self.normalized = normalized
  36. self.onesided = onesided
  37. if window is not None and not hasattr(torch, f"{window}_window"):
  38. if window.lower() != "povey":
  39. raise ValueError(f"{window} window is not implemented")
  40. self.window = window
  41. def extra_repr(self):
  42. return (
  43. f"n_fft={self.n_fft}, "
  44. f"win_length={self.win_length}, "
  45. f"hop_length={self.hop_length}, "
  46. f"center={self.center}, "
  47. f"normalized={self.normalized}, "
  48. f"onesided={self.onesided}"
  49. )
  50. def forward(
  51. self, input: torch.Tensor, ilens: torch.Tensor = None
  52. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  53. """STFT forward function.
  54. Args:
  55. input: (Batch, Nsamples) or (Batch, Nsample, Channels)
  56. ilens: (Batch)
  57. Returns:
  58. output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
  59. """
  60. bs = input.size(0)
  61. if input.dim() == 3:
  62. multi_channel = True
  63. # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
  64. input = input.transpose(1, 2).reshape(-1, input.size(1))
  65. else:
  66. multi_channel = False
  67. # NOTE(kamo):
  68. # The default behaviour of torch.stft is compatible with librosa.stft
  69. # about padding and scaling.
  70. # Note that it's different from scipy.signal.stft
  71. # output: (Batch, Freq, Frames, 2=real_imag)
  72. # or (Batch, Channel, Freq, Frames, 2=real_imag)
  73. if self.window is not None:
  74. if self.window.lower() == "povey":
  75. window = torch.hann_window(self.win_length, periodic=False,
  76. device=input.device, dtype=input.dtype).pow(0.85)
  77. else:
  78. window_func = getattr(torch, f"{self.window}_window")
  79. window = window_func(
  80. self.win_length, dtype=input.dtype, device=input.device
  81. )
  82. else:
  83. window = None
  84. # For the compatibility of ARM devices, which do not support
  85. # torch.stft() due to the lake of MKL.
  86. if input.is_cuda or torch.backends.mkl.is_available():
  87. stft_kwargs = dict(
  88. n_fft=self.n_fft,
  89. win_length=self.win_length,
  90. hop_length=self.hop_length,
  91. center=self.center,
  92. window=window,
  93. normalized=self.normalized,
  94. onesided=self.onesided,
  95. )
  96. if is_torch_1_7_plus:
  97. stft_kwargs["return_complex"] = False
  98. output = torch.stft(input, **stft_kwargs)
  99. else:
  100. if self.training:
  101. raise NotImplementedError(
  102. "stft is implemented with librosa on this device, which does not "
  103. "support the training mode."
  104. )
  105. # use stft_kwargs to flexibly control different PyTorch versions' kwargs
  106. stft_kwargs = dict(
  107. n_fft=self.n_fft,
  108. win_length=self.win_length,
  109. hop_length=self.hop_length,
  110. center=self.center,
  111. window=window,
  112. )
  113. if window is not None:
  114. # pad the given window to n_fft
  115. n_pad_left = (self.n_fft - window.shape[0]) // 2
  116. n_pad_right = self.n_fft - window.shape[0] - n_pad_left
  117. stft_kwargs["window"] = torch.cat(
  118. [torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0
  119. ).numpy()
  120. else:
  121. win_length = (
  122. self.win_length if self.win_length is not None else self.n_fft
  123. )
  124. stft_kwargs["window"] = torch.ones(win_length)
  125. output = []
  126. # iterate over istances in a batch
  127. for i, instance in enumerate(input):
  128. stft = librosa.stft(input[i].numpy(), **stft_kwargs)
  129. output.append(torch.tensor(np.stack([stft.real, stft.imag], -1)))
  130. output = torch.stack(output, 0)
  131. if not self.onesided:
  132. len_conj = self.n_fft - output.shape[1]
  133. conj = output[:, 1 : 1 + len_conj].flip(1)
  134. conj[:, :, :, -1].data *= -1
  135. output = torch.cat([output, conj], 1)
  136. if self.normalized:
  137. output = output * (stft_kwargs["window"].shape[0] ** (-0.5))
  138. # output: (Batch, Freq, Frames, 2=real_imag)
  139. # -> (Batch, Frames, Freq, 2=real_imag)
  140. output = output.transpose(1, 2)
  141. if multi_channel:
  142. # output: (Batch * Channel, Frames, Freq, 2=real_imag)
  143. # -> (Batch, Frame, Channel, Freq, 2=real_imag)
  144. output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
  145. 1, 2
  146. )
  147. if ilens is not None:
  148. if self.center:
  149. pad = self.n_fft // 2
  150. ilens = ilens + 2 * pad
  151. olens = (ilens - self.n_fft) // self.hop_length + 1
  152. output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
  153. else:
  154. olens = None
  155. return output, olens
  156. def inverse(
  157. self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor = None
  158. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  159. """Inverse STFT.
  160. Args:
  161. input: Tensor(batch, T, F, 2) or ComplexTensor(batch, T, F)
  162. ilens: (batch,)
  163. Returns:
  164. wavs: (batch, samples)
  165. ilens: (batch,)
  166. """
  167. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  168. istft = torch.functional.istft
  169. else:
  170. try:
  171. import torchaudio
  172. except ImportError:
  173. raise ImportError(
  174. "Please install torchaudio>=0.3.0 or use torch>=1.6.0"
  175. )
  176. if not hasattr(torchaudio.functional, "istft"):
  177. raise ImportError(
  178. "Please install torchaudio>=0.3.0 or use torch>=1.6.0"
  179. )
  180. istft = torchaudio.functional.istft
  181. if self.window is not None:
  182. window_func = getattr(torch, f"{self.window}_window")
  183. if is_complex(input):
  184. datatype = input.real.dtype
  185. else:
  186. datatype = input.dtype
  187. window = window_func(self.win_length, dtype=datatype, device=input.device)
  188. else:
  189. window = None
  190. if is_complex(input):
  191. input = torch.stack([input.real, input.imag], dim=-1)
  192. elif input.shape[-1] != 2:
  193. raise TypeError("Invalid input type")
  194. input = input.transpose(1, 2)
  195. wavs = istft(
  196. input,
  197. n_fft=self.n_fft,
  198. hop_length=self.hop_length,
  199. win_length=self.win_length,
  200. window=window,
  201. center=self.center,
  202. normalized=self.normalized,
  203. onesided=self.onesided,
  204. length=ilens.max() if ilens is not None else ilens,
  205. )
  206. return wavs, ilens