嘉渊 %!s(int64=2) %!d(string=hai) anos
pai
achega
ba3d455b21
Modificáronse 2 ficheiros con 4 adicións e 16 borrados
  1. 1 4
      funasr/bin/asr_inference_launch.py
  2. 3 12
      funasr/bin/diar_inference_launch.py

+ 1 - 4
funasr/bin/asr_inference_launch.py

@@ -1349,10 +1349,7 @@ def inference_transducer(
         left_context=left_context,
         right_context=right_context,
     )
-    speech2text = Speech2TextTransducer.from_pretrained(
-        model_tag=model_tag,
-        **speech2text_kwargs,
-    )
+    speech2text = Speech2TextTransducer(**speech2text_kwargs)
 
     def _forward(data_path_and_name_and_type,
                  raw_inputs: Union[np.ndarray, torch.Tensor] = None,

+ 3 - 12
funasr/bin/diar_inference_launch.py

@@ -92,10 +92,7 @@ def inference_sond(
             embedding_node="resnet1_dense"
         )
         logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
-        speech2xvector = Speech2Xvector.from_pretrained(
-            model_tag=model_tag,
-            **speech2xvector_kwargs,
-        )
+        speech2xvector = Speech2Xvector(**speech2xvector_kwargs)
         speech2xvector.sv_model.eval()
 
     # 2b. Build speech2diar
@@ -109,10 +106,7 @@ def inference_sond(
         dur_threshold=dur_threshold,
     )
     logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
-    speech2diar = Speech2DiarizationSOND.from_pretrained(
-        model_tag=model_tag,
-        **speech2diar_kwargs,
-    )
+    speech2diar = Speech2DiarizationSOND(**speech2diar_kwargs)
     speech2diar.diar_model.eval()
 
     def output_results_str(results: dict, uttid: str):
@@ -257,10 +251,7 @@ def inference_eend(
         dtype=dtype,
     )
     logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
-    speech2diar = Speech2DiarizationEEND.from_pretrained(
-        model_tag=model_tag,
-        **speech2diar_kwargs,
-    )
+    speech2diar = Speech2DiarizationEEND(**speech2diar_kwargs)
     speech2diar.diar_model.eval()
 
     def output_results_str(results: dict, uttid: str):