xformer_decoder.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import os
  2. import torch
  3. import torch.nn as nn
  4. from funasr.modules.attention import MultiHeadedAttention
  5. from funasr.export.models.modules.decoder_layer import DecoderLayer as OnnxDecoderLayer
  6. from funasr.export.models.language_models.embed import Embedding
  7. from funasr.export.models.modules.multihead_att import \
  8. OnnxMultiHeadedAttention
  9. from funasr.export.utils.torch_function import MakePadMask, subsequent_mask
  10. class XformerDecoder(nn.Module):
  11. def __init__(self,
  12. model,
  13. max_seq_len = 512,
  14. model_name = 'decoder',
  15. onnx: bool = True,):
  16. super().__init__()
  17. self.embed = Embedding(model.embed, max_seq_len)
  18. self.model = model
  19. if onnx:
  20. self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
  21. else:
  22. self.make_pad_mask = subsequent_mask(max_seq_len, flip=False)
  23. if isinstance(self.model.decoders[0].self_attn, MultiHeadedAttention):
  24. self.num_heads = self.model.decoders[0].self_attn.h
  25. self.hidden_size = self.model.decoders[0].self_attn.linear_out.out_features
  26. # replace multi-head attention module into customized module.
  27. for i, d in enumerate(self.model.decoders):
  28. # d is DecoderLayer
  29. if isinstance(d.self_attn, MultiHeadedAttention):
  30. d.self_attn = OnnxMultiHeadedAttention(d.self_attn)
  31. if isinstance(d.src_attn, MultiHeadedAttention):
  32. d.src_attn = OnnxMultiHeadedAttention(d.src_attn)
  33. self.model.decoders[i] = OnnxDecoderLayer(d)
  34. self.model_name = model_name
  35. def prepare_mask(self, mask):
  36. mask_3d_btd = mask[:, :, None]
  37. if len(mask.shape) == 2:
  38. mask_4d_bhlt = 1 - mask[:, None, None, :]
  39. elif len(mask.shape) == 3:
  40. mask_4d_bhlt = 1 - mask[:, None, :]
  41. mask_4d_bhlt = mask_4d_bhlt * -10000.0
  42. return mask_3d_btd, mask_4d_bhlt
  43. def forward(self,
  44. tgt,
  45. memory,
  46. cache):
  47. mask = subsequent_mask(tgt.size(-1)).unsqueeze(0) # (B, T)
  48. x = self.embed(tgt)
  49. mask = self.prepare_mask(mask)
  50. new_cache = []
  51. for c, decoder in zip(cache, self.model.decoders):
  52. x, mask = decoder(x, mask, memory, None, c)
  53. new_cache.append(x)
  54. x = x[:, 1:, :]
  55. if self.model.normalize_before:
  56. y = self.model.after_norm(x[:, -1])
  57. else:
  58. y = x[:, -1]
  59. if self.model.output_layer is not None:
  60. y = torch.log_softmax(self.model.output_layer(y), dim=-1)
  61. return y, new_cache
  62. def get_dummy_inputs(self, enc_size):
  63. tgt = torch.LongTensor([0]).unsqueeze(0)
  64. memory = torch.randn(1, 100, enc_size)
  65. cache_num = len(self.model.decoders)
  66. cache = [
  67. torch.zeros((1, 1, self.model.decoders[0].size))
  68. for _ in range(cache_num)
  69. ]
  70. return (tgt, memory, cache)
  71. def is_optimizable(self):
  72. return True
  73. def get_input_names(self):
  74. cache_num = len(self.model.decoders)
  75. return ["tgt", "memory"] + [
  76. "cache_%d" % i for i in range(cache_num)
  77. ]
  78. def get_output_names(self):
  79. cache_num = len(self.model.decoders)
  80. return ["y"] + ["out_cache_%d" % i for i in range(cache_num)]
  81. def get_dynamic_axes(self):
  82. ret = {
  83. "tgt": {0: "tgt_batch", 1: "tgt_length"},
  84. "memory": {0: "memory_batch", 1: "memory_length"},
  85. }
  86. cache_num = len(self.model.decoders)
  87. ret.update(
  88. {
  89. "cache_%d" % d: {0: "cache_%d_batch" % d, 2: "cache_%d_length" % d}
  90. for d in range(cache_num)
  91. }
  92. )
  93. return ret
  94. def get_model_config(self, path):
  95. return {
  96. "dec_type": "XformerDecoder",
  97. "model_path": os.path.join(path, f"{self.model_name}.onnx"),
  98. "n_layers": len(self.model.decoders),
  99. "odim": self.model.decoders[0].size,
  100. }