download_from_hub.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import os
  2. import json
  3. from omegaconf import OmegaConf, DictConfig
  4. from funasr.download.name_maps_from_hub import name_maps_ms, name_maps_hf, name_maps_openai
  5. def download_model(**kwargs):
  6. hub = kwargs.get("hub", "ms")
  7. if hub == "ms":
  8. kwargs = download_from_ms(**kwargs)
  9. elif hub == "hf":
  10. pass
  11. elif hub == "openai":
  12. model_or_path = kwargs.get("model")
  13. if os.path.exists(model_or_path):
  14. # local path
  15. kwargs["model_path"] = model_or_path
  16. kwargs["model"] = "WhisperWarp"
  17. else:
  18. # model name
  19. if model_or_path in name_maps_openai:
  20. model_or_path = name_maps_openai[model_or_path]
  21. kwargs["model_path"] = model_or_path
  22. return kwargs
  23. def download_from_ms(**kwargs):
  24. model_or_path = kwargs.get("model")
  25. if model_or_path in name_maps_ms:
  26. model_or_path = name_maps_ms[model_or_path]
  27. model_revision = kwargs.get("model_revision")
  28. if not os.path.exists(model_or_path) and "model_path" not in kwargs:
  29. try:
  30. model_or_path = get_or_download_model_dir(model_or_path, model_revision,
  31. is_training=kwargs.get("is_training"),
  32. check_latest=kwargs.get("check_latest", True))
  33. except Exception as e:
  34. print(f"Download: {model_or_path} failed!: {e}")
  35. kwargs["model_path"] = model_or_path if "model_path" not in kwargs else kwargs["model_path"]
  36. if os.path.exists(os.path.join(model_or_path, "configuration.json")):
  37. with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
  38. conf_json = json.load(f)
  39. cfg = {}
  40. if "file_path_metas" in conf_json:
  41. add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
  42. cfg.update(kwargs)
  43. config = OmegaConf.load(cfg["config"])
  44. kwargs = OmegaConf.merge(config, cfg)
  45. kwargs["model"] = config["model"]
  46. elif os.path.exists(os.path.join(model_or_path, "config.yaml")) and os.path.exists(os.path.join(model_or_path, "model.pt")):
  47. config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
  48. kwargs = OmegaConf.merge(config, kwargs)
  49. init_param = os.path.join(model_or_path, "model.pb")
  50. kwargs["init_param"] = init_param
  51. if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
  52. kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
  53. if os.path.exists(os.path.join(model_or_path, "tokens.json")):
  54. kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
  55. if os.path.exists(os.path.join(model_or_path, "seg_dict")):
  56. kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
  57. if os.path.exists(os.path.join(model_or_path, "bpe.model")):
  58. kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
  59. kwargs["model"] = config["model"]
  60. if os.path.exists(os.path.join(model_or_path, "am.mvn")):
  61. kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
  62. if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
  63. kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
  64. if isinstance(kwargs, DictConfig):
  65. kwargs = OmegaConf.to_container(kwargs, resolve=True)
  66. return kwargs
  67. def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg = {}):
  68. if isinstance(file_path_metas, dict):
  69. for k, v in file_path_metas.items():
  70. if isinstance(v, str):
  71. p = os.path.join(model_or_path, v)
  72. if os.path.exists(p):
  73. cfg[k] = p
  74. elif isinstance(v, dict):
  75. if k not in cfg:
  76. cfg[k] = {}
  77. add_file_root_path(model_or_path, v, cfg[k])
  78. return cfg
  79. def get_or_download_model_dir(
  80. model,
  81. model_revision=None,
  82. is_training=False,
  83. check_latest=True,
  84. ):
  85. """ Get local model directory or download model if necessary.
  86. Args:
  87. model (str): model id or path to local model directory.
  88. model_revision (str, optional): model version number.
  89. :param is_training:
  90. """
  91. from modelscope.hub.check_model import check_local_model_is_latest
  92. from modelscope.hub.snapshot_download import snapshot_download
  93. from modelscope.utils.constant import Invoke, ThirdParty
  94. key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
  95. if os.path.exists(model) and check_latest:
  96. model_cache_dir = model if os.path.isdir(
  97. model) else os.path.dirname(model)
  98. try:
  99. check_local_model_is_latest(
  100. model_cache_dir,
  101. user_agent={
  102. Invoke.KEY: key,
  103. ThirdParty.KEY: "funasr"
  104. })
  105. except:
  106. print("could not check the latest version")
  107. else:
  108. model_cache_dir = snapshot_download(
  109. model,
  110. revision=model_revision,
  111. user_agent={
  112. Invoke.KEY: key,
  113. ThirdParty.KEY: "funasr"
  114. })
  115. return model_cache_dir