|
|
@@ -583,11 +583,14 @@ class Trainer:
|
|
|
if rank == 0:
|
|
|
if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")):
|
|
|
num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates()
|
|
|
- if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai:
|
|
|
- buffer = BytesIO()
|
|
|
- torch.save(model.state_dict(), buffer)
|
|
|
- options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}batch.pth"), buffer.getvalue())
|
|
|
-
|
|
|
+ if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None):
|
|
|
+ if options.use_pai:
|
|
|
+ buffer = BytesIO()
|
|
|
+ torch.save(model.state_dict(), buffer)
|
|
|
+ options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue())
|
|
|
+ else:
|
|
|
+ torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
|
|
|
+
|
|
|
if distributed:
|
|
|
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
|
|
|
if iterator_stop > 0:
|