|
|
@@ -92,10 +92,7 @@ def inference_sond(
|
|
|
embedding_node="resnet1_dense"
|
|
|
)
|
|
|
logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
|
|
|
- speech2xvector = Speech2Xvector.from_pretrained(
|
|
|
- model_tag=model_tag,
|
|
|
- **speech2xvector_kwargs,
|
|
|
- )
|
|
|
+ speech2xvector = Speech2Xvector(**speech2xvector_kwargs)
|
|
|
speech2xvector.sv_model.eval()
|
|
|
|
|
|
# 2b. Build speech2diar
|
|
|
@@ -109,10 +106,7 @@ def inference_sond(
|
|
|
dur_threshold=dur_threshold,
|
|
|
)
|
|
|
logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
|
|
|
- speech2diar = Speech2DiarizationSOND.from_pretrained(
|
|
|
- model_tag=model_tag,
|
|
|
- **speech2diar_kwargs,
|
|
|
- )
|
|
|
+ speech2diar = Speech2DiarizationSOND(**speech2diar_kwargs)
|
|
|
speech2diar.diar_model.eval()
|
|
|
|
|
|
def output_results_str(results: dict, uttid: str):
|
|
|
@@ -257,10 +251,7 @@ def inference_eend(
|
|
|
dtype=dtype,
|
|
|
)
|
|
|
logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
|
|
|
- speech2diar = Speech2DiarizationEEND.from_pretrained(
|
|
|
- model_tag=model_tag,
|
|
|
- **speech2diar_kwargs,
|
|
|
- )
|
|
|
+ speech2diar = Speech2DiarizationEEND(**speech2diar_kwargs)
|
|
|
speech2diar.diar_model.eval()
|
|
|
|
|
|
def output_results_str(results: dict, uttid: str):
|