|
|
@@ -11,9 +11,9 @@ from typing import Optional, List
|
|
|
|
|
|
class LoRALayer():
|
|
|
def __init__(
|
|
|
- self,
|
|
|
- r: int,
|
|
|
- lora_alpha: int,
|
|
|
+ self,
|
|
|
+ r: int,
|
|
|
+ lora_alpha: int,
|
|
|
lora_dropout: float,
|
|
|
merge_weights: bool,
|
|
|
):
|
|
|
@@ -61,40 +61,42 @@ class Embedding(nn.Embedding, LoRALayer):
|
|
|
|
|
|
def train(self, mode: bool = True):
|
|
|
nn.Embedding.train(self, mode)
|
|
|
- if mode:
|
|
|
- if self.merge_weights and self.merged:
|
|
|
- # Make sure that the weights are not merged
|
|
|
- if self.r > 0:
|
|
|
- self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
|
|
|
- self.merged = False
|
|
|
- else:
|
|
|
- if self.merge_weights and not self.merged:
|
|
|
- # Merge the weights and mark it
|
|
|
- if self.r > 0:
|
|
|
- self.weight.data += (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
|
|
|
- self.merged = True
|
|
|
-
|
|
|
+ if self.merge_weights and self.merged:
|
|
|
+ # Make sure that the weights are not merged
|
|
|
+ if self.r > 0:
|
|
|
+ self.weight.data -= (self.lora_B @ self.lora_A).T * self.scaling
|
|
|
+ self.merged = False
|
|
|
+
|
|
|
+ def eval(self):
|
|
|
+ nn.Linear.eval(self)
|
|
|
+ if self.merge_weights and not self.merged:
|
|
|
+ # Merge the weights and mark it
|
|
|
+ if self.r > 0:
|
|
|
+ self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
|
|
|
+ self.merged = True
|
|
|
+
|
|
|
def forward(self, x: torch.Tensor):
|
|
|
if self.r > 0 and not self.merged:
|
|
|
result = nn.Embedding.forward(self, x)
|
|
|
- after_A = F.embedding(
|
|
|
- x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm,
|
|
|
- self.norm_type, self.scale_grad_by_freq, self.sparse
|
|
|
- )
|
|
|
- result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
|
|
|
+ if self.r > 0:
|
|
|
+ after_A = F.embedding(
|
|
|
+ x, self.lora_A.T, self.padding_idx, self.max_norm,
|
|
|
+ self.norm_type, self.scale_grad_by_freq, self.sparse
|
|
|
+ )
|
|
|
+ result += (after_A @ self.lora_B.T) * self.scaling
|
|
|
return result
|
|
|
else:
|
|
|
return nn.Embedding.forward(self, x)
|
|
|
-
|
|
|
+
|
|
|
|
|
|
class Linear(nn.Linear, LoRALayer):
|
|
|
# LoRA implemented in a dense layer
|
|
|
def __init__(
|
|
|
- self,
|
|
|
- in_features: int,
|
|
|
- out_features: int,
|
|
|
- r: int = 0,
|
|
|
- lora_alpha: int = 1,
|
|
|
+ self,
|
|
|
+ in_features: int,
|
|
|
+ out_features: int,
|
|
|
+ r: int = 0,
|
|
|
+ lora_alpha: int = 1,
|
|
|
lora_dropout: float = 0.,
|
|
|
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
|
|
merge_weights: bool = True,
|
|
|
@@ -114,7 +116,7 @@ class Linear(nn.Linear, LoRALayer):
|
|
|
self.weight.requires_grad = False
|
|
|
self.reset_parameters()
|
|
|
if fan_in_fan_out:
|
|
|
- self.weight.data = self.weight.data.transpose(0, 1)
|
|
|
+ self.weight.data = self.weight.data.T
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
nn.Linear.reset_parameters(self)
|
|
|
@@ -125,27 +127,31 @@ class Linear(nn.Linear, LoRALayer):
|
|
|
|
|
|
def train(self, mode: bool = True):
|
|
|
def T(w):
|
|
|
- return w.transpose(0, 1) if self.fan_in_fan_out else w
|
|
|
+ return w.T if self.fan_in_fan_out else w
|
|
|
nn.Linear.train(self, mode)
|
|
|
- if mode:
|
|
|
- if self.merge_weights and self.merged:
|
|
|
- # Make sure that the weights are not merged
|
|
|
- if self.r > 0:
|
|
|
- self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
|
|
- self.merged = False
|
|
|
- else:
|
|
|
- if self.merge_weights and not self.merged:
|
|
|
- # Merge the weights and mark it
|
|
|
- if self.r > 0:
|
|
|
- self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
|
|
- self.merged = True
|
|
|
+ if self.merge_weights and self.merged:
|
|
|
+ # Make sure that the weights are not merged
|
|
|
+ if self.r > 0:
|
|
|
+ self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
|
|
+ self.merged = False
|
|
|
+
|
|
|
+ def eval(self):
|
|
|
+ def T(w):
|
|
|
+ return w.T if self.fan_in_fan_out else w
|
|
|
+ nn.Linear.eval(self)
|
|
|
+ if self.merge_weights and not self.merged:
|
|
|
+ # Merge the weights and mark it
|
|
|
+ if self.r > 0:
|
|
|
+ self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
|
|
+ self.merged = True
|
|
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
|
def T(w):
|
|
|
- return w.transpose(0, 1) if self.fan_in_fan_out else w
|
|
|
+ return w.T if self.fan_in_fan_out else w
|
|
|
if self.r > 0 and not self.merged:
|
|
|
- result = F.linear(x, T(self.weight), bias=self.bias)
|
|
|
- result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
|
|
|
+ result = F.linear(x, T(self.weight), bias=self.bias)
|
|
|
+ if self.r > 0:
|
|
|
+ result += (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
|
|
|
return result
|
|
|
else:
|
|
|
return F.linear(x, T(self.weight), bias=self.bias)
|
|
|
@@ -154,11 +160,11 @@ class Linear(nn.Linear, LoRALayer):
|
|
|
class MergedLinear(nn.Linear, LoRALayer):
|
|
|
# LoRA implemented in a dense layer
|
|
|
def __init__(
|
|
|
- self,
|
|
|
- in_features: int,
|
|
|
- out_features: int,
|
|
|
- r: int = 0,
|
|
|
- lora_alpha: int = 1,
|
|
|
+ self,
|
|
|
+ in_features: int,
|
|
|
+ out_features: int,
|
|
|
+ r: int = 0,
|
|
|
+ lora_alpha: int = 1,
|
|
|
lora_dropout: float = 0.,
|
|
|
enable_lora: List[bool] = [False],
|
|
|
fan_in_fan_out: bool = False,
|
|
|
@@ -190,7 +196,7 @@ class MergedLinear(nn.Linear, LoRALayer):
|
|
|
self.lora_ind = self.lora_ind.view(-1)
|
|
|
self.reset_parameters()
|
|
|
if fan_in_fan_out:
|
|
|
- self.weight.data = self.weight.data.transpose(0, 1)
|
|
|
+ self.weight.data = self.weight.data.T
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
nn.Linear.reset_parameters(self)
|
|
|
@@ -209,34 +215,37 @@ class MergedLinear(nn.Linear, LoRALayer):
|
|
|
|
|
|
def train(self, mode: bool = True):
|
|
|
def T(w):
|
|
|
- return w.transpose(0, 1) if self.fan_in_fan_out else w
|
|
|
+ return w.T if self.fan_in_fan_out else w
|
|
|
nn.Linear.train(self, mode)
|
|
|
- if mode:
|
|
|
- if self.merge_weights and self.merged:
|
|
|
- # Make sure that the weights are not merged
|
|
|
- if self.r > 0 and any(self.enable_lora):
|
|
|
- delta_w = F.conv1d(
|
|
|
- self.lora_A.data.unsqueeze(0),
|
|
|
- self.lora_B.data.unsqueeze(-1),
|
|
|
- groups=sum(self.enable_lora)
|
|
|
- ).squeeze(0)
|
|
|
- self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
|
|
|
- self.merged = False
|
|
|
- else:
|
|
|
- if self.merge_weights and not self.merged:
|
|
|
- # Merge the weights and mark it
|
|
|
- if self.r > 0 and any(self.enable_lora):
|
|
|
- delta_w = F.conv1d(
|
|
|
- self.lora_A.data.unsqueeze(0),
|
|
|
- self.lora_B.data.unsqueeze(-1),
|
|
|
- groups=sum(self.enable_lora)
|
|
|
- ).squeeze(0)
|
|
|
- self.weight.data += self.zero_pad(T(delta_w * self.scaling))
|
|
|
- self.merged = True
|
|
|
+ if self.merge_weights and self.merged:
|
|
|
+ # Make sure that the weights are not merged
|
|
|
+ if self.r > 0 and any(self.enable_lora):
|
|
|
+ delta_w = F.conv1d(
|
|
|
+ self.lora_A.data.unsqueeze(0),
|
|
|
+ self.lora_B.data.unsqueeze(-1),
|
|
|
+ groups=sum(self.enable_lora)
|
|
|
+ ).squeeze(0)
|
|
|
+ self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
|
|
|
+ self.merged = False
|
|
|
+
|
|
|
+ def eval(self):
|
|
|
+ def T(w):
|
|
|
+ return w.T if self.fan_in_fan_out else w
|
|
|
+ nn.Linear.eval(self)
|
|
|
+ if self.merge_weights and not self.merged:
|
|
|
+ # Merge the weights and mark it
|
|
|
+ if self.r > 0 and any(self.enable_lora):
|
|
|
+ delta_w = F.conv1d(
|
|
|
+ self.lora_A.data.unsqueeze(0),
|
|
|
+ self.lora_B.data.unsqueeze(-1),
|
|
|
+ groups=sum(self.enable_lora)
|
|
|
+ ).squeeze(0)
|
|
|
+ self.weight.data += self.zero_pad(T(delta_w * self.scaling))
|
|
|
+ self.merged = True
|
|
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
|
def T(w):
|
|
|
- return w.transpose(0, 1) if self.fan_in_fan_out else w
|
|
|
+ return w.T if self.fan_in_fan_out else w
|
|
|
if self.merged:
|
|
|
return F.linear(x, T(self.weight), bias=self.bias)
|
|
|
else:
|
|
|
@@ -244,76 +253,71 @@ class MergedLinear(nn.Linear, LoRALayer):
|
|
|
if self.r > 0:
|
|
|
after_A = F.linear(self.lora_dropout(x), self.lora_A)
|
|
|
after_B = F.conv1d(
|
|
|
- after_A.transpose(-2, -1),
|
|
|
- self.lora_B.unsqueeze(-1),
|
|
|
+ after_A.transpose(-2, -1),
|
|
|
+ self.lora_B.unsqueeze(-1),
|
|
|
groups=sum(self.enable_lora)
|
|
|
).transpose(-2, -1)
|
|
|
result += self.zero_pad(after_B) * self.scaling
|
|
|
return result
|
|
|
-
|
|
|
|
|
|
-class ConvLoRA(nn.Module, LoRALayer):
|
|
|
- def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
|
|
|
- super(ConvLoRA, self).__init__()
|
|
|
- self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
|
|
|
- LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
|
|
|
- assert isinstance(kernel_size, int)
|
|
|
+
|
|
|
+class Conv2d(nn.Conv2d, LoRALayer):
|
|
|
+ # LoRA implemented in a dense layer
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ in_channels: int,
|
|
|
+ out_channels: int,
|
|
|
+ kernel_size: int,
|
|
|
+ r: int = 0,
|
|
|
+ lora_alpha: int = 1,
|
|
|
+ lora_dropout: float = 0.,
|
|
|
+ merge_weights: bool = True,
|
|
|
+ **kwargs
|
|
|
+ ):
|
|
|
+ nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
|
|
|
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
|
|
|
+ merge_weights=merge_weights)
|
|
|
+ assert type(kernel_size) is int
|
|
|
# Actual trainable parameters
|
|
|
if r > 0:
|
|
|
self.lora_A = nn.Parameter(
|
|
|
- self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
|
|
|
+ self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
|
|
|
)
|
|
|
self.lora_B = nn.Parameter(
|
|
|
- self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size))
|
|
|
+ self.weight.new_zeros((out_channels*kernel_size, r*kernel_size))
|
|
|
)
|
|
|
self.scaling = self.lora_alpha / self.r
|
|
|
# Freezing the pre-trained weight matrix
|
|
|
- self.conv.weight.requires_grad = False
|
|
|
+ self.weight.requires_grad = False
|
|
|
self.reset_parameters()
|
|
|
- self.merged = False
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
- self.conv.reset_parameters()
|
|
|
+ nn.Conv2d.reset_parameters(self)
|
|
|
if hasattr(self, 'lora_A'):
|
|
|
# initialize A the same way as the default for nn.Linear and B to zero
|
|
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
|
|
nn.init.zeros_(self.lora_B)
|
|
|
|
|
|
- def train(self, mode=True):
|
|
|
- super(ConvLoRA, self).train(mode)
|
|
|
- if mode:
|
|
|
- if self.merge_weights and self.merged:
|
|
|
- if self.r > 0:
|
|
|
- # Make sure that the weights are not merged
|
|
|
- self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
|
|
|
- self.merged = False
|
|
|
- else:
|
|
|
- if self.merge_weights and not self.merged:
|
|
|
- if self.r > 0:
|
|
|
- # Merge the weights and mark it
|
|
|
- self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
|
|
|
- self.merged = True
|
|
|
+ def train(self, mode: bool = True):
|
|
|
+ nn.Conv2d.train(self, mode)
|
|
|
+ if self.merge_weights and self.merged:
|
|
|
+ # Make sure that the weights are not merged
|
|
|
+ self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
|
|
|
+ self.merged = False
|
|
|
|
|
|
- def forward(self, x):
|
|
|
+ def eval(self):
|
|
|
+ nn.Conv2d.eval(self)
|
|
|
+ if self.merge_weights and not self.merged:
|
|
|
+ # Merge the weights and mark it
|
|
|
+ self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
|
|
|
+ self.merged = True
|
|
|
+
|
|
|
+ def forward(self, x: torch.Tensor):
|
|
|
if self.r > 0 and not self.merged:
|
|
|
- return self.conv._conv_forward(
|
|
|
- x,
|
|
|
- self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling,
|
|
|
- self.conv.bias
|
|
|
+ return F.conv2d(
|
|
|
+ x,
|
|
|
+ self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
|
|
|
+ self.bias, self.stride, self.padding, self.dilation, self.groups
|
|
|
)
|
|
|
- return self.conv(x)
|
|
|
-
|
|
|
-class Conv2d(ConvLoRA):
|
|
|
- def __init__(self, *args, **kwargs):
|
|
|
- super(Conv2d, self).__init__(nn.Conv2d, *args, **kwargs)
|
|
|
-
|
|
|
-class Conv1d(ConvLoRA):
|
|
|
- def __init__(self, *args, **kwargs):
|
|
|
- super(Conv1d, self).__init__(nn.Conv1d, *args, **kwargs)
|
|
|
-
|
|
|
-# Can Extend to other ones like this
|
|
|
-
|
|
|
-class Conv3d(ConvLoRA):
|
|
|
- def __init__(self, *args, **kwargs):
|
|
|
- super(Conv3d, self).__init__(nn.Conv3d, *args, **kwargs)
|
|
|
+ return nn.Conv2d.forward(self, x)
|
|
|
|