ResNet.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
  2. # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """ Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
  4. ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
  5. The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
  6. The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
  7. ERes2Net-Large is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
  8. recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
  9. """
  10. import torch
  11. import math
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. import funasr.models.whisper_lid.eres2net.pooling_layers as pooling_layers
  15. from funasr.models.whisper_lid.eres2net.fusion import AFF
  16. class ReLU(nn.Hardtanh):
  17. def __init__(self, inplace=False):
  18. super(ReLU, self).__init__(0, 20, inplace)
  19. def __repr__(self):
  20. inplace_str = 'inplace' if self.inplace else ''
  21. return self.__class__.__name__ + ' (' \
  22. + inplace_str + ')'
  23. def conv1x1(in_planes, out_planes, stride=1):
  24. "1x1 convolution without padding"
  25. return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
  26. padding=0, bias=False)
  27. def conv3x3(in_planes, out_planes, stride=1):
  28. "3x3 convolution with padding"
  29. return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  30. padding=1, bias=False)
  31. class BasicBlockERes2Net(nn.Module):
  32. expansion = 2
  33. def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
  34. super(BasicBlockERes2Net, self).__init__()
  35. width = int(math.floor(planes * (baseWidth / 64.0)))
  36. self.conv1 = conv1x1(in_planes, width * scale, stride)
  37. self.bn1 = nn.BatchNorm2d(width * scale)
  38. self.nums = scale
  39. convs = []
  40. bns = []
  41. for i in range(self.nums):
  42. convs.append(conv3x3(width, width))
  43. bns.append(nn.BatchNorm2d(width))
  44. self.convs = nn.ModuleList(convs)
  45. self.bns = nn.ModuleList(bns)
  46. self.relu = ReLU(inplace=True)
  47. self.conv3 = conv1x1(width * scale, planes * self.expansion)
  48. self.bn3 = nn.BatchNorm2d(planes * self.expansion)
  49. self.shortcut = nn.Sequential()
  50. if stride != 1 or in_planes != self.expansion * planes:
  51. self.shortcut = nn.Sequential(
  52. nn.Conv2d(in_planes,
  53. self.expansion * planes,
  54. kernel_size=1,
  55. stride=stride,
  56. bias=False),
  57. nn.BatchNorm2d(self.expansion * planes))
  58. self.stride = stride
  59. self.width = width
  60. self.scale = scale
  61. def forward(self, x):
  62. residual = x
  63. out = self.conv1(x)
  64. out = self.bn1(out)
  65. out = self.relu(out)
  66. spx = torch.split(out, self.width, 1)
  67. for i in range(self.nums):
  68. if i == 0:
  69. sp = spx[i]
  70. else:
  71. sp = sp + spx[i]
  72. sp = self.convs[i](sp)
  73. sp = self.relu(self.bns[i](sp))
  74. if i == 0:
  75. out = sp
  76. else:
  77. out = torch.cat((out, sp), 1)
  78. out = self.conv3(out)
  79. out = self.bn3(out)
  80. residual = self.shortcut(x)
  81. out += residual
  82. out = self.relu(out)
  83. return out
  84. class BasicBlockERes2Net_diff_AFF(nn.Module):
  85. expansion = 2
  86. def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
  87. super(BasicBlockERes2Net_diff_AFF, self).__init__()
  88. width = int(math.floor(planes * (baseWidth / 64.0)))
  89. self.conv1 = conv1x1(in_planes, width * scale, stride)
  90. self.bn1 = nn.BatchNorm2d(width * scale)
  91. self.nums = scale
  92. convs = []
  93. fuse_models = []
  94. bns = []
  95. for i in range(self.nums):
  96. convs.append(conv3x3(width, width))
  97. bns.append(nn.BatchNorm2d(width))
  98. for j in range(self.nums - 1):
  99. fuse_models.append(AFF(channels=width))
  100. self.convs = nn.ModuleList(convs)
  101. self.bns = nn.ModuleList(bns)
  102. self.fuse_models = nn.ModuleList(fuse_models)
  103. self.relu = ReLU(inplace=True)
  104. self.conv3 = conv1x1(width * scale, planes * self.expansion)
  105. self.bn3 = nn.BatchNorm2d(planes * self.expansion)
  106. self.shortcut = nn.Sequential()
  107. if stride != 1 or in_planes != self.expansion * planes:
  108. self.shortcut = nn.Sequential(
  109. nn.Conv2d(in_planes,
  110. self.expansion * planes,
  111. kernel_size=1,
  112. stride=stride,
  113. bias=False),
  114. nn.BatchNorm2d(self.expansion * planes))
  115. self.stride = stride
  116. self.width = width
  117. self.scale = scale
  118. def forward(self, x):
  119. residual = x
  120. out = self.conv1(x)
  121. out = self.bn1(out)
  122. out = self.relu(out)
  123. spx = torch.split(out, self.width, 1)
  124. for i in range(self.nums):
  125. if i == 0:
  126. sp = spx[i]
  127. else:
  128. sp = self.fuse_models[i - 1](sp, spx[i])
  129. sp = self.convs[i](sp)
  130. sp = self.relu(self.bns[i](sp))
  131. if i == 0:
  132. out = sp
  133. else:
  134. out = torch.cat((out, sp), 1)
  135. out = self.conv3(out)
  136. out = self.bn3(out)
  137. residual = self.shortcut(x)
  138. out += residual
  139. out = self.relu(out)
  140. return out
  141. class ERes2Net(nn.Module):
  142. def __init__(self,
  143. block=BasicBlockERes2Net,
  144. block_fuse=BasicBlockERes2Net_diff_AFF,
  145. num_blocks=[3, 4, 6, 3],
  146. m_channels=32,
  147. feat_dim=80,
  148. embedding_size=192,
  149. pooling_func='TSTP',
  150. two_emb_layer=False):
  151. super(ERes2Net, self).__init__()
  152. self.in_planes = m_channels
  153. self.feat_dim = feat_dim
  154. self.embedding_size = embedding_size
  155. self.stats_dim = int(feat_dim / 8) * m_channels * 8
  156. self.two_emb_layer = two_emb_layer
  157. self._output_size = embedding_size
  158. self.conv1 = nn.Conv2d(1,
  159. m_channels,
  160. kernel_size=3,
  161. stride=1,
  162. padding=1,
  163. bias=False)
  164. self.bn1 = nn.BatchNorm2d(m_channels)
  165. self.layer1 = self._make_layer(block,
  166. m_channels,
  167. num_blocks[0],
  168. stride=1)
  169. self.layer2 = self._make_layer(block,
  170. m_channels * 2,
  171. num_blocks[1],
  172. stride=2)
  173. self.layer3 = self._make_layer(block_fuse,
  174. m_channels * 4,
  175. num_blocks[2],
  176. stride=2)
  177. self.layer4 = self._make_layer(block_fuse,
  178. m_channels * 8,
  179. num_blocks[3],
  180. stride=2)
  181. # Downsampling module for each layer
  182. self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1,
  183. bias=False)
  184. self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2,
  185. bias=False)
  186. self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2,
  187. bias=False)
  188. # Bottom-up fusion module
  189. self.fuse_mode12 = AFF(channels=m_channels * 4)
  190. self.fuse_mode123 = AFF(channels=m_channels * 8)
  191. self.fuse_mode1234 = AFF(channels=m_channels * 16)
  192. self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
  193. self.pool = getattr(pooling_layers, pooling_func)(
  194. in_dim=self.stats_dim * block.expansion)
  195. self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
  196. embedding_size)
  197. if self.two_emb_layer:
  198. self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
  199. self.seg_2 = nn.Linear(embedding_size, embedding_size)
  200. else:
  201. self.seg_bn_1 = nn.Identity()
  202. self.seg_2 = nn.Identity()
  203. def _make_layer(self, block, planes, num_blocks, stride):
  204. strides = [stride] + [1] * (num_blocks - 1)
  205. layers = []
  206. for stride in strides:
  207. layers.append(block(self.in_planes, planes, stride))
  208. self.in_planes = planes * block.expansion
  209. return nn.Sequential(*layers)
  210. def output_size(self) -> int:
  211. return self._output_size
  212. def forward(self, x, ilens):
  213. # assert x.shape[1] == ilens.max()
  214. x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
  215. x = x.unsqueeze_(1)
  216. out = F.relu(self.bn1(self.conv1(x)))
  217. out1 = self.layer1(out)
  218. out2 = self.layer2(out1)
  219. out1_downsample = self.layer1_downsample(out1)
  220. fuse_out12 = self.fuse_mode12(out2, out1_downsample)
  221. out3 = self.layer3(out2)
  222. fuse_out12_downsample = self.layer2_downsample(fuse_out12)
  223. fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
  224. out4 = self.layer4(out3)
  225. fuse_out123_downsample = self.layer3_downsample(fuse_out123)
  226. fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
  227. olens = (((((ilens - 1) // 2 + 1) - 1) // 2 + 1) - 1) // 2 + 1
  228. stats = self.pool(fuse_out1234, olens)
  229. embed_a = self.seg_1(stats)
  230. if self.two_emb_layer:
  231. out = F.relu(embed_a)
  232. out = self.seg_bn_1(out)
  233. embed_b = self.seg_2(out)
  234. return embed_b
  235. else:
  236. return embed_a
  237. class BasicBlockRes2Net(nn.Module):
  238. expansion = 2
  239. def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
  240. super(BasicBlockRes2Net, self).__init__()
  241. width = int(math.floor(planes * (baseWidth / 64.0)))
  242. self.conv1 = conv1x1(in_planes, width * scale, stride)
  243. self.bn1 = nn.BatchNorm2d(width * scale)
  244. self.nums = scale - 1
  245. convs = []
  246. bns = []
  247. for i in range(self.nums):
  248. convs.append(conv3x3(width, width))
  249. bns.append(nn.BatchNorm2d(width))
  250. self.convs = nn.ModuleList(convs)
  251. self.bns = nn.ModuleList(bns)
  252. self.relu = ReLU(inplace=True)
  253. self.conv3 = conv1x1(width * scale, planes * self.expansion)
  254. self.bn3 = nn.BatchNorm2d(planes * self.expansion)
  255. self.shortcut = nn.Sequential()
  256. if stride != 1 or in_planes != self.expansion * planes:
  257. self.shortcut = nn.Sequential(
  258. nn.Conv2d(in_planes,
  259. self.expansion * planes,
  260. kernel_size=1,
  261. stride=stride,
  262. bias=False),
  263. nn.BatchNorm2d(self.expansion * planes))
  264. self.stride = stride
  265. self.width = width
  266. self.scale = scale
  267. def forward(self, x):
  268. residual = x
  269. out = self.conv1(x)
  270. out = self.bn1(out)
  271. out = self.relu(out)
  272. spx = torch.split(out, self.width, 1)
  273. for i in range(self.nums):
  274. if i == 0:
  275. sp = spx[i]
  276. else:
  277. sp = sp + spx[i]
  278. sp = self.convs[i](sp)
  279. sp = self.relu(self.bns[i](sp))
  280. if i == 0:
  281. out = sp
  282. else:
  283. out = torch.cat((out, sp), 1)
  284. out = torch.cat((out, spx[self.nums]), 1)
  285. out = self.conv3(out)
  286. out = self.bn3(out)
  287. residual = self.shortcut(x)
  288. out += residual
  289. out = self.relu(out)
  290. return out
  291. class Res2Net(nn.Module):
  292. def __init__(self,
  293. block=BasicBlockRes2Net,
  294. num_blocks=[3, 4, 6, 3],
  295. m_channels=32,
  296. feat_dim=80,
  297. embedding_size=192,
  298. pooling_func='TSTP',
  299. two_emb_layer=False):
  300. super(Res2Net, self).__init__()
  301. self.in_planes = m_channels
  302. self.feat_dim = feat_dim
  303. self.embedding_size = embedding_size
  304. self.stats_dim = int(feat_dim / 8) * m_channels * 8
  305. self.two_emb_layer = two_emb_layer
  306. self.conv1 = nn.Conv2d(1,
  307. m_channels,
  308. kernel_size=3,
  309. stride=1,
  310. padding=1,
  311. bias=False)
  312. self.bn1 = nn.BatchNorm2d(m_channels)
  313. self.layer1 = self._make_layer(block,
  314. m_channels,
  315. num_blocks[0],
  316. stride=1)
  317. self.layer2 = self._make_layer(block,
  318. m_channels * 2,
  319. num_blocks[1],
  320. stride=2)
  321. self.layer3 = self._make_layer(block,
  322. m_channels * 4,
  323. num_blocks[2],
  324. stride=2)
  325. self.layer4 = self._make_layer(block,
  326. m_channels * 8,
  327. num_blocks[3],
  328. stride=2)
  329. self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
  330. self.pool = getattr(pooling_layers, pooling_func)(
  331. in_dim=self.stats_dim * block.expansion)
  332. self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
  333. embedding_size)
  334. if self.two_emb_layer:
  335. self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
  336. self.seg_2 = nn.Linear(embedding_size, embedding_size)
  337. else:
  338. self.seg_bn_1 = nn.Identity()
  339. self.seg_2 = nn.Identity()
  340. def _make_layer(self, block, planes, num_blocks, stride):
  341. strides = [stride] + [1] * (num_blocks - 1)
  342. layers = []
  343. for stride in strides:
  344. layers.append(block(self.in_planes, planes, stride))
  345. self.in_planes = planes * block.expansion
  346. return nn.Sequential(*layers)
  347. def forward(self, x):
  348. x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
  349. x = x.unsqueeze_(1)
  350. out = F.relu(self.bn1(self.conv1(x)))
  351. out = self.layer1(out)
  352. out = self.layer2(out)
  353. out = self.layer3(out)
  354. out = self.layer4(out)
  355. stats = self.pool(out)
  356. embed_a = self.seg_1(stats)
  357. if self.two_emb_layer:
  358. out = F.relu(embed_a)
  359. out = self.seg_bn_1(out)
  360. embed_b = self.seg_2(out)
  361. return embed_b
  362. else:
  363. return embed_a