游雁 2 лет назад
Родитель
Сommit
1c2eb051cd
1 измененных файлов с 5 добавлено и 3 удалено
  1. 5 3
      funasr/cli/trainer.py

+ 5 - 3
funasr/cli/trainer.py

@@ -4,6 +4,7 @@ from funasr.torch_utils.device_funcs import to_device
 import logging
 from tqdm import tqdm
 from contextlib import nullcontext
+import torch.distributed as dist
 
 class Trainer:
 	"""
@@ -80,7 +81,7 @@ class Trainer:
 		}
 		# Create output directory if it does not exist
 		os.makedirs(self.output_dir, exist_ok=True)
-		filename = os.path.join(self.output_dir, f'model.{epoch}.pb')
+		filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
 		torch.save(state, filename)
 		print(f'Checkpoint saved to {filename}')
 	
@@ -110,8 +111,9 @@ class Trainer:
 		for epoch in range(self.start_epoch, self.max_epoch + 1):
 			self._train_epoch(epoch)
 			# self._validate_epoch(epoch)
-			self._save_checkpoint(epoch)
-			self.scheduler.step()
+			if dist.get_rank() == 0:
+				self._save_checkpoint(epoch)
+			# self.scheduler.step()
 	
 	def _train_epoch(self, epoch):
 		"""