utils.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. # ------------------------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
  4. # ------------------------------------------------------------------------------------------
  5. import torch
  6. import torch.nn as nn
  7. from typing import Dict
  8. from .layers import LoRALayer
  9. def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
  10. for n, p in model.named_parameters():
  11. if 'lora_' not in n and 'cif' not in n:
  12. p.requires_grad = False
  13. if bias == 'none':
  14. return
  15. elif bias == 'all':
  16. for n, p in model.named_parameters():
  17. if 'bias' in n:
  18. p.requires_grad = True
  19. elif bias == 'lora_only':
  20. for m in model.modules():
  21. if isinstance(m, LoRALayer) and \
  22. hasattr(m, 'bias') and \
  23. m.bias is not None:
  24. m.bias.requires_grad = True
  25. else:
  26. raise NotImplementedError
  27. def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
  28. my_state_dict = model.state_dict()
  29. if bias == 'none':
  30. return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
  31. elif bias == 'all':
  32. return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k}
  33. elif bias == 'lora_only':
  34. to_return = {}
  35. for k in my_state_dict:
  36. if 'lora_' in k:
  37. to_return[k] = my_state_dict[k]
  38. bias_name = k.split('lora_')[0]+'bias'
  39. if bias_name in my_state_dict:
  40. to_return[bias_name] = my_state_dict[bias_name]
  41. return to_return
  42. else:
  43. raise NotImplementedError