encoder_layer.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import torch
  4. from torch import nn
  5. class EncoderLayerSANM(nn.Module):
  6. def __init__(
  7. self,
  8. model,
  9. ):
  10. """Construct an EncoderLayer object."""
  11. super().__init__()
  12. self.self_attn = model.self_attn
  13. self.feed_forward = model.feed_forward
  14. self.norm1 = model.norm1
  15. self.norm2 = model.norm2
  16. self.in_size = model.in_size
  17. self.size = model.size
  18. def forward(self, x, mask):
  19. residual = x
  20. x = self.norm1(x)
  21. x = self.self_attn(x, mask)
  22. if self.in_size == self.size:
  23. x = x + residual
  24. residual = x
  25. x = self.norm2(x)
  26. x = self.feed_forward(x)
  27. x = x + residual
  28. return x, mask
  29. class EncoderLayerConformer(nn.Module):
  30. def __init__(
  31. self,
  32. model,
  33. ):
  34. """Construct an EncoderLayer object."""
  35. super().__init__()
  36. self.self_attn = model.self_attn
  37. self.feed_forward = model.feed_forward
  38. self.feed_forward_macaron = model.feed_forward_macaron
  39. self.conv_module = model.conv_module
  40. self.norm_ff = model.norm_ff
  41. self.norm_mha = model.norm_mha
  42. self.norm_ff_macaron = model.norm_ff_macaron
  43. self.norm_conv = model.norm_conv
  44. self.norm_final = model.norm_final
  45. self.size = model.size
  46. def forward(self, x, mask):
  47. if isinstance(x, tuple):
  48. x, pos_emb = x[0], x[1]
  49. else:
  50. x, pos_emb = x, None
  51. if self.feed_forward_macaron is not None:
  52. residual = x
  53. x = self.norm_ff_macaron(x)
  54. x = residual + self.feed_forward_macaron(x) * 0.5
  55. residual = x
  56. x = self.norm_mha(x)
  57. x_q = x
  58. if pos_emb is not None:
  59. x_att = self.self_attn(x_q, x, x, pos_emb, mask)
  60. else:
  61. x_att = self.self_attn(x_q, x, x, mask)
  62. x = residual + x_att
  63. if self.conv_module is not None:
  64. residual = x
  65. x = self.norm_conv(x)
  66. x = residual + self.conv_module(x)
  67. residual = x
  68. x = self.norm_ff(x)
  69. x = residual + self.feed_forward(x) * 0.5
  70. x = self.norm_final(x)
  71. if pos_emb is not None:
  72. return (x, pos_emb), mask
  73. return x, mask