layers.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # ------------------------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
  4. # ------------------------------------------------------------------------------------------
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import math
  9. from typing import Optional, List
  10. class LoRALayer():
  11. def __init__(
  12. self,
  13. r: int,
  14. lora_alpha: int,
  15. lora_dropout: float,
  16. merge_weights: bool,
  17. ):
  18. self.r = r
  19. self.lora_alpha = lora_alpha
  20. # Optional dropout
  21. if lora_dropout > 0.:
  22. self.lora_dropout = nn.Dropout(p=lora_dropout)
  23. else:
  24. self.lora_dropout = lambda x: x
  25. # Mark the weight as unmerged
  26. self.merged = False
  27. self.merge_weights = merge_weights
  28. class Embedding(nn.Embedding, LoRALayer):
  29. # LoRA implemented in a dense layer
  30. def __init__(
  31. self,
  32. num_embeddings: int,
  33. embedding_dim: int,
  34. r: int = 0,
  35. lora_alpha: int = 1,
  36. merge_weights: bool = True,
  37. **kwargs
  38. ):
  39. nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
  40. LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0,
  41. merge_weights=merge_weights)
  42. # Actual trainable parameters
  43. if r > 0:
  44. self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
  45. self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
  46. self.scaling = self.lora_alpha / self.r
  47. # Freezing the pre-trained weight matrix
  48. self.weight.requires_grad = False
  49. self.reset_parameters()
  50. def reset_parameters(self):
  51. nn.Embedding.reset_parameters(self)
  52. if hasattr(self, 'lora_A'):
  53. # initialize A the same way as the default for nn.Linear and B to zero
  54. nn.init.zeros_(self.lora_A)
  55. nn.init.normal_(self.lora_B)
  56. def train(self, mode: bool = True):
  57. nn.Embedding.train(self, mode)
  58. if self.merge_weights and self.merged:
  59. # Make sure that the weights are not merged
  60. if self.r > 0:
  61. self.weight.data -= (self.lora_B @ self.lora_A).T * self.scaling
  62. self.merged = False
  63. def eval(self):
  64. nn.Linear.eval(self)
  65. if self.merge_weights and not self.merged:
  66. # Merge the weights and mark it
  67. if self.r > 0:
  68. self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
  69. self.merged = True
  70. def forward(self, x: torch.Tensor):
  71. if self.r > 0 and not self.merged:
  72. result = nn.Embedding.forward(self, x)
  73. if self.r > 0:
  74. after_A = F.embedding(
  75. x, self.lora_A.T, self.padding_idx, self.max_norm,
  76. self.norm_type, self.scale_grad_by_freq, self.sparse
  77. )
  78. result += (after_A @ self.lora_B.T) * self.scaling
  79. return result
  80. else:
  81. return nn.Embedding.forward(self, x)
  82. class Linear(nn.Linear, LoRALayer):
  83. # LoRA implemented in a dense layer
  84. def __init__(
  85. self,
  86. in_features: int,
  87. out_features: int,
  88. r: int = 0,
  89. lora_alpha: int = 1,
  90. lora_dropout: float = 0.,
  91. fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
  92. merge_weights: bool = True,
  93. **kwargs
  94. ):
  95. nn.Linear.__init__(self, in_features, out_features, **kwargs)
  96. LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
  97. merge_weights=merge_weights)
  98. self.fan_in_fan_out = fan_in_fan_out
  99. # Actual trainable parameters
  100. if r > 0:
  101. self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
  102. self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
  103. self.scaling = self.lora_alpha / self.r
  104. # Freezing the pre-trained weight matrix
  105. self.weight.requires_grad = False
  106. self.reset_parameters()
  107. if fan_in_fan_out:
  108. self.weight.data = self.weight.data.T
  109. def reset_parameters(self):
  110. nn.Linear.reset_parameters(self)
  111. if hasattr(self, 'lora_A'):
  112. # initialize A the same way as the default for nn.Linear and B to zero
  113. nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
  114. nn.init.zeros_(self.lora_B)
  115. def train(self, mode: bool = True):
  116. def T(w):
  117. return w.T if self.fan_in_fan_out else w
  118. nn.Linear.train(self, mode)
  119. if self.merge_weights and self.merged:
  120. # Make sure that the weights are not merged
  121. if self.r > 0:
  122. self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
  123. self.merged = False
  124. def eval(self):
  125. def T(w):
  126. return w.T if self.fan_in_fan_out else w
  127. nn.Linear.eval(self)
  128. if self.merge_weights and not self.merged:
  129. # Merge the weights and mark it
  130. if self.r > 0:
  131. self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
  132. self.merged = True
  133. def forward(self, x: torch.Tensor):
  134. def T(w):
  135. return w.T if self.fan_in_fan_out else w
  136. if self.r > 0 and not self.merged:
  137. result = F.linear(x, T(self.weight), bias=self.bias)
  138. if self.r > 0:
  139. result += (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
  140. return result
  141. else:
  142. return F.linear(x, T(self.weight), bias=self.bias)
  143. class MergedLinear(nn.Linear, LoRALayer):
  144. # LoRA implemented in a dense layer
  145. def __init__(
  146. self,
  147. in_features: int,
  148. out_features: int,
  149. r: int = 0,
  150. lora_alpha: int = 1,
  151. lora_dropout: float = 0.,
  152. enable_lora: List[bool] = [False],
  153. fan_in_fan_out: bool = False,
  154. merge_weights: bool = True,
  155. **kwargs
  156. ):
  157. nn.Linear.__init__(self, in_features, out_features, **kwargs)
  158. LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
  159. merge_weights=merge_weights)
  160. assert out_features % len(enable_lora) == 0, \
  161. 'The length of enable_lora must divide out_features'
  162. self.enable_lora = enable_lora
  163. self.fan_in_fan_out = fan_in_fan_out
  164. # Actual trainable parameters
  165. if r > 0 and any(enable_lora):
  166. self.lora_A = nn.Parameter(
  167. self.weight.new_zeros((r * sum(enable_lora), in_features)))
  168. self.lora_B = nn.Parameter(
  169. self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
  170. ) # weights for Conv1D with groups=sum(enable_lora)
  171. self.scaling = self.lora_alpha / self.r
  172. # Freezing the pre-trained weight matrix
  173. self.weight.requires_grad = False
  174. # Compute the indices
  175. self.lora_ind = self.weight.new_zeros(
  176. (out_features, ), dtype=torch.bool
  177. ).view(len(enable_lora), -1)
  178. self.lora_ind[enable_lora, :] = True
  179. self.lora_ind = self.lora_ind.view(-1)
  180. self.reset_parameters()
  181. if fan_in_fan_out:
  182. self.weight.data = self.weight.data.T
  183. def reset_parameters(self):
  184. nn.Linear.reset_parameters(self)
  185. if hasattr(self, 'lora_A'):
  186. # initialize A the same way as the default for nn.Linear and B to zero
  187. nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
  188. nn.init.zeros_(self.lora_B)
  189. def zero_pad(self, x):
  190. result = x.new_zeros((*x.shape[:-1], self.out_features))
  191. result = result.view(-1, self.out_features)
  192. result[:, self.lora_ind] = x.reshape(
  193. -1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
  194. )
  195. return result.view((*x.shape[:-1], self.out_features))
  196. def train(self, mode: bool = True):
  197. def T(w):
  198. return w.T if self.fan_in_fan_out else w
  199. nn.Linear.train(self, mode)
  200. if self.merge_weights and self.merged:
  201. # Make sure that the weights are not merged
  202. if self.r > 0 and any(self.enable_lora):
  203. delta_w = F.conv1d(
  204. self.lora_A.data.unsqueeze(0),
  205. self.lora_B.data.unsqueeze(-1),
  206. groups=sum(self.enable_lora)
  207. ).squeeze(0)
  208. self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
  209. self.merged = False
  210. def eval(self):
  211. def T(w):
  212. return w.T if self.fan_in_fan_out else w
  213. nn.Linear.eval(self)
  214. if self.merge_weights and not self.merged:
  215. # Merge the weights and mark it
  216. if self.r > 0 and any(self.enable_lora):
  217. delta_w = F.conv1d(
  218. self.lora_A.data.unsqueeze(0),
  219. self.lora_B.data.unsqueeze(-1),
  220. groups=sum(self.enable_lora)
  221. ).squeeze(0)
  222. self.weight.data += self.zero_pad(T(delta_w * self.scaling))
  223. self.merged = True
  224. def forward(self, x: torch.Tensor):
  225. def T(w):
  226. return w.T if self.fan_in_fan_out else w
  227. if self.merged:
  228. return F.linear(x, T(self.weight), bias=self.bias)
  229. else:
  230. result = F.linear(x, T(self.weight), bias=self.bias)
  231. if self.r > 0:
  232. after_A = F.linear(self.lora_dropout(x), self.lora_A)
  233. after_B = F.conv1d(
  234. after_A.transpose(-2, -1),
  235. self.lora_B.unsqueeze(-1),
  236. groups=sum(self.enable_lora)
  237. ).transpose(-2, -1)
  238. result += self.zero_pad(after_B) * self.scaling
  239. return result
  240. class Conv2d(nn.Conv2d, LoRALayer):
  241. # LoRA implemented in a dense layer
  242. def __init__(
  243. self,
  244. in_channels: int,
  245. out_channels: int,
  246. kernel_size: int,
  247. r: int = 0,
  248. lora_alpha: int = 1,
  249. lora_dropout: float = 0.,
  250. merge_weights: bool = True,
  251. **kwargs
  252. ):
  253. nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
  254. LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
  255. merge_weights=merge_weights)
  256. assert type(kernel_size) is int
  257. # Actual trainable parameters
  258. if r > 0:
  259. self.lora_A = nn.Parameter(
  260. self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
  261. )
  262. self.lora_B = nn.Parameter(
  263. self.weight.new_zeros((out_channels*kernel_size, r*kernel_size))
  264. )
  265. self.scaling = self.lora_alpha / self.r
  266. # Freezing the pre-trained weight matrix
  267. self.weight.requires_grad = False
  268. self.reset_parameters()
  269. def reset_parameters(self):
  270. nn.Conv2d.reset_parameters(self)
  271. if hasattr(self, 'lora_A'):
  272. # initialize A the same way as the default for nn.Linear and B to zero
  273. nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
  274. nn.init.zeros_(self.lora_B)
  275. def train(self, mode: bool = True):
  276. nn.Conv2d.train(self, mode)
  277. if self.merge_weights and self.merged:
  278. # Make sure that the weights are not merged
  279. self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
  280. self.merged = False
  281. def eval(self):
  282. nn.Conv2d.eval(self)
  283. if self.merge_weights and not self.merged:
  284. # Merge the weights and mark it
  285. self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
  286. self.merged = True
  287. def forward(self, x: torch.Tensor):
  288. if self.r > 0 and not self.merged:
  289. return F.conv2d(
  290. x,
  291. self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
  292. self.bias, self.stride, self.padding, self.dilation, self.groups
  293. )
  294. return nn.Conv2d.forward(self, x)