encoder.py 44 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279
  1. # Copyright 2020 Tomoki Hayashi
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Conformer encoder definition."""
  4. import logging
  5. from typing import Union, Dict, List, Tuple, Optional
  6. import torch
  7. from torch import nn
  8. from funasr.models.ctc.ctc import CTC
  9. from funasr.models.transformer.attention import (
  10. MultiHeadedAttention, # noqa: H301
  11. RelPositionMultiHeadedAttention, # noqa: H301
  12. LegacyRelPositionMultiHeadedAttention, # noqa: H301
  13. RelPositionMultiHeadedAttentionChunk,
  14. )
  15. from funasr.models.transformer.embedding import (
  16. PositionalEncoding, # noqa: H301
  17. ScaledPositionalEncoding, # noqa: H301
  18. RelPositionalEncoding, # noqa: H301
  19. LegacyRelPositionalEncoding, # noqa: H301
  20. StreamingRelPositionalEncoding,
  21. )
  22. from funasr.models.transformer.layer_norm import LayerNorm
  23. from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
  24. from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
  25. from funasr.models.transformer.utils.nets_utils import get_activation
  26. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  27. from funasr.models.transformer.utils.nets_utils import (
  28. TooShortUttError,
  29. check_short_utt,
  30. make_chunk_mask,
  31. make_source_mask,
  32. )
  33. from funasr.models.transformer.positionwise_feed_forward import (
  34. PositionwiseFeedForward, # noqa: H301
  35. )
  36. from funasr.models.transformer.utils.repeat import repeat, MultiBlocks
  37. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
  38. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
  39. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
  40. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
  41. from funasr.models.transformer.utils.subsampling import TooShortUttError
  42. from funasr.models.transformer.utils.subsampling import check_short_utt
  43. from funasr.models.transformer.utils.subsampling import Conv2dSubsamplingPad
  44. from funasr.models.transformer.utils.subsampling import StreamingConvInput
  45. from funasr.register import tables
  46. import pdb
  47. class ConvolutionModule(nn.Module):
  48. """ConvolutionModule in Conformer model.
  49. Args:
  50. channels (int): The number of channels of conv layers.
  51. kernel_size (int): Kernerl size of conv layers.
  52. """
  53. def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
  54. """Construct an ConvolutionModule object."""
  55. super(ConvolutionModule, self).__init__()
  56. # kernerl_size should be a odd number for 'SAME' padding
  57. assert (kernel_size - 1) % 2 == 0
  58. self.pointwise_conv1 = nn.Conv1d(
  59. channels,
  60. 2 * channels,
  61. kernel_size=1,
  62. stride=1,
  63. padding=0,
  64. bias=bias,
  65. )
  66. self.depthwise_conv = nn.Conv1d(
  67. channels,
  68. channels,
  69. kernel_size,
  70. stride=1,
  71. padding=(kernel_size - 1) // 2,
  72. groups=channels,
  73. bias=bias,
  74. )
  75. self.norm = nn.BatchNorm1d(channels)
  76. self.pointwise_conv2 = nn.Conv1d(
  77. channels,
  78. channels,
  79. kernel_size=1,
  80. stride=1,
  81. padding=0,
  82. bias=bias,
  83. )
  84. self.activation = activation
  85. def forward(self, x):
  86. """Compute convolution module.
  87. Args:
  88. x (torch.Tensor): Input tensor (#batch, time, channels).
  89. Returns:
  90. torch.Tensor: Output tensor (#batch, time, channels).
  91. """
  92. # exchange the temporal dimension and the feature dimension
  93. x = x.transpose(1, 2)
  94. # GLU mechanism
  95. x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
  96. x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
  97. # 1D Depthwise Conv
  98. x = self.depthwise_conv(x)
  99. x = self.activation(self.norm(x))
  100. x = self.pointwise_conv2(x)
  101. return x.transpose(1, 2)
  102. class EncoderLayer(nn.Module):
  103. """Encoder layer module.
  104. Args:
  105. size (int): Input dimension.
  106. self_attn (torch.nn.Module): Self-attention module instance.
  107. `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
  108. can be used as the argument.
  109. feed_forward (torch.nn.Module): Feed-forward module instance.
  110. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  111. can be used as the argument.
  112. feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
  113. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  114. can be used as the argument.
  115. conv_module (torch.nn.Module): Convolution module instance.
  116. `ConvlutionModule` instance can be used as the argument.
  117. dropout_rate (float): Dropout rate.
  118. normalize_before (bool): Whether to use layer_norm before the first block.
  119. concat_after (bool): Whether to concat attention layer's input and output.
  120. if True, additional linear will be applied.
  121. i.e. x -> x + linear(concat(x, att(x)))
  122. if False, no additional linear will be applied. i.e. x -> x + att(x)
  123. stochastic_depth_rate (float): Proability to skip this layer.
  124. During training, the layer may skip residual computation and return input
  125. as-is with given probability.
  126. """
  127. def __init__(
  128. self,
  129. size,
  130. self_attn,
  131. feed_forward,
  132. feed_forward_macaron,
  133. conv_module,
  134. dropout_rate,
  135. normalize_before=True,
  136. concat_after=False,
  137. stochastic_depth_rate=0.0,
  138. ):
  139. """Construct an EncoderLayer object."""
  140. super(EncoderLayer, self).__init__()
  141. self.self_attn = self_attn
  142. self.feed_forward = feed_forward
  143. self.feed_forward_macaron = feed_forward_macaron
  144. self.conv_module = conv_module
  145. self.norm_ff = LayerNorm(size) # for the FNN module
  146. self.norm_mha = LayerNorm(size) # for the MHA module
  147. if feed_forward_macaron is not None:
  148. self.norm_ff_macaron = LayerNorm(size)
  149. self.ff_scale = 0.5
  150. else:
  151. self.ff_scale = 1.0
  152. if self.conv_module is not None:
  153. self.norm_conv = LayerNorm(size) # for the CNN module
  154. self.norm_final = LayerNorm(size) # for the final output of the block
  155. self.dropout = nn.Dropout(dropout_rate)
  156. self.size = size
  157. self.normalize_before = normalize_before
  158. self.concat_after = concat_after
  159. if self.concat_after:
  160. self.concat_linear = nn.Linear(size + size, size)
  161. self.stochastic_depth_rate = stochastic_depth_rate
  162. def forward(self, x_input, mask, cache=None):
  163. """Compute encoded features.
  164. Args:
  165. x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
  166. - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
  167. - w/o pos emb: Tensor (#batch, time, size).
  168. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  169. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  170. Returns:
  171. torch.Tensor: Output tensor (#batch, time, size).
  172. torch.Tensor: Mask tensor (#batch, time).
  173. """
  174. if isinstance(x_input, tuple):
  175. x, pos_emb = x_input[0], x_input[1]
  176. else:
  177. x, pos_emb = x_input, None
  178. skip_layer = False
  179. # with stochastic depth, residual connection `x + f(x)` becomes
  180. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  181. stoch_layer_coeff = 1.0
  182. if self.training and self.stochastic_depth_rate > 0:
  183. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  184. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  185. if skip_layer:
  186. if cache is not None:
  187. x = torch.cat([cache, x], dim=1)
  188. if pos_emb is not None:
  189. return (x, pos_emb), mask
  190. return x, mask
  191. # whether to use macaron style
  192. if self.feed_forward_macaron is not None:
  193. residual = x
  194. if self.normalize_before:
  195. x = self.norm_ff_macaron(x)
  196. x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
  197. self.feed_forward_macaron(x)
  198. )
  199. if not self.normalize_before:
  200. x = self.norm_ff_macaron(x)
  201. # multi-headed self-attention module
  202. residual = x
  203. if self.normalize_before:
  204. x = self.norm_mha(x)
  205. if cache is None:
  206. x_q = x
  207. else:
  208. assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
  209. x_q = x[:, -1:, :]
  210. residual = residual[:, -1:, :]
  211. mask = None if mask is None else mask[:, -1:, :]
  212. if pos_emb is not None:
  213. x_att = self.self_attn(x_q, x, x, pos_emb, mask)
  214. else:
  215. x_att = self.self_attn(x_q, x, x, mask)
  216. if self.concat_after:
  217. x_concat = torch.cat((x, x_att), dim=-1)
  218. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  219. else:
  220. x = residual + stoch_layer_coeff * self.dropout(x_att)
  221. if not self.normalize_before:
  222. x = self.norm_mha(x)
  223. # convolution module
  224. if self.conv_module is not None:
  225. residual = x
  226. if self.normalize_before:
  227. x = self.norm_conv(x)
  228. x = residual + stoch_layer_coeff * self.dropout(self.conv_module(x))
  229. if not self.normalize_before:
  230. x = self.norm_conv(x)
  231. # feed forward module
  232. residual = x
  233. if self.normalize_before:
  234. x = self.norm_ff(x)
  235. x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
  236. self.feed_forward(x)
  237. )
  238. if not self.normalize_before:
  239. x = self.norm_ff(x)
  240. if self.conv_module is not None:
  241. x = self.norm_final(x)
  242. if cache is not None:
  243. x = torch.cat([cache, x], dim=1)
  244. if pos_emb is not None:
  245. return (x, pos_emb), mask
  246. return x, mask
  247. @tables.register("encoder_classes", "ConformerEncoder")
  248. class ConformerEncoder(nn.Module):
  249. """Conformer encoder module.
  250. Args:
  251. input_size (int): Input dimension.
  252. output_size (int): Dimension of attention.
  253. attention_heads (int): The number of heads of multi head attention.
  254. linear_units (int): The number of units of position-wise feed forward.
  255. num_blocks (int): The number of decoder blocks.
  256. dropout_rate (float): Dropout rate.
  257. attention_dropout_rate (float): Dropout rate in attention.
  258. positional_dropout_rate (float): Dropout rate after adding positional encoding.
  259. input_layer (Union[str, torch.nn.Module]): Input layer type.
  260. normalize_before (bool): Whether to use layer_norm before the first block.
  261. concat_after (bool): Whether to concat attention layer's input and output.
  262. If True, additional linear will be applied.
  263. i.e. x -> x + linear(concat(x, att(x)))
  264. If False, no additional linear will be applied. i.e. x -> x + att(x)
  265. positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
  266. positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
  267. rel_pos_type (str): Whether to use the latest relative positional encoding or
  268. the legacy one. The legacy relative positional encoding will be deprecated
  269. in the future. More Details can be found in
  270. https://github.com/espnet/espnet/pull/2816.
  271. encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
  272. encoder_attn_layer_type (str): Encoder attention layer type.
  273. activation_type (str): Encoder activation function type.
  274. macaron_style (bool): Whether to use macaron style for positionwise layer.
  275. use_cnn_module (bool): Whether to use convolution module.
  276. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
  277. cnn_module_kernel (int): Kernerl size of convolution module.
  278. padding_idx (int): Padding idx for input_layer=embed.
  279. """
  280. def __init__(
  281. self,
  282. input_size: int,
  283. output_size: int = 256,
  284. attention_heads: int = 4,
  285. linear_units: int = 2048,
  286. num_blocks: int = 6,
  287. dropout_rate: float = 0.1,
  288. positional_dropout_rate: float = 0.1,
  289. attention_dropout_rate: float = 0.0,
  290. input_layer: str = "conv2d",
  291. normalize_before: bool = True,
  292. concat_after: bool = False,
  293. positionwise_layer_type: str = "linear",
  294. positionwise_conv_kernel_size: int = 3,
  295. macaron_style: bool = False,
  296. rel_pos_type: str = "legacy",
  297. pos_enc_layer_type: str = "rel_pos",
  298. selfattention_layer_type: str = "rel_selfattn",
  299. activation_type: str = "swish",
  300. use_cnn_module: bool = True,
  301. zero_triu: bool = False,
  302. cnn_module_kernel: int = 31,
  303. padding_idx: int = -1,
  304. interctc_layer_idx: List[int] = [],
  305. interctc_use_conditioning: bool = False,
  306. stochastic_depth_rate: Union[float, List[float]] = 0.0,
  307. ):
  308. super().__init__()
  309. self._output_size = output_size
  310. if rel_pos_type == "legacy":
  311. if pos_enc_layer_type == "rel_pos":
  312. pos_enc_layer_type = "legacy_rel_pos"
  313. if selfattention_layer_type == "rel_selfattn":
  314. selfattention_layer_type = "legacy_rel_selfattn"
  315. elif rel_pos_type == "latest":
  316. assert selfattention_layer_type != "legacy_rel_selfattn"
  317. assert pos_enc_layer_type != "legacy_rel_pos"
  318. else:
  319. raise ValueError("unknown rel_pos_type: " + rel_pos_type)
  320. activation = get_activation(activation_type)
  321. if pos_enc_layer_type == "abs_pos":
  322. pos_enc_class = PositionalEncoding
  323. elif pos_enc_layer_type == "scaled_abs_pos":
  324. pos_enc_class = ScaledPositionalEncoding
  325. elif pos_enc_layer_type == "rel_pos":
  326. assert selfattention_layer_type == "rel_selfattn"
  327. pos_enc_class = RelPositionalEncoding
  328. elif pos_enc_layer_type == "legacy_rel_pos":
  329. assert selfattention_layer_type == "legacy_rel_selfattn"
  330. pos_enc_class = LegacyRelPositionalEncoding
  331. logging.warning(
  332. "Using legacy_rel_pos and it will be deprecated in the future."
  333. )
  334. else:
  335. raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
  336. if input_layer == "linear":
  337. self.embed = torch.nn.Sequential(
  338. torch.nn.Linear(input_size, output_size),
  339. torch.nn.LayerNorm(output_size),
  340. torch.nn.Dropout(dropout_rate),
  341. pos_enc_class(output_size, positional_dropout_rate),
  342. )
  343. elif input_layer == "conv2d":
  344. self.embed = Conv2dSubsampling(
  345. input_size,
  346. output_size,
  347. dropout_rate,
  348. pos_enc_class(output_size, positional_dropout_rate),
  349. )
  350. elif input_layer == "conv2dpad":
  351. self.embed = Conv2dSubsamplingPad(
  352. input_size,
  353. output_size,
  354. dropout_rate,
  355. pos_enc_class(output_size, positional_dropout_rate),
  356. )
  357. elif input_layer == "conv2d2":
  358. self.embed = Conv2dSubsampling2(
  359. input_size,
  360. output_size,
  361. dropout_rate,
  362. pos_enc_class(output_size, positional_dropout_rate),
  363. )
  364. elif input_layer == "conv2d6":
  365. self.embed = Conv2dSubsampling6(
  366. input_size,
  367. output_size,
  368. dropout_rate,
  369. pos_enc_class(output_size, positional_dropout_rate),
  370. )
  371. elif input_layer == "conv2d8":
  372. self.embed = Conv2dSubsampling8(
  373. input_size,
  374. output_size,
  375. dropout_rate,
  376. pos_enc_class(output_size, positional_dropout_rate),
  377. )
  378. elif input_layer == "embed":
  379. self.embed = torch.nn.Sequential(
  380. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  381. pos_enc_class(output_size, positional_dropout_rate),
  382. )
  383. elif isinstance(input_layer, torch.nn.Module):
  384. self.embed = torch.nn.Sequential(
  385. input_layer,
  386. pos_enc_class(output_size, positional_dropout_rate),
  387. )
  388. elif input_layer is None:
  389. self.embed = torch.nn.Sequential(
  390. pos_enc_class(output_size, positional_dropout_rate)
  391. )
  392. else:
  393. raise ValueError("unknown input_layer: " + input_layer)
  394. self.normalize_before = normalize_before
  395. if positionwise_layer_type == "linear":
  396. positionwise_layer = PositionwiseFeedForward
  397. positionwise_layer_args = (
  398. output_size,
  399. linear_units,
  400. dropout_rate,
  401. activation,
  402. )
  403. elif positionwise_layer_type == "conv1d":
  404. positionwise_layer = MultiLayeredConv1d
  405. positionwise_layer_args = (
  406. output_size,
  407. linear_units,
  408. positionwise_conv_kernel_size,
  409. dropout_rate,
  410. )
  411. elif positionwise_layer_type == "conv1d-linear":
  412. positionwise_layer = Conv1dLinear
  413. positionwise_layer_args = (
  414. output_size,
  415. linear_units,
  416. positionwise_conv_kernel_size,
  417. dropout_rate,
  418. )
  419. else:
  420. raise NotImplementedError("Support only linear or conv1d.")
  421. if selfattention_layer_type == "selfattn":
  422. encoder_selfattn_layer = MultiHeadedAttention
  423. encoder_selfattn_layer_args = (
  424. attention_heads,
  425. output_size,
  426. attention_dropout_rate,
  427. )
  428. elif selfattention_layer_type == "legacy_rel_selfattn":
  429. assert pos_enc_layer_type == "legacy_rel_pos"
  430. encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
  431. encoder_selfattn_layer_args = (
  432. attention_heads,
  433. output_size,
  434. attention_dropout_rate,
  435. )
  436. logging.warning(
  437. "Using legacy_rel_selfattn and it will be deprecated in the future."
  438. )
  439. elif selfattention_layer_type == "rel_selfattn":
  440. assert pos_enc_layer_type == "rel_pos"
  441. encoder_selfattn_layer = RelPositionMultiHeadedAttention
  442. encoder_selfattn_layer_args = (
  443. attention_heads,
  444. output_size,
  445. attention_dropout_rate,
  446. zero_triu,
  447. )
  448. else:
  449. raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
  450. convolution_layer = ConvolutionModule
  451. convolution_layer_args = (output_size, cnn_module_kernel, activation)
  452. if isinstance(stochastic_depth_rate, float):
  453. stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
  454. if len(stochastic_depth_rate) != num_blocks:
  455. raise ValueError(
  456. f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
  457. f"should be equal to num_blocks ({num_blocks})"
  458. )
  459. self.encoders = repeat(
  460. num_blocks,
  461. lambda lnum: EncoderLayer(
  462. output_size,
  463. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  464. positionwise_layer(*positionwise_layer_args),
  465. positionwise_layer(*positionwise_layer_args) if macaron_style else None,
  466. convolution_layer(*convolution_layer_args) if use_cnn_module else None,
  467. dropout_rate,
  468. normalize_before,
  469. concat_after,
  470. stochastic_depth_rate[lnum],
  471. ),
  472. )
  473. if self.normalize_before:
  474. self.after_norm = LayerNorm(output_size)
  475. self.interctc_layer_idx = interctc_layer_idx
  476. if len(interctc_layer_idx) > 0:
  477. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  478. self.interctc_use_conditioning = interctc_use_conditioning
  479. self.conditioning_layer = None
  480. def output_size(self) -> int:
  481. return self._output_size
  482. def forward(
  483. self,
  484. xs_pad: torch.Tensor,
  485. ilens: torch.Tensor,
  486. prev_states: torch.Tensor = None,
  487. ctc: CTC = None,
  488. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  489. """Calculate forward propagation.
  490. Args:
  491. xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
  492. ilens (torch.Tensor): Input length (#batch).
  493. prev_states (torch.Tensor): Not to be used now.
  494. Returns:
  495. torch.Tensor: Output tensor (#batch, L, output_size).
  496. torch.Tensor: Output length (#batch).
  497. torch.Tensor: Not to be used now.
  498. """
  499. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  500. if (
  501. isinstance(self.embed, Conv2dSubsampling)
  502. or isinstance(self.embed, Conv2dSubsampling2)
  503. or isinstance(self.embed, Conv2dSubsampling6)
  504. or isinstance(self.embed, Conv2dSubsampling8)
  505. or isinstance(self.embed, Conv2dSubsamplingPad)
  506. ):
  507. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  508. if short_status:
  509. raise TooShortUttError(
  510. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  511. + f"(it needs more than {limit_size} frames), return empty results",
  512. xs_pad.size(1),
  513. limit_size,
  514. )
  515. xs_pad, masks = self.embed(xs_pad, masks)
  516. else:
  517. xs_pad = self.embed(xs_pad)
  518. intermediate_outs = []
  519. if len(self.interctc_layer_idx) == 0:
  520. xs_pad, masks = self.encoders(xs_pad, masks)
  521. else:
  522. for layer_idx, encoder_layer in enumerate(self.encoders):
  523. xs_pad, masks = encoder_layer(xs_pad, masks)
  524. if layer_idx + 1 in self.interctc_layer_idx:
  525. encoder_out = xs_pad
  526. if isinstance(encoder_out, tuple):
  527. encoder_out = encoder_out[0]
  528. # intermediate outputs are also normalized
  529. if self.normalize_before:
  530. encoder_out = self.after_norm(encoder_out)
  531. intermediate_outs.append((layer_idx + 1, encoder_out))
  532. if self.interctc_use_conditioning:
  533. ctc_out = ctc.softmax(encoder_out)
  534. if isinstance(xs_pad, tuple):
  535. x, pos_emb = xs_pad
  536. x = x + self.conditioning_layer(ctc_out)
  537. xs_pad = (x, pos_emb)
  538. else:
  539. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  540. if isinstance(xs_pad, tuple):
  541. xs_pad = xs_pad[0]
  542. if self.normalize_before:
  543. xs_pad = self.after_norm(xs_pad)
  544. olens = masks.squeeze(1).sum(1)
  545. if len(intermediate_outs) > 0:
  546. return (xs_pad, intermediate_outs), olens, None
  547. return xs_pad, olens, None
  548. class CausalConvolution(torch.nn.Module):
  549. """ConformerConvolution module definition.
  550. Args:
  551. channels: The number of channels.
  552. kernel_size: Size of the convolving kernel.
  553. activation: Type of activation function.
  554. norm_args: Normalization module arguments.
  555. causal: Whether to use causal convolution (set to True if streaming).
  556. """
  557. def __init__(
  558. self,
  559. channels: int,
  560. kernel_size: int,
  561. activation: torch.nn.Module = torch.nn.ReLU(),
  562. norm_args: Dict = {},
  563. causal: bool = False,
  564. ) -> None:
  565. """Construct an ConformerConvolution object."""
  566. super().__init__()
  567. assert (kernel_size - 1) % 2 == 0
  568. self.kernel_size = kernel_size
  569. self.pointwise_conv1 = torch.nn.Conv1d(
  570. channels,
  571. 2 * channels,
  572. kernel_size=1,
  573. stride=1,
  574. padding=0,
  575. )
  576. if causal:
  577. self.lorder = kernel_size - 1
  578. padding = 0
  579. else:
  580. self.lorder = 0
  581. padding = (kernel_size - 1) // 2
  582. self.depthwise_conv = torch.nn.Conv1d(
  583. channels,
  584. channels,
  585. kernel_size,
  586. stride=1,
  587. padding=padding,
  588. groups=channels,
  589. )
  590. self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
  591. self.pointwise_conv2 = torch.nn.Conv1d(
  592. channels,
  593. channels,
  594. kernel_size=1,
  595. stride=1,
  596. padding=0,
  597. )
  598. self.activation = activation
  599. def forward(
  600. self,
  601. x: torch.Tensor,
  602. cache: Optional[torch.Tensor] = None,
  603. right_context: int = 0,
  604. ) -> Tuple[torch.Tensor, torch.Tensor]:
  605. """Compute convolution module.
  606. Args:
  607. x: ConformerConvolution input sequences. (B, T, D_hidden)
  608. cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
  609. right_context: Number of frames in right context.
  610. Returns:
  611. x: ConformerConvolution output sequences. (B, T, D_hidden)
  612. cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
  613. """
  614. x = self.pointwise_conv1(x.transpose(1, 2))
  615. x = torch.nn.functional.glu(x, dim=1)
  616. if self.lorder > 0:
  617. if cache is None:
  618. x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
  619. else:
  620. x = torch.cat([cache, x], dim=2)
  621. if right_context > 0:
  622. cache = x[:, :, -(self.lorder + right_context) : -right_context]
  623. else:
  624. cache = x[:, :, -self.lorder :]
  625. x = self.depthwise_conv(x)
  626. x = self.activation(self.norm(x))
  627. x = self.pointwise_conv2(x).transpose(1, 2)
  628. return x, cache
  629. class ChunkEncoderLayer(torch.nn.Module):
  630. """Chunk Conformer module definition.
  631. Args:
  632. block_size: Input/output size.
  633. self_att: Self-attention module instance.
  634. feed_forward: Feed-forward module instance.
  635. feed_forward_macaron: Feed-forward module instance for macaron network.
  636. conv_mod: Convolution module instance.
  637. norm_class: Normalization module class.
  638. norm_args: Normalization module arguments.
  639. dropout_rate: Dropout rate.
  640. """
  641. def __init__(
  642. self,
  643. block_size: int,
  644. self_att: torch.nn.Module,
  645. feed_forward: torch.nn.Module,
  646. feed_forward_macaron: torch.nn.Module,
  647. conv_mod: torch.nn.Module,
  648. norm_class: torch.nn.Module = LayerNorm,
  649. norm_args: Dict = {},
  650. dropout_rate: float = 0.0,
  651. ) -> None:
  652. """Construct a Conformer object."""
  653. super().__init__()
  654. self.self_att = self_att
  655. self.feed_forward = feed_forward
  656. self.feed_forward_macaron = feed_forward_macaron
  657. self.feed_forward_scale = 0.5
  658. self.conv_mod = conv_mod
  659. self.norm_feed_forward = norm_class(block_size, **norm_args)
  660. self.norm_self_att = norm_class(block_size, **norm_args)
  661. self.norm_macaron = norm_class(block_size, **norm_args)
  662. self.norm_conv = norm_class(block_size, **norm_args)
  663. self.norm_final = norm_class(block_size, **norm_args)
  664. self.dropout = torch.nn.Dropout(dropout_rate)
  665. self.block_size = block_size
  666. self.cache = None
  667. def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
  668. """Initialize/Reset self-attention and convolution modules cache for streaming.
  669. Args:
  670. left_context: Number of left frames during chunk-by-chunk inference.
  671. device: Device to use for cache tensor.
  672. """
  673. self.cache = [
  674. torch.zeros(
  675. (1, left_context, self.block_size),
  676. device=device,
  677. ),
  678. torch.zeros(
  679. (
  680. 1,
  681. self.block_size,
  682. self.conv_mod.kernel_size - 1,
  683. ),
  684. device=device,
  685. ),
  686. ]
  687. def forward(
  688. self,
  689. x: torch.Tensor,
  690. pos_enc: torch.Tensor,
  691. mask: torch.Tensor,
  692. chunk_mask: Optional[torch.Tensor] = None,
  693. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  694. """Encode input sequences.
  695. Args:
  696. x: Conformer input sequences. (B, T, D_block)
  697. pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
  698. mask: Source mask. (B, T)
  699. chunk_mask: Chunk mask. (T_2, T_2)
  700. Returns:
  701. x: Conformer output sequences. (B, T, D_block)
  702. mask: Source mask. (B, T)
  703. pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
  704. """
  705. residual = x
  706. x = self.norm_macaron(x)
  707. x = residual + self.feed_forward_scale * self.dropout(
  708. self.feed_forward_macaron(x)
  709. )
  710. residual = x
  711. x = self.norm_self_att(x)
  712. x_q = x
  713. x = residual + self.dropout(
  714. self.self_att(
  715. x_q,
  716. x,
  717. x,
  718. pos_enc,
  719. mask,
  720. chunk_mask=chunk_mask,
  721. )
  722. )
  723. residual = x
  724. x = self.norm_conv(x)
  725. x, _ = self.conv_mod(x)
  726. x = residual + self.dropout(x)
  727. residual = x
  728. x = self.norm_feed_forward(x)
  729. x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
  730. x = self.norm_final(x)
  731. return x, mask, pos_enc
  732. def chunk_forward(
  733. self,
  734. x: torch.Tensor,
  735. pos_enc: torch.Tensor,
  736. mask: torch.Tensor,
  737. chunk_size: int = 16,
  738. left_context: int = 0,
  739. right_context: int = 0,
  740. ) -> Tuple[torch.Tensor, torch.Tensor]:
  741. """Encode chunk of input sequence.
  742. Args:
  743. x: Conformer input sequences. (B, T, D_block)
  744. pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
  745. mask: Source mask. (B, T_2)
  746. left_context: Number of frames in left context.
  747. right_context: Number of frames in right context.
  748. Returns:
  749. x: Conformer output sequences. (B, T, D_block)
  750. pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
  751. """
  752. residual = x
  753. x = self.norm_macaron(x)
  754. x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
  755. residual = x
  756. x = self.norm_self_att(x)
  757. if left_context > 0:
  758. key = torch.cat([self.cache[0], x], dim=1)
  759. else:
  760. key = x
  761. val = key
  762. if right_context > 0:
  763. att_cache = key[:, -(left_context + right_context) : -right_context, :]
  764. else:
  765. att_cache = key[:, -left_context:, :]
  766. x = residual + self.self_att(
  767. x,
  768. key,
  769. val,
  770. pos_enc,
  771. mask,
  772. left_context=left_context,
  773. )
  774. residual = x
  775. x = self.norm_conv(x)
  776. x, conv_cache = self.conv_mod(
  777. x, cache=self.cache[1], right_context=right_context
  778. )
  779. x = residual + x
  780. residual = x
  781. x = self.norm_feed_forward(x)
  782. x = residual + self.feed_forward_scale * self.feed_forward(x)
  783. x = self.norm_final(x)
  784. self.cache = [att_cache, conv_cache]
  785. return x, pos_enc
  786. @tables.register("encoder_classes", "ChunkConformerEncoder")
  787. class ConformerChunkEncoder(torch.nn.Module):
  788. """Encoder module definition.
  789. Args:
  790. input_size: Input size.
  791. body_conf: Encoder body configuration.
  792. input_conf: Encoder input configuration.
  793. main_conf: Encoder main configuration.
  794. """
  795. def __init__(
  796. self,
  797. input_size: int,
  798. output_size: int = 256,
  799. attention_heads: int = 4,
  800. linear_units: int = 2048,
  801. num_blocks: int = 6,
  802. dropout_rate: float = 0.1,
  803. positional_dropout_rate: float = 0.1,
  804. attention_dropout_rate: float = 0.0,
  805. embed_vgg_like: bool = False,
  806. normalize_before: bool = True,
  807. concat_after: bool = False,
  808. positionwise_layer_type: str = "linear",
  809. positionwise_conv_kernel_size: int = 3,
  810. macaron_style: bool = False,
  811. rel_pos_type: str = "legacy",
  812. pos_enc_layer_type: str = "rel_pos",
  813. selfattention_layer_type: str = "rel_selfattn",
  814. activation_type: str = "swish",
  815. use_cnn_module: bool = True,
  816. zero_triu: bool = False,
  817. norm_type: str = "layer_norm",
  818. cnn_module_kernel: int = 31,
  819. conv_mod_norm_eps: float = 0.00001,
  820. conv_mod_norm_momentum: float = 0.1,
  821. simplified_att_score: bool = False,
  822. dynamic_chunk_training: bool = False,
  823. short_chunk_threshold: float = 0.75,
  824. short_chunk_size: int = 25,
  825. left_chunk_size: int = 0,
  826. time_reduction_factor: int = 1,
  827. unified_model_training: bool = False,
  828. default_chunk_size: int = 16,
  829. jitter_range: int = 4,
  830. subsampling_factor: int = 1,
  831. ) -> None:
  832. """Construct an Encoder object."""
  833. super().__init__()
  834. self.embed = StreamingConvInput(
  835. input_size=input_size,
  836. conv_size=output_size,
  837. subsampling_factor=subsampling_factor,
  838. vgg_like=embed_vgg_like,
  839. output_size=output_size,
  840. )
  841. self.pos_enc = StreamingRelPositionalEncoding(
  842. output_size,
  843. positional_dropout_rate,
  844. )
  845. activation = get_activation(
  846. activation_type
  847. )
  848. pos_wise_args = (
  849. output_size,
  850. linear_units,
  851. positional_dropout_rate,
  852. activation,
  853. )
  854. conv_mod_norm_args = {
  855. "eps": conv_mod_norm_eps,
  856. "momentum": conv_mod_norm_momentum,
  857. }
  858. conv_mod_args = (
  859. output_size,
  860. cnn_module_kernel,
  861. activation,
  862. conv_mod_norm_args,
  863. dynamic_chunk_training or unified_model_training,
  864. )
  865. mult_att_args = (
  866. attention_heads,
  867. output_size,
  868. attention_dropout_rate,
  869. simplified_att_score,
  870. )
  871. fn_modules = []
  872. for _ in range(num_blocks):
  873. module = lambda: ChunkEncoderLayer(
  874. output_size,
  875. RelPositionMultiHeadedAttentionChunk(*mult_att_args),
  876. PositionwiseFeedForward(*pos_wise_args),
  877. PositionwiseFeedForward(*pos_wise_args),
  878. CausalConvolution(*conv_mod_args),
  879. dropout_rate=dropout_rate,
  880. )
  881. fn_modules.append(module)
  882. self.encoders = MultiBlocks(
  883. [fn() for fn in fn_modules],
  884. output_size,
  885. )
  886. self._output_size = output_size
  887. self.dynamic_chunk_training = dynamic_chunk_training
  888. self.short_chunk_threshold = short_chunk_threshold
  889. self.short_chunk_size = short_chunk_size
  890. self.left_chunk_size = left_chunk_size
  891. self.unified_model_training = unified_model_training
  892. self.default_chunk_size = default_chunk_size
  893. self.jitter_range = jitter_range
  894. self.time_reduction_factor = time_reduction_factor
  895. def output_size(self) -> int:
  896. return self._output_size
  897. def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
  898. """Return the corresponding number of sample for a given chunk size, in frames.
  899. Where size is the number of features frames after applying subsampling.
  900. Args:
  901. size: Number of frames after subsampling.
  902. hop_length: Frontend's hop length
  903. Returns:
  904. : Number of raw samples
  905. """
  906. return self.embed.get_size_before_subsampling(size) * hop_length
  907. def get_encoder_input_size(self, size: int) -> int:
  908. """Return the corresponding number of sample for a given chunk size, in frames.
  909. Where size is the number of features frames after applying subsampling.
  910. Args:
  911. size: Number of frames after subsampling.
  912. Returns:
  913. : Number of raw samples
  914. """
  915. return self.embed.get_size_before_subsampling(size)
  916. def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
  917. """Initialize/Reset encoder streaming cache.
  918. Args:
  919. left_context: Number of frames in left context.
  920. device: Device ID.
  921. """
  922. return self.encoders.reset_streaming_cache(left_context, device)
  923. def forward(
  924. self,
  925. x: torch.Tensor,
  926. x_len: torch.Tensor,
  927. ) -> Tuple[torch.Tensor, torch.Tensor]:
  928. """Encode input sequences.
  929. Args:
  930. x: Encoder input features. (B, T_in, F)
  931. x_len: Encoder input features lengths. (B,)
  932. Returns:
  933. x: Encoder outputs. (B, T_out, D_enc)
  934. x_len: Encoder outputs lenghts. (B,)
  935. """
  936. short_status, limit_size = check_short_utt(
  937. self.embed.subsampling_factor, x.size(1)
  938. )
  939. if short_status:
  940. raise TooShortUttError(
  941. f"has {x.size(1)} frames and is too short for subsampling "
  942. + f"(it needs more than {limit_size} frames), return empty results",
  943. x.size(1),
  944. limit_size,
  945. )
  946. mask = make_source_mask(x_len).to(x.device)
  947. if self.unified_model_training:
  948. if self.training:
  949. chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
  950. else:
  951. chunk_size = self.default_chunk_size
  952. x, mask = self.embed(x, mask, chunk_size)
  953. pos_enc = self.pos_enc(x)
  954. chunk_mask = make_chunk_mask(
  955. x.size(1),
  956. chunk_size,
  957. left_chunk_size=self.left_chunk_size,
  958. device=x.device,
  959. )
  960. x_utt = self.encoders(
  961. x,
  962. pos_enc,
  963. mask,
  964. chunk_mask=None,
  965. )
  966. x_chunk = self.encoders(
  967. x,
  968. pos_enc,
  969. mask,
  970. chunk_mask=chunk_mask,
  971. )
  972. olens = mask.eq(0).sum(1)
  973. if self.time_reduction_factor > 1:
  974. x_utt = x_utt[:,::self.time_reduction_factor,:]
  975. x_chunk = x_chunk[:,::self.time_reduction_factor,:]
  976. olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
  977. return x_utt, x_chunk, olens
  978. elif self.dynamic_chunk_training:
  979. max_len = x.size(1)
  980. if self.training:
  981. chunk_size = torch.randint(1, max_len, (1,)).item()
  982. if chunk_size > (max_len * self.short_chunk_threshold):
  983. chunk_size = max_len
  984. else:
  985. chunk_size = (chunk_size % self.short_chunk_size) + 1
  986. else:
  987. chunk_size = self.default_chunk_size
  988. x, mask = self.embed(x, mask, chunk_size)
  989. pos_enc = self.pos_enc(x)
  990. chunk_mask = make_chunk_mask(
  991. x.size(1),
  992. chunk_size,
  993. left_chunk_size=self.left_chunk_size,
  994. device=x.device,
  995. )
  996. else:
  997. x, mask = self.embed(x, mask, None)
  998. pos_enc = self.pos_enc(x)
  999. chunk_mask = None
  1000. x = self.encoders(
  1001. x,
  1002. pos_enc,
  1003. mask,
  1004. chunk_mask=chunk_mask,
  1005. )
  1006. olens = mask.eq(0).sum(1)
  1007. if self.time_reduction_factor > 1:
  1008. x = x[:,::self.time_reduction_factor,:]
  1009. olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
  1010. return x, olens, None
  1011. def full_utt_forward(
  1012. self,
  1013. x: torch.Tensor,
  1014. x_len: torch.Tensor,
  1015. ) -> Tuple[torch.Tensor, torch.Tensor]:
  1016. """Encode input sequences.
  1017. Args:
  1018. x: Encoder input features. (B, T_in, F)
  1019. x_len: Encoder input features lengths. (B,)
  1020. Returns:
  1021. x: Encoder outputs. (B, T_out, D_enc)
  1022. x_len: Encoder outputs lenghts. (B,)
  1023. """
  1024. short_status, limit_size = check_short_utt(
  1025. self.embed.subsampling_factor, x.size(1)
  1026. )
  1027. if short_status:
  1028. raise TooShortUttError(
  1029. f"has {x.size(1)} frames and is too short for subsampling "
  1030. + f"(it needs more than {limit_size} frames), return empty results",
  1031. x.size(1),
  1032. limit_size,
  1033. )
  1034. mask = make_source_mask(x_len).to(x.device)
  1035. x, mask = self.embed(x, mask, None)
  1036. pos_enc = self.pos_enc(x)
  1037. x_utt = self.encoders(
  1038. x,
  1039. pos_enc,
  1040. mask,
  1041. chunk_mask=None,
  1042. )
  1043. if self.time_reduction_factor > 1:
  1044. x_utt = x_utt[:,::self.time_reduction_factor,:]
  1045. return x_utt
  1046. def simu_chunk_forward(
  1047. self,
  1048. x: torch.Tensor,
  1049. x_len: torch.Tensor,
  1050. chunk_size: int = 16,
  1051. left_context: int = 32,
  1052. right_context: int = 0,
  1053. ) -> torch.Tensor:
  1054. short_status, limit_size = check_short_utt(
  1055. self.embed.subsampling_factor, x.size(1)
  1056. )
  1057. if short_status:
  1058. raise TooShortUttError(
  1059. f"has {x.size(1)} frames and is too short for subsampling "
  1060. + f"(it needs more than {limit_size} frames), return empty results",
  1061. x.size(1),
  1062. limit_size,
  1063. )
  1064. mask = make_source_mask(x_len)
  1065. x, mask = self.embed(x, mask, chunk_size)
  1066. pos_enc = self.pos_enc(x)
  1067. chunk_mask = make_chunk_mask(
  1068. x.size(1),
  1069. chunk_size,
  1070. left_chunk_size=self.left_chunk_size,
  1071. device=x.device,
  1072. )
  1073. x = self.encoders(
  1074. x,
  1075. pos_enc,
  1076. mask,
  1077. chunk_mask=chunk_mask,
  1078. )
  1079. olens = mask.eq(0).sum(1)
  1080. if self.time_reduction_factor > 1:
  1081. x = x[:,::self.time_reduction_factor,:]
  1082. return x
  1083. def chunk_forward(
  1084. self,
  1085. x: torch.Tensor,
  1086. x_len: torch.Tensor,
  1087. processed_frames: torch.tensor,
  1088. chunk_size: int = 16,
  1089. left_context: int = 32,
  1090. right_context: int = 0,
  1091. ) -> torch.Tensor:
  1092. """Encode input sequences as chunks.
  1093. Args:
  1094. x: Encoder input features. (1, T_in, F)
  1095. x_len: Encoder input features lengths. (1,)
  1096. processed_frames: Number of frames already seen.
  1097. left_context: Number of frames in left context.
  1098. right_context: Number of frames in right context.
  1099. Returns:
  1100. x: Encoder outputs. (B, T_out, D_enc)
  1101. """
  1102. mask = make_source_mask(x_len)
  1103. x, mask = self.embed(x, mask, None)
  1104. if left_context > 0:
  1105. processed_mask = (
  1106. torch.arange(left_context, device=x.device)
  1107. .view(1, left_context)
  1108. .flip(1)
  1109. )
  1110. processed_mask = processed_mask >= processed_frames
  1111. mask = torch.cat([processed_mask, mask], dim=1)
  1112. pos_enc = self.pos_enc(x, left_context=left_context)
  1113. x = self.encoders.chunk_forward(
  1114. x,
  1115. pos_enc,
  1116. mask,
  1117. chunk_size=chunk_size,
  1118. left_context=left_context,
  1119. right_context=right_context,
  1120. )
  1121. if right_context > 0:
  1122. x = x[:, 0:-right_context, :]
  1123. if self.time_reduction_factor > 1:
  1124. x = x[:,::self.time_reduction_factor,:]
  1125. return x