shixian.shi пре 3 година
родитељ
комит
c441eb08c4
1 измењених фајлова са 3 додато и 1 уклоњено
  1. 3 1
      funasr/bin/tp_inference.py

+ 3 - 1
funasr/bin/tp_inference.py

@@ -112,6 +112,9 @@ class SpeechText2Timestamp:
         tp_model, tp_train_args = ASRTask.build_model_from_file(
             timestamp_infer_config, timestamp_model_file, device
         )
+        if 'cuda' in device:
+            tp_model = tp_model.cuda()
+            
         frontend = None
         if tp_train_args.frontend is not None:
             frontend = WavFrontend(cmvn_file=timestamp_cmvn_file, **tp_train_args.frontend_conf)
@@ -240,7 +243,6 @@ def inference_modelscope(
         device = "cuda"
     else:
         device = "cpu"
-
     # 1. Set random-seed
     set_all_random_seed(seed)