haoneng.lhn 2 лет назад
Родитель
Сommit
36c43d4c9f

+ 14 - 2
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/finetune.py

@@ -19,7 +19,8 @@ def modelscope_finetune(params):
         work_dir=params.output_dir,
         batch_bins=params.batch_bins,
         max_epoch=params.max_epoch,
-        lr=params.lr)
+        lr=params.lr,
+        mate_params=params.param_dict)
     trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
     trainer.train()
 
@@ -30,7 +31,18 @@ if __name__ == '__main__':
     params.data_path = "./example_data/"            # 数据路径
     params.dataset_type = "small"                   # 小数据量设置small,若数据量大于1000小时,请使用large
     params.batch_bins = 2000                       # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒,
-    params.max_epoch = 50                           # 最大训练轮数
+    params.max_epoch = 5                           # 最大训练轮数
     params.lr = 0.00005                             # 设置学习率
+    init_param = []
+    freeze_param = []
+    ignore_init_mismatch = True
+    use_lora = False
+    params.param_dict = {"init_param":init_param, "freeze_param": freeze_param, "ignore_init_mismatch": ignore_init_mismatch}
+    if use_lora:
+        enable_lora = True
+        lora_bias = "all"
+        lora_params = {"lora_list":['q','v'], "lora_rank":8, "lora_alpha":16, "lora_dropout":0.1}
+        lora_config = {"enable_lora": enable_lora, "lora_bias": lora_bias, "lora_params": lora_params}
+        params.param_dict.update(lora_config)
     
     modelscope_finetune(params)

+ 11 - 4
funasr/bin/build_trainer.py

