rnnt_decoder.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import torch
  6. from typing import List, Optional, Tuple
  7. from funasr.register import tables
  8. from funasr.models.specaug.specaug import SpecAug
  9. from funasr.models.transducer.beam_search_transducer import Hypothesis
  10. @tables.register("decoder_classes", "rnnt_decoder")
  11. class RNNTDecoder(torch.nn.Module):
  12. """RNN decoder module.
  13. Args:
  14. vocab_size: Vocabulary size.
  15. embed_size: Embedding size.
  16. hidden_size: Hidden size..
  17. rnn_type: Decoder layers type.
  18. num_layers: Number of decoder layers.
  19. dropout_rate: Dropout rate for decoder layers.
  20. embed_dropout_rate: Dropout rate for embedding layer.
  21. embed_pad: Embedding padding symbol ID.
  22. """
  23. def __init__(
  24. self,
  25. vocab_size: int,
  26. embed_size: int = 256,
  27. hidden_size: int = 256,
  28. rnn_type: str = "lstm",
  29. num_layers: int = 1,
  30. dropout_rate: float = 0.0,
  31. embed_dropout_rate: float = 0.0,
  32. embed_pad: int = 0,
  33. use_embed_mask: bool = False,
  34. ) -> None:
  35. """Construct a RNNDecoder object."""
  36. super().__init__()
  37. if rnn_type not in ("lstm", "gru"):
  38. raise ValueError(f"Not supported: rnn_type={rnn_type}")
  39. self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
  40. self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate)
  41. rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
  42. self.rnn = torch.nn.ModuleList(
  43. [rnn_class(embed_size, hidden_size, 1, batch_first=True)]
  44. )
  45. for _ in range(1, num_layers):
  46. self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
  47. self.dropout_rnn = torch.nn.ModuleList(
  48. [torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)]
  49. )
  50. self.dlayers = num_layers
  51. self.dtype = rnn_type
  52. self.output_size = hidden_size
  53. self.vocab_size = vocab_size
  54. self.device = next(self.parameters()).device
  55. self.score_cache = {}
  56. self.use_embed_mask = use_embed_mask
  57. if self.use_embed_mask:
  58. self._embed_mask = SpecAug(
  59. time_mask_width_range=3,
  60. num_time_mask=4,
  61. apply_freq_mask=False,
  62. apply_time_warp=False
  63. )
  64. def forward(
  65. self,
  66. labels: torch.Tensor,
  67. label_lens: torch.Tensor,
  68. states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
  69. ) -> torch.Tensor:
  70. """Encode source label sequences.
  71. Args:
  72. labels: Label ID sequences. (B, L)
  73. states: Decoder hidden states.
  74. ((N, B, D_dec), (N, B, D_dec) or None) or None
  75. Returns:
  76. dec_out: Decoder output sequences. (B, U, D_dec)
  77. """
  78. if states is None:
  79. states = self.init_state(labels.size(0))
  80. dec_embed = self.dropout_embed(self.embed(labels))
  81. if self.use_embed_mask and self.training:
  82. dec_embed = self._embed_mask(dec_embed, label_lens)[0]
  83. dec_out, states = self.rnn_forward(dec_embed, states)
  84. return dec_out
  85. def rnn_forward(
  86. self,
  87. x: torch.Tensor,
  88. state: Tuple[torch.Tensor, Optional[torch.Tensor]],
  89. ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
  90. """Encode source label sequences.
  91. Args:
  92. x: RNN input sequences. (B, D_emb)
  93. state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
  94. Returns:
  95. x: RNN output sequences. (B, D_dec)
  96. (h_next, c_next): Decoder hidden states.
  97. (N, B, D_dec), (N, B, D_dec) or None)
  98. """
  99. h_prev, c_prev = state
  100. h_next, c_next = self.init_state(x.size(0))
  101. for layer in range(self.dlayers):
  102. if self.dtype == "lstm":
  103. x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
  104. layer
  105. ](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
  106. else:
  107. x, h_next[layer : layer + 1] = self.rnn[layer](
  108. x, hx=h_prev[layer : layer + 1]
  109. )
  110. x = self.dropout_rnn[layer](x)
  111. return x, (h_next, c_next)
  112. def score(
  113. self,
  114. label: torch.Tensor,
  115. label_sequence: List[int],
  116. dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]],
  117. ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
  118. """One-step forward hypothesis.
  119. Args:
  120. label: Previous label. (1, 1)
  121. label_sequence: Current label sequence.
  122. dec_state: Previous decoder hidden states.
  123. ((N, 1, D_dec), (N, 1, D_dec) or None)
  124. Returns:
  125. dec_out: Decoder output sequence. (1, D_dec)
  126. dec_state: Decoder hidden states.
  127. ((N, 1, D_dec), (N, 1, D_dec) or None)
  128. """
  129. str_labels = "_".join(map(str, label_sequence))
  130. if str_labels in self.score_cache:
  131. dec_out, dec_state = self.score_cache[str_labels]
  132. else:
  133. dec_embed = self.embed(label)
  134. dec_out, dec_state = self.rnn_forward(dec_embed, dec_state)
  135. self.score_cache[str_labels] = (dec_out, dec_state)
  136. return dec_out[0], dec_state
  137. def batch_score(
  138. self,
  139. hyps: List[Hypothesis],
  140. ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
  141. """One-step forward hypotheses.
  142. Args:
  143. hyps: Hypotheses.
  144. Returns:
  145. dec_out: Decoder output sequences. (B, D_dec)
  146. states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
  147. """
  148. labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
  149. dec_embed = self.embed(labels)
  150. states = self.create_batch_states([h.dec_state for h in hyps])
  151. dec_out, states = self.rnn_forward(dec_embed, states)
  152. return dec_out.squeeze(1), states
  153. def set_device(self, device: torch.device) -> None:
  154. """Set GPU device to use.
  155. Args:
  156. device: Device ID.
  157. """
  158. self.device = device
  159. def init_state(
  160. self, batch_size: int
  161. ) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
  162. """Initialize decoder states.
  163. Args:
  164. batch_size: Batch size.
  165. Returns:
  166. : Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
  167. """
  168. h_n = torch.zeros(
  169. self.dlayers,
  170. batch_size,
  171. self.output_size,
  172. device=self.device,
  173. )
  174. if self.dtype == "lstm":
  175. c_n = torch.zeros(
  176. self.dlayers,
  177. batch_size,
  178. self.output_size,
  179. device=self.device,
  180. )
  181. return (h_n, c_n)
  182. return (h_n, None)
  183. def select_state(
  184. self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int
  185. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  186. """Get specified ID state from decoder hidden states.
  187. Args:
  188. states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
  189. idx: State ID to extract.
  190. Returns:
  191. : Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None)
  192. """
  193. return (
  194. states[0][:, idx : idx + 1, :],
  195. states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
  196. )
  197. def create_batch_states(
  198. self,
  199. new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
  200. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  201. """Create decoder hidden states.
  202. Args:
  203. new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)]
  204. Returns:
  205. states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
  206. """
  207. return (
  208. torch.cat([s[0] for s in new_states], dim=1),
  209. torch.cat([s[1] for s in new_states], dim=1)
  210. if self.dtype == "lstm"
  211. else None,
  212. )