finetune.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435
  1. import os
  2. from modelscope.metainfo import Trainers
  3. from modelscope.trainers import build_trainer
  4. from funasr.datasets.ms_dataset import MsDataset
  5. def modelscope_finetune(params):
  6. if not os.path.exists(params["output_dir"]):
  7. os.makedirs(params["output_dir"], exist_ok=True)
  8. # dataset split ["train", "validation"]
  9. ds_dict = MsDataset.load(params["data_dir"])
  10. kwargs = dict(
  11. model=params["model"],
  12. model_revision=params["model_revision"],
  13. data_dir=ds_dict,
  14. dataset_type=params["dataset_type"],
  15. work_dir=params["output_dir"],
  16. batch_bins=params["batch_bins"],
  17. max_epoch=params["max_epoch"],
  18. lr=params["lr"])
  19. trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
  20. trainer.train()
  21. if __name__ == '__main__':
  22. params = {}
  23. params["output_dir"] = "./checkpoint"
  24. params["data_dir"] = "./data"
  25. params["batch_bins"] = 2000
  26. params["dataset_type"] = "small"
  27. params["max_epoch"] = 50
  28. params["lr"] = 0.00005
  29. params["model"] = "damo/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-online"
  30. params["model_revision"] = None
  31. modelscope_finetune(params)