Просмотр исходного кода

Merge pull request #334 from alibaba-damo-academy/dev_lyb

Dev lyb
zhifu gao 2 лет назад
Родитель
Сommit
48e6117a3a

+ 1 - 1
funasr/runtime/grpc/CMakeLists.txt

@@ -74,7 +74,7 @@ foreach(_target
     "${_target}.cc")
   target_link_libraries(${_target}
     rg_grpc_proto
-    rapidasr
+    funasr
     ${EXTRA_LIBS}
     ${_REFLECTION}
     ${_GRPC_GRPCPP}

+ 22 - 35
funasr/runtime/grpc/paraformer_server.cc

@@ -15,7 +15,6 @@
 #include "paraformer.grpc.pb.h"
 #include "paraformer_server.h"
 
-
 using grpc::Server;
 using grpc::ServerBuilder;
 using grpc::ServerContext;
@@ -24,48 +23,35 @@ using grpc::ServerReaderWriter;
 using grpc::ServerWriter;
 using grpc::Status;
 
-
 using paraformer::Request;
 using paraformer::Response;
 using paraformer::ASR;
 
 ASRServicer::ASRServicer(const char* model_path, int thread_num, bool quantize) {
-    AsrHanlde=RapidAsrInit(model_path, thread_num, quantize);
+    AsrHanlde=FunASRInit(model_path, thread_num, quantize);
     std::cout << "ASRServicer init" << std::endl;
     init_flag = 0;
 }
 
-void ASRServicer::clear_states(const std::string& user) {
-    clear_buffers(user);
-    clear_transcriptions(user);
-}
-
-void ASRServicer::clear_buffers(const std::string& user) {
-    if (client_buffers.count(user)) {
-        client_buffers.erase(user);
-    }
-}
-
-void ASRServicer::clear_transcriptions(const std::string& user) {
-    if (client_transcription.count(user)) {
-        client_transcription.erase(user);
-    }
-}
-
-void ASRServicer::disconnect(const std::string& user) {
-    clear_states(user);
-    std::cout << "Disconnecting user: " << user << std::endl;
-}
-
 grpc::Status ASRServicer::Recognize(
     grpc::ServerContext* context,
     grpc::ServerReaderWriter<Response, Request>* stream) {
 
     Request req;
+    std::unordered_map<std::string, std::string> client_buffers;
+    std::unordered_map<std::string, std::string> client_transcription;
+
     while (stream->Read(&req)) {
         if (req.isend()) {
             std::cout << "asr end" << std::endl;
-            disconnect(req.user());
+            // disconnect 
+            if (client_buffers.count(req.user())) {
+                client_buffers.erase(req.user());
+            }
+            if (client_transcription.count(req.user())) {
+                client_transcription.erase(req.user());
+            }
+
             Response res;
             res.set_sentence(
                 R"({"success": true, "detail": "asr end"})"
@@ -103,8 +89,14 @@ grpc::Status ASRServicer::Recognize(
                   auto& buf = client_buffers[req.user()];
                   buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end());
                 }
-                std::string tmp_data = this->client_buffers[req.user()];
-                this->clear_states(req.user());
+                std::string tmp_data = client_buffers[req.user()];
+                // clear_states
+                if (client_buffers.count(req.user())) {
+                    client_buffers.erase(req.user());
+                }
+                if (client_transcription.count(req.user())) {
+                    client_transcription.erase(req.user());
+                }
 
                 Response res;
                 res.set_sentence(
@@ -133,14 +125,11 @@ grpc::Status ASRServicer::Recognize(
                     res.set_user(req.user());
                     res.set_action("finish");
                     res.set_language(req.language());
-
-
-
                     stream->Write(res);
                 }
                 else {
-                    RPASR_RESULT Result= RapidAsrRecogPCMBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, RASR_NONE, NULL);
-                    std::string asr_result = ((RPASR_RECOG_RESULT*)Result)->msg;
+                    FUNASR_RESULT Result= FunASRRecogPCMBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, RASR_NONE, NULL);
+                    std::string asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg;
 
                     auto end_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
                     std::string delay_str = std::to_string(end_time - begin_time);
@@ -155,7 +144,6 @@ grpc::Status ASRServicer::Recognize(
                     res.set_action("finish");
                     res.set_language(req.language());
 
-
                     stream->Write(res);
                 }
             }
@@ -173,7 +161,6 @@ grpc::Status ASRServicer::Recognize(
     return Status::OK;
 }
 
-
 void RunServer(const std::string& port, int thread_num, const char* model_path, bool quantize) {
     std::string server_address;
     server_address = "0.0.0.0:" + port;

+ 3 - 9
funasr/runtime/grpc/paraformer_server.h

@@ -15,7 +15,7 @@
 #include <chrono>
 
 #include "paraformer.grpc.pb.h"
-#include "librapidasrapi.h"
+#include "libfunasrapi.h"
 
 
 using grpc::Server;
@@ -35,22 +35,16 @@ typedef struct
 {
     std::string msg;
     float  snippet_time;
-}RPASR_RECOG_RESULT;
+}FUNASR_RECOG_RESULT;
 
 
 class ASRServicer final : public ASR::Service {
   private:
     int init_flag;
-    std::unordered_map<std::string, std::string> client_buffers;
-    std::unordered_map<std::string, std::string> client_transcription;
 
   public:
     ASRServicer(const char* model_path, int thread_num, bool quantize);
-    void clear_states(const std::string& user);
-    void clear_buffers(const std::string& user);
-    void clear_transcriptions(const std::string& user);
-    void disconnect(const std::string& user);
     grpc::Status Recognize(grpc::ServerContext* context, grpc::ServerReaderWriter<Response, Request>* stream);
-    RPASR_HANDLE AsrHanlde;
+    FUNASR_HANDLE AsrHanlde;
 	
 };

+ 77 - 0
funasr/runtime/onnxruntime/include/libfunasrapi.h

@@ -0,0 +1,77 @@
+#pragma once
+
+#ifdef WIN32
+#ifdef _FUNASR_API_EXPORT
+#define  _FUNASRAPI __declspec(dllexport)
+#else
+#define  _FUNASRAPI __declspec(dllimport)
+#endif
+#else
+#define _FUNASRAPI
+#endif
+
+#ifndef _WIN32
+#define FUNASR_CALLBCK_PREFIX __attribute__((__stdcall__))
+#else
+#define FUNASR_CALLBCK_PREFIX __stdcall
+#endif
+
+#ifdef __cplusplus 
+
+extern "C" {
+#endif
+
+typedef void* FUNASR_HANDLE;
+typedef void* FUNASR_RESULT;
+typedef unsigned char FUNASR_BOOL;
+
+#define FUNASR_TRUE 1
+#define FUNASR_FALSE 0
+#define QM_DEFAULT_THREAD_NUM  4
+
+typedef enum
+{
+ RASR_NONE=-1,
+ RASRM_CTC_GREEDY_SEARCH=0,
+ RASRM_CTC_RPEFIX_BEAM_SEARCH = 1,
+ RASRM_ATTENSION_RESCORING = 2,
+ 
+}FUNASR_MODE;
+
+typedef enum {
+	FUNASR_MODEL_PADDLE = 0,
+	FUNASR_MODEL_PADDLE_2 = 1,
+	FUNASR_MODEL_K2 = 2,
+	FUNASR_MODEL_PARAFORMER = 3,
+
+}FUNASR_MODEL_TYPE;
+
+typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step.
+	
+// APIs for qmasr
+_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThread, bool quantize);
+
+
+// if not give a fnCallback ,it should be NULL 
+_FUNASRAPI FUNASR_RESULT	FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
+
+_FUNASRAPI FUNASR_RESULT	FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
+
+_FUNASRAPI FUNASR_RESULT	FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
+
+_FUNASRAPI FUNASR_RESULT	FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
+
+_FUNASRAPI const char*	FunASRGetResult(FUNASR_RESULT Result,int nIndex);
+
+_FUNASRAPI const int		FunASRGetRetNumber(FUNASR_RESULT Result);
+
+_FUNASRAPI void			FunASRFreeResult(FUNASR_RESULT Result);
+
+_FUNASRAPI void			FunASRUninit(FUNASR_HANDLE Handle);
+
+_FUNASRAPI const float	FunASRGetRetSnippetTime(FUNASR_RESULT Result);
+
+#ifdef __cplusplus 
+
+}
+#endif

+ 0 - 77
funasr/runtime/onnxruntime/include/librapidasrapi.h

@@ -1,77 +0,0 @@
-#pragma once
-
-#ifdef WIN32
-#ifdef _RPASR_API_EXPORT
-#define  _RAPIDASRAPI __declspec(dllexport)
-#else
-#define  _RAPIDASRAPI __declspec(dllimport)
-#endif
-#else
-#define _RAPIDASRAPI
-#endif
-
-#ifndef _WIN32
-#define RPASR_CALLBCK_PREFIX __attribute__((__stdcall__))
-#else
-#define RPASR_CALLBCK_PREFIX __stdcall
-#endif
-
-#ifdef __cplusplus 
-
-extern "C" {
-#endif
-
-typedef void* RPASR_HANDLE;
-typedef void* RPASR_RESULT;
-typedef unsigned char RPASR_BOOL;
-
-#define RPASR_TRUE 1
-#define RPASR_FALSE 0
-#define QM_DEFAULT_THREAD_NUM  4
-
-typedef enum
-{
- RASR_NONE=-1,
- RASRM_CTC_GREEDY_SEARCH=0,
- RASRM_CTC_RPEFIX_BEAM_SEARCH = 1,
- RASRM_ATTENSION_RESCORING = 2,
- 
-}RPASR_MODE;
-
-typedef enum {
-	RPASR_MODEL_PADDLE = 0,
-	RPASR_MODEL_PADDLE_2 = 1,
-	RPASR_MODEL_K2 = 2,
-	RPASR_MODEL_PARAFORMER = 3,
-
-}RPASR_MODEL_TYPE;
-
-typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step.
-	
-// APIs for qmasr
-_RAPIDASRAPI RPASR_HANDLE  RapidAsrInit(const char* szModelDir, int nThread, bool quantize);
-
-
-// if not give a fnCallback ,it should be NULL 
-_RAPIDASRAPI RPASR_RESULT	RapidAsrRecogBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback);
-
-_RAPIDASRAPI RPASR_RESULT	RapidAsrRecogPCMBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback);
-
-_RAPIDASRAPI RPASR_RESULT	RapidAsrRecogPCMFile(RPASR_HANDLE handle, const char* szFileName, RPASR_MODE Mode, QM_CALLBACK fnCallback);
-
-_RAPIDASRAPI RPASR_RESULT	RapidAsrRecogFile(RPASR_HANDLE handle, const char* szWavfile, RPASR_MODE Mode, QM_CALLBACK fnCallback);
-
-_RAPIDASRAPI const char*	RapidAsrGetResult(RPASR_RESULT Result,int nIndex);
-
-_RAPIDASRAPI const int		RapidAsrGetRetNumber(RPASR_RESULT Result);
-
-_RAPIDASRAPI void			RapidAsrFreeResult(RPASR_RESULT Result);
-
-_RAPIDASRAPI void			RapidAsrUninit(RPASR_HANDLE Handle);
-
-_RAPIDASRAPI const float	RapidAsrGetRetSnippetTime(RPASR_RESULT Result);
-
-#ifdef __cplusplus 
-
-}
-#endif

