| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import torch
- import torch.nn as nn
- def quant_noise(module, p, block_size):
- """
- Wraps modules and applies quantization noise to the weights for
- subsequent quantization with Iterative Product Quantization as
- described in "Training with Quantization Noise for Extreme Model Compression"
- Args:
- - module: nn.Module
- - p: amount of Quantization Noise
- - block_size: size of the blocks for subsequent quantization with iPQ
- Remarks:
- - Module weights must have the right sizes wrt the block size
- - Only Linear, Embedding and Conv2d modules are supported for the moment
- - For more detail on how to quantize by blocks with convolutional weights,
- see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
- - We implement the simplest form of noise here as stated in the paper
- which consists in randomly dropping blocks
- """
- # if no quantization noise, don't register hook
- if p <= 0:
- return module
- # supported modules
- assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
- # test whether module.weight has the right sizes wrt block_size
- is_conv = module.weight.ndim == 4
- # 2D matrix
- if not is_conv:
- assert (
- module.weight.size(1) % block_size == 0
- ), "Input features must be a multiple of block sizes"
- # 4D matrix
- else:
- # 1x1 convolutions
- if module.kernel_size == (1, 1):
- assert (
- module.in_channels % block_size == 0
- ), "Input channels must be a multiple of block sizes"
- # regular convolutions
- else:
- k = module.kernel_size[0] * module.kernel_size[1]
- assert k % block_size == 0, "Kernel size must be a multiple of block size"
- def _forward_pre_hook(mod, input):
- # no noise for evaluation
- if mod.training:
- if not is_conv:
- # gather weight and sizes
- weight = mod.weight
- in_features = weight.size(1)
- out_features = weight.size(0)
- # split weight matrix into blocks and randomly drop selected blocks
- mask = torch.zeros(
- in_features // block_size * out_features, device=weight.device
- )
- mask.bernoulli_(p)
- mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
- else:
- # gather weight and sizes
- weight = mod.weight
- in_channels = mod.in_channels
- out_channels = mod.out_channels
- # split weight matrix into blocks and randomly drop selected blocks
- if mod.kernel_size == (1, 1):
- mask = torch.zeros(
- int(in_channels // block_size * out_channels),
- device=weight.device,
- )
- mask.bernoulli_(p)
- mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
- else:
- mask = torch.zeros(
- weight.size(0), weight.size(1), device=weight.device
- )
- mask.bernoulli_(p)
- mask = (
- mask.unsqueeze(2)
- .unsqueeze(3)
- .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
- )
- # scale weights and apply mask
- mask = mask.to(
- torch.bool
- ) # x.bool() is not currently supported in TorchScript
- s = 1 / (1 - p)
- mod.weight.data = s * weight.masked_fill(mask, 0)
- module.register_forward_pre_hook(_forward_pre_hook)
- return module
|