|
@@ -571,8 +571,7 @@ class Trainer:
|
|
|
#ouput dir
|
|
#ouput dir
|
|
|
output_dir = Path(options.output_dir)
|
|
output_dir = Path(options.output_dir)
|
|
|
#batch interval
|
|
#batch interval
|
|
|
- batch_interval = options.batch_interval
|
|
|
|
|
- assert batch_interval > 0
|
|
|
|
|
|
|
+ batch_interval = options.batch_interval
|
|
|
|
|
|
|
|
start_time = time.perf_counter()
|
|
start_time = time.perf_counter()
|
|
|
for iiter, (_, batch) in enumerate(
|
|
for iiter, (_, batch) in enumerate(
|
|
@@ -580,11 +579,11 @@ class Trainer:
|
|
|
):
|
|
):
|
|
|
assert isinstance(batch, dict), type(batch)
|
|
assert isinstance(batch, dict), type(batch)
|
|
|
|
|
|
|
|
- if rank == 0:
|
|
|
|
|
|
|
+ if batch_interval > 0 and (not distributed_option.distributed or rank == 0):
|
|
|
if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")):
|
|
if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")):
|
|
|
num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates()
|
|
num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates()
|
|
|
- if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None):
|
|
|
|
|
- if options.use_pai:
|
|
|
|
|
|
|
+ if num_batch_updates % batch_interval == 0:
|
|
|
|
|
+ if options.use_pai and options.oss_bucket is not None:
|
|
|
buffer = BytesIO()
|
|
buffer = BytesIO()
|
|
|
torch.save(model.state_dict(), buffer)
|
|
torch.save(model.state_dict(), buffer)
|
|
|
options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue())
|
|
options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue())
|