prepare_checkpoint.py 551 B

123456789101112131415161718
  1. import os
  2. import shutil
  3. from modelscope.hub.snapshot_download import snapshot_download
  4. if __name__ == '__main__':
  5. import sys
  6. model = sys.argv[1]
  7. checkpoint_dir = sys.argv[2]
  8. checkpoint_name = sys.argv[3]
  9. try:
  10. pretrained_model_path = snapshot_download(model, cache_dir=checkpoint_dir)
  11. except BaseException:
  12. raise BaseException(f"Please download pretrain model from ModelScope firstly.")
  13. shutil.copy(os.path.join(checkpoint_dir, checkpoint_name), os.path.join(pretrained_model_path, "model.pb"))