@@ -92,6 +92,14 @@ def build_trainer(modelscope_dict,
     for key, value in finetune_configs.items():
         if hasattr(args, key):
             setattr(args, key, value)
+    if mate_params is not None:
+        for key, value in mate_params.items():
+            if hasattr(args, key):
+                setattr(args, key, value)
+    if mate_params is not None and "lora_params" in mate_params:
+        lora_params = mate_params['lora_params']
+        configs['encoder_conf'].update(lora_params) 
+        configs['decoder_conf'].update(lora_params) 
 
     # prepare data
     args.dataset_type = dataset_type
@@ -106,6 +114,9 @@ def build_trainer(modelscope_dict,
     else:
         raise ValueError(f"Not supported dataset_type={args.dataset_type}")
     args.init_param = [init_param]
+    if mate_params is not None and "init_param" in mate_params:
+        if len(mate_params["init_param"]) != 0:
+            args.init_param = mate_params["init_param"]
     args.cmvn_file = cmvn_file
     if os.path.exists(seg_dict_file):
         args.seg_dict_file = seg_dict_file
@@ -144,10 +155,6 @@ def build_trainer(modelscope_dict,
         args.patience = None
     args.local_rank = local_rank
     args.distributed = distributed
-    if mate_params is not None:
-        for key, value in mate_params.items():
-            if hasattr(args, key):
-                setattr(args, key, value)
     ASRTask.finetune_args = args
 
     return ASRTask

+ 13 - 3
funasr/modules/attention.py

@@ -338,7 +338,10 @@ class MultiHeadedAttentionSANM(nn.Module):
             else:
                 self.linear_out = nn.Linear(n_feat, n_feat)
             lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
-            self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
+            if lora_qkv_list == [False, False, False]:
+                self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+            else:
+                self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
         else:
             self.linear_out = nn.Linear(n_feat, n_feat)
             self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
@@ -562,11 +565,18 @@ class MultiHeadedAttentionCrossAtt(nn.Module):
         if lora_list is not None:
             if "q" in lora_list:
                 self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+            else:
+                self.linear_q = nn.Linear(n_feat, n_feat)
             lora_kv_list = ["k" in lora_list, "v" in lora_list]
-            self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2, 
-                                  r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
+            if lora_kv_list == [False, False]:
+                self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+            else:
+                self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2, 
+                                      r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
             if "o" in lora_list:
                 self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+            else:
+                self.linear_out = nn.Linear(n_feat, n_feat)
         else:
             self.linear_q = nn.Linear(n_feat, n_feat)
             self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)

+ 126 - 122
funasr/modules/lora/layers.py

@@ -11,9 +11,9 @@ from typing import Optional, List
 
 class LoRALayer():
     def __init__(
-        self, 
-        r: int, 
-        lora_alpha: int, 
+        self,
+        r: int,
+        lora_alpha: int,
         lora_dropout: float,
         merge_weights: bool,
     ):
@@ -61,40 +61,42 @@ class Embedding(nn.Embedding, LoRALayer):
 
     def train(self, mode: bool = True):
         nn.Embedding.train(self, mode)
-        if mode:
-            if self.merge_weights and self.merged:
-                # Make sure that the weights are not merged
-                if self.r > 0:
-                    self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
-                self.merged = False
-        else:
-            if self.merge_weights and not self.merged:
-                # Merge the weights and mark it
-                if self.r > 0:
-                    self.weight.data += (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
-                self.merged = True
-        
+        if self.merge_weights and self.merged:
+            # Make sure that the weights are not merged
+            if self.r > 0:
+                self.weight.data -= (self.lora_B @ self.lora_A).T * self.scaling
+            self.merged = False
+
+    def eval(self):
+        nn.Linear.eval(self)
+        if self.merge_weights and not self.merged:
+            # Merge the weights and mark it
+            if self.r > 0:
+                self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
+            self.merged = True
+
     def forward(self, x: torch.Tensor):
         if self.r > 0 and not self.merged:
             result = nn.Embedding.forward(self, x)
-            after_A = F.embedding(
-                x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm,
-                self.norm_type, self.scale_grad_by_freq, self.sparse
-            )
-            result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
+            if self.r > 0:
+                after_A = F.embedding(
+                    x, self.lora_A.T, self.padding_idx, self.max_norm,
+                    self.norm_type, self.scale_grad_by_freq, self.sparse
+                )
+                result += (after_A @ self.lora_B.T) * self.scaling
             return result
         else:
             return nn.Embedding.forward(self, x)
-            
+
 
 class Linear(nn.Linear, LoRALayer):
     # LoRA implemented in a dense layer
     def __init__(
-        self, 
-        in_features: int, 
-        out_features: int, 
-        r: int = 0, 
-        lora_alpha: int = 1, 
+        self,
+        in_features: int,
+        out_features: int,
+        r: int = 0,
+        lora_alpha: int = 1,
         lora_dropout: float = 0.,
         fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
         merge_weights: bool = True,
@@ -114,7 +116,7 @@ class Linear(nn.Linear, LoRALayer):
             self.weight.requires_grad = False
         self.reset_parameters()
         if fan_in_fan_out:
-            self.weight.data = self.weight.data.transpose(0, 1)
+            self.weight.data = self.weight.data.T
 
     def reset_parameters(self):
         nn.Linear.reset_parameters(self)
@@ -125,27 +127,31 @@ class Linear(nn.Linear, LoRALayer):
 
     def train(self, mode: bool = True):
         def T(w):
-            return w.transpose(0, 1) if self.fan_in_fan_out else w
+            return w.T if self.fan_in_fan_out else w
         nn.Linear.train(self, mode)
-        if mode:
-            if self.merge_weights and self.merged:
-                # Make sure that the weights are not merged
-                if self.r > 0:
-                    self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
-                self.merged = False
-        else:
-            if self.merge_weights and not self.merged:
-                # Merge the weights and mark it
-                if self.r > 0:
-                    self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
-                self.merged = True       
+        if self.merge_weights and self.merged:
+            # Make sure that the weights are not merged
+            if self.r > 0:
+                self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
+            self.merged = False
+
+    def eval(self):
+        def T(w):
+            return w.T if self.fan_in_fan_out else w
+        nn.Linear.eval(self)
+        if self.merge_weights and not self.merged:
+            # Merge the weights and mark it
+            if self.r > 0:
+                self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
+            self.merged = True
 
     def forward(self, x: torch.Tensor):
         def T(w):
-            return w.transpose(0, 1) if self.fan_in_fan_out else w
+            return w.T if self.fan_in_fan_out else w
         if self.r > 0 and not self.merged:
-            result = F.linear(x, T(self.weight), bias=self.bias)            
-            result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
+            result = F.linear(x, T(self.weight), bias=self.bias)
+            if self.r > 0:
+                result += (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
             return result
         else:
             return F.linear(x, T(self.weight), bias=self.bias)
@@ -154,11 +160,11 @@ class Linear(nn.Linear, LoRALayer):
 class MergedLinear(nn.Linear, LoRALayer):
     # LoRA implemented in a dense layer
     def __init__(
-        self, 
-        in_features: int, 
-        out_features: int, 
-        r: int = 0, 
-        lora_alpha: int = 1, 
+        self,
+        in_features: int,
+        out_features: int,
+        r: int = 0,
+        lora_alpha: int = 1,
         lora_dropout: float = 0.,
         enable_lora: List[bool] = [False],
         fan_in_fan_out: bool = False,
@@ -190,7 +196,7 @@ class MergedLinear(nn.Linear, LoRALayer):
             self.lora_ind = self.lora_ind.view(-1)
         self.reset_parameters()
         if fan_in_fan_out:
-            self.weight.data = self.weight.data.transpose(0, 1)
+            self.weight.data = self.weight.data.T
 
     def reset_parameters(self):
         nn.Linear.reset_parameters(self)
@@ -209,34 +215,37 @@ class MergedLinear(nn.Linear, LoRALayer):
 
     def train(self, mode: bool = True):
         def T(w):
-            return w.transpose(0, 1) if self.fan_in_fan_out else w
+            return w.T if self.fan_in_fan_out else w
         nn.Linear.train(self, mode)
-        if mode:
-            if self.merge_weights and self.merged:
-                # Make sure that the weights are not merged
-                if self.r > 0 and any(self.enable_lora):
-                    delta_w = F.conv1d(
-                        self.lora_A.data.unsqueeze(0), 
-                        self.lora_B.data.unsqueeze(-1), 
-                        groups=sum(self.enable_lora)
-                    ).squeeze(0)
-                    self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
-                self.merged = False
-        else:
-            if self.merge_weights and not self.merged:
-                # Merge the weights and mark it
-                if self.r > 0 and any(self.enable_lora):
-                    delta_w = F.conv1d(
-                        self.lora_A.data.unsqueeze(0), 
-                        self.lora_B.data.unsqueeze(-1), 
-                        groups=sum(self.enable_lora)
-                    ).squeeze(0)
-                    self.weight.data += self.zero_pad(T(delta_w * self.scaling))
-                self.merged = True        
+        if self.merge_weights and self.merged:
+            # Make sure that the weights are not merged
+            if self.r > 0 and any(self.enable_lora):
+                delta_w = F.conv1d(
+                    self.lora_A.data.unsqueeze(0),
+                    self.lora_B.data.unsqueeze(-1),
+                    groups=sum(self.enable_lora)
+                ).squeeze(0)
+                self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
+            self.merged = False
+
+    def eval(self):
+        def T(w):
+            return w.T if self.fan_in_fan_out else w
+        nn.Linear.eval(self)
+        if self.merge_weights and not self.merged:
+            # Merge the weights and mark it
+            if self.r > 0 and any(self.enable_lora):
+                delta_w = F.conv1d(
+                    self.lora_A.data.unsqueeze(0),
+                    self.lora_B.data.unsqueeze(-1),
+                    groups=sum(self.enable_lora)
+                ).squeeze(0)
+                self.weight.data += self.zero_pad(T(delta_w * self.scaling))
+            self.merged = True
 
     def forward(self, x: torch.Tensor):
         def T(w):
-            return w.transpose(0, 1) if self.fan_in_fan_out else w
+            return w.T if self.fan_in_fan_out else w
         if self.merged:
             return F.linear(x, T(self.weight), bias=self.bias)
         else:
@@ -244,76 +253,71 @@ class MergedLinear(nn.Linear, LoRALayer):
             if self.r > 0:
                 after_A = F.linear(self.lora_dropout(x), self.lora_A)
                 after_B = F.conv1d(
-                    after_A.transpose(-2, -1), 
-                    self.lora_B.unsqueeze(-1), 
+                    after_A.transpose(-2, -1),
+                    self.lora_B.unsqueeze(-1),
                     groups=sum(self.enable_lora)
                 ).transpose(-2, -1)
                 result += self.zero_pad(after_B) * self.scaling
             return result
-        
 
-class ConvLoRA(nn.Module, LoRALayer):
-    def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
-        super(ConvLoRA, self).__init__()
-        self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
-        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
-        assert isinstance(kernel_size, int)
+
+class Conv2d(nn.Conv2d, LoRALayer):
+    # LoRA implemented in a dense layer
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        r: int = 0,
+        lora_alpha: int = 1,
+        lora_dropout: float = 0.,
+        merge_weights: bool = True,
+        **kwargs
+    ):
+        nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
+        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
+                           merge_weights=merge_weights)
+        assert type(kernel_size) is int
         # Actual trainable parameters
         if r > 0:
             self.lora_A = nn.Parameter(
-                self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
+                self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
             )
             self.lora_B = nn.Parameter(
-              self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size))
+                self.weight.new_zeros((out_channels*kernel_size, r*kernel_size))
             )
             self.scaling = self.lora_alpha / self.r
             # Freezing the pre-trained weight matrix
-            self.conv.weight.requires_grad = False
+            self.weight.requires_grad = False
         self.reset_parameters()
-        self.merged = False
 
     def reset_parameters(self):
-        self.conv.reset_parameters()
+        nn.Conv2d.reset_parameters(self)
         if hasattr(self, 'lora_A'):
             # initialize A the same way as the default for nn.Linear and B to zero
             nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
             nn.init.zeros_(self.lora_B)
 
-    def train(self, mode=True):
-        super(ConvLoRA, self).train(mode)
-        if mode:
-            if self.merge_weights and self.merged:
-                if self.r > 0:
-                    # Make sure that the weights are not merged
-                    self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
-                self.merged = False
-        else:
-            if self.merge_weights and not self.merged:
-                if self.r > 0:
-                    # Merge the weights and mark it
-                    self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
-                self.merged = True
+    def train(self, mode: bool = True):
+        nn.Conv2d.train(self, mode)
+        if self.merge_weights and self.merged:
+            # Make sure that the weights are not merged
+            self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
+            self.merged = False
 
-    def forward(self, x):
+    def eval(self):
+        nn.Conv2d.eval(self)
+        if self.merge_weights and not self.merged:
+            # Merge the weights and mark it
+            self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
+            self.merged = True
+
+    def forward(self, x: torch.Tensor):
         if self.r > 0 and not self.merged:
-            return self.conv._conv_forward(
-                x, 
-                self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling,
-                self.conv.bias
+            return F.conv2d(
+                x,
+                self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
+                self.bias, self.stride, self.padding, self.dilation, self.groups
             )
-        return self.conv(x)
-
-class Conv2d(ConvLoRA):
-    def __init__(self, *args, **kwargs):
-        super(Conv2d, self).__init__(nn.Conv2d, *args, **kwargs)
-
-class Conv1d(ConvLoRA):
-    def __init__(self, *args, **kwargs):
-        super(Conv1d, self).__init__(nn.Conv1d, *args, **kwargs)
-
-# Can Extend to other ones like this
-
-class Conv3d(ConvLoRA):
-    def __init__(self, *args, **kwargs):
-        super(Conv3d, self).__init__(nn.Conv3d, *args, **kwargs)
+        return nn.Conv2d.forward(self, x)