decoder_layer.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import torch
  4. from torch import nn
  5. class DecoderLayerSANM(nn.Module):
  6. def __init__(
  7. self,
  8. model
  9. ):
  10. super().__init__()
  11. self.self_attn = model.self_attn
  12. self.src_attn = model.src_attn
  13. self.feed_forward = model.feed_forward
  14. self.norm1 = model.norm1
  15. self.norm2 = model.norm2 if hasattr(model, 'norm2') else None
  16. self.norm3 = model.norm3 if hasattr(model, 'norm3') else None
  17. self.size = model.size
  18. def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
  19. residual = tgt
  20. tgt = self.norm1(tgt)
  21. tgt = self.feed_forward(tgt)
  22. x = tgt
  23. if self.self_attn is not None:
  24. tgt = self.norm2(tgt)
  25. x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
  26. x = residual + x
  27. if self.src_attn is not None:
  28. residual = x
  29. x = self.norm3(x)
  30. x = residual + self.src_attn(x, memory, memory_mask)
  31. return x, tgt_mask, memory, memory_mask, cache