游雁 1 年之前
父节点
当前提交
497edf4c9d
共有 3 个文件被更改,包括 88 次插入18 次删除
  1. 12 0
      examples/aishell/conformer/infer.sh
  2. 62 17
      funasr/train_utils/load_pretrained_model.py
  3. 14 1
      funasr/train_utils/trainer.py

+ 12 - 0
examples/aishell/conformer/infer.sh

@@ -0,0 +1,12 @@
+
+
+python funasr/bin/inference.py \
+--config-path="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3" \
+--config-name="config.yaml" \
+++init_param="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3/model.pt.ep38" \
+++tokenizer_conf.token_list="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/zh_token_list/char/tokens.txt" \
+++frontend_conf.cmvn_file="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/train/am.mvn" \
+++input="/mnt/nfs/zhifu.gzf/data/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav" \
+++output_dir="./outputs/debug" \
+++device="cpu" \
+

+ 62 - 17
funasr/train_utils/load_pretrained_model.py

@@ -75,6 +75,7 @@ def assigment_scope_map(dst_state: dict, src_state: dict, scope_map: str=None):
 	
 	return assignment_map
 
+
 def load_pretrained_model(
 	path: str,
 	model: torch.nn.Module,
@@ -94,25 +95,69 @@ def load_pretrained_model(
 	"""
 	
 	obj = model
-	
+	dst_state = obj.state_dict()
+	# import pdb;
+	# pdb.set_trace()
+	print(f"ckpt: {path}")
 	if oss_bucket is None:
 		src_state = torch.load(path, map_location=map_location)
 	else:
 		buffer = BytesIO(oss_bucket.get_object(path).read())
 		src_state = torch.load(buffer, map_location=map_location)
-	src_state = src_state["model"] if "model" in src_state else src_state
-	
-	if excludes is not None:
-		for e in excludes.split(","):
-			src_state = {k: v for k, v in src_state.items() if not k.startswith(e)}
-	
-	dst_state = obj.state_dict()
-	src_state = assigment_scope_map(dst_state, src_state, scope_map)
-	
-	if ignore_init_mismatch:
-		src_state = filter_state_dict(dst_state, src_state)
-	
-	logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
-	logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
-	dst_state.update(src_state)
-	obj.load_state_dict(dst_state, strict=True)
+	if "state_dict" in src_state:
+		src_state = src_state["state_dict"]
+		
+	for k in dst_state.keys():
+		if not k.startswith("module.") and "module." + k in src_state.keys():
+			k_ddp = "module." + k
+		else:
+			k_ddp = k
+		if k_ddp in src_state:
+			dst_state[k] = src_state[k_ddp]
+		else:
+			print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
+			
+	flag = obj.load_state_dict(dst_state, strict=True)
+	print(flag)
+
+# def load_pretrained_model(
+# 	path: str,
+# 	model: torch.nn.Module,
+# 	ignore_init_mismatch: bool,
+# 	map_location: str = "cpu",
+# 	oss_bucket=None,
+# 	scope_map=None,
+# 	excludes=None,
+# ):
+# 	"""Load a model state and set it to the model.
+#
+# 	Args:
+# 		init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys>
+#
+# 	Examples:
+#
+# 	"""
+#
+# 	obj = model
+#
+# 	if oss_bucket is None:
+# 		src_state = torch.load(path, map_location=map_location)
+# 	else:
+# 		buffer = BytesIO(oss_bucket.get_object(path).read())
+# 		src_state = torch.load(buffer, map_location=map_location)
+# 	src_state = src_state["model"] if "model" in src_state else src_state
+#
+# 	if excludes is not None:
+# 		for e in excludes.split(","):
+# 			src_state = {k: v for k, v in src_state.items() if not k.startswith(e)}
+#
+# 	dst_state = obj.state_dict()
+# 	src_state = assigment_scope_map(dst_state, src_state, scope_map)
+#
+# 	if ignore_init_mismatch:
+# 		src_state = filter_state_dict(dst_state, src_state)
+#
+# 	logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
+# 	logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
+# 	dst_state.update(src_state)
+# 	obj.load_state_dict(dst_state, strict=True)

+ 14 - 1
funasr/train_utils/trainer.py

@@ -128,7 +128,20 @@ class Trainer:
         if os.path.isfile(ckpt):
             checkpoint = torch.load(ckpt)
             self.start_epoch = checkpoint['epoch'] + 1
-            self.model.load_state_dict(checkpoint['state_dict'])
+            # self.model.load_state_dict(checkpoint['state_dict'])
+            src_state = checkpoint['state_dict']
+            dst_state = self.model.state_dict()
+            for k in dst_state.keys():
+                if not k.startswith("module.") and "module."+k in src_state.keys():
+                    k_ddp = "module."+k
+                else:
+                    k_ddp = k
+                if k_ddp in src_state.keys():
+                    dst_state[k] = src_state[k_ddp]
+                else:
+                    print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
+
+            self.model.load_state_dict(dst_state)
             self.optim.load_state_dict(checkpoint['optimizer'])
             self.scheduler.load_state_dict(checkpoint['scheduler'])
             print(f"Checkpoint loaded successfully from '{ckpt}'")