|
|
@@ -4,6 +4,7 @@ from funasr.torch_utils.device_funcs import to_device
|
|
|
import logging
|
|
|
from tqdm import tqdm
|
|
|
from contextlib import nullcontext
|
|
|
+import torch.distributed as dist
|
|
|
|
|
|
class Trainer:
|
|
|
"""
|
|
|
@@ -80,7 +81,7 @@ class Trainer:
|
|
|
}
|
|
|
# Create output directory if it does not exist
|
|
|
os.makedirs(self.output_dir, exist_ok=True)
|
|
|
- filename = os.path.join(self.output_dir, f'model.{epoch}.pb')
|
|
|
+ filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
|
|
|
torch.save(state, filename)
|
|
|
print(f'Checkpoint saved to {filename}')
|
|
|
|
|
|
@@ -110,8 +111,9 @@ class Trainer:
|
|
|
for epoch in range(self.start_epoch, self.max_epoch + 1):
|
|
|
self._train_epoch(epoch)
|
|
|
# self._validate_epoch(epoch)
|
|
|
- self._save_checkpoint(epoch)
|
|
|
- self.scheduler.step()
|
|
|
+ if dist.get_rank() == 0:
|
|
|
+ self._save_checkpoint(epoch)
|
|
|
+ # self.scheduler.step()
|
|
|
|
|
|
def _train_epoch(self, epoch):
|
|
|
"""
|