encoder_layer.py 799 B

12345678910111213141516171819202122232425262728293031323334353637
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import torch
  4. from torch import nn
  5. class EncoderLayerSANM(nn.Module):
  6. def __init__(
  7. self,
  8. model,
  9. ):
  10. """Construct an EncoderLayer object."""
  11. super().__init__()
  12. self.self_attn = model.self_attn
  13. self.feed_forward = model.feed_forward
  14. self.norm1 = model.norm1
  15. self.norm2 = model.norm2
  16. self.size = model.size
  17. def forward(self, x, mask):
  18. residual = x
  19. x = self.norm1(x)
  20. x = self.self_attn(x, mask)
  21. if x.size(2) == residual.size(2):
  22. x = x + residual
  23. residual = x
  24. x = self.norm2(x)
  25. x = self.feed_forward(x)
  26. if x.size(2) == residual.size(2):
  27. x = x + residual
  28. return x, mask