timm_modules.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import torch.nn as nn
  2. import collections.abc
  3. from itertools import repeat
  4. from functools import partial
  5. def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
  6. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  7. This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
  8. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  9. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
  10. changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
  11. 'survival rate' as the argument.
  12. """
  13. if drop_prob == 0. or not training:
  14. return x
  15. keep_prob = 1 - drop_prob
  16. shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  17. random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
  18. if keep_prob > 0.0 and scale_by_keep:
  19. random_tensor.div_(keep_prob)
  20. return x * random_tensor
  21. class DropPath(nn.Module):
  22. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  23. """
  24. def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
  25. super(DropPath, self).__init__()
  26. self.drop_prob = drop_prob
  27. self.scale_by_keep = scale_by_keep
  28. def forward(self, x):
  29. return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
  30. def extra_repr(self):
  31. return f'drop_prob={round(self.drop_prob,3):0.3f}'
  32. # From PyTorch internals
  33. def _ntuple(n):
  34. def parse(x):
  35. if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
  36. return tuple(x)
  37. return tuple(repeat(x, n))
  38. return parse
  39. to_1tuple = _ntuple(1)
  40. to_2tuple = _ntuple(2)
  41. to_3tuple = _ntuple(3)
  42. to_4tuple = _ntuple(4)
  43. to_ntuple = _ntuple
  44. class Mlp(nn.Module):
  45. """ MLP as used in Vision Transformer, MLP-Mixer and related networks
  46. """
  47. def __init__(
  48. self,
  49. in_features,
  50. hidden_features=None,
  51. out_features=None,
  52. act_layer=nn.GELU,
  53. norm_layer=None,
  54. bias=True,
  55. drop=0.,
  56. use_conv=False,
  57. ):
  58. super().__init__()
  59. out_features = out_features or in_features
  60. hidden_features = hidden_features or in_features
  61. bias = to_2tuple(bias)
  62. drop_probs = to_2tuple(drop)
  63. linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
  64. self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
  65. self.act = act_layer()
  66. self.drop1 = nn.Dropout(drop_probs[0])
  67. self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
  68. self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
  69. self.drop2 = nn.Dropout(drop_probs[1])
  70. def forward(self, x):
  71. x = self.fc1(x)
  72. x = self.act(x)
  73. x = self.drop1(x)
  74. x = self.norm(x)
  75. x = self.fc2(x)
  76. x = self.drop2(x)
  77. return x