+ 14 - 14
funasr/runtime/onnxruntime/src/CMakeLists.txt

@@ -6,38 +6,38 @@ set(files ${files1} ${files2} ${files3} ${files4})
 
 # message("${files}")
 
-add_library(rapidasr ${files})
+add_library(funasr ${files})
 
 if(WIN32)
 
         set(EXTRA_LIBS libfftw3f-3 yaml-cpp)
         if(CMAKE_CL_64)
-            target_link_directories(rapidasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x64)
+            target_link_directories(funasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x64)
         else()
-            target_link_directories(rapidasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x86)
+            target_link_directories(funasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x86)
         endif()
-        target_include_directories(rapidasr PUBLIC ${CMAKE_SOURCE_DIR}/win/include )
+        target_include_directories(funasr PUBLIC ${CMAKE_SOURCE_DIR}/win/include )
         
-        target_compile_definitions(rapidasr PUBLIC -D_RPASR_API_EXPORT)
+        target_compile_definitions(funasr PUBLIC -D_FUNASR_API_EXPORT)
 else()
 
     set(EXTRA_LIBS fftw3f pthread yaml-cpp)
-    target_include_directories(rapidasr PUBLIC "/usr/local/opt/fftw/include")
-    target_link_directories(rapidasr PUBLIC "/usr/local/opt/fftw/lib")
+    target_include_directories(funasr PUBLIC "/usr/local/opt/fftw/include")
+    target_link_directories(funasr PUBLIC "/usr/local/opt/fftw/lib")
 
