griffin_lim.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. #!/usr/bin/env python3
  2. """Griffin-Lim related modules."""
  3. # Copyright 2019 Tomoki Hayashi
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. import logging
  6. from distutils.version import LooseVersion
  7. from functools import partial
  8. from typeguard import check_argument_types
  9. from typing import Optional
  10. import librosa
  11. import numpy as np
  12. import torch
  13. EPS = 1e-10
  14. def logmel2linear(
  15. lmspc: np.ndarray,
  16. fs: int,
  17. n_fft: int,
  18. n_mels: int,
  19. fmin: int = None,
  20. fmax: int = None,
  21. ) -> np.ndarray:
  22. """Convert log Mel filterbank to linear spectrogram.
  23. Args:
  24. lmspc: Log Mel filterbank (T, n_mels).
  25. fs: Sampling frequency.
  26. n_fft: The number of FFT points.
  27. n_mels: The number of mel basis.
  28. f_min: Minimum frequency to analyze.
  29. f_max: Maximum frequency to analyze.
  30. Returns:
  31. Linear spectrogram (T, n_fft // 2 + 1).
  32. """
  33. assert lmspc.shape[1] == n_mels
  34. fmin = 0 if fmin is None else fmin
  35. fmax = fs / 2 if fmax is None else fmax
  36. mspc = np.power(10.0, lmspc)
  37. mel_basis = librosa.filters.mel(
  38. sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax
  39. )
  40. inv_mel_basis = np.linalg.pinv(mel_basis)
  41. return np.maximum(EPS, np.dot(inv_mel_basis, mspc.T).T)
  42. def griffin_lim(
  43. spc: np.ndarray,
  44. n_fft: int,
  45. n_shift: int,
  46. win_length: int = None,
  47. window: Optional[str] = "hann",
  48. n_iter: Optional[int] = 32,
  49. ) -> np.ndarray:
  50. """Convert linear spectrogram into waveform using Griffin-Lim.
  51. Args:
  52. spc: Linear spectrogram (T, n_fft // 2 + 1).
  53. n_fft: The number of FFT points.
  54. n_shift: Shift size in points.
  55. win_length: Window length in points.
  56. window: Window function type.
  57. n_iter: The number of iterations.
  58. Returns:
  59. Reconstructed waveform (N,).
  60. """
  61. # assert the size of input linear spectrogram
  62. assert spc.shape[1] == n_fft // 2 + 1
  63. if LooseVersion(librosa.__version__) >= LooseVersion("0.7.0"):
  64. # use librosa's fast Grriffin-Lim algorithm
  65. spc = np.abs(spc.T)
  66. y = librosa.griffinlim(
  67. S=spc,
  68. n_iter=n_iter,
  69. hop_length=n_shift,
  70. win_length=win_length,
  71. window=window,
  72. center=True if spc.shape[1] > 1 else False,
  73. )
  74. else:
  75. # use slower version of Grriffin-Lim algorithm
  76. logging.warning(
  77. "librosa version is old. use slow version of Grriffin-Lim algorithm."
  78. "if you want to use fast Griffin-Lim, please update librosa via "
  79. "`source ./path.sh && pip install librosa==0.7.0`."
  80. )
  81. cspc = np.abs(spc).astype(np.complex).T
  82. angles = np.exp(2j * np.pi * np.random.rand(*cspc.shape))
  83. y = librosa.istft(cspc * angles, n_shift, win_length, window=window)
  84. for i in range(n_iter):
  85. angles = np.exp(
  86. 1j
  87. * np.angle(librosa.stft(y, n_fft, n_shift, win_length, window=window))
  88. )
  89. y = librosa.istft(cspc * angles, n_shift, win_length, window=window)
  90. return y
  91. # TODO(kan-bayashi): write as torch.nn.Module
  92. class Spectrogram2Waveform(object):
  93. """Spectrogram to waveform conversion module."""
  94. def __init__(
  95. self,
  96. n_fft: int,
  97. n_shift: int,
  98. fs: int = None,
  99. n_mels: int = None,
  100. win_length: int = None,
  101. window: Optional[str] = "hann",
  102. fmin: int = None,
  103. fmax: int = None,
  104. griffin_lim_iters: Optional[int] = 8,
  105. ):
  106. """Initialize module.
  107. Args:
  108. fs: Sampling frequency.
  109. n_fft: The number of FFT points.
  110. n_shift: Shift size in points.
  111. n_mels: The number of mel basis.
  112. win_length: Window length in points.
  113. window: Window function type.
  114. f_min: Minimum frequency to analyze.
  115. f_max: Maximum frequency to analyze.
  116. griffin_lim_iters: The number of iterations.
  117. """
  118. assert check_argument_types()
  119. self.fs = fs
  120. self.logmel2linear = (
  121. partial(
  122. logmel2linear, fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax
  123. )
  124. if n_mels is not None
  125. else None
  126. )
  127. self.griffin_lim = partial(
  128. griffin_lim,
  129. n_fft=n_fft,
  130. n_shift=n_shift,
  131. win_length=win_length,
  132. window=window,
  133. n_iter=griffin_lim_iters,
  134. )
  135. self.params = dict(
  136. n_fft=n_fft,
  137. n_shift=n_shift,
  138. win_length=win_length,
  139. window=window,
  140. n_iter=griffin_lim_iters,
  141. )
  142. if n_mels is not None:
  143. self.params.update(fs=fs, n_mels=n_mels, fmin=fmin, fmax=fmax)
  144. def __repr__(self):
  145. retval = f"{self.__class__.__name__}("
  146. for k, v in self.params.items():
  147. retval += f"{k}={v}, "
  148. retval += ")"
  149. return retval
  150. def __call__(self, spc: torch.Tensor) -> torch.Tensor:
  151. """Convert spectrogram to waveform.
  152. Args:
  153. spc: Log Mel filterbank (T_feats, n_mels)
  154. or linear spectrogram (T_feats, n_fft // 2 + 1).
  155. Returns:
  156. Tensor: Reconstructed waveform (T_wav,).
  157. """
  158. device = spc.device
  159. dtype = spc.dtype
  160. spc = spc.cpu().numpy()
  161. if self.logmel2linear is not None:
  162. spc = self.logmel2linear(spc)
  163. wav = self.griffin_lim(spc)
  164. return torch.tensor(wav).to(device=device, dtype=dtype)