嘉渊 2 лет назад
Родитель
Сommit
74a2059e9c
1 измененных файлов с 4 добавлено и 0 удалено
  1. 4 0
      funasr/bin/train.py

+ 4 - 0
funasr/bin/train.py

@@ -520,6 +520,10 @@ if __name__ == '__main__':
     prepare_data(args, distributed_option)
 
     model = build_model(args)
+    model = model.to(
+        dtype=getattr(torch, args.train_dtype),
+        device="cuda" if args.ngpu > 0 else "cpu",
+    )
     optimizers = build_optimizer(args, model=model)
     schedulers = build_scheduler(args, optimizers)