conformer_chunk_encoder.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701
  1. """Conformer encoder definition."""
  2. import logging
  3. from typing import Union, Dict, List, Tuple, Optional
  4. import torch
  5. from torch import nn
  6. from funasr.models.bat.attention import (
  7. RelPositionMultiHeadedAttentionChunk,
  8. )
  9. from funasr.models.transformer.embedding import (
  10. StreamingRelPositionalEncoding,
  11. )
  12. from funasr.models.transformer.layer_norm import LayerNorm
  13. from funasr.models.transformer.utils.nets_utils import get_activation
  14. from funasr.models.transformer.utils.nets_utils import (
  15. TooShortUttError,
  16. check_short_utt,
  17. make_chunk_mask,
  18. make_source_mask,
  19. )
  20. from funasr.models.transformer.positionwise_feed_forward import (
  21. PositionwiseFeedForward,
  22. )
  23. from funasr.models.transformer.utils.repeat import repeat, MultiBlocks
  24. from funasr.models.transformer.utils.subsampling import TooShortUttError
  25. from funasr.models.transformer.utils.subsampling import check_short_utt
  26. from funasr.models.transformer.utils.subsampling import StreamingConvInput
  27. from funasr.register import tables
  28. class ChunkEncoderLayer(nn.Module):
  29. """Chunk Conformer module definition.
  30. Args:
  31. block_size: Input/output size.
  32. self_att: Self-attention module instance.
  33. feed_forward: Feed-forward module instance.
  34. feed_forward_macaron: Feed-forward module instance for macaron network.
  35. conv_mod: Convolution module instance.
  36. norm_class: Normalization module class.
  37. norm_args: Normalization module arguments.
  38. dropout_rate: Dropout rate.
  39. """
  40. def __init__(
  41. self,
  42. block_size: int,
  43. self_att: torch.nn.Module,
  44. feed_forward: torch.nn.Module,
  45. feed_forward_macaron: torch.nn.Module,
  46. conv_mod: torch.nn.Module,
  47. norm_class: torch.nn.Module = LayerNorm,
  48. norm_args: Dict = {},
  49. dropout_rate: float = 0.0,
  50. ) -> None:
  51. """Construct a Conformer object."""
  52. super().__init__()
  53. self.self_att = self_att
  54. self.feed_forward = feed_forward
  55. self.feed_forward_macaron = feed_forward_macaron
  56. self.feed_forward_scale = 0.5
  57. self.conv_mod = conv_mod
  58. self.norm_feed_forward = norm_class(block_size, **norm_args)
  59. self.norm_self_att = norm_class(block_size, **norm_args)
  60. self.norm_macaron = norm_class(block_size, **norm_args)
  61. self.norm_conv = norm_class(block_size, **norm_args)
  62. self.norm_final = norm_class(block_size, **norm_args)
  63. self.dropout = torch.nn.Dropout(dropout_rate)
  64. self.block_size = block_size
  65. self.cache = None
  66. def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
  67. """Initialize/Reset self-attention and convolution modules cache for streaming.
  68. Args:
  69. left_context: Number of left frames during chunk-by-chunk inference.
  70. device: Device to use for cache tensor.
  71. """
  72. self.cache = [
  73. torch.zeros(
  74. (1, left_context, self.block_size),
  75. device=device,
  76. ),
  77. torch.zeros(
  78. (
  79. 1,
  80. self.block_size,
  81. self.conv_mod.kernel_size - 1,
  82. ),
  83. device=device,
  84. ),
  85. ]
  86. def forward(
  87. self,
  88. x: torch.Tensor,
  89. pos_enc: torch.Tensor,
  90. mask: torch.Tensor,
  91. chunk_mask: Optional[torch.Tensor] = None,
  92. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  93. """Encode input sequences.
  94. Args:
  95. x: Conformer input sequences. (B, T, D_block)
  96. pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
  97. mask: Source mask. (B, T)
  98. chunk_mask: Chunk mask. (T_2, T_2)
  99. Returns:
  100. x: Conformer output sequences. (B, T, D_block)
  101. mask: Source mask. (B, T)
  102. pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
  103. """
  104. residual = x
  105. x = self.norm_macaron(x)
  106. x = residual + self.feed_forward_scale * self.dropout(
  107. self.feed_forward_macaron(x)
  108. )
  109. residual = x
  110. x = self.norm_self_att(x)
  111. x_q = x
  112. x = residual + self.dropout(
  113. self.self_att(
  114. x_q,
  115. x,
  116. x,
  117. pos_enc,
  118. mask,
  119. chunk_mask=chunk_mask,
  120. )
  121. )
  122. residual = x
  123. x = self.norm_conv(x)
  124. x, _ = self.conv_mod(x)
  125. x = residual + self.dropout(x)
  126. residual = x
  127. x = self.norm_feed_forward(x)
  128. x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
  129. x = self.norm_final(x)
  130. return x, mask, pos_enc
  131. def chunk_forward(
  132. self,
  133. x: torch.Tensor,
  134. pos_enc: torch.Tensor,
  135. mask: torch.Tensor,
  136. chunk_size: int = 16,
  137. left_context: int = 0,
  138. right_context: int = 0,
  139. ) -> Tuple[torch.Tensor, torch.Tensor]:
  140. """Encode chunk of input sequence.
  141. Args:
  142. x: Conformer input sequences. (B, T, D_block)
  143. pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
  144. mask: Source mask. (B, T_2)
  145. left_context: Number of frames in left context.
  146. right_context: Number of frames in right context.
  147. Returns:
  148. x: Conformer output sequences. (B, T, D_block)
  149. pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
  150. """
  151. residual = x
  152. x = self.norm_macaron(x)
  153. x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
  154. residual = x
  155. x = self.norm_self_att(x)
  156. if left_context > 0:
  157. key = torch.cat([self.cache[0], x], dim=1)
  158. else:
  159. key = x
  160. val = key
  161. if right_context > 0:
  162. att_cache = key[:, -(left_context + right_context) : -right_context, :]
  163. else:
  164. att_cache = key[:, -left_context:, :]
  165. x = residual + self.self_att(
  166. x,
  167. key,
  168. val,
  169. pos_enc,
  170. mask,
  171. left_context=left_context,
  172. )
  173. residual = x
  174. x = self.norm_conv(x)
  175. x, conv_cache = self.conv_mod(
  176. x, cache=self.cache[1], right_context=right_context
  177. )
  178. x = residual + x
  179. residual = x
  180. x = self.norm_feed_forward(x)
  181. x = residual + self.feed_forward_scale * self.feed_forward(x)
  182. x = self.norm_final(x)
  183. self.cache = [att_cache, conv_cache]
  184. return x, pos_enc
  185. class CausalConvolution(nn.Module):
  186. """ConformerConvolution module definition.
  187. Args:
  188. channels: The number of channels.
  189. kernel_size: Size of the convolving kernel.
  190. activation: Type of activation function.
  191. norm_args: Normalization module arguments.
  192. causal: Whether to use causal convolution (set to True if streaming).
  193. """
  194. def __init__(
  195. self,
  196. channels: int,
  197. kernel_size: int,
  198. activation: torch.nn.Module = torch.nn.ReLU(),
  199. norm_args: Dict = {},
  200. causal: bool = False,
  201. ) -> None:
  202. """Construct an ConformerConvolution object."""
  203. super().__init__()
  204. assert (kernel_size - 1) % 2 == 0
  205. self.kernel_size = kernel_size
  206. self.pointwise_conv1 = torch.nn.Conv1d(
  207. channels,
  208. 2 * channels,
  209. kernel_size=1,
  210. stride=1,
  211. padding=0,
  212. )
  213. if causal:
  214. self.lorder = kernel_size - 1
  215. padding = 0
  216. else:
  217. self.lorder = 0
  218. padding = (kernel_size - 1) // 2
  219. self.depthwise_conv = torch.nn.Conv1d(
  220. channels,
  221. channels,
  222. kernel_size,
  223. stride=1,
  224. padding=padding,
  225. groups=channels,
  226. )
  227. self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
  228. self.pointwise_conv2 = torch.nn.Conv1d(
  229. channels,
  230. channels,
  231. kernel_size=1,
  232. stride=1,
  233. padding=0,
  234. )
  235. self.activation = activation
  236. def forward(
  237. self,
  238. x: torch.Tensor,
  239. cache: Optional[torch.Tensor] = None,
  240. right_context: int = 0,
  241. ) -> Tuple[torch.Tensor, torch.Tensor]:
  242. """Compute convolution module.
  243. Args:
  244. x: ConformerConvolution input sequences. (B, T, D_hidden)
  245. cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
  246. right_context: Number of frames in right context.
  247. Returns:
  248. x: ConformerConvolution output sequences. (B, T, D_hidden)
  249. cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
  250. """
  251. x = self.pointwise_conv1(x.transpose(1, 2))
  252. x = torch.nn.functional.glu(x, dim=1)
  253. if self.lorder > 0:
  254. if cache is None:
  255. x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
  256. else:
  257. x = torch.cat([cache, x], dim=2)
  258. if right_context > 0:
  259. cache = x[:, :, -(self.lorder + right_context) : -right_context]
  260. else:
  261. cache = x[:, :, -self.lorder :]
  262. x = self.depthwise_conv(x)
  263. x = self.activation(self.norm(x))
  264. x = self.pointwise_conv2(x).transpose(1, 2)
  265. return x, cache
  266. @tables.register("encoder_classes", "ConformerChunkEncoder")
  267. class ConformerChunkEncoder(nn.Module):
  268. """Encoder module definition.
  269. Args:
  270. input_size: Input size.
  271. body_conf: Encoder body configuration.
  272. input_conf: Encoder input configuration.
  273. main_conf: Encoder main configuration.
  274. """
  275. def __init__(
  276. self,
  277. input_size: int,
  278. output_size: int = 256,
  279. attention_heads: int = 4,
  280. linear_units: int = 2048,
  281. num_blocks: int = 6,
  282. dropout_rate: float = 0.1,
  283. positional_dropout_rate: float = 0.1,
  284. attention_dropout_rate: float = 0.0,
  285. embed_vgg_like: bool = False,
  286. normalize_before: bool = True,
  287. concat_after: bool = False,
  288. positionwise_layer_type: str = "linear",
  289. positionwise_conv_kernel_size: int = 3,
  290. macaron_style: bool = False,
  291. rel_pos_type: str = "legacy",
  292. pos_enc_layer_type: str = "rel_pos",
  293. selfattention_layer_type: str = "rel_selfattn",
  294. activation_type: str = "swish",
  295. use_cnn_module: bool = True,
  296. zero_triu: bool = False,
  297. norm_type: str = "layer_norm",
  298. cnn_module_kernel: int = 31,
  299. conv_mod_norm_eps: float = 0.00001,
  300. conv_mod_norm_momentum: float = 0.1,
  301. simplified_att_score: bool = False,
  302. dynamic_chunk_training: bool = False,
  303. short_chunk_threshold: float = 0.75,
  304. short_chunk_size: int = 25,
  305. left_chunk_size: int = 0,
  306. time_reduction_factor: int = 1,
  307. unified_model_training: bool = False,
  308. default_chunk_size: int = 16,
  309. jitter_range: int = 4,
  310. subsampling_factor: int = 1,
  311. ) -> None:
  312. """Construct an Encoder object."""
  313. super().__init__()
  314. self.embed = StreamingConvInput(
  315. input_size,
  316. output_size,
  317. subsampling_factor,
  318. vgg_like=embed_vgg_like,
  319. output_size=output_size,
  320. )
  321. self.pos_enc = StreamingRelPositionalEncoding(
  322. output_size,
  323. positional_dropout_rate,
  324. )
  325. activation = get_activation(
  326. activation_type
  327. )
  328. pos_wise_args = (
  329. output_size,
  330. linear_units,
  331. positional_dropout_rate,
  332. activation,
  333. )
  334. conv_mod_norm_args = {
  335. "eps": conv_mod_norm_eps,
  336. "momentum": conv_mod_norm_momentum,
  337. }
  338. conv_mod_args = (
  339. output_size,
  340. cnn_module_kernel,
  341. activation,
  342. conv_mod_norm_args,
  343. dynamic_chunk_training or unified_model_training,
  344. )
  345. mult_att_args = (
  346. attention_heads,
  347. output_size,
  348. attention_dropout_rate,
  349. simplified_att_score,
  350. )
  351. fn_modules = []
  352. for _ in range(num_blocks):
  353. module = lambda: ChunkEncoderLayer(
  354. output_size,
  355. RelPositionMultiHeadedAttentionChunk(*mult_att_args),
  356. PositionwiseFeedForward(*pos_wise_args),
  357. PositionwiseFeedForward(*pos_wise_args),
  358. CausalConvolution(*conv_mod_args),
  359. dropout_rate=dropout_rate,
  360. )
  361. fn_modules.append(module)
  362. self.encoders = MultiBlocks(
  363. [fn() for fn in fn_modules],
  364. output_size,
  365. )
  366. self._output_size = output_size
  367. self.dynamic_chunk_training = dynamic_chunk_training
  368. self.short_chunk_threshold = short_chunk_threshold
  369. self.short_chunk_size = short_chunk_size
  370. self.left_chunk_size = left_chunk_size
  371. self.unified_model_training = unified_model_training
  372. self.default_chunk_size = default_chunk_size
  373. self.jitter_range = jitter_range
  374. self.time_reduction_factor = time_reduction_factor
  375. def output_size(self) -> int:
  376. return self._output_size
  377. def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
  378. """Return the corresponding number of sample for a given chunk size, in frames.
  379. Where size is the number of features frames after applying subsampling.
  380. Args:
  381. size: Number of frames after subsampling.
  382. hop_length: Frontend's hop length
  383. Returns:
  384. : Number of raw samples
  385. """
  386. return self.embed.get_size_before_subsampling(size) * hop_length
  387. def get_encoder_input_size(self, size: int) -> int:
  388. """Return the corresponding number of sample for a given chunk size, in frames.
  389. Where size is the number of features frames after applying subsampling.
  390. Args:
  391. size: Number of frames after subsampling.
  392. Returns:
  393. : Number of raw samples
  394. """
  395. return self.embed.get_size_before_subsampling(size)
  396. def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
  397. """Initialize/Reset encoder streaming cache.
  398. Args:
  399. left_context: Number of frames in left context.
  400. device: Device ID.
  401. """
  402. return self.encoders.reset_streaming_cache(left_context, device)
  403. def forward(
  404. self,
  405. x: torch.Tensor,
  406. x_len: torch.Tensor,
  407. ) -> Tuple[torch.Tensor, torch.Tensor]:
  408. """Encode input sequences.
  409. Args:
  410. x: Encoder input features. (B, T_in, F)
  411. x_len: Encoder input features lengths. (B,)
  412. Returns:
  413. x: Encoder outputs. (B, T_out, D_enc)
  414. x_len: Encoder outputs lenghts. (B,)
  415. """
  416. short_status, limit_size = check_short_utt(
  417. self.embed.subsampling_factor, x.size(1)
  418. )
  419. if short_status:
  420. raise TooShortUttError(
  421. f"has {x.size(1)} frames and is too short for subsampling "
  422. + f"(it needs more than {limit_size} frames), return empty results",
  423. x.size(1),
  424. limit_size,
  425. )
  426. mask = make_source_mask(x_len).to(x.device)
  427. if self.unified_model_training:
  428. if self.training:
  429. chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
  430. else:
  431. chunk_size = self.default_chunk_size
  432. x, mask = self.embed(x, mask, chunk_size)
  433. pos_enc = self.pos_enc(x)
  434. chunk_mask = make_chunk_mask(
  435. x.size(1),
  436. chunk_size,
  437. left_chunk_size=self.left_chunk_size,
  438. device=x.device,
  439. )
  440. x_utt = self.encoders(
  441. x,
  442. pos_enc,
  443. mask,
  444. chunk_mask=None,
  445. )
  446. x_chunk = self.encoders(
  447. x,
  448. pos_enc,
  449. mask,
  450. chunk_mask=chunk_mask,
  451. )
  452. olens = mask.eq(0).sum(1)
  453. if self.time_reduction_factor > 1:
  454. x_utt = x_utt[:,::self.time_reduction_factor,:]
  455. x_chunk = x_chunk[:,::self.time_reduction_factor,:]
  456. olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
  457. return x_utt, x_chunk, olens
  458. elif self.dynamic_chunk_training:
  459. max_len = x.size(1)
  460. if self.training:
  461. chunk_size = torch.randint(1, max_len, (1,)).item()
  462. if chunk_size > (max_len * self.short_chunk_threshold):
  463. chunk_size = max_len
  464. else:
  465. chunk_size = (chunk_size % self.short_chunk_size) + 1
  466. else:
  467. chunk_size = self.default_chunk_size
  468. x, mask = self.embed(x, mask, chunk_size)
  469. pos_enc = self.pos_enc(x)
  470. chunk_mask = make_chunk_mask(
  471. x.size(1),
  472. chunk_size,
  473. left_chunk_size=self.left_chunk_size,
  474. device=x.device,
  475. )
  476. else:
  477. x, mask = self.embed(x, mask, None)
  478. pos_enc = self.pos_enc(x)
  479. chunk_mask = None
  480. x = self.encoders(
  481. x,
  482. pos_enc,
  483. mask,
  484. chunk_mask=chunk_mask,
  485. )
  486. olens = mask.eq(0).sum(1)
  487. if self.time_reduction_factor > 1:
  488. x = x[:,::self.time_reduction_factor,:]
  489. olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
  490. return x, olens, None
  491. def full_utt_forward(
  492. self,
  493. x: torch.Tensor,
  494. x_len: torch.Tensor,
  495. ) -> Tuple[torch.Tensor, torch.Tensor]:
  496. """Encode input sequences.
  497. Args:
  498. x: Encoder input features. (B, T_in, F)
  499. x_len: Encoder input features lengths. (B,)
  500. Returns:
  501. x: Encoder outputs. (B, T_out, D_enc)
  502. x_len: Encoder outputs lenghts. (B,)
  503. """
  504. short_status, limit_size = check_short_utt(
  505. self.embed.subsampling_factor, x.size(1)
  506. )
  507. if short_status:
  508. raise TooShortUttError(
  509. f"has {x.size(1)} frames and is too short for subsampling "
  510. + f"(it needs more than {limit_size} frames), return empty results",
  511. x.size(1),
  512. limit_size,
  513. )
  514. mask = make_source_mask(x_len).to(x.device)
  515. x, mask = self.embed(x, mask, None)
  516. pos_enc = self.pos_enc(x)
  517. x_utt = self.encoders(
  518. x,
  519. pos_enc,
  520. mask,
  521. chunk_mask=None,
  522. )
  523. if self.time_reduction_factor > 1:
  524. x_utt = x_utt[:,::self.time_reduction_factor,:]
  525. return x_utt
  526. def simu_chunk_forward(
  527. self,
  528. x: torch.Tensor,
  529. x_len: torch.Tensor,
  530. chunk_size: int = 16,
  531. left_context: int = 32,
  532. right_context: int = 0,
  533. ) -> torch.Tensor:
  534. short_status, limit_size = check_short_utt(
  535. self.embed.subsampling_factor, x.size(1)
  536. )
  537. if short_status:
  538. raise TooShortUttError(
  539. f"has {x.size(1)} frames and is too short for subsampling "
  540. + f"(it needs more than {limit_size} frames), return empty results",
  541. x.size(1),
  542. limit_size,
  543. )
  544. mask = make_source_mask(x_len)
  545. x, mask = self.embed(x, mask, chunk_size)
  546. pos_enc = self.pos_enc(x)
  547. chunk_mask = make_chunk_mask(
  548. x.size(1),
  549. chunk_size,
  550. left_chunk_size=self.left_chunk_size,
  551. device=x.device,
  552. )
  553. x = self.encoders(
  554. x,
  555. pos_enc,
  556. mask,
  557. chunk_mask=chunk_mask,
  558. )
  559. olens = mask.eq(0).sum(1)
  560. if self.time_reduction_factor > 1:
  561. x = x[:,::self.time_reduction_factor,:]
  562. return x
  563. def chunk_forward(
  564. self,
  565. x: torch.Tensor,
  566. x_len: torch.Tensor,
  567. processed_frames: torch.tensor,
  568. chunk_size: int = 16,
  569. left_context: int = 32,
  570. right_context: int = 0,
  571. ) -> torch.Tensor:
  572. """Encode input sequences as chunks.
  573. Args:
  574. x: Encoder input features. (1, T_in, F)
  575. x_len: Encoder input features lengths. (1,)
  576. processed_frames: Number of frames already seen.
  577. left_context: Number of frames in left context.
  578. right_context: Number of frames in right context.
  579. Returns:
  580. x: Encoder outputs. (B, T_out, D_enc)
  581. """
  582. mask = make_source_mask(x_len)
  583. x, mask = self.embed(x, mask, None)
  584. if left_context > 0:
  585. processed_mask = (
  586. torch.arange(left_context, device=x.device)
  587. .view(1, left_context)
  588. .flip(1)
  589. )
  590. processed_mask = processed_mask >= processed_frames
  591. mask = torch.cat([processed_mask, mask], dim=1)
  592. pos_enc = self.pos_enc(x, left_context=left_context)
  593. x = self.encoders.chunk_forward(
  594. x,
  595. pos_enc,
  596. mask,
  597. chunk_size=chunk_size,
  598. left_context=left_context,
  599. right_context=right_context,
  600. )
  601. if right_context > 0:
  602. x = x[:, 0:-right_context, :]
  603. if self.time_reduction_factor > 1:
  604. x = x[:,::self.time_reduction_factor,:]
  605. return x