-    target_include_directories(rapidasr PUBLIC "/usr/local/opt/openblas/include")
-    target_link_directories(rapidasr PUBLIC "/usr/local/opt/openblas/lib")
+    target_include_directories(funasr PUBLIC "/usr/local/opt/openblas/include")
+    target_link_directories(funasr PUBLIC "/usr/local/opt/openblas/lib")
 
-    target_include_directories(rapidasr PUBLIC "/usr/include")
-    target_link_directories(rapidasr PUBLIC "/usr/lib64")
+    target_include_directories(funasr PUBLIC "/usr/include")
+    target_link_directories(funasr PUBLIC "/usr/lib64")
 
-    target_include_directories(rapidasr PUBLIC  ${FFTW3F_INCLUDE_DIR})
-    target_link_directories(rapidasr PUBLIC ${FFTW3F_LIBRARY_DIR})
+    target_include_directories(funasr PUBLIC  ${FFTW3F_INCLUDE_DIR})
+    target_link_directories(funasr PUBLIC ${FFTW3F_LIBRARY_DIR})
     include_directories(${ONNXRUNTIME_DIR}/include)    
 endif()
 
 include_directories(${CMAKE_SOURCE_DIR}/include)
-target_link_libraries(rapidasr PUBLIC onnxruntime ${EXTRA_LIBS})
+target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})
 
 
 

