quant_noise.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import torch
  6. import torch.nn as nn
  7. def quant_noise(module, p, block_size):
  8. """
  9. Wraps modules and applies quantization noise to the weights for
  10. subsequent quantization with Iterative Product Quantization as
  11. described in "Training with Quantization Noise for Extreme Model Compression"
  12. Args:
  13. - module: nn.Module
  14. - p: amount of Quantization Noise
  15. - block_size: size of the blocks for subsequent quantization with iPQ
  16. Remarks:
  17. - Module weights must have the right sizes wrt the block size
  18. - Only Linear, Embedding and Conv2d modules are supported for the moment
  19. - For more detail on how to quantize by blocks with convolutional weights,
  20. see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
  21. - We implement the simplest form of noise here as stated in the paper
  22. which consists in randomly dropping blocks
  23. """
  24. # if no quantization noise, don't register hook
  25. if p <= 0:
  26. return module
  27. # supported modules
  28. assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
  29. # test whether module.weight has the right sizes wrt block_size
  30. is_conv = module.weight.ndim == 4
  31. # 2D matrix
  32. if not is_conv:
  33. assert (
  34. module.weight.size(1) % block_size == 0
  35. ), "Input features must be a multiple of block sizes"
  36. # 4D matrix
  37. else:
  38. # 1x1 convolutions
  39. if module.kernel_size == (1, 1):
  40. assert (
  41. module.in_channels % block_size == 0
  42. ), "Input channels must be a multiple of block sizes"
  43. # regular convolutions
  44. else:
  45. k = module.kernel_size[0] * module.kernel_size[1]
  46. assert k % block_size == 0, "Kernel size must be a multiple of block size"
  47. def _forward_pre_hook(mod, input):
  48. # no noise for evaluation
  49. if mod.training:
  50. if not is_conv:
  51. # gather weight and sizes
  52. weight = mod.weight
  53. in_features = weight.size(1)
  54. out_features = weight.size(0)
  55. # split weight matrix into blocks and randomly drop selected blocks
  56. mask = torch.zeros(
  57. in_features // block_size * out_features, device=weight.device
  58. )
  59. mask.bernoulli_(p)
  60. mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
  61. else:
  62. # gather weight and sizes
  63. weight = mod.weight
  64. in_channels = mod.in_channels
  65. out_channels = mod.out_channels
  66. # split weight matrix into blocks and randomly drop selected blocks
  67. if mod.kernel_size == (1, 1):
  68. mask = torch.zeros(
  69. int(in_channels // block_size * out_channels),
  70. device=weight.device,
  71. )
  72. mask.bernoulli_(p)
  73. mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
  74. else:
  75. mask = torch.zeros(
  76. weight.size(0), weight.size(1), device=weight.device
  77. )
  78. mask.bernoulli_(p)
  79. mask = (
  80. mask.unsqueeze(2)
  81. .unsqueeze(3)
  82. .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
  83. )
  84. # scale weights and apply mask
  85. mask = mask.to(
  86. torch.bool
  87. ) # x.bool() is not currently supported in TorchScript
  88. s = 1 / (1 - p)
  89. mod.weight.data = s * weight.masked_fill(mask, 0)
  90. module.register_forward_pre_hook(_forward_pre_hook)
  91. return module