|
|
@@ -33,10 +33,10 @@ if __name__ == '__main__':
|
|
|
params.batch_bins = 2000 # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒,
|
|
|
params.max_epoch = 20 # 最大训练轮数
|
|
|
params.lr = 0.0002 # 设置学习率
|
|
|
- init_param = []
|
|
|
- freeze_param = []
|
|
|
- ignore_init_mismatch = True
|
|
|
- use_lora = False
|
|
|
+ init_param = [] # 初始模型路径,默认加载modelscope模型初始化,例如: ["checkpoint/20epoch.pb"]
|
|
|
+ freeze_param = [] # 模型参数freeze, 例如: ["encoder"]
|
|
|
+ ignore_init_mismatch = True # 是否忽略模型参数初始化不匹配
|
|
|
+ use_lora = False # 是否使用lora进行模型微调
|
|
|
params.param_dict = {"init_param":init_param, "freeze_param": freeze_param, "ignore_init_mismatch": ignore_init_mismatch}
|
|
|
if use_lora:
|
|
|
enable_lora = True
|