lyblsgo 2 лет назад
Родитель
Сommit
0535db1c65

+ 1 - 1
funasr/runtime/onnxruntime/include/Model.h

@@ -15,5 +15,5 @@ class Model {
     virtual std::string AddPunc(const char* szInput)=0;
 };
 
-Model *create_model(const char *path,int nThread=0,bool quantize=false, bool use_vad=false);
+Model *create_model(const char *path,int nThread=0,bool quantize=false, bool use_vad=false, bool use_punc=false);
 #endif

+ 5 - 5
funasr/runtime/onnxruntime/include/libfunasrapi.h

@@ -49,17 +49,17 @@ typedef enum {
 typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step.
 	
 // APIs for funasr
-_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThread, bool quantize=false, bool use_vad=false);
+_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThread, bool quantize=false, bool use_vad=false, bool use_punc=false);
 
 
 // if not give a fnCallback ,it should be NULL 
-_FUNASRAPI FUNASR_RESULT	FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
+_FUNASRAPI FUNASR_RESULT	FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
 
-_FUNASRAPI FUNASR_RESULT	FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
+_FUNASRAPI FUNASR_RESULT	FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
 
-_FUNASRAPI FUNASR_RESULT	FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
+_FUNASRAPI FUNASR_RESULT	FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
 
-_FUNASRAPI FUNASR_RESULT	FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
+_FUNASRAPI FUNASR_RESULT	FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
 
 _FUNASRAPI const char*	FunASRGetResult(FUNASR_RESULT Result,int nIndex);
 

+ 2 - 2
funasr/runtime/onnxruntime/readme.md

@@ -59,12 +59,12 @@ Ref to win/
 ## Run the demo
 
 ```shell
-funasr-onnx-offline /path/models_dir /path/wave_file quantize(true or false) use_vad(true or false)
+funasr-onnx-offline /path/models_dir /path/wave_file quantize(true or false) use_vad(true or false) use_punc(true or false)
 ```
 
 The structure of /path/models_dir
 ```
-config.yaml, am.mvn, model.onnx(or model_quant.onnx), (vad_model.onnx, vad.mvn if you use vad)
+config.yaml, am.mvn, model.onnx(or model_quant.onnx), (vad_model.onnx, vad.mvn if you use vad), (punc_model.onnx, punc.yaml if you use vad)
 ```
 
 

+ 7 - 2
funasr/runtime/onnxruntime/src/punc_infer.cpp → funasr/runtime/onnxruntime/src/CT-transformer.cpp

@@ -10,14 +10,19 @@ CTTransformer::CTTransformer(const char* sz_model_dir, int thread_num)
 	string strModelPath = pathAppend(sz_model_dir, PUNC_MODEL_FILE);
 	string strYamlPath = pathAppend(sz_model_dir, PUNC_YAML_FILE);
 
+    try{
 #ifdef _WIN32
 	std::wstring detPath = strToWstr(strModelPath);
     m_session = std::make_unique<Ort::Session>(env_, detPath.c_str(), session_options);
 #else
     m_session = std::make_unique<Ort::Session>(env_, strModelPath.c_str(), session_options);
 #endif
-    // read inputnames outputnames
-    vector<string> m_strInputNames, m_strOutputNames;
+    }
+    catch(exception e)
+    {
+        printf(e.what());
+    }
+    // read inputnames outputnamess
     string strName;
     getInputName(m_session.get(), strName);
     m_strInputNames.push_back(strName.c_str());

+ 1 - 0
funasr/runtime/onnxruntime/src/punc_infer.h → funasr/runtime/onnxruntime/src/CT-transformer.h

@@ -10,6 +10,7 @@ class CTTransformer {
 private:
 
 	CTokenizer m_tokenizer;
+	vector<string> m_strInputNames, m_strOutputNames;
 	vector<const char*> m_szInputNames;
 	vector<const char*> m_szOutputNames;
 

+ 2 - 2
funasr/runtime/onnxruntime/src/Model.cpp

@@ -1,10 +1,10 @@
 #include "precomp.h"
 
-Model *create_model(const char *path, int nThread, bool quantize, bool use_vad)
+Model *create_model(const char *path, int nThread, bool quantize, bool use_vad, bool use_punc)
 {
     Model *mm;
 
-    mm = new paraformer::ModelImp(path, nThread, quantize, use_vad);
+    mm = new paraformer::ModelImp(path, nThread, quantize, use_vad, use_punc);
 
     return mm;
 }

+ 6 - 4
funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp

@@ -11,9 +11,9 @@ using namespace std;
 
 int main(int argc, char *argv[])
 {
-    if (argc < 5)
+    if (argc < 6)
     {
-        printf("Usage: %s /path/to/model_dir /path/to/wav/file quantize(true or false) use_vad(true or false) \n", argv[0]);
+        printf("Usage: %s /path/to/model_dir /path/to/wav/file quantize(true or false) use_vad(true or false) use_punc(true or false)\n", argv[0]);
         exit(-1);
     }
     struct timeval start, end;
@@ -22,9 +22,11 @@ int main(int argc, char *argv[])
     // is quantize
     bool quantize = false;
     bool use_vad = false;
+    bool use_punc = false;
     istringstream(argv[3]) >> boolalpha >> quantize;
     istringstream(argv[4]) >> boolalpha >> use_vad;
-    FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize, use_vad);
+    istringstream(argv[5]) >> boolalpha >> use_punc;
+    FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize, use_vad, use_punc);
 
     if (!AsrHanlde)
     {
@@ -38,7 +40,7 @@ int main(int argc, char *argv[])
     printf("Model initialization takes %lfs.\n", (double)modle_init_micros / 1000000);
 
     gettimeofday(&start, NULL);
-    FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL, use_vad);
+    FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL, use_vad, use_punc);
     gettimeofday(&end, NULL);
 
     float snippet_time = 0.0f;

