| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323 |
- # ------------------------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
- # ------------------------------------------------------------------------------------------
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import math
- from typing import Optional, List
- class LoRALayer():
- def __init__(
- self,
- r: int,
- lora_alpha: int,
- lora_dropout: float,
- merge_weights: bool,
- ):
- self.r = r
- self.lora_alpha = lora_alpha
- # Optional dropout
- if lora_dropout > 0.:
- self.lora_dropout = nn.Dropout(p=lora_dropout)
- else:
- self.lora_dropout = lambda x: x
- # Mark the weight as unmerged
- self.merged = False
- self.merge_weights = merge_weights
- class Embedding(nn.Embedding, LoRALayer):
- # LoRA implemented in a dense layer
- def __init__(
- self,
- num_embeddings: int,
- embedding_dim: int,
- r: int = 0,
- lora_alpha: int = 1,
- merge_weights: bool = True,
- **kwargs
- ):
- nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
- LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0,
- merge_weights=merge_weights)
- # Actual trainable parameters
- if r > 0:
- self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
- self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
- self.scaling = self.lora_alpha / self.r
- # Freezing the pre-trained weight matrix
- self.weight.requires_grad = False
- self.reset_parameters()
- def reset_parameters(self):
- nn.Embedding.reset_parameters(self)
- if hasattr(self, 'lora_A'):
- # initialize A the same way as the default for nn.Linear and B to zero
- nn.init.zeros_(self.lora_A)
- nn.init.normal_(self.lora_B)
- def train(self, mode: bool = True):
- nn.Embedding.train(self, mode)
- if self.merge_weights and self.merged:
- # Make sure that the weights are not merged
- if self.r > 0:
- self.weight.data -= (self.lora_B @ self.lora_A).T * self.scaling
- self.merged = False
- def eval(self):
- nn.Linear.eval(self)
- if self.merge_weights and not self.merged:
- # Merge the weights and mark it
- if self.r > 0:
- self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
- self.merged = True
- def forward(self, x: torch.Tensor):
- if self.r > 0 and not self.merged:
- result = nn.Embedding.forward(self, x)
- if self.r > 0:
- after_A = F.embedding(
- x, self.lora_A.T, self.padding_idx, self.max_norm,
- self.norm_type, self.scale_grad_by_freq, self.sparse
- )
- result += (after_A @ self.lora_B.T) * self.scaling
- return result
- else:
- return nn.Embedding.forward(self, x)
- class Linear(nn.Linear, LoRALayer):
- # LoRA implemented in a dense layer
- def __init__(
- self,
- in_features: int,
- out_features: int,
- r: int = 0,
- lora_alpha: int = 1,
- lora_dropout: float = 0.,
- fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
- merge_weights: bool = True,
- **kwargs
- ):
- nn.Linear.__init__(self, in_features, out_features, **kwargs)
- LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
- merge_weights=merge_weights)
- self.fan_in_fan_out = fan_in_fan_out
- # Actual trainable parameters
- if r > 0:
- self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
- self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
- self.scaling = self.lora_alpha / self.r
- # Freezing the pre-trained weight matrix
- self.weight.requires_grad = False
- self.reset_parameters()
- if fan_in_fan_out:
- self.weight.data = self.weight.data.T
- def reset_parameters(self):
- nn.Linear.reset_parameters(self)
- if hasattr(self, 'lora_A'):
- # initialize A the same way as the default for nn.Linear and B to zero
- nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
- nn.init.zeros_(self.lora_B)
- def train(self, mode: bool = True):
- def T(w):
- return w.T if self.fan_in_fan_out else w
- nn.Linear.train(self, mode)
- if self.merge_weights and self.merged:
- # Make sure that the weights are not merged
- if self.r > 0:
- self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
- self.merged = False
- def eval(self):
- def T(w):
- return w.T if self.fan_in_fan_out else w
- nn.Linear.eval(self)
- if self.merge_weights and not self.merged:
- # Merge the weights and mark it
- if self.r > 0:
- self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
- self.merged = True
- def forward(self, x: torch.Tensor):
- def T(w):
- return w.T if self.fan_in_fan_out else w
- if self.r > 0 and not self.merged:
- result = F.linear(x, T(self.weight), bias=self.bias)
- if self.r > 0:
- result += (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
- return result
- else:
- return F.linear(x, T(self.weight), bias=self.bias)
- class MergedLinear(nn.Linear, LoRALayer):
- # LoRA implemented in a dense layer
- def __init__(
- self,
- in_features: int,
- out_features: int,
- r: int = 0,
- lora_alpha: int = 1,
- lora_dropout: float = 0.,
- enable_lora: List[bool] = [False],
- fan_in_fan_out: bool = False,
- merge_weights: bool = True,
- **kwargs
- ):
- nn.Linear.__init__(self, in_features, out_features, **kwargs)
- LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
- merge_weights=merge_weights)
- assert out_features % len(enable_lora) == 0, \
- 'The length of enable_lora must divide out_features'
- self.enable_lora = enable_lora
- self.fan_in_fan_out = fan_in_fan_out
- # Actual trainable parameters
- if r > 0 and any(enable_lora):
- self.lora_A = nn.Parameter(
- self.weight.new_zeros((r * sum(enable_lora), in_features)))
- self.lora_B = nn.Parameter(
- self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
- ) # weights for Conv1D with groups=sum(enable_lora)
- self.scaling = self.lora_alpha / self.r
- # Freezing the pre-trained weight matrix
- self.weight.requires_grad = False
- # Compute the indices
- self.lora_ind = self.weight.new_zeros(
- (out_features, ), dtype=torch.bool
- ).view(len(enable_lora), -1)
- self.lora_ind[enable_lora, :] = True
- self.lora_ind = self.lora_ind.view(-1)
- self.reset_parameters()
- if fan_in_fan_out:
- self.weight.data = self.weight.data.T
- def reset_parameters(self):
- nn.Linear.reset_parameters(self)
- if hasattr(self, 'lora_A'):
- # initialize A the same way as the default for nn.Linear and B to zero
- nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
- nn.init.zeros_(self.lora_B)
- def zero_pad(self, x):
- result = x.new_zeros((*x.shape[:-1], self.out_features))
- result = result.view(-1, self.out_features)
- result[:, self.lora_ind] = x.reshape(
- -1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
- )
- return result.view((*x.shape[:-1], self.out_features))
- def train(self, mode: bool = True):
- def T(w):
- return w.T if self.fan_in_fan_out else w
- nn.Linear.train(self, mode)
- if self.merge_weights and self.merged:
- # Make sure that the weights are not merged
- if self.r > 0 and any(self.enable_lora):
- delta_w = F.conv1d(
- self.lora_A.data.unsqueeze(0),
- self.lora_B.data.unsqueeze(-1),
- groups=sum(self.enable_lora)
- ).squeeze(0)
- self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
- self.merged = False
- def eval(self):
- def T(w):
- return w.T if self.fan_in_fan_out else w
- nn.Linear.eval(self)
- if self.merge_weights and not self.merged:
- # Merge the weights and mark it
- if self.r > 0 and any(self.enable_lora):
- delta_w = F.conv1d(
- self.lora_A.data.unsqueeze(0),
- self.lora_B.data.unsqueeze(-1),
- groups=sum(self.enable_lora)
- ).squeeze(0)
- self.weight.data += self.zero_pad(T(delta_w * self.scaling))
- self.merged = True
- def forward(self, x: torch.Tensor):
- def T(w):
- return w.T if self.fan_in_fan_out else w
- if self.merged:
- return F.linear(x, T(self.weight), bias=self.bias)
- else:
- result = F.linear(x, T(self.weight), bias=self.bias)
- if self.r > 0:
- after_A = F.linear(self.lora_dropout(x), self.lora_A)
- after_B = F.conv1d(
- after_A.transpose(-2, -1),
- self.lora_B.unsqueeze(-1),
- groups=sum(self.enable_lora)
- ).transpose(-2, -1)
- result += self.zero_pad(after_B) * self.scaling
- return result
- class Conv2d(nn.Conv2d, LoRALayer):
- # LoRA implemented in a dense layer
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- r: int = 0,
- lora_alpha: int = 1,
- lora_dropout: float = 0.,
- merge_weights: bool = True,
- **kwargs
- ):
- nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
- LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
- merge_weights=merge_weights)
- assert type(kernel_size) is int
- # Actual trainable parameters
- if r > 0:
- self.lora_A = nn.Parameter(
- self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
- )
- self.lora_B = nn.Parameter(
- self.weight.new_zeros((out_channels*kernel_size, r*kernel_size))
- )
- self.scaling = self.lora_alpha / self.r
- # Freezing the pre-trained weight matrix
- self.weight.requires_grad = False
- self.reset_parameters()
- def reset_parameters(self):
- nn.Conv2d.reset_parameters(self)
- if hasattr(self, 'lora_A'):
- # initialize A the same way as the default for nn.Linear and B to zero
- nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
- nn.init.zeros_(self.lora_B)
- def train(self, mode: bool = True):
- nn.Conv2d.train(self, mode)
- if self.merge_weights and self.merged:
- # Make sure that the weights are not merged
- self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
- self.merged = False
- def eval(self):
- nn.Conv2d.eval(self)
- if self.merge_weights and not self.merged:
- # Merge the weights and mark it
- self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
- self.merged = True
- def forward(self, x: torch.Tensor):
- if self.r > 0 and not self.merged:
- return F.conv2d(
- x,
- self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
- self.bias, self.stride, self.padding, self.dilation, self.groups
- )
- return nn.Conv2d.forward(self, x)
|