|
|
@@ -520,6 +520,10 @@ if __name__ == '__main__':
|
|
|
prepare_data(args, distributed_option)
|
|
|
|
|
|
model = build_model(args)
|
|
|
+ model = model.to(
|
|
|
+ dtype=getattr(torch, args.train_dtype),
|
|
|
+ device="cuda" if args.ngpu > 0 else "cpu",
|
|
|
+ )
|
|
|
optimizers = build_optimizer(args, model=model)
|
|
|
schedulers = build_scheduler(args, optimizers)
|
|
|
|