瀏覽代碼

add lora finetune code

haoneng.lhn 2 年之前
父節點
當前提交
7ac54b3c97
共有 3 個文件被更改,包括 42 次插入10 次删除
  1. 5 1
      funasr/models/decoder/sanm_decoder.py
  2. 12 0
      funasr/models/encoder/sanm_encoder.py
  3. 25 9
      funasr/modules/attention.py

+ 5 - 1
funasr/models/decoder/sanm_decoder.py

@@ -833,6 +833,10 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
         att_layer_num: int = 6,
         kernel_size: int = 21,
         sanm_shfit: int = 0,
+        lora_list: List[str] = None,
+        lora_rank: int = 8,
+        lora_alpha: int = 16,
+        lora_dropout: float = 0.1,
         tf2torch_tensor_name_prefix_torch: str = "decoder",
         tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
     ):
@@ -885,7 +889,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
                     attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                 ),
                 MultiHeadedAttentionCrossAtt(
-                    attention_heads, attention_dim, src_attention_dropout_rate
+                    attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
                 ),
                 PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
                 dropout_rate,

+ 12 - 0
funasr/models/encoder/sanm_encoder.py

@@ -146,6 +146,10 @@ class SANMEncoder(AbsEncoder):
         interctc_use_conditioning: bool = False,
         kernel_size : int = 11,
         sanm_shfit : int = 0,
+        lora_list: List[str] = None,
+        lora_rank: int = 8,
+        lora_alpha: int = 16,
+        lora_dropout: float = 0.1,
         selfattention_layer_type: str = "sanm",
         tf2torch_tensor_name_prefix_torch: str = "encoder",
         tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
@@ -229,6 +233,10 @@ class SANMEncoder(AbsEncoder):
                 attention_dropout_rate,
                 kernel_size,
                 sanm_shfit,
+                lora_list,
+                lora_rank,
+                lora_alpha,
+                lora_dropout,
             )
 
             encoder_selfattn_layer_args = (
@@ -238,6 +246,10 @@ class SANMEncoder(AbsEncoder):
                 attention_dropout_rate,
                 kernel_size,
                 sanm_shfit,
+                lora_list,
+                lora_rank,
+                lora_alpha,
+                lora_dropout,
             )
         self.encoders0 = repeat(
             1,

+ 25 - 9
funasr/modules/attention.py

@@ -15,6 +15,7 @@ from typing import Optional, Tuple
 
 import torch.nn.functional as F
 from funasr.modules.nets_utils import make_pad_mask
+import funasr.modules.lora.layers as lora
 
 class MultiHeadedAttention(nn.Module):
     """Multi-Head Attention layer.
@@ -321,7 +322,7 @@ class MultiHeadedAttentionSANM(nn.Module):
 
     """
 
-    def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
+    def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
         """Construct an MultiHeadedAttention object."""
         super(MultiHeadedAttentionSANM, self).__init__()
         assert n_feat % n_head == 0
@@ -331,8 +332,16 @@ class MultiHeadedAttentionSANM(nn.Module):
         # self.linear_q = nn.Linear(n_feat, n_feat)
         # self.linear_k = nn.Linear(n_feat, n_feat)
         # self.linear_v = nn.Linear(n_feat, n_feat)
-        self.linear_out = nn.Linear(n_feat, n_feat)
-        self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+        if lora_list is not None:
+            if "o" in lora_list:
+                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+            else:
+                self.linear_out = nn.Linear(n_feat, n_feat)
+            lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
+            self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
+        else:
+            self.linear_out = nn.Linear(n_feat, n_feat)
+            self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
         self.attn = None
         self.dropout = nn.Dropout(p=dropout_rate)
 
@@ -543,18 +552,25 @@ class MultiHeadedAttentionCrossAtt(nn.Module):
 
     """
 
-    def __init__(self, n_head, n_feat, dropout_rate, encoder_output_size=None):
+    def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
         """Construct an MultiHeadedAttention object."""
         super(MultiHeadedAttentionCrossAtt, self).__init__()
         assert n_feat % n_head == 0
         # We assume d_v always equals d_k
         self.d_k = n_feat // n_head
         self.h = n_head
-        self.linear_q = nn.Linear(n_feat, n_feat)
-        # self.linear_k = nn.Linear(n_feat, n_feat)
-        # self.linear_v = nn.Linear(n_feat, n_feat)
-        self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
-        self.linear_out = nn.Linear(n_feat, n_feat)
+        if lora_list is not None:
+            if "q" in lora_list:
+                self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+            lora_kv_list = ["k" in lora_list, "v" in lora_list]
+            self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2, 
+                                  r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
+            if "o" in lora_list:
+                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+        else:
+            self.linear_q = nn.Linear(n_feat, n_feat)
+            self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+            self.linear_out = nn.Linear(n_feat, n_feat)
         self.attn = None
         self.dropout = nn.Dropout(p=dropout_rate)