encoder_layer.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import torch
  4. from torch import nn
  5. class EncoderLayerSANM(nn.Module):
  6. def __init__(
  7. self,
  8. model,
  9. ):
  10. """Construct an EncoderLayer object."""
  11. super().__init__()
  12. self.self_attn = model.self_attn
  13. self.feed_forward = model.feed_forward
  14. self.norm1 = model.norm1
  15. self.norm2 = model.norm2
  16. self.in_size = model.in_size
  17. self.size = model.size
  18. def forward(self, x, mask):
  19. residual = x
  20. x = self.norm1(x)
  21. x = self.self_attn(x, mask)
  22. # x = x / scale
  23. if self.in_size == self.size:
  24. x = x + residual
  25. residual = x
  26. x = self.norm2(x)
  27. x = self.feed_forward(x)
  28. x = x + residual
  29. return x, mask
  30. class EncoderLayerConformer(nn.Module):
  31. def __init__(
  32. self,
  33. model,
  34. ):
  35. """Construct an EncoderLayer object."""
  36. super().__init__()
  37. self.self_attn = model.self_attn
  38. self.feed_forward = model.feed_forward
  39. self.feed_forward_macaron = model.feed_forward_macaron
  40. self.conv_module = model.conv_module
  41. self.norm_ff = model.norm_ff
  42. self.norm_mha = model.norm_mha
  43. self.norm_ff_macaron = model.norm_ff_macaron
  44. self.norm_conv = model.norm_conv
  45. self.norm_final = model.norm_final
  46. self.size = model.size
  47. def forward(self, x, mask):
  48. if isinstance(x, tuple):
  49. x, pos_emb = x[0], x[1]
  50. else:
  51. x, pos_emb = x, None
  52. if self.feed_forward_macaron is not None:
  53. residual = x
  54. x = self.norm_ff_macaron(x)
  55. x = residual + self.feed_forward_macaron(x) * 0.5
  56. residual = x
  57. x = self.norm_mha(x)
  58. x_q = x
  59. if pos_emb is not None:
  60. x_att = self.self_attn(x_q, x, x, pos_emb, mask)
  61. else:
  62. x_att = self.self_attn(x_q, x, x, mask)
  63. x = residual + x_att
  64. if self.conv_module is not None:
  65. residual = x
  66. x = self.norm_conv(x)
  67. x = residual + self.conv_module(x)
  68. residual = x
  69. x = self.norm_ff(x)
  70. x = residual + self.feed_forward(x) * 0.5
  71. x = self.norm_final(x)
  72. if pos_emb is not None:
  73. return (x, pos_emb), mask
  74. return x, mask