|
@@ -952,10 +952,13 @@ def inference_paraformer_vad_speaker(
|
|
|
##### speaker_verification #####
|
|
##### speaker_verification #####
|
|
|
##################################
|
|
##################################
|
|
|
# load sv model
|
|
# load sv model
|
|
|
- sv_model_dict = torch.load(sv_model_file)
|
|
|
|
|
- sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
|
|
|
|
|
if ngpu > 0:
|
|
if ngpu > 0:
|
|
|
|
|
+ sv_model_dict = torch.load(sv_model_file)
|
|
|
|
|
+ sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
|
|
|
sv_model.cuda()
|
|
sv_model.cuda()
|
|
|
|
|
+ else:
|
|
|
|
|
+ sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu'))
|
|
|
|
|
+ sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
|
|
|
sv_model.load_state_dict(sv_model_dict)
|
|
sv_model.load_state_dict(sv_model_dict)
|
|
|
print(f'load sv model params: {sv_model_file}')
|
|
print(f'load sv model params: {sv_model_file}')
|
|
|
sv_model.eval()
|
|
sv_model.eval()
|