stft.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. from distutils.version import LooseVersion
  2. from typing import Optional
  3. from typing import Tuple
  4. from typing import Union
  5. import torch
  6. from torch_complex.tensor import ComplexTensor
  7. from typeguard import check_argument_types
  8. from funasr.modules.nets_utils import make_pad_mask
  9. from funasr.layers.complex_utils import is_complex
  10. from funasr.layers.inversible_interface import InversibleInterface
  11. import librosa
  12. import numpy as np
  13. is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
  14. is_torch_1_7_plus = LooseVersion(torch.__version__) >= LooseVersion("1.7")
  15. class Stft(torch.nn.Module, InversibleInterface):
  16. def __init__(
  17. self,
  18. n_fft: int = 512,
  19. win_length: int = None,
  20. hop_length: int = 128,
  21. window: Optional[str] = "hann",
  22. center: bool = True,
  23. normalized: bool = False,
  24. onesided: bool = True,
  25. ):
  26. assert check_argument_types()
  27. super().__init__()
  28. self.n_fft = n_fft
  29. if win_length is None:
  30. self.win_length = n_fft
  31. else:
  32. self.win_length = win_length
  33. self.hop_length = hop_length
  34. self.center = center
  35. self.normalized = normalized
  36. self.onesided = onesided
  37. if window is not None and not hasattr(torch, f"{window}_window"):
  38. raise ValueError(f"{window} window is not implemented")
  39. self.window = window
  40. def extra_repr(self):
  41. return (
  42. f"n_fft={self.n_fft}, "
  43. f"win_length={self.win_length}, "
  44. f"hop_length={self.hop_length}, "
  45. f"center={self.center}, "
  46. f"normalized={self.normalized}, "
  47. f"onesided={self.onesided}"
  48. )
  49. def forward(
  50. self, input: torch.Tensor, ilens: torch.Tensor = None
  51. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  52. """STFT forward function.
  53. Args:
  54. input: (Batch, Nsamples) or (Batch, Nsample, Channels)
  55. ilens: (Batch)
  56. Returns:
  57. output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
  58. """
  59. bs = input.size(0)
  60. if input.dim() == 3:
  61. multi_channel = True
  62. # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
  63. input = input.transpose(1, 2).reshape(-1, input.size(1))
  64. else:
  65. multi_channel = False
  66. # NOTE(kamo):
  67. # The default behaviour of torch.stft is compatible with librosa.stft
  68. # about padding and scaling.
  69. # Note that it's different from scipy.signal.stft
  70. # output: (Batch, Freq, Frames, 2=real_imag)
  71. # or (Batch, Channel, Freq, Frames, 2=real_imag)
  72. if self.window is not None:
  73. window_func = getattr(torch, f"{self.window}_window")
  74. window = window_func(
  75. self.win_length, dtype=input.dtype, device=input.device
  76. )
  77. else:
  78. window = None
  79. # For the compatibility of ARM devices, which do not support
  80. # torch.stft() due to the lake of MKL.
  81. if input.is_cuda or torch.backends.mkl.is_available():
  82. stft_kwargs = dict(
  83. n_fft=self.n_fft,
  84. win_length=self.win_length,
  85. hop_length=self.hop_length,
  86. center=self.center,
  87. window=window,
  88. normalized=self.normalized,
  89. onesided=self.onesided,
  90. )
  91. if is_torch_1_7_plus:
  92. stft_kwargs["return_complex"] = False
  93. output = torch.stft(input, **stft_kwargs)
  94. else:
  95. if self.training:
  96. raise NotImplementedError(
  97. "stft is implemented with librosa on this device, which does not "
  98. "support the training mode."
  99. )
  100. # use stft_kwargs to flexibly control different PyTorch versions' kwargs
  101. stft_kwargs = dict(
  102. n_fft=self.n_fft,
  103. win_length=self.win_length,
  104. hop_length=self.hop_length,
  105. center=self.center,
  106. window=window,
  107. )
  108. if window is not None:
  109. # pad the given window to n_fft
  110. n_pad_left = (self.n_fft - window.shape[0]) // 2
  111. n_pad_right = self.n_fft - window.shape[0] - n_pad_left
  112. stft_kwargs["window"] = torch.cat(
  113. [torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0
  114. ).numpy()
  115. else:
  116. win_length = (
  117. self.win_length if self.win_length is not None else self.n_fft
  118. )
  119. stft_kwargs["window"] = torch.ones(win_length)
  120. output = []
  121. # iterate over istances in a batch
  122. for i, instance in enumerate(input):
  123. stft = librosa.stft(input[i].numpy(), **stft_kwargs)
  124. output.append(torch.tensor(np.stack([stft.real, stft.imag], -1)))
  125. output = torch.stack(output, 0)
  126. if not self.onesided:
  127. len_conj = self.n_fft - output.shape[1]
  128. conj = output[:, 1 : 1 + len_conj].flip(1)
  129. conj[:, :, :, -1].data *= -1
  130. output = torch.cat([output, conj], 1)
  131. if self.normalized:
  132. output = output * (stft_kwargs["window"].shape[0] ** (-0.5))
  133. # output: (Batch, Freq, Frames, 2=real_imag)
  134. # -> (Batch, Frames, Freq, 2=real_imag)
  135. output = output.transpose(1, 2)
  136. if multi_channel:
  137. # output: (Batch * Channel, Frames, Freq, 2=real_imag)
  138. # -> (Batch, Frame, Channel, Freq, 2=real_imag)
  139. output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
  140. 1, 2
  141. )
  142. if ilens is not None:
  143. if self.center:
  144. pad = self.n_fft // 2
  145. ilens = ilens + 2 * pad
  146. olens = (ilens - self.n_fft) // self.hop_length + 1
  147. output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
  148. else:
  149. olens = None
  150. return output, olens
  151. def inverse(
  152. self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor = None
  153. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  154. """Inverse STFT.
  155. Args:
  156. input: Tensor(batch, T, F, 2) or ComplexTensor(batch, T, F)
  157. ilens: (batch,)
  158. Returns:
  159. wavs: (batch, samples)
  160. ilens: (batch,)
  161. """
  162. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  163. istft = torch.functional.istft
  164. else:
  165. try:
  166. import torchaudio
  167. except ImportError:
  168. raise ImportError(
  169. "Please install torchaudio>=0.3.0 or use torch>=1.6.0"
  170. )
  171. if not hasattr(torchaudio.functional, "istft"):
  172. raise ImportError(
  173. "Please install torchaudio>=0.3.0 or use torch>=1.6.0"
  174. )
  175. istft = torchaudio.functional.istft
  176. if self.window is not None:
  177. window_func = getattr(torch, f"{self.window}_window")
  178. if is_complex(input):
  179. datatype = input.real.dtype
  180. else:
  181. datatype = input.dtype
  182. window = window_func(self.win_length, dtype=datatype, device=input.device)
  183. else:
  184. window = None
  185. if is_complex(input):
  186. input = torch.stack([input.real, input.imag], dim=-1)
  187. elif input.shape[-1] != 2:
  188. raise TypeError("Invalid input type")
  189. input = input.transpose(1, 2)
  190. wavs = istft(
  191. input,
  192. n_fft=self.n_fft,
  193. hop_length=self.hop_length,
  194. win_length=self.win_length,
  195. window=window,
  196. center=self.center,
  197. normalized=self.normalized,
  198. onesided=self.onesided,
  199. length=ilens.max() if ilens is not None else ilens,
  200. )
  201. return wavs, ilens