rnn_encoder.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from typing import Optional
  2. from typing import Sequence
  3. from typing import Tuple
  4. import numpy as np
  5. import torch
  6. from typeguard import check_argument_types
  7. from funasr.modules.nets_utils import make_pad_mask
  8. from funasr.modules.rnn.encoders import RNN
  9. from funasr.modules.rnn.encoders import RNNP
  10. from funasr.models.encoder.abs_encoder import AbsEncoder
  11. class RNNEncoder(AbsEncoder):
  12. """RNNEncoder class.
  13. Args:
  14. input_size: The number of expected features in the input
  15. output_size: The number of output features
  16. hidden_size: The number of hidden features
  17. bidirectional: If ``True`` becomes a bidirectional LSTM
  18. use_projection: Use projection layer or not
  19. num_layers: Number of recurrent layers
  20. dropout: dropout probability
  21. """
  22. def __init__(
  23. self,
  24. input_size: int,
  25. rnn_type: str = "lstm",
  26. bidirectional: bool = True,
  27. use_projection: bool = True,
  28. num_layers: int = 4,
  29. hidden_size: int = 320,
  30. output_size: int = 320,
  31. dropout: float = 0.0,
  32. subsample: Optional[Sequence[int]] = (2, 2, 1, 1),
  33. ):
  34. assert check_argument_types()
  35. super().__init__()
  36. self._output_size = output_size
  37. self.rnn_type = rnn_type
  38. self.bidirectional = bidirectional
  39. self.use_projection = use_projection
  40. if rnn_type not in {"lstm", "gru"}:
  41. raise ValueError(f"Not supported rnn_type={rnn_type}")
  42. if subsample is None:
  43. subsample = np.ones(num_layers + 1, dtype=np.int)
  44. else:
  45. subsample = subsample[:num_layers]
  46. # Append 1 at the beginning because the second or later is used
  47. subsample = np.pad(
  48. np.array(subsample, dtype=np.int),
  49. [1, num_layers - len(subsample)],
  50. mode="constant",
  51. constant_values=1,
  52. )
  53. rnn_type = ("b" if bidirectional else "") + rnn_type
  54. if use_projection:
  55. self.enc = torch.nn.ModuleList(
  56. [
  57. RNNP(
  58. input_size,
  59. num_layers,
  60. hidden_size,
  61. output_size,
  62. subsample,
  63. dropout,
  64. typ=rnn_type,
  65. )
  66. ]
  67. )
  68. else:
  69. self.enc = torch.nn.ModuleList(
  70. [
  71. RNN(
  72. input_size,
  73. num_layers,
  74. hidden_size,
  75. output_size,
  76. dropout,
  77. typ=rnn_type,
  78. )
  79. ]
  80. )
  81. def output_size(self) -> int:
  82. return self._output_size
  83. def forward(
  84. self,
  85. xs_pad: torch.Tensor,
  86. ilens: torch.Tensor,
  87. prev_states: torch.Tensor = None,
  88. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  89. if prev_states is None:
  90. prev_states = [None] * len(self.enc)
  91. assert len(prev_states) == len(self.enc)
  92. current_states = []
  93. for module, prev_state in zip(self.enc, prev_states):
  94. xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
  95. current_states.append(states)
  96. if self.use_projection:
  97. xs_pad.masked_fill_(make_pad_mask(ilens, xs_pad, 1), 0.0)
  98. else:
  99. xs_pad = xs_pad.masked_fill(make_pad_mask(ilens, xs_pad, 1), 0.0)
  100. return xs_pad, ilens, current_states