conformer_encoder.py 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280
  1. # Copyright 2020 Tomoki Hayashi
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Conformer encoder definition."""
  4. import logging
  5. from typing import List
  6. from typing import Optional
  7. from typing import Tuple
  8. from typing import Union
  9. from typing import Dict
  10. import torch
  11. from torch import nn
  12. from funasr.models.ctc import CTC
  13. from funasr.modules.attention import (
  14. MultiHeadedAttention, # noqa: H301
  15. RelPositionMultiHeadedAttention, # noqa: H301
  16. RelPositionMultiHeadedAttentionChunk,
  17. LegacyRelPositionMultiHeadedAttention, # noqa: H301
  18. )
  19. from funasr.models.encoder.abs_encoder import AbsEncoder
  20. from funasr.modules.embedding import (
  21. PositionalEncoding, # noqa: H301
  22. ScaledPositionalEncoding, # noqa: H301
  23. RelPositionalEncoding, # noqa: H301
  24. LegacyRelPositionalEncoding, # noqa: H301
  25. StreamingRelPositionalEncoding,
  26. )
  27. from funasr.modules.layer_norm import LayerNorm
  28. from funasr.modules.multi_layer_conv import Conv1dLinear
  29. from funasr.modules.multi_layer_conv import MultiLayeredConv1d
  30. from funasr.modules.nets_utils import get_activation
  31. from funasr.modules.nets_utils import make_pad_mask
  32. from funasr.modules.nets_utils import (
  33. TooShortUttError,
  34. check_short_utt,
  35. make_chunk_mask,
  36. make_source_mask,
  37. )
  38. from funasr.modules.positionwise_feed_forward import (
  39. PositionwiseFeedForward, # noqa: H301
  40. )
  41. from funasr.modules.repeat import repeat, MultiBlocks
  42. from funasr.modules.subsampling import Conv2dSubsampling
  43. from funasr.modules.subsampling import Conv2dSubsampling2
  44. from funasr.modules.subsampling import Conv2dSubsampling6
  45. from funasr.modules.subsampling import Conv2dSubsampling8
  46. from funasr.modules.subsampling import TooShortUttError
  47. from funasr.modules.subsampling import check_short_utt
  48. from funasr.modules.subsampling import Conv2dSubsamplingPad
  49. from funasr.modules.subsampling import StreamingConvInput
  50. class ConvolutionModule(nn.Module):
  51. """ConvolutionModule in Conformer model.
  52. Args:
  53. channels (int): The number of channels of conv layers.
  54. kernel_size (int): Kernerl size of conv layers.
  55. """
  56. def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
  57. """Construct an ConvolutionModule object."""
  58. super(ConvolutionModule, self).__init__()
  59. # kernerl_size should be a odd number for 'SAME' padding
  60. assert (kernel_size - 1) % 2 == 0
  61. self.pointwise_conv1 = nn.Conv1d(
  62. channels,
  63. 2 * channels,
  64. kernel_size=1,
  65. stride=1,
  66. padding=0,
  67. bias=bias,
  68. )
  69. self.depthwise_conv = nn.Conv1d(
  70. channels,
  71. channels,
  72. kernel_size,
  73. stride=1,
  74. padding=(kernel_size - 1) // 2,
  75. groups=channels,
  76. bias=bias,
  77. )
  78. self.norm = nn.BatchNorm1d(channels)
  79. self.pointwise_conv2 = nn.Conv1d(
  80. channels,
  81. channels,
  82. kernel_size=1,
  83. stride=1,
  84. padding=0,
  85. bias=bias,
  86. )
  87. self.activation = activation
  88. def forward(self, x):
  89. """Compute convolution module.
  90. Args:
  91. x (torch.Tensor): Input tensor (#batch, time, channels).
  92. Returns:
  93. torch.Tensor: Output tensor (#batch, time, channels).
  94. """
  95. # exchange the temporal dimension and the feature dimension
  96. x = x.transpose(1, 2)
  97. # GLU mechanism
  98. x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
  99. x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
  100. # 1D Depthwise Conv
  101. x = self.depthwise_conv(x)
  102. x = self.activation(self.norm(x))
  103. x = self.pointwise_conv2(x)
  104. return x.transpose(1, 2)
  105. class EncoderLayer(nn.Module):
  106. """Encoder layer module.
  107. Args:
  108. size (int): Input dimension.
  109. self_attn (torch.nn.Module): Self-attention module instance.
  110. `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
  111. can be used as the argument.
  112. feed_forward (torch.nn.Module): Feed-forward module instance.
  113. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  114. can be used as the argument.
  115. feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
  116. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  117. can be used as the argument.
  118. conv_module (torch.nn.Module): Convolution module instance.
  119. `ConvlutionModule` instance can be used as the argument.
  120. dropout_rate (float): Dropout rate.
  121. normalize_before (bool): Whether to use layer_norm before the first block.
  122. concat_after (bool): Whether to concat attention layer's input and output.
  123. if True, additional linear will be applied.
  124. i.e. x -> x + linear(concat(x, att(x)))
  125. if False, no additional linear will be applied. i.e. x -> x + att(x)
  126. stochastic_depth_rate (float): Proability to skip this layer.
  127. During training, the layer may skip residual computation and return input
  128. as-is with given probability.
  129. """
  130. def __init__(
  131. self,
  132. size,
  133. self_attn,
  134. feed_forward,
  135. feed_forward_macaron,
  136. conv_module,
  137. dropout_rate,
  138. normalize_before=True,
  139. concat_after=False,
  140. stochastic_depth_rate=0.0,
  141. ):
  142. """Construct an EncoderLayer object."""
  143. super(EncoderLayer, self).__init__()
  144. self.self_attn = self_attn
  145. self.feed_forward = feed_forward
  146. self.feed_forward_macaron = feed_forward_macaron
  147. self.conv_module = conv_module
  148. self.norm_ff = LayerNorm(size) # for the FNN module
  149. self.norm_mha = LayerNorm(size) # for the MHA module
  150. if feed_forward_macaron is not None:
  151. self.norm_ff_macaron = LayerNorm(size)
  152. self.ff_scale = 0.5
  153. else:
  154. self.ff_scale = 1.0
  155. if self.conv_module is not None:
  156. self.norm_conv = LayerNorm(size) # for the CNN module
  157. self.norm_final = LayerNorm(size) # for the final output of the block
  158. self.dropout = nn.Dropout(dropout_rate)
  159. self.size = size
  160. self.normalize_before = normalize_before
  161. self.concat_after = concat_after
  162. if self.concat_after:
  163. self.concat_linear = nn.Linear(size + size, size)
  164. self.stochastic_depth_rate = stochastic_depth_rate
  165. def forward(self, x_input, mask, cache=None):
  166. """Compute encoded features.
  167. Args:
  168. x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
  169. - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
  170. - w/o pos emb: Tensor (#batch, time, size).
  171. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  172. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  173. Returns:
  174. torch.Tensor: Output tensor (#batch, time, size).
  175. torch.Tensor: Mask tensor (#batch, time).
  176. """
  177. if isinstance(x_input, tuple):
  178. x, pos_emb = x_input[0], x_input[1]
  179. else:
  180. x, pos_emb = x_input, None
  181. skip_layer = False
  182. # with stochastic depth, residual connection `x + f(x)` becomes
  183. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  184. stoch_layer_coeff = 1.0
  185. if self.training and self.stochastic_depth_rate > 0:
  186. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  187. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  188. if skip_layer:
  189. if cache is not None:
  190. x = torch.cat([cache, x], dim=1)
  191. if pos_emb is not None:
  192. return (x, pos_emb), mask
  193. return x, mask
  194. # whether to use macaron style
  195. if self.feed_forward_macaron is not None:
  196. residual = x
  197. if self.normalize_before:
  198. x = self.norm_ff_macaron(x)
  199. x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
  200. self.feed_forward_macaron(x)
  201. )
  202. if not self.normalize_before:
  203. x = self.norm_ff_macaron(x)
  204. # multi-headed self-attention module
  205. residual = x
  206. if self.normalize_before:
  207. x = self.norm_mha(x)
  208. if cache is None:
  209. x_q = x
  210. else:
  211. assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
  212. x_q = x[:, -1:, :]
  213. residual = residual[:, -1:, :]
  214. mask = None if mask is None else mask[:, -1:, :]
  215. if pos_emb is not None:
  216. x_att = self.self_attn(x_q, x, x, pos_emb, mask)
  217. else:
  218. x_att = self.self_attn(x_q, x, x, mask)
  219. if self.concat_after:
  220. x_concat = torch.cat((x, x_att), dim=-1)
  221. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  222. else:
  223. x = residual + stoch_layer_coeff * self.dropout(x_att)
  224. if not self.normalize_before:
  225. x = self.norm_mha(x)
  226. # convolution module
  227. if self.conv_module is not None:
  228. residual = x
  229. if self.normalize_before:
  230. x = self.norm_conv(x)
  231. x = residual + stoch_layer_coeff * self.dropout(self.conv_module(x))
  232. if not self.normalize_before:
  233. x = self.norm_conv(x)
  234. # feed forward module
  235. residual = x
  236. if self.normalize_before:
  237. x = self.norm_ff(x)
  238. x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
  239. self.feed_forward(x)
  240. )
  241. if not self.normalize_before:
  242. x = self.norm_ff(x)
  243. if self.conv_module is not None:
  244. x = self.norm_final(x)
  245. if cache is not None:
  246. x = torch.cat([cache, x], dim=1)
  247. if pos_emb is not None:
  248. return (x, pos_emb), mask
  249. return x, mask
  250. class ChunkEncoderLayer(torch.nn.Module):
  251. """Chunk Conformer module definition.
  252. Args:
  253. block_size: Input/output size.
  254. self_att: Self-attention module instance.
  255. feed_forward: Feed-forward module instance.
  256. feed_forward_macaron: Feed-forward module instance for macaron network.
  257. conv_mod: Convolution module instance.
  258. norm_class: Normalization module class.
  259. norm_args: Normalization module arguments.
  260. dropout_rate: Dropout rate.
  261. """
  262. def __init__(
  263. self,
  264. block_size: int,
  265. self_att: torch.nn.Module,
  266. feed_forward: torch.nn.Module,
  267. feed_forward_macaron: torch.nn.Module,
  268. conv_mod: torch.nn.Module,
  269. norm_class: torch.nn.Module = LayerNorm,
  270. norm_args: Dict = {},
  271. dropout_rate: float = 0.0,
  272. ) -> None:
  273. """Construct a Conformer object."""
  274. super().__init__()
  275. self.self_att = self_att
  276. self.feed_forward = feed_forward
  277. self.feed_forward_macaron = feed_forward_macaron
  278. self.feed_forward_scale = 0.5
  279. self.conv_mod = conv_mod
  280. self.norm_feed_forward = norm_class(block_size, **norm_args)
  281. self.norm_self_att = norm_class(block_size, **norm_args)
  282. self.norm_macaron = norm_class(block_size, **norm_args)
  283. self.norm_conv = norm_class(block_size, **norm_args)
  284. self.norm_final = norm_class(block_size, **norm_args)
  285. self.dropout = torch.nn.Dropout(dropout_rate)
  286. self.block_size = block_size
  287. self.cache = None
  288. def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
  289. """Initialize/Reset self-attention and convolution modules cache for streaming.
  290. Args:
  291. left_context: Number of left frames during chunk-by-chunk inference.
  292. device: Device to use for cache tensor.
  293. """
  294. self.cache = [
  295. torch.zeros(
  296. (1, left_context, self.block_size),
  297. device=device,
  298. ),
  299. torch.zeros(
  300. (
  301. 1,
  302. self.block_size,
  303. self.conv_mod.kernel_size - 1,
  304. ),
  305. device=device,
  306. ),
  307. ]
  308. def forward(
  309. self,
  310. x: torch.Tensor,
  311. pos_enc: torch.Tensor,
  312. mask: torch.Tensor,
  313. chunk_mask: Optional[torch.Tensor] = None,
  314. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  315. """Encode input sequences.
  316. Args:
  317. x: Conformer input sequences. (B, T, D_block)
  318. pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
  319. mask: Source mask. (B, T)
  320. chunk_mask: Chunk mask. (T_2, T_2)
  321. Returns:
  322. x: Conformer output sequences. (B, T, D_block)
  323. mask: Source mask. (B, T)
  324. pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
  325. """
  326. residual = x
  327. x = self.norm_macaron(x)
  328. x = residual + self.feed_forward_scale * self.dropout(
  329. self.feed_forward_macaron(x)
  330. )
  331. residual = x
  332. x = self.norm_self_att(x)
  333. x_q = x
  334. x = residual + self.dropout(
  335. self.self_att(
  336. x_q,
  337. x,
  338. x,
  339. pos_enc,
  340. mask,
  341. chunk_mask=chunk_mask,
  342. )
  343. )
  344. residual = x
  345. x = self.norm_conv(x)
  346. x, _ = self.conv_mod(x)
  347. x = residual + self.dropout(x)
  348. residual = x
  349. x = self.norm_feed_forward(x)
  350. x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
  351. x = self.norm_final(x)
  352. return x, mask, pos_enc
  353. def chunk_forward(
  354. self,
  355. x: torch.Tensor,
  356. pos_enc: torch.Tensor,
  357. mask: torch.Tensor,
  358. chunk_size: int = 16,
  359. left_context: int = 0,
  360. right_context: int = 0,
  361. ) -> Tuple[torch.Tensor, torch.Tensor]:
  362. """Encode chunk of input sequence.
  363. Args:
  364. x: Conformer input sequences. (B, T, D_block)
  365. pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
  366. mask: Source mask. (B, T_2)
  367. left_context: Number of frames in left context.
  368. right_context: Number of frames in right context.
  369. Returns:
  370. x: Conformer output sequences. (B, T, D_block)
  371. pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
  372. """
  373. residual = x
  374. x = self.norm_macaron(x)
  375. x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
  376. residual = x
  377. x = self.norm_self_att(x)
  378. if left_context > 0:
  379. key = torch.cat([self.cache[0], x], dim=1)
  380. else:
  381. key = x
  382. val = key
  383. if right_context > 0:
  384. att_cache = key[:, -(left_context + right_context) : -right_context, :]
  385. else:
  386. att_cache = key[:, -left_context:, :]
  387. x = residual + self.self_att(
  388. x,
  389. key,
  390. val,
  391. pos_enc,
  392. mask,
  393. left_context=left_context,
  394. )
  395. residual = x
  396. x = self.norm_conv(x)
  397. x, conv_cache = self.conv_mod(
  398. x, cache=self.cache[1], right_context=right_context
  399. )
  400. x = residual + x
  401. residual = x
  402. x = self.norm_feed_forward(x)
  403. x = residual + self.feed_forward_scale * self.feed_forward(x)
  404. x = self.norm_final(x)
  405. self.cache = [att_cache, conv_cache]
  406. return x, pos_enc
  407. class ConformerEncoder(AbsEncoder):
  408. """Conformer encoder module.
  409. Args:
  410. input_size (int): Input dimension.
  411. output_size (int): Dimension of attention.
  412. attention_heads (int): The number of heads of multi head attention.
  413. linear_units (int): The number of units of position-wise feed forward.
  414. num_blocks (int): The number of decoder blocks.
  415. dropout_rate (float): Dropout rate.
  416. attention_dropout_rate (float): Dropout rate in attention.
  417. positional_dropout_rate (float): Dropout rate after adding positional encoding.
  418. input_layer (Union[str, torch.nn.Module]): Input layer type.
  419. normalize_before (bool): Whether to use layer_norm before the first block.
  420. concat_after (bool): Whether to concat attention layer's input and output.
  421. If True, additional linear will be applied.
  422. i.e. x -> x + linear(concat(x, att(x)))
  423. If False, no additional linear will be applied. i.e. x -> x + att(x)
  424. positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
  425. positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
  426. rel_pos_type (str): Whether to use the latest relative positional encoding or
  427. the legacy one. The legacy relative positional encoding will be deprecated
  428. in the future. More Details can be found in
  429. https://github.com/espnet/espnet/pull/2816.
  430. encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
  431. encoder_attn_layer_type (str): Encoder attention layer type.
  432. activation_type (str): Encoder activation function type.
  433. macaron_style (bool): Whether to use macaron style for positionwise layer.
  434. use_cnn_module (bool): Whether to use convolution module.
  435. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
  436. cnn_module_kernel (int): Kernerl size of convolution module.
  437. padding_idx (int): Padding idx for input_layer=embed.
  438. """
  439. def __init__(
  440. self,
  441. input_size: int,
  442. output_size: int = 256,
  443. attention_heads: int = 4,
  444. linear_units: int = 2048,
  445. num_blocks: int = 6,
  446. dropout_rate: float = 0.1,
  447. positional_dropout_rate: float = 0.1,
  448. attention_dropout_rate: float = 0.0,
  449. input_layer: str = "conv2d",
  450. normalize_before: bool = True,
  451. concat_after: bool = False,
  452. positionwise_layer_type: str = "linear",
  453. positionwise_conv_kernel_size: int = 3,
  454. macaron_style: bool = False,
  455. rel_pos_type: str = "legacy",
  456. pos_enc_layer_type: str = "rel_pos",
  457. selfattention_layer_type: str = "rel_selfattn",
  458. activation_type: str = "swish",
  459. use_cnn_module: bool = True,
  460. zero_triu: bool = False,
  461. cnn_module_kernel: int = 31,
  462. padding_idx: int = -1,
  463. interctc_layer_idx: List[int] = [],
  464. interctc_use_conditioning: bool = False,
  465. stochastic_depth_rate: Union[float, List[float]] = 0.0,
  466. ):
  467. super().__init__()
  468. self._output_size = output_size
  469. if rel_pos_type == "legacy":
  470. if pos_enc_layer_type == "rel_pos":
  471. pos_enc_layer_type = "legacy_rel_pos"
  472. if selfattention_layer_type == "rel_selfattn":
  473. selfattention_layer_type = "legacy_rel_selfattn"
  474. elif rel_pos_type == "latest":
  475. assert selfattention_layer_type != "legacy_rel_selfattn"
  476. assert pos_enc_layer_type != "legacy_rel_pos"
  477. else:
  478. raise ValueError("unknown rel_pos_type: " + rel_pos_type)
  479. activation = get_activation(activation_type)
  480. if pos_enc_layer_type == "abs_pos":
  481. pos_enc_class = PositionalEncoding
  482. elif pos_enc_layer_type == "scaled_abs_pos":
  483. pos_enc_class = ScaledPositionalEncoding
  484. elif pos_enc_layer_type == "rel_pos":
  485. assert selfattention_layer_type == "rel_selfattn"
  486. pos_enc_class = RelPositionalEncoding
  487. elif pos_enc_layer_type == "legacy_rel_pos":
  488. assert selfattention_layer_type == "legacy_rel_selfattn"
  489. pos_enc_class = LegacyRelPositionalEncoding
  490. logging.warning(
  491. "Using legacy_rel_pos and it will be deprecated in the future."
  492. )
  493. else:
  494. raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
  495. if input_layer == "linear":
  496. self.embed = torch.nn.Sequential(
  497. torch.nn.Linear(input_size, output_size),
  498. torch.nn.LayerNorm(output_size),
  499. torch.nn.Dropout(dropout_rate),
  500. pos_enc_class(output_size, positional_dropout_rate),
  501. )
  502. elif input_layer == "conv2d":
  503. self.embed = Conv2dSubsampling(
  504. input_size,
  505. output_size,
  506. dropout_rate,
  507. pos_enc_class(output_size, positional_dropout_rate),
  508. )
  509. elif input_layer == "conv2dpad":
  510. self.embed = Conv2dSubsamplingPad(
  511. input_size,
  512. output_size,
  513. dropout_rate,
  514. pos_enc_class(output_size, positional_dropout_rate),
  515. )
  516. elif input_layer == "conv2d2":
  517. self.embed = Conv2dSubsampling2(
  518. input_size,
  519. output_size,
  520. dropout_rate,
  521. pos_enc_class(output_size, positional_dropout_rate),
  522. )
  523. elif input_layer == "conv2d6":
  524. self.embed = Conv2dSubsampling6(
  525. input_size,
  526. output_size,
  527. dropout_rate,
  528. pos_enc_class(output_size, positional_dropout_rate),
  529. )
  530. elif input_layer == "conv2d8":
  531. self.embed = Conv2dSubsampling8(
  532. input_size,
  533. output_size,
  534. dropout_rate,
  535. pos_enc_class(output_size, positional_dropout_rate),
  536. )
  537. elif input_layer == "embed":
  538. self.embed = torch.nn.Sequential(
  539. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  540. pos_enc_class(output_size, positional_dropout_rate),
  541. )
  542. elif isinstance(input_layer, torch.nn.Module):
  543. self.embed = torch.nn.Sequential(
  544. input_layer,
  545. pos_enc_class(output_size, positional_dropout_rate),
  546. )
  547. elif input_layer is None:
  548. self.embed = torch.nn.Sequential(
  549. pos_enc_class(output_size, positional_dropout_rate)
  550. )
  551. else:
  552. raise ValueError("unknown input_layer: " + input_layer)
  553. self.normalize_before = normalize_before
  554. if positionwise_layer_type == "linear":
  555. positionwise_layer = PositionwiseFeedForward
  556. positionwise_layer_args = (
  557. output_size,
  558. linear_units,
  559. dropout_rate,
  560. activation,
  561. )
  562. elif positionwise_layer_type == "conv1d":
  563. positionwise_layer = MultiLayeredConv1d
  564. positionwise_layer_args = (
  565. output_size,
  566. linear_units,
  567. positionwise_conv_kernel_size,
  568. dropout_rate,
  569. )
  570. elif positionwise_layer_type == "conv1d-linear":
  571. positionwise_layer = Conv1dLinear
  572. positionwise_layer_args = (
  573. output_size,
  574. linear_units,
  575. positionwise_conv_kernel_size,
  576. dropout_rate,
  577. )
  578. else:
  579. raise NotImplementedError("Support only linear or conv1d.")
  580. if selfattention_layer_type == "selfattn":
  581. encoder_selfattn_layer = MultiHeadedAttention
  582. encoder_selfattn_layer_args = (
  583. attention_heads,
  584. output_size,
  585. attention_dropout_rate,
  586. )
  587. elif selfattention_layer_type == "legacy_rel_selfattn":
  588. assert pos_enc_layer_type == "legacy_rel_pos"
  589. encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
  590. encoder_selfattn_layer_args = (
  591. attention_heads,
  592. output_size,
  593. attention_dropout_rate,
  594. )
  595. logging.warning(
  596. "Using legacy_rel_selfattn and it will be deprecated in the future."
  597. )
  598. elif selfattention_layer_type == "rel_selfattn":
  599. assert pos_enc_layer_type == "rel_pos"
  600. encoder_selfattn_layer = RelPositionMultiHeadedAttention
  601. encoder_selfattn_layer_args = (
  602. attention_heads,
  603. output_size,
  604. attention_dropout_rate,
  605. zero_triu,
  606. )
  607. else:
  608. raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
  609. convolution_layer = ConvolutionModule
  610. convolution_layer_args = (output_size, cnn_module_kernel, activation)
  611. if isinstance(stochastic_depth_rate, float):
  612. stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
  613. if len(stochastic_depth_rate) != num_blocks:
  614. raise ValueError(
  615. f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
  616. f"should be equal to num_blocks ({num_blocks})"
  617. )
  618. self.encoders = repeat(
  619. num_blocks,
  620. lambda lnum: EncoderLayer(
  621. output_size,
  622. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  623. positionwise_layer(*positionwise_layer_args),
  624. positionwise_layer(*positionwise_layer_args) if macaron_style else None,
  625. convolution_layer(*convolution_layer_args) if use_cnn_module else None,
  626. dropout_rate,
  627. normalize_before,
  628. concat_after,
  629. stochastic_depth_rate[lnum],
  630. ),
  631. )
  632. if self.normalize_before:
  633. self.after_norm = LayerNorm(output_size)
  634. self.interctc_layer_idx = interctc_layer_idx
  635. if len(interctc_layer_idx) > 0:
  636. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  637. self.interctc_use_conditioning = interctc_use_conditioning
  638. self.conditioning_layer = None
  639. def output_size(self) -> int:
  640. return self._output_size
  641. def forward(
  642. self,
  643. xs_pad: torch.Tensor,
  644. ilens: torch.Tensor,
  645. prev_states: torch.Tensor = None,
  646. ctc: CTC = None,
  647. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  648. """Calculate forward propagation.
  649. Args:
  650. xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
  651. ilens (torch.Tensor): Input length (#batch).
  652. prev_states (torch.Tensor): Not to be used now.
  653. Returns:
  654. torch.Tensor: Output tensor (#batch, L, output_size).
  655. torch.Tensor: Output length (#batch).
  656. torch.Tensor: Not to be used now.
  657. """
  658. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  659. if (
  660. isinstance(self.embed, Conv2dSubsampling)
  661. or isinstance(self.embed, Conv2dSubsampling2)
  662. or isinstance(self.embed, Conv2dSubsampling6)
  663. or isinstance(self.embed, Conv2dSubsampling8)
  664. or isinstance(self.embed, Conv2dSubsamplingPad)
  665. ):
  666. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  667. if short_status:
  668. raise TooShortUttError(
  669. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  670. + f"(it needs more than {limit_size} frames), return empty results",
  671. xs_pad.size(1),
  672. limit_size,
  673. )
  674. xs_pad, masks = self.embed(xs_pad, masks)
  675. else:
  676. xs_pad = self.embed(xs_pad)
  677. intermediate_outs = []
  678. if len(self.interctc_layer_idx) == 0:
  679. xs_pad, masks = self.encoders(xs_pad, masks)
  680. else:
  681. for layer_idx, encoder_layer in enumerate(self.encoders):
  682. xs_pad, masks = encoder_layer(xs_pad, masks)
  683. if layer_idx + 1 in self.interctc_layer_idx:
  684. encoder_out = xs_pad
  685. if isinstance(encoder_out, tuple):
  686. encoder_out = encoder_out[0]
  687. # intermediate outputs are also normalized
  688. if self.normalize_before:
  689. encoder_out = self.after_norm(encoder_out)
  690. intermediate_outs.append((layer_idx + 1, encoder_out))
  691. if self.interctc_use_conditioning:
  692. ctc_out = ctc.softmax(encoder_out)
  693. if isinstance(xs_pad, tuple):
  694. x, pos_emb = xs_pad
  695. x = x + self.conditioning_layer(ctc_out)
  696. xs_pad = (x, pos_emb)
  697. else:
  698. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  699. if isinstance(xs_pad, tuple):
  700. xs_pad = xs_pad[0]
  701. if self.normalize_before:
  702. xs_pad = self.after_norm(xs_pad)
  703. olens = masks.squeeze(1).sum(1)
  704. if len(intermediate_outs) > 0:
  705. return (xs_pad, intermediate_outs), olens, None
  706. return xs_pad, olens, None
  707. class CausalConvolution(torch.nn.Module):
  708. """ConformerConvolution module definition.
  709. Args:
  710. channels: The number of channels.
  711. kernel_size: Size of the convolving kernel.
  712. activation: Type of activation function.
  713. norm_args: Normalization module arguments.
  714. causal: Whether to use causal convolution (set to True if streaming).
  715. """
  716. def __init__(
  717. self,
  718. channels: int,
  719. kernel_size: int,
  720. activation: torch.nn.Module = torch.nn.ReLU(),
  721. norm_args: Dict = {},
  722. causal: bool = False,
  723. ) -> None:
  724. """Construct an ConformerConvolution object."""
  725. super().__init__()
  726. assert (kernel_size - 1) % 2 == 0
  727. self.kernel_size = kernel_size
  728. self.pointwise_conv1 = torch.nn.Conv1d(
  729. channels,
  730. 2 * channels,
  731. kernel_size=1,
  732. stride=1,
  733. padding=0,
  734. )
  735. if causal:
  736. self.lorder = kernel_size - 1
  737. padding = 0
  738. else:
  739. self.lorder = 0
  740. padding = (kernel_size - 1) // 2
  741. self.depthwise_conv = torch.nn.Conv1d(
  742. channels,
  743. channels,
  744. kernel_size,
  745. stride=1,
  746. padding=padding,
  747. groups=channels,
  748. )
  749. self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
  750. self.pointwise_conv2 = torch.nn.Conv1d(
  751. channels,
  752. channels,
  753. kernel_size=1,
  754. stride=1,
  755. padding=0,
  756. )
  757. self.activation = activation
  758. def forward(
  759. self,
  760. x: torch.Tensor,
  761. cache: Optional[torch.Tensor] = None,
  762. right_context: int = 0,
  763. ) -> Tuple[torch.Tensor, torch.Tensor]:
  764. """Compute convolution module.
  765. Args:
  766. x: ConformerConvolution input sequences. (B, T, D_hidden)
  767. cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
  768. right_context: Number of frames in right context.
  769. Returns:
  770. x: ConformerConvolution output sequences. (B, T, D_hidden)
  771. cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
  772. """
  773. x = self.pointwise_conv1(x.transpose(1, 2))
  774. x = torch.nn.functional.glu(x, dim=1)
  775. if self.lorder > 0:
  776. if cache is None:
  777. x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
  778. else:
  779. x = torch.cat([cache, x], dim=2)
  780. if right_context > 0:
  781. cache = x[:, :, -(self.lorder + right_context) : -right_context]
  782. else:
  783. cache = x[:, :, -self.lorder :]
  784. x = self.depthwise_conv(x)
  785. x = self.activation(self.norm(x))
  786. x = self.pointwise_conv2(x).transpose(1, 2)
  787. return x, cache
  788. class ConformerChunkEncoder(AbsEncoder):
  789. """Encoder module definition.
  790. Args:
  791. input_size: Input size.
  792. body_conf: Encoder body configuration.
  793. input_conf: Encoder input configuration.
  794. main_conf: Encoder main configuration.
  795. """
  796. def __init__(
  797. self,
  798. input_size: int,
  799. output_size: int = 256,
  800. attention_heads: int = 4,
  801. linear_units: int = 2048,
  802. num_blocks: int = 6,
  803. dropout_rate: float = 0.1,
  804. positional_dropout_rate: float = 0.1,
  805. attention_dropout_rate: float = 0.0,
  806. embed_vgg_like: bool = False,
  807. normalize_before: bool = True,
  808. concat_after: bool = False,
  809. positionwise_layer_type: str = "linear",
  810. positionwise_conv_kernel_size: int = 3,
  811. macaron_style: bool = False,
  812. rel_pos_type: str = "legacy",
  813. pos_enc_layer_type: str = "rel_pos",
  814. selfattention_layer_type: str = "rel_selfattn",
  815. activation_type: str = "swish",
  816. use_cnn_module: bool = True,
  817. zero_triu: bool = False,
  818. norm_type: str = "layer_norm",
  819. cnn_module_kernel: int = 31,
  820. conv_mod_norm_eps: float = 0.00001,
  821. conv_mod_norm_momentum: float = 0.1,
  822. simplified_att_score: bool = False,
  823. dynamic_chunk_training: bool = False,
  824. short_chunk_threshold: float = 0.75,
  825. short_chunk_size: int = 25,
  826. left_chunk_size: int = 0,
  827. time_reduction_factor: int = 1,
  828. unified_model_training: bool = False,
  829. default_chunk_size: int = 16,
  830. jitter_range: int = 4,
  831. subsampling_factor: int = 1,
  832. ) -> None:
  833. """Construct an Encoder object."""
  834. super().__init__()
  835. self.embed = StreamingConvInput(
  836. input_size,
  837. output_size,
  838. subsampling_factor,
  839. vgg_like=embed_vgg_like,
  840. output_size=output_size,
  841. )
  842. self.pos_enc = StreamingRelPositionalEncoding(
  843. output_size,
  844. positional_dropout_rate,
  845. )
  846. activation = get_activation(
  847. activation_type
  848. )
  849. pos_wise_args = (
  850. output_size,
  851. linear_units,
  852. positional_dropout_rate,
  853. activation,
  854. )
  855. conv_mod_norm_args = {
  856. "eps": conv_mod_norm_eps,
  857. "momentum": conv_mod_norm_momentum,
  858. }
  859. conv_mod_args = (
  860. output_size,
  861. cnn_module_kernel,
  862. activation,
  863. conv_mod_norm_args,
  864. dynamic_chunk_training or unified_model_training,
  865. )
  866. mult_att_args = (
  867. attention_heads,
  868. output_size,
  869. attention_dropout_rate,
  870. simplified_att_score,
  871. )
  872. fn_modules = []
  873. for _ in range(num_blocks):
  874. module = lambda: ChunkEncoderLayer(
  875. output_size,
  876. RelPositionMultiHeadedAttentionChunk(*mult_att_args),
  877. PositionwiseFeedForward(*pos_wise_args),
  878. PositionwiseFeedForward(*pos_wise_args),
  879. CausalConvolution(*conv_mod_args),
  880. dropout_rate=dropout_rate,
  881. )
  882. fn_modules.append(module)
  883. self.encoders = MultiBlocks(
  884. [fn() for fn in fn_modules],
  885. output_size,
  886. )
  887. self._output_size = output_size
  888. self.dynamic_chunk_training = dynamic_chunk_training
  889. self.short_chunk_threshold = short_chunk_threshold
  890. self.short_chunk_size = short_chunk_size
  891. self.left_chunk_size = left_chunk_size
  892. self.unified_model_training = unified_model_training
  893. self.default_chunk_size = default_chunk_size
  894. self.jitter_range = jitter_range
  895. self.time_reduction_factor = time_reduction_factor
  896. def output_size(self) -> int:
  897. return self._output_size
  898. def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
  899. """Return the corresponding number of sample for a given chunk size, in frames.
  900. Where size is the number of features frames after applying subsampling.
  901. Args:
  902. size: Number of frames after subsampling.
  903. hop_length: Frontend's hop length
  904. Returns:
  905. : Number of raw samples
  906. """
  907. return self.embed.get_size_before_subsampling(size) * hop_length
  908. def get_encoder_input_size(self, size: int) -> int:
  909. """Return the corresponding number of sample for a given chunk size, in frames.
  910. Where size is the number of features frames after applying subsampling.
  911. Args:
  912. size: Number of frames after subsampling.
  913. Returns:
  914. : Number of raw samples
  915. """
  916. return self.embed.get_size_before_subsampling(size)
  917. def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
  918. """Initialize/Reset encoder streaming cache.
  919. Args:
  920. left_context: Number of frames in left context.
  921. device: Device ID.
  922. """
  923. return self.encoders.reset_streaming_cache(left_context, device)
  924. def forward(
  925. self,
  926. x: torch.Tensor,
  927. x_len: torch.Tensor,
  928. ) -> Tuple[torch.Tensor, torch.Tensor]:
  929. """Encode input sequences.
  930. Args:
  931. x: Encoder input features. (B, T_in, F)
  932. x_len: Encoder input features lengths. (B,)
  933. Returns:
  934. x: Encoder outputs. (B, T_out, D_enc)
  935. x_len: Encoder outputs lenghts. (B,)
  936. """
  937. short_status, limit_size = check_short_utt(
  938. self.embed.subsampling_factor, x.size(1)
  939. )
  940. if short_status:
  941. raise TooShortUttError(
  942. f"has {x.size(1)} frames and is too short for subsampling "
  943. + f"(it needs more than {limit_size} frames), return empty results",
  944. x.size(1),
  945. limit_size,
  946. )
  947. mask = make_source_mask(x_len).to(x.device)
  948. if self.unified_model_training:
  949. if self.training:
  950. chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
  951. else:
  952. chunk_size = self.default_chunk_size
  953. x, mask = self.embed(x, mask, chunk_size)
  954. pos_enc = self.pos_enc(x)
  955. chunk_mask = make_chunk_mask(
  956. x.size(1),
  957. chunk_size,
  958. left_chunk_size=self.left_chunk_size,
  959. device=x.device,
  960. )
  961. x_utt = self.encoders(
  962. x,
  963. pos_enc,
  964. mask,
  965. chunk_mask=None,
  966. )
  967. x_chunk = self.encoders(
  968. x,
  969. pos_enc,
  970. mask,
  971. chunk_mask=chunk_mask,
  972. )
  973. olens = mask.eq(0).sum(1)
  974. if self.time_reduction_factor > 1:
  975. x_utt = x_utt[:,::self.time_reduction_factor,:]
  976. x_chunk = x_chunk[:,::self.time_reduction_factor,:]
  977. olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
  978. return x_utt, x_chunk, olens
  979. elif self.dynamic_chunk_training:
  980. max_len = x.size(1)
  981. if self.training:
  982. chunk_size = torch.randint(1, max_len, (1,)).item()
  983. if chunk_size > (max_len * self.short_chunk_threshold):
  984. chunk_size = max_len
  985. else:
  986. chunk_size = (chunk_size % self.short_chunk_size) + 1
  987. else:
  988. chunk_size = self.default_chunk_size
  989. x, mask = self.embed(x, mask, chunk_size)
  990. pos_enc = self.pos_enc(x)
  991. chunk_mask = make_chunk_mask(
  992. x.size(1),
  993. chunk_size,
  994. left_chunk_size=self.left_chunk_size,
  995. device=x.device,
  996. )
  997. else:
  998. x, mask = self.embed(x, mask, None)
  999. pos_enc = self.pos_enc(x)
  1000. chunk_mask = None
  1001. x = self.encoders(
  1002. x,
  1003. pos_enc,
  1004. mask,
  1005. chunk_mask=chunk_mask,
  1006. )
  1007. olens = mask.eq(0).sum(1)
  1008. if self.time_reduction_factor > 1:
  1009. x = x[:,::self.time_reduction_factor,:]
  1010. olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
  1011. return x, olens, None
  1012. def full_utt_forward(
  1013. self,
  1014. x: torch.Tensor,
  1015. x_len: torch.Tensor,
  1016. ) -> Tuple[torch.Tensor, torch.Tensor]:
  1017. """Encode input sequences.
  1018. Args:
  1019. x: Encoder input features. (B, T_in, F)
  1020. x_len: Encoder input features lengths. (B,)
  1021. Returns:
  1022. x: Encoder outputs. (B, T_out, D_enc)
  1023. x_len: Encoder outputs lenghts. (B,)
  1024. """
  1025. short_status, limit_size = check_short_utt(
  1026. self.embed.subsampling_factor, x.size(1)
  1027. )
  1028. if short_status:
  1029. raise TooShortUttError(
  1030. f"has {x.size(1)} frames and is too short for subsampling "
  1031. + f"(it needs more than {limit_size} frames), return empty results",
  1032. x.size(1),
  1033. limit_size,
  1034. )
  1035. mask = make_source_mask(x_len).to(x.device)
  1036. x, mask = self.embed(x, mask, None)
  1037. pos_enc = self.pos_enc(x)
  1038. x_utt = self.encoders(
  1039. x,
  1040. pos_enc,
  1041. mask,
  1042. chunk_mask=None,
  1043. )
  1044. if self.time_reduction_factor > 1:
  1045. x_utt = x_utt[:,::self.time_reduction_factor,:]
  1046. return x_utt
  1047. def simu_chunk_forward(
  1048. self,
  1049. x: torch.Tensor,
  1050. x_len: torch.Tensor,
  1051. chunk_size: int = 16,
  1052. left_context: int = 32,
  1053. right_context: int = 0,
  1054. ) -> torch.Tensor:
  1055. short_status, limit_size = check_short_utt(
  1056. self.embed.subsampling_factor, x.size(1)
  1057. )
  1058. if short_status:
  1059. raise TooShortUttError(
  1060. f"has {x.size(1)} frames and is too short for subsampling "
  1061. + f"(it needs more than {limit_size} frames), return empty results",
  1062. x.size(1),
  1063. limit_size,
  1064. )
  1065. mask = make_source_mask(x_len)
  1066. x, mask = self.embed(x, mask, chunk_size)
  1067. pos_enc = self.pos_enc(x)
  1068. chunk_mask = make_chunk_mask(
  1069. x.size(1),
  1070. chunk_size,
  1071. left_chunk_size=self.left_chunk_size,
  1072. device=x.device,
  1073. )
  1074. x = self.encoders(
  1075. x,
  1076. pos_enc,
  1077. mask,
  1078. chunk_mask=chunk_mask,
  1079. )
  1080. olens = mask.eq(0).sum(1)
  1081. if self.time_reduction_factor > 1:
  1082. x = x[:,::self.time_reduction_factor,:]
  1083. return x
  1084. def chunk_forward(
  1085. self,
  1086. x: torch.Tensor,
  1087. x_len: torch.Tensor,
  1088. processed_frames: torch.tensor,
  1089. chunk_size: int = 16,
  1090. left_context: int = 32,
  1091. right_context: int = 0,
  1092. ) -> torch.Tensor:
  1093. """Encode input sequences as chunks.
  1094. Args:
  1095. x: Encoder input features. (1, T_in, F)
  1096. x_len: Encoder input features lengths. (1,)
  1097. processed_frames: Number of frames already seen.
  1098. left_context: Number of frames in left context.
  1099. right_context: Number of frames in right context.
  1100. Returns:
  1101. x: Encoder outputs. (B, T_out, D_enc)
  1102. """
  1103. mask = make_source_mask(x_len)
  1104. x, mask = self.embed(x, mask, None)
  1105. if left_context > 0:
  1106. processed_mask = (
  1107. torch.arange(left_context, device=x.device)
  1108. .view(1, left_context)
  1109. .flip(1)
  1110. )
  1111. processed_mask = processed_mask >= processed_frames
  1112. mask = torch.cat([processed_mask, mask], dim=1)
  1113. pos_enc = self.pos_enc(x, left_context=left_context)
  1114. x = self.encoders.chunk_forward(
  1115. x,
  1116. pos_enc,
  1117. mask,
  1118. chunk_size=chunk_size,
  1119. left_context=left_context,
  1120. right_context=right_context,
  1121. )
  1122. if right_context > 0:
  1123. x = x[:, 0:-right_context, :]
  1124. if self.time_reduction_factor > 1:
  1125. x = x[:,::self.time_reduction_factor,:]
  1126. return x