seq_rnn_lm.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. """Sequential implementation of Recurrent Neural Network Language Model."""
  2. from typing import Tuple
  3. from typing import Union
  4. import torch
  5. import torch.nn as nn
  6. from funasr.train.abs_model import AbsLM
  7. class SequentialRNNLM(AbsLM):
  8. """Sequential RNNLM.
  9. See also:
  10. https://github.com/pytorch/examples/blob/4581968193699de14b56527296262dd76ab43557/word_language_model/model.py
  11. """
  12. def __init__(
  13. self,
  14. vocab_size: int,
  15. unit: int = 650,
  16. nhid: int = None,
  17. nlayers: int = 2,
  18. dropout_rate: float = 0.0,
  19. tie_weights: bool = False,
  20. rnn_type: str = "lstm",
  21. ignore_id: int = 0,
  22. ):
  23. super().__init__()
  24. ninp = unit
  25. if nhid is None:
  26. nhid = unit
  27. rnn_type = rnn_type.upper()
  28. self.drop = nn.Dropout(dropout_rate)
  29. self.encoder = nn.Embedding(vocab_size, ninp, padding_idx=ignore_id)
  30. if rnn_type in ["LSTM", "GRU"]:
  31. rnn_class = getattr(nn, rnn_type)
  32. self.rnn = rnn_class(
  33. ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True
  34. )
  35. else:
  36. try:
  37. nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type]
  38. except KeyError:
  39. raise ValueError(
  40. """An invalid option for `--model` was supplied,
  41. options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']"""
  42. )
  43. self.rnn = nn.RNN(
  44. ninp,
  45. nhid,
  46. nlayers,
  47. nonlinearity=nonlinearity,
  48. dropout=dropout_rate,
  49. batch_first=True,
  50. )
  51. self.decoder = nn.Linear(nhid, vocab_size)
  52. # Optionally tie weights as in:
  53. # "Using the Output Embedding to Improve Language Models"
  54. # (Press & Wolf 2016) https://arxiv.org/abs/1608.05859
  55. # and
  56. # "Tying Word Vectors and Word Classifiers:
  57. # A Loss Framework for Language Modeling" (Inan et al. 2016)
  58. # https://arxiv.org/abs/1611.01462
  59. if tie_weights:
  60. if nhid != ninp:
  61. raise ValueError(
  62. "When using the tied flag, nhid must be equal to emsize"
  63. )
  64. self.decoder.weight = self.encoder.weight
  65. self.rnn_type = rnn_type
  66. self.nhid = nhid
  67. self.nlayers = nlayers
  68. def zero_state(self):
  69. """Initialize LM state filled with zero values."""
  70. if isinstance(self.rnn, torch.nn.LSTM):
  71. h = torch.zeros((self.nlayers, self.nhid), dtype=torch.float)
  72. c = torch.zeros((self.nlayers, self.nhid), dtype=torch.float)
  73. state = h, c
  74. else:
  75. state = torch.zeros((self.nlayers, self.nhid), dtype=torch.float)
  76. return state
  77. def forward(
  78. self, input: torch.Tensor, hidden: torch.Tensor
  79. ) -> Tuple[torch.Tensor, torch.Tensor]:
  80. emb = self.drop(self.encoder(input))
  81. output, hidden = self.rnn(emb, hidden)
  82. output = self.drop(output)
  83. decoded = self.decoder(
  84. output.contiguous().view(output.size(0) * output.size(1), output.size(2))
  85. )
  86. return (
  87. decoded.view(output.size(0), output.size(1), decoded.size(1)),
  88. hidden,
  89. )
  90. def score(
  91. self,
  92. y: torch.Tensor,
  93. state: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
  94. x: torch.Tensor,
  95. ) -> Tuple[torch.Tensor, Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
  96. """Score new token.
  97. Args:
  98. y: 1D torch.int64 prefix tokens.
  99. state: Scorer state for prefix tokens
  100. x: 2D encoder feature that generates ys.
  101. Returns:
  102. Tuple of
  103. torch.float32 scores for next token (n_vocab)
  104. and next state for ys
  105. """
  106. y, new_state = self(y[-1].view(1, 1), state)
  107. logp = y.log_softmax(dim=-1).view(-1)
  108. return logp, new_state
  109. def batch_score(
  110. self, ys: torch.Tensor, states: torch.Tensor, xs: torch.Tensor
  111. ) -> Tuple[torch.Tensor, torch.Tensor]:
  112. """Score new token batch.
  113. Args:
  114. ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
  115. states (List[Any]): Scorer states for prefix tokens.
  116. xs (torch.Tensor):
  117. The encoder feature that generates ys (n_batch, xlen, n_feat).
  118. Returns:
  119. tuple[torch.Tensor, List[Any]]: Tuple of
  120. batchfied scores for next token with shape of `(n_batch, n_vocab)`
  121. and next state list for ys.
  122. """
  123. if states[0] is None:
  124. states = None
  125. elif isinstance(self.rnn, torch.nn.LSTM):
  126. # states: Batch x 2 x (Nlayers, Dim) -> 2 x (Nlayers, Batch, Dim)
  127. h = torch.stack([h for h, c in states], dim=1)
  128. c = torch.stack([c for h, c in states], dim=1)
  129. states = h, c
  130. else:
  131. # states: Batch x (Nlayers, Dim) -> (Nlayers, Batch, Dim)
  132. states = torch.stack(states, dim=1)
  133. ys, states = self(ys[:, -1:], states)
  134. # ys: (Batch, 1, Nvocab) -> (Batch, NVocab)
  135. assert ys.size(1) == 1, ys.shape
  136. ys = ys.squeeze(1)
  137. logp = ys.log_softmax(dim=-1)
  138. # state: Change to batch first
  139. if isinstance(self.rnn, torch.nn.LSTM):
  140. # h, c: (Nlayers, Batch, Dim)
  141. h, c = states
  142. # states: Batch x 2 x (Nlayers, Dim)
  143. states = [(h[:, i], c[:, i]) for i in range(h.size(1))]
  144. else:
  145. # states: (Nlayers, Batch, Dim) -> Batch x (Nlayers, Dim)
  146. states = [states[:, i] for i in range(states.size(1))]
  147. return logp, states