Przeglądaj źródła

Merge pull request #703 from alibaba-damo-academy/dev_wjm

fix load bug
jmwang66 2 lat temu
rodzic
commit
92dec9b331

+ 1 - 4
funasr/bin/asr_inference_launch.py

@@ -1367,10 +1367,7 @@ def inference_transducer(
         left_context=left_context,
         right_context=right_context,
     )
-    speech2text = Speech2TextTransducer.from_pretrained(
-        model_tag=model_tag,
-        **speech2text_kwargs,
-    )
+    speech2text = Speech2TextTransducer(**speech2text_kwargs)
 
     def _forward(data_path_and_name_and_type,
                  raw_inputs: Union[np.ndarray, torch.Tensor] = None,

+ 3 - 12
funasr/bin/diar_inference_launch.py

@@ -92,10 +92,7 @@ def inference_sond(
             embedding_node="resnet1_dense"
         )
         logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
-        speech2xvector = Speech2Xvector.from_pretrained(
-            model_tag=model_tag,
-            **speech2xvector_kwargs,
-        )
+        speech2xvector = Speech2Xvector(**speech2xvector_kwargs)
         speech2xvector.sv_model.eval()
 
     # 2b. Build speech2diar
@@ -109,10 +106,7 @@ def inference_sond(
         dur_threshold=dur_threshold,
     )
     logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
-    speech2diar = Speech2DiarizationSOND.from_pretrained(
-        model_tag=model_tag,
-        **speech2diar_kwargs,
-    )
+    speech2diar = Speech2DiarizationSOND(**speech2diar_kwargs)
     speech2diar.diar_model.eval()
 
     def output_results_str(results: dict, uttid: str):
@@ -257,10 +251,7 @@ def inference_eend(
         dtype=dtype,
     )
     logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
-    speech2diar = Speech2DiarizationEEND.from_pretrained(
-        model_tag=model_tag,
-        **speech2diar_kwargs,
-    )
+    speech2diar = Speech2DiarizationEEND(**speech2diar_kwargs)
     speech2diar.diar_model.eval()
 
     def output_results_str(results: dict, uttid: str):

+ 12 - 4
tests/test_asr_inference_pipeline.py

@@ -119,20 +119,28 @@ class TestParaformerInferencePipelines(unittest.TestCase):
     def test_paraformer_large_online_common(self):
         inference_pipeline = pipeline(
             task=Tasks.auto_speech_recognition,
-            model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online')
+            model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
+            model_revision='v1.0.6',
+            update_model=False,
+            mode="paraformer_fake_streaming"
+        )
         rec_result = inference_pipeline(
             audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
         logger.info("asr inference result: {0}".format(rec_result))
-        assert rec_result["text"] == "欢迎大 家来 体验达 摩院推 出的 语音识 别模 型"
+        assert rec_result["text"] == "欢迎大家来体验达摩院推出的语音识别模型"
 
     def test_paraformer_online_common(self):
         inference_pipeline = pipeline(
             task=Tasks.auto_speech_recognition,
-            model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online')
+            model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online',
+            model_revision='v1.0.6',
+            update_model=False,
+            mode="paraformer_fake_streaming"
+        )
         rec_result = inference_pipeline(
             audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
         logger.info("asr inference result: {0}".format(rec_result))
-        assert rec_result["text"] == "欢迎 大家来 体验达 摩院推 出的 语音识 别模 型"
+        assert rec_result["text"] == "欢迎大家来体验达摩院推出的语音识别模型"
 
     def test_paraformer_tiny_commandword(self):
         inference_pipeline = pipeline(