feedforward.py 658 B

12345678910111213141516171819202122232425262728293031
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import torch
  4. import torch.nn as nn
  5. class PositionwiseFeedForward(nn.Module):
  6. def __init__(self, model):
  7. super().__init__()
  8. self.w_1 = model.w_1
  9. self.w_2 = model.w_2
  10. self.activation = model.activation
  11. def forward(self, x):
  12. x = self.activation(self.w_1(x))
  13. x = self.w_2(x)
  14. return x
  15. class PositionwiseFeedForwardDecoderSANM(nn.Module):
  16. def __init__(self, model):
  17. super().__init__()
  18. self.w_1 = model.w_1
  19. self.w_2 = model.w_2
  20. self.activation = model.activation
  21. self.norm = model.norm
  22. def forward(self, x):
  23. x = self.activation(self.w_1(x))
  24. x = self.w_2(self.norm(x))
  25. return x