rnnt_decoder.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. """RNN decoder definition for Transducer models."""
  2. from typing import List, Optional, Tuple
  3. import torch
  4. from typeguard import check_argument_types
  5. from funasr.modules.beam_search.beam_search_transducer import Hypothesis
  6. from funasr.models.specaug.specaug import SpecAug
  7. class RNNTDecoder(torch.nn.Module):
  8. """RNN decoder module.
  9. Args:
  10. vocab_size: Vocabulary size.
  11. embed_size: Embedding size.
  12. hidden_size: Hidden size..
  13. rnn_type: Decoder layers type.
  14. num_layers: Number of decoder layers.
  15. dropout_rate: Dropout rate for decoder layers.
  16. embed_dropout_rate: Dropout rate for embedding layer.
  17. embed_pad: Embedding padding symbol ID.
  18. """
  19. def __init__(
  20. self,
  21. vocab_size: int,
  22. embed_size: int = 256,
  23. hidden_size: int = 256,
  24. rnn_type: str = "lstm",
  25. num_layers: int = 1,
  26. dropout_rate: float = 0.0,
  27. embed_dropout_rate: float = 0.0,
  28. embed_pad: int = 0,
  29. ) -> None:
  30. """Construct a RNNDecoder object."""
  31. super().__init__()
  32. assert check_argument_types()
  33. if rnn_type not in ("lstm", "gru"):
  34. raise ValueError(f"Not supported: rnn_type={rnn_type}")
  35. self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
  36. self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate)
  37. rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
  38. self.rnn = torch.nn.ModuleList(
  39. [rnn_class(embed_size, hidden_size, 1, batch_first=True)]
  40. )
  41. for _ in range(1, num_layers):
  42. self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
  43. self.dropout_rnn = torch.nn.ModuleList(
  44. [torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)]
  45. )
  46. self.dlayers = num_layers
  47. self.dtype = rnn_type
  48. self.output_size = hidden_size
  49. self.vocab_size = vocab_size
  50. self.device = next(self.parameters()).device
  51. self.score_cache = {}
  52. def forward(
  53. self,
  54. labels: torch.Tensor,
  55. label_lens: torch.Tensor,
  56. states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
  57. ) -> torch.Tensor:
  58. """Encode source label sequences.
  59. Args:
  60. labels: Label ID sequences. (B, L)
  61. states: Decoder hidden states.
  62. ((N, B, D_dec), (N, B, D_dec) or None) or None
  63. Returns:
  64. dec_out: Decoder output sequences. (B, U, D_dec)
  65. """
  66. if states is None:
  67. states = self.init_state(labels.size(0))
  68. dec_embed = self.dropout_embed(self.embed(labels))
  69. dec_out, states = self.rnn_forward(dec_embed, states)
  70. return dec_out
  71. def rnn_forward(
  72. self,
  73. x: torch.Tensor,
  74. state: Tuple[torch.Tensor, Optional[torch.Tensor]],
  75. ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
  76. """Encode source label sequences.
  77. Args:
  78. x: RNN input sequences. (B, D_emb)
  79. state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
  80. Returns:
  81. x: RNN output sequences. (B, D_dec)
  82. (h_next, c_next): Decoder hidden states.
  83. (N, B, D_dec), (N, B, D_dec) or None)
  84. """
  85. h_prev, c_prev = state
  86. h_next, c_next = self.init_state(x.size(0))
  87. for layer in range(self.dlayers):
  88. if self.dtype == "lstm":
  89. x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
  90. layer
  91. ](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
  92. else:
  93. x, h_next[layer : layer + 1] = self.rnn[layer](
  94. x, hx=h_prev[layer : layer + 1]
  95. )
  96. x = self.dropout_rnn[layer](x)
  97. return x, (h_next, c_next)
  98. def score(
  99. self,
  100. label: torch.Tensor,
  101. label_sequence: List[int],
  102. dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]],
  103. ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
  104. """One-step forward hypothesis.
  105. Args:
  106. label: Previous label. (1, 1)
  107. label_sequence: Current label sequence.
  108. dec_state: Previous decoder hidden states.
  109. ((N, 1, D_dec), (N, 1, D_dec) or None)
  110. Returns:
  111. dec_out: Decoder output sequence. (1, D_dec)
  112. dec_state: Decoder hidden states.
  113. ((N, 1, D_dec), (N, 1, D_dec) or None)
  114. """
  115. str_labels = "_".join(map(str, label_sequence))
  116. if str_labels in self.score_cache:
  117. dec_out, dec_state = self.score_cache[str_labels]
  118. else:
  119. dec_embed = self.embed(label)
  120. dec_out, dec_state = self.rnn_forward(dec_embed, dec_state)
  121. self.score_cache[str_labels] = (dec_out, dec_state)
  122. return dec_out[0], dec_state
  123. def batch_score(
  124. self,
  125. hyps: List[Hypothesis],
  126. ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
  127. """One-step forward hypotheses.
  128. Args:
  129. hyps: Hypotheses.
  130. Returns:
  131. dec_out: Decoder output sequences. (B, D_dec)
  132. states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
  133. """
  134. labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
  135. dec_embed = self.embed(labels)
  136. states = self.create_batch_states([h.dec_state for h in hyps])
  137. dec_out, states = self.rnn_forward(dec_embed, states)
  138. return dec_out.squeeze(1), states
  139. def set_device(self, device: torch.device) -> None:
  140. """Set GPU device to use.
  141. Args:
  142. device: Device ID.
  143. """
  144. self.device = device
  145. def init_state(
  146. self, batch_size: int
  147. ) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
  148. """Initialize decoder states.
  149. Args:
  150. batch_size: Batch size.
  151. Returns:
  152. : Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
  153. """
  154. h_n = torch.zeros(
  155. self.dlayers,
  156. batch_size,
  157. self.output_size,
  158. device=self.device,
  159. )
  160. if self.dtype == "lstm":
  161. c_n = torch.zeros(
  162. self.dlayers,
  163. batch_size,
  164. self.output_size,
  165. device=self.device,
  166. )
  167. return (h_n, c_n)
  168. return (h_n, None)
  169. def select_state(
  170. self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int
  171. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  172. """Get specified ID state from decoder hidden states.
  173. Args:
  174. states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
  175. idx: State ID to extract.
  176. Returns:
  177. : Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None)
  178. """
  179. return (
  180. states[0][:, idx : idx + 1, :],
  181. states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
  182. )
  183. def create_batch_states(
  184. self,
  185. new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
  186. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  187. """Create decoder hidden states.
  188. Args:
  189. new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)]
  190. Returns:
  191. states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
  192. """
  193. return (
  194. torch.cat([s[0] for s in new_states], dim=1),
  195. torch.cat([s[1] for s in new_states], dim=1)
  196. if self.dtype == "lstm"
  197. else None,
  198. )