Procházet zdrojové kódy

onnx export funasr_onnx

游雁 před 2 roky
rodič
revize
1c0774fdf3

+ 22 - 36
funasr/runtime/python/onnxruntime/README.md

@@ -1,23 +1,5 @@
 # ONNXRuntime-python
 
-## Export the model
-### Install [modelscope and funasr](https://github.com/alibaba-damo-academy/FunASR#installation)
-
-```shell
-#pip3 install torch torchaudio
-pip install -U modelscope funasr
-# For the users in China, you could install with the command:
-# pip install -U modelscope funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
-pip install torch-quant # Optional, for torchscript quantization
-pip install onnx onnxruntime # Optional, for onnx quantization
-```
-
-### Export [onnx model](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export)
-
-```shell
-python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True
-```
-
 
 ## Install `funasr_onnx`
 
@@ -43,17 +25,18 @@ pip install -e ./
 ### Speech Recognition
 #### Paraformer
  ```python
- from funasr_onnx import Paraformer
+from funasr_onnx import Paraformer
+from pathlib import Path
 
- model_dir = "./export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
- model = Paraformer(model_dir, batch_size=1, quantize=True)
+model_dir = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+model = Paraformer(model_dir, batch_size=1, quantize=True)
 
- wav_path = ['./export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav']
+wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav'.format(Path.home())]
 
- result = model(wav_path)
- print(result)
+result = model(wav_path)
+print(result)
  ```
-- `model_dir`: the model path, which contains `model.onnx`, `config.yaml`, `am.mvn`
+- `model_dir`: model_name in modelscope or local path downloaded from modelscope. If the local path is set, it should contain `model.onnx`, `config.yaml`, `am.mvn`
 - `batch_size`: `1` (Default), the batch size duration inference
 - `device_id`: `-1` (Default), infer on CPU. If you want to infer with GPU, set it to gpu_id (Please make sure that you have install the onnxruntime-gpu)
 - `quantize`: `False` (Default), load the model of `model.onnx` in `model_dir`. If set `True`, load the model of `model_quant.onnx` in `model_dir`
@@ -69,15 +52,17 @@ Output: `List[str]`: recognition result
 #### FSMN-VAD
 ```python
 from funasr_onnx import Fsmn_vad
+from pathlib import Path
+
+model_dir = "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+wav_path = '{}/.cache/modelscope/hub/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav'.format(Path.home())
 
-model_dir = "./export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
-wav_path = "./export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav"
 model = Fsmn_vad(model_dir)
 
 result = model(wav_path)
 print(result)
 ```
-- `model_dir`: the model path, which contains `model.onnx`, `config.yaml`, `am.mvn`
+- `model_dir`: model_name in modelscope or local path downloaded from modelscope. If the local path is set, it should contain `model.onnx`, `config.yaml`, `am.mvn`
 - `batch_size`: `1` (Default), the batch size duration inference
 - `device_id`: `-1` (Default), infer on CPU. If you want to infer with GPU, set it to gpu_id (Please make sure that you have install the onnxruntime-gpu)
 - `quantize`: `False` (Default), load the model of `model.onnx` in `model_dir`. If set `True`, load the model of `model_quant.onnx` in `model_dir`
@@ -92,10 +77,11 @@ Output: `List[str]`: recognition result
 ```python
 from funasr_onnx import Fsmn_vad_online
 import soundfile
+from pathlib import Path
 
+model_dir = "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+wav_path = '{}/.cache/modelscope/hub/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav'.format(Path.home())
 
-model_dir = "./export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
-wav_path = "./export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav"
 model = Fsmn_vad_online(model_dir)
 
 
@@ -118,7 +104,7 @@ for sample_offset in range(0, speech_length, min(step, speech_length - sample_of
     if segments_result:
         print(segments_result)
 ```
-- `model_dir`: the model path, which contains `model.onnx`, `config.yaml`, `am.mvn`
+- `model_dir`: model_name in modelscope or local path downloaded from modelscope. If the local path is set, it should contain `model.onnx`, `config.yaml`, `am.mvn`
 - `batch_size`: `1` (Default), the batch size duration inference
 - `device_id`: `-1` (Default), infer on CPU. If you want to infer with GPU, set it to gpu_id (Please make sure that you have install the onnxruntime-gpu)
 - `quantize`: `False` (Default), load the model of `model.onnx` in `model_dir`. If set `True`, load the model of `model_quant.onnx` in `model_dir`
@@ -134,14 +120,14 @@ Output: `List[str]`: recognition result
 ```python
 from funasr_onnx import CT_Transformer
 
-model_dir = "./export/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
+model_dir = "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
 model = CT_Transformer(model_dir)
 
 text_in="跨境河流是养育沿岸人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切愿意进一步完善双方联合工作机制凡是中方能做的我们都会去做而且会做得更好我请印度朋友们放心中国在上游的任何开发利用都会经过科学规划和论证兼顾上下游的利益"
 result = model(text_in)
 print(result[0])
 ```
-- `model_dir`: the model path, which contains `model.onnx`, `config.yaml`, `am.mvn`
+- `model_dir`: model_name in modelscope or local path downloaded from modelscope. If the local path is set, it should contain `model.onnx`, `config.yaml`, `am.mvn`
 - `device_id`: `-1` (Default), infer on CPU. If you want to infer with GPU, set it to gpu_id (Please make sure that you have install the onnxruntime-gpu)
 - `quantize`: `False` (Default), load the model of `model.onnx` in `model_dir`. If set `True`, load the model of `model_quant.onnx` in `model_dir`
 - `intra_op_num_threads`: `4` (Default), sets the number of threads used for intraop parallelism on CPU
