encoder.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. """Encoder for Transducer model."""
  2. from typing import Any, Dict, List, Tuple
  3. import torch
  4. from typeguard import check_argument_types
  5. from funasr.models_transducer.encoder.building import (
  6. build_body_blocks,
  7. build_input_block,
  8. build_main_parameters,
  9. build_positional_encoding,
  10. )
  11. from funasr.models_transducer.encoder.validation import validate_architecture
  12. from funasr.models_transducer.utils import (
  13. TooShortUttError,
  14. check_short_utt,
  15. make_chunk_mask,
  16. make_source_mask,
  17. )
  18. class Encoder(torch.nn.Module):
  19. """Encoder module definition.
  20. Args:
  21. input_size: Input size.
  22. body_conf: Encoder body configuration.
  23. input_conf: Encoder input configuration.
  24. main_conf: Encoder main configuration.
  25. """
  26. def __init__(
  27. self,
  28. input_size: int,
  29. body_conf: List[Dict[str, Any]],
  30. input_conf: Dict[str, Any] = {},
  31. main_conf: Dict[str, Any] = {},
  32. ) -> None:
  33. """Construct an Encoder object."""
  34. super().__init__()
  35. assert check_argument_types()
  36. embed_size, output_size = validate_architecture(
  37. input_conf, body_conf, input_size
  38. )
  39. main_params = build_main_parameters(**main_conf)
  40. self.embed = build_input_block(input_size, input_conf)
  41. self.pos_enc = build_positional_encoding(embed_size, main_params)
  42. self.encoders = build_body_blocks(body_conf, main_params, output_size)
  43. self.output_size = output_size
  44. self.dynamic_chunk_training = main_params["dynamic_chunk_training"]
  45. self.short_chunk_threshold = main_params["short_chunk_threshold"]
  46. self.short_chunk_size = main_params["short_chunk_size"]
  47. self.left_chunk_size = main_params["left_chunk_size"]
  48. self.unified_model_training = main_params["unified_model_training"]
  49. self.default_chunk_size = main_params["default_chunk_size"]
  50. self.jitter_range = main_params["jitter_range"]
  51. self.time_reduction_factor = main_params["time_reduction_factor"]
  52. def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
  53. """Return the corresponding number of sample for a given chunk size, in frames.
  54. Where size is the number of features frames after applying subsampling.
  55. Args:
  56. size: Number of frames after subsampling.
  57. hop_length: Frontend's hop length
  58. Returns:
  59. : Number of raw samples
  60. """
  61. return self.embed.get_size_before_subsampling(size) * hop_length
  62. def get_encoder_input_size(self, size: int) -> int:
  63. """Return the corresponding number of sample for a given chunk size, in frames.
  64. Where size is the number of features frames after applying subsampling.
  65. Args:
  66. size: Number of frames after subsampling.
  67. Returns:
  68. : Number of raw samples
  69. """
  70. return self.embed.get_size_before_subsampling(size)
  71. def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
  72. """Initialize/Reset encoder streaming cache.
  73. Args:
  74. left_context: Number of frames in left context.
  75. device: Device ID.
  76. """
  77. return self.encoders.reset_streaming_cache(left_context, device)
  78. def forward(
  79. self,
  80. x: torch.Tensor,
  81. x_len: torch.Tensor,
  82. ) -> Tuple[torch.Tensor, torch.Tensor]:
  83. """Encode input sequences.
  84. Args:
  85. x: Encoder input features. (B, T_in, F)
  86. x_len: Encoder input features lengths. (B,)
  87. Returns:
  88. x: Encoder outputs. (B, T_out, D_enc)
  89. x_len: Encoder outputs lenghts. (B,)
  90. """
  91. short_status, limit_size = check_short_utt(
  92. self.embed.subsampling_factor, x.size(1)
  93. )
  94. if short_status:
  95. raise TooShortUttError(
  96. f"has {x.size(1)} frames and is too short for subsampling "
  97. + f"(it needs more than {limit_size} frames), return empty results",
  98. x.size(1),
  99. limit_size,
  100. )
  101. mask = make_source_mask(x_len)
  102. if self.unified_model_training:
  103. chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
  104. x, mask = self.embed(x, mask, chunk_size)
  105. pos_enc = self.pos_enc(x)
  106. chunk_mask = make_chunk_mask(
  107. x.size(1),
  108. chunk_size,
  109. left_chunk_size=self.left_chunk_size,
  110. device=x.device,
  111. )
  112. x_utt = self.encoders(
  113. x,
  114. pos_enc,
  115. mask,
  116. chunk_mask=None,
  117. )
  118. x_chunk = self.encoders(
  119. x,
  120. pos_enc,
  121. mask,
  122. chunk_mask=chunk_mask,
  123. )
  124. olens = mask.eq(0).sum(1)
  125. if self.time_reduction_factor > 1:
  126. x_utt = x_utt[:,::self.time_reduction_factor,:]
  127. x_chunk = x_chunk[:,::self.time_reduction_factor,:]
  128. olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
  129. return x_utt, x_chunk, olens
  130. elif self.dynamic_chunk_training:
  131. max_len = x.size(1)
  132. chunk_size = torch.randint(1, max_len, (1,)).item()
  133. if chunk_size > (max_len * self.short_chunk_threshold):
  134. chunk_size = max_len
  135. else:
  136. chunk_size = (chunk_size % self.short_chunk_size) + 1
  137. x, mask = self.embed(x, mask, chunk_size)
  138. pos_enc = self.pos_enc(x)
  139. chunk_mask = make_chunk_mask(
  140. x.size(1),
  141. chunk_size,
  142. left_chunk_size=self.left_chunk_size,
  143. device=x.device,
  144. )
  145. else:
  146. x, mask = self.embed(x, mask, None)
  147. pos_enc = self.pos_enc(x)
  148. chunk_mask = None
  149. x = self.encoders(
  150. x,
  151. pos_enc,
  152. mask,
  153. chunk_mask=chunk_mask,
  154. )
  155. olens = mask.eq(0).sum(1)
  156. if self.time_reduction_factor > 1:
  157. x = x[:,::self.time_reduction_factor,:]
  158. olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
  159. return x, olens
  160. def simu_chunk_forward(
  161. self,
  162. x: torch.Tensor,
  163. x_len: torch.Tensor,
  164. chunk_size: int = 16,
  165. left_context: int = 32,
  166. right_context: int = 0,
  167. ) -> torch.Tensor:
  168. short_status, limit_size = check_short_utt(
  169. self.embed.subsampling_factor, x.size(1)
  170. )
  171. if short_status:
  172. raise TooShortUttError(
  173. f"has {x.size(1)} frames and is too short for subsampling "
  174. + f"(it needs more than {limit_size} frames), return empty results",
  175. x.size(1),
  176. limit_size,
  177. )
  178. mask = make_source_mask(x_len)
  179. x, mask = self.embed(x, mask, chunk_size)
  180. pos_enc = self.pos_enc(x)
  181. chunk_mask = make_chunk_mask(
  182. x.size(1),
  183. chunk_size,
  184. left_chunk_size=self.left_chunk_size,
  185. device=x.device,
  186. )
  187. x = self.encoders(
  188. x,
  189. pos_enc,
  190. mask,
  191. chunk_mask=chunk_mask,
  192. )
  193. olens = mask.eq(0).sum(1)
  194. if self.time_reduction_factor > 1:
  195. x = x[:,::self.time_reduction_factor,:]
  196. return x
  197. def chunk_forward(
  198. self,
  199. x: torch.Tensor,
  200. x_len: torch.Tensor,
  201. processed_frames: torch.tensor,
  202. chunk_size: int = 16,
  203. left_context: int = 32,
  204. right_context: int = 0,
  205. ) -> torch.Tensor:
  206. """Encode input sequences as chunks.
  207. Args:
  208. x: Encoder input features. (1, T_in, F)
  209. x_len: Encoder input features lengths. (1,)
  210. processed_frames: Number of frames already seen.
  211. left_context: Number of frames in left context.
  212. right_context: Number of frames in right context.
  213. Returns:
  214. x: Encoder outputs. (B, T_out, D_enc)
  215. """
  216. mask = make_source_mask(x_len)
  217. x, mask = self.embed(x, mask, None)
  218. if left_context > 0:
  219. processed_mask = (
  220. torch.arange(left_context, device=x.device)
  221. .view(1, left_context)
  222. .flip(1)
  223. )
  224. processed_mask = processed_mask >= processed_frames
  225. mask = torch.cat([processed_mask, mask], dim=1)
  226. pos_enc = self.pos_enc(x, left_context=left_context)
  227. x = self.encoders.chunk_forward(
  228. x,
  229. pos_enc,
  230. mask,
  231. chunk_size=chunk_size,
  232. left_context=left_context,
  233. right_context=right_context,
  234. )
  235. if right_context > 0:
  236. x = x[:, 0:-right_context, :]
  237. if self.time_reduction_factor > 1:
  238. x = x[:,::self.time_reduction_factor,:]
  239. return x