resnet34_encoder.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import torch
  2. from torch.nn import functional as F
  3. from funasr.models.encoder.abs_encoder import AbsEncoder
  4. from typing import Tuple
  5. class BasicLayer(torch.nn.Module):
  6. def __init__(self, in_filters: int, filters: int, stride: int, bn_momentum: float = 0.5):
  7. super().__init__()
  8. self.stride = stride
  9. self.in_filters = in_filters
  10. self.filters = filters
  11. self.bn1 = torch.nn.BatchNorm2d(in_filters, eps=1e-3, momentum=bn_momentum, affine=True)
  12. self.relu1 = torch.nn.ReLU()
  13. self.conv1 = torch.nn.Conv2d(in_filters, filters, 3, stride, bias=False)
  14. self.bn2 = torch.nn.BatchNorm2d(filters, eps=1e-3, momentum=bn_momentum, affine=True)
  15. self.relu2 = torch.nn.ReLU()
  16. self.conv2 = torch.nn.Conv2d(filters, filters, 3, 1, bias=False)
  17. if in_filters != filters or stride > 1:
  18. self.conv_sc = torch.nn.Conv2d(in_filters, filters, 1, stride, bias=False)
  19. self.bn_sc = torch.nn.BatchNorm2d(filters, eps=1e-3, momentum=bn_momentum, affine=True)
  20. def proper_padding(self, x, stride):
  21. # align padding mode to tf.layers.conv2d with padding_mod="same"
  22. if stride == 1:
  23. return F.pad(x, (1, 1, 1, 1), "constant", 0)
  24. elif stride == 2:
  25. h, w = x.size(2), x.size(3)
  26. # (left, right, top, bottom)
  27. return F.pad(x, (w % 2, 1, h % 2, 1), "constant", 0)
  28. def forward(self, xs_pad, ilens):
  29. identity = xs_pad
  30. if self.in_filters != self.filters or self.stride > 1:
  31. identity = self.conv_sc(identity)
  32. identity = self.bn_sc(identity)
  33. xs_pad = self.relu1(self.bn1(xs_pad))
  34. xs_pad = self.proper_padding(xs_pad, self.stride)
  35. xs_pad = self.conv1(xs_pad)
  36. xs_pad = self.relu2(self.bn2(xs_pad))
  37. xs_pad = self.proper_padding(xs_pad, 1)
  38. xs_pad = self.conv2(xs_pad)
  39. if self.stride == 2:
  40. ilens = (ilens + 1) // self.stride
  41. return xs_pad + identity, ilens
  42. class BasicBlock(torch.nn.Module):
  43. def __init__(self, in_filters, filters, num_layer, stride, bn_momentum=0.5):
  44. super().__init__()
  45. self.num_layer = num_layer
  46. for i in range(num_layer):
  47. layer = BasicLayer(in_filters if i == 0 else filters, filters,
  48. stride if i == 0 else 1, bn_momentum)
  49. self.add_module("layer_{}".format(i), layer)
  50. def forward(self, xs_pad, ilens):
  51. for i in range(self.num_layer):
  52. xs_pad, ilens = self._modules["layer_{}".format(i)](xs_pad, ilens)
  53. return xs_pad, ilens
  54. class ResNet34(AbsEncoder):
  55. def __init__(
  56. self,
  57. input_size,
  58. use_head_conv=True,
  59. batchnorm_momentum=0.5,
  60. use_head_maxpool=False,
  61. num_nodes_pooling_layer=256,
  62. layers_in_block=(3, 4, 6, 3),
  63. filters_in_block=(32, 64, 128, 256),
  64. ):
  65. super(ResNet34, self).__init__()
  66. self.use_head_conv = use_head_conv
  67. self.use_head_maxpool = use_head_maxpool
  68. self.num_nodes_pooling_layer = num_nodes_pooling_layer
  69. self.layers_in_block = layers_in_block
  70. self.filters_in_block = filters_in_block
  71. self.input_size = input_size
  72. pre_filters = filters_in_block[0]
  73. if use_head_conv:
  74. self.pre_conv = torch.nn.Conv2d(1, pre_filters, 3, 1, 1, bias=False, padding_mode="zeros")
  75. self.pre_conv_bn = torch.nn.BatchNorm2d(pre_filters, eps=1e-3, momentum=batchnorm_momentum)
  76. if use_head_maxpool:
  77. self.head_maxpool = torch.nn.MaxPool2d(3, 1, padding=1)
  78. for i in range(len(layers_in_block)):
  79. if i == 0:
  80. in_filters = pre_filters if self.use_head_conv else 1
  81. else:
  82. in_filters = filters_in_block[i-1]
  83. block = BasicBlock(in_filters,
  84. filters=filters_in_block[i],
  85. num_layer=layers_in_block[i],
  86. stride=1 if i == 0 else 2,
  87. bn_momentum=batchnorm_momentum)
  88. self.add_module("block_{}".format(i), block)
  89. self.resnet0_dense = torch.nn.Conv2d(filters_in_block[-1], num_nodes_pooling_layer, 1)
  90. self.resnet0_bn = torch.nn.BatchNorm2d(num_nodes_pooling_layer, eps=1e-3, momentum=batchnorm_momentum)
  91. def output_size(self) -> int:
  92. return self.num_nodes_pooling_layer
  93. def forward(self, xs_pad: torch.Tensor, ilens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  94. features = xs_pad
  95. assert features.size(-1) == self.input_size, \
  96. "Dimension of features {} doesn't match the input_size {}.".format(features.size(-1), self.input_size)
  97. features = torch.unsqueeze(features, dim=1)
  98. if self.use_head_conv:
  99. features = self.pre_conv(features)
  100. features = self.pre_conv_bn(features)
  101. features = F.relu(features)
  102. if self.use_head_maxpool:
  103. features = self.head_maxpool(features)
  104. resnet_outs, resnet_out_lens = features, ilens
  105. for i in range(len(self.layers_in_block)):
  106. block = self._modules["block_{}".format(i)]
  107. resnet_outs, resnet_out_lens = block(resnet_outs, resnet_out_lens)
  108. features = self.resnet0_dense(resnet_outs)
  109. features = F.relu(features)
  110. features = self.resnet0_bn(features)
  111. return features, ilens // 8