finetune.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536
  1. import os
  2. from modelscope.metainfo import Trainers
  3. from modelscope.trainers import build_trainer
  4. from funasr.datasets.ms_dataset import MsDataset
  5. from funasr.utils.modelscope_param import modelscope_args
  6. def modelscope_finetune(params):
  7. if not os.path.exists(params.output_dir):
  8. os.makedirs(params.output_dir, exist_ok=True)
  9. # dataset split ["train", "validation"]
  10. ds_dict = MsDataset.load(params.data_path)
  11. kwargs = dict(
  12. model=params.model,
  13. data_dir=ds_dict,
  14. dataset_type=params.dataset_type,
  15. work_dir=params.output_dir,
  16. batch_bins=params.batch_bins,
  17. max_epoch=params.max_epoch,
  18. lr=params.lr)
  19. trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
  20. trainer.train()
  21. if __name__ == '__main__':
  22. params = modelscope_args(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", data_path="./data")
  23. params.output_dir = "./checkpoint" # m模型保存路径
  24. params.data_path = "./example_data/" # 数据路径
  25. params.dataset_type = "small" # 小数据量设置small,若数据量大于1000小时,请使用large
  26. params.batch_bins = 2000 # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒,
  27. params.max_epoch = 50 # 最大训练轮数
  28. params.lr = 0.00005 # 设置学习率
  29. modelscope_finetune(params)