|
|
@@ -9,6 +9,8 @@ from funasr.utils import config_argparse
|
|
|
from funasr.utils.build_dataloader import build_dataloader
|
|
|
from funasr.utils.build_distributed import build_distributed
|
|
|
from funasr.utils.prepare_data import prepare_data
|
|
|
+from funasr.utils.build_optimizer import build_optimizer
|
|
|
+from funasr.utils.build_scheduler import build_scheduler
|
|
|
from funasr.utils.types import str2bool
|
|
|
|
|
|
|
|
|
@@ -355,20 +357,6 @@ if __name__ == '__main__':
|
|
|
distributed_option.dist_rank,
|
|
|
distributed_option.local_rank))
|
|
|
|
|
|
- # optimizers = cls.build_optimizers(args, model=model)
|
|
|
- # schedulers = []
|
|
|
- # for i, optim in enumerate(optimizers, 1):
|
|
|
- # suf = "" if i == 1 else str(i)
|
|
|
- # name = getattr(args, f"scheduler{suf}")
|
|
|
- # conf = getattr(args, f"scheduler{suf}_conf")
|
|
|
- # if name is not None:
|
|
|
- # cls_ = scheduler_classes.get(name)
|
|
|
- # if cls_ is None:
|
|
|
- # raise ValueError(
|
|
|
- # f"must be one of {list(scheduler_classes)}: {name}"
|
|
|
- # )
|
|
|
- # scheduler = cls_(optim, **conf)
|
|
|
- # else:
|
|
|
- # scheduler = None
|
|
|
- #
|
|
|
- # schedulers.append(scheduler)
|
|
|
+ model = build_model(args)
|
|
|
+ optimizers = build_optimizer(args, model=model)
|
|
|
+ schedule = build_scheduler(args)
|