eres2net.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
  2. # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """ Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
  4. ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
  5. The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
  6. The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
  7. ERes2Net-Large is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
  8. recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
  9. """
  10. import math
  11. import torch
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. import funasr.models.sond.pooling.pooling_layers as pooling_layers
  15. from funasr.models.eres2net.fusion import AFF
  16. class ReLU(nn.Hardtanh):
  17. def __init__(self, inplace=False):
  18. super(ReLU, self).__init__(0, 20, inplace)
  19. def __repr__(self):
  20. inplace_str = 'inplace' if self.inplace else ''
  21. return self.__class__.__name__ + ' (' \
  22. + inplace_str + ')'
  23. def conv1x1(in_planes, out_planes, stride=1):
  24. "1x1 convolution without padding"
  25. return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
  26. padding=0, bias=False)
  27. def conv3x3(in_planes, out_planes, stride=1):
  28. "3x3 convolution with padding"
  29. return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  30. padding=1, bias=False)
  31. class BasicBlockERes2Net(nn.Module):
  32. expansion = 2
  33. def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
  34. super(BasicBlockERes2Net, self).__init__()
  35. width = int(math.floor(planes * (baseWidth / 64.0)))
  36. self.conv1 = conv1x1(in_planes, width * scale, stride)
  37. self.bn1 = nn.BatchNorm2d(width * scale)
  38. self.nums = scale
  39. convs = []
  40. bns = []
  41. for i in range(self.nums):
  42. convs.append(conv3x3(width, width))
  43. bns.append(nn.BatchNorm2d(width))
  44. self.convs = nn.ModuleList(convs)
  45. self.bns = nn.ModuleList(bns)
  46. self.relu = ReLU(inplace=True)
  47. self.conv3 = conv1x1(width * scale, planes * self.expansion)
  48. self.bn3 = nn.BatchNorm2d(planes * self.expansion)
  49. self.shortcut = nn.Sequential()
  50. if stride != 1 or in_planes != self.expansion * planes:
  51. self.shortcut = nn.Sequential(
  52. nn.Conv2d(in_planes,
  53. self.expansion * planes,
  54. kernel_size=1,
  55. stride=stride,
  56. bias=False),
  57. nn.BatchNorm2d(self.expansion * planes))
  58. self.stride = stride
  59. self.width = width
  60. self.scale = scale
  61. def forward(self, x):
  62. residual = x
  63. out = self.conv1(x)
  64. out = self.bn1(out)
  65. out = self.relu(out)
  66. spx = torch.split(out, self.width, 1)
  67. for i in range(self.nums):
  68. if i == 0:
  69. sp = spx[i]
  70. else:
  71. sp = sp + spx[i]
  72. sp = self.convs[i](sp)
  73. sp = self.relu(self.bns[i](sp))
  74. if i == 0:
  75. out = sp
  76. else:
  77. out = torch.cat((out, sp), 1)
  78. out = self.conv3(out)
  79. out = self.bn3(out)
  80. residual = self.shortcut(x)
  81. out += residual
  82. out = self.relu(out)
  83. return out
  84. class BasicBlockERes2Net_diff_AFF(nn.Module):
  85. expansion = 2
  86. def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
  87. super(BasicBlockERes2Net_diff_AFF, self).__init__()
  88. width = int(math.floor(planes * (baseWidth / 64.0)))
  89. self.conv1 = conv1x1(in_planes, width * scale, stride)
  90. self.bn1 = nn.BatchNorm2d(width * scale)
  91. self.nums = scale
  92. convs = []
  93. fuse_models = []
  94. bns = []
  95. for i in range(self.nums):
  96. convs.append(conv3x3(width, width))
  97. bns.append(nn.BatchNorm2d(width))
  98. for j in range(self.nums - 1):
  99. fuse_models.append(AFF(channels=width))
  100. self.convs = nn.ModuleList(convs)
  101. self.bns = nn.ModuleList(bns)
  102. self.fuse_models = nn.ModuleList(fuse_models)
  103. self.relu = ReLU(inplace=True)
  104. self.conv3 = conv1x1(width * scale, planes * self.expansion)
  105. self.bn3 = nn.BatchNorm2d(planes * self.expansion)
  106. self.shortcut = nn.Sequential()
  107. if stride != 1 or in_planes != self.expansion * planes:
  108. self.shortcut = nn.Sequential(
  109. nn.Conv2d(in_planes,
  110. self.expansion * planes,
  111. kernel_size=1,
  112. stride=stride,
  113. bias=False),
  114. nn.BatchNorm2d(self.expansion * planes))
  115. self.stride = stride
  116. self.width = width
  117. self.scale = scale
  118. def forward(self, x):
  119. residual = x
  120. out = self.conv1(x)
  121. out = self.bn1(out)
  122. out = self.relu(out)
  123. spx = torch.split(out, self.width, 1)
  124. for i in range(self.nums):
  125. if i == 0:
  126. sp = spx[i]
  127. else:
  128. sp = self.fuse_models[i - 1](sp, spx[i])
  129. sp = self.convs[i](sp)
  130. sp = self.relu(self.bns[i](sp))
  131. if i == 0:
  132. out = sp
  133. else:
  134. out = torch.cat((out, sp), 1)
  135. out = self.conv3(out)
  136. out = self.bn3(out)
  137. residual = self.shortcut(x)
  138. out += residual
  139. out = self.relu(out)
  140. return out
  141. class ERes2Net(nn.Module):
  142. def __init__(self,
  143. block=BasicBlockERes2Net,
  144. block_fuse=BasicBlockERes2Net_diff_AFF,
  145. num_blocks=[3, 4, 6, 3],
  146. m_channels=32,
  147. feat_dim=80,
  148. embedding_size=192,
  149. pooling_func='TSTP',
  150. two_emb_layer=False):
  151. super(ERes2Net, self).__init__()
  152. self.in_planes = m_channels
  153. self.feat_dim = feat_dim
  154. self.embedding_size = embedding_size
  155. self.stats_dim = int(feat_dim / 8) * m_channels * 8
  156. self.two_emb_layer = two_emb_layer
  157. self.conv1 = nn.Conv2d(1,
  158. m_channels,
  159. kernel_size=3,
  160. stride=1,
  161. padding=1,
  162. bias=False)
  163. self.bn1 = nn.BatchNorm2d(m_channels)
  164. self.layer1 = self._make_layer(block,
  165. m_channels,
  166. num_blocks[0],
  167. stride=1)
  168. self.layer2 = self._make_layer(block,
  169. m_channels * 2,
  170. num_blocks[1],
  171. stride=2)
  172. self.layer3 = self._make_layer(block_fuse,
  173. m_channels * 4,
  174. num_blocks[2],
  175. stride=2)
  176. self.layer4 = self._make_layer(block_fuse,
  177. m_channels * 8,
  178. num_blocks[3],
  179. stride=2)
  180. # Downsampling module for each layer
  181. self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1,
  182. bias=False)
  183. self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2,
  184. bias=False)
  185. self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2,
  186. bias=False)
  187. # Bottom-up fusion module
  188. self.fuse_mode12 = AFF(channels=m_channels * 4)
  189. self.fuse_mode123 = AFF(channels=m_channels * 8)
  190. self.fuse_mode1234 = AFF(channels=m_channels * 16)
  191. self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
  192. self.pool = getattr(pooling_layers, pooling_func)(
  193. in_dim=self.stats_dim * block.expansion)
  194. self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
  195. embedding_size)
  196. if self.two_emb_layer:
  197. self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
  198. self.seg_2 = nn.Linear(embedding_size, embedding_size)
  199. else:
  200. self.seg_bn_1 = nn.Identity()
  201. self.seg_2 = nn.Identity()
  202. def _make_layer(self, block, planes, num_blocks, stride):
  203. strides = [stride] + [1] * (num_blocks - 1)
  204. layers = []
  205. for stride in strides:
  206. layers.append(block(self.in_planes, planes, stride))
  207. self.in_planes = planes * block.expansion
  208. return nn.Sequential(*layers)
  209. def forward(self, x):
  210. x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
  211. x = x.unsqueeze_(1)
  212. out = F.relu(self.bn1(self.conv1(x)))
  213. out1 = self.layer1(out)
  214. out2 = self.layer2(out1)
  215. out1_downsample = self.layer1_downsample(out1)
  216. fuse_out12 = self.fuse_mode12(out2, out1_downsample)
  217. out3 = self.layer3(out2)
  218. fuse_out12_downsample = self.layer2_downsample(fuse_out12)
  219. fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
  220. out4 = self.layer4(out3)
  221. fuse_out123_downsample = self.layer3_downsample(fuse_out123)
  222. fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
  223. stats = self.pool(fuse_out1234)
  224. embed_a = self.seg_1(stats)
  225. if self.two_emb_layer:
  226. out = F.relu(embed_a)
  227. out = self.seg_bn_1(out)
  228. embed_b = self.seg_2(out)
  229. return embed_b
  230. else:
  231. return embed_a
  232. class BasicBlockRes2Net(nn.Module):
  233. expansion = 2
  234. def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
  235. super(BasicBlockRes2Net, self).__init__()
  236. width = int(math.floor(planes * (baseWidth / 64.0)))
  237. self.conv1 = conv1x1(in_planes, width * scale, stride)
  238. self.bn1 = nn.BatchNorm2d(width * scale)
  239. self.nums = scale - 1
  240. convs = []
  241. bns = []
  242. for i in range(self.nums):
  243. convs.append(conv3x3(width, width))
  244. bns.append(nn.BatchNorm2d(width))
  245. self.convs = nn.ModuleList(convs)
  246. self.bns = nn.ModuleList(bns)
  247. self.relu = ReLU(inplace=True)
  248. self.conv3 = conv1x1(width * scale, planes * self.expansion)
  249. self.bn3 = nn.BatchNorm2d(planes * self.expansion)
  250. self.shortcut = nn.Sequential()
  251. if stride != 1 or in_planes != self.expansion * planes:
  252. self.shortcut = nn.Sequential(
  253. nn.Conv2d(in_planes,
  254. self.expansion * planes,
  255. kernel_size=1,
  256. stride=stride,
  257. bias=False),
  258. nn.BatchNorm2d(self.expansion * planes))
  259. self.stride = stride
  260. self.width = width
  261. self.scale = scale
  262. def forward(self, x):
  263. residual = x
  264. out = self.conv1(x)
  265. out = self.bn1(out)
  266. out = self.relu(out)
  267. spx = torch.split(out, self.width, 1)
  268. for i in range(self.nums):
  269. if i == 0:
  270. sp = spx[i]
  271. else:
  272. sp = sp + spx[i]
  273. sp = self.convs[i](sp)
  274. sp = self.relu(self.bns[i](sp))
  275. if i == 0:
  276. out = sp
  277. else:
  278. out = torch.cat((out, sp), 1)
  279. out = torch.cat((out, spx[self.nums]), 1)
  280. out = self.conv3(out)
  281. out = self.bn3(out)
  282. residual = self.shortcut(x)
  283. out += residual
  284. out = self.relu(out)
  285. return out
  286. class Res2Net(nn.Module):
  287. def __init__(self,
  288. block=BasicBlockRes2Net,
  289. num_blocks=[3, 4, 6, 3],
  290. m_channels=32,
  291. feat_dim=80,
  292. embedding_size=192,
  293. pooling_func='TSTP',
  294. two_emb_layer=False):
  295. super(Res2Net, self).__init__()
  296. self.in_planes = m_channels
  297. self.feat_dim = feat_dim
  298. self.embedding_size = embedding_size
  299. self.stats_dim = int(feat_dim / 8) * m_channels * 8
  300. self.two_emb_layer = two_emb_layer
  301. self.conv1 = nn.Conv2d(1,
  302. m_channels,
  303. kernel_size=3,
  304. stride=1,
  305. padding=1,
  306. bias=False)
  307. self.bn1 = nn.BatchNorm2d(m_channels)
  308. self.layer1 = self._make_layer(block,
  309. m_channels,
  310. num_blocks[0],
  311. stride=1)
  312. self.layer2 = self._make_layer(block,
  313. m_channels * 2,
  314. num_blocks[1],
  315. stride=2)
  316. self.layer3 = self._make_layer(block,
  317. m_channels * 4,
  318. num_blocks[2],
  319. stride=2)
  320. self.layer4 = self._make_layer(block,
  321. m_channels * 8,
  322. num_blocks[3],
  323. stride=2)
  324. self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
  325. self.pool = getattr(pooling_layers, pooling_func)(
  326. in_dim=self.stats_dim * block.expansion)
  327. self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
  328. embedding_size)
  329. if self.two_emb_layer:
  330. self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
  331. self.seg_2 = nn.Linear(embedding_size, embedding_size)
  332. else:
  333. self.seg_bn_1 = nn.Identity()
  334. self.seg_2 = nn.Identity()
  335. def _make_layer(self, block, planes, num_blocks, stride):
  336. strides = [stride] + [1] * (num_blocks - 1)
  337. layers = []
  338. for stride in strides:
  339. layers.append(block(self.in_planes, planes, stride))
  340. self.in_planes = planes * block.expansion
  341. return nn.Sequential(*layers)
  342. def forward(self, x):
  343. x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
  344. x = x.unsqueeze_(1)
  345. out = F.relu(self.bn1(self.conv1(x)))
  346. out = self.layer1(out)
  347. out = self.layer2(out)
  348. out = self.layer3(out)
  349. out = self.layer4(out)
  350. stats = self.pool(out)
  351. embed_a = self.seg_1(stats)
  352. if self.two_emb_layer:
  353. out = F.relu(embed_a)
  354. out = self.seg_bn_1(out)
  355. embed_b = self.seg_2(out)
  356. return embed_b
  357. else:
  358. return embed_a