|
|
@@ -146,7 +146,7 @@ class AutoModel:
|
|
|
set_all_random_seed(kwargs.get("seed", 0))
|
|
|
|
|
|
device = kwargs.get("device", "cuda")
|
|
|
- if not torch.cuda.is_available() or kwargs.get("ngpu", 0):
|
|
|
+ if not torch.cuda.is_available() or kwargs.get("ngpu", 0) == 0:
|
|
|
device = "cpu"
|
|
|
kwargs["batch_size"] = 1
|
|
|
kwargs["device"] = device
|