seq_rnn.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import os
  2. import torch
  3. import torch.nn as nn
  4. class SequentialRNNLM(nn.Module):
  5. def __init__(self, model, **kwargs):
  6. super().__init__()
  7. self.encoder = model.encoder
  8. self.rnn = model.rnn
  9. self.rnn_type = model.rnn_type
  10. self.decoder = model.decoder
  11. self.nlayers = model.nlayers
  12. self.nhid = model.nhid
  13. self.model_name = "seq_rnnlm"
  14. def forward(self, y, hidden1, hidden2=None):
  15. # batch_score function.
  16. emb = self.encoder(y)
  17. if self.rnn_type == "LSTM":
  18. output, (hidden1, hidden2) = self.rnn(emb, (hidden1, hidden2))
  19. else:
  20. output, hidden1 = self.rnn(emb, hidden1)
  21. decoded = self.decoder(
  22. output.contiguous().view(output.size(0) * output.size(1), output.size(2))
  23. )
  24. if self.rnn_type == "LSTM":
  25. return (
  26. decoded.view(output.size(0), output.size(1), decoded.size(1)),
  27. hidden1,
  28. hidden2,
  29. )
  30. else:
  31. return (
  32. decoded.view(output.size(0), output.size(1), decoded.size(1)),
  33. hidden1,
  34. )
  35. def get_dummy_inputs(self):
  36. tgt = torch.LongTensor([0, 1]).unsqueeze(0)
  37. hidden = torch.randn(self.nlayers, 1, self.nhid)
  38. if self.rnn_type == "LSTM":
  39. return (tgt, hidden, hidden)
  40. else:
  41. return (tgt, hidden)
  42. def get_input_names(self):
  43. if self.rnn_type == "LSTM":
  44. return ["x", "in_hidden1", "in_hidden2"]
  45. else:
  46. return ["x", "in_hidden1"]
  47. def get_output_names(self):
  48. if self.rnn_type == "LSTM":
  49. return ["y", "out_hidden1", "out_hidden2"]
  50. else:
  51. return ["y", "out_hidden1"]
  52. def get_dynamic_axes(self):
  53. ret = {
  54. "x": {0: "x_batch", 1: "x_length"},
  55. "y": {0: "y_batch"},
  56. "in_hidden1": {1: "hidden1_batch"},
  57. "out_hidden1": {1: "out_hidden1_batch"},
  58. }
  59. if self.rnn_type == "LSTM":
  60. ret.update(
  61. {
  62. "in_hidden2": {1: "hidden2_batch"},
  63. "out_hidden2": {1: "out_hidden2_batch"},
  64. }
  65. )
  66. return ret
  67. def get_model_config(self, path):
  68. return {
  69. "use_lm": True,
  70. "model_path": os.path.join(path, f"{self.model_name}.onnx"),
  71. "lm_type": "SequentialRNNLM",
  72. "rnn_type": self.rnn_type,
  73. "nhid": self.nhid,
  74. "nlayers": self.nlayers,
  75. }