|
|
@@ -1150,7 +1150,6 @@ class AbsTask(ABC):
|
|
|
def main_worker(cls, args: argparse.Namespace):
|
|
|
assert check_argument_types()
|
|
|
|
|
|
- args.ngpu = 0
|
|
|
# 0. Init distributed process
|
|
|
distributed_option = build_dataclass(DistributedOption, args)
|
|
|
# Setting distributed_option.dist_rank, etc.
|
|
|
@@ -1253,13 +1252,9 @@ class AbsTask(ABC):
|
|
|
raise RuntimeError(
|
|
|
f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
|
|
|
)
|
|
|
- #model = model.to(
|
|
|
- # dtype=getattr(torch, args.train_dtype),
|
|
|
- # device="cuda" if args.ngpu > 0 else "cpu",
|
|
|
- #)
|
|
|
model = model.to(
|
|
|
dtype=getattr(torch, args.train_dtype),
|
|
|
- device="cpu",
|
|
|
+ device="cuda" if args.ngpu > 0 else "cpu",
|
|
|
)
|
|
|
for t in args.freeze_param:
|
|
|
for k, p in model.named_parameters():
|