transformer.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import os
  2. import torch
  3. import torch.nn as nn
  4. from funasr.modules.vgg2l import import VGG2L
  5. from funasr.modules.attention import MultiHeadedAttention
  6. from funasr.modules.subsampling import (
  7. Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8)
  8. from funasr.export.models.modules.encoder_layer import EncoderLayerConformer as OnnxEncoderLayer
  9. from funasr.export.models.language_models.embed import Embedding
  10. from funasr.export.models.modules.multihead_att import OnnxMultiHeadedAttention
  11. from funasr.export.utils.torch_function import MakePadMask
  12. class TransformerLM(nn.Module, AbsExportModel):
  13. def __init__(self, model, max_seq_len=512, **kwargs):
  14. super().__init__()
  15. self.embed = Embedding(model.embed, max_seq_len)
  16. self.encoder = model.encoder
  17. self.decoder = model.decoder
  18. self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
  19. # replace multihead attention module into customized module.
  20. for i, d in enumerate(self.encoder.encoders):
  21. # d is EncoderLayer
  22. if isinstance(d.self_attn, MultiHeadedAttention):
  23. d.self_attn = OnnxMultiHeadedAttention(d.self_attn)
  24. self.encoder.encoders[i] = OnnxEncoderLayer(d)
  25. self.model_name = "transformer_lm"
  26. self.num_heads = self.encoder.encoders[0].self_attn.h
  27. self.hidden_size = self.encoder.encoders[0].self_attn.linear_out.out_features
  28. def prepare_mask(self, mask):
  29. if len(mask.shape) == 2:
  30. mask = mask[:, None, None, :]
  31. elif len(mask.shape) == 3:
  32. mask = mask[:, None, :]
  33. mask = 1 - mask
  34. return mask * -10000.0
  35. def forward(self, y, cache):
  36. feats_length = torch.ones(y.shape).sum(dim=-1).type(torch.long)
  37. mask = self.make_pad_mask(feats_length) # (B, T)
  38. mask = (y != 0) * mask
  39. xs = self.embed(y)
  40. # forward_one_step of Encoder
  41. if isinstance(
  42. self.encoder.embed,
  43. (Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8, VGG2L),
  44. ):
  45. xs, mask = self.encoder.embed(xs, mask)
  46. else:
  47. xs = self.encoder.embed(xs)
  48. new_cache = []
  49. mask = self.prepare_mask(mask)
  50. for c, e in zip(cache, self.encoder.encoders):
  51. xs, mask = e(xs, mask, c)
  52. new_cache.append(xs)
  53. if self.encoder.normalize_before:
  54. xs = self.encoder.after_norm(xs)
  55. h = self.decoder(xs[:, -1])
  56. return h, new_cache
  57. def get_dummy_inputs(self):
  58. tgt = torch.LongTensor([1]).unsqueeze(0)
  59. cache = [
  60. torch.zeros((1, 1, self.encoder.encoders[0].size))
  61. for _ in range(len(self.encoder.encoders))
  62. ]
  63. return (tgt, cache)
  64. def is_optimizable(self):
  65. return True
  66. def get_input_names(self):
  67. return ["tgt"] + ["cache_%d" % i for i in range(len(self.encoder.encoders))]
  68. def get_output_names(self):
  69. return ["y"] + ["out_cache_%d" % i for i in range(len(self.encoder.encoders))]
  70. def get_dynamic_axes(self):
  71. ret = {"tgt": {0: "tgt_batch", 1: "tgt_length"}}
  72. ret.update(
  73. {
  74. "cache_%d" % d: {0: "cache_%d_batch" % d, 1: "cache_%d_length" % d}
  75. for d in range(len(self.encoder.encoders))
  76. }
  77. )
  78. ret.update(
  79. {
  80. "out_cache_%d"
  81. % d: {0: "out_cache_%d_batch" % d, 1: "out_cache_%d_length" % d}
  82. for d in range(len(self.encoder.encoders))
  83. }
  84. )
  85. return ret
  86. def get_model_config(self, path):
  87. return {
  88. "use_lm": True,
  89. "model_path": os.path.join(path, f"{self.model_name}.onnx"),
  90. "lm_type": "TransformerLM",
  91. "odim": self.encoder.encoders[0].size,
  92. "nlayers": len(self.encoder.encoders),
  93. }