+ 19 - 7
funasr/runtime/onnxruntime/src/libfunasrapi.cpp

@@ -5,13 +5,13 @@ extern "C" {
 #endif
 
 	// APIs for funasr
-	_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThreadNum, bool quantize, bool use_vad)
+	_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThreadNum, bool quantize, bool use_vad, bool use_punc)
 	{
-		Model* mm = create_model(szModelDir, nThreadNum, quantize, use_vad);
+		Model* mm = create_model(szModelDir, nThreadNum, quantize, use_vad, use_punc);
 		return mm;
 	}
 
-	_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
 	{
 		Model* pRecogObj = (Model*)handle;
 		if (!pRecogObj)
@@ -39,11 +39,15 @@ extern "C" {
 			if (fnCallback)
 				fnCallback(nStep, nTotal);
 		}
+		if(use_punc){
+			string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
+			pResult->msg = punc_res;
+		}
 
 		return pResult;
 	}
 
-	_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
 	{
 		Model* pRecogObj = (Model*)handle;
 		if (!pRecogObj)
@@ -70,11 +74,15 @@ extern "C" {
 			if (fnCallback)
 				fnCallback(nStep, nTotal);
 		}
+		if(use_punc){
+			string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
+			pResult->msg = punc_res;
+		}
 
 		return pResult;
 	}
 
-	_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
 	{
 		Model* pRecogObj = (Model*)handle;
 		if (!pRecogObj)
@@ -101,11 +109,15 @@ extern "C" {
 			if (fnCallback)
 				fnCallback(nStep, nTotal);
 		}
+		if(use_punc){
+			string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
+			pResult->msg = punc_res;
+		}
 
 		return pResult;
 	}
 
-	_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
 	{
 		Model* pRecogObj = (Model*)handle;
 		if (!pRecogObj)
@@ -133,7 +145,7 @@ extern "C" {
 			if (fnCallback)
 				fnCallback(nStep, nTotal);
 		}
-		if(true){
+		if(use_punc){
 			string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
 			pResult->msg = punc_res;
 		}

+ 2 - 3
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp

@@ -3,7 +3,7 @@
 using namespace std;
 using namespace paraformer;
 
-ModelImp::ModelImp(const char* path,int nNumThread, bool quantize, bool use_vad)
+ModelImp::ModelImp(const char* path,int nNumThread, bool quantize, bool use_vad, bool use_punc)
 :env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
     string model_path;
     string cmvn_path;
@@ -18,7 +18,7 @@ ModelImp::ModelImp(const char* path,int nNumThread, bool quantize, bool use_vad)
     }
 
     // PUNC model
-    if(true){
+    if(use_punc){
         puncHandle = make_unique<CTTransformer>(path, nNumThread);
     }
 
@@ -55,7 +55,6 @@ ModelImp::ModelImp(const char* path,int nNumThread, bool quantize, bool use_vad)
     m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
 #endif
 
-    vector<string> m_strInputNames, m_strOutputNames;
     string strName;
     getInputName(m_session.get(), strName);
     m_strInputNames.push_back(strName.c_str());

+ 2 - 1
funasr/runtime/onnxruntime/src/paraformer_onnx.h

@@ -33,11 +33,12 @@ namespace paraformer {
         Ort::Env env_;
         Ort::SessionOptions sessionOptions;
 
+        vector<string> m_strInputNames, m_strOutputNames;
         vector<const char*> m_szInputNames;
         vector<const char*> m_szOutputNames;
 
     public:
-        ModelImp(const char* path, int nNumThread=0, bool quantize=false, bool use_vad=false);
+        ModelImp(const char* path, int nNumThread=0, bool quantize=false, bool use_vad=false, bool use_punc=false);
         ~ModelImp();
         void reset();
         vector<float> FbankKaldi(float sample_rate, const float* waves, int len);

+ 1 - 1
funasr/runtime/onnxruntime/src/precomp.h

@@ -29,7 +29,7 @@ using namespace std;
 #include "commonfunc.h"
 #include "predefine_coe.h"
 #include "tokenizer.h"
-#include "punc_infer.h"
+#include "CT-transformer.h"
 #include "FsmnVad.h"
 #include "e2e_vad.h"
 #include "Vocab.h"