griffin_lim.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. #!/usr/bin/env python3
  2. """Griffin-Lim related modules."""
  3. # Copyright 2019 Tomoki Hayashi
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. import logging
  6. from distutils.version import LooseVersion
  7. from functools import partial
  8. from typing import Optional
  9. import librosa
  10. import numpy as np
  11. import torch
  12. EPS = 1e-10
  13. def logmel2linear(
  14. lmspc: np.ndarray,
  15. fs: int,
  16. n_fft: int,
  17. n_mels: int,
  18. fmin: int = None,
  19. fmax: int = None,
  20. ) -> np.ndarray:
  21. """Convert log Mel filterbank to linear spectrogram.
  22. Args:
  23. lmspc: Log Mel filterbank (T, n_mels).
  24. fs: Sampling frequency.
  25. n_fft: The number of FFT points.
  26. n_mels: The number of mel basis.
  27. f_min: Minimum frequency to analyze.
  28. f_max: Maximum frequency to analyze.
  29. Returns:
  30. Linear spectrogram (T, n_fft // 2 + 1).
  31. """
  32. assert lmspc.shape[1] == n_mels
  33. fmin = 0 if fmin is None else fmin
  34. fmax = fs / 2 if fmax is None else fmax
  35. mspc = np.power(10.0, lmspc)
  36. mel_basis = librosa.filters.mel(
  37. sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax
  38. )
  39. inv_mel_basis = np.linalg.pinv(mel_basis)
  40. return np.maximum(EPS, np.dot(inv_mel_basis, mspc.T).T)
  41. def griffin_lim(
  42. spc: np.ndarray,
  43. n_fft: int,
  44. n_shift: int,
  45. win_length: int = None,
  46. window: Optional[str] = "hann",
  47. n_iter: Optional[int] = 32,
  48. ) -> np.ndarray:
  49. """Convert linear spectrogram into waveform using Griffin-Lim.
  50. Args:
  51. spc: Linear spectrogram (T, n_fft // 2 + 1).
  52. n_fft: The number of FFT points.
  53. n_shift: Shift size in points.
  54. win_length: Window length in points.
  55. window: Window function type.
  56. n_iter: The number of iterations.
  57. Returns:
  58. Reconstructed waveform (N,).
  59. """
  60. # assert the size of input linear spectrogram
  61. assert spc.shape[1] == n_fft // 2 + 1
  62. if LooseVersion(librosa.__version__) >= LooseVersion("0.7.0"):
  63. # use librosa's fast Grriffin-Lim algorithm
  64. spc = np.abs(spc.T)
  65. y = librosa.griffinlim(
  66. S=spc,
  67. n_iter=n_iter,
  68. hop_length=n_shift,
  69. win_length=win_length,
  70. window=window,
  71. center=True if spc.shape[1] > 1 else False,
  72. )
  73. else:
  74. # use slower version of Grriffin-Lim algorithm
  75. logging.warning(
  76. "librosa version is old. use slow version of Grriffin-Lim algorithm."
  77. "if you want to use fast Griffin-Lim, please update librosa via "
  78. "`source ./path.sh && pip install librosa==0.7.0`."
  79. )
  80. cspc = np.abs(spc).astype(np.complex).T
  81. angles = np.exp(2j * np.pi * np.random.rand(*cspc.shape))
  82. y = librosa.istft(cspc * angles, n_shift, win_length, window=window)
  83. for i in range(n_iter):
  84. angles = np.exp(
  85. 1j
  86. * np.angle(librosa.stft(y, n_fft, n_shift, win_length, window=window))
  87. )
  88. y = librosa.istft(cspc * angles, n_shift, win_length, window=window)
  89. return y
  90. # TODO(kan-bayashi): write as torch.nn.Module
  91. class Spectrogram2Waveform(object):
  92. """Spectrogram to waveform conversion module."""
  93. def __init__(
  94. self,
  95. n_fft: int,
  96. n_shift: int,
  97. fs: int = None,
  98. n_mels: int = None,
  99. win_length: int = None,
  100. window: Optional[str] = "hann",
  101. fmin: int = None,
  102. fmax: int = None,
  103. griffin_lim_iters: Optional[int] = 8,
  104. ):
  105. """Initialize module.
  106. Args:
  107. fs: Sampling frequency.
  108. n_fft: The number of FFT points.
  109. n_shift: Shift size in points.
  110. n_mels: The number of mel basis.
  111. win_length: Window length in points.
  112. window: Window function type.
  113. f_min: Minimum frequency to analyze.
  114. f_max: Maximum frequency to analyze.
  115. griffin_lim_iters: The number of iterations.
  116. """
  117. self.fs = fs
  118. self.logmel2linear = (
  119. partial(
  120. logmel2linear, fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax
  121. )
  122. if n_mels is not None
  123. else None
  124. )
  125. self.griffin_lim = partial(
  126. griffin_lim,
  127. n_fft=n_fft,
  128. n_shift=n_shift,
  129. win_length=win_length,
  130. window=window,
  131. n_iter=griffin_lim_iters,
  132. )
  133. self.params = dict(
  134. n_fft=n_fft,
  135. n_shift=n_shift,
  136. win_length=win_length,
  137. window=window,
  138. n_iter=griffin_lim_iters,
  139. )
  140. if n_mels is not None:
  141. self.params.update(fs=fs, n_mels=n_mels, fmin=fmin, fmax=fmax)
  142. def __repr__(self):
  143. retval = f"{self.__class__.__name__}("
  144. for k, v in self.params.items():
  145. retval += f"{k}={v}, "
  146. retval += ")"
  147. return retval
  148. def __call__(self, spc: torch.Tensor) -> torch.Tensor:
  149. """Convert spectrogram to waveform.
  150. Args:
  151. spc: Log Mel filterbank (T_feats, n_mels)
  152. or linear spectrogram (T_feats, n_fft // 2 + 1).
  153. Returns:
  154. Tensor: Reconstructed waveform (T_wav,).
  155. """
  156. device = spc.device
  157. dtype = spc.dtype
  158. spc = spc.cpu().numpy()
  159. if self.logmel2linear is not None:
  160. spc = self.logmel2linear(spc)
  161. wav = self.griffin_lim(spc)
  162. return torch.tensor(wav).to(device=device, dtype=dtype)