| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
- # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
- import copy
- from typing import Optional, Tuple, Union
- import torch
- from torch import nn
- import torch.nn.functional as F
- import whisper
- from funasr.models.transformer.utils.nets_utils import make_pad_mask
- from funasr.models.specaug.specaug import SpecAug
- from funasr.register import tables
- @tables.register("encoder_classes", "OpenAIWhisperEncoderWarp")
- class OpenAIWhisperEncoderWarp(nn.Module):
- """Transformer-based Speech Encoder from OpenAI's Whisper Model:
- URL: https://github.com/openai/whisper
- """
- def __init__(
- self,
- dropout_rate: float = 0.0,
- whisper_model: str = "small",
- download_dir: str = None,
- use_specaug: bool = False,
- use_padmask: bool = False,
- specaug_conf: Union[dict, None] = None,
- ):
- super().__init__()
- # note that originally Whisper doesn't use dropouts
- self.dropout = torch.nn.Dropout(dropout_rate)
- assert whisper_model in whisper.available_models()
- _model = whisper.load_model(
- whisper_model, download_root=download_dir, device="cpu"
- )
- self.encoders = copy.deepcopy(_model.encoder)
- self.encoders.train()
- del _model
- if use_specaug:
- self.specaug = SpecAug(**specaug_conf)
- else:
- self.specaug = None
- self.use_padmask = use_padmask
- def whisper_encode(
- self,
- input: torch.Tensor,
- ilens: torch.Tensor = None,
- ) -> torch.Tensor:
- x = F.gelu(self.encoders.conv1(input))
- x = F.gelu(self.encoders.conv2(x))
- x = x.permute(0, 2, 1)
- n_frames = x.size(1)
- max_pos = self.encoders.positional_embedding.size(0)
- if n_frames <= max_pos:
- x = (x + self.encoders.positional_embedding[: x.size(1), :]).to(x.dtype)
- else:
- # due to positional encoding, audios >30 sec won't be accepted
- x = x[:, :max_pos, :] + self.encoders.positional_embedding
- if ilens is not None:
- olens = (
- 1
- + (
- ilens
- - self.encoders.conv2.kernel_size[0]
- + 2 * self.encoders.conv2.padding[0]
- )
- // self.encoders.conv2.stride[0]
- )
- olens = torch.clamp(olens, max=max_pos)
- else:
- olens = None
- if self.use_padmask:
- padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
- else:
- padding_mask = None
- x = self.dropout(x)
- for layer, block in enumerate(self.encoders.blocks):
- x = block(x)
- if layer < len(self.encoders.blocks) - 1:
- x = self.dropout(x)
- x = self.encoders.ln_post(x)
- return x, olens
- def output_size(self) -> int:
- # dummy output size
- return self.encoders.conv2.weight.shape[0]
- def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- feats, feats_lens = xs_pad, ilens
- if self.specaug is not None and self.encoders.training:
- feats = torch.transpose(feats, 1, 2)
- feats, feats_lens = self.specaug(feats, feats_lens)
- feats = torch.transpose(feats, 1, 2)
- xs_pad, olens = self.whisper_encode(feats, feats_lens)
- return xs_pad, olens, None
|