Просмотр исходного кода

add cpp websocket online 2pass srv

zhaoming 2 лет назад
Родитель
Сommit
e9862461bc

+ 2 - 0
funasr/runtime/websocket/CMakeLists.txt

@@ -58,7 +58,9 @@ add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog)
 find_package(OpenSSL REQUIRED)
 
 add_executable(funasr-wss-server "funasr-wss-server.cpp" "websocket-server.cpp")
+add_executable(funasr-wss-server-2pass "funasr-wss-server-2pass.cpp" "websocket-server-2pass.cpp")
 add_executable(funasr-wss-client "funasr-wss-client.cpp")
 
 target_link_libraries(funasr-wss-client PUBLIC funasr ssl crypto)
 target_link_libraries(funasr-wss-server PUBLIC funasr ssl crypto)
+target_link_libraries(funasr-wss-server-2pass PUBLIC funasr ssl crypto)

+ 419 - 0
funasr/runtime/websocket/funasr-wss-server-2pass.cpp

@@ -0,0 +1,419 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+ * Reserved. MIT License  (https://opensource.org/licenses/MIT)
+ */
+/* 2022-2023 by zhaomingwork */
+
+// io server
+// Usage:funasr-wss-server  [--model_thread_num <int>] [--decoder_thread_num
+// <int>]
+//                    [--io_thread_num <int>] [--port <int>] [--listen_ip
+//                    <string>] [--punc-quant <string>] [--punc-dir <string>]
+//                    [--vad-quant <string>] [--vad-dir <string>] [--quantize
+//                    <string>] --model-dir <string> [--] [--version] [-h]
+#include <unistd.h>
+#include "websocket-server-2pass.h"
+
+using namespace std;
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
+              std::map<std::string, std::string>& model_path) {
+  model_path.insert({key, value_arg.getValue()});
+  LOG(INFO) << key << " : " << value_arg.getValue();
+}
+int main(int argc, char* argv[]) {
+  try {
+    google::InitGoogleLogging(argv[0]);
+    FLAGS_logtostderr = true;
+
+    TCLAP::CmdLine cmd("funasr-wss-server", ' ', "1.0");
+    TCLAP::ValueArg<std::string> download_model_dir(
+        "", "download-model-dir",
+        "Download model from Modelscope to download_model_dir", false,
+        "/workspace/models", "string");
+    TCLAP::ValueArg<std::string> offline_model_dir(
+        "", OFFLINE_MODEL_DIR,
+        "default: /workspace/models/offline_asr, the asr model path, which "
+        "contains model_quant.onnx, config.yaml, am.mvn",
+        false, "/workspace/models/offline_asr", "string");
+    TCLAP::ValueArg<std::string> online_model_dir(
+        "", ONLINE_MODEL_DIR,
+        "default: /workspace/models/online_asr, the asr model path, which "
+        "contains model_quant.onnx, config.yaml, am.mvn",
+        false, "/workspace/models/online_asr", "string");
+
+    TCLAP::ValueArg<std::string> offline_model_revision(
+        "", "offline-model-revision", "ASR offline model revision", false,
+        "v1.2.1", "string");
+
+    TCLAP::ValueArg<std::string> online_model_revision(
+        "", "online-model-revision", "ASR online model revision", false,
+        "v1.0.4", "string");
+
+    TCLAP::ValueArg<std::string> quantize(
+        "", QUANTIZE,
+        "true (Default), load the model of model_quant.onnx in model_dir. If "
+        "set "
+        "false, load the model of model.onnx in model_dir",
+        false, "true", "string");
+    TCLAP::ValueArg<std::string> vad_dir(
+        "", VAD_DIR,
+        "default: /workspace/models/vad, the vad model path, which contains "
+        "model_quant.onnx, vad.yaml, vad.mvn",
+        false, "/workspace/models/vad", "string");
+    TCLAP::ValueArg<std::string> vad_revision(
+        "", "vad-revision", "VAD model revision", false, "v1.2.0", "string");
+    TCLAP::ValueArg<std::string> vad_quant(
+        "", VAD_QUANT,
+        "true (Default), load the model of model_quant.onnx in vad_dir. If set "
+        "false, load the model of model.onnx in vad_dir",
+        false, "true", "string");
+    TCLAP::ValueArg<std::string> punc_dir(
+        "", PUNC_DIR,
+        "default: /workspace/models/punc, the punc model path, which contains "
+        "model_quant.onnx, punc.yaml",
+        false, "/workspace/models/punc", "string");
+    TCLAP::ValueArg<std::string> punc_revision(
+        "", "punc-revision", "PUNC model revision", false, "0.4.7", "string");
+    TCLAP::ValueArg<std::string> punc_quant(
+        "", PUNC_QUANT,
+        "true (Default), load the model of model_quant.onnx in punc_dir. If "
+        "set "
+        "false, load the model of model.onnx in punc_dir",
+        false, "true", "string");
+
+    TCLAP::ValueArg<std::string> listen_ip("", "listen-ip", "listen ip", false,
+                                           "0.0.0.0", "string");
+    TCLAP::ValueArg<int> port("", "port", "port", false, 10095, "int");
+    TCLAP::ValueArg<int> io_thread_num("", "io-thread-num", "io thread num",
+                                       false, 8, "int");
+    TCLAP::ValueArg<int> decoder_thread_num(
+        "", "decoder-thread-num", "decoder thread num", false, 8, "int");
+    TCLAP::ValueArg<int> model_thread_num("", "model-thread-num",
+                                          "model thread num", false, 1, "int");
+
+    TCLAP::ValueArg<std::string> certfile(
+        "", "certfile",
+        "default: ../../../ssl_key/server.crt, path of certficate for WSS "
+        "connection. if it is empty, it will be in WS mode.",
+        false, "../../../ssl_key/server.crt", "string");
+    TCLAP::ValueArg<std::string> keyfile(
+        "", "keyfile",
+        "default: ../../../ssl_key/server.key, path of keyfile for WSS "
+        "connection",
+        false, "../../../ssl_key/server.key", "string");
+
+    cmd.add(certfile);
+    cmd.add(keyfile);
+
+    cmd.add(download_model_dir);
+    cmd.add(offline_model_dir);
+    cmd.add(online_model_dir);
+    cmd.add(offline_model_revision);
+    cmd.add(online_model_revision);
+    cmd.add(quantize);
+    cmd.add(vad_dir);
+    cmd.add(vad_revision);
+    cmd.add(vad_quant);
+    cmd.add(punc_dir);
+    cmd.add(punc_revision);
+    cmd.add(punc_quant);
+
+    cmd.add(listen_ip);
+    cmd.add(port);
+    cmd.add(io_thread_num);
+    cmd.add(decoder_thread_num);
+    cmd.add(model_thread_num);
+    cmd.parse(argc, argv);
+
+    std::map<std::string, std::string> model_path;
+    GetValue(offline_model_dir, OFFLINE_MODEL_DIR, model_path);
+    GetValue(online_model_dir, ONLINE_MODEL_DIR, model_path);
+    GetValue(quantize, QUANTIZE, model_path);
+    GetValue(vad_dir, VAD_DIR, model_path);
+    GetValue(vad_quant, VAD_QUANT, model_path);
+    GetValue(punc_dir, PUNC_DIR, model_path);
+    GetValue(punc_quant, PUNC_QUANT, model_path);
+
+    GetValue(offline_model_revision, "offline-model-revision", model_path);
+    GetValue(online_model_revision, "online-model-revision", model_path);
+    GetValue(vad_revision, "vad-revision", model_path);
+    GetValue(punc_revision, "punc-revision", model_path);
+
+    // Download model form Modelscope
+    try {
+      std::string s_download_model_dir = download_model_dir.getValue();
+
+      std::string s_vad_path = model_path[VAD_DIR];
+      std::string s_vad_quant = model_path[VAD_QUANT];
+      std::string s_offline_asr_path = model_path[OFFLINE_MODEL_DIR];
+      std::string s_online_asr_path = model_path[ONLINE_MODEL_DIR];
+      std::string s_asr_quant = model_path[QUANTIZE];
+      std::string s_punc_path = model_path[PUNC_DIR];
+      std::string s_punc_quant = model_path[PUNC_QUANT];
+
+      std::string python_cmd =
+          "python -m funasr.utils.runtime_sdk_download_tool --type onnx --quantize True ";
+
+        if (vad_dir.isSet() && !s_vad_path.empty()) {
+        std::string python_cmd_vad;
+        std::string down_vad_path;
+        std::string down_vad_model;
+
+        if (access(s_vad_path.c_str(), F_OK) == 0) {
+          // local
+          python_cmd_vad = python_cmd + " --model-name " + s_vad_path +
+                           " --export-dir ./ " + " --model_revision " +
+                           model_path["vad-revision"];
+          down_vad_path = s_vad_path;
+        } else {
+          // modelscope
+          LOG(INFO) << "Download model: " << s_vad_path
+                    << " from modelscope: "; 
+		  python_cmd_vad = python_cmd + " --model-name " +
+                s_vad_path +
+                " --export-dir " + s_download_model_dir +
+                " --model_revision " + model_path["vad-revision"]; 
+		  down_vad_path  =
+                s_download_model_dir +
+                "/" + s_vad_path;
+        }
+
+        int ret = system(python_cmd_vad.c_str());
+        if (ret != 0) {
+          LOG(INFO) << "Failed to download model from modelscope. If you set local vad model path, you can ignore the errors.";
+        }
+        down_vad_model = down_vad_path + "/model_quant.onnx";
+        if (s_vad_quant == "false" || s_vad_quant == "False" ||
+            s_vad_quant == "FALSE") {
+          down_vad_model = down_vad_path + "/model.onnx";
+        }
+
+        if (access(down_vad_model.c_str(), F_OK) != 0) {
+          LOG(ERROR) << down_vad_model << " do not exists.";
+          exit(-1);
+        } else {
+          model_path[VAD_DIR] = down_vad_path;
+          LOG(INFO) << "Set " << VAD_DIR << " : " << model_path[VAD_DIR];
+        }
+      }
+      else {
+        LOG(INFO) << "VAD model is not set, use default.";
+      }
+
+      if (offline_model_dir.isSet() && !s_offline_asr_path.empty()) {
+        std::string python_cmd_asr;
+        std::string down_asr_path;
+        std::string down_asr_model;
+
+        if (access(s_offline_asr_path.c_str(), F_OK) == 0) {
+          // local
+          python_cmd_asr = python_cmd + " --model-name " + s_offline_asr_path +
+                           " --export-dir ./ " + " --model_revision " +
+                           model_path["offline-model-revision"];
+          down_asr_path = s_offline_asr_path;
+        } else {
+          // modelscope
+          LOG(INFO) << "Download model: " << s_offline_asr_path
+                    << " from modelscope : "; 
+			  python_cmd_asr = python_cmd + " --model-name " +
+                s_offline_asr_path +
+                " --export-dir " + s_download_model_dir +
+                " --model_revision " + model_path["offline-model-revision"]; 
+		  down_asr_path
+              = s_download_model_dir + "/" + s_offline_asr_path;
+        }
+
+        int ret = system(python_cmd_asr.c_str());
+        if (ret != 0) {
+          LOG(INFO) << "Failed to download model from modelscope. If you set local asr model path, you can ignore the errors.";
+        }
+        down_asr_model = down_asr_path + "/model_quant.onnx";
+        if (s_asr_quant == "false" || s_asr_quant == "False" ||
+            s_asr_quant == "FALSE") {
+          down_asr_model = down_asr_path + "/model.onnx";
+        }
+
+        if (access(down_asr_model.c_str(), F_OK) != 0) {
+          LOG(ERROR) << down_asr_model << " do not exists.";
+          exit(-1);
+        } else {
+          model_path[MODEL_DIR] = down_asr_path;
+          LOG(INFO) << "Set " << MODEL_DIR << " : " << model_path[MODEL_DIR];
+        }
+      } else {
+        LOG(INFO) << "ASR Offline model is not set, use default.";
+      }
+
+      if (online_model_dir.isSet() && !s_online_asr_path.empty()) {
+        std::string python_cmd_asr;
+        std::string down_asr_path;
+        std::string down_asr_model;
+
+        if (access(s_online_asr_path.c_str(), F_OK) == 0) {
+          // local
+          python_cmd_asr = python_cmd + " --model-name " + s_online_asr_path +
+                           " --export-dir ./ " + " --model_revision " +
+                           model_path["online-model-revision"];
+          down_asr_path = s_online_asr_path;
+        } else {
+          // modelscope
+          LOG(INFO) << "Download model: " << s_online_asr_path
+                    << " from modelscope : "; 
+		  python_cmd_asr = python_cmd + " --model-name " +
+                s_online_asr_path +
+                " --export-dir " + s_download_model_dir +
+                " --model_revision " + model_path["online-model-revision"]; 
+		  down_asr_path
+              = s_download_model_dir + "/" + s_online_asr_path;
+        }
+
+        int ret = system(python_cmd_asr.c_str());
+        if (ret != 0) {
+          LOG(INFO) << "Failed to download model from modelscope. If you set local asr model path,  you can ignore the errors.";
+        }
+        down_asr_model = down_asr_path + "/model_quant.onnx";
+        if (s_asr_quant == "false" || s_asr_quant == "False" ||
+            s_asr_quant == "FALSE") {
+          down_asr_model = down_asr_path + "/model.onnx";
+        }
+
+        if (access(down_asr_model.c_str(), F_OK) != 0) {
+          LOG(ERROR) << down_asr_model << " do not exists.";
+          exit(-1);
+        } else {
+          model_path[MODEL_DIR] = down_asr_path;
+          LOG(INFO) << "Set " << MODEL_DIR << " : " << model_path[MODEL_DIR];
+        }
+      } else {
+        LOG(INFO) << "ASR online model is not set, use default.";
+      }
+
+      if (punc_dir.isSet() && !s_punc_path.empty()) {
+        std::string python_cmd_punc;
+        std::string down_punc_path;
+        std::string down_punc_model;
+
+        if (access(s_punc_path.c_str(), F_OK) == 0) {
+          // local
+          python_cmd_punc = python_cmd + " --model-name " + s_punc_path +
+                            " --export-dir ./ " + " --model_revision " +
+                            model_path["punc-revision"];
+          down_punc_path = s_punc_path;
+        } else {
+          // modelscope
+          LOG(INFO) << "Download model: " << s_punc_path
+                    << " from modelscope : "; python_cmd_punc = python_cmd + " --model-name " +
+                s_punc_path +
+                " --export-dir " + s_download_model_dir +
+                " --model_revision " + model_path["punc-revision "]; 
+		  down_punc_path  =
+                s_download_model_dir +
+                "/" + s_punc_path;
+        }
+
+        int ret = system(python_cmd_punc.c_str());
+        if (ret != 0) {
+          LOG(INFO) << "Failed to download model from modelscope. If you set local punc model path, you can ignore the errors.";
+        }
+        down_punc_model = down_punc_path + "/model_quant.onnx";
+        if (s_punc_quant == "false" || s_punc_quant == "False" ||
+            s_punc_quant == "FALSE") {
+          down_punc_model = down_punc_path + "/model.onnx";
+        }
+
+        if (access(down_punc_model.c_str(), F_OK) != 0) {
+          LOG(ERROR) << down_punc_model << " do not exists.";
+          exit(-1);
+        } else {
+          model_path[PUNC_DIR] = down_punc_path;
+          LOG(INFO) << "Set " << PUNC_DIR << " : " << model_path[PUNC_DIR];
+        }
+      } else {
+        LOG(INFO) << "PUNC model is not set, use default.";
+      }
+
+    } catch (std::exception const& e) {
+      LOG(ERROR) << "Error: " << e.what();
+    }
+
+    std::string s_listen_ip = listen_ip.getValue();
+    int s_port = port.getValue();
+    int s_io_thread_num = io_thread_num.getValue();
+    int s_decoder_thread_num = decoder_thread_num.getValue();
+
+    int s_model_thread_num = model_thread_num.getValue();
+
+    asio::io_context io_decoder;  // context for decoding
+    asio::io_context io_server;   // context for server
+
+    std::vector<std::thread> decoder_threads;
+
+    std::string s_certfile = certfile.getValue();
+    std::string s_keyfile = keyfile.getValue();
+
+    bool is_ssl = false;
+    if (!s_certfile.empty()) {
+      is_ssl = true;
+    }
+
+    auto conn_guard = asio::make_work_guard(
+        io_decoder);  // make sure threads can wait in the queue
+    auto server_guard = asio::make_work_guard(
+        io_server);  // make sure threads can wait in the queue
+    // create threads pool
+    for (int32_t i = 0; i < s_decoder_thread_num; ++i) {
+      decoder_threads.emplace_back([&io_decoder]() { io_decoder.run(); });
+    }
+
+    server server_;  // server for websocket
+    wss_server wss_server_;
+    if (is_ssl) {
+      wss_server_.init_asio(&io_server);  // init asio
+      wss_server_.set_reuse_addr(
+          true);  // reuse address as we create multiple threads
+
+      // list on port for accept
+      wss_server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
+      WebSocketServer websocket_srv(
+          io_decoder, is_ssl, nullptr, &wss_server_, s_certfile,
+          s_keyfile);  // websocket server for asr engine
+      websocket_srv.initAsr(model_path, s_model_thread_num);  // init asr model
+
+    } else {
+      server_.init_asio(&io_server);  // init asio
+      server_.set_reuse_addr(
+          true);  // reuse address as we create multiple threads
+
+      // list on port for accept
+      server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
+      WebSocketServer websocket_srv(
+          io_decoder, is_ssl, &server_, nullptr, s_certfile,
+          s_keyfile);  // websocket server for asr engine
+      websocket_srv.initAsr(model_path, s_model_thread_num);  // init asr model
+    }
+
+    std::cout << "asr model init finished. listen on port:" << s_port
+              << std::endl;
+
+    // Start the ASIO network io_service run loop
+    std::vector<std::thread> ts;
+    // create threads for io network
+    for (size_t i = 0; i < s_io_thread_num; i++) {
+      ts.emplace_back([&io_server]() { io_server.run(); });
+    }
+    // wait for theads
+    for (size_t i = 0; i < s_io_thread_num; i++) {
+      ts[i].join();
+    }
+
+    // wait for theads
+    for (auto& t : decoder_threads) {
+      t.join();
+    }
+
+  } catch (std::exception const& e) {
+    std::cerr << "Error: " << e.what() << std::endl;
+  }
+
+  return 0;
+}

