test_inference_pipeline.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import unittest
  2. from modelscope.pipelines import pipeline
  3. from modelscope.utils.constant import Tasks
  4. from modelscope.utils.logger import get_logger
  5. logger = get_logger()
  6. class TestInferencePipelines(unittest.TestCase):
  7. def test_funasr_path(self):
  8. import funasr
  9. import os
  10. logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))
  11. def test_asr_inference_pipeline(self):
  12. inference_pipeline = pipeline(
  13. task=Tasks.auto_speech_recognition,
  14. model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
  15. rec_result = inference_pipeline(
  16. audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
  17. logger.info("asr inference result: {0}".format(rec_result))
  18. def test_asr_inference_pipeline_with_vad_punc(self):
  19. inference_pipeline = pipeline(
  20. task=Tasks.auto_speech_recognition,
  21. model='damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
  22. vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
  23. vad_model_revision="v1.1.8",
  24. punc_model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
  25. punc_model_revision="v1.1.6")
  26. rec_result = inference_pipeline(
  27. audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav')
  28. logger.info("asr inference with vad punc result: {0}".format(rec_result))
  29. def test_vad_inference_pipeline(self):
  30. inference_pipeline = pipeline(
  31. task=Tasks.voice_activity_detection,
  32. model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
  33. model_revision='v1.1.8',
  34. )
  35. segments_result = inference_pipeline(
  36. audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav')
  37. logger.info("vad inference result: {0}".format(segments_result))
  38. if __name__ == '__main__':
  39. unittest.main()