|
@@ -183,14 +183,14 @@ class Trainer:
|
|
|
raise RuntimeError(
|
|
raise RuntimeError(
|
|
|
"Require torch>=1.6.0 for Automatic Mixed Precision"
|
|
"Require torch>=1.6.0 for Automatic Mixed Precision"
|
|
|
)
|
|
)
|
|
|
- if trainer_options.sharded_ddp:
|
|
|
|
|
- if fairscale is None:
|
|
|
|
|
- raise RuntimeError(
|
|
|
|
|
- "Requiring fairscale. Do 'pip install fairscale'"
|
|
|
|
|
- )
|
|
|
|
|
- scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
|
|
|
|
|
- else:
|
|
|
|
|
- scaler = GradScaler()
|
|
|
|
|
|
|
+ # if trainer_options.sharded_ddp:
|
|
|
|
|
+ # if fairscale is None:
|
|
|
|
|
+ # raise RuntimeError(
|
|
|
|
|
+ # "Requiring fairscale. Do 'pip install fairscale'"
|
|
|
|
|
+ # )
|
|
|
|
|
+ # scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
|
|
|
|
|
+ # else:
|
|
|
|
|
+ scaler = GradScaler()
|
|
|
else:
|
|
else:
|
|
|
scaler = None
|
|
scaler = None
|
|
|
|
|
|
|
@@ -295,10 +295,10 @@ class Trainer:
|
|
|
)
|
|
)
|
|
|
elif isinstance(scheduler, AbsEpochStepScheduler):
|
|
elif isinstance(scheduler, AbsEpochStepScheduler):
|
|
|
scheduler.step()
|
|
scheduler.step()
|
|
|
- if trainer_options.sharded_ddp:
|
|
|
|
|
- for optimizer in optimizers:
|
|
|
|
|
- if isinstance(optimizer, fairscale.optim.oss.OSS):
|
|
|
|
|
- optimizer.consolidate_state_dict()
|
|
|
|
|
|
|
+ # if trainer_options.sharded_ddp:
|
|
|
|
|
+ # for optimizer in optimizers:
|
|
|
|
|
+ # if isinstance(optimizer, fairscale.optim.oss.OSS):
|
|
|
|
|
+ # optimizer.consolidate_state_dict()
|
|
|
|
|
|
|
|
if not distributed_option.distributed or distributed_option.dist_rank == 0:
|
|
if not distributed_option.distributed or distributed_option.dist_rank == 0:
|
|
|
# 3. Report the results
|
|
# 3. Report the results
|
|
@@ -306,8 +306,8 @@ class Trainer:
|
|
|
if train_summary_writer is not None:
|
|
if train_summary_writer is not None:
|
|
|
reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
|
|
reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
|
|
|
reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
|
|
reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
|
|
|
- if trainer_options.use_wandb:
|
|
|
|
|
- reporter.wandb_log()
|
|
|
|
|
|
|
+ # if trainer_options.use_wandb:
|
|
|
|
|
+ # reporter.wandb_log()
|
|
|
|
|
|
|
|
# save tensorboard on oss
|
|
# save tensorboard on oss
|
|
|
if trainer_options.use_pai and train_summary_writer is not None:
|
|
if trainer_options.use_pai and train_summary_writer is not None:
|
|
@@ -412,25 +412,25 @@ class Trainer:
|
|
|
"The best model has been updated: " + ", ".join(_improved)
|
|
"The best model has been updated: " + ", ".join(_improved)
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- log_model = (
|
|
|
|
|
- trainer_options.wandb_model_log_interval > 0
|
|
|
|
|
- and iepoch % trainer_options.wandb_model_log_interval == 0
|
|
|
|
|
- )
|
|
|
|
|
- if log_model and trainer_options.use_wandb:
|
|
|
|
|
- import wandb
|
|
|
|
|
-
|
|
|
|
|
- logging.info("Logging Model on this epoch :::::")
|
|
|
|
|
- artifact = wandb.Artifact(
|
|
|
|
|
- name=f"model_{wandb.run.id}",
|
|
|
|
|
- type="model",
|
|
|
|
|
- metadata={"improved": _improved},
|
|
|
|
|
- )
|
|
|
|
|
- artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
|
|
|
|
|
- aliases = [
|
|
|
|
|
- f"epoch-{iepoch}",
|
|
|
|
|
- "best" if best_epoch == iepoch else "",
|
|
|
|
|
- ]
|
|
|
|
|
- wandb.log_artifact(artifact, aliases=aliases)
|
|
|
|
|
|
|
+ # log_model = (
|
|
|
|
|
+ # trainer_options.wandb_model_log_interval > 0
|
|
|
|
|
+ # and iepoch % trainer_options.wandb_model_log_interval == 0
|
|
|
|
|
+ # )
|
|
|
|
|
+ # if log_model and trainer_options.use_wandb:
|
|
|
|
|
+ # import wandb
|
|
|
|
|
+ #
|
|
|
|
|
+ # logging.info("Logging Model on this epoch :::::")
|
|
|
|
|
+ # artifact = wandb.Artifact(
|
|
|
|
|
+ # name=f"model_{wandb.run.id}",
|
|
|
|
|
+ # type="model",
|
|
|
|
|
+ # metadata={"improved": _improved},
|
|
|
|
|
+ # )
|
|
|
|
|
+ # artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
|
|
|
|
|
+ # aliases = [
|
|
|
|
|
+ # f"epoch-{iepoch}",
|
|
|
|
|
+ # "best" if best_epoch == iepoch else "",
|
|
|
|
|
+ # ]
|
|
|
|
|
+ # wandb.log_artifact(artifact, aliases=aliases)
|
|
|
|
|
|
|
|
# 6. Remove the model files excluding n-best epoch and latest epoch
|
|
# 6. Remove the model files excluding n-best epoch and latest epoch
|
|
|
_removed = []
|
|
_removed = []
|
|
@@ -529,9 +529,9 @@ class Trainer:
|
|
|
grad_clip = options.grad_clip
|
|
grad_clip = options.grad_clip
|
|
|
grad_clip_type = options.grad_clip_type
|
|
grad_clip_type = options.grad_clip_type
|
|
|
log_interval = options.log_interval
|
|
log_interval = options.log_interval
|
|
|
- no_forward_run = options.no_forward_run
|
|
|
|
|
|
|
+ # no_forward_run = options.no_forward_run
|
|
|
ngpu = options.ngpu
|
|
ngpu = options.ngpu
|
|
|
- use_wandb = options.use_wandb
|
|
|
|
|
|
|
+ # use_wandb = options.use_wandb
|
|
|
distributed = distributed_option.distributed
|
|
distributed = distributed_option.distributed
|
|
|
|
|
|
|
|
if log_interval is None:
|
|
if log_interval is None:
|
|
@@ -559,9 +559,9 @@ class Trainer:
|
|
|
break
|
|
break
|
|
|
|
|
|
|
|
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
|
|
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
|
|
|
- if no_forward_run:
|
|
|
|
|
- all_steps_are_invalid = False
|
|
|
|
|
- continue
|
|
|
|
|
|
|
+ # if no_forward_run:
|
|
|
|
|
+ # all_steps_are_invalid = False
|
|
|
|
|
+ # continue
|
|
|
|
|
|
|
|
with autocast(scaler is not None):
|
|
with autocast(scaler is not None):
|
|
|
with reporter.measure_time("forward_time"):
|
|
with reporter.measure_time("forward_time"):
|
|
@@ -737,8 +737,8 @@ class Trainer:
|
|
|
logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
|
|
logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
|
|
|
if summary_writer is not None:
|
|
if summary_writer is not None:
|
|
|
reporter.tensorboard_add_scalar(summary_writer, -log_interval)
|
|
reporter.tensorboard_add_scalar(summary_writer, -log_interval)
|
|
|
- if use_wandb:
|
|
|
|
|
- reporter.wandb_log()
|
|
|
|
|
|
|
+ # if use_wandb:
|
|
|
|
|
+ # reporter.wandb_log()
|
|
|
|
|
|
|
|
if max_update_stop:
|
|
if max_update_stop:
|
|
|
break
|
|
break
|
|
@@ -760,7 +760,7 @@ class Trainer:
|
|
|
) -> None:
|
|
) -> None:
|
|
|
assert check_argument_types()
|
|
assert check_argument_types()
|
|
|
ngpu = options.ngpu
|
|
ngpu = options.ngpu
|
|
|
- no_forward_run = options.no_forward_run
|
|
|
|
|
|
|
+ # no_forward_run = options.no_forward_run
|
|
|
distributed = distributed_option.distributed
|
|
distributed = distributed_option.distributed
|
|
|
|
|
|
|
|
model.eval()
|
|
model.eval()
|
|
@@ -776,8 +776,8 @@ class Trainer:
|
|
|
break
|
|
break
|
|
|
|
|
|
|
|
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
|
|
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
|
|
|
- if no_forward_run:
|
|
|
|
|
- continue
|
|
|
|
|
|
|
+ # if no_forward_run:
|
|
|
|
|
+ # continue
|
|
|
|
|
|
|
|
retval = model(**batch)
|
|
retval = model(**batch)
|
|
|
if isinstance(retval, dict):
|
|
if isinstance(retval, dict):
|