+ 12 - 0
funasr/runtime/websocket/readme.md

@@ -116,6 +116,18 @@ Export Detailed Introduction([docs](https://github.com/alibaba-damo-academy/Fu
   --punc-dir ./export/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx
 ```
 
+##### Start the 2pass Service
+```shell
+./funasr-wss-server-2pass  \
+  --download-model-dir /workspace/models \
+  --offline-model-dir ./exportdamo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
+  --vad-dir ./exportdamo/speech_fsmn_vad_zh-cn-16k-common-onnx \
+  --punc-dir ./export/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx \
+  --online-model-dir ./exportdamo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online \
+  --quantize false
+```
+
+
 ### Client Usage
 
 

+ 370 - 0
funasr/runtime/websocket/websocket-server-2pass.cpp

@@ -0,0 +1,370 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+ * Reserved. MIT License  (https://opensource.org/licenses/MIT)
+ */
+/* 2022-2023 by zhaomingwork */
+
+// websocket server for asr engine
+// take some ideas from https://github.com/k2-fsa/sherpa-onnx
+// online-websocket-server-impl.cc, thanks. The websocket server has two threads
+// pools, one for handle network data and one for asr decoder.
+// now only support offline engine.
+
+#include "websocket-server-2pass.h"
+
+#include <thread>
+#include <utility>
+#include <vector>
+
+context_ptr WebSocketServer::on_tls_init(tls_mode mode,
+                                         websocketpp::connection_hdl hdl,
+                                         std::string& s_certfile,
+                                         std::string& s_keyfile) {
+  namespace asio = websocketpp::lib::asio;
+
+  LOG(INFO) << "on_tls_init called with hdl: " << hdl.lock().get();
+  LOG(INFO) << "using TLS mode: "
+            << (mode == MOZILLA_MODERN ? "Mozilla Modern"
+                                       : "Mozilla Intermediate");
+
+  context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
+      asio::ssl::context::sslv23);
+
+  try {
+    if (mode == MOZILLA_MODERN) {
+      // Modern disables TLSv1
+      ctx->set_options(
+          asio::ssl::context::default_workarounds |
+          asio::ssl::context::no_sslv2 | asio::ssl::context::no_sslv3 |
+          asio::ssl::context::no_tlsv1 | asio::ssl::context::single_dh_use);
+    } else {
+      ctx->set_options(asio::ssl::context::default_workarounds |
+                       asio::ssl::context::no_sslv2 |
+                       asio::ssl::context::no_sslv3 |
+                       asio::ssl::context::single_dh_use);
+    }
+
+    ctx->use_certificate_chain_file(s_certfile);
+    ctx->use_private_key_file(s_keyfile, asio::ssl::context::pem);
+
+  } catch (std::exception& e) {
+    LOG(INFO) << "Exception: " << e.what();
+  }
+  return ctx;
+}
+
+nlohmann::json handle_result(FUNASR_RESULT result, std::string& online_res,
+                             std::string& tpass_res, nlohmann::json msg) {
+ 
+  std::string tmp_online_msg = FunASRGetResult(result, 0);
+  online_res += tmp_online_msg;
+  if (online_res != "") {
+    LOG(INFO) << "online_res :" << online_res;
+  }
+  std::string tmp_tpass_msg = FunASRGetTpassResult(result, 0);
+  tpass_res += tmp_tpass_msg;
+  if (tpass_res != "") {
+    LOG(INFO) << "offline results : " << tpass_res;
+  }
+
+  websocketpp::lib::error_code ec;
+  nlohmann::json jsonresult;               // result json
+  jsonresult["text"] = tmp_online_msg;     // put result in 'text'
+  jsonresult["offline_text"] = tpass_res;  // put result in 'offline text'
+ 
+  if (msg.contains("wav_name")) {
+    jsonresult["wav_name"] = msg["wav_name"];
+  }
+ 
+ 
+
+  FunASRFreeResult(result);
+  return jsonresult;
+}
+// feed buffer to asr engine for decoder
+void WebSocketServer::do_decoder(
+    std::vector<char>& buffer, websocketpp::connection_hdl& hdl,
+    nlohmann::json& msg, std::vector<std::vector<std::string>>& punc_cache,
+    websocketpp::lib::mutex& thread_lock, bool& is_final,
+    FUNASR_HANDLE& tpass_online_handle, std::string& online_res,
+    std::string& tpass_res) {
+ 
+  // lock for each connection
+  scoped_lock guard(thread_lock);
+ 
+  try {
+    int num_samples = buffer.size();  // the size of the buf
+
+    if (!buffer.empty()) {
+      FUNASR_RESULT Result = nullptr;
+
+      // bool is_final=false;
+      int asr_mode_ = 2;
+      if (msg.contains("mode")) {
+        std::string modeltype = msg["mode"];
+        if (modeltype == "offline") {
+          asr_mode_ = 0;
+        } else if (modeltype == "online") {
+          asr_mode_ = 1;
+        } else if (modeltype == "2pass") {
+          asr_mode_ = 2;
+        }
+      } else {
+        // default value
+        msg["mode"] = "2pass";
+        asr_mode_ = 2;
+      }
+ 
+      // loop to send chunk_size 1600*2 data to asr engine.   TODO: chunk_size need get from client 
+      while (buffer.size() >= 1600 * 2) {
+        std::vector<char> subvector = {buffer.begin(),
+                                       buffer.begin() + 1600 * 2};
+        buffer.erase(buffer.begin(), buffer.begin() + 1600 * 2);
+
+        Result =
+            FunTpassInferBuffer(tpass_handle, tpass_online_handle,
+                                subvector.data(), subvector.size(), punc_cache,
+                                false, 16000, "pcm", (ASR_TYPE)asr_mode_);
+        if (Result) {
+          websocketpp::lib::error_code ec;
+
+          nlohmann::json jsonresult =
+              handle_result(Result, online_res, tpass_res, msg["wav_name"]);
+
+          jsonresult["is_final"] = true;
+          jsonresult["mode"] = msg["mode"];
+          if (jsonresult["text"].size() > 0) {
+            if (is_ssl) {
+              wss_server_->send(hdl, jsonresult.dump(),
+                                websocketpp::frame::opcode::text, ec);
+            } else {
+              server_->send(hdl, jsonresult.dump(),
+                            websocketpp::frame::opcode::text, ec);
+            }
+          }
+        }
+      }
+	  // if it is in final message
+      if (is_final && buffer.size() > 0) {
+        LOG(INFO) << "is final, the buffer size=" << buffer.size();
+
+        Result = FunTpassInferBuffer(tpass_handle, tpass_online_handle,
+                                     buffer.data(), buffer.size(), punc_cache,
+                                     true, 16000, "pcm", (ASR_TYPE)asr_mode_);
+ 
+        if (Result) {
+  
+          websocketpp::lib::error_code ec;
+
+          nlohmann::json jsonresult =
+              handle_result(Result, online_res, tpass_res, msg["wav_name"]);
+          jsonresult["is_final"] = false;
+          jsonresult["mode"] = msg["mode"];
+          if (asr_mode_ != 1) {
+            jsonresult["text"] = jsonresult["offline_text"];
+          }
+          if (jsonresult["offline_text"].size() > 0) {
+            if (is_ssl) {
+              wss_server_->send(hdl, jsonresult.dump(),
+                                websocketpp::frame::opcode::text, ec);
+            } else {
+              server_->send(hdl, jsonresult.dump(),
+                            websocketpp::frame::opcode::text, ec);
+            }
+          }
+        }
+      }
+
+ 
+    }
+
+  } catch (std::exception const& e) {
+    std::cerr << "Error: " << e.what() << std::endl;
+  }
+
+ 
+}
+
+void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
+  scoped_lock guard(m_lock);     // for threads safty
+  check_and_clean_connection();  // remove closed connection
+
+  std::shared_ptr<FUNASR_MESSAGE> data_msg =
+      std::make_shared<FUNASR_MESSAGE>();  // put a new data vector for new
+                                           // connection
+  data_msg->samples = std::make_shared<std::vector<char>>();
+  data_msg->thread_lock = new websocketpp::lib::mutex();
+ 
+  data_msg->msg = nlohmann::json::parse("{}");
+  data_msg->msg["wav_format"] = "pcm";
+  data_msg->punc_cache =
+      std::make_shared<std::vector<std::vector<std::string>>>(2);
+  std::vector<int> chunk_size = {5, 10, 5};  //TODO, need get from client 
+  FUNASR_HANDLE tpass_online_handle =
+      FunTpassOnlineInit(tpass_handle, chunk_size);
+  data_msg->tpass_online_handle = tpass_online_handle;
+  data_map.emplace(hdl, data_msg);
+  FunTpassOnlineInit(tpass_handle);
+  LOG(INFO) << "on_open, active connections: " << data_map.size();
+ 
+}
+
+void WebSocketServer::on_close(websocketpp::connection_hdl hdl) {
+  scoped_lock guard(m_lock);
+  data_map.erase(hdl);  // remove data vector when  connection is closed
+
+  LOG(INFO) << "on_close, active connections: " << data_map.size();
+}
+
+// remove closed connection
+void WebSocketServer::check_and_clean_connection() {
+  std::vector<websocketpp::connection_hdl> to_remove;  // remove list
+  auto iter = data_map.begin();
+  while (iter != data_map.end()) {  // loop to find closed connection
+    websocketpp::connection_hdl hdl = iter->first;
+
+    if (is_ssl) {
+      wss_server::connection_ptr con = wss_server_->get_con_from_hdl(hdl);
+      if (con->get_state() != 1) {  // session::state::open ==1
+        to_remove.push_back(hdl);
+      }
+    } else {
+      server::connection_ptr con = server_->get_con_from_hdl(hdl);
+      if (con->get_state() != 1) {  // session::state::open ==1
+        to_remove.push_back(hdl);
+      }
+    }
+
+    iter++;
+  }
+  for (auto hdl : to_remove) {
+    data_map.erase(hdl);
+    LOG(INFO) << "remove one connection ";
+  }
+}
+void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
+                                 message_ptr msg) {
+  unique_lock lock(m_lock);
+  // find the sample data vector according to one connection
+
+  std::shared_ptr<FUNASR_MESSAGE> msg_data = nullptr;
+
+  auto it_data = data_map.find(hdl);
+  if (it_data != data_map.end()) {
+    msg_data = it_data->second;
+  }
+
+  std::shared_ptr<std::vector<char>> sample_data_p = msg_data->samples;
+  std::shared_ptr<std::vector<std::vector<std::string>>> punc_cache_p =
+      msg_data->punc_cache;
+  websocketpp::lib::mutex* thread_lock_p = msg_data->thread_lock;
+ 
+  lock.unlock();
+
+  if (sample_data_p == nullptr) {
+    LOG(INFO) << "error when fetch sample data vector";
+    return;
+  }
+ 
+  const std::string& payload = msg->get_payload();  // get msg type
+
+  switch (msg->get_opcode()) {
+    case websocketpp::frame::opcode::text: {
+      nlohmann::json jsonresult = nlohmann::json::parse(payload);
+
+      if (jsonresult.contains("wav_name")) {
+        msg_data->msg["wav_name"] = jsonresult["wav_name"];
+      }
+      if (jsonresult.contains("mode")) {
+        msg_data->msg["mode"] = jsonresult["mode"];
+      }
+      if (jsonresult.contains("mode")) {
+        msg_data->msg["mode"] = jsonresult["mode"];
+      }
+
+      if (jsonresult.contains("wav_format")) {
+        msg_data->msg["wav_format"] = jsonresult["wav_format"];
+      }
+      LOG(INFO) << "jsonresult=" << jsonresult << "msg_data->msg"
+                << msg_data->msg;
+      if (jsonresult["is_speaking"] == false ||
+          jsonresult["is_finished"] == true) {
+        LOG(INFO) << "client done";
+
+        // if it is in final message, post the sample_data to decode
+        asio::post(
+            io_decoder_,
+            std::bind(&WebSocketServer::do_decoder, this,
+                      std::move(*(sample_data_p.get())), std::move(hdl),
+                      std::ref(msg_data->msg), std::ref(*(punc_cache_p.get())),
+                      std::ref(*thread_lock_p), std::move(true),
+                      std::ref(msg_data->tpass_online_handle),
+                      std::ref(msg_data->online_res),
+                      std::ref(msg_data->tpass_res)));
+      }
+      break;
+    }
+    case websocketpp::frame::opcode::binary: {
+      // recived binary data
+      const auto* pcm_data = static_cast<const char*>(payload.data());
+      int32_t num_samples = payload.size();
+ 
+      if (isonline) {
+ 
+        // need to split data to required chunksize(1600*2)
+		// put rev data to sample_data
+        sample_data_p->insert(sample_data_p->end(), pcm_data,
+                              pcm_data + num_samples);
+        int setpsize = 1600 * 2;  // TODO, need get from client 
+		// if sample_data size > setpsize, we post data to decode
+        if (sample_data_p->size() > setpsize) {
+          int chunksize = floor(sample_data_p->size() / setpsize);
+		  // make sure the subvector size is an integer multiple of setpsize
+          std::vector<char> subvector = {
+              sample_data_p->begin(),
+              sample_data_p->begin() + chunksize * setpsize};
+		  // keep remain in sample_data
+          sample_data_p->erase(sample_data_p->begin(),
+                               sample_data_p->begin() + chunksize * setpsize);
+		  // post to decode
+          asio::post(io_decoder_,
+                     std::bind(&WebSocketServer::do_decoder, this,
+                               std::move(subvector), std::move(hdl),
+                               std::ref(msg_data->msg),
+                               std::ref(*(punc_cache_p.get())),
+                               std::ref(*thread_lock_p), std::move(false),
+                               std::ref(msg_data->tpass_online_handle),
+                               std::ref(msg_data->online_res),
+                               std::ref(msg_data->tpass_res)));
+        }
+      } else {
+        // for offline, we add receive data to end of the sample data vector
+        sample_data_p->insert(sample_data_p->end(), pcm_data,
+                              pcm_data + num_samples);
+      }
+
+      break;
+    }
+    default:
+      break;
+  }
+}
+
+// init asr model
+void WebSocketServer::initAsr(std::map<std::string, std::string>& model_path,
+                              int thread_num) {
+  try {
+    // init model with api
+
+    asr_handle = FunOfflineInit(model_path, thread_num);
+    LOG(INFO) << "model successfully inited";
+    tpass_handle = FunTpassInit(model_path, thread_num);
+    if (!tpass_handle) {
+      LOG(ERROR) << "FunTpassInit init failed";
+      exit(-1);
+    }
+
+  } catch (const std::exception& e) {
+    LOG(INFO) << e.what();
+  }
+}

+ 148 - 0
funasr/runtime/websocket/websocket-server-2pass.h

@@ -0,0 +1,148 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+ * Reserved. MIT License  (https://opensource.org/licenses/MIT)
+ */
+/* 2022-2023 by zhaomingwork */
+
+// websocket server for asr engine
+// take some ideas from https://github.com/k2-fsa/sherpa-onnx
+// online-websocket-server-impl.cc, thanks. The websocket server has two threads
+// pools, one for handle network data and one for asr decoder.
+// now only support offline engine.
+
+#ifndef WEBSOCKET_SERVER_H_
+#define WEBSOCKET_SERVER_H_
+
+#include <iostream>
+#include <map>
+#include <memory>
+#include <string>
+#include <thread>
+#include <utility>
+#define ASIO_STANDALONE 1  // not boost
+#include <glog/logging.h>
+
+#include <fstream>
+#include <functional>
+#include <websocketpp/common/thread.hpp>
+#include <websocketpp/config/asio.hpp>
+#include <websocketpp/server.hpp>
+
+#include "asio.hpp"
+#include "com-define.h"
+#include "funasrruntime.h"
+#include "nlohmann/json.hpp"
+#include "tclap/CmdLine.h"
+typedef websocketpp::server<websocketpp::config::asio> server;
+typedef websocketpp::server<websocketpp::config::asio_tls> wss_server;
+typedef server::message_ptr message_ptr;
+using websocketpp::lib::bind;
+using websocketpp::lib::placeholders::_1;
+using websocketpp::lib::placeholders::_2;
+
+typedef websocketpp::lib::lock_guard<websocketpp::lib::mutex> scoped_lock;
+typedef websocketpp::lib::unique_lock<websocketpp::lib::mutex> unique_lock;
+typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context>
+    context_ptr;
+
+typedef struct {
+  std::string msg;
+  float snippet_time;
+} FUNASR_RECOG_RESULT;
+
+typedef struct {
+  nlohmann::json msg;
+  std::shared_ptr<std::vector<char>> samples;
+  std::shared_ptr<std::vector<std::vector<std::string>>> punc_cache;
+  websocketpp::lib::mutex* thread_lock; // lock for each connection
+  FUNASR_HANDLE tpass_online_handle;
+  std::string online_res = "";
+  std::string tpass_res = "";
+  
+} FUNASR_MESSAGE;
+
+// See https://wiki.mozilla.org/Security/Server_Side_TLS for more details about
+// the TLS modes. The code below demonstrates how to implement both the modern
+enum tls_mode { MOZILLA_INTERMEDIATE = 1, MOZILLA_MODERN = 2 };
+class WebSocketServer {
+ public:
+  WebSocketServer(asio::io_context& io_decoder, bool is_ssl, server* server,
+                  wss_server* wss_server, std::string& s_certfile,
+                  std::string& s_keyfile)
+      : io_decoder_(io_decoder),
+        is_ssl(is_ssl),
+        server_(server),
+        wss_server_(wss_server) {
+    if (is_ssl) {
+      std::cout << "certfile path is " << s_certfile << std::endl;
+      wss_server->set_tls_init_handler(
+          bind<context_ptr>(&WebSocketServer::on_tls_init, this,
+                            MOZILLA_INTERMEDIATE, ::_1, s_certfile, s_keyfile));
+      wss_server_->set_message_handler(
+          [this](websocketpp::connection_hdl hdl, message_ptr msg) {
+            on_message(hdl, msg);
+          });
+      // set open handle
+      wss_server_->set_open_handler(
+          [this](websocketpp::connection_hdl hdl) { on_open(hdl); });
+      // set close handle
+      wss_server_->set_close_handler(
+          [this](websocketpp::connection_hdl hdl) { on_close(hdl); });
+      // begin accept
+      wss_server_->start_accept();
+      // not print log
+      wss_server_->clear_access_channels(websocketpp::log::alevel::all);
+
+    } else {
+      // set message handle
+      server_->set_message_handler(
+          [this](websocketpp::connection_hdl hdl, message_ptr msg) {
+            on_message(hdl, msg);
+          });
+      // set open handle
+      server_->set_open_handler(
+          [this](websocketpp::connection_hdl hdl) { on_open(hdl); });
+      // set close handle
+      server_->set_close_handler(
+          [this](websocketpp::connection_hdl hdl) { on_close(hdl); });
+      // begin accept
+      server_->start_accept();
+      // not print log
+      server_->clear_access_channels(websocketpp::log::alevel::all);
+    }
+  }
+  void do_decoder(std::vector<char>& buffer, websocketpp::connection_hdl& hdl,
+                  nlohmann::json& msg,
+                  std::vector<std::vector<std::string>>& punc_cache,
+                  websocketpp::lib::mutex& thread_lock, bool& is_final,
+                  FUNASR_HANDLE& tpass_online_handle, std::string& online_res,
+                  std::string& tpass_res);
+
+  void initAsr(std::map<std::string, std::string>& model_path, int thread_num);
+  void on_message(websocketpp::connection_hdl hdl, message_ptr msg);
+  void on_open(websocketpp::connection_hdl hdl);
+  void on_close(websocketpp::connection_hdl hdl);
+  context_ptr on_tls_init(tls_mode mode, websocketpp::connection_hdl hdl,
+                          std::string& s_certfile, std::string& s_keyfile);
+
+ private:
+  void check_and_clean_connection();
+  asio::io_context& io_decoder_;  // threads for asr decoder
+  // std::ofstream fout;
+  FUNASR_HANDLE asr_handle;  // asr engine handle
+  FUNASR_HANDLE tpass_handle;
+  bool isonline = true;  // online or offline engine, now only support offline
+  bool is_ssl = true;
+  server* server_;          // websocket server
+  wss_server* wss_server_;  // websocket server
+
+  // use map to keep the received samples data from one connection in offline
+  // engine. if for online engline, a data struct is needed(TODO)
+
+  std::map<websocketpp::connection_hdl, std::shared_ptr<FUNASR_MESSAGE>,
+           std::owner_less<websocketpp::connection_hdl>>
+      data_map;
+  websocketpp::lib::mutex m_lock;  // mutex for sample_map
+};
+
+#endif  // WEBSOCKET_SERVER_H_