resnet34_encoder.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853
  1. import torch
  2. from torch.nn import functional as F
  3. from funasr.models.encoder.abs_encoder import AbsEncoder
  4. from typing import Tuple, Optional
  5. from funasr.models.pooling.statistic_pooling import statistic_pooling, windowed_statistic_pooling
  6. from collections import OrderedDict
  7. import logging
  8. import numpy as np
  9. class BasicLayer(torch.nn.Module):
  10. def __init__(self, in_filters: int, filters: int, stride: int, bn_momentum: float = 0.5):
  11. super().__init__()
  12. self.stride = stride
  13. self.in_filters = in_filters
  14. self.filters = filters
  15. self.bn1 = torch.nn.BatchNorm2d(in_filters, eps=1e-3, momentum=bn_momentum, affine=True)
  16. self.relu1 = torch.nn.ReLU()
  17. self.conv1 = torch.nn.Conv2d(in_filters, filters, 3, stride, bias=False)
  18. self.bn2 = torch.nn.BatchNorm2d(filters, eps=1e-3, momentum=bn_momentum, affine=True)
  19. self.relu2 = torch.nn.ReLU()
  20. self.conv2 = torch.nn.Conv2d(filters, filters, 3, 1, bias=False)
  21. if in_filters != filters or stride > 1:
  22. self.conv_sc = torch.nn.Conv2d(in_filters, filters, 1, stride, bias=False)
  23. self.bn_sc = torch.nn.BatchNorm2d(filters, eps=1e-3, momentum=bn_momentum, affine=True)
  24. def proper_padding(self, x, stride):
  25. # align padding mode to tf.layers.conv2d with padding_mod="same"
  26. if stride == 1:
  27. return F.pad(x, (1, 1, 1, 1), "constant", 0)
  28. elif stride == 2:
  29. h, w = x.size(2), x.size(3)
  30. # (left, right, top, bottom)
  31. return F.pad(x, (w % 2, 1, h % 2, 1), "constant", 0)
  32. def forward(self, xs_pad, ilens):
  33. identity = xs_pad
  34. if self.in_filters != self.filters or self.stride > 1:
  35. identity = self.conv_sc(identity)
  36. identity = self.bn_sc(identity)
  37. xs_pad = self.relu1(self.bn1(xs_pad))
  38. xs_pad = self.proper_padding(xs_pad, self.stride)
  39. xs_pad = self.conv1(xs_pad)
  40. xs_pad = self.relu2(self.bn2(xs_pad))
  41. xs_pad = self.proper_padding(xs_pad, 1)
  42. xs_pad = self.conv2(xs_pad)
  43. if self.stride == 2:
  44. ilens = (ilens + 1) // self.stride
  45. return xs_pad + identity, ilens
  46. class BasicBlock(torch.nn.Module):
  47. def __init__(self, in_filters, filters, num_layer, stride, bn_momentum=0.5):
  48. super().__init__()
  49. self.num_layer = num_layer
  50. for i in range(num_layer):
  51. layer = BasicLayer(in_filters if i == 0 else filters, filters,
  52. stride if i == 0 else 1, bn_momentum)
  53. self.add_module("layer_{}".format(i), layer)
  54. def forward(self, xs_pad, ilens):
  55. for i in range(self.num_layer):
  56. xs_pad, ilens = self._modules["layer_{}".format(i)](xs_pad, ilens)
  57. return xs_pad, ilens
  58. class ResNet34(AbsEncoder):
  59. def __init__(
  60. self,
  61. input_size,
  62. use_head_conv=True,
  63. batchnorm_momentum=0.5,
  64. use_head_maxpool=False,
  65. num_nodes_pooling_layer=256,
  66. layers_in_block=(3, 4, 6, 3),
  67. filters_in_block=(32, 64, 128, 256),
  68. ):
  69. super(ResNet34, self).__init__()
  70. self.use_head_conv = use_head_conv
  71. self.use_head_maxpool = use_head_maxpool
  72. self.num_nodes_pooling_layer = num_nodes_pooling_layer
  73. self.layers_in_block = layers_in_block
  74. self.filters_in_block = filters_in_block
  75. self.input_size = input_size
  76. pre_filters = filters_in_block[0]
  77. if use_head_conv:
  78. self.pre_conv = torch.nn.Conv2d(1, pre_filters, 3, 1, 1, bias=False, padding_mode="zeros")
  79. self.pre_conv_bn = torch.nn.BatchNorm2d(pre_filters, eps=1e-3, momentum=batchnorm_momentum)
  80. if use_head_maxpool:
  81. self.head_maxpool = torch.nn.MaxPool2d(3, 1, padding=1)
  82. for i in range(len(layers_in_block)):
  83. if i == 0:
  84. in_filters = pre_filters if self.use_head_conv else 1
  85. else:
  86. in_filters = filters_in_block[i-1]
  87. block = BasicBlock(in_filters,
  88. filters=filters_in_block[i],
  89. num_layer=layers_in_block[i],
  90. stride=1 if i == 0 else 2,
  91. bn_momentum=batchnorm_momentum)
  92. self.add_module("block_{}".format(i), block)
  93. self.resnet0_dense = torch.nn.Conv2d(filters_in_block[-1], num_nodes_pooling_layer, 1)
  94. self.resnet0_bn = torch.nn.BatchNorm2d(num_nodes_pooling_layer, eps=1e-3, momentum=batchnorm_momentum)
  95. self.time_ds_ratio = 8
  96. def output_size(self) -> int:
  97. return self.num_nodes_pooling_layer
  98. def forward(
  99. self,
  100. xs_pad: torch.Tensor,
  101. ilens: torch.Tensor,
  102. prev_states: torch.Tensor = None
  103. ) -> Tuple[torch.Tensor, torch.Tensor]:
  104. features = xs_pad
  105. assert features.size(-1) == self.input_size, \
  106. "Dimension of features {} doesn't match the input_size {}.".format(features.size(-1), self.input_size)
  107. features = torch.unsqueeze(features, dim=1)
  108. if self.use_head_conv:
  109. features = self.pre_conv(features)
  110. features = self.pre_conv_bn(features)
  111. features = F.relu(features)
  112. if self.use_head_maxpool:
  113. features = self.head_maxpool(features)
  114. resnet_outs, resnet_out_lens = features, ilens
  115. for i in range(len(self.layers_in_block)):
  116. block = self._modules["block_{}".format(i)]
  117. resnet_outs, resnet_out_lens = block(resnet_outs, resnet_out_lens)
  118. features = self.resnet0_dense(resnet_outs)
  119. features = F.relu(features)
  120. features = self.resnet0_bn(features)
  121. return features, resnet_out_lens
  122. # Note: For training, this implement is not equivalent to tf because of the kernel_regularizer in tf.layers.
  123. # TODO: implement kernel_regularizer in torch with munal loss addition or weigth_decay in the optimizer
  124. class ResNet34_SP_L2Reg(AbsEncoder):
  125. def __init__(
  126. self,
  127. input_size,
  128. use_head_conv=True,
  129. batchnorm_momentum=0.5,
  130. use_head_maxpool=False,
  131. num_nodes_pooling_layer=256,
  132. layers_in_block=(3, 4, 6, 3),
  133. filters_in_block=(32, 64, 128, 256),
  134. tf2torch_tensor_name_prefix_torch="encoder",
  135. tf2torch_tensor_name_prefix_tf="EAND/speech_encoder",
  136. tf_train_steps=720000,
  137. ):
  138. super(ResNet34_SP_L2Reg, self).__init__()
  139. self.use_head_conv = use_head_conv
  140. self.use_head_maxpool = use_head_maxpool
  141. self.num_nodes_pooling_layer = num_nodes_pooling_layer
  142. self.layers_in_block = layers_in_block
  143. self.filters_in_block = filters_in_block
  144. self.input_size = input_size
  145. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  146. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  147. self.tf_train_steps = tf_train_steps
  148. pre_filters = filters_in_block[0]
  149. if use_head_conv:
  150. self.pre_conv = torch.nn.Conv2d(1, pre_filters, 3, 1, 1, bias=False, padding_mode="zeros")
  151. self.pre_conv_bn = torch.nn.BatchNorm2d(pre_filters, eps=1e-3, momentum=batchnorm_momentum)
  152. if use_head_maxpool:
  153. self.head_maxpool = torch.nn.MaxPool2d(3, 1, padding=1)
  154. for i in range(len(layers_in_block)):
  155. if i == 0:
  156. in_filters = pre_filters if self.use_head_conv else 1
  157. else:
  158. in_filters = filters_in_block[i-1]
  159. block = BasicBlock(in_filters,
  160. filters=filters_in_block[i],
  161. num_layer=layers_in_block[i],
  162. stride=1 if i == 0 else 2,
  163. bn_momentum=batchnorm_momentum)
  164. self.add_module("block_{}".format(i), block)
  165. self.resnet0_dense = torch.nn.Conv1d(filters_in_block[-1] * input_size // 8, num_nodes_pooling_layer, 1)
  166. self.resnet0_bn = torch.nn.BatchNorm1d(num_nodes_pooling_layer, eps=1e-3, momentum=batchnorm_momentum)
  167. self.time_ds_ratio = 8
  168. def output_size(self) -> int:
  169. return self.num_nodes_pooling_layer
  170. def forward(
  171. self,
  172. xs_pad: torch.Tensor,
  173. ilens: torch.Tensor,
  174. prev_states: torch.Tensor = None
  175. ) -> Tuple[torch.Tensor, torch.Tensor]:
  176. features = xs_pad
  177. assert features.size(-1) == self.input_size, \
  178. "Dimension of features {} doesn't match the input_size {}.".format(features.size(-1), self.input_size)
  179. features = torch.unsqueeze(features, dim=1)
  180. if self.use_head_conv:
  181. features = self.pre_conv(features)
  182. features = self.pre_conv_bn(features)
  183. features = F.relu(features)
  184. if self.use_head_maxpool:
  185. features = self.head_maxpool(features)
  186. resnet_outs, resnet_out_lens = features, ilens
  187. for i in range(len(self.layers_in_block)):
  188. block = self._modules["block_{}".format(i)]
  189. resnet_outs, resnet_out_lens = block(resnet_outs, resnet_out_lens)
  190. # B, C, T, F
  191. bb, cc, tt, ff = resnet_outs.shape
  192. resnet_outs = torch.reshape(resnet_outs.permute(0, 3, 1, 2), [bb, ff*cc, tt])
  193. features = self.resnet0_dense(resnet_outs)
  194. features = F.relu(features)
  195. features = self.resnet0_bn(features)
  196. return features, resnet_out_lens
  197. def gen_tf2torch_map_dict(self):
  198. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  199. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  200. train_steps = self.tf_train_steps
  201. map_dict_local = {
  202. # torch: conv1d.weight in "out_channel in_channel kernel_size"
  203. # tf : conv1d.weight in "kernel_size in_channel out_channel"
  204. # torch: linear.weight in "out_channel in_channel"
  205. # tf : dense.weight in "in_channel out_channel"
  206. "{}.pre_conv.weight".format(tensor_name_prefix_torch):
  207. {"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
  208. "squeeze": None,
  209. "transpose": (3, 2, 0, 1),
  210. },
  211. "{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
  212. {"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
  213. "squeeze": None,
  214. "transpose": None,
  215. },
  216. "{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
  217. {"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
  218. "squeeze": None,
  219. "transpose": None,
  220. },
  221. "{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
  222. {"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
  223. "squeeze": None,
  224. "transpose": None,
  225. },
  226. "{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
  227. {"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
  228. "squeeze": None,
  229. "transpose": None,
  230. },
  231. "{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
  232. }
  233. for layer_idx in range(3):
  234. map_dict_local.update({
  235. "{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
  236. {"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
  237. "squeeze": None,
  238. "transpose": (2, 1, 0) if layer_idx == 0 else (1, 0),
  239. },
  240. "{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
  241. {"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
  242. "squeeze": None,
  243. "transpose": None,
  244. },
  245. "{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
  246. {"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
  247. "squeeze": None,
  248. "transpose": None,
  249. },
  250. "{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
  251. {"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
  252. "squeeze": None,
  253. "transpose": None,
  254. },
  255. "{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
  256. {"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
  257. "squeeze": None,
  258. "transpose": None,
  259. },
  260. "{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
  261. {"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
  262. "squeeze": None,
  263. "transpose": None,
  264. },
  265. "{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
  266. })
  267. for block_idx in range(len(self.layers_in_block)):
  268. for layer_idx in range(self.layers_in_block[block_idx]):
  269. for i in ["1", "2", "_sc"]:
  270. map_dict_local.update({
  271. "{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
  272. {"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
  273. "squeeze": None,
  274. "transpose": (3, 2, 0, 1),
  275. },
  276. "{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
  277. {"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
  278. "squeeze": None,
  279. "transpose": None,
  280. },
  281. "{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
  282. {"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
  283. "squeeze": None,
  284. "transpose": None,
  285. },
  286. "{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
  287. {"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
  288. "squeeze": None,
  289. "transpose": None,
  290. },
  291. "{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
  292. {"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
  293. "squeeze": None,
  294. "transpose": None,
  295. },
  296. "{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
  297. })
  298. return map_dict_local
  299. def convert_tf2torch(self,
  300. var_dict_tf,
  301. var_dict_torch,
  302. ):
  303. map_dict = self.gen_tf2torch_map_dict()
  304. var_dict_torch_update = dict()
  305. for name in sorted(var_dict_torch.keys(), reverse=False):
  306. if name.startswith(self.tf2torch_tensor_name_prefix_torch):
  307. if name in map_dict:
  308. if "num_batches_tracked" not in name:
  309. name_tf = map_dict[name]["name"]
  310. data_tf = var_dict_tf[name_tf]
  311. if map_dict[name]["squeeze"] is not None:
  312. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  313. if map_dict[name]["transpose"] is not None:
  314. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  315. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  316. assert var_dict_torch[name].size() == data_tf.size(), \
  317. "{}, {}, {} != {}".format(name, name_tf,
  318. var_dict_torch[name].size(), data_tf.size())
  319. var_dict_torch_update[name] = data_tf
  320. logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
  321. name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
  322. ))
  323. else:
  324. var_dict_torch_update[name] = torch.Tensor(map_dict[name]).type(torch.int64).to("cpu")
  325. logging.info("torch tensor: {}, manually assigning to: {}".format(
  326. name, map_dict[name]
  327. ))
  328. else:
  329. logging.warning("{} is missed from tf checkpoint".format(name))
  330. return var_dict_torch_update
  331. class ResNet34Diar(ResNet34):
  332. def __init__(
  333. self,
  334. input_size,
  335. embedding_node="resnet1_dense",
  336. use_head_conv=True,
  337. batchnorm_momentum=0.5,
  338. use_head_maxpool=False,
  339. num_nodes_pooling_layer=256,
  340. layers_in_block=(3, 4, 6, 3),
  341. filters_in_block=(32, 64, 128, 256),
  342. num_nodes_resnet1=256,
  343. num_nodes_last_layer=256,
  344. pooling_type="window_shift",
  345. pool_size=20,
  346. stride=1,
  347. tf2torch_tensor_name_prefix_torch="encoder",
  348. tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
  349. ):
  350. """
  351. Author: Speech Lab, Alibaba Group, China
  352. SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
  353. https://arxiv.org/abs/2211.10243
  354. """
  355. super(ResNet34Diar, self).__init__(
  356. input_size,
  357. use_head_conv=use_head_conv,
  358. batchnorm_momentum=batchnorm_momentum,
  359. use_head_maxpool=use_head_maxpool,
  360. num_nodes_pooling_layer=num_nodes_pooling_layer,
  361. layers_in_block=layers_in_block,
  362. filters_in_block=filters_in_block,
  363. )
  364. self.embedding_node = embedding_node
  365. self.num_nodes_resnet1 = num_nodes_resnet1
  366. self.num_nodes_last_layer = num_nodes_last_layer
  367. self.pooling_type = pooling_type
  368. self.pool_size = pool_size
  369. self.stride = stride
  370. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  371. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  372. self.resnet1_dense = torch.nn.Linear(num_nodes_pooling_layer * 2, num_nodes_resnet1)
  373. self.resnet1_bn = torch.nn.BatchNorm1d(num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum)
  374. self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
  375. self.resnet2_bn = torch.nn.BatchNorm1d(num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum)
  376. def output_size(self) -> int:
  377. if self.embedding_node.startswith("resnet1"):
  378. return self.num_nodes_resnet1
  379. elif self.embedding_node.startswith("resnet2"):
  380. return self.num_nodes_last_layer
  381. return self.num_nodes_pooling_layer
  382. def forward(
  383. self,
  384. xs_pad: torch.Tensor,
  385. ilens: torch.Tensor,
  386. prev_states: torch.Tensor = None,
  387. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  388. endpoints = OrderedDict()
  389. res_out, ilens = super().forward(xs_pad, ilens)
  390. endpoints["resnet0_bn"] = res_out
  391. if self.pooling_type == "frame_gsp":
  392. features = statistic_pooling(res_out, ilens, (3, ))
  393. else:
  394. features, ilens = windowed_statistic_pooling(res_out, ilens, (2, 3), self.pool_size, self.stride)
  395. features = features.transpose(1, 2)
  396. endpoints["pooling"] = features
  397. features = self.resnet1_dense(features)
  398. endpoints["resnet1_dense"] = features
  399. features = F.relu(features)
  400. endpoints["resnet1_relu"] = features
  401. features = self.resnet1_bn(features.transpose(1, 2)).transpose(1, 2)
  402. endpoints["resnet1_bn"] = features
  403. features = self.resnet2_dense(features)
  404. endpoints["resnet2_dense"] = features
  405. features = F.relu(features)
  406. endpoints["resnet2_relu"] = features
  407. features = self.resnet2_bn(features.transpose(1, 2)).transpose(1, 2)
  408. endpoints["resnet2_bn"] = features
  409. return endpoints[self.embedding_node], ilens, None
  410. def gen_tf2torch_map_dict(self):
  411. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  412. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  413. train_steps = 300000
  414. map_dict_local = {
  415. # torch: conv1d.weight in "out_channel in_channel kernel_size"
  416. # tf : conv1d.weight in "kernel_size in_channel out_channel"
  417. # torch: linear.weight in "out_channel in_channel"
  418. # tf : dense.weight in "in_channel out_channel"
  419. "{}.pre_conv.weight".format(tensor_name_prefix_torch):
  420. {"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
  421. "squeeze": None,
  422. "transpose": (3, 2, 0, 1),
  423. },
  424. "{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
  425. {"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
  426. "squeeze": None,
  427. "transpose": None,
  428. },
  429. "{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
  430. {"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
  431. "squeeze": None,
  432. "transpose": None,
  433. },
  434. "{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
  435. {"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
  436. "squeeze": None,
  437. "transpose": None,
  438. },
  439. "{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
  440. {"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
  441. "squeeze": None,
  442. "transpose": None,
  443. },
  444. "{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
  445. }
  446. for layer_idx in range(3):
  447. map_dict_local.update({
  448. "{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
  449. {"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
  450. "squeeze": None,
  451. "transpose": (3, 2, 0, 1) if layer_idx == 0 else (1, 0),
  452. },
  453. "{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
  454. {"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
  455. "squeeze": None,
  456. "transpose": None,
  457. },
  458. "{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
  459. {"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
  460. "squeeze": None,
  461. "transpose": None,
  462. },
  463. "{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
  464. {"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
  465. "squeeze": None,
  466. "transpose": None,
  467. },
  468. "{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
  469. {"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
  470. "squeeze": None,
  471. "transpose": None,
  472. },
  473. "{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
  474. {"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
  475. "squeeze": None,
  476. "transpose": None,
  477. },
  478. "{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
  479. })
  480. for block_idx in range(len(self.layers_in_block)):
  481. for layer_idx in range(self.layers_in_block[block_idx]):
  482. for i in ["1", "2", "_sc"]:
  483. map_dict_local.update({
  484. "{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
  485. {"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
  486. "squeeze": None,
  487. "transpose": (3, 2, 0, 1),
  488. },
  489. "{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
  490. {"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
  491. "squeeze": None,
  492. "transpose": None,
  493. },
  494. "{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
  495. {"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
  496. "squeeze": None,
  497. "transpose": None,
  498. },
  499. "{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
  500. {"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
  501. "squeeze": None,
  502. "transpose": None,
  503. },
  504. "{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
  505. {"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
  506. "squeeze": None,
  507. "transpose": None,
  508. },
  509. "{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
  510. })
  511. return map_dict_local
  512. def convert_tf2torch(self,
  513. var_dict_tf,
  514. var_dict_torch,
  515. ):
  516. map_dict = self.gen_tf2torch_map_dict()
  517. var_dict_torch_update = dict()
  518. for name in sorted(var_dict_torch.keys(), reverse=False):
  519. if name.startswith(self.tf2torch_tensor_name_prefix_torch):
  520. if name in map_dict:
  521. if "num_batches_tracked" not in name:
  522. name_tf = map_dict[name]["name"]
  523. data_tf = var_dict_tf[name_tf]
  524. if map_dict[name]["squeeze"] is not None:
  525. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  526. if map_dict[name]["transpose"] is not None:
  527. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  528. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  529. assert var_dict_torch[name].size() == data_tf.size(), \
  530. "{}, {}, {} != {}".format(name, name_tf,
  531. var_dict_torch[name].size(), data_tf.size())
  532. var_dict_torch_update[name] = data_tf
  533. logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
  534. name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
  535. ))
  536. else:
  537. var_dict_torch_update[name] = torch.Tensor(map_dict[name]).type(torch.int64).to("cpu")
  538. logging.info("torch tensor: {}, manually assigning to: {}".format(
  539. name, map_dict[name]
  540. ))
  541. else:
  542. logging.warning("{} is missed from tf checkpoint".format(name))
  543. return var_dict_torch_update
  544. class ResNet34SpL2RegDiar(ResNet34_SP_L2Reg):
  545. def __init__(
  546. self,
  547. input_size,
  548. embedding_node="resnet1_dense",
  549. use_head_conv=True,
  550. batchnorm_momentum=0.5,
  551. use_head_maxpool=False,
  552. num_nodes_pooling_layer=256,
  553. layers_in_block=(3, 4, 6, 3),
  554. filters_in_block=(32, 64, 128, 256),
  555. num_nodes_resnet1=256,
  556. num_nodes_last_layer=256,
  557. pooling_type="window_shift",
  558. pool_size=20,
  559. stride=1,
  560. tf2torch_tensor_name_prefix_torch="encoder",
  561. tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
  562. ):
  563. """
  564. Author: Speech Lab, Alibaba Group, China
  565. TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
  566. https://arxiv.org/abs/2303.05397
  567. """
  568. super(ResNet34SpL2RegDiar, self).__init__(
  569. input_size,
  570. use_head_conv=use_head_conv,
  571. batchnorm_momentum=batchnorm_momentum,
  572. use_head_maxpool=use_head_maxpool,
  573. num_nodes_pooling_layer=num_nodes_pooling_layer,
  574. layers_in_block=layers_in_block,
  575. filters_in_block=filters_in_block,
  576. )
  577. self.embedding_node = embedding_node
  578. self.num_nodes_resnet1 = num_nodes_resnet1
  579. self.num_nodes_last_layer = num_nodes_last_layer
  580. self.pooling_type = pooling_type
  581. self.pool_size = pool_size
  582. self.stride = stride
  583. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  584. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  585. self.resnet1_dense = torch.nn.Linear(num_nodes_pooling_layer * 2, num_nodes_resnet1)
  586. self.resnet1_bn = torch.nn.BatchNorm1d(num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum)
  587. self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
  588. self.resnet2_bn = torch.nn.BatchNorm1d(num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum)
  589. def output_size(self) -> int:
  590. if self.embedding_node.startswith("resnet1"):
  591. return self.num_nodes_resnet1
  592. elif self.embedding_node.startswith("resnet2"):
  593. return self.num_nodes_last_layer
  594. return self.num_nodes_pooling_layer
  595. def forward(
  596. self,
  597. xs_pad: torch.Tensor,
  598. ilens: torch.Tensor,
  599. prev_states: torch.Tensor = None,
  600. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  601. endpoints = OrderedDict()
  602. res_out, ilens = super().forward(xs_pad, ilens)
  603. endpoints["resnet0_bn"] = res_out
  604. if self.pooling_type == "frame_gsp":
  605. features = statistic_pooling(res_out, ilens, (2, ))
  606. else:
  607. features, ilens = windowed_statistic_pooling(res_out, ilens, (2, ), self.pool_size, self.stride)
  608. features = features.transpose(1, 2)
  609. endpoints["pooling"] = features
  610. features = self.resnet1_dense(features)
  611. endpoints["resnet1_dense"] = features
  612. features = F.relu(features)
  613. endpoints["resnet1_relu"] = features
  614. features = self.resnet1_bn(features.transpose(1, 2)).transpose(1, 2)
  615. endpoints["resnet1_bn"] = features
  616. features = self.resnet2_dense(features)
  617. endpoints["resnet2_dense"] = features
  618. features = F.relu(features)
  619. endpoints["resnet2_relu"] = features
  620. features = self.resnet2_bn(features.transpose(1, 2)).transpose(1, 2)
  621. endpoints["resnet2_bn"] = features
  622. return endpoints[self.embedding_node], ilens, None
  623. def gen_tf2torch_map_dict(self):
  624. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  625. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  626. train_steps = 720000
  627. map_dict_local = {
  628. # torch: conv1d.weight in "out_channel in_channel kernel_size"
  629. # tf : conv1d.weight in "kernel_size in_channel out_channel"
  630. # torch: linear.weight in "out_channel in_channel"
  631. # tf : dense.weight in "in_channel out_channel"
  632. "{}.pre_conv.weight".format(tensor_name_prefix_torch):
  633. {"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
  634. "squeeze": None,
  635. "transpose": (3, 2, 0, 1),
  636. },
  637. "{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
  638. {"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
  639. "squeeze": None,
  640. "transpose": None,
  641. },
  642. "{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
  643. {"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
  644. "squeeze": None,
  645. "transpose": None,
  646. },
  647. "{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
  648. {"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
  649. "squeeze": None,
  650. "transpose": None,
  651. },
  652. "{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
  653. {"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
  654. "squeeze": None,
  655. "transpose": None,
  656. },
  657. "{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
  658. }
  659. for layer_idx in range(3):
  660. map_dict_local.update({
  661. "{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
  662. {"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
  663. "squeeze": None,
  664. "transpose": (2, 1, 0) if layer_idx == 0 else (1, 0),
  665. },
  666. "{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
  667. {"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
  668. "squeeze": None,
  669. "transpose": None,
  670. },
  671. "{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
  672. {"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
  673. "squeeze": None,
  674. "transpose": None,
  675. },
  676. "{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
  677. {"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
  678. "squeeze": None,
  679. "transpose": None,
  680. },
  681. "{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
  682. {"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
  683. "squeeze": None,
  684. "transpose": None,
  685. },
  686. "{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
  687. {"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
  688. "squeeze": None,
  689. "transpose": None,
  690. },
  691. "{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
  692. })
  693. for block_idx in range(len(self.layers_in_block)):
  694. for layer_idx in range(self.layers_in_block[block_idx]):
  695. for i in ["1", "2", "_sc"]:
  696. map_dict_local.update({
  697. "{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
  698. {"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
  699. "squeeze": None,
  700. "transpose": (3, 2, 0, 1),
  701. },
  702. "{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
  703. {"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
  704. "squeeze": None,
  705. "transpose": None,
  706. },
  707. "{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
  708. {"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
  709. "squeeze": None,
  710. "transpose": None,
  711. },
  712. "{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
  713. {"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
  714. "squeeze": None,
  715. "transpose": None,
  716. },
  717. "{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
  718. {"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
  719. "squeeze": None,
  720. "transpose": None,
  721. },
  722. "{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
  723. })
  724. return map_dict_local
  725. def convert_tf2torch(self,
  726. var_dict_tf,
  727. var_dict_torch,
  728. ):
  729. map_dict = self.gen_tf2torch_map_dict()
  730. var_dict_torch_update = dict()
  731. for name in sorted(var_dict_torch.keys(), reverse=False):
  732. if name.startswith(self.tf2torch_tensor_name_prefix_torch):
  733. if name in map_dict:
  734. if "num_batches_tracked" not in name:
  735. name_tf = map_dict[name]["name"]
  736. data_tf = var_dict_tf[name_tf]
  737. if map_dict[name]["squeeze"] is not None:
  738. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  739. if map_dict[name]["transpose"] is not None:
  740. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  741. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  742. assert var_dict_torch[name].size() == data_tf.size(), \
  743. "{}, {}, {} != {}".format(name, name_tf,
  744. var_dict_torch[name].size(), data_tf.size())
  745. var_dict_torch_update[name] = data_tf
  746. logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
  747. name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
  748. ))
  749. else:
  750. var_dict_torch_update[name] = torch.from_numpy(np.array(map_dict[name])).type(torch.int64).to("cpu")
  751. logging.info("torch tensor: {}, manually assigning to: {}".format(
  752. name, map_dict[name]
  753. ))
  754. else:
  755. logging.warning("{} is missed from tf checkpoint".format(name))
  756. return var_dict_torch_update