|
|
@@ -75,6 +75,7 @@ def assigment_scope_map(dst_state: dict, src_state: dict, scope_map: str=None):
|
|
|
|
|
|
return assignment_map
|
|
|
|
|
|
+
|
|
|
def load_pretrained_model(
|
|
|
path: str,
|
|
|
model: torch.nn.Module,
|
|
|
@@ -94,25 +95,69 @@ def load_pretrained_model(
|
|
|
"""
|
|
|
|
|
|
obj = model
|
|
|
-
|
|
|
+ dst_state = obj.state_dict()
|
|
|
+ # import pdb;
|
|
|
+ # pdb.set_trace()
|
|
|
+ print(f"ckpt: {path}")
|
|
|
if oss_bucket is None:
|
|
|
src_state = torch.load(path, map_location=map_location)
|
|
|
else:
|
|
|
buffer = BytesIO(oss_bucket.get_object(path).read())
|
|
|
src_state = torch.load(buffer, map_location=map_location)
|
|
|
- src_state = src_state["model"] if "model" in src_state else src_state
|
|
|
-
|
|
|
- if excludes is not None:
|
|
|
- for e in excludes.split(","):
|
|
|
- src_state = {k: v for k, v in src_state.items() if not k.startswith(e)}
|
|
|
-
|
|
|
- dst_state = obj.state_dict()
|
|
|
- src_state = assigment_scope_map(dst_state, src_state, scope_map)
|
|
|
-
|
|
|
- if ignore_init_mismatch:
|
|
|
- src_state = filter_state_dict(dst_state, src_state)
|
|
|
-
|
|
|
- logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
|
|
|
- logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
|
|
|
- dst_state.update(src_state)
|
|
|
- obj.load_state_dict(dst_state, strict=True)
|
|
|
+ if "state_dict" in src_state:
|
|
|
+ src_state = src_state["state_dict"]
|
|
|
+
|
|
|
+ for k in dst_state.keys():
|
|
|
+ if not k.startswith("module.") and "module." + k in src_state.keys():
|
|
|
+ k_ddp = "module." + k
|
|
|
+ else:
|
|
|
+ k_ddp = k
|
|
|
+ if k_ddp in src_state:
|
|
|
+ dst_state[k] = src_state[k_ddp]
|
|
|
+ else:
|
|
|
+ print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
|
|
|
+
|
|
|
+ flag = obj.load_state_dict(dst_state, strict=True)
|
|
|
+ print(flag)
|
|
|
+
|
|
|
+# def load_pretrained_model(
|
|
|
+# path: str,
|
|
|
+# model: torch.nn.Module,
|
|
|
+# ignore_init_mismatch: bool,
|
|
|
+# map_location: str = "cpu",
|
|
|
+# oss_bucket=None,
|
|
|
+# scope_map=None,
|
|
|
+# excludes=None,
|
|
|
+# ):
|
|
|
+# """Load a model state and set it to the model.
|
|
|
+#
|
|
|
+# Args:
|
|
|
+# init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys>
|
|
|
+#
|
|
|
+# Examples:
|
|
|
+#
|
|
|
+# """
|
|
|
+#
|
|
|
+# obj = model
|
|
|
+#
|
|
|
+# if oss_bucket is None:
|
|
|
+# src_state = torch.load(path, map_location=map_location)
|
|
|
+# else:
|
|
|
+# buffer = BytesIO(oss_bucket.get_object(path).read())
|
|
|
+# src_state = torch.load(buffer, map_location=map_location)
|
|
|
+# src_state = src_state["model"] if "model" in src_state else src_state
|
|
|
+#
|
|
|
+# if excludes is not None:
|
|
|
+# for e in excludes.split(","):
|
|
|
+# src_state = {k: v for k, v in src_state.items() if not k.startswith(e)}
|
|
|
+#
|
|
|
+# dst_state = obj.state_dict()
|
|
|
+# src_state = assigment_scope_map(dst_state, src_state, scope_map)
|
|
|
+#
|
|
|
+# if ignore_init_mismatch:
|
|
|
+# src_state = filter_state_dict(dst_state, src_state)
|
|
|
+#
|
|
|
+# logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
|
|
|
+# logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
|
|
|
+# dst_state.update(src_state)
|
|
|
+# obj.load_state_dict(dst_state, strict=True)
|