+ 8 - 20
funasr/runtime/onnxruntime/src/FeatureExtract.cpp

@@ -5,14 +5,10 @@ using namespace std;
 
 FeatureExtract::FeatureExtract(int mode) : mode(mode)
 {
-    fftw_init();
 }
 
 FeatureExtract::~FeatureExtract()
 {
-    fftwf_free(fft_input);
-    fftwf_free(fft_out);
-    fftwf_destroy_plan(p);
 }
 
 void FeatureExtract::reset()
@@ -26,34 +22,25 @@ int FeatureExtract::size()
     return fqueue.size();
 }
 
-void FeatureExtract::fftw_init()
+void FeatureExtract::insert(fftwf_plan plan, float *din, int len, int flag)
 {
-    int fft_size = 512;
-    fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size);
-    fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size);
+    float* fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size);
+    fftwf_complex* fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size);
     memset(fft_input, 0, sizeof(float) * fft_size);
-    p = fftwf_plan_dft_r2c_1d(fft_size, fft_input, fft_out, FFTW_ESTIMATE);
-}
 
-void FeatureExtract::insert(float *din, int len, int flag)
-{
     const float *window = (const float *)&window_hex;
     if (mode == 3)
         window = (const float *)&window_hamm_hex;
 
-    int window_size = 400;
-    int fft_size = 512;
-    int window_shift = 160;
-
     speech.load(din, len);
     int i, j;
     float tmp_feature[80];
     if (mode == 0 || mode == 2 || mode == 3) {
-        int ll = (speech.size() - 400) / 160 + 1;
+        int ll = (speech.size() - window_size) / window_shift + 1;
         fqueue.reinit(ll);
     }
 
-    for (i = 0; i <= speech.size() - 400; i = i + window_shift) {
+    for (i = 0; i <= speech.size() - window_size; i = i + window_shift) {
         float tmp_mean = 0;
         for (j = 0; j < window_size; j++) {
             tmp_mean += speech[i + j];
@@ -70,7 +57,7 @@ void FeatureExtract::insert(float *din, int len, int flag)
             pre_val = cur_val;
         }
 
-        fftwf_execute(p);
+        fftwf_execute_dft_r2c(plan, fft_input, fft_out);
 
         melspect((float *)fft_out, tmp_feature);
         int tmp_flag = S_MIDDLE;
@@ -80,6 +67,8 @@ void FeatureExtract::insert(float *din, int len, int flag)
         fqueue.push(tmp_feature, tmp_flag);
     }
     speech.update(i);
+    fftwf_free(fft_input);
+    fftwf_free(fft_out);
 }
 
 bool FeatureExtract::fetch(Tensor<float> *&dout)
@@ -128,7 +117,6 @@ void FeatureExtract::global_cmvn(float *din)
 void FeatureExtract::melspect(float *din, float *dout)
 {
     float fftmag[256];
-//    float tmp;
     const float *melcoe = (const float *)melcoe_hex;
     int i;
     for (i = 0; i < 256; i++) {

+ 6 - 7
funasr/runtime/onnxruntime/src/FeatureExtract.h

@@ -14,12 +14,11 @@ class FeatureExtract {
     SpeechWrap speech;
     FeatureQueue fqueue;
     int mode;
+    int fft_size = 512;
+    int window_size = 400;
+    int window_shift = 160;
 
-    float *fft_input;
-    fftwf_complex *fft_out;
-    fftwf_plan p;
-
-    void fftw_init();
+    //void fftw_init();
     void melspect(float *din, float *dout);
     void global_cmvn(float *din);
 
@@ -27,9 +26,9 @@ class FeatureExtract {
     FeatureExtract(int mode);
     ~FeatureExtract();
     int size();
-    int status();
+    //int status();
     void reset();
-    void insert(float *din, int len, int flag);
+    void insert(fftwf_plan plan, float *din, int len, int flag);
     bool fetch(Tensor<float> *&dout);
 };
 

+ 2 - 2
funasr/runtime/onnxruntime/src/commonfunc.h

@@ -5,7 +5,7 @@ typedef struct
 {
     std::string msg;
     float  snippet_time;
-}RPASR_RECOG_RESULT;
+}FUNASR_RECOG_RESULT;
 
 
 #ifdef _WIN32
@@ -53,4 +53,4 @@ inline void getOutputName(Ort::Session* session, string& outputName, int nIndex
 
         }
     }
-}
+}

