|
|
@@ -15,6 +15,7 @@ from typing import Optional, Tuple
|
|
|
|
|
|
import torch.nn.functional as F
|
|
|
from funasr.modules.nets_utils import make_pad_mask
|
|
|
+import funasr.modules.lora.layers as lora
|
|
|
|
|
|
class MultiHeadedAttention(nn.Module):
|
|
|
"""Multi-Head Attention layer.
|
|
|
@@ -321,7 +322,7 @@ class MultiHeadedAttentionSANM(nn.Module):
|
|
|
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
|
|
|
+ def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
|
|
|
"""Construct an MultiHeadedAttention object."""
|
|
|
super(MultiHeadedAttentionSANM, self).__init__()
|
|
|
assert n_feat % n_head == 0
|
|
|
@@ -331,8 +332,16 @@ class MultiHeadedAttentionSANM(nn.Module):
|
|
|
# self.linear_q = nn.Linear(n_feat, n_feat)
|
|
|
# self.linear_k = nn.Linear(n_feat, n_feat)
|
|
|
# self.linear_v = nn.Linear(n_feat, n_feat)
|
|
|
- self.linear_out = nn.Linear(n_feat, n_feat)
|
|
|
- self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
|
|
|
+ if lora_list is not None:
|
|
|
+ if "o" in lora_list:
|
|
|
+ self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
|
|
|
+ else:
|
|
|
+ self.linear_out = nn.Linear(n_feat, n_feat)
|
|
|
+ lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
|
|
|
+ self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
|
|
|
+ else:
|
|
|
+ self.linear_out = nn.Linear(n_feat, n_feat)
|
|
|
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
|
|
|
self.attn = None
|
|
|
self.dropout = nn.Dropout(p=dropout_rate)
|
|
|
|
|
|
@@ -543,18 +552,25 @@ class MultiHeadedAttentionCrossAtt(nn.Module):
|
|
|
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, n_head, n_feat, dropout_rate, encoder_output_size=None):
|
|
|
+ def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
|
|
|
"""Construct an MultiHeadedAttention object."""
|
|
|
super(MultiHeadedAttentionCrossAtt, self).__init__()
|
|
|
assert n_feat % n_head == 0
|
|
|
# We assume d_v always equals d_k
|
|
|
self.d_k = n_feat // n_head
|
|
|
self.h = n_head
|
|
|
- self.linear_q = nn.Linear(n_feat, n_feat)
|
|
|
- # self.linear_k = nn.Linear(n_feat, n_feat)
|
|
|
- # self.linear_v = nn.Linear(n_feat, n_feat)
|
|
|
- self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
|
|
|
- self.linear_out = nn.Linear(n_feat, n_feat)
|
|
|
+ if lora_list is not None:
|
|
|
+ if "q" in lora_list:
|
|
|
+ self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
|
|
|
+ lora_kv_list = ["k" in lora_list, "v" in lora_list]
|
|
|
+ self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2,
|
|
|
+ r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
|
|
|
+ if "o" in lora_list:
|
|
|
+ self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
|
|
|
+ else:
|
|
|
+ self.linear_q = nn.Linear(n_feat, n_feat)
|
|
|
+ self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
|
|
|
+ self.linear_out = nn.Linear(n_feat, n_feat)
|
|
|
self.attn = None
|
|
|
self.dropout = nn.Dropout(p=dropout_rate)
|
|
|
|