ecapa_tdnn_encoder.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. class _BatchNorm1d(nn.Module):
  6. def __init__(
  7. self,
  8. input_shape=None,
  9. input_size=None,
  10. eps=1e-05,
  11. momentum=0.1,
  12. affine=True,
  13. track_running_stats=True,
  14. combine_batch_time=False,
  15. skip_transpose=False,
  16. ):
  17. super().__init__()
  18. self.combine_batch_time = combine_batch_time
  19. self.skip_transpose = skip_transpose
  20. if input_size is None and skip_transpose:
  21. input_size = input_shape[1]
  22. elif input_size is None:
  23. input_size = input_shape[-1]
  24. self.norm = nn.BatchNorm1d(
  25. input_size,
  26. eps=eps,
  27. momentum=momentum,
  28. affine=affine,
  29. track_running_stats=track_running_stats,
  30. )
  31. def forward(self, x):
  32. shape_or = x.shape
  33. if self.combine_batch_time:
  34. if x.ndim == 3:
  35. x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
  36. else:
  37. x = x.reshape(
  38. shape_or[0] * shape_or[1], shape_or[3], shape_or[2]
  39. )
  40. elif not self.skip_transpose:
  41. x = x.transpose(-1, 1)
  42. x_n = self.norm(x)
  43. if self.combine_batch_time:
  44. x_n = x_n.reshape(shape_or)
  45. elif not self.skip_transpose:
  46. x_n = x_n.transpose(1, -1)
  47. return x_n
  48. class _Conv1d(nn.Module):
  49. def __init__(
  50. self,
  51. out_channels,
  52. kernel_size,
  53. input_shape=None,
  54. in_channels=None,
  55. stride=1,
  56. dilation=1,
  57. padding="same",
  58. groups=1,
  59. bias=True,
  60. padding_mode="reflect",
  61. skip_transpose=False,
  62. ):
  63. super().__init__()
  64. self.kernel_size = kernel_size
  65. self.stride = stride
  66. self.dilation = dilation
  67. self.padding = padding
  68. self.padding_mode = padding_mode
  69. self.unsqueeze = False
  70. self.skip_transpose = skip_transpose
  71. if input_shape is None and in_channels is None:
  72. raise ValueError("Must provide one of input_shape or in_channels")
  73. if in_channels is None:
  74. in_channels = self._check_input_shape(input_shape)
  75. self.conv = nn.Conv1d(
  76. in_channels,
  77. out_channels,
  78. self.kernel_size,
  79. stride=self.stride,
  80. dilation=self.dilation,
  81. padding=0,
  82. groups=groups,
  83. bias=bias,
  84. )
  85. def forward(self, x):
  86. if not self.skip_transpose:
  87. x = x.transpose(1, -1)
  88. if self.unsqueeze:
  89. x = x.unsqueeze(1)
  90. if self.padding == "same":
  91. x = self._manage_padding(
  92. x, self.kernel_size, self.dilation, self.stride
  93. )
  94. elif self.padding == "causal":
  95. num_pad = (self.kernel_size - 1) * self.dilation
  96. x = F.pad(x, (num_pad, 0))
  97. elif self.padding == "valid":
  98. pass
  99. else:
  100. raise ValueError(
  101. "Padding must be 'same', 'valid' or 'causal'. Got "
  102. + self.padding
  103. )
  104. wx = self.conv(x)
  105. if self.unsqueeze:
  106. wx = wx.squeeze(1)
  107. if not self.skip_transpose:
  108. wx = wx.transpose(1, -1)
  109. return wx
  110. def _manage_padding(
  111. self, x, kernel_size: int, dilation: int, stride: int,
  112. ):
  113. # Detecting input shape
  114. L_in = x.shape[-1]
  115. # Time padding
  116. padding = get_padding_elem(L_in, stride, kernel_size, dilation)
  117. # Applying padding
  118. x = F.pad(x, padding, mode=self.padding_mode)
  119. return x
  120. def _check_input_shape(self, shape):
  121. """Checks the input shape and returns the number of input channels.
  122. """
  123. if len(shape) == 2:
  124. self.unsqueeze = True
  125. in_channels = 1
  126. elif self.skip_transpose:
  127. in_channels = shape[1]
  128. elif len(shape) == 3:
  129. in_channels = shape[2]
  130. else:
  131. raise ValueError(
  132. "conv1d expects 2d, 3d inputs. Got " + str(len(shape))
  133. )
  134. # Kernel size must be odd
  135. if self.kernel_size % 2 == 0:
  136. raise ValueError(
  137. "The field kernel size must be an odd number. Got %s."
  138. % (self.kernel_size)
  139. )
  140. return in_channels
  141. def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
  142. if stride > 1:
  143. n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
  144. L_out = stride * (n_steps - 1) + kernel_size * dilation
  145. padding = [kernel_size // 2, kernel_size // 2]
  146. else:
  147. L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
  148. padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
  149. return padding
  150. # Skip transpose as much as possible for efficiency
  151. class Conv1d(_Conv1d):
  152. def __init__(self, *args, **kwargs):
  153. super().__init__(skip_transpose=True, *args, **kwargs)
  154. class BatchNorm1d(_BatchNorm1d):
  155. def __init__(self, *args, **kwargs):
  156. super().__init__(skip_transpose=True, *args, **kwargs)
  157. def length_to_mask(length, max_len=None, dtype=None, device=None):
  158. assert len(length.shape) == 1
  159. if max_len is None:
  160. max_len = length.max().long().item() # using arange to generate mask
  161. mask = torch.arange(
  162. max_len, device=length.device, dtype=length.dtype
  163. ).expand(len(length), max_len) < length.unsqueeze(1)
  164. if dtype is None:
  165. dtype = length.dtype
  166. if device is None:
  167. device = length.device
  168. mask = torch.as_tensor(mask, dtype=dtype, device=device)
  169. return mask
  170. class TDNNBlock(nn.Module):
  171. def __init__(
  172. self,
  173. in_channels,
  174. out_channels,
  175. kernel_size,
  176. dilation,
  177. activation=nn.ReLU,
  178. groups=1,
  179. ):
  180. super(TDNNBlock, self).__init__()
  181. self.conv = Conv1d(
  182. in_channels=in_channels,
  183. out_channels=out_channels,
  184. kernel_size=kernel_size,
  185. dilation=dilation,
  186. groups=groups,
  187. )
  188. self.activation = activation()
  189. self.norm = BatchNorm1d(input_size=out_channels)
  190. def forward(self, x):
  191. return self.norm(self.activation(self.conv(x)))
  192. class Res2NetBlock(torch.nn.Module):
  193. """An implementation of Res2NetBlock w/ dilation.
  194. Arguments
  195. ---------
  196. in_channels : int
  197. The number of channels expected in the input.
  198. out_channels : int
  199. The number of output channels.
  200. scale : int
  201. The scale of the Res2Net block.
  202. kernel_size: int
  203. The kernel size of the Res2Net block.
  204. dilation : int
  205. The dilation of the Res2Net block.
  206. Example
  207. -------
  208. >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
  209. >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
  210. >>> out_tensor = layer(inp_tensor).transpose(1, 2)
  211. >>> out_tensor.shape
  212. torch.Size([8, 120, 64])
  213. """
  214. def __init__(
  215. self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1
  216. ):
  217. super(Res2NetBlock, self).__init__()
  218. assert in_channels % scale == 0
  219. assert out_channels % scale == 0
  220. in_channel = in_channels // scale
  221. hidden_channel = out_channels // scale
  222. self.blocks = nn.ModuleList(
  223. [
  224. TDNNBlock(
  225. in_channel,
  226. hidden_channel,
  227. kernel_size=kernel_size,
  228. dilation=dilation,
  229. )
  230. for i in range(scale - 1)
  231. ]
  232. )
  233. self.scale = scale
  234. def forward(self, x):
  235. y = []
  236. for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
  237. if i == 0:
  238. y_i = x_i
  239. elif i == 1:
  240. y_i = self.blocks[i - 1](x_i)
  241. else:
  242. y_i = self.blocks[i - 1](x_i + y_i)
  243. y.append(y_i)
  244. y = torch.cat(y, dim=1)
  245. return y
  246. class SEBlock(nn.Module):
  247. """An implementation of squeeze-and-excitation block.
  248. Arguments
  249. ---------
  250. in_channels : int
  251. The number of input channels.
  252. se_channels : int
  253. The number of output channels after squeeze.
  254. out_channels : int
  255. The number of output channels.
  256. Example
  257. -------
  258. >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
  259. >>> se_layer = SEBlock(64, 16, 64)
  260. >>> lengths = torch.rand((8,))
  261. >>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2)
  262. >>> out_tensor.shape
  263. torch.Size([8, 120, 64])
  264. """
  265. def __init__(self, in_channels, se_channels, out_channels):
  266. super(SEBlock, self).__init__()
  267. self.conv1 = Conv1d(
  268. in_channels=in_channels, out_channels=se_channels, kernel_size=1
  269. )
  270. self.relu = torch.nn.ReLU(inplace=True)
  271. self.conv2 = Conv1d(
  272. in_channels=se_channels, out_channels=out_channels, kernel_size=1
  273. )
  274. self.sigmoid = torch.nn.Sigmoid()
  275. def forward(self, x, lengths=None):
  276. L = x.shape[-1]
  277. if lengths is not None:
  278. mask = length_to_mask(lengths * L, max_len=L, device=x.device)
  279. mask = mask.unsqueeze(1)
  280. total = mask.sum(dim=2, keepdim=True)
  281. s = (x * mask).sum(dim=2, keepdim=True) / total
  282. else:
  283. s = x.mean(dim=2, keepdim=True)
  284. s = self.relu(self.conv1(s))
  285. s = self.sigmoid(self.conv2(s))
  286. return s * x
  287. class AttentiveStatisticsPooling(nn.Module):
  288. """This class implements an attentive statistic pooling layer for each channel.
  289. It returns the concatenated mean and std of the input tensor.
  290. Arguments
  291. ---------
  292. channels: int
  293. The number of input channels.
  294. attention_channels: int
  295. The number of attention channels.
  296. Example
  297. -------
  298. >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
  299. >>> asp_layer = AttentiveStatisticsPooling(64)
  300. >>> lengths = torch.rand((8,))
  301. >>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2)
  302. >>> out_tensor.shape
  303. torch.Size([8, 1, 128])
  304. """
  305. def __init__(self, channels, attention_channels=128, global_context=True):
  306. super().__init__()
  307. self.eps = 1e-12
  308. self.global_context = global_context
  309. if global_context:
  310. self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
  311. else:
  312. self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
  313. self.tanh = nn.Tanh()
  314. self.conv = Conv1d(
  315. in_channels=attention_channels, out_channels=channels, kernel_size=1
  316. )
  317. def forward(self, x, lengths=None):
  318. """Calculates mean and std for a batch (input tensor).
  319. Arguments
  320. ---------
  321. x : torch.Tensor
  322. Tensor of shape [N, C, L].
  323. """
  324. L = x.shape[-1]
  325. def _compute_statistics(x, m, dim=2, eps=self.eps):
  326. mean = (m * x).sum(dim)
  327. std = torch.sqrt(
  328. (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)
  329. )
  330. return mean, std
  331. if lengths is None:
  332. lengths = torch.ones(x.shape[0], device=x.device)
  333. # Make binary mask of shape [N, 1, L]
  334. mask = length_to_mask(lengths * L, max_len=L, device=x.device)
  335. mask = mask.unsqueeze(1)
  336. # Expand the temporal context of the pooling layer by allowing the
  337. # self-attention to look at global properties of the utterance.
  338. if self.global_context:
  339. # torch.std is unstable for backward computation
  340. # https://github.com/pytorch/pytorch/issues/4320
  341. total = mask.sum(dim=2, keepdim=True).float()
  342. mean, std = _compute_statistics(x, mask / total)
  343. mean = mean.unsqueeze(2).repeat(1, 1, L)
  344. std = std.unsqueeze(2).repeat(1, 1, L)
  345. attn = torch.cat([x, mean, std], dim=1)
  346. else:
  347. attn = x
  348. # Apply layers
  349. attn = self.conv(self.tanh(self.tdnn(attn)))
  350. # Filter out zero-paddings
  351. attn = attn.masked_fill(mask == 0, float("-inf"))
  352. attn = F.softmax(attn, dim=2)
  353. mean, std = _compute_statistics(x, attn)
  354. # Append mean and std of the batch
  355. pooled_stats = torch.cat((mean, std), dim=1)
  356. pooled_stats = pooled_stats.unsqueeze(2)
  357. return pooled_stats
  358. class SERes2NetBlock(nn.Module):
  359. """An implementation of building block in ECAPA-TDNN, i.e.,
  360. TDNN-Res2Net-TDNN-SEBlock.
  361. Arguments
  362. ----------
  363. out_channels: int
  364. The number of output channels.
  365. res2net_scale: int
  366. The scale of the Res2Net block.
  367. kernel_size: int
  368. The kernel size of the TDNN blocks.
  369. dilation: int
  370. The dilation of the Res2Net block.
  371. activation : torch class
  372. A class for constructing the activation layers.
  373. groups: int
  374. Number of blocked connections from input channels to output channels.
  375. Example
  376. -------
  377. >>> x = torch.rand(8, 120, 64).transpose(1, 2)
  378. >>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
  379. >>> out = conv(x).transpose(1, 2)
  380. >>> out.shape
  381. torch.Size([8, 120, 64])
  382. """
  383. def __init__(
  384. self,
  385. in_channels,
  386. out_channels,
  387. res2net_scale=8,
  388. se_channels=128,
  389. kernel_size=1,
  390. dilation=1,
  391. activation=torch.nn.ReLU,
  392. groups=1,
  393. ):
  394. super().__init__()
  395. self.out_channels = out_channels
  396. self.tdnn1 = TDNNBlock(
  397. in_channels,
  398. out_channels,
  399. kernel_size=1,
  400. dilation=1,
  401. activation=activation,
  402. groups=groups,
  403. )
  404. self.res2net_block = Res2NetBlock(
  405. out_channels, out_channels, res2net_scale, kernel_size, dilation
  406. )
  407. self.tdnn2 = TDNNBlock(
  408. out_channels,
  409. out_channels,
  410. kernel_size=1,
  411. dilation=1,
  412. activation=activation,
  413. groups=groups,
  414. )
  415. self.se_block = SEBlock(out_channels, se_channels, out_channels)
  416. self.shortcut = None
  417. if in_channels != out_channels:
  418. self.shortcut = Conv1d(
  419. in_channels=in_channels,
  420. out_channels=out_channels,
  421. kernel_size=1,
  422. )
  423. def forward(self, x, lengths=None):
  424. residual = x
  425. if self.shortcut:
  426. residual = self.shortcut(x)
  427. x = self.tdnn1(x)
  428. x = self.res2net_block(x)
  429. x = self.tdnn2(x)
  430. x = self.se_block(x, lengths)
  431. return x + residual
  432. class ECAPA_TDNN(torch.nn.Module):
  433. """An implementation of the speaker embedding model in a paper.
  434. "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
  435. TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
  436. Arguments
  437. ---------
  438. activation : torch class
  439. A class for constructing the activation layers.
  440. channels : list of ints
  441. Output channels for TDNN/SERes2Net layer.
  442. kernel_sizes : list of ints
  443. List of kernel sizes for each layer.
  444. dilations : list of ints
  445. List of dilations for kernels in each layer.
  446. lin_neurons : int
  447. Number of neurons in linear layers.
  448. groups : list of ints
  449. List of groups for kernels in each layer.
  450. Example
  451. -------
  452. >>> input_feats = torch.rand([5, 120, 80])
  453. >>> compute_embedding = ECAPA_TDNN(80, lin_neurons=192)
  454. >>> outputs = compute_embedding(input_feats)
  455. >>> outputs.shape
  456. torch.Size([5, 1, 192])
  457. """
  458. def __init__(
  459. self,
  460. input_size,
  461. lin_neurons=192,
  462. activation=torch.nn.ReLU,
  463. channels=[512, 512, 512, 512, 1536],
  464. kernel_sizes=[5, 3, 3, 3, 1],
  465. dilations=[1, 2, 3, 4, 1],
  466. attention_channels=128,
  467. res2net_scale=8,
  468. se_channels=128,
  469. global_context=True,
  470. groups=[1, 1, 1, 1, 1],
  471. window_size=20,
  472. window_shift=1,
  473. ):
  474. super().__init__()
  475. assert len(channels) == len(kernel_sizes)
  476. assert len(channels) == len(dilations)
  477. self.channels = channels
  478. self.blocks = nn.ModuleList()
  479. self.window_size = window_size
  480. self.window_shift = window_shift
  481. # The initial TDNN layer
  482. self.blocks.append(
  483. TDNNBlock(
  484. input_size,
  485. channels[0],
  486. kernel_sizes[0],
  487. dilations[0],
  488. activation,
  489. groups[0],
  490. )
  491. )
  492. # SE-Res2Net layers
  493. for i in range(1, len(channels) - 1):
  494. self.blocks.append(
  495. SERes2NetBlock(
  496. channels[i - 1],
  497. channels[i],
  498. res2net_scale=res2net_scale,
  499. se_channels=se_channels,
  500. kernel_size=kernel_sizes[i],
  501. dilation=dilations[i],
  502. activation=activation,
  503. groups=groups[i],
  504. )
  505. )
  506. # Multi-layer feature aggregation
  507. self.mfa = TDNNBlock(
  508. channels[-1],
  509. channels[-1],
  510. kernel_sizes[-1],
  511. dilations[-1],
  512. activation,
  513. groups=groups[-1],
  514. )
  515. # Attentive Statistical Pooling
  516. self.asp = AttentiveStatisticsPooling(
  517. channels[-1],
  518. attention_channels=attention_channels,
  519. global_context=global_context,
  520. )
  521. self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
  522. # Final linear transformation
  523. self.fc = Conv1d(
  524. in_channels=channels[-1] * 2,
  525. out_channels=lin_neurons,
  526. kernel_size=1,
  527. )
  528. def windowed_pooling(self, x, lengths=None):
  529. # x: Batch, Channel, Time
  530. tt = x.shape[2]
  531. num_chunk = int(math.ceil(tt / self.window_shift))
  532. pad = self.window_size // 2
  533. x = F.pad(x, (pad, pad, 0, 0), "reflect")
  534. stat_list = []
  535. for i in range(num_chunk):
  536. # B x C
  537. st, ed = i * self.window_shift, i * self.window_shift + self.window_size
  538. x = self.asp(x[:, :, st: ed],
  539. lengths=torch.clamp(lengths - i, 0, self.window_size)
  540. if lengths is not None else None)
  541. x = self.asp_bn(x)
  542. x = self.fc(x)
  543. stat_list.append(x)
  544. return torch.cat(stat_list, dim=2)
  545. def forward(self, x, lengths=None):
  546. """Returns the embedding vector.
  547. Arguments
  548. ---------
  549. x : torch.Tensor
  550. Tensor of shape (batch, time, channel).
  551. lengths: torch.Tensor
  552. Tensor of shape (batch, )
  553. """
  554. # Minimize transpose for efficiency
  555. x = x.transpose(1, 2)
  556. xl = []
  557. for layer in self.blocks:
  558. try:
  559. x = layer(x, lengths=lengths)
  560. except TypeError:
  561. x = layer(x)
  562. xl.append(x)
  563. # Multi-layer feature aggregation
  564. x = torch.cat(xl[1:], dim=1)
  565. x = self.mfa(x)
  566. if self.window_size is None:
  567. # Attentive Statistical Pooling
  568. x = self.asp(x, lengths=lengths)
  569. x = self.asp_bn(x)
  570. # Final linear transformation
  571. x = self.fc(x)
  572. # x = x.transpose(1, 2)
  573. x = x.squeeze(2) # -> B, C
  574. else:
  575. x = self.windowed_pooling(x, lengths)
  576. x = x.transpose(1, 2) # -> B, T, C
  577. return x