seq_rnn_lm.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. """Sequential implementation of Recurrent Neural Network Language Model."""
  2. from typing import Tuple
  3. from typing import Union
  4. import torch
  5. import torch.nn as nn
  6. from typeguard import check_argument_types
  7. from funasr.train.abs_model import AbsLM
  8. class SequentialRNNLM(AbsLM):
  9. """Sequential RNNLM.
  10. See also:
  11. https://github.com/pytorch/examples/blob/4581968193699de14b56527296262dd76ab43557/word_language_model/model.py
  12. """
  13. def __init__(
  14. self,
  15. vocab_size: int,
  16. unit: int = 650,
  17. nhid: int = None,
  18. nlayers: int = 2,
  19. dropout_rate: float = 0.0,
  20. tie_weights: bool = False,
  21. rnn_type: str = "lstm",
  22. ignore_id: int = 0,
  23. ):
  24. assert check_argument_types()
  25. super().__init__()
  26. ninp = unit
  27. if nhid is None:
  28. nhid = unit
  29. rnn_type = rnn_type.upper()
  30. self.drop = nn.Dropout(dropout_rate)
  31. self.encoder = nn.Embedding(vocab_size, ninp, padding_idx=ignore_id)
  32. if rnn_type in ["LSTM", "GRU"]:
  33. rnn_class = getattr(nn, rnn_type)
  34. self.rnn = rnn_class(
  35. ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True
  36. )
  37. else:
  38. try:
  39. nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type]
  40. except KeyError:
  41. raise ValueError(
  42. """An invalid option for `--model` was supplied,
  43. options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']"""
  44. )
  45. self.rnn = nn.RNN(
  46. ninp,
  47. nhid,
  48. nlayers,
  49. nonlinearity=nonlinearity,
  50. dropout=dropout_rate,
  51. batch_first=True,
  52. )
  53. self.decoder = nn.Linear(nhid, vocab_size)
  54. # Optionally tie weights as in:
  55. # "Using the Output Embedding to Improve Language Models"
  56. # (Press & Wolf 2016) https://arxiv.org/abs/1608.05859
  57. # and
  58. # "Tying Word Vectors and Word Classifiers:
  59. # A Loss Framework for Language Modeling" (Inan et al. 2016)
  60. # https://arxiv.org/abs/1611.01462
  61. if tie_weights:
  62. if nhid != ninp:
  63. raise ValueError(
  64. "When using the tied flag, nhid must be equal to emsize"
  65. )
  66. self.decoder.weight = self.encoder.weight
  67. self.rnn_type = rnn_type
  68. self.nhid = nhid
  69. self.nlayers = nlayers
  70. def zero_state(self):
  71. """Initialize LM state filled with zero values."""
  72. if isinstance(self.rnn, torch.nn.LSTM):
  73. h = torch.zeros((self.nlayers, self.nhid), dtype=torch.float)
  74. c = torch.zeros((self.nlayers, self.nhid), dtype=torch.float)
  75. state = h, c
  76. else:
  77. state = torch.zeros((self.nlayers, self.nhid), dtype=torch.float)
  78. return state
  79. def forward(
  80. self, input: torch.Tensor, hidden: torch.Tensor
  81. ) -> Tuple[torch.Tensor, torch.Tensor]:
  82. emb = self.drop(self.encoder(input))
  83. output, hidden = self.rnn(emb, hidden)
  84. output = self.drop(output)
  85. decoded = self.decoder(
  86. output.contiguous().view(output.size(0) * output.size(1), output.size(2))
  87. )
  88. return (
  89. decoded.view(output.size(0), output.size(1), decoded.size(1)),
  90. hidden,
  91. )
  92. def score(
  93. self,
  94. y: torch.Tensor,
  95. state: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
  96. x: torch.Tensor,
  97. ) -> Tuple[torch.Tensor, Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
  98. """Score new token.
  99. Args:
  100. y: 1D torch.int64 prefix tokens.
  101. state: Scorer state for prefix tokens
  102. x: 2D encoder feature that generates ys.
  103. Returns:
  104. Tuple of
  105. torch.float32 scores for next token (n_vocab)
  106. and next state for ys
  107. """
  108. y, new_state = self(y[-1].view(1, 1), state)
  109. logp = y.log_softmax(dim=-1).view(-1)
  110. return logp, new_state
  111. def batch_score(
  112. self, ys: torch.Tensor, states: torch.Tensor, xs: torch.Tensor
  113. ) -> Tuple[torch.Tensor, torch.Tensor]:
  114. """Score new token batch.
  115. Args:
  116. ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
  117. states (List[Any]): Scorer states for prefix tokens.
  118. xs (torch.Tensor):
  119. The encoder feature that generates ys (n_batch, xlen, n_feat).
  120. Returns:
  121. tuple[torch.Tensor, List[Any]]: Tuple of
  122. batchfied scores for next token with shape of `(n_batch, n_vocab)`
  123. and next state list for ys.
  124. """
  125. if states[0] is None:
  126. states = None
  127. elif isinstance(self.rnn, torch.nn.LSTM):
  128. # states: Batch x 2 x (Nlayers, Dim) -> 2 x (Nlayers, Batch, Dim)
  129. h = torch.stack([h for h, c in states], dim=1)
  130. c = torch.stack([c for h, c in states], dim=1)
  131. states = h, c
  132. else:
  133. # states: Batch x (Nlayers, Dim) -> (Nlayers, Batch, Dim)
  134. states = torch.stack(states, dim=1)
  135. ys, states = self(ys[:, -1:], states)
  136. # ys: (Batch, 1, Nvocab) -> (Batch, NVocab)
  137. assert ys.size(1) == 1, ys.shape
  138. ys = ys.squeeze(1)
  139. logp = ys.log_softmax(dim=-1)
  140. # state: Change to batch first
  141. if isinstance(self.rnn, torch.nn.LSTM):
  142. # h, c: (Nlayers, Batch, Dim)
  143. h, c = states
  144. # states: Batch x 2 x (Nlayers, Dim)
  145. states = [(h[:, i], c[:, i]) for i in range(h.size(1))]
  146. else:
  147. # states: (Nlayers, Batch, Dim) -> Batch x (Nlayers, Dim)
  148. states = [states[:, i] for i in range(states.size(1))]
  149. return logp, states