소스 검색

export model

游雁 3 년 전
부모
커밋
abd600823b

+ 9 - 19
funasr/export/README.md

@@ -11,33 +11,23 @@ The installation is the same as [funasr](../../README.md)
 
 ## Export onnx format model
 Export model from modelscope
-```python
-from funasr.export.export_model import ASRModelExportParaformer
-
-output_dir = "../export"  # onnx/torchscripts model save path
-export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=True)
-export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+```shell
+python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true
 ```
-
-
 Export model from local path
-```python
-export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+```shell
+python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true
 ```
 
 ## Export torchscripts format model
 Export model from modelscope
-```python
-from funasr.export.export_model import ASRModelExportParaformer
-
-output_dir = "../export"  # onnx/torchscripts model save path
-export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False)
-export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+```shell
+python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false
 ```
 
-Export model from local path
-```python
 
-export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+Export model from local path
+```shell
+python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false
 ```
 

+ 11 - 3
funasr/export/export_model.py

@@ -117,7 +117,15 @@ class ASRModelExportParaformer:
         )
 
 if __name__ == '__main__':
-    output_dir = "../export"
-    export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=True)
-    export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+    import sys
+    
+    model_path = sys.argv[1]
+    output_dir = sys.argv[2]
+    onnx = sys.argv[3]
+    onnx = onnx.lower()
+    onnx = onnx == 'true'
+    # model_path = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
+    # output_dir = "../export"
+    export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=onnx)
+    export_model.export(model_path)
     # export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')

+ 0 - 50
funasr/export/models/predictor/cif.py

@@ -116,53 +116,3 @@ def cif(hidden, alphas, threshold: float):
 		pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], device=hidden.device)
 		list_ls.append(torch.cat([l, pad_l], 0))
 	return torch.stack(list_ls, 0), fires
-
-
-def CifPredictorV2_test():
-	x = torch.rand([2, 21, 2])
-	x_len = torch.IntTensor([6, 21])
-	
-	mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
-	x = x * mask[:, :, None]
-	
-	predictor_scripts = torch.jit.script(CifPredictorV2(2, 1, 1))
-	# cif_output, cif_length, alphas, cif_peak = predictor_scripts(x, mask=mask[:, None, :])
-	predictor_scripts.save('test.pt')
-	loaded = torch.jit.load('test.pt')
-	cif_output, cif_length, alphas, cif_peak = loaded(x, mask=mask[:, None, :])
-	# print(cif_output)
-	print(predictor_scripts.code)
-	# predictor = CifPredictorV2(2, 1, 1)
-	# cif_output, cif_length, alphas, cif_peak = predictor(x, mask=mask[:, None, :])
-	print(cif_output)
-
-
-def CifPredictorV2_export_test():
-	x = torch.rand([2, 21, 2])
-	x_len = torch.IntTensor([6, 21])
-	
-	mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
-	x = x * mask[:, :, None]
-	
-	# predictor_scripts = torch.jit.script(CifPredictorV2(2, 1, 1))
-	# cif_output, cif_length, alphas, cif_peak = predictor_scripts(x, mask=mask[:, None, :])
-	predictor = CifPredictorV2(2, 1, 1)
-	predictor_trace = torch.jit.trace(predictor, (x, mask[:, None, :]))
-	predictor_trace.save('test_trace.pt')
-	loaded = torch.jit.load('test_trace.pt')
-	
-	x = torch.rand([3, 30, 2])
-	x_len = torch.IntTensor([6, 20, 30])
-	mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
-	x = x * mask[:, :, None]
-	cif_output, cif_length, alphas, cif_peak = loaded(x, mask=mask[:, None, :])
-	print(cif_output)
-	# print(predictor_trace.code)
-	# predictor = CifPredictorV2(2, 1, 1)
-	# cif_output, cif_length, alphas, cif_peak = predictor(x, mask=mask[:, None, :])
-	# print(cif_output)
-
-
-if __name__ == '__main__':
-	# CifPredictorV2_test()
-	CifPredictorV2_export_test()

+ 12 - 2
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/README.md

@@ -20,9 +20,19 @@ cd funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer
    pip install -r requirements.txt
    ```
 3. Export the model.
-    - Export your model([docs](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export))
+   
+   - Export model from modelscope
+   ```shell
+   python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true
+   ```
+   - Export model from local path
+   ```shell
+   python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true
+   ```
+    - More details ref to ([docs](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export))
+
 
-4. Run the demo.
+5. Run the demo.
    - Model_dir: the model path, which contains `model.onnx`, `config.yaml`, `am.mvn`.
    - Input: wav formt file, support formats: `str, np.ndarray, List[str]`
    - Output: `List[str]`: recognition result.

+ 9 - 0
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/demo.py

@@ -0,0 +1,9 @@
+from paraformer_onnx import Paraformer
+
+model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+model = Paraformer(model_dir, batch_size=1)
+
+wav_path = ['/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav']
+
+result = model(wav_path)
+print(result)