whisper_frontend.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from typing import Tuple
  2. import torch
  3. import torch.nn as nn
  4. import whisper
  5. from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
  6. from funasr.register import tables
  7. from torch.nn.utils.rnn import pad_sequence
  8. @tables.register("frontend_classes", "WhisperFrontend")
  9. class WhisperFrontend(nn.Module):
  10. """Speech Representation Using Encoder Outputs from OpenAI's Whisper Model:
  11. URL: https://github.com/openai/whisper
  12. """
  13. def __init__(
  14. self,
  15. fs: int = 16000,
  16. whisper_model: str = None,
  17. do_pad_trim: bool = True,
  18. n_mels: int = 80,
  19. ):
  20. super().__init__()
  21. assert fs == 16000
  22. self.fs = fs
  23. self.n_fft = N_FFT
  24. self.win_length = N_FFT
  25. self.hop_length = HOP_LENGTH
  26. self.pad_samples = N_SAMPLES
  27. self.frame_shift = self.hop_length
  28. self.lfr_n = 1
  29. self.n_mels = n_mels
  30. if whisper_model == "large-v3" or whisper_model == "large":
  31. self.n_mels = 128
  32. self.mel_filters = whisper.audio.mel_filters
  33. self.do_pad_trim = do_pad_trim
  34. if do_pad_trim:
  35. self.pad_or_trim = whisper.pad_or_trim
  36. # assert whisper_model in whisper.available_models()
  37. def output_size(self) -> int:
  38. return self.n_mels
  39. def log_mel_spectrogram(
  40. self,
  41. audio: torch.Tensor,
  42. ilens: torch.Tensor = None,
  43. ) -> torch.Tensor:
  44. window = torch.hann_window(self.win_length).to(audio.device)
  45. stft = torch.stft(
  46. audio, self.n_fft, self.hop_length, window=window, return_complex=True
  47. )
  48. # whisper deletes the last frame by default (Shih-Lun)
  49. magnitudes = stft[..., :-1].abs() ** 2
  50. filters = self.mel_filters(audio.device, self.n_mels)
  51. mel_spec = filters @ magnitudes
  52. log_spec = torch.clamp(mel_spec, min=1e-10).log10()
  53. if ilens is not None:
  54. olens = ilens // self.hop_length
  55. else:
  56. olens = None
  57. log_spec = torch.maximum(
  58. log_spec,
  59. log_spec.view(audio.size(0), -1).max(dim=-1)[0][:, None, None] - 8.0,
  60. )
  61. log_spec = (log_spec + 4.0) / 4.0
  62. return log_spec, olens
  63. def forward(
  64. self, input: torch.Tensor, input_lengths: torch.Tensor
  65. ) -> Tuple[torch.Tensor, torch.Tensor]:
  66. batch_size = input.size(0)
  67. feats = []
  68. feats_lens = []
  69. for i in range(batch_size):
  70. if self.do_pad_trim:
  71. feat = self.pad_or_trim(input[i], self.pad_samples)
  72. else:
  73. feat = input[i]
  74. feat, feat_len = self.log_mel_spectrogram(feat[None, :], input_lengths[0])
  75. feats.append(feat[0])
  76. feats_lens.append(feat_len)
  77. feats_lens = torch.as_tensor(feats_lens)
  78. if batch_size == 1:
  79. feats_pad = feats[0][None, :, :]
  80. else:
  81. feats_pad = pad_sequence(feats,
  82. batch_first=True,
  83. padding_value=0.0)
  84. return feats_pad, feats_lens