+ 21 - 21
funasr/runtime/onnxruntime/src/librapidasrapi.cpp → funasr/runtime/onnxruntime/src/libfunasrapi.cpp

@@ -5,13 +5,13 @@ extern "C" {
 #endif
 
 	// APIs for qmasr
-	_RAPIDASRAPI RPASR_HANDLE  RapidAsrInit(const char* szModelDir, int nThreadNum, bool quantize)
+	_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThreadNum, bool quantize)
 	{
 		Model* mm = create_model(szModelDir, nThreadNum, quantize);
 		return mm;
 	}
 
-	_RAPIDASRAPI RPASR_RESULT RapidAsrRecogBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
 	{
 		Model* pRecogObj = (Model*)handle;
 		if (!pRecogObj)
@@ -25,12 +25,12 @@ extern "C" {
 		float* buff;
 		int len;
 		int flag=0;
-		RPASR_RECOG_RESULT* pResult = new RPASR_RECOG_RESULT;
+		FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
 		pResult->snippet_time = audio.get_time_len();
 		int nStep = 0;
 		int nTotal = audio.get_queue_size();
 		while (audio.fetch(buff, len, flag) > 0) {
-			pRecogObj->reset();
+			//pRecogObj->reset();
 			string msg = pRecogObj->forward(buff, len, flag);
 			pResult->msg += msg;
 			nStep++;
@@ -41,7 +41,7 @@ extern "C" {
 		return pResult;
 	}
 
-	_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
 	{
 		Model* pRecogObj = (Model*)handle;
 		if (!pRecogObj)
@@ -55,12 +55,12 @@ extern "C" {
 		float* buff;
 		int len;
 		int flag = 0;
-		RPASR_RECOG_RESULT* pResult = new RPASR_RECOG_RESULT;
+		FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
 		pResult->snippet_time = audio.get_time_len();
 		int nStep = 0;
 		int nTotal = audio.get_queue_size();
 		while (audio.fetch(buff, len, flag) > 0) {
-			pRecogObj->reset();
+			//pRecogObj->reset();
 			string msg = pRecogObj->forward(buff, len, flag);
 			pResult->msg += msg;
 			nStep++;
@@ -71,7 +71,7 @@ extern "C" {
 		return pResult;
 	}
 
-	_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMFile(RPASR_HANDLE handle, const char* szFileName, RPASR_MODE Mode, QM_CALLBACK fnCallback)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
 	{
 		Model* pRecogObj = (Model*)handle;
 		if (!pRecogObj)
@@ -85,12 +85,12 @@ extern "C" {
 		float* buff;
 		int len;
 		int flag = 0;
-		RPASR_RECOG_RESULT* pResult = new RPASR_RECOG_RESULT;
+		FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
 		pResult->snippet_time = audio.get_time_len();
 		int nStep = 0;
 		int nTotal = audio.get_queue_size();
 		while (audio.fetch(buff, len, flag) > 0) {
-			pRecogObj->reset();
+			//pRecogObj->reset();
 			string msg = pRecogObj->forward(buff, len, flag);
 			pResult->msg += msg;
 			nStep++;
@@ -101,7 +101,7 @@ extern "C" {
 		return pResult;
 	}
 
-	_RAPIDASRAPI RPASR_RESULT RapidAsrRecogFile(RPASR_HANDLE handle, const char* szWavfile, RPASR_MODE Mode, QM_CALLBACK fnCallback)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
 	{
 		Model* pRecogObj = (Model*)handle;
 		if (!pRecogObj)
@@ -117,10 +117,10 @@ extern "C" {
 		int flag = 0;
 		int nStep = 0;
 		int nTotal = audio.get_queue_size();
-		RPASR_RECOG_RESULT* pResult = new RPASR_RECOG_RESULT;
+		FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
 		pResult->snippet_time = audio.get_time_len();
 		while (audio.fetch(buff, len, flag) > 0) {
-			pRecogObj->reset();
+			//pRecogObj->reset();
 			string msg = pRecogObj->forward(buff, len, flag);
 			pResult->msg+= msg;
 			nStep++;
@@ -131,7 +131,7 @@ extern "C" {
 		return pResult;
 	}
 
-	_RAPIDASRAPI const int RapidAsrGetRetNumber(RPASR_RESULT Result)
+	_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT Result)
 	{
 		if (!Result)
 			return 0;
@@ -140,32 +140,32 @@ extern "C" {
 	}
 
 
-	_RAPIDASRAPI const float RapidAsrGetRetSnippetTime(RPASR_RESULT Result)
+	_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT Result)
 	{
 		if (!Result)
 			return 0.0f;
 
-		return ((RPASR_RECOG_RESULT*)Result)->snippet_time;
+		return ((FUNASR_RECOG_RESULT*)Result)->snippet_time;
 	}
 
-	_RAPIDASRAPI const char* RapidAsrGetResult(RPASR_RESULT Result,int nIndex)
+	_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT Result,int nIndex)
 	{
-		RPASR_RECOG_RESULT * pResult = (RPASR_RECOG_RESULT*)Result;
+		FUNASR_RECOG_RESULT * pResult = (FUNASR_RECOG_RESULT*)Result;
 		if(!pResult)
 			return nullptr;
 
 		return pResult->msg.c_str();
 	}
 
-	_RAPIDASRAPI void RapidAsrFreeResult(RPASR_RESULT Result)
+	_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT Result)
 	{
 		if (Result)
 		{
-			delete (RPASR_RECOG_RESULT*)Result;
+			delete (FUNASR_RECOG_RESULT*)Result;
 		}
 	}
 
-	_RAPIDASRAPI void RapidAsrUninit(RPASR_HANDLE handle)
+	_FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle)
 	{
 		Model* pRecogObj = (Model*)handle;
 

+ 19 - 9
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp

@@ -18,7 +18,10 @@ ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
     cmvn_path = pathAppend(path, "am.mvn");
     config_path = pathAppend(path, "config.yaml");
 
-    fe = new FeatureExtract(3);
+    fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size);
+    fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size);
+    memset(fft_input, 0, sizeof(float) * fft_size);
+    plan = fftwf_plan_dft_r2c_1d(fft_size, fft_input, fft_out, FFTW_ESTIMATE);
 
     //sessionOptions.SetInterOpNumThreads(1);
     sessionOptions.SetIntraOpNumThreads(nNumThread);
@@ -52,8 +55,6 @@ ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
 
 ModelImp::~ModelImp()
 {
-    if(fe)
-        delete fe;
     if (m_session)
     {
         delete m_session;
@@ -61,11 +62,15 @@ ModelImp::~ModelImp()
     }
     if(vocab)
         delete vocab;
+    fftwf_free(fft_input);
+    fftwf_free(fft_out);
+    fftwf_destroy_plan(plan);
+    fftwf_cleanup();
 }
 
 void ModelImp::reset()
 {
-    fe->reset();
+    printf("Not Imp!!!!!!\n");
 }
 
 void ModelImp::apply_lfr(Tensor<float>*& din)
@@ -159,9 +164,10 @@ string ModelImp::greedy_search(float * in, int nLen )
 
 string ModelImp::forward(float* din, int len, int flag)
 {
-
     Tensor<float>* in;
-    fe->insert(din, len, flag);
+    FeatureExtract* fe = new FeatureExtract(3);
+    fe->reset();
+    fe->insert(plan, din, len, flag);
     fe->fetch(in);
     apply_lfr(in);
     apply_cmvn(in);
@@ -192,7 +198,6 @@ string ModelImp::forward(float* din, int len, int flag)
         auto outputTensor = m_session->Run(run_option, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size());
         std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
 
-
         int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
         float* floatData = outputTensor[0].GetTensorMutableData<float>();
         auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
@@ -203,9 +208,14 @@ string ModelImp::forward(float* din, int len, int flag)
         result = "";
     }
 
-
-    if(in)
+    if(in){
         delete in;
+        in = nullptr;
+    }
+    if(fe){
+        delete fe;
+        fe = nullptr;
+    }
 
     return result;
 }

+ 4 - 3
funasr/runtime/onnxruntime/src/paraformer_onnx.h

@@ -8,7 +8,10 @@ namespace paraformer {
 
     class ModelImp : public Model {
     private:
-        FeatureExtract* fe;
+        int fft_size=512;
+        float *fft_input;
+        fftwf_complex *fft_out;
+        fftwf_plan plan;
 
         Vocab* vocab;
         vector<float> means_list;
@@ -34,8 +37,6 @@ namespace paraformer {
         vector<string> m_strInputNames, m_strOutputNames;
         vector<const char*> m_szInputNames;
         vector<const char*> m_szOutputNames;
-        //string m_strInputName, m_strInputNameLen;
-        //string m_strOutputName, m_strOutputNameLen;
 
     public:
         ModelImp(const char* path, int nNumThread=0, bool quantize=false);

+ 1 - 1
funasr/runtime/onnxruntime/src/precomp.h

@@ -46,7 +46,7 @@ using namespace std;
 #include <Audio.h>
 #include "Model.h"
 #include "paraformer_onnx.h"
-#include "librapidasrapi.h"
+#include "libfunasrapi.h"
 
 
 using namespace paraformer;

+ 1 - 1
funasr/runtime/onnxruntime/tester/CMakeLists.txt

@@ -8,7 +8,7 @@ if(WIN32)
     endif()
 endif()
 
-set(EXTRA_LIBS rapidasr)
+set(EXTRA_LIBS funasr)
 
 
 include_directories(${CMAKE_SOURCE_DIR}/include)

+ 9 - 49
funasr/runtime/onnxruntime/tester/tester.cpp

@@ -5,7 +5,7 @@
 #include <win_func.h>
 #endif
 
-#include "librapidasrapi.h"
+#include "libfunasrapi.h"
 
 #include <iostream>
 #include <fstream>
@@ -26,7 +26,7 @@ int main(int argc, char *argv[])
     // is quantize
     bool quantize = false;
     istringstream(argv[3]) >> boolalpha >> quantize;
-    RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum, quantize);
+    FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize);
 
     if (!AsrHanlde)
     {
@@ -42,72 +42,32 @@ int main(int argc, char *argv[])
     gettimeofday(&start, NULL);
     float snippet_time = 0.0f;
 
-    RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL);
+    FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL);
 
     gettimeofday(&end, NULL);
    
     if (Result)
     {
-        string msg = RapidAsrGetResult(Result, 0);
+        string msg = FunASRGetResult(Result, 0);
         setbuf(stdout, NULL);
-        cout << "Result: \"";
-        cout << msg << "\"." << endl;
-        snippet_time = RapidAsrGetRetSnippetTime(Result);
-        RapidAsrFreeResult(Result);
+        printf("Result: %s \n", msg.c_str());
+        snippet_time = FunASRGetRetSnippetTime(Result);
+        FunASRFreeResult(Result);
     }
     else
     {
         cout <<"no return data!";
     }
  
-    //char* buff = nullptr;
-    //int len = 0;
-    //ifstream ifs(argv[2], std::ios::binary | std::ios::in);
-    //if (ifs.is_open())
-    //{
-    //    ifs.seekg(0, std::ios::end);
-    //    len = ifs.tellg();
-    //    ifs.seekg(0, std::ios::beg);
-
-    //    buff = new char[len];
-
-    //    ifs.read(buff, len);
-
-
-    //    //RPASR_RESULT Result = RapidAsrRecogPCMFile(AsrHanlde, argv[2], RASR_NONE, NULL);
-
-    //    RPASR_RESULT Result=RapidAsrRecogPCMBuffer(AsrHanlde, buff,len, RASR_NONE, NULL);
-    //    //RPASR_RESULT Result = RapidAsrRecogPCMFile(AsrHanlde, argv[2], RASR_NONE, NULL);
-    //    gettimeofday(&end, NULL);
-    //   
-    //    if (Result)
-    //    {
-    //        string msg = RapidAsrGetResult(Result, 0);
-    //        setbuf(stdout, NULL);
-    //        cout << "Result: \"";
-    //        cout << msg << endl;
-    //        cout << "\"." << endl;
-    //        snippet_time = RapidAsrGetRetSnippetTime(Result);
-    //        RapidAsrFreeResult(Result);
-    //    }
-    //    else
-    //    {
-    //        cout <<"no return data!";
-    //    }
-  
-    //   
-    //delete[]buff;
-    //}
- 
     printf("Audio length %lfs.\n", (double)snippet_time);
     seconds = (end.tv_sec - start.tv_sec);
     long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
     printf("Model inference takes %lfs.\n", (double)taking_micros / 1000000);
     printf("Model inference RTF: %04lf.\n", (double)taking_micros/ (snippet_time*1000000));
 
-    RapidAsrUninit(AsrHanlde);
+    FunASRUninit(AsrHanlde);
 
     return 0;
 }
 
-    
+    

+ 9 - 9
funasr/runtime/onnxruntime/tester/tester_rtf.cpp

@@ -5,7 +5,7 @@
 #include <win_func.h>
 #endif
 
-#include "librapidasrapi.h"
+#include "libfunasrapi.h"
 
 #include <iostream>
 #include <fstream>
@@ -47,7 +47,7 @@ int main(int argc, char *argv[])
     bool quantize = false;
     istringstream(argv[3]) >> boolalpha >> quantize;
 
-    RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum, quantize);
+    FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize);
     if (!AsrHanlde)
     {
         printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
@@ -61,7 +61,7 @@ int main(int argc, char *argv[])
     // warm up
     for (size_t i = 0; i < 30; i++)
     {
-        RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, wav_list[0].c_str(), RASR_NONE, NULL);
+        FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, wav_list[0].c_str(), RASR_NONE, NULL);
     }
 
     // forward
@@ -72,19 +72,19 @@ int main(int argc, char *argv[])
     for (size_t i = 0; i < wav_list.size(); i++)
     {
         gettimeofday(&start, NULL);
-        RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, wav_list[i].c_str(), RASR_NONE, NULL);
+        FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, wav_list[i].c_str(), RASR_NONE, NULL);
         gettimeofday(&end, NULL);
         seconds = (end.tv_sec - start.tv_sec);
         long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
         total_time += taking_micros;
 
         if(Result){
-            string msg = RapidAsrGetResult(Result, 0);
-            printf("Result: %s \n", msg);
+            string msg = FunASRGetResult(Result, 0);
+            printf("Result: %s \n", msg.c_str());
 
-            snippet_time = RapidAsrGetRetSnippetTime(Result);
+            snippet_time = FunASRGetRetSnippetTime(Result);
             total_length += snippet_time;
-            RapidAsrFreeResult(Result);
+            FunASRFreeResult(Result);
         }else{
             cout <<"No return data!";
         }
@@ -94,6 +94,6 @@ int main(int argc, char *argv[])
     printf("total_time_comput %ld ms.\n", total_time / 1000);
     printf("total_rtf %05lf .\n", (double)total_time/ (total_length*1000000));
 
-    RapidAsrUninit(AsrHanlde);
+    FunASRUninit(AsrHanlde);
     return 0;
 }

