rnn_encoder.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from typing import Optional
  2. from typing import Sequence
  3. from typing import Tuple
  4. import numpy as np
  5. import torch
  6. from funasr.modules.nets_utils import make_pad_mask
  7. from funasr.modules.rnn.encoders import RNN
  8. from funasr.modules.rnn.encoders import RNNP
  9. from funasr.models.encoder.abs_encoder import AbsEncoder
  10. class RNNEncoder(AbsEncoder):
  11. """RNNEncoder class.
  12. Args:
  13. input_size: The number of expected features in the input
  14. output_size: The number of output features
  15. hidden_size: The number of hidden features
  16. bidirectional: If ``True`` becomes a bidirectional LSTM
  17. use_projection: Use projection layer or not
  18. num_layers: Number of recurrent layers
  19. dropout: dropout probability
  20. """
  21. def __init__(
  22. self,
  23. input_size: int,
  24. rnn_type: str = "lstm",
  25. bidirectional: bool = True,
  26. use_projection: bool = True,
  27. num_layers: int = 4,
  28. hidden_size: int = 320,
  29. output_size: int = 320,
  30. dropout: float = 0.0,
  31. subsample: Optional[Sequence[int]] = (2, 2, 1, 1),
  32. ):
  33. super().__init__()
  34. self._output_size = output_size
  35. self.rnn_type = rnn_type
  36. self.bidirectional = bidirectional
  37. self.use_projection = use_projection
  38. if rnn_type not in {"lstm", "gru"}:
  39. raise ValueError(f"Not supported rnn_type={rnn_type}")
  40. if subsample is None:
  41. subsample = np.ones(num_layers + 1, dtype=np.int32)
  42. else:
  43. subsample = subsample[:num_layers]
  44. # Append 1 at the beginning because the second or later is used
  45. subsample = np.pad(
  46. np.array(subsample, dtype=np.int32),
  47. [1, num_layers - len(subsample)],
  48. mode="constant",
  49. constant_values=1,
  50. )
  51. rnn_type = ("b" if bidirectional else "") + rnn_type
  52. if use_projection:
  53. self.enc = torch.nn.ModuleList(
  54. [
  55. RNNP(
  56. input_size,
  57. num_layers,
  58. hidden_size,
  59. output_size,
  60. subsample,
  61. dropout,
  62. typ=rnn_type,
  63. )
  64. ]
  65. )
  66. else:
  67. self.enc = torch.nn.ModuleList(
  68. [
  69. RNN(
  70. input_size,
  71. num_layers,
  72. hidden_size,
  73. output_size,
  74. dropout,
  75. typ=rnn_type,
  76. )
  77. ]
  78. )
  79. def output_size(self) -> int:
  80. return self._output_size
  81. def forward(
  82. self,
  83. xs_pad: torch.Tensor,
  84. ilens: torch.Tensor,
  85. prev_states: torch.Tensor = None,
  86. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  87. if prev_states is None:
  88. prev_states = [None] * len(self.enc)
  89. assert len(prev_states) == len(self.enc)
  90. current_states = []
  91. for module, prev_state in zip(self.enc, prev_states):
  92. xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
  93. current_states.append(states)
  94. if self.use_projection:
  95. xs_pad.masked_fill_(make_pad_mask(ilens, xs_pad, 1), 0.0)
  96. else:
  97. xs_pad = xs_pad.masked_fill(make_pad_mask(ilens, xs_pad, 1), 0.0)
  98. return xs_pad, ilens, current_states