decoder_layer.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import torch
  4. from torch import nn
  5. class DecoderLayerSANM(nn.Module):
  6. def __init__(
  7. self,
  8. model
  9. ):
  10. super().__init__()
  11. self.self_attn = model.self_attn
  12. self.src_attn = model.src_attn
  13. self.feed_forward = model.feed_forward
  14. self.norm1 = model.norm1
  15. self.norm2 = model.norm2 if hasattr(model, 'norm2') else None
  16. self.norm3 = model.norm3 if hasattr(model, 'norm3') else None
  17. self.size = model.size
  18. def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
  19. residual = tgt
  20. tgt = self.norm1(tgt)
  21. tgt = self.feed_forward(tgt)
  22. x = tgt
  23. if self.self_attn is not None:
  24. tgt = self.norm2(tgt)
  25. x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
  26. x = residual + x
  27. if self.src_attn is not None:
  28. residual = x
  29. x = self.norm3(x)
  30. x = residual + self.src_attn(x, memory, memory_mask)
  31. return x, tgt_mask, memory, memory_mask, cache
  32. class DecoderLayer(nn.Module):
  33. def __init__(self, model):
  34. super().__init__()
  35. self.self_attn = model.self_attn
  36. self.src_attn = model.src_attn
  37. self.feed_forward = model.feed_forward
  38. self.norm1 = model.norm1
  39. self.norm2 = model.norm2
  40. self.norm3 = model.norm3
  41. def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
  42. residual = tgt
  43. tgt = self.norm1(tgt)
  44. tgt_q = tgt
  45. tgt_q_mask = tgt_mask
  46. x = residual + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)
  47. residual = x
  48. x = self.norm2(x)
  49. x = residual + self.src_attn(x, memory, memory, memory_mask)
  50. residual = x
  51. x = self.norm3(x)
  52. x = residual + self.feed_forward(x)
  53. return x, tgt_mask, memory, memory_mask