+ 62 - 0
funasr/runtime/python/grpc/grpc_main_client.py

@@ -0,0 +1,62 @@
+import grpc
+import json
+import time
+import asyncio
+import soundfile as sf
+import argparse
+
+from grpc_client import transcribe_audio_bytes
+from paraformer_pb2_grpc import ASRStub
+
+# send the audio data once
+async def grpc_rec(wav_scp, grpc_uri, asr_user, language):
+    with grpc.insecure_channel(grpc_uri) as channel:
+        stub = ASRStub(channel)
+        for line in wav_scp:
+            wav_file = line.split()[1]
+            wav, _ = sf.read(wav_file, dtype='int16')
+            
+            b = time.time()
+            response = transcribe_audio_bytes(stub, wav.tobytes(), user=asr_user, language=language, speaking=False, isEnd=False)
+            resp = response.next()
+            text = ''
+            if 'decoding' == resp.action:
+                resp = response.next()
+                if 'finish' == resp.action:
+                    text = json.loads(resp.sentence)['text']
+            response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking=False, isEnd=True)
+            res= {'text': text, 'time': time.time() - b}
+            print(res)
+
+async def test(args):
+    wav_scp = open(args.wav_scp, "r").readlines()
+    uri = '{}:{}'.format(args.host, args.port)
+    res = await grpc_rec(wav_scp, uri, args.user_allowed, language = 'zh-CN')
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--host",
+                        type=str,
+                        default="127.0.0.1",
+                        required=False,
+                        help="grpc server host ip")
+    parser.add_argument("--port",
+                        type=int,
+                        default=10108,
+                        required=False,
+                        help="grpc server port")              
+    parser.add_argument("--user_allowed",
+                        type=str,
+                        default="project1_user1",
+                        help="allowed user for grpc client")
+    parser.add_argument("--sample_rate",
+                        type=int,
+                        default=16000,
+                        help="audio sample_rate from client") 
+    parser.add_argument("--wav_scp",
+                        type=str,
+                        required=True,
+                        help="audio wav scp")                    
+    args = parser.parse_args()
+    
+    asyncio.run(test(args))