|
@@ -94,7 +94,7 @@ class TrainerOptions:
|
|
|
wandb_model_log_interval: int
|
|
wandb_model_log_interval: int
|
|
|
use_pai: bool
|
|
use_pai: bool
|
|
|
oss_bucket: Union[oss2.Bucket, None]
|
|
oss_bucket: Union[oss2.Bucket, None]
|
|
|
-
|
|
|
|
|
|
|
+ batch_interval: int
|
|
|
|
|
|
|
|
class Trainer:
|
|
class Trainer:
|
|
|
"""Trainer having a optimizer.
|
|
"""Trainer having a optimizer.
|
|
@@ -186,7 +186,10 @@ class Trainer:
|
|
|
logging.warning("No keep_nbest_models is given. Change to [1]")
|
|
logging.warning("No keep_nbest_models is given. Change to [1]")
|
|
|
trainer_options.keep_nbest_models = [1]
|
|
trainer_options.keep_nbest_models = [1]
|
|
|
keep_nbest_models = trainer_options.keep_nbest_models
|
|
keep_nbest_models = trainer_options.keep_nbest_models
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
|
|
+ #assert batch_interval is set and >0
|
|
|
|
|
+ assert trainer_options.batch_interval > 0
|
|
|
|
|
+
|
|
|
output_dir = Path(trainer_options.output_dir)
|
|
output_dir = Path(trainer_options.output_dir)
|
|
|
reporter = Reporter()
|
|
reporter = Reporter()
|
|
|
if trainer_options.use_amp:
|
|
if trainer_options.use_amp:
|
|
@@ -560,13 +563,30 @@ class Trainer:
|
|
|
# [For distributed] Because iteration counts are not always equals between
|
|
# [For distributed] Because iteration counts are not always equals between
|
|
|
# processes, send stop-flag to the other processes if iterator is finished
|
|
# processes, send stop-flag to the other processes if iterator is finished
|
|
|
iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
|
|
iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
|
|
+ #get the rank
|
|
|
|
|
+ rank = distributed_option.dist_rank
|
|
|
|
|
+ #get the num batch updates
|
|
|
|
|
+ num_batch_updates = 0
|
|
|
|
|
+ #ouput dir
|
|
|
|
|
+ output_dir = Path(options.output_dir)
|
|
|
|
|
+ #batch interval
|
|
|
|
|
+ batch_interval = options.batch_interval
|
|
|
|
|
+ assert batch_interval > 0
|
|
|
|
|
+
|
|
|
start_time = time.perf_counter()
|
|
start_time = time.perf_counter()
|
|
|
for iiter, (_, batch) in enumerate(
|
|
for iiter, (_, batch) in enumerate(
|
|
|
reporter.measure_iter_time(iterator, "iter_time"), 1
|
|
reporter.measure_iter_time(iterator, "iter_time"), 1
|
|
|
):
|
|
):
|
|
|
assert isinstance(batch, dict), type(batch)
|
|
assert isinstance(batch, dict), type(batch)
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
|
|
+ if rank == 0 and hasattr(model.module, "num_updates"):
|
|
|
|
|
+ num_batch_updates = model.module.get_num_updates()
|
|
|
|
|
+ if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai:
|
|
|
|
|
+ buffer = BytesIO()
|
|
|
|
|
+ torch.save(model.state_dict(), buffer)
|
|
|
|
|
+ options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}batch.pth"), buffer.getvalue())
|
|
|
|
|
+
|
|
|
if distributed:
|
|
if distributed:
|
|
|
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
|
|
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
|
|
|
if iterator_stop > 0:
|
|
if iterator_stop > 0:
|
|
@@ -811,4 +831,4 @@ class Trainer:
|
|
|
else:
|
|
else:
|
|
|
if distributed:
|
|
if distributed:
|
|
|
iterator_stop.fill_(1)
|
|
iterator_stop.fill_(1)
|
|
|
- torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
|
|
|
|
|
|
|
+ torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
|