@@ -155,7 +141,7 @@ Output: `List[str]`: recognition result
 ```python
 from funasr_onnx import CT_Transformer_VadRealtime
 
-model_dir = "./export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
+model_dir = "damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
 model = CT_Transformer_VadRealtime(model_dir)
 
 text_in  = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流>问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益"
@@ -169,7 +155,7 @@ for vad in vads:
 
 print(rec_result_all)
 ```
-- `model_dir`: the model path, which contains `model.onnx`, `config.yaml`, `am.mvn`
+- `model_dir`: model_name in modelscope or local path downloaded from modelscope. If the local path is set, it should contain `model.onnx`, `config.yaml`, `am.mvn`
 - `device_id`: `-1` (Default), infer on CPU. If you want to infer with GPU, set it to gpu_id (Please make sure that you have install the onnxruntime-gpu)
 - `quantize`: `False` (Default), load the model of `model.onnx` in `model_dir`. If set `True`, load the model of `model_quant.onnx` in `model_dir`
 - `intra_op_num_threads`: `4` (Default), sets the number of threads used for intraop parallelism on CPU
@@ -184,4 +170,4 @@ Please ref to [benchmark](https://github.com/alibaba-damo-academy/FunASR/blob/ma
 
 ## Acknowledge
 1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
-2. We acknowledge [SWHL](https://github.com/RapidAI/RapidASR) for contributing the onnxruntime (for paraformer model).
+2. We partially refer [SWHL](https://github.com/RapidAI/RapidASR) for onnxruntime (only for paraformer model).

+ 0 - 15
funasr/runtime/python/onnxruntime/demo.py

@@ -1,15 +0,0 @@
-from funasr_onnx import Paraformer
-
-
-model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
-
-model = Paraformer(model_dir, batch_size=2, plot_timestamp_to="./", pred_bias=0)  # cpu
-# model = Paraformer(model_dir, batch_size=2, plot_timestamp_to="./", pred_bias=0, device_id=0)  # gpu
-
-# when using paraformer-large-vad-punc model, you can set plot_timestamp_to="./xx.png" to get figure of alignment besides timestamps
-# model = Paraformer(model_dir, batch_size=1, plot_timestamp_to="test.png")
-
-wav_path = "YourPath/xx.wav"
-
-result = model(wav_path)
-print(result)

+ 14 - 0
funasr/runtime/python/onnxruntime/demo_paraformer_offline.py

@@ -0,0 +1,14 @@
+from funasr_onnx import Paraformer
+from pathlib import Path
+
+model_dir = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+model = Paraformer(model_dir, batch_size=1, quantize=True)
+# model = Paraformer(model_dir, batch_size=1, device_id=0)  # gpu
+
+# when using paraformer-large-vad-punc model, you can set plot_timestamp_to="./xx.png" to get figure of alignment besides timestamps
+# model = Paraformer(model_dir, batch_size=1, plot_timestamp_to="test.png")
+
+wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav'.format(Path.home())]
+
+result = model(wav_path)
+print(result)

+ 1 - 1
funasr/runtime/python/onnxruntime/demo_punc_offline.py

@@ -1,6 +1,6 @@
 from funasr_onnx import CT_Transformer
 
-model_dir = "../../../export/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
+model_dir = "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
 model = CT_Transformer(model_dir)
 
 text_in="跨境河流是养育沿岸人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切愿意进一步完善双方联合工作机制凡是中方能做的我们都会去做而且会做得更好我请印度朋友们放心中国在上游的任何开发利用都会经过科学规划和论证兼顾上下游的利益"

+ 1 - 1
funasr/runtime/python/onnxruntime/demo_punc_online.py

@@ -1,6 +1,6 @@
 from funasr_onnx import CT_Transformer_VadRealtime
 
-model_dir = "../../../export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
+model_dir = "damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
 model = CT_Transformer_VadRealtime(model_dir)
 
 text_in  = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流>问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益"

+ 3 - 4
funasr/runtime/python/onnxruntime/demo_vad_offline.py

@@ -1,11 +1,10 @@
-import soundfile
 from funasr_onnx import Fsmn_vad
+from pathlib import Path
 
+model_dir = "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+wav_path = '{}/.cache/modelscope/hub/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav'.format(Path.home())
 
-model_dir = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
-wav_path = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/vad_example_16k.wav"
 model = Fsmn_vad(model_dir)
 
-#offline vad
 result = model(wav_path)
 print(result)

+ 5 - 5
funasr/runtime/python/onnxruntime/demo_vad_online.py

@@ -1,9 +1,10 @@
-import soundfile
 from funasr_onnx import Fsmn_vad_online
+import soundfile
+from pathlib import Path
 
+model_dir = "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+wav_path = '{}/.cache/modelscope/hub/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav'.format(Path.home())
 
-model_dir = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
-wav_path = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/vad_example_16k.wav"
 model = Fsmn_vad_online(model_dir)
 
 
@@ -24,5 +25,4 @@ for sample_offset in range(0, speech_length, min(step, speech_length - sample_of
     segments_result = model(audio_in=speech[sample_offset: sample_offset + step],
                             param_dict=param_dict)
     if segments_result:
-        print(segments_result)
-
+        print(segments_result)

+ 1 - 1
funasr/runtime/python/onnxruntime/setup.py

@@ -13,7 +13,7 @@ def get_readme():
 
 
 MODULE_NAME = 'funasr_onnx'
-VERSION_NUM = '0.0.10'
+VERSION_NUM = '0.1.0'
 
 setuptools.setup(
     name=MODULE_NAME,