Răsfoiți Sursa

add ssl support for cpp websocket (#553)

zhaomingwork 2 ani în urmă
părinte
comite
21c590ad67

+ 5 - 3
funasr/runtime/websocket/CMakeLists.txt

@@ -55,10 +55,12 @@ add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/src src)
 include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog)
 set(BUILD_TESTING OFF)
 add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog)
- 
+
+# install openssl first apt-get install libssl-dev
+find_package(OpenSSL REQUIRED)
 
 add_executable(websocketmain "websocketmain.cpp" "websocketsrv.cpp")
 add_executable(websocketclient "websocketclient.cpp")
 
-target_link_libraries(websocketclient PUBLIC funasr)
-target_link_libraries(websocketmain PUBLIC funasr)
+target_link_libraries(websocketclient PUBLIC funasr ssl crypto)
+target_link_libraries(websocketmain PUBLIC funasr ssl crypto)

+ 17 - 5
funasr/runtime/websocket/readme.md

@@ -33,7 +33,12 @@ sudo apt-get install libopenblas-dev #ubuntu
 ```
 
 ### Build runtime
+required openssl lib
+
 ```shell
+#install openssl lib first
+apt-get install libssl-dev
+
 git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/websocket
 mkdir build && cd build
 cmake  -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0
@@ -43,11 +48,12 @@ make
 
 ```shell
 cd bin
-./websocketmain  [--model_thread_num <int>] [--decoder_thread_num <int>]
+   ./websocketmain  [--model_thread_num <int>] [--decoder_thread_num <int>]
                     [--io_thread_num <int>] [--port <int>] [--listen_ip
                     <string>] [--punc-quant <string>] [--punc-dir <string>]
                     [--vad-quant <string>] [--vad-dir <string>] [--quantize
-                    <string>] --model-dir <string> [--] [--version] [-h]
+                    <string>] --model-dir <string> [--keyfile <string>]
+                    [--certfile <string>] [--] [--version] [-h]
 Where:
    --model-dir <string>
      (required)  the asr model path, which contains model.onnx, config.yaml, am.mvn
@@ -70,6 +76,10 @@ Where:
      number of threads for network io, default:8
    --port <int>
      listen port, default:8889
+   --certfile <string>
+     path of certficate for WSS connection. if it is empty, it will be in WS mode.
+   --keyfile <string>
+     path of keyfile for WSS connection
   
    Required:  --model-dir <string>
    If use vad, please add: --vad-dir <string>
@@ -81,14 +91,16 @@ example:
 ## Run websocket client test
 
 ```shell
-Usage: websocketclient server_ip port wav_path threads_num
+Usage: ./websocketclient server_ip port wav_path threads_num is_ssl
+
+is_ssl is 1 means use wss connection, or use ws connection
 
 example:
 
-websocketclient 127.0.0.1 8889 funasr/runtime/websocket/test.pcm.wav 64
+websocketclient 127.0.0.1 8889 funasr/runtime/websocket/test.pcm.wav 64 0
 
 result json, example like:
