conformer_encoder.py 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238
  1. # Copyright 2020 Tomoki Hayashi
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Conformer encoder definition."""
  4. import logging
  5. from typing import List
  6. from typing import Optional
  7. from typing import Tuple
  8. from typing import Union
  9. from typing import Dict
  10. import torch
  11. from torch import nn
  12. from typeguard import check_argument_types
  13. from funasr.models.ctc import CTC
  14. from funasr.models.encoder.abs_encoder import AbsEncoder
  15. from funasr.modules.attention import (
  16. MultiHeadedAttention, # noqa: H301
  17. RelPositionMultiHeadedAttention, # noqa: H301
  18. RelPositionMultiHeadedAttentionChunk,
  19. LegacyRelPositionMultiHeadedAttention, # noqa: H301
  20. )
  21. from funasr.modules.embedding import (
  22. PositionalEncoding, # noqa: H301
  23. ScaledPositionalEncoding, # noqa: H301
  24. RelPositionalEncoding, # noqa: H301
  25. LegacyRelPositionalEncoding, # noqa: H301
  26. StreamingRelPositionalEncoding,
  27. )
  28. from funasr.modules.layer_norm import LayerNorm
  29. from funasr.modules.multi_layer_conv import Conv1dLinear
  30. from funasr.modules.multi_layer_conv import MultiLayeredConv1d
  31. from funasr.modules.nets_utils import get_activation
  32. from funasr.modules.nets_utils import make_pad_mask
  33. from funasr.modules.nets_utils import (
  34. TooShortUttError,
  35. check_short_utt,
  36. make_chunk_mask,
  37. make_source_mask,
  38. )
  39. from funasr.modules.positionwise_feed_forward import (
  40. PositionwiseFeedForward, # noqa: H301
  41. )
  42. from funasr.modules.repeat import repeat, MultiBlocks
  43. from funasr.modules.subsampling import Conv2dSubsampling
  44. from funasr.modules.subsampling import Conv2dSubsampling2
  45. from funasr.modules.subsampling import Conv2dSubsampling6
  46. from funasr.modules.subsampling import Conv2dSubsampling8
  47. from funasr.modules.subsampling import TooShortUttError
  48. from funasr.modules.subsampling import check_short_utt
  49. from funasr.modules.subsampling import Conv2dSubsamplingPad
  50. from funasr.modules.subsampling import StreamingConvInput
  51. class ConvolutionModule(nn.Module):
  52. """ConvolutionModule in Conformer model.
  53. Args:
  54. channels (int): The number of channels of conv layers.
  55. kernel_size (int): Kernerl size of conv layers.
  56. """
  57. def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
  58. """Construct an ConvolutionModule object."""
  59. super(ConvolutionModule, self).__init__()
  60. # kernerl_size should be a odd number for 'SAME' padding
  61. assert (kernel_size - 1) % 2 == 0
  62. self.pointwise_conv1 = nn.Conv1d(
  63. channels,
  64. 2 * channels,
  65. kernel_size=1,
  66. stride=1,
  67. padding=0,
  68. bias=bias,
  69. )
  70. self.depthwise_conv = nn.Conv1d(
  71. channels,
  72. channels,
  73. kernel_size,
  74. stride=1,
  75. padding=(kernel_size - 1) // 2,
  76. groups=channels,
  77. bias=bias,
  78. )
  79. self.norm = nn.BatchNorm1d(channels)
  80. self.pointwise_conv2 = nn.Conv1d(
  81. channels,
  82. channels,
  83. kernel_size=1,
  84. stride=1,
  85. padding=0,
  86. bias=bias,
  87. )
  88. self.activation = activation
  89. def forward(self, x):
  90. """Compute convolution module.
  91. Args:
  92. x (torch.Tensor): Input tensor (#batch, time, channels).
  93. Returns:
  94. torch.Tensor: Output tensor (#batch, time, channels).
  95. """
  96. # exchange the temporal dimension and the feature dimension
  97. x = x.transpose(1, 2)
  98. # GLU mechanism
  99. x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
  100. x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
  101. # 1D Depthwise Conv
  102. x = self.depthwise_conv(x)
  103. x = self.activation(self.norm(x))
  104. x = self.pointwise_conv2(x)
  105. return x.transpose(1, 2)
  106. class EncoderLayer(nn.Module):
  107. """Encoder layer module.
  108. Args:
  109. size (int): Input dimension.
  110. self_attn (torch.nn.Module): Self-attention module instance.
  111. `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
  112. can be used as the argument.
  113. feed_forward (torch.nn.Module): Feed-forward module instance.
  114. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  115. can be used as the argument.
  116. feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
  117. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  118. can be used as the argument.
  119. conv_module (torch.nn.Module): Convolution module instance.
  120. `ConvlutionModule` instance can be used as the argument.
  121. dropout_rate (float): Dropout rate.
  122. normalize_before (bool): Whether to use layer_norm before the first block.
  123. concat_after (bool): Whether to concat attention layer's input and output.
  124. if True, additional linear will be applied.
  125. i.e. x -> x + linear(concat(x, att(x)))
  126. if False, no additional linear will be applied. i.e. x -> x + att(x)
  127. stochastic_depth_rate (float): Proability to skip this layer.
  128. During training, the layer may skip residual computation and return input
  129. as-is with given probability.
  130. """
  131. def __init__(
  132. self,
  133. size,
  134. self_attn,
  135. feed_forward,
  136. feed_forward_macaron,
  137. conv_module,
  138. dropout_rate,
  139. normalize_before=True,
  140. concat_after=False,
  141. stochastic_depth_rate=0.0,
  142. ):
  143. """Construct an EncoderLayer object."""
  144. super(EncoderLayer, self).__init__()
  145. self.self_attn = self_attn
  146. self.feed_forward = feed_forward
  147. self.feed_forward_macaron = feed_forward_macaron
  148. self.conv_module = conv_module
  149. self.norm_ff = LayerNorm(size) # for the FNN module
  150. self.norm_mha = LayerNorm(size) # for the MHA module
  151. if feed_forward_macaron is not None:
  152. self.norm_ff_macaron = LayerNorm(size)
  153. self.ff_scale = 0.5
  154. else:
  155. self.ff_scale = 1.0
  156. if self.conv_module is not None:
  157. self.norm_conv = LayerNorm(size) # for the CNN module
  158. self.norm_final = LayerNorm(size) # for the final output of the block
  159. self.dropout = nn.Dropout(dropout_rate)
  160. self.size = size
  161. self.normalize_before = normalize_before
  162. self.concat_after = concat_after
  163. if self.concat_after:
  164. self.concat_linear = nn.Linear(size + size, size)
  165. self.stochastic_depth_rate = stochastic_depth_rate
  166. def forward(self, x_input, mask, cache=None):
  167. """Compute encoded features.
  168. Args:
  169. x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
  170. - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
  171. - w/o pos emb: Tensor (#batch, time, size).
  172. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  173. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  174. Returns:
  175. torch.Tensor: Output tensor (#batch, time, size).
  176. torch.Tensor: Mask tensor (#batch, time).
  177. """
  178. if isinstance(x_input, tuple):
  179. x, pos_emb = x_input[0], x_input[1]
  180. else:
  181. x, pos_emb = x_input, None
  182. skip_layer = False
  183. # with stochastic depth, residual connection `x + f(x)` becomes
  184. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  185. stoch_layer_coeff = 1.0
  186. if self.training and self.stochastic_depth_rate > 0:
  187. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  188. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  189. if skip_layer:
  190. if cache is not None:
  191. x = torch.cat([cache, x], dim=1)
  192. if pos_emb is not None:
  193. return (x, pos_emb), mask
  194. return x, mask
  195. # whether to use macaron style
  196. if self.feed_forward_macaron is not None:
  197. residual = x
  198. if self.normalize_before:
  199. x = self.norm_ff_macaron(x)
  200. x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
  201. self.feed_forward_macaron(x)
  202. )
  203. if not self.normalize_before:
  204. x = self.norm_ff_macaron(x)
  205. # multi-headed self-attention module
  206. residual = x
  207. if self.normalize_before:
  208. x = self.norm_mha(x)
  209. if cache is None:
  210. x_q = x
  211. else:
  212. assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
  213. x_q = x[:, -1:, :]
  214. residual = residual[:, -1:, :]
  215. mask = None if mask is None else mask[:, -1:, :]
  216. if pos_emb is not None:
  217. x_att = self.self_attn(x_q, x, x, pos_emb, mask)
  218. else:
  219. x_att = self.self_attn(x_q, x, x, mask)
  220. if self.concat_after:
  221. x_concat = torch.cat((x, x_att), dim=-1)
  222. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  223. else:
  224. x = residual + stoch_layer_coeff * self.dropout(x_att)
  225. if not self.normalize_before:
  226. x = self.norm_mha(x)
  227. # convolution module
  228. if self.conv_module is not None:
  229. residual = x
  230. if self.normalize_before:
  231. x = self.norm_conv(x)
  232. x = residual + stoch_layer_coeff * self.dropout(self.conv_module(x))
  233. if not self.normalize_before:
  234. x = self.norm_conv(x)
  235. # feed forward module
  236. residual = x
  237. if self.normalize_before:
  238. x = self.norm_ff(x)
  239. x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
  240. self.feed_forward(x)
  241. )
  242. if not self.normalize_before:
  243. x = self.norm_ff(x)
  244. if self.conv_module is not None:
  245. x = self.norm_final(x)
  246. if cache is not None:
  247. x = torch.cat([cache, x], dim=1)
  248. if pos_emb is not None:
  249. return (x, pos_emb), mask
  250. return x, mask
  251. class ChunkEncoderLayer(torch.nn.Module):
  252. """Chunk Conformer module definition.
  253. Args:
  254. block_size: Input/output size.
  255. self_att: Self-attention module instance.
  256. feed_forward: Feed-forward module instance.
  257. feed_forward_macaron: Feed-forward module instance for macaron network.
  258. conv_mod: Convolution module instance.
  259. norm_class: Normalization module class.
  260. norm_args: Normalization module arguments.
  261. dropout_rate: Dropout rate.
  262. """
  263. def __init__(
  264. self,
  265. block_size: int,
  266. self_att: torch.nn.Module,
  267. feed_forward: torch.nn.Module,
  268. feed_forward_macaron: torch.nn.Module,
  269. conv_mod: torch.nn.Module,
  270. norm_class: torch.nn.Module = LayerNorm,
  271. norm_args: Dict = {},
  272. dropout_rate: float = 0.0,
  273. ) -> None:
  274. """Construct a Conformer object."""
  275. super().__init__()
  276. self.self_att = self_att
  277. self.feed_forward = feed_forward
  278. self.feed_forward_macaron = feed_forward_macaron
  279. self.feed_forward_scale = 0.5
  280. self.conv_mod = conv_mod
  281. self.norm_feed_forward = norm_class(block_size, **norm_args)
  282. self.norm_self_att = norm_class(block_size, **norm_args)
  283. self.norm_macaron = norm_class(block_size, **norm_args)
  284. self.norm_conv = norm_class(block_size, **norm_args)
  285. self.norm_final = norm_class(block_size, **norm_args)
  286. self.dropout = torch.nn.Dropout(dropout_rate)
  287. self.block_size = block_size
  288. self.cache = None
  289. def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
  290. """Initialize/Reset self-attention and convolution modules cache for streaming.
  291. Args:
  292. left_context: Number of left frames during chunk-by-chunk inference.
  293. device: Device to use for cache tensor.
  294. """
  295. self.cache = [
  296. torch.zeros(
  297. (1, left_context, self.block_size),
  298. device=device,
  299. ),
  300. torch.zeros(
  301. (
  302. 1,
  303. self.block_size,
  304. self.conv_mod.kernel_size - 1,
  305. ),
  306. device=device,
  307. ),
  308. ]
  309. def forward(
  310. self,
  311. x: torch.Tensor,
  312. pos_enc: torch.Tensor,
  313. mask: torch.Tensor,
  314. chunk_mask: Optional[torch.Tensor] = None,
  315. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  316. """Encode input sequences.
  317. Args:
  318. x: Conformer input sequences. (B, T, D_block)
  319. pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
  320. mask: Source mask. (B, T)
  321. chunk_mask: Chunk mask. (T_2, T_2)
  322. Returns:
  323. x: Conformer output sequences. (B, T, D_block)
  324. mask: Source mask. (B, T)
  325. pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
  326. """
  327. residual = x
  328. x = self.norm_macaron(x)
  329. x = residual + self.feed_forward_scale * self.dropout(
  330. self.feed_forward_macaron(x)
  331. )
  332. residual = x
  333. x = self.norm_self_att(x)
  334. x_q = x
  335. x = residual + self.dropout(
  336. self.self_att(
  337. x_q,
  338. x,
  339. x,
  340. pos_enc,
  341. mask,
  342. chunk_mask=chunk_mask,
  343. )
  344. )
  345. residual = x
  346. x = self.norm_conv(x)
  347. x, _ = self.conv_mod(x)
  348. x = residual + self.dropout(x)
  349. residual = x
  350. x = self.norm_feed_forward(x)
  351. x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
  352. x = self.norm_final(x)
  353. return x, mask, pos_enc
  354. def chunk_forward(
  355. self,
  356. x: torch.Tensor,
  357. pos_enc: torch.Tensor,
  358. mask: torch.Tensor,
  359. chunk_size: int = 16,
  360. left_context: int = 0,
  361. right_context: int = 0,
  362. ) -> Tuple[torch.Tensor, torch.Tensor]:
  363. """Encode chunk of input sequence.
  364. Args:
  365. x: Conformer input sequences. (B, T, D_block)
  366. pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
  367. mask: Source mask. (B, T_2)
  368. left_context: Number of frames in left context.
  369. right_context: Number of frames in right context.
  370. Returns:
  371. x: Conformer output sequences. (B, T, D_block)
  372. pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
  373. """
  374. residual = x
  375. x = self.norm_macaron(x)
  376. x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
  377. residual = x
  378. x = self.norm_self_att(x)
  379. if left_context > 0:
  380. key = torch.cat([self.cache[0], x], dim=1)
  381. else:
  382. key = x
  383. val = key
  384. if right_context > 0:
  385. att_cache = key[:, -(left_context + right_context) : -right_context, :]
  386. else:
  387. att_cache = key[:, -left_context:, :]
  388. x = residual + self.self_att(
  389. x,
  390. key,
  391. val,
  392. pos_enc,
  393. mask,
  394. left_context=left_context,
  395. )
  396. residual = x
  397. x = self.norm_conv(x)
  398. x, conv_cache = self.conv_mod(
  399. x, cache=self.cache[1], right_context=right_context
  400. )
  401. x = residual + x
  402. residual = x
  403. x = self.norm_feed_forward(x)
  404. x = residual + self.feed_forward_scale * self.feed_forward(x)
  405. x = self.norm_final(x)
  406. self.cache = [att_cache, conv_cache]
  407. return x, pos_enc
  408. class ConformerEncoder(AbsEncoder):
  409. """Conformer encoder module.
  410. Args:
  411. input_size (int): Input dimension.
  412. output_size (int): Dimension of attention.
  413. attention_heads (int): The number of heads of multi head attention.
  414. linear_units (int): The number of units of position-wise feed forward.
  415. num_blocks (int): The number of decoder blocks.
  416. dropout_rate (float): Dropout rate.
  417. attention_dropout_rate (float): Dropout rate in attention.
  418. positional_dropout_rate (float): Dropout rate after adding positional encoding.
  419. input_layer (Union[str, torch.nn.Module]): Input layer type.
  420. normalize_before (bool): Whether to use layer_norm before the first block.
  421. concat_after (bool): Whether to concat attention layer's input and output.
  422. If True, additional linear will be applied.
  423. i.e. x -> x + linear(concat(x, att(x)))
  424. If False, no additional linear will be applied. i.e. x -> x + att(x)
  425. positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
  426. positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
  427. rel_pos_type (str): Whether to use the latest relative positional encoding or
  428. the legacy one. The legacy relative positional encoding will be deprecated
  429. in the future. More Details can be found in
  430. https://github.com/espnet/espnet/pull/2816.
  431. encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
  432. encoder_attn_layer_type (str): Encoder attention layer type.
  433. activation_type (str): Encoder activation function type.
  434. macaron_style (bool): Whether to use macaron style for positionwise layer.
  435. use_cnn_module (bool): Whether to use convolution module.
  436. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
  437. cnn_module_kernel (int): Kernerl size of convolution module.
  438. padding_idx (int): Padding idx for input_layer=embed.
  439. """
  440. def __init__(
  441. self,
  442. input_size: int,
  443. output_size: int = 256,
  444. attention_heads: int = 4,
  445. linear_units: int = 2048,
  446. num_blocks: int = 6,
  447. dropout_rate: float = 0.1,
  448. positional_dropout_rate: float = 0.1,
  449. attention_dropout_rate: float = 0.0,
  450. input_layer: str = "conv2d",
  451. normalize_before: bool = True,
  452. concat_after: bool = False,
  453. positionwise_layer_type: str = "linear",
  454. positionwise_conv_kernel_size: int = 3,
  455. macaron_style: bool = False,
  456. rel_pos_type: str = "legacy",
  457. pos_enc_layer_type: str = "rel_pos",
  458. selfattention_layer_type: str = "rel_selfattn",
  459. activation_type: str = "swish",
  460. use_cnn_module: bool = True,
  461. zero_triu: bool = False,
  462. cnn_module_kernel: int = 31,
  463. padding_idx: int = -1,
  464. interctc_layer_idx: List[int] = [],
  465. interctc_use_conditioning: bool = False,
  466. stochastic_depth_rate: Union[float, List[float]] = 0.0,
  467. ):
  468. assert check_argument_types()
  469. super().__init__()
  470. self._output_size = output_size
  471. if rel_pos_type == "legacy":
  472. if pos_enc_layer_type == "rel_pos":
  473. pos_enc_layer_type = "legacy_rel_pos"
  474. if selfattention_layer_type == "rel_selfattn":
  475. selfattention_layer_type = "legacy_rel_selfattn"
  476. elif rel_pos_type == "latest":
  477. assert selfattention_layer_type != "legacy_rel_selfattn"
  478. assert pos_enc_layer_type != "legacy_rel_pos"
  479. else:
  480. raise ValueError("unknown rel_pos_type: " + rel_pos_type)
  481. activation = get_activation(activation_type)
  482. if pos_enc_layer_type == "abs_pos":
  483. pos_enc_class = PositionalEncoding
  484. elif pos_enc_layer_type == "scaled_abs_pos":
  485. pos_enc_class = ScaledPositionalEncoding
  486. elif pos_enc_layer_type == "rel_pos":
  487. assert selfattention_layer_type == "rel_selfattn"
  488. pos_enc_class = RelPositionalEncoding
  489. elif pos_enc_layer_type == "legacy_rel_pos":
  490. assert selfattention_layer_type == "legacy_rel_selfattn"
  491. pos_enc_class = LegacyRelPositionalEncoding
  492. logging.warning(
  493. "Using legacy_rel_pos and it will be deprecated in the future."
  494. )
  495. else:
  496. raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
  497. if input_layer == "linear":
  498. self.embed = torch.nn.Sequential(
  499. torch.nn.Linear(input_size, output_size),
  500. torch.nn.LayerNorm(output_size),
  501. torch.nn.Dropout(dropout_rate),
  502. pos_enc_class(output_size, positional_dropout_rate),
  503. )
  504. elif input_layer == "conv2d":
  505. self.embed = Conv2dSubsampling(
  506. input_size,
  507. output_size,
  508. dropout_rate,
  509. pos_enc_class(output_size, positional_dropout_rate),
  510. )
  511. elif input_layer == "conv2dpad":
  512. self.embed = Conv2dSubsamplingPad(
  513. input_size,
  514. output_size,
  515. dropout_rate,
  516. pos_enc_class(output_size, positional_dropout_rate),
  517. )
  518. elif input_layer == "conv2d2":
  519. self.embed = Conv2dSubsampling2(
  520. input_size,
  521. output_size,
  522. dropout_rate,
  523. pos_enc_class(output_size, positional_dropout_rate),
  524. )
  525. elif input_layer == "conv2d6":
  526. self.embed = Conv2dSubsampling6(
  527. input_size,
  528. output_size,
  529. dropout_rate,
  530. pos_enc_class(output_size, positional_dropout_rate),
  531. )
  532. elif input_layer == "conv2d8":
  533. self.embed = Conv2dSubsampling8(
  534. input_size,
  535. output_size,
  536. dropout_rate,
  537. pos_enc_class(output_size, positional_dropout_rate),
  538. )
  539. elif input_layer == "embed":
  540. self.embed = torch.nn.Sequential(
  541. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  542. pos_enc_class(output_size, positional_dropout_rate),
  543. )
  544. elif isinstance(input_layer, torch.nn.Module):
  545. self.embed = torch.nn.Sequential(
  546. input_layer,
  547. pos_enc_class(output_size, positional_dropout_rate),
  548. )
  549. elif input_layer is None:
  550. self.embed = torch.nn.Sequential(
  551. pos_enc_class(output_size, positional_dropout_rate)
  552. )
  553. else:
  554. raise ValueError("unknown input_layer: " + input_layer)
  555. self.normalize_before = normalize_before
  556. if positionwise_layer_type == "linear":
  557. positionwise_layer = PositionwiseFeedForward
  558. positionwise_layer_args = (
  559. output_size,
  560. linear_units,
  561. dropout_rate,
  562. activation,
  563. )
  564. elif positionwise_layer_type == "conv1d":
  565. positionwise_layer = MultiLayeredConv1d
  566. positionwise_layer_args = (
  567. output_size,
  568. linear_units,
  569. positionwise_conv_kernel_size,
  570. dropout_rate,
  571. )
  572. elif positionwise_layer_type == "conv1d-linear":
  573. positionwise_layer = Conv1dLinear
  574. positionwise_layer_args = (
  575. output_size,
  576. linear_units,
  577. positionwise_conv_kernel_size,
  578. dropout_rate,
  579. )
  580. else:
  581. raise NotImplementedError("Support only linear or conv1d.")
  582. if selfattention_layer_type == "selfattn":
  583. encoder_selfattn_layer = MultiHeadedAttention
  584. encoder_selfattn_layer_args = (
  585. attention_heads,
  586. output_size,
  587. attention_dropout_rate,
  588. )
  589. elif selfattention_layer_type == "legacy_rel_selfattn":
  590. assert pos_enc_layer_type == "legacy_rel_pos"
  591. encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
  592. encoder_selfattn_layer_args = (
  593. attention_heads,
  594. output_size,
  595. attention_dropout_rate,
  596. )
  597. logging.warning(
  598. "Using legacy_rel_selfattn and it will be deprecated in the future."
  599. )
  600. elif selfattention_layer_type == "rel_selfattn":
  601. assert pos_enc_layer_type == "rel_pos"
  602. encoder_selfattn_layer = RelPositionMultiHeadedAttention
  603. encoder_selfattn_layer_args = (
  604. attention_heads,
  605. output_size,
  606. attention_dropout_rate,
  607. zero_triu,
  608. )
  609. else:
  610. raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
  611. convolution_layer = ConvolutionModule
  612. convolution_layer_args = (output_size, cnn_module_kernel, activation)
  613. if isinstance(stochastic_depth_rate, float):
  614. stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
  615. if len(stochastic_depth_rate) != num_blocks:
  616. raise ValueError(
  617. f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
  618. f"should be equal to num_blocks ({num_blocks})"
  619. )
  620. self.encoders = repeat(
  621. num_blocks,
  622. lambda lnum: EncoderLayer(
  623. output_size,
  624. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  625. positionwise_layer(*positionwise_layer_args),
  626. positionwise_layer(*positionwise_layer_args) if macaron_style else None,
  627. convolution_layer(*convolution_layer_args) if use_cnn_module else None,
  628. dropout_rate,
  629. normalize_before,
  630. concat_after,
  631. stochastic_depth_rate[lnum],
  632. ),
  633. )
  634. if self.normalize_before:
  635. self.after_norm = LayerNorm(output_size)
  636. self.interctc_layer_idx = interctc_layer_idx
  637. if len(interctc_layer_idx) > 0:
  638. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  639. self.interctc_use_conditioning = interctc_use_conditioning
  640. self.conditioning_layer = None
  641. def output_size(self) -> int:
  642. return self._output_size
  643. def forward(
  644. self,
  645. xs_pad: torch.Tensor,
  646. ilens: torch.Tensor,
  647. prev_states: torch.Tensor = None,
  648. ctc: CTC = None,
  649. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  650. """Calculate forward propagation.
  651. Args:
  652. xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
  653. ilens (torch.Tensor): Input length (#batch).
  654. prev_states (torch.Tensor): Not to be used now.
  655. Returns:
  656. torch.Tensor: Output tensor (#batch, L, output_size).
  657. torch.Tensor: Output length (#batch).
  658. torch.Tensor: Not to be used now.
  659. """
  660. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  661. if (
  662. isinstance(self.embed, Conv2dSubsampling)
  663. or isinstance(self.embed, Conv2dSubsampling2)
  664. or isinstance(self.embed, Conv2dSubsampling6)
  665. or isinstance(self.embed, Conv2dSubsampling8)
  666. or isinstance(self.embed, Conv2dSubsamplingPad)
  667. ):
  668. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  669. if short_status:
  670. raise TooShortUttError(
  671. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  672. + f"(it needs more than {limit_size} frames), return empty results",
  673. xs_pad.size(1),
  674. limit_size,
  675. )
  676. xs_pad, masks = self.embed(xs_pad, masks)
  677. else:
  678. xs_pad = self.embed(xs_pad)
  679. intermediate_outs = []
  680. if len(self.interctc_layer_idx) == 0:
  681. xs_pad, masks = self.encoders(xs_pad, masks)
  682. else:
  683. for layer_idx, encoder_layer in enumerate(self.encoders):
  684. xs_pad, masks = encoder_layer(xs_pad, masks)
  685. if layer_idx + 1 in self.interctc_layer_idx:
  686. encoder_out = xs_pad
  687. if isinstance(encoder_out, tuple):
  688. encoder_out = encoder_out[0]
  689. # intermediate outputs are also normalized
  690. if self.normalize_before:
  691. encoder_out = self.after_norm(encoder_out)
  692. intermediate_outs.append((layer_idx + 1, encoder_out))
  693. if self.interctc_use_conditioning:
  694. ctc_out = ctc.softmax(encoder_out)
  695. if isinstance(xs_pad, tuple):
  696. x, pos_emb = xs_pad
  697. x = x + self.conditioning_layer(ctc_out)
  698. xs_pad = (x, pos_emb)
  699. else:
  700. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  701. if isinstance(xs_pad, tuple):
  702. xs_pad = xs_pad[0]
  703. if self.normalize_before:
  704. xs_pad = self.after_norm(xs_pad)
  705. olens = masks.squeeze(1).sum(1)
  706. if len(intermediate_outs) > 0:
  707. return (xs_pad, intermediate_outs), olens, None
  708. return xs_pad, olens, None
  709. class CausalConvolution(torch.nn.Module):
  710. """ConformerConvolution module definition.
  711. Args:
  712. channels: The number of channels.
  713. kernel_size: Size of the convolving kernel.
  714. activation: Type of activation function.
  715. norm_args: Normalization module arguments.
  716. causal: Whether to use causal convolution (set to True if streaming).
  717. """
  718. def __init__(
  719. self,
  720. channels: int,
  721. kernel_size: int,
  722. activation: torch.nn.Module = torch.nn.ReLU(),
  723. norm_args: Dict = {},
  724. causal: bool = False,
  725. ) -> None:
  726. """Construct an ConformerConvolution object."""
  727. super().__init__()
  728. assert (kernel_size - 1) % 2 == 0
  729. self.kernel_size = kernel_size
  730. self.pointwise_conv1 = torch.nn.Conv1d(
  731. channels,
  732. 2 * channels,
  733. kernel_size=1,
  734. stride=1,
  735. padding=0,
  736. )
  737. if causal:
  738. self.lorder = kernel_size - 1
  739. padding = 0
  740. else:
  741. self.lorder = 0
  742. padding = (kernel_size - 1) // 2
  743. self.depthwise_conv = torch.nn.Conv1d(
  744. channels,
  745. channels,
  746. kernel_size,
  747. stride=1,
  748. padding=padding,
  749. groups=channels,
  750. )
  751. self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
  752. self.pointwise_conv2 = torch.nn.Conv1d(
  753. channels,
  754. channels,
  755. kernel_size=1,
  756. stride=1,
  757. padding=0,
  758. )
  759. self.activation = activation
  760. def forward(
  761. self,
  762. x: torch.Tensor,
  763. cache: Optional[torch.Tensor] = None,
  764. right_context: int = 0,
  765. ) -> Tuple[torch.Tensor, torch.Tensor]:
  766. """Compute convolution module.
  767. Args:
  768. x: ConformerConvolution input sequences. (B, T, D_hidden)
  769. cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
  770. right_context: Number of frames in right context.
  771. Returns:
  772. x: ConformerConvolution output sequences. (B, T, D_hidden)
  773. cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
  774. """
  775. x = self.pointwise_conv1(x.transpose(1, 2))
  776. x = torch.nn.functional.glu(x, dim=1)
  777. if self.lorder > 0:
  778. if cache is None:
  779. x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
  780. else:
  781. x = torch.cat([cache, x], dim=2)
  782. if right_context > 0:
  783. cache = x[:, :, -(self.lorder + right_context) : -right_context]
  784. else:
  785. cache = x[:, :, -self.lorder :]
  786. x = self.depthwise_conv(x)
  787. x = self.activation(self.norm(x))
  788. x = self.pointwise_conv2(x).transpose(1, 2)
  789. return x, cache
  790. class ConformerChunkEncoder(AbsEncoder):
  791. """Encoder module definition.
  792. Args:
  793. input_size: Input size.
  794. body_conf: Encoder body configuration.
  795. input_conf: Encoder input configuration.
  796. main_conf: Encoder main configuration.
  797. """
  798. def __init__(
  799. self,
  800. input_size: int,
  801. output_size: int = 256,
  802. attention_heads: int = 4,
  803. linear_units: int = 2048,
  804. num_blocks: int = 6,
  805. dropout_rate: float = 0.1,
  806. positional_dropout_rate: float = 0.1,
  807. attention_dropout_rate: float = 0.0,
  808. embed_vgg_like: bool = False,
  809. normalize_before: bool = True,
  810. concat_after: bool = False,
  811. positionwise_layer_type: str = "linear",
  812. positionwise_conv_kernel_size: int = 3,
  813. macaron_style: bool = False,
  814. rel_pos_type: str = "legacy",
  815. pos_enc_layer_type: str = "rel_pos",
  816. selfattention_layer_type: str = "rel_selfattn",
  817. activation_type: str = "swish",
  818. use_cnn_module: bool = True,
  819. zero_triu: bool = False,
  820. norm_type: str = "layer_norm",
  821. cnn_module_kernel: int = 31,
  822. conv_mod_norm_eps: float = 0.00001,
  823. conv_mod_norm_momentum: float = 0.1,
  824. simplified_att_score: bool = False,
  825. dynamic_chunk_training: bool = False,
  826. short_chunk_threshold: float = 0.75,
  827. short_chunk_size: int = 25,
  828. left_chunk_size: int = 0,
  829. time_reduction_factor: int = 1,
  830. unified_model_training: bool = False,
  831. default_chunk_size: int = 16,
  832. jitter_range: int = 4,
  833. subsampling_factor: int = 1,
  834. ) -> None:
  835. """Construct an Encoder object."""
  836. super().__init__()
  837. assert check_argument_types()
  838. self.embed = StreamingConvInput(
  839. input_size,
  840. output_size,
  841. subsampling_factor,
  842. vgg_like=embed_vgg_like,
  843. output_size=output_size,
  844. )
  845. self.pos_enc = StreamingRelPositionalEncoding(
  846. output_size,
  847. positional_dropout_rate,
  848. )
  849. activation = get_activation(
  850. activation_type
  851. )
  852. pos_wise_args = (
  853. output_size,
  854. linear_units,
  855. positional_dropout_rate,
  856. activation,
  857. )
  858. conv_mod_norm_args = {
  859. "eps": conv_mod_norm_eps,
  860. "momentum": conv_mod_norm_momentum,
  861. }
  862. conv_mod_args = (
  863. output_size,
  864. cnn_module_kernel,
  865. activation,
  866. conv_mod_norm_args,
  867. dynamic_chunk_training or unified_model_training,
  868. )
  869. mult_att_args = (
  870. attention_heads,
  871. output_size,
  872. attention_dropout_rate,
  873. simplified_att_score,
  874. )
  875. fn_modules = []
  876. for _ in range(num_blocks):
  877. module = lambda: ChunkEncoderLayer(
  878. output_size,
  879. RelPositionMultiHeadedAttentionChunk(*mult_att_args),
  880. PositionwiseFeedForward(*pos_wise_args),
  881. PositionwiseFeedForward(*pos_wise_args),
  882. CausalConvolution(*conv_mod_args),
  883. dropout_rate=dropout_rate,
  884. )
  885. fn_modules.append(module)
  886. self.encoders = MultiBlocks(
  887. [fn() for fn in fn_modules],
  888. output_size,
  889. )
  890. self._output_size = output_size
  891. self.dynamic_chunk_training = dynamic_chunk_training
  892. self.short_chunk_threshold = short_chunk_threshold
  893. self.short_chunk_size = short_chunk_size
  894. self.left_chunk_size = left_chunk_size
  895. self.unified_model_training = unified_model_training
  896. self.default_chunk_size = default_chunk_size
  897. self.jitter_range = jitter_range
  898. self.time_reduction_factor = time_reduction_factor
  899. def output_size(self) -> int:
  900. return self._output_size
  901. def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
  902. """Return the corresponding number of sample for a given chunk size, in frames.
  903. Where size is the number of features frames after applying subsampling.
  904. Args:
  905. size: Number of frames after subsampling.
  906. hop_length: Frontend's hop length
  907. Returns:
  908. : Number of raw samples
  909. """
  910. return self.embed.get_size_before_subsampling(size) * hop_length
  911. def get_encoder_input_size(self, size: int) -> int:
  912. """Return the corresponding number of sample for a given chunk size, in frames.
  913. Where size is the number of features frames after applying subsampling.
  914. Args:
  915. size: Number of frames after subsampling.
  916. Returns:
  917. : Number of raw samples
  918. """
  919. return self.embed.get_size_before_subsampling(size)
  920. def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
  921. """Initialize/Reset encoder streaming cache.
  922. Args:
  923. left_context: Number of frames in left context.
  924. device: Device ID.
  925. """
  926. return self.encoders.reset_streaming_cache(left_context, device)
  927. def forward(
  928. self,
  929. x: torch.Tensor,
  930. x_len: torch.Tensor,
  931. ) -> Tuple[torch.Tensor, torch.Tensor]:
  932. """Encode input sequences.
  933. Args:
  934. x: Encoder input features. (B, T_in, F)
  935. x_len: Encoder input features lengths. (B,)
  936. Returns:
  937. x: Encoder outputs. (B, T_out, D_enc)
  938. x_len: Encoder outputs lenghts. (B,)
  939. """
  940. short_status, limit_size = check_short_utt(
  941. self.embed.subsampling_factor, x.size(1)
  942. )
  943. if short_status:
  944. raise TooShortUttError(
  945. f"has {x.size(1)} frames and is too short for subsampling "
  946. + f"(it needs more than {limit_size} frames), return empty results",
  947. x.size(1),
  948. limit_size,
  949. )
  950. mask = make_source_mask(x_len)
  951. if self.unified_model_training:
  952. chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
  953. x, mask = self.embed(x, mask, chunk_size)
  954. pos_enc = self.pos_enc(x)
  955. chunk_mask = make_chunk_mask(
  956. x.size(1),
  957. chunk_size,
  958. left_chunk_size=self.left_chunk_size,
  959. device=x.device,
  960. )
  961. x_utt = self.encoders(
  962. x,
  963. pos_enc,
  964. mask,
  965. chunk_mask=None,
  966. )
  967. x_chunk = self.encoders(
  968. x,
  969. pos_enc,
  970. mask,
  971. chunk_mask=chunk_mask,
  972. )
  973. olens = mask.eq(0).sum(1)
  974. if self.time_reduction_factor > 1:
  975. x_utt = x_utt[:,::self.time_reduction_factor,:]
  976. x_chunk = x_chunk[:,::self.time_reduction_factor,:]
  977. olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
  978. return x_utt, x_chunk, olens
  979. elif self.dynamic_chunk_training:
  980. max_len = x.size(1)
  981. chunk_size = torch.randint(1, max_len, (1,)).item()
  982. if chunk_size > (max_len * self.short_chunk_threshold):
  983. chunk_size = max_len
  984. else:
  985. chunk_size = (chunk_size % self.short_chunk_size) + 1
  986. x, mask = self.embed(x, mask, chunk_size)
  987. pos_enc = self.pos_enc(x)
  988. chunk_mask = make_chunk_mask(
  989. x.size(1),
  990. chunk_size,
  991. left_chunk_size=self.left_chunk_size,
  992. device=x.device,
  993. )
  994. else:
  995. x, mask = self.embed(x, mask, None)
  996. pos_enc = self.pos_enc(x)
  997. chunk_mask = None
  998. x = self.encoders(
  999. x,
  1000. pos_enc,
  1001. mask,
  1002. chunk_mask=chunk_mask,
  1003. )
  1004. olens = mask.eq(0).sum(1)
  1005. if self.time_reduction_factor > 1:
  1006. x = x[:,::self.time_reduction_factor,:]
  1007. olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
  1008. return x, olens, None
  1009. def simu_chunk_forward(
  1010. self,
  1011. x: torch.Tensor,
  1012. x_len: torch.Tensor,
  1013. chunk_size: int = 16,
  1014. left_context: int = 32,
  1015. right_context: int = 0,
  1016. ) -> torch.Tensor:
  1017. short_status, limit_size = check_short_utt(
  1018. self.embed.subsampling_factor, x.size(1)
  1019. )
  1020. if short_status:
  1021. raise TooShortUttError(
  1022. f"has {x.size(1)} frames and is too short for subsampling "
  1023. + f"(it needs more than {limit_size} frames), return empty results",
  1024. x.size(1),
  1025. limit_size,
  1026. )
  1027. mask = make_source_mask(x_len)
  1028. x, mask = self.embed(x, mask, chunk_size)
  1029. pos_enc = self.pos_enc(x)
  1030. chunk_mask = make_chunk_mask(
  1031. x.size(1),
  1032. chunk_size,
  1033. left_chunk_size=self.left_chunk_size,
  1034. device=x.device,
  1035. )
  1036. x = self.encoders(
  1037. x,
  1038. pos_enc,
  1039. mask,
  1040. chunk_mask=chunk_mask,
  1041. )
  1042. olens = mask.eq(0).sum(1)
  1043. if self.time_reduction_factor > 1:
  1044. x = x[:,::self.time_reduction_factor,:]
  1045. return x
  1046. def chunk_forward(
  1047. self,
  1048. x: torch.Tensor,
  1049. x_len: torch.Tensor,
  1050. processed_frames: torch.tensor,
  1051. chunk_size: int = 16,
  1052. left_context: int = 32,
  1053. right_context: int = 0,
  1054. ) -> torch.Tensor:
  1055. """Encode input sequences as chunks.
  1056. Args:
  1057. x: Encoder input features. (1, T_in, F)
  1058. x_len: Encoder input features lengths. (1,)
  1059. processed_frames: Number of frames already seen.
  1060. left_context: Number of frames in left context.
  1061. right_context: Number of frames in right context.
  1062. Returns:
  1063. x: Encoder outputs. (B, T_out, D_enc)
  1064. """
  1065. mask = make_source_mask(x_len)
  1066. x, mask = self.embed(x, mask, None)
  1067. if left_context > 0:
  1068. processed_mask = (
  1069. torch.arange(left_context, device=x.device)
  1070. .view(1, left_context)
  1071. .flip(1)
  1072. )
  1073. processed_mask = processed_mask >= processed_frames
  1074. mask = torch.cat([processed_mask, mask], dim=1)
  1075. pos_enc = self.pos_enc(x, left_context=left_context)
  1076. x = self.encoders.chunk_forward(
  1077. x,
  1078. pos_enc,
  1079. mask,
  1080. chunk_size=chunk_size,
  1081. left_context=left_context,
  1082. right_context=right_context,
  1083. )
  1084. if right_context > 0:
  1085. x = x[:, 0:-right_context, :]
  1086. if self.time_reduction_factor > 1:
  1087. x = x[:,::self.time_reduction_factor,:]
  1088. return x