joint_network.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import torch
  6. from funasr.register import tables
  7. from funasr.models.transformer.utils.nets_utils import get_activation
  8. @tables.register("joint_network_classes", "joint_network")
  9. class JointNetwork(torch.nn.Module):
  10. """Transducer joint network module.
  11. Args:
  12. output_size: Output size.
  13. encoder_size: Encoder output size.
  14. decoder_size: Decoder output size..
  15. joint_space_size: Joint space size.
  16. joint_act_type: Type of activation for joint network.
  17. **activation_parameters: Parameters for the activation function.
  18. """
  19. def __init__(
  20. self,
  21. output_size: int,
  22. encoder_size: int,
  23. decoder_size: int,
  24. joint_space_size: int = 256,
  25. joint_activation_type: str = "tanh",
  26. ) -> None:
  27. """Construct a JointNetwork object."""
  28. super().__init__()
  29. self.lin_enc = torch.nn.Linear(encoder_size, joint_space_size)
  30. self.lin_dec = torch.nn.Linear(decoder_size, joint_space_size, bias=False)
  31. self.lin_out = torch.nn.Linear(joint_space_size, output_size)
  32. self.joint_activation = get_activation(
  33. joint_activation_type
  34. )
  35. def forward(
  36. self,
  37. enc_out: torch.Tensor,
  38. dec_out: torch.Tensor,
  39. project_input: bool = True,
  40. ) -> torch.Tensor:
  41. """Joint computation of encoder and decoder hidden state sequences.
  42. Args:
  43. enc_out: Expanded encoder output state sequences (B, T, 1, D_enc)
  44. dec_out: Expanded decoder output state sequences (B, 1, U, D_dec)
  45. Returns:
  46. joint_out: Joint output state sequences. (B, T, U, D_out)
  47. """
  48. if project_input:
  49. joint_out = self.joint_activation(self.lin_enc(enc_out) + self.lin_dec(dec_out))
  50. else:
  51. joint_out = self.joint_activation(enc_out + dec_out)
  52. return self.lin_out(joint_out)