Browse Source

fix paraformer server for new apis

lyblsgo 2 years ago
parent
commit
a539392ad4

+ 12 - 9
funasr/runtime/grpc/CMakeLists.txt

@@ -42,17 +42,23 @@ add_custom_command(
         "${rg_proto}"
       DEPENDS "${rg_proto}")
 
-
 # Include generated *.pb.h files
 include_directories("${CMAKE_CURRENT_BINARY_DIR}")
 
-include_directories(../onnxruntime/include/)
-link_directories(../onnxruntime/build/src/)
-link_directories(../onnxruntime/build/third_party/yaml-cpp/)
-
 link_directories(${ONNXRUNTIME_DIR}/lib)
+
+include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/include/)
+include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp/include/)
+include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank)
+
+add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp yaml-cpp)
+add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank/kaldi-native-fbank/csrc csrc)
 add_subdirectory("../onnxruntime/src" onnx_src)
 
+include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog)
+set(BUILD_TESTING OFF)
+add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog)
+
 # rg_grpc_proto
 add_library(rg_grpc_proto
   ${rg_grpc_srcs}
@@ -60,16 +66,13 @@ add_library(rg_grpc_proto
   ${rg_proto_srcs}
   ${rg_proto_hdrs})
 
-
-
 target_link_libraries(rg_grpc_proto
   ${_REFLECTION}
   ${_GRPC_GRPCPP}
   ${_PROTOBUF_LIBPROTOBUF})
 
-# Targets paraformer_(server)
 foreach(_target
-  paraformer_server)
+  paraformer-server)
   add_executable(${_target}
     "${_target}.cc")
   target_link_libraries(${_target}

+ 33 - 11
funasr/runtime/grpc/Readme.md

@@ -4,15 +4,6 @@
 
 ### Build [onnxruntime](./onnxruntime_cpp.md) as it's document
 
-```
-#put onnx-lib & onnx-asr-model into /path/to/asrmodel(eg: /data/asrmodel)
-ls /data/asrmodel/
-onnxruntime-linux-x64-1.14.0  speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
-
-#make sure you have config.yaml, am.mvn, model.onnx(or model_quant.onnx) under speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
-
-```
-
 ### Compile and install grpc v1.52.0 in case of grpc bugs
 ```
 export GRPC_INSTALL_DIR=/data/soft/grpc
@@ -46,8 +37,39 @@ source ~/.bashrc
 
 ### Start grpc paraformer server
 ```
-Usage: ./cmake/build/paraformer_server port thread_num /path/to/model_file quantize(true or false)
-./cmake/build/paraformer_server 10108 4 /data/asrmodel/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch false
+./cmake/build/paraformer-server     --port-id <string> [--punc-config
+                                    <string>] [--punc-model <string>]
+                                    --am-config <string> --am-cmvn <string>
+                                    --am-model <string> [--vad-config
+                                    <string>] [--vad-cmvn <string>]
+                                    [--vad-model <string>] [--] [--version]
+                                    [-h]
+Where:
+   --port-id <string>
+     (required)  port id
+
+   --am-config <string>
+     (required)  am config path
+   --am-cmvn <string>
+     (required)  am cmvn path
+   --am-model <string>
+     (required)  am model path
+
+   --punc-config <string>
+     punc config path
+   --punc-model <string>
+     punc model path
+
+   --vad-config <string>
+     vad config path
+   --vad-cmvn <string>
+     vad cmvn path
+   --vad-model <string>
+     vad model path
+
+   Required: --port-id <string> --am-config <string> --am-cmvn <string> --am-model <string> 
+   If use vad, please add: [--vad-config <string>] [--vad-cmvn <string>] [--vad-model <string>]
+   If use punc, please add: [--punc-config <string>] [--punc-model <string>] 
 ```
 
 ## For the client

+ 87 - 33
funasr/runtime/grpc/paraformer_server.cc → funasr/runtime/grpc/paraformer-server.cc

@@ -13,7 +13,10 @@
 #include <grpcpp/security/server_credentials.h>
 
 #include "paraformer.grpc.pb.h"
-#include "paraformer_server.h"
+#include "paraformer-server.h"
+#include "tclap/CmdLine.h"
+#include "com-define.h"
+#include "glog/logging.h"
 
 using grpc::Server;
 using grpc::ServerBuilder;
@@ -27,31 +30,43 @@ using paraformer::Request;
 using paraformer::Response;
 using paraformer::ASR;
 
-ASRServicer::ASRServicer(const char* model_path, int thread_num, bool quantize) {
-    AsrHanlde=FunASRInit(model_path, thread_num, quantize);
+ASRServicer::ASRServicer(std::map<std::string, std::string>& model_path) {
+    AsrHanlde=FunASRInit(model_path, 1);
     std::cout << "ASRServicer init" << std::endl;
     init_flag = 0;
 }
 
+void ASRServicer::clear_states(const std::string& user) {
+    clear_buffers(user);
+    clear_transcriptions(user);
+}
+
+void ASRServicer::clear_buffers(const std::string& user) {
+    if (client_buffers.count(user)) {
+        client_buffers.erase(user);
+    }
+}
+
+void ASRServicer::clear_transcriptions(const std::string& user) {
+    if (client_transcription.count(user)) {
+        client_transcription.erase(user);
+    }
+}
+
+void ASRServicer::disconnect(const std::string& user) {
+    clear_states(user);
+    std::cout << "Disconnecting user: " << user << std::endl;
+}
+
 grpc::Status ASRServicer::Recognize(
     grpc::ServerContext* context,
     grpc::ServerReaderWriter<Response, Request>* stream) {
 
     Request req;
-    std::unordered_map<std::string, std::string> client_buffers;
-    std::unordered_map<std::string, std::string> client_transcription;
-
     while (stream->Read(&req)) {
         if (req.isend()) {
             std::cout << "asr end" << std::endl;
-            // disconnect 
-            if (client_buffers.count(req.user())) {
-                client_buffers.erase(req.user());
-            }
-            if (client_transcription.count(req.user())) {
-                client_transcription.erase(req.user());
-            }
-
+            disconnect(req.user());
             Response res;
             res.set_sentence(
                 R"({"success": true, "detail": "asr end"})"
@@ -89,14 +104,8 @@ grpc::Status ASRServicer::Recognize(
                   auto& buf = client_buffers[req.user()];
                   buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end());
                 }
-                std::string tmp_data = client_buffers[req.user()];
-                // clear_states
-                if (client_buffers.count(req.user())) {
-                    client_buffers.erase(req.user());
-                }
-                if (client_transcription.count(req.user())) {
-                    client_transcription.erase(req.user());
-                }
+                std::string tmp_data = this->client_buffers[req.user()];
+                this->clear_states(req.user());
 
                 Response res;
                 res.set_sentence(
@@ -161,10 +170,17 @@ grpc::Status ASRServicer::Recognize(
     return Status::OK;
 }
 
-void RunServer(const std::string& port, int thread_num, const char* model_path, bool quantize) {
+void RunServer(std::map<std::string, std::string>& model_path) {
+    std::string port;
+    try{
+        port = model_path.at(PORT_ID);
+    }catch(std::exception const &e){
+        printf("Error when read port.\n");
+        exit(0);
+    }
     std::string server_address;
     server_address = "0.0.0.0:" + port;
-    ASRServicer service(model_path, thread_num, quantize);
+    ASRServicer service(model_path);
 
     ServerBuilder builder;
     builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
@@ -174,16 +190,54 @@ void RunServer(const std::string& port, int thread_num, const char* model_path,
     server->Wait();
 }
 
-int main(int argc, char* argv[]) {
-    if (argc < 5)
-    {
-        printf("Usage: %s port thread_num /path/to/model_file quantize(true or false) \n", argv[0]);
-        exit(-1);
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, std::string key, std::map<std::string, std::string>& model_path)
+{
+    if (value_arg.isSet()){
+        model_path.insert({key, value_arg.getValue()});
+        LOG(INFO)<< key << " : " << value_arg.getValue();
     }
+}
+
+int main(int argc, char* argv[]) {
 
-    // is quantize
-    bool quantize = false;
-    std::istringstream(argv[4]) >> std::boolalpha >> quantize;
-    RunServer(argv[1], atoi(argv[2]), argv[3], quantize);
+    google::InitGoogleLogging(argv[0]);
+    FLAGS_logtostderr = true;
+
+    TCLAP::CmdLine cmd("paraformer-server", ' ', "1.0");
+    TCLAP::ValueArg<std::string> vad_model("", VAD_MODEL_PATH, "vad model path", false, "", "string");
+    TCLAP::ValueArg<std::string> vad_cmvn("", VAD_CMVN_PATH, "vad cmvn path", false, "", "string");
+    TCLAP::ValueArg<std::string> vad_config("", VAD_CONFIG_PATH, "vad config path", false, "", "string");
+
+    TCLAP::ValueArg<std::string> am_model("", AM_MODEL_PATH, "am model path", true, "", "string");
+    TCLAP::ValueArg<std::string> am_cmvn("", AM_CMVN_PATH, "am cmvn path", true, "", "string");
+    TCLAP::ValueArg<std::string> am_config("", AM_CONFIG_PATH, "am config path", true, "", "string");
+
+    TCLAP::ValueArg<std::string> punc_model("", PUNC_MODEL_PATH, "punc model path", false, "", "string");
+    TCLAP::ValueArg<std::string> punc_config("", PUNC_CONFIG_PATH, "punc config path", false, "", "string");
+    TCLAP::ValueArg<std::string> port_id("", PORT_ID, "port id", true, "", "string");
+
+    cmd.add(vad_model);
+    cmd.add(vad_cmvn);
+    cmd.add(vad_config);
+    cmd.add(am_model);
+    cmd.add(am_cmvn);
+    cmd.add(am_config);
+    cmd.add(punc_model);
+    cmd.add(punc_config);
+    cmd.add(port_id);
+    cmd.parse(argc, argv);
+
+    std::map<std::string, std::string> model_path;
+    GetValue(vad_model, VAD_MODEL_PATH, model_path);
+    GetValue(vad_cmvn, VAD_CMVN_PATH, model_path);
+    GetValue(vad_config, VAD_CONFIG_PATH, model_path);
+    GetValue(am_model, AM_MODEL_PATH, model_path);
+    GetValue(am_cmvn, AM_CMVN_PATH, model_path);
+    GetValue(am_config, AM_CONFIG_PATH, model_path);
+    GetValue(punc_model, PUNC_MODEL_PATH, model_path);
+    GetValue(punc_config, PUNC_CONFIG_PATH, model_path);
+    GetValue(port_id, PORT_ID, model_path);
+
+    RunServer(model_path);
     return 0;
 }

+ 7 - 2
funasr/runtime/grpc/paraformer_server.h → funasr/runtime/grpc/paraformer-server.h

@@ -37,13 +37,18 @@ typedef struct
     float  snippet_time;
 }FUNASR_RECOG_RESULT;
 
-
 class ASRServicer final : public ASR::Service {
   private:
     int init_flag;
+    std::unordered_map<std::string, std::string> client_buffers;
+    std::unordered_map<std::string, std::string> client_transcription;
 
   public:
-    ASRServicer(const char* model_path, int thread_num, bool quantize);
+    ASRServicer(std::map<std::string, std::string>& model_path);
+    void clear_states(const std::string& user);
+    void clear_buffers(const std::string& user);
+    void clear_transcriptions(const std::string& user);
+    void disconnect(const std::string& user);
     grpc::Status Recognize(grpc::ServerContext* context, grpc::ServerReaderWriter<Response, Request>* stream);
     FUNASR_HANDLE AsrHanlde;
 	

+ 1 - 0
funasr/runtime/onnxruntime/include/com-define.h

@@ -24,6 +24,7 @@
 #define WAV_PATH "wav-path"
 #define WAV_SCP "wav-scp"
 #define THREAD_NUM "thread-num"
+#define PORT_ID "port-id"
 
 // vad
 #ifndef VAD_SILENCE_DURATION

+ 9 - 2
funasr/runtime/onnxruntime/include/libfunasrapi.h

@@ -47,10 +47,9 @@ typedef enum {
 
 typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
 	
-// APIs for funasr
+// // ASR
 _FUNASRAPI FUNASR_HANDLE  FunASRInit(std::map<std::string, std::string>& model_path, int thread_num);
 
-// if not give a fn_callback ,it should be NULL 
 _FUNASRAPI FUNASR_RESULT	FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback);
 _FUNASRAPI FUNASR_RESULT	FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
 _FUNASRAPI FUNASR_RESULT	FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
@@ -62,6 +61,14 @@ _FUNASRAPI void			FunASRFreeResult(FUNASR_RESULT result);
 _FUNASRAPI void			FunASRUninit(FUNASR_HANDLE handle);
 _FUNASRAPI const float	FunASRGetRetSnippetTime(FUNASR_RESULT result);
 
+// VAD
+_FUNASRAPI FUNASR_HANDLE  FunVadInit(std::map<std::string, std::string>& model_path, int thread_num);
+
+_FUNASRAPI FUNASR_RESULT	FunASRVadBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+_FUNASRAPI FUNASR_RESULT	FunASRVadPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+_FUNASRAPI FUNASR_RESULT	FunASRVadPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+_FUNASRAPI FUNASR_RESULT	FunASRVadFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+
 #ifdef __cplusplus 
 
 }

+ 6 - 0
funasr/runtime/onnxruntime/src/libfunasrapi.cpp

@@ -11,6 +11,12 @@ extern "C" {
 		return mm;
 	}
 
+	_FUNASRAPI FUNASR_HANDLE  FunVadInit(std::map<std::string, std::string>& model_path, int thread_num)
+	{
+		Model* mm = CreateModel(model_path, thread_num);
+		return mm;
+	}
+
 	_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback)
 	{
 		Model* recog_obj = (Model*)handle;

+ 3 - 2
funasr/runtime/onnxruntime/src/precomp.h

@@ -21,8 +21,8 @@ using namespace std;
 // third part
 #include "onnxruntime_run_options_config_keys.h"
 #include "onnxruntime_cxx_api.h"
-#include <kaldi-native-fbank/csrc/feature-fbank.h>
-#include <kaldi-native-fbank/csrc/online-feature.h>
+#include "kaldi-native-fbank/csrc/feature-fbank.h"
+#include "kaldi-native-fbank/csrc/online-feature.h"
 
 // mine
 #include <glog/logging.h>
@@ -40,6 +40,7 @@ using namespace std;
 #include "util.h"
 #include "resample.h"
 #include "model.h"
+#include "vad-model.h"
 #include "paraformer.h"
 #include "libfunasrapi.h"