-{"text":"一二三四五六七八九十一二三四五六七八九十"}
+{"mode":"offline","text":"欢迎大家来体验达摩院推出的语音识别模型","wav_name":"wav2"}
 ```
 
 

+ 54 - 14
funasr/runtime/websocket/websocketclient.cpp

@@ -10,7 +10,7 @@
 #define ASIO_STANDALONE 1
 #include <websocketpp/client.hpp>
 #include <websocketpp/common/thread.hpp>
-#include <websocketpp/config/asio_no_tls_client.hpp>
+#include <websocketpp/config/asio_client.hpp>
 
 #include "audio.h"
 #include "nlohmann/json.hpp"
@@ -26,14 +26,37 @@ void wait_a_bit() {
 #endif
 }
 typedef websocketpp::config::asio_client::message_type::ptr message_ptr;
-
+typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context>
+    context_ptr;
+using websocketpp::lib::bind;
+using websocketpp::lib::placeholders::_1;
+using websocketpp::lib::placeholders::_2;
+context_ptr on_tls_init(websocketpp::connection_hdl) {
+  context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
+      asio::ssl::context::sslv23);
+
+  try {
+    ctx->set_options(
+        asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 |
+        asio::ssl::context::no_sslv3 | asio::ssl::context::single_dh_use);
+
+  } catch (std::exception& e) {
+    std::cout << e.what() << std::endl;
+  }
+  return ctx;
+}
+// template for tls or not config
+template <typename T>
 class websocket_client {
  public:
-  typedef websocketpp::client<websocketpp::config::asio_client> client;
+  // typedef websocketpp::client<T> client;
+  // typedef websocketpp::client<websocketpp::config::asio_tls_client>
+  // wss_client;
   typedef websocketpp::lib::lock_guard<websocketpp::lib::mutex> scoped_lock;
 
-  websocket_client() : m_open(false), m_done(false) {
+  websocket_client(int is_ssl) : m_open(false), m_done(false) {
     // set up access channels to only log interesting things
+
     m_client.clear_access_channels(websocketpp::log::alevel::all);
     m_client.set_access_channels(websocketpp::log::alevel::connect);
     m_client.set_access_channels(websocketpp::log::alevel::disconnect);
@@ -65,10 +88,12 @@ class websocket_client {
     }
   }
   // This method will block until the connection is complete
+
   void run(const std::string& uri, const std::string& wav_path) {
     // Create a new connection to the given URI
     websocketpp::lib::error_code ec;
-    client::connection_ptr con = m_client.get_connection(uri, ec);
+    typename websocketpp::client<T>::connection_ptr con =
+        m_client.get_connection(uri, ec);
     if (ec) {
       m_client.get_alog().write(websocketpp::log::alevel::app,
                                 "Get Connection Error: " + ec.message());
@@ -84,7 +109,8 @@ class websocket_client {
     m_client.connect(con);
 
     // Create a thread to run the ASIO io_service event loop
-    websocketpp::lib::thread asio_thread(&client::run, &m_client);
+    websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
+                                         &m_client);
 
     send_wav_data();
     asio_thread.join();
@@ -201,9 +227,9 @@ class websocket_client {
                   ec);
     wait_a_bit();
   }
+  websocketpp::client<T> m_client;
 
  private:
-  client m_client;
   websocketpp::connection_hdl m_hdl;
   websocketpp::lib::mutex m_lock;
   std::string wav_path;
@@ -212,22 +238,36 @@ class websocket_client {
 };
 
 int main(int argc, char* argv[]) {
-  if (argc < 5) {
-    printf("Usage: %s server_ip port wav_path threads_num\n", argv[0]);
+  if (argc < 6) {
+    printf("Usage: %s server_ip port wav_path threads_num is_ssl\n", argv[0]);
     exit(-1);
   }
   std::string server_ip = argv[1];
   std::string port = argv[2];
   std::string wav_path = argv[3];
   int threads_num = atoi(argv[4]);
+  int is_ssl = atoi(argv[5]);
   std::vector<websocketpp::lib::thread> client_threads;
-
-  std::string uri = "ws://" + server_ip + ":" + port;
+  std::string uri = "";
+  if (is_ssl == 1) {
+    uri = "wss://" + server_ip + ":" + port;
+  } else {
+    uri = "ws://" + server_ip + ":" + port;
+  }
 
   for (size_t i = 0; i < threads_num; i++) {
-    client_threads.emplace_back([uri, wav_path]() {
-      websocket_client c;
-      c.run(uri, wav_path);
+    client_threads.emplace_back([uri, wav_path, is_ssl]() {
+      if (is_ssl == 1) {
+        websocket_client<websocketpp::config::asio_tls_client> c(is_ssl);
+
+        c.m_client.set_tls_init_handler(bind(&on_tls_init, ::_1));
+
+        c.run(uri, wav_path);
+      } else {
+        websocket_client<websocketpp::config::asio_client> c(is_ssl);
+
+        c.run(uri, wav_path);
+      }
     });
   }
 

+ 53 - 12
funasr/runtime/websocket/websocketmain.cpp

@@ -64,6 +64,14 @@ int main(int argc, char* argv[]) {
     TCLAP::ValueArg<int> model_thread_num("", "model_thread_num",
                                           "model_thread_num", false, 1, "int");
 
+    TCLAP::ValueArg<std::string> certfile("", "certfile", "certfile", false, "",
+                                          "string");
+    TCLAP::ValueArg<std::string> keyfile("", "keyfile", "keyfile", false, "",
+                                         "string");
+
+    cmd.add(certfile);
+    cmd.add(keyfile);
+
     cmd.add(model_dir);
     cmd.add(quantize);
     cmd.add(vad_dir);
@@ -97,6 +105,14 @@ int main(int argc, char* argv[]) {
 
     std::vector<std::thread> decoder_threads;
 
+    std::string s_certfile = certfile.getValue();
+    std::string s_keyfile = keyfile.getValue();
+
+    bool is_ssl = false;
+    if (!s_certfile.empty()) {
+      is_ssl = true;
+    }
+
     auto conn_guard = asio::make_work_guard(
         io_decoder);  // make sure threads can wait in the queue
 
@@ -105,30 +121,55 @@ int main(int argc, char* argv[]) {
       decoder_threads.emplace_back([&io_decoder]() { io_decoder.run(); });
     }
 
-    server server_;       // server for websocket
-    server_.init_asio();  // init asio
-    server_.set_reuse_addr(
-        true);  // reuse address as we create multiple threads
+    server server_;  // server for websocket
+    wss_server wss_server_;
+    if (is_ssl) {
+      wss_server_.init_asio();  // init asio
+      wss_server_.set_reuse_addr(
+          true);  // reuse address as we create multiple threads
 
-    // list on port for accept
-    server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
+      // list on port for accept
+      wss_server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
+      WebSocketServer websocket_srv(
+          io_decoder, is_ssl, nullptr, &wss_server_, s_certfile,
+          s_keyfile);  // websocket server for asr engine
+      websocket_srv.initAsr(model_path, s_model_thread_num);  // init asr model
+
+    } else {
+      server_.init_asio();  // init asio
+      server_.set_reuse_addr(
+          true);  // reuse address as we create multiple threads
+
+      // list on port for accept
+      server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
+      WebSocketServer websocket_srv(
+          io_decoder, is_ssl, &server_, nullptr, s_certfile,
+          s_keyfile);  // websocket server for asr engine
+      websocket_srv.initAsr(model_path, s_model_thread_num);  // init asr model
+    }
 
-    WebSocketServer websocket_srv(io_decoder,
-                                  &server_);  // websocket server for asr engine
-    websocket_srv.initAsr(model_path, s_model_thread_num);  // init asr model
     std::cout << "asr model init finished. listen on port:" << s_port
               << std::endl;
 
     // Start the ASIO network io_service run loop
     if (s_io_thread_num == 1) {
-      server_.run();
+      if (is_ssl) {
+        wss_server_.run();
+      } else {
+        server_.run();
+      }
     } else {
       typedef websocketpp::lib::shared_ptr<websocketpp::lib::thread> thread_ptr;
       std::vector<thread_ptr> ts;
       // create threads for io network
       for (size_t i = 0; i < s_io_thread_num; i++) {
-        ts.push_back(websocketpp::lib::make_shared<websocketpp::lib::thread>(
-            &server::run, &server_));
+        if (is_ssl) {
+          ts.push_back(websocketpp::lib::make_shared<websocketpp::lib::thread>(
+              &wss_server::run, &wss_server_));
+        } else {
+          ts.push_back(websocketpp::lib::make_shared<websocketpp::lib::thread>(
+              &server::run, &server_));
+        }
       }
       // wait for theads
       for (size_t i = 0; i < s_io_thread_num; i++) {

+ 57 - 5
funasr/runtime/websocket/websocketsrv.cpp

@@ -16,6 +16,44 @@
 #include <utility>
 #include <vector>
 
+context_ptr WebSocketServer::on_tls_init(tls_mode mode,
+                                         websocketpp::connection_hdl hdl,
+                                         std::string& s_certfile,
+                                         std::string& s_keyfile) {
+  namespace asio = websocketpp::lib::asio;
+
+  std::cout << "on_tls_init called with hdl: " << hdl.lock().get() << std::endl;
+  std::cout << "using TLS mode: "
+            << (mode == MOZILLA_MODERN ? "Mozilla Modern"
+                                       : "Mozilla Intermediate")
+            << std::endl;
+
+  context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
+      asio::ssl::context::sslv23);
+
+  try {
+    if (mode == MOZILLA_MODERN) {
+      // Modern disables TLSv1
+      ctx->set_options(
+          asio::ssl::context::default_workarounds |
+          asio::ssl::context::no_sslv2 | asio::ssl::context::no_sslv3 |
+          asio::ssl::context::no_tlsv1 | asio::ssl::context::single_dh_use);
+    } else {
+      ctx->set_options(asio::ssl::context::default_workarounds |
+                       asio::ssl::context::no_sslv2 |
+                       asio::ssl::context::no_sslv3 |
+                       asio::ssl::context::single_dh_use);
+    }
+
+    ctx->use_certificate_chain_file(s_certfile);
+    ctx->use_private_key_file(s_keyfile, asio::ssl::context::pem);
+
+  } catch (std::exception& e) {
+    std::cout << "Exception: " << e.what() << std::endl;
+  }
+  return ctx;
+}
+
 // feed buffer to asr engine for decoder
 void WebSocketServer::do_decoder(const std::vector<char>& buffer,
                                  websocketpp::connection_hdl& hdl,
@@ -40,8 +78,13 @@ void WebSocketServer::do_decoder(const std::vector<char>& buffer,
       jsonresult["wav_name"] = msg["wav_name"];
 
       // send the json to client
-      server_->send(hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
-                    ec);
+      if (is_ssl) {
+        wss_server_->send(hdl, jsonresult.dump(),
+                          websocketpp::frame::opcode::text, ec);
+      } else {
+        server_->send(hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
+                      ec);
+      }
 
       std::cout << "buffer.size=" << buffer.size()
                 << ",result json=" << jsonresult.dump() << std::endl;
@@ -83,10 +126,19 @@ void WebSocketServer::check_and_clean_connection() {
   auto iter = data_map.begin();
   while (iter != data_map.end()) {  // loop to find closed connection
     websocketpp::connection_hdl hdl = iter->first;
-    server::connection_ptr con = server_->get_con_from_hdl(hdl);
-    if (con->get_state() != 1) {  // session::state::open ==1
-      to_remove.push_back(hdl);
+
+    if (is_ssl) {
+      wss_server::connection_ptr con = wss_server_->get_con_from_hdl(hdl);
+      if (con->get_state() != 1) {  // session::state::open ==1
+        to_remove.push_back(hdl);
+      }
+    } else {
+      server::connection_ptr con = server_->get_con_from_hdl(hdl);
+      if (con->get_state() != 1) {  // session::state::open ==1
+        to_remove.push_back(hdl);
+      }
     }
+
     iter++;
   }
   for (auto hdl : to_remove) {

+ 57 - 19
funasr/runtime/websocket/websocketsrv.h

@@ -25,7 +25,7 @@
 #include <fstream>
 #include <functional>
 #include <websocketpp/common/thread.hpp>
-#include <websocketpp/config/asio_no_tls.hpp>
+#include <websocketpp/config/asio.hpp>
 #include <websocketpp/server.hpp>
 
 #include "asio.hpp"
@@ -34,12 +34,16 @@
 #include "nlohmann/json.hpp"
 #include "tclap/CmdLine.h"
 typedef websocketpp::server<websocketpp::config::asio> server;
+typedef websocketpp::server<websocketpp::config::asio_tls> wss_server;
 typedef server::message_ptr message_ptr;
 using websocketpp::lib::bind;
 using websocketpp::lib::placeholders::_1;
 using websocketpp::lib::placeholders::_2;
+
 typedef websocketpp::lib::lock_guard<websocketpp::lib::mutex> scoped_lock;
 typedef websocketpp::lib::unique_lock<websocketpp::lib::mutex> unique_lock;
+typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context>
+    context_ptr;
 
 typedef struct {
   std::string msg;
@@ -51,25 +55,55 @@ typedef struct {
   std::shared_ptr<std::vector<char>> samples;
 } FUNASR_MESSAGE;
 
+// See https://wiki.mozilla.org/Security/Server_Side_TLS for more details about
+// the TLS modes. The code below demonstrates how to implement both the modern
+enum tls_mode { MOZILLA_INTERMEDIATE = 1, MOZILLA_MODERN = 2 };
 class WebSocketServer {
  public:
-  WebSocketServer(asio::io_context& io_decoder, server* server_)
-      : io_decoder_(io_decoder), server_(server_) {
-    // set message handle
-    server_->set_message_handler(
-        [this](websocketpp::connection_hdl hdl, message_ptr msg) {
-          on_message(hdl, msg);
-        });
-    // set open handle
-    server_->set_open_handler(
-        [this](websocketpp::connection_hdl hdl) { on_open(hdl); });
-    // set close handle
-    server_->set_close_handler(
-        [this](websocketpp::connection_hdl hdl) { on_close(hdl); });
-    // begin accept
-    server_->start_accept();
-    // not print log
-    server_->clear_access_channels(websocketpp::log::alevel::all);
+  WebSocketServer(asio::io_context& io_decoder, bool is_ssl, server* server,
+                  wss_server* wss_server, std::string& s_certfile,
+                  std::string& s_keyfile)
+      : io_decoder_(io_decoder),
+        is_ssl(is_ssl),
+        server_(server),
+        wss_server_(wss_server) {
+    if (is_ssl) {
+      std::cout << "certfile path is " << s_certfile << std::endl;
+      wss_server->set_tls_init_handler(
+          bind<context_ptr>(&WebSocketServer::on_tls_init, this,
+                            MOZILLA_INTERMEDIATE, ::_1, s_certfile, s_keyfile));
+      wss_server_->set_message_handler(
+          [this](websocketpp::connection_hdl hdl, message_ptr msg) {
+            on_message(hdl, msg);
+          });
+      // set open handle
+      wss_server_->set_open_handler(
+          [this](websocketpp::connection_hdl hdl) { on_open(hdl); });
+      // set close handle
+      wss_server_->set_close_handler(
+          [this](websocketpp::connection_hdl hdl) { on_close(hdl); });
+      // begin accept
+      wss_server_->start_accept();
+      // not print log
+      wss_server_->clear_access_channels(websocketpp::log::alevel::all);
+
+    } else {
+      // set message handle
+      server_->set_message_handler(
+          [this](websocketpp::connection_hdl hdl, message_ptr msg) {
+            on_message(hdl, msg);
+          });
+      // set open handle
+      server_->set_open_handler(
+          [this](websocketpp::connection_hdl hdl) { on_open(hdl); });
+      // set close handle
+      server_->set_close_handler(
+          [this](websocketpp::connection_hdl hdl) { on_close(hdl); });
+      // begin accept
+      server_->start_accept();
+      // not print log
+      server_->clear_access_channels(websocketpp::log::alevel::all);
+    }
   }
   void do_decoder(const std::vector<char>& buffer,
                   websocketpp::connection_hdl& hdl, const nlohmann::json& msg);
@@ -78,6 +112,8 @@ class WebSocketServer {
   void on_message(websocketpp::connection_hdl hdl, message_ptr msg);
   void on_open(websocketpp::connection_hdl hdl);
   void on_close(websocketpp::connection_hdl hdl);
+  context_ptr on_tls_init(tls_mode mode, websocketpp::connection_hdl hdl,
+                          std::string& s_certfile, std::string& s_keyfile);
 
  private:
   void check_and_clean_connection();
@@ -85,7 +121,9 @@ class WebSocketServer {
   // std::ofstream fout;
   FUNASR_HANDLE asr_hanlde;  // asr engine handle
   bool isonline = false;  // online or offline engine, now only support offline
-  server* server_;        // websocket server
+  bool is_ssl = true;
+  server* server_;          // websocket server
+  wss_server* wss_server_;  // websocket server
 
   // use map to keep the received samples data from one connection in offline
   // engine. if for online engline, a data struct is needed(TODO)