|
|
@@ -205,9 +205,9 @@ class Trainer:
|
|
|
else:
|
|
|
scaler = None
|
|
|
|
|
|
- if trainer_options.resume and (output_dir / "checkpoint.pth").exists():
|
|
|
+ if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
|
|
|
cls.resume(
|
|
|
- checkpoint=output_dir / "checkpoint.pth",
|
|
|
+ checkpoint=output_dir / "checkpoint.pb",
|
|
|
model=model,
|
|
|
optimizers=optimizers,
|
|
|
schedulers=schedulers,
|
|
|
@@ -361,7 +361,7 @@ class Trainer:
|
|
|
},
|
|
|
buffer,
|
|
|
)
|
|
|
- trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pth"), buffer.getvalue())
|
|
|
+ trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pb"), buffer.getvalue())
|
|
|
else:
|
|
|
torch.save(
|
|
|
{
|
|
|
@@ -374,7 +374,7 @@ class Trainer:
|
|
|
],
|
|
|
"scaler": scaler.state_dict() if scaler is not None else None,
|
|
|
},
|
|
|
- output_dir / "checkpoint.pth",
|
|
|
+ output_dir / "checkpoint.pb",
|
|
|
)
|
|
|
|
|
|
# 5. Save and log the model and update the link to the best model
|
|
|
@@ -382,22 +382,22 @@ class Trainer:
|
|
|
buffer = BytesIO()
|
|
|
torch.save(model.state_dict(), buffer)
|
|
|
trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir,
|
|
|
- f"{iepoch}epoch.pth"),buffer.getvalue())
|
|
|
+ f"{iepoch}epoch.pb"),buffer.getvalue())
|
|
|
else:
|
|
|
- torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pth")
|
|
|
+ torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pb")
|
|
|
|
|
|
- # Creates a sym link latest.pth -> {iepoch}epoch.pth
|
|
|
+ # Creates a sym link latest.pb -> {iepoch}epoch.pb
|
|
|
if trainer_options.use_pai:
|
|
|
- p = os.path.join(trainer_options.output_dir, "latest.pth")
|
|
|
+ p = os.path.join(trainer_options.output_dir, "latest.pb")
|
|
|
if trainer_options.oss_bucket.object_exists(p):
|
|
|
trainer_options.oss_bucket.delete_object(p)
|
|
|
trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
|
|
|
- os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pth"), p)
|
|
|
+ os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pb"), p)
|
|
|
else:
|
|
|
- p = output_dir / "latest.pth"
|
|
|
+ p = output_dir / "latest.pb"
|
|
|
if p.is_symlink() or p.exists():
|
|
|
p.unlink()
|
|
|
- p.symlink_to(f"{iepoch}epoch.pth")
|
|
|
+ p.symlink_to(f"{iepoch}epoch.pb")
|
|
|
|
|
|
_improved = []
|
|
|
for _phase, k, _mode in trainer_options.best_model_criterion:
|
|
|
@@ -407,16 +407,16 @@ class Trainer:
|
|
|
# Creates sym links if it's the best result
|
|
|
if best_epoch == iepoch:
|
|
|
if trainer_options.use_pai:
|
|
|
- p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pth")
|
|
|
+ p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pb")
|
|
|
if trainer_options.oss_bucket.object_exists(p):
|
|
|
trainer_options.oss_bucket.delete_object(p)
|
|
|
trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
|
|
|
- os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pth"),p)
|
|
|
+ os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pb"),p)
|
|
|
else:
|
|
|
- p = output_dir / f"{_phase}.{k}.best.pth"
|
|
|
+ p = output_dir / f"{_phase}.{k}.best.pb"
|
|
|
if p.is_symlink() or p.exists():
|
|
|
p.unlink()
|
|
|
- p.symlink_to(f"{iepoch}epoch.pth")
|
|
|
+ p.symlink_to(f"{iepoch}epoch.pb")
|
|
|
_improved.append(f"{_phase}.{k}")
|
|
|
if len(_improved) == 0:
|
|
|
logging.info("There are no improvements in this epoch")
|
|
|
@@ -438,7 +438,7 @@ class Trainer:
|
|
|
type="model",
|
|
|
metadata={"improved": _improved},
|
|
|
)
|
|
|
- artifact.add_file(str(output_dir / f"{iepoch}epoch.pth"))
|
|
|
+ artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
|
|
|
aliases = [
|
|
|
f"epoch-{iepoch}",
|
|
|
"best" if best_epoch == iepoch else "",
|
|
|
@@ -473,12 +473,12 @@ class Trainer:
|
|
|
|
|
|
for e in range(1, iepoch):
|
|
|
if trainer_options.use_pai:
|
|
|
- p = os.path.join(trainer_options.output_dir, f"{e}epoch.pth")
|
|
|
+ p = os.path.join(trainer_options.output_dir, f"{e}epoch.pb")
|
|
|
if trainer_options.oss_bucket.object_exists(p) and e not in nbests:
|
|
|
trainer_options.oss_bucket.delete_object(p)
|
|
|
_removed.append(str(p))
|
|
|
else:
|
|
|
- p = output_dir / f"{e}epoch.pth"
|
|
|
+ p = output_dir / f"{e}epoch.pb"
|
|
|
if p.exists() and e not in nbests:
|
|
|
p.unlink()
|
|
|
_removed.append(str(p))
|