Forráskód Böngészése

Merge pull request #345 from alibaba-damo-academy/dev_tmp

fix unit test
hnluo 2 éve
szülő
commit
5563b28a74

+ 2 - 0
tests/test_asr_inference_pipeline.py

@@ -43,6 +43,7 @@ class TestData2vecInferencePipelines(unittest.TestCase):
         rec_result = inference_pipeline(
         rec_result = inference_pipeline(
             audio_in='https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav')
             audio_in='https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav')
         logger.info("asr inference result: {0}".format(rec_result))
         logger.info("asr inference result: {0}".format(rec_result))
+        assert rec_result["text"] == "每一天都要快乐喔"
 
 
     def test_paraformer(self):
     def test_paraformer(self):
         inference_pipeline = pipeline(
         inference_pipeline = pipeline(
@@ -51,6 +52,7 @@ class TestData2vecInferencePipelines(unittest.TestCase):
         rec_result = inference_pipeline(
         rec_result = inference_pipeline(
             audio_in='https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav')
             audio_in='https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav')
         logger.info("asr inference result: {0}".format(rec_result))
         logger.info("asr inference result: {0}".format(rec_result))
+        assert rec_result["text"] == "每一天都要快乐喔"
 
 
 
 
 class TestMfccaInferencePipelines(unittest.TestCase):
 class TestMfccaInferencePipelines(unittest.TestCase):

+ 3 - 5
tests/test_punctuation_pipeline.py

@@ -26,16 +26,14 @@ class TestTransformerInferencePipelines(unittest.TestCase):
         inference_pipeline = pipeline(
         inference_pipeline = pipeline(
             task=Tasks.punctuation,
             task=Tasks.punctuation,
             model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
             model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
-            model_revision="v1.0.0",
         )
         )
         inputs = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益"
         inputs = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益"
         vads = inputs.split("|")
         vads = inputs.split("|")
-        cache_out = []
         rec_result_all = "outputs:"
         rec_result_all = "outputs:"
+        param_dict = {"cache": []}
         for vad in vads:
         for vad in vads:
-            rec_result = inference_pipeline(text_in=vad, cache=cache_out)
-            cache_out = rec_result['cache']
-            rec_result_all += rec_result['text']
+            rec_result = inference_pipeline(text_in=vad, param_dict=param_dict)
+            rec_result_all += rec_result["text"]
         logger.info("punctuation inference result: {0}".format(rec_result_all))
         logger.info("punctuation inference result: {0}".format(rec_result_all))