components.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. # Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker)
  6. import torch
  7. import torch.nn.functional as F
  8. import torch.utils.checkpoint as cp
  9. class BasicResBlock(torch.nn.Module):
  10. expansion = 1
  11. def __init__(self, in_planes, planes, stride=1):
  12. super(BasicResBlock, self).__init__()
  13. self.conv1 = torch.nn.Conv2d(in_planes,
  14. planes,
  15. kernel_size=3,
  16. stride=(stride, 1),
  17. padding=1,
  18. bias=False)
  19. self.bn1 = torch.nn.BatchNorm2d(planes)
  20. self.conv2 = torch.nn.Conv2d(planes,
  21. planes,
  22. kernel_size=3,
  23. stride=1,
  24. padding=1,
  25. bias=False)
  26. self.bn2 = torch.nn.BatchNorm2d(planes)
  27. self.shortcut = torch.nn.Sequential()
  28. if stride != 1 or in_planes != self.expansion * planes:
  29. self.shortcut = torch.nn.Sequential(
  30. torch.nn.Conv2d(in_planes,
  31. self.expansion * planes,
  32. kernel_size=1,
  33. stride=(stride, 1),
  34. bias=False),
  35. torch.nn.BatchNorm2d(self.expansion * planes))
  36. def forward(self, x):
  37. out = F.relu(self.bn1(self.conv1(x)))
  38. out = self.bn2(self.conv2(out))
  39. out += self.shortcut(x)
  40. out = F.relu(out)
  41. return out
  42. class FCM(torch.nn.Module):
  43. def __init__(self,
  44. block=BasicResBlock,
  45. num_blocks=[2, 2],
  46. m_channels=32,
  47. feat_dim=80):
  48. super(FCM, self).__init__()
  49. self.in_planes = m_channels
  50. self.conv1 = torch.nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
  51. self.bn1 = torch.nn.BatchNorm2d(m_channels)
  52. self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
  53. self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
  54. self.conv2 = torch.nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
  55. self.bn2 = torch.nn.BatchNorm2d(m_channels)
  56. self.out_channels = m_channels * (feat_dim // 8)
  57. def _make_layer(self, block, planes, num_blocks, stride):
  58. strides = [stride] + [1] * (num_blocks - 1)
  59. layers = []
  60. for stride in strides:
  61. layers.append(block(self.in_planes, planes, stride))
  62. self.in_planes = planes * block.expansion
  63. return torch.nn.Sequential(*layers)
  64. def forward(self, x):
  65. x = x.unsqueeze(1)
  66. out = F.relu(self.bn1(self.conv1(x)))
  67. out = self.layer1(out)
  68. out = self.layer2(out)
  69. out = F.relu(self.bn2(self.conv2(out)))
  70. shape = out.shape
  71. out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
  72. return out
  73. def get_nonlinear(config_str, channels):
  74. nonlinear = torch.nn.Sequential()
  75. for name in config_str.split('-'):
  76. if name == 'relu':
  77. nonlinear.add_module('relu', torch.nn.ReLU(inplace=True))
  78. elif name == 'prelu':
  79. nonlinear.add_module('prelu', torch.nn.PReLU(channels))
  80. elif name == 'batchnorm':
  81. nonlinear.add_module('batchnorm', torch.nn.BatchNorm1d(channels))
  82. elif name == 'batchnorm_':
  83. nonlinear.add_module('batchnorm',
  84. torch.nn.BatchNorm1d(channels, affine=False))
  85. else:
  86. raise ValueError('Unexpected module ({}).'.format(name))
  87. return nonlinear
  88. def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
  89. mean = x.mean(dim=dim)
  90. std = x.std(dim=dim, unbiased=unbiased)
  91. stats = torch.cat([mean, std], dim=-1)
  92. if keepdim:
  93. stats = stats.unsqueeze(dim=dim)
  94. return stats
  95. class StatsPool(torch.nn.Module):
  96. def forward(self, x):
  97. return statistics_pooling(x)
  98. class TDNNLayer(torch.nn.Module):
  99. def __init__(self,
  100. in_channels,
  101. out_channels,
  102. kernel_size,
  103. stride=1,
  104. padding=0,
  105. dilation=1,
  106. bias=False,
  107. config_str='batchnorm-relu'):
  108. super(TDNNLayer, self).__init__()
  109. if padding < 0:
  110. assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
  111. kernel_size)
  112. padding = (kernel_size - 1) // 2 * dilation
  113. self.linear = torch.nn.Conv1d(in_channels,
  114. out_channels,
  115. kernel_size,
  116. stride=stride,
  117. padding=padding,
  118. dilation=dilation,
  119. bias=bias)
  120. self.nonlinear = get_nonlinear(config_str, out_channels)
  121. def forward(self, x):
  122. x = self.linear(x)
  123. x = self.nonlinear(x)
  124. return x
  125. class CAMLayer(torch.nn.Module):
  126. def __init__(self,
  127. bn_channels,
  128. out_channels,
  129. kernel_size,
  130. stride,
  131. padding,
  132. dilation,
  133. bias,
  134. reduction=2):
  135. super(CAMLayer, self).__init__()
  136. self.linear_local = torch.nn.Conv1d(bn_channels,
  137. out_channels,
  138. kernel_size,
  139. stride=stride,
  140. padding=padding,
  141. dilation=dilation,
  142. bias=bias)
  143. self.linear1 = torch.nn.Conv1d(bn_channels, bn_channels // reduction, 1)
  144. self.relu = torch.nn.ReLU(inplace=True)
  145. self.linear2 = torch.nn.Conv1d(bn_channels // reduction, out_channels, 1)
  146. self.sigmoid = torch.nn.Sigmoid()
  147. def forward(self, x):
  148. y = self.linear_local(x)
  149. context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
  150. context = self.relu(self.linear1(context))
  151. m = self.sigmoid(self.linear2(context))
  152. return y * m
  153. def seg_pooling(self, x, seg_len=100, stype='avg'):
  154. if stype == 'avg':
  155. seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
  156. elif stype == 'max':
  157. seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
  158. else:
  159. raise ValueError('Wrong segment pooling type.')
  160. shape = seg.shape
  161. seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
  162. seg = seg[..., :x.shape[-1]]
  163. return seg
  164. class CAMDenseTDNNLayer(torch.nn.Module):
  165. def __init__(self,
  166. in_channels,
  167. out_channels,
  168. bn_channels,
  169. kernel_size,
  170. stride=1,
  171. dilation=1,
  172. bias=False,
  173. config_str='batchnorm-relu',
  174. memory_efficient=False):
  175. super(CAMDenseTDNNLayer, self).__init__()
  176. assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
  177. kernel_size)
  178. padding = (kernel_size - 1) // 2 * dilation
  179. self.memory_efficient = memory_efficient
  180. self.nonlinear1 = get_nonlinear(config_str, in_channels)
  181. self.linear1 = torch.nn.Conv1d(in_channels, bn_channels, 1, bias=False)
  182. self.nonlinear2 = get_nonlinear(config_str, bn_channels)
  183. self.cam_layer = CAMLayer(bn_channels,
  184. out_channels,
  185. kernel_size,
  186. stride=stride,
  187. padding=padding,
  188. dilation=dilation,
  189. bias=bias)
  190. def bn_function(self, x):
  191. return self.linear1(self.nonlinear1(x))
  192. def forward(self, x):
  193. if self.training and self.memory_efficient:
  194. x = cp.checkpoint(self.bn_function, x)
  195. else:
  196. x = self.bn_function(x)
  197. x = self.cam_layer(self.nonlinear2(x))
  198. return x
  199. class CAMDenseTDNNBlock(torch.nn.ModuleList):
  200. def __init__(self,
  201. num_layers,
  202. in_channels,
  203. out_channels,
  204. bn_channels,
  205. kernel_size,
  206. stride=1,
  207. dilation=1,
  208. bias=False,
  209. config_str='batchnorm-relu',
  210. memory_efficient=False):
  211. super(CAMDenseTDNNBlock, self).__init__()
  212. for i in range(num_layers):
  213. layer = CAMDenseTDNNLayer(in_channels=in_channels + i * out_channels,
  214. out_channels=out_channels,
  215. bn_channels=bn_channels,
  216. kernel_size=kernel_size,
  217. stride=stride,
  218. dilation=dilation,
  219. bias=bias,
  220. config_str=config_str,
  221. memory_efficient=memory_efficient)
  222. self.add_module('tdnnd%d' % (i + 1), layer)
  223. def forward(self, x):
  224. for layer in self:
  225. x = torch.cat([x, layer(x)], dim=1)
  226. return x
  227. class TransitLayer(torch.nn.Module):
  228. def __init__(self,
  229. in_channels,
  230. out_channels,
  231. bias=True,
  232. config_str='batchnorm-relu'):
  233. super(TransitLayer, self).__init__()
  234. self.nonlinear = get_nonlinear(config_str, in_channels)
  235. self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias)
  236. def forward(self, x):
  237. x = self.nonlinear(x)
  238. x = self.linear(x)
  239. return x
  240. class DenseLayer(torch.nn.Module):
  241. def __init__(self,
  242. in_channels,
  243. out_channels,
  244. bias=False,
  245. config_str='batchnorm-relu'):
  246. super(DenseLayer, self).__init__()
  247. self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias)
  248. self.nonlinear = get_nonlinear(config_str, out_channels)
  249. def forward(self, x):
  250. if len(x.shape) == 2:
  251. x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
  252. else:
  253. x = self.linear(x)
  254. x = self.nonlinear(x)
  255. return x