Sfoglia il codice sorgente

add onnx quantize model for onnx; add tester_rtf

lyblsgo 3 anni fa
parent
commit
c8f3af8855

+ 1 - 1
funasr/runtime/onnxruntime/include/Model.h

@@ -13,5 +13,5 @@ class Model {
     virtual std::string rescoring() = 0;
 };
 
-Model *create_model(const char *path,int nThread=0);
+Model *create_model(const char *path,int nThread=0,bool quantize=false);
 #endif

+ 5 - 24
funasr/runtime/onnxruntime/include/librapidasrapi.h

@@ -1,33 +1,20 @@
 #pragma once
 
-
 #ifdef WIN32
-
-
 #ifdef _RPASR_API_EXPORT
-
 #define  _RAPIDASRAPI __declspec(dllexport)
 #else
 #define  _RAPIDASRAPI __declspec(dllimport)
 #endif
-	
-
 #else
-#define _RAPIDASRAPI  
+#define _RAPIDASRAPI
 #endif
 
-
-
-
-
 #ifndef _WIN32
-
 #define RPASR_CALLBCK_PREFIX __attribute__((__stdcall__))
-
 #else
 #define RPASR_CALLBCK_PREFIX __stdcall
 #endif
-	
 
 #ifdef __cplusplus 
 
@@ -35,16 +22,13 @@ extern "C" {
 #endif
 
 typedef void* RPASR_HANDLE;
-
 typedef void* RPASR_RESULT;
-
 typedef unsigned char RPASR_BOOL;
 
 #define RPASR_TRUE 1
 #define RPASR_FALSE 0
 #define QM_DEFAULT_THREAD_NUM  4
 
-
 typedef enum
 {
  RASR_NONE=-1,
@@ -55,7 +39,6 @@ typedef enum
 }RPASR_MODE;
 
 typedef enum {
-
 	RPASR_MODEL_PADDLE = 0,
 	RPASR_MODEL_PADDLE_2 = 1,
 	RPASR_MODEL_K2 = 2,
@@ -63,17 +46,15 @@ typedef enum {
 
 }RPASR_MODEL_TYPE;
 
-
 typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step.
 	
-	// APIs for qmasr
-
-_RAPIDASRAPI RPASR_HANDLE  RapidAsrInit(const char* szModelDir, int nThread);
-
+// APIs for qmasr
+_RAPIDASRAPI RPASR_HANDLE  RapidAsrInit(const char* szModelDir, int nThread, bool quantize);
 
 
 // if not give a fnCallback ,it should be NULL 
 _RAPIDASRAPI RPASR_RESULT	RapidAsrRecogBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback);
+
 _RAPIDASRAPI RPASR_RESULT	RapidAsrRecogPCMBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback);
 
 _RAPIDASRAPI RPASR_RESULT	RapidAsrRecogPCMFile(RPASR_HANDLE handle, const char* szFileName, RPASR_MODE Mode, QM_CALLBACK fnCallback);
