encoder.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. import copy
  4. from typing import Optional, Tuple, Union
  5. import torch
  6. from torch import nn
  7. import torch.nn.functional as F
  8. import whisper
  9. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  10. from funasr.models.specaug.specaug import SpecAug
  11. from funasr.register import tables
  12. @tables.register("encoder_classes", "OpenAIWhisperEncoderWarp")
  13. class OpenAIWhisperEncoderWarp(nn.Module):
  14. """Transformer-based Speech Encoder from OpenAI's Whisper Model:
  15. URL: https://github.com/openai/whisper
  16. """
  17. def __init__(
  18. self,
  19. dropout_rate: float = 0.0,
  20. whisper_model: str = "small",
  21. download_dir: str = None,
  22. use_specaug: bool = False,
  23. use_padmask: bool = False,
  24. specaug_conf: Union[dict, None] = None,
  25. ):
  26. super().__init__()
  27. # note that originally Whisper doesn't use dropouts
  28. self.dropout = torch.nn.Dropout(dropout_rate)
  29. assert whisper_model in whisper.available_models()
  30. _model = whisper.load_model(
  31. whisper_model, download_root=download_dir, device="cpu"
  32. )
  33. self.encoders = copy.deepcopy(_model.encoder)
  34. self.encoders.train()
  35. del _model
  36. if use_specaug:
  37. self.specaug = SpecAug(**specaug_conf)
  38. else:
  39. self.specaug = None
  40. self.use_padmask = use_padmask
  41. def whisper_encode(
  42. self,
  43. input: torch.Tensor,
  44. ilens: torch.Tensor = None,
  45. ) -> torch.Tensor:
  46. x = F.gelu(self.encoders.conv1(input))
  47. x = F.gelu(self.encoders.conv2(x))
  48. x = x.permute(0, 2, 1)
  49. n_frames = x.size(1)
  50. max_pos = self.encoders.positional_embedding.size(0)
  51. if n_frames <= max_pos:
  52. x = (x + self.encoders.positional_embedding[: x.size(1), :]).to(x.dtype)
  53. else:
  54. # due to positional encoding, audios >30 sec won't be accepted
  55. x = x[:, :max_pos, :] + self.encoders.positional_embedding
  56. if ilens is not None:
  57. olens = (
  58. 1
  59. + (
  60. ilens
  61. - self.encoders.conv2.kernel_size[0]
  62. + 2 * self.encoders.conv2.padding[0]
  63. )
  64. // self.encoders.conv2.stride[0]
  65. )
  66. olens = torch.clamp(olens, max=max_pos)
  67. else:
  68. olens = None
  69. if self.use_padmask:
  70. padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
  71. else:
  72. padding_mask = None
  73. x = self.dropout(x)
  74. for layer, block in enumerate(self.encoders.blocks):
  75. x = block(x)
  76. if layer < len(self.encoders.blocks) - 1:
  77. x = self.dropout(x)
  78. x = self.encoders.ln_post(x)
  79. return x, olens
  80. def output_size(self) -> int:
  81. # dummy output size
  82. return self.encoders.conv2.weight.shape[0]
  83. def forward(
  84. self,
  85. xs_pad: torch.Tensor,
  86. ilens: torch.Tensor,
  87. prev_states: torch.Tensor = None,
  88. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  89. feats, feats_lens = xs_pad, ilens
  90. if self.specaug is not None and self.encoders.training:
  91. feats = torch.transpose(feats, 1, 2)
  92. feats, feats_lens = self.specaug(feats, feats_lens)
  93. feats = torch.transpose(feats, 1, 2)
  94. xs_pad, olens = self.whisper_encode(feats, feats_lens)
  95. return xs_pad, olens, None