|
@@ -582,10 +582,16 @@ class Trainer:
|
|
|
if num_batch_updates % batch_interval == 0:
|
|
if num_batch_updates % batch_interval == 0:
|
|
|
if options.use_pai and options.oss_bucket is not None:
|
|
if options.use_pai and options.oss_bucket is not None:
|
|
|
buffer = BytesIO()
|
|
buffer = BytesIO()
|
|
|
- torch.save(model.state_dict(), buffer)
|
|
|
|
|
|
|
+ if hasattr(model, "module"):
|
|
|
|
|
+ torch.save(model.module.state_dict(), buffer)
|
|
|
|
|
+ else:
|
|
|
|
|
+ torch.save(model.state_dict(), buffer)
|
|
|
options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue())
|
|
options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue())
|
|
|
else:
|
|
else:
|
|
|
- torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
|
|
|
|
|
|
|
+ if hasattr(model, "module"):
|
|
|
|
|
+ torch.save(model.module.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
|
|
|
|
|
+ else:
|
|
|
|
|
+ torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
|
|
|
|
|
|
|
|
if distributed:
|
|
if distributed:
|
|
|
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
|
|
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
|