|
@@ -22,6 +22,7 @@ def download_from_ms(**kwargs):
|
|
|
|
|
|
|
|
config = os.path.join(model_or_path, "config.yaml")
|
|
config = os.path.join(model_or_path, "config.yaml")
|
|
|
if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
|
|
if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
|
|
|
|
|
+
|
|
|
config = OmegaConf.load(config)
|
|
config = OmegaConf.load(config)
|
|
|
kwargs = OmegaConf.merge(config, kwargs)
|
|
kwargs = OmegaConf.merge(config, kwargs)
|
|
|
init_param = os.path.join(model_or_path, "model.pb")
|
|
init_param = os.path.join(model_or_path, "model.pb")
|
|
@@ -39,8 +40,7 @@ def download_from_ms(**kwargs):
|
|
|
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
|
|
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
|
|
|
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
|
|
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
|
|
|
kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
|
|
kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
|
|
|
- else:# configuration.json
|
|
|
|
|
- assert os.path.exists(os.path.join(model_or_path, "configuration.json"))
|
|
|
|
|
|
|
+ elif os.path.exists(os.path.join(model_or_path, "configuration.json")):
|
|
|
with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
|
|
with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
|
|
|
conf_json = json.load(f)
|
|
conf_json = json.load(f)
|
|
|
cfg = {}
|
|
cfg = {}
|