@@ -83,8 +64,8 @@ _RAPIDASRAPI RPASR_RESULT	RapidAsrRecogFile(RPASR_HANDLE handle, const char* szW
 _RAPIDASRAPI const char*	RapidAsrGetResult(RPASR_RESULT Result,int nIndex);
 
 _RAPIDASRAPI const int		RapidAsrGetRetNumber(RPASR_RESULT Result);
-_RAPIDASRAPI void			RapidAsrFreeResult(RPASR_RESULT Result);
 
+_RAPIDASRAPI void			RapidAsrFreeResult(RPASR_RESULT Result);
 
 _RAPIDASRAPI void			RapidAsrUninit(RPASR_HANDLE Handle);
 

+ 2 - 3
funasr/runtime/onnxruntime/src/Model.cpp

@@ -1,11 +1,10 @@
 #include "precomp.h"
 
-Model *create_model(const char *path,int nThread)
+Model *create_model(const char *path, int nThread, bool quantize)
 {
     Model *mm;
 
-
-    mm = new paraformer::ModelImp(path, nThread);
+    mm = new paraformer::ModelImp(path, nThread, quantize);
 
     return mm;
 }

+ 2 - 33
funasr/runtime/onnxruntime/src/librapidasrapi.cpp

@@ -4,24 +4,16 @@
 extern "C" {
 #endif
 
-
 	// APIs for qmasr
-	_RAPIDASRAPI RPASR_HANDLE  RapidAsrInit(const char* szModelDir, int nThreadNum)
+	_RAPIDASRAPI RPASR_HANDLE  RapidAsrInit(const char* szModelDir, int nThreadNum, bool quantize)
 	{
-
-
-		Model* mm = create_model(szModelDir, nThreadNum); 
-
+		Model* mm = create_model(szModelDir, nThreadNum, quantize);
 		return mm;
 	}
 
-
 	_RAPIDASRAPI RPASR_RESULT RapidAsrRecogBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback)
 	{
-
-
 		Model* pRecogObj = (Model*)handle;
-
 		if (!pRecogObj)
 			return nullptr;
 
@@ -46,15 +38,12 @@ extern "C" {
 				fnCallback(nStep, nTotal);
 		}
 
-
 		return pResult;
 	}
 
 	_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback)
 	{
-
 		Model* pRecogObj = (Model*)handle;
-
 		if (!pRecogObj)
 			return nullptr;
 
@@ -79,16 +68,12 @@ extern "C" {
 				fnCallback(nStep, nTotal);
 		}
 
-
 		return pResult;
-
 	}
 
 	_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMFile(RPASR_HANDLE handle, const char* szFileName, RPASR_MODE Mode, QM_CALLBACK fnCallback)
 	{
-
 		Model* pRecogObj = (Model*)handle;
-
 		if (!pRecogObj)
 			return nullptr;
 
@@ -113,15 +98,12 @@ extern "C" {
 				fnCallback(nStep, nTotal);
 		}
 
-
 		return pResult;
-
 	}
 
 	_RAPIDASRAPI RPASR_RESULT RapidAsrRecogFile(RPASR_HANDLE handle, const char* szWavfile, RPASR_MODE Mode, QM_CALLBACK fnCallback)
 	{
 		Model* pRecogObj = (Model*)handle;
-
 		if (!pRecogObj)
 			return nullptr;
 
@@ -146,9 +128,6 @@ extern "C" {
 				fnCallback(nStep, nTotal);
 		}
 	
-	
-
-
 		return pResult;
 	}
 
@@ -158,7 +137,6 @@ extern "C" {
 			return 0;
 
 		return 1;
-		
 	}
 
 
@@ -168,7 +146,6 @@ extern "C" {
 			return 0.0f;
 
 		return ((RPASR_RECOG_RESULT*)Result)->snippet_time;
-
 	}
 
 	_RAPIDASRAPI const char* RapidAsrGetResult(RPASR_RESULT Result,int nIndex)
@@ -178,34 +155,26 @@ extern "C" {
 			return nullptr;
 
 		return pResult->msg.c_str();
-	
 	}
 
 	_RAPIDASRAPI void RapidAsrFreeResult(RPASR_RESULT Result)
 	{
-
 		if (Result)
 		{
 			delete (RPASR_RECOG_RESULT*)Result;
-
 		}
 	}
 
 	_RAPIDASRAPI void RapidAsrUninit(RPASR_HANDLE handle)
 	{
-
 		Model* pRecogObj = (Model*)handle;
 
-
 		if (!pRecogObj)
 			return;
 
 		delete pRecogObj;
-
 	}
 
-
-
 #ifdef __cplusplus 
 
 }

+ 10 - 3
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp

@@ -3,10 +3,17 @@
 using namespace std;
 using namespace paraformer;
 
