decoder.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. import copy
  4. from typing import Any, List, Tuple
  5. import torch
  6. from torch import nn
  7. import whisper
  8. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  9. from funasr.register import tables
  10. @tables.register("decoder_classes", "OpenAIWhisperDecoderWarp")
  11. class OpenAIWhisperDecoderWarp(nn.Module):
  12. """Transformer-based Speech-to-Text Decoder from OpenAI's Whisper Model:
  13. URL: https://github.com/openai/whisper
  14. """
  15. def __init__(
  16. self,
  17. dropout_rate: float = 0.0,
  18. whisper_model: str = "small",
  19. download_dir: str = None,
  20. use_padmask: bool = False,
  21. ):
  22. super().__init__()
  23. assert whisper_model in whisper.available_models()
  24. _model = whisper.load_model(
  25. whisper_model, download_root=download_dir, device="cpu"
  26. )
  27. self.decoders = copy.deepcopy(_model.decoder)
  28. attention_dim = self.decoders.token_embedding.embedding_dim
  29. # note that originally Whisper doesn't use dropouts
  30. self.dropout = torch.nn.Dropout(dropout_rate)
  31. self.decoders.train()
  32. del _model
  33. self.use_padmask = use_padmask
  34. def forward(
  35. self,
  36. hs_pad: torch.Tensor,
  37. hlens: torch.Tensor,
  38. ys_in_pad: torch.Tensor,
  39. ys_in_lens: torch.Tensor,
  40. ) -> Tuple[torch.Tensor, torch.Tensor]:
  41. """Forward decoder.
  42. Args:
  43. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  44. hlens: (batch)
  45. ys_in_pad:
  46. input token ids, int64 (batch, maxlen_out)
  47. if input_layer == "embed"
  48. input tensor (batch, maxlen_out, #mels) in the other cases
  49. ys_in_lens: (batch)
  50. Returns:
  51. (tuple): tuple containing:
  52. x: decoded token score before softmax (batch, maxlen_out, token)
  53. if use_output_layer is True,
  54. olens: (batch, )
  55. """
  56. tgt, memory = ys_in_pad, hs_pad
  57. tgt = (
  58. self.decoders.token_embedding(tgt)
  59. + self.decoders.positional_embedding[: tgt.size(1)]
  60. )
  61. tgt = self.dropout(tgt)
  62. x = tgt.to(memory.dtype)
  63. if self.use_padmask:
  64. memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
  65. else:
  66. memory_mask = None
  67. for layer, block in enumerate(self.decoders.blocks):
  68. x = block(x, memory, mask=self.decoders.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
  69. if layer < len(self.decoders.blocks) - 1:
  70. x = self.dropout(x)
  71. x = self.decoders.ln(x)
  72. x = (
  73. x @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)
  74. ).float()
  75. return x, ys_in_lens
  76. def forward_one_step(
  77. self,
  78. tgt: torch.Tensor,
  79. tgt_mask: torch.Tensor,
  80. memory: torch.Tensor,
  81. cache: List[torch.Tensor] = None,
  82. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  83. """Forward one step.
  84. Args:
  85. tgt: input token ids, int64 (batch, maxlen_out)
  86. tgt_mask: input token mask, (batch, maxlen_out)
  87. dtype=torch.uint8 in PyTorch 1.2-
  88. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  89. memory: encoded memory, float32 (batch, maxlen_in, feat)
  90. cache: cached output list of (batch, max_time_out-1, size)
  91. Returns:
  92. y, cache: NN output value and cache per `self.decoders`.
  93. y.shape` is (batch, maxlen_out, token)
  94. NOTE (Shih-Lun):
  95. cache implementation is ignored for now
  96. for simplicity & correctness
  97. """
  98. x = (
  99. self.decoders.token_embedding(tgt)
  100. + self.decoders.positional_embedding[: tgt.size(1)]
  101. )
  102. x = self.dropout(x)
  103. x = x.to(memory.dtype)
  104. for layer, block in enumerate(self.decoders.blocks):
  105. x = block(x, memory, mask=self.decoders.mask)
  106. if layer < len(self.decoders.blocks) - 1:
  107. x = self.dropout(x)
  108. x = self.decoders.ln(x)
  109. y = x[:, -1]
  110. y = (
  111. y @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)
  112. ).float()
  113. y = torch.log_softmax(y, dim=-1)
  114. return y, None
  115. def score(self, ys, state, x):
  116. """Score."""
  117. logp, state = self.forward_one_step(
  118. ys.unsqueeze(0), torch.empty(0), x.unsqueeze(0), cache=state # dummy mask
  119. )
  120. return logp.squeeze(0), state
  121. def batch_score(
  122. self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
  123. ) -> Tuple[torch.Tensor, List[Any]]:
  124. """Score new token batch.
  125. Args:
  126. ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
  127. states (List[Any]): Scorer states for prefix tokens.
  128. xs (torch.Tensor):
  129. The encoder feature that generates ys (n_batch, xlen, n_feat).
  130. Returns:
  131. tuple[torch.Tensor, List[Any]]: Tuple of
  132. batchfied scores for next token with shape of `(n_batch, n_vocab)`
  133. and next state list for ys.
  134. """
  135. # batch decoding, dummy mask is passed
  136. logp, states = self.forward_one_step(ys, torch.empty(0), xs, cache=None)
  137. return logp, None