stft.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. from distutils.version import LooseVersion
  2. from typing import Optional
  3. from typing import Tuple
  4. from typing import Union
  5. import torch
  6. from torch_complex.tensor import ComplexTensor
  7. from funasr.modules.nets_utils import make_pad_mask
  8. from funasr.layers.complex_utils import is_complex
  9. from funasr.layers.inversible_interface import InversibleInterface
  10. import librosa
  11. import numpy as np
  12. is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
  13. is_torch_1_7_plus = LooseVersion(torch.__version__) >= LooseVersion("1.7")
  14. class Stft(torch.nn.Module, InversibleInterface):
  15. def __init__(
  16. self,
  17. n_fft: int = 512,
  18. win_length: int = None,
  19. hop_length: int = 128,
  20. window: Optional[str] = "hann",
  21. center: bool = True,
  22. normalized: bool = False,
  23. onesided: bool = True,
  24. ):
  25. super().__init__()
  26. self.n_fft = n_fft
  27. if win_length is None:
  28. self.win_length = n_fft
  29. else:
  30. self.win_length = win_length
  31. self.hop_length = hop_length
  32. self.center = center
  33. self.normalized = normalized
  34. self.onesided = onesided
  35. if window is not None and not hasattr(torch, f"{window}_window"):
  36. if window.lower() != "povey":
  37. raise ValueError(f"{window} window is not implemented")
  38. self.window = window
  39. def extra_repr(self):
  40. return (
  41. f"n_fft={self.n_fft}, "
  42. f"win_length={self.win_length}, "
  43. f"hop_length={self.hop_length}, "
  44. f"center={self.center}, "
  45. f"normalized={self.normalized}, "
  46. f"onesided={self.onesided}"
  47. )
  48. def forward(
  49. self, input: torch.Tensor, ilens: torch.Tensor = None
  50. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  51. """STFT forward function.
  52. Args:
  53. input: (Batch, Nsamples) or (Batch, Nsample, Channels)
  54. ilens: (Batch)
  55. Returns:
  56. output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
  57. """
  58. bs = input.size(0)
  59. if input.dim() == 3:
  60. multi_channel = True
  61. # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
  62. input = input.transpose(1, 2).reshape(-1, input.size(1))
  63. else:
  64. multi_channel = False
  65. # NOTE(kamo):
  66. # The default behaviour of torch.stft is compatible with librosa.stft
  67. # about padding and scaling.
  68. # Note that it's different from scipy.signal.stft
  69. # output: (Batch, Freq, Frames, 2=real_imag)
  70. # or (Batch, Channel, Freq, Frames, 2=real_imag)
  71. if self.window is not None:
  72. if self.window.lower() == "povey":
  73. window = torch.hann_window(self.win_length, periodic=False,
  74. device=input.device, dtype=input.dtype).pow(0.85)
  75. else:
  76. window_func = getattr(torch, f"{self.window}_window")
  77. window = window_func(
  78. self.win_length, dtype=input.dtype, device=input.device
  79. )
  80. else:
  81. window = None
  82. # For the compatibility of ARM devices, which do not support
  83. # torch.stft() due to the lake of MKL.
  84. if input.is_cuda or torch.backends.mkl.is_available():
  85. stft_kwargs = dict(
  86. n_fft=self.n_fft,
  87. win_length=self.win_length,
  88. hop_length=self.hop_length,
  89. center=self.center,
  90. window=window,
  91. normalized=self.normalized,
  92. onesided=self.onesided,
  93. )
  94. if is_torch_1_7_plus:
  95. stft_kwargs["return_complex"] = False
  96. output = torch.stft(input, **stft_kwargs)
  97. else:
  98. if self.training:
  99. raise NotImplementedError(
  100. "stft is implemented with librosa on this device, which does not "
  101. "support the training mode."
  102. )
  103. # use stft_kwargs to flexibly control different PyTorch versions' kwargs
  104. stft_kwargs = dict(
  105. n_fft=self.n_fft,
  106. win_length=self.win_length,
  107. hop_length=self.hop_length,
  108. center=self.center,
  109. window=window,
  110. )
  111. if window is not None:
  112. # pad the given window to n_fft
  113. n_pad_left = (self.n_fft - window.shape[0]) // 2
  114. n_pad_right = self.n_fft - window.shape[0] - n_pad_left
  115. stft_kwargs["window"] = torch.cat(
  116. [torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0
  117. ).numpy()
  118. else:
  119. win_length = (
  120. self.win_length if self.win_length is not None else self.n_fft
  121. )
  122. stft_kwargs["window"] = torch.ones(win_length)
  123. output = []
  124. # iterate over istances in a batch
  125. for i, instance in enumerate(input):
  126. stft = librosa.stft(input[i].numpy(), **stft_kwargs)
  127. output.append(torch.tensor(np.stack([stft.real, stft.imag], -1)))
  128. output = torch.stack(output, 0)
  129. if not self.onesided:
  130. len_conj = self.n_fft - output.shape[1]
  131. conj = output[:, 1 : 1 + len_conj].flip(1)
  132. conj[:, :, :, -1].data *= -1
  133. output = torch.cat([output, conj], 1)
  134. if self.normalized:
  135. output = output * (stft_kwargs["window"].shape[0] ** (-0.5))
  136. # output: (Batch, Freq, Frames, 2=real_imag)
  137. # -> (Batch, Frames, Freq, 2=real_imag)
  138. output = output.transpose(1, 2)
  139. if multi_channel:
  140. # output: (Batch * Channel, Frames, Freq, 2=real_imag)
  141. # -> (Batch, Frame, Channel, Freq, 2=real_imag)
  142. output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
  143. 1, 2
  144. )
  145. if ilens is not None:
  146. if self.center:
  147. pad = self.n_fft // 2
  148. ilens = ilens + 2 * pad
  149. olens = (ilens - self.n_fft) // self.hop_length + 1
  150. output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
  151. else:
  152. olens = None
  153. return output, olens
  154. def inverse(
  155. self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor = None
  156. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  157. """Inverse STFT.
  158. Args:
  159. input: Tensor(batch, T, F, 2) or ComplexTensor(batch, T, F)
  160. ilens: (batch,)
  161. Returns:
  162. wavs: (batch, samples)
  163. ilens: (batch,)
  164. """
  165. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  166. istft = torch.functional.istft
  167. else:
  168. try:
  169. import torchaudio
  170. except ImportError:
  171. raise ImportError(
  172. "Please install torchaudio>=0.3.0 or use torch>=1.6.0"
  173. )
  174. if not hasattr(torchaudio.functional, "istft"):
  175. raise ImportError(
  176. "Please install torchaudio>=0.3.0 or use torch>=1.6.0"
  177. )
  178. istft = torchaudio.functional.istft
  179. if self.window is not None:
  180. window_func = getattr(torch, f"{self.window}_window")
  181. if is_complex(input):
  182. datatype = input.real.dtype
  183. else:
  184. datatype = input.dtype
  185. window = window_func(self.win_length, dtype=datatype, device=input.device)
  186. else:
  187. window = None
  188. if is_complex(input):
  189. input = torch.stack([input.real, input.imag], dim=-1)
  190. elif input.shape[-1] != 2:
  191. raise TypeError("Invalid input type")
  192. input = input.transpose(1, 2)
  193. wavs = istft(
  194. input,
  195. n_fft=self.n_fft,
  196. hop_length=self.hop_length,
  197. win_length=self.win_length,
  198. window=window,
  199. center=self.center,
  200. normalized=self.normalized,
  201. onesided=self.onesided,
  202. length=ilens.max() if ilens is not None else ilens,
  203. )
  204. return wavs, ilens