-ModelImp::ModelImp(const char* path,int nNumThread)
+ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
 {
-    string model_path = pathAppend(path, "model.onnx");
-    string vocab_path = pathAppend(path, "vocab.txt");
+    string model_path;
+    string vocab_path;
+    if(quantize)
+    {
+        model_path = pathAppend(path, "model_quant.onnx");
+    }else{
+        model_path = pathAppend(path, "model.onnx");
+    }
+    vocab_path = pathAppend(path, "vocab.txt");
 
     fe = new FeatureExtract(3);
 

+ 1 - 6
funasr/runtime/onnxruntime/src/paraformer_onnx.h

@@ -4,10 +4,6 @@
 #ifndef PARAFORMER_MODELIMP_H
 #define PARAFORMER_MODELIMP_H
 
-
-
-
-
 namespace paraformer {
 
     class ModelImp : public Model {
@@ -19,7 +15,6 @@ namespace paraformer {
         void apply_lfr(Tensor<float>*& din);
         void apply_cmvn(Tensor<float>* din);
 
-        
         string greedy_search( float* in, int nLen);
 
 #ifdef _WIN_X86
@@ -39,7 +34,7 @@ namespace paraformer {
         //string m_strOutputName, m_strOutputNameLen;
 
     public:
-        ModelImp(const char* path, int nNumThread=0);
+        ModelImp(const char* path, int nNumThread=0, bool quantize=false);
         ~ModelImp();
         void reset();
         string forward_chunk(float* din, int len, int flag);

+ 3 - 0
funasr/runtime/onnxruntime/tester/CMakeLists.txt

@@ -13,8 +13,11 @@ set(EXTRA_LIBS rapidasr)
 
 include_directories(${CMAKE_SOURCE_DIR}/include)
 set(EXECNAME "tester")
+set(EXECNAMERTF "tester_rtf")
 
 add_executable(${EXECNAME} "tester.cpp")
 target_link_libraries(${EXECNAME} PUBLIC ${EXTRA_LIBS})
 
+add_executable(${EXECNAMERTF} "tester_rtf.cpp")
+target_link_libraries(${EXECNAMERTF} PUBLIC ${EXTRA_LIBS})
 

+ 7 - 11
funasr/runtime/onnxruntime/tester/tester.cpp

@@ -9,40 +9,39 @@
 
 #include <iostream>
 #include <fstream>
+#include <sstream>
 using namespace std;
 
 int main(int argc, char *argv[])
 {
 
-    if (argc < 2)
+    if (argc < 3)
     {
-        printf("Usage: %s /path/to/model_dir /path/to/wav/file", argv[0]);
+        printf("Usage: %s /path/to/model_dir /path/to/wav/file quantize(true or false)", argv[0]);
         exit(-1);
     }
     struct timeval start, end;
     gettimeofday(&start, NULL);
     int nThreadNum = 4;
-    RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum);
+    // is quantize
+    bool quantize = false;
+    istringstream(argv[3]) >> boolalpha >> quantize;
+    RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum, quantize);
 
     if (!AsrHanlde)
     {
         printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
         exit(-1);
     }
-    
- 
 
     gettimeofday(&end, NULL);
     long seconds = (end.tv_sec - start.tv_sec);
     long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
     printf("Model initialization takes %lfs.\n", (double)modle_init_micros / 1000000);
 
-
-
     gettimeofday(&start, NULL);
     float snippet_time = 0.0f;
 
-
     RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL);
 
     gettimeofday(&end, NULL);
@@ -62,7 +61,6 @@ int main(int argc, char *argv[])
         cout <<"no return data!";
     }
  
- 
     //char* buff = nullptr;
     //int len = 0;
     //ifstream ifs(argv[2], std::ios::binary | std::ios::in);
@@ -101,13 +99,11 @@ int main(int argc, char *argv[])
     //   
     //delete[]buff;
     //}
-
  
     printf("Audio length %lfs.\n", (double)snippet_time);
     seconds = (end.tv_sec - start.tv_sec);
     long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
     printf("Model inference takes %lfs.\n", (double)taking_micros / 1000000);
-
     printf("Model inference RTF: %04lf.\n", (double)taking_micros/ (snippet_time*1000000));
 
     RapidAsrUninit(AsrHanlde);

+ 8 - 4
funasr/runtime/onnxruntime/tester/tester_rtf.cpp

@@ -16,9 +16,9 @@ using namespace std;
 int main(int argc, char *argv[])
 {
 
-    if (argc < 2)
+    if (argc < 4)
     {
-        printf("Usage: %s /path/to/model_dir /path/to/wav.scp", argv[0]);
+        printf("Usage: %s /path/to/model_dir /path/to/wav.scp quantize(true or false) \n", argv[0]);
         exit(-1);
     }
 
@@ -43,7 +43,11 @@ int main(int argc, char *argv[])
     struct timeval start, end;
     gettimeofday(&start, NULL);
     int nThreadNum = 1;
-    RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum);
+    // is quantize
+    bool quantize = false;
+    istringstream(argv[3]) >> boolalpha >> quantize;
+
+    RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum, quantize);
     if (!AsrHanlde)
     {
         printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
@@ -88,7 +92,7 @@ int main(int argc, char *argv[])
 
     printf("total_time_wav %ld ms.\n", (long)(total_length * 1000));
     printf("total_time_comput %ld ms.\n", total_time / 1000);
-    printf("Model inference RTF: %05lf.\n", (double)total_time/ (total_length*1000000));
+    printf("total_rtf %05lf .\n", (double)total_time/ (total_length*1000000));
 
     RapidAsrUninit(AsrHanlde);
     return 0;