shixian.shi 2 лет назад
Родитель
Сommit
e54535e5eb
1 измененных файлов с 5 добавлено и 2 удалено
  1. 5 2
      funasr/bin/asr_inference_launch.py

+ 5 - 2
funasr/bin/asr_inference_launch.py

@@ -952,10 +952,13 @@ def inference_paraformer_vad_speaker(
             #####  speaker_verification  #####
             #####  speaker_verification  #####
             ##################################
             ##################################
             # load sv model
             # load sv model
-            sv_model_dict = torch.load(sv_model_file)
-            sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
             if ngpu > 0:
             if ngpu > 0:
+                sv_model_dict = torch.load(sv_model_file)
+                sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
                 sv_model.cuda()
                 sv_model.cuda()
+            else:
+                sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu'))
+                sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
             sv_model.load_state_dict(sv_model_dict)
             sv_model.load_state_dict(sv_model_dict)
             print(f'load sv model params: {sv_model_file}')
             print(f'load sv model params: {sv_model_file}')
             sv_model.eval()
             sv_model.eval()