websocket-server.cpp 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. /**
  2. * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
  3. * Reserved. MIT License (https://opensource.org/licenses/MIT)
  4. */
  5. /* 2022-2023 by zhaomingwork */
  6. // websocket server for asr engine
  7. // take some ideas from https://github.com/k2-fsa/sherpa-onnx
  8. // online-websocket-server-impl.cc, thanks. The websocket server has two threads
  9. // pools, one for handle network data and one for asr decoder.
  10. // now only support offline engine.
  11. #include "websocket-server.h"
  12. #include <thread>
  13. #include <utility>
  14. #include <vector>
  15. context_ptr WebSocketServer::on_tls_init(tls_mode mode,
  16. websocketpp::connection_hdl hdl,
  17. std::string& s_certfile,
  18. std::string& s_keyfile) {
  19. namespace asio = websocketpp::lib::asio;
  20. std::cout << "on_tls_init called with hdl: " << hdl.lock().get() << std::endl;
  21. std::cout << "using TLS mode: "
  22. << (mode == MOZILLA_MODERN ? "Mozilla Modern"
  23. : "Mozilla Intermediate")
  24. << std::endl;
  25. context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
  26. asio::ssl::context::sslv23);
  27. try {
  28. if (mode == MOZILLA_MODERN) {
  29. // Modern disables TLSv1
  30. ctx->set_options(
  31. asio::ssl::context::default_workarounds |
  32. asio::ssl::context::no_sslv2 | asio::ssl::context::no_sslv3 |
  33. asio::ssl::context::no_tlsv1 | asio::ssl::context::single_dh_use);
  34. } else {
  35. ctx->set_options(asio::ssl::context::default_workarounds |
  36. asio::ssl::context::no_sslv2 |
  37. asio::ssl::context::no_sslv3 |
  38. asio::ssl::context::single_dh_use);
  39. }
  40. ctx->use_certificate_chain_file(s_certfile);
  41. ctx->use_private_key_file(s_keyfile, asio::ssl::context::pem);
  42. } catch (std::exception& e) {
  43. std::cout << "Exception: " << e.what() << std::endl;
  44. }
  45. return ctx;
  46. }
  47. // feed buffer to asr engine for decoder
  48. void WebSocketServer::do_decoder(const std::vector<char>& buffer,
  49. websocketpp::connection_hdl& hdl,
  50. const nlohmann::json& msg) {
  51. try {
  52. int num_samples = buffer.size(); // the size of the buf
  53. if (!buffer.empty()) {
  54. // fout.write(buffer.data(), buffer.size());
  55. // feed data to asr engine
  56. FUNASR_RESULT Result = FunOfflineInferBuffer(
  57. asr_hanlde, buffer.data(), buffer.size(), RASR_NONE, NULL, 16000);
  58. std::string asr_result =
  59. ((FUNASR_RECOG_RESULT*)Result)->msg; // get decode result
  60. websocketpp::lib::error_code ec;
  61. nlohmann::json jsonresult; // result json
  62. jsonresult["text"] = asr_result; // put result in 'text'
  63. jsonresult["mode"] = "offline";
  64. jsonresult["wav_name"] = msg["wav_name"];
  65. // send the json to client
  66. if (is_ssl) {
  67. wss_server_->send(hdl, jsonresult.dump(),
  68. websocketpp::frame::opcode::text, ec);
  69. } else {
  70. server_->send(hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
  71. ec);
  72. }
  73. std::cout << "buffer.size=" << buffer.size()
  74. << ",result json=" << jsonresult.dump() << std::endl;
  75. if (!isonline) {
  76. // close the client if it is not online asr
  77. // server_->close(hdl, websocketpp::close::status::normal, "DONE", ec);
  78. // fout.close();
  79. }
  80. }
  81. } catch (std::exception const& e) {
  82. std::cerr << "Error: " << e.what() << std::endl;
  83. }
  84. }
  85. void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
  86. scoped_lock guard(m_lock); // for threads safty
  87. check_and_clean_connection(); // remove closed connection
  88. std::shared_ptr<FUNASR_MESSAGE> data_msg =
  89. std::make_shared<FUNASR_MESSAGE>(); // put a new data vector for new
  90. // connection
  91. data_msg->samples = std::make_shared<std::vector<char>>();
  92. data_msg->msg = nlohmann::json::parse("{}");
  93. data_map.emplace(hdl, data_msg);
  94. std::cout << "on_open, active connections: " << data_map.size() << std::endl;
  95. }
  96. void WebSocketServer::on_close(websocketpp::connection_hdl hdl) {
  97. scoped_lock guard(m_lock);
  98. data_map.erase(hdl); // remove data vector when connection is closed
  99. std::cout << "on_close, active connections: " << data_map.size() << std::endl;
  100. }
  101. // remove closed connection
  102. void WebSocketServer::check_and_clean_connection() {
  103. std::vector<websocketpp::connection_hdl> to_remove; // remove list
  104. auto iter = data_map.begin();
  105. while (iter != data_map.end()) { // loop to find closed connection
  106. websocketpp::connection_hdl hdl = iter->first;
  107. if (is_ssl) {
  108. wss_server::connection_ptr con = wss_server_->get_con_from_hdl(hdl);
  109. if (con->get_state() != 1) { // session::state::open ==1
  110. to_remove.push_back(hdl);
  111. }
  112. } else {
  113. server::connection_ptr con = server_->get_con_from_hdl(hdl);
  114. if (con->get_state() != 1) { // session::state::open ==1
  115. to_remove.push_back(hdl);
  116. }
  117. }
  118. iter++;
  119. }
  120. for (auto hdl : to_remove) {
  121. data_map.erase(hdl);
  122. std::cout << "remove one connection " << std::endl;
  123. }
  124. }
  125. void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
  126. message_ptr msg) {
  127. unique_lock lock(m_lock);
  128. // find the sample data vector according to one connection
  129. std::shared_ptr<FUNASR_MESSAGE> msg_data = nullptr;
  130. auto it_data = data_map.find(hdl);
  131. if (it_data != data_map.end()) {
  132. msg_data = it_data->second;
  133. }
  134. std::shared_ptr<std::vector<char>> sample_data_p = msg_data->samples;
  135. lock.unlock();
  136. if (sample_data_p == nullptr) {
  137. std::cout << "error when fetch sample data vector" << std::endl;
  138. return;
  139. }
  140. const std::string& payload = msg->get_payload(); // get msg type
  141. switch (msg->get_opcode()) {
  142. case websocketpp::frame::opcode::text: {
  143. nlohmann::json jsonresult = nlohmann::json::parse(payload);
  144. if (jsonresult["wav_name"] != nullptr) {
  145. msg_data->msg["wav_name"] = jsonresult["wav_name"];
  146. }
  147. if (jsonresult["is_speaking"] == false ||
  148. jsonresult["is_finished"] == true) {
  149. std::cout << "client done" << std::endl;
  150. if (isonline) {
  151. // do_close(ws);
  152. } else {
  153. // add padding to the end of the wav data
  154. std::vector<short> padding(static_cast<short>(0.3 * 16000));
  155. sample_data_p->insert(sample_data_p->end(), padding.data(),
  156. padding.data() + padding.size());
  157. // for offline, send all receive data to decoder engine
  158. asio::post(io_decoder_,
  159. std::bind(&WebSocketServer::do_decoder, this,
  160. std::move(*(sample_data_p.get())),
  161. std::move(hdl), std::move(msg_data->msg)));
  162. }
  163. }
  164. break;
  165. }
  166. case websocketpp::frame::opcode::binary: {
  167. // recived binary data
  168. const auto* pcm_data = static_cast<const char*>(payload.data());
  169. int32_t num_samples = payload.size();
  170. if (isonline) {
  171. // if online TODO(zhaoming) still not done
  172. std::vector<char> s(pcm_data, pcm_data + num_samples);
  173. asio::post(io_decoder_,
  174. std::bind(&WebSocketServer::do_decoder, this, std::move(s),
  175. std::move(hdl), std::move(msg_data->msg)));
  176. } else {
  177. // for offline, we add receive data to end of the sample data vector
  178. sample_data_p->insert(sample_data_p->end(), pcm_data,
  179. pcm_data + num_samples);
  180. }
  181. break;
  182. }
  183. default:
  184. break;
  185. }
  186. }
  187. // init asr model
  188. void WebSocketServer::initAsr(std::map<std::string, std::string>& model_path,
  189. int thread_num) {
  190. try {
  191. // init model with api
  192. asr_hanlde = FunOfflineInit(model_path, thread_num);
  193. std::cout << "model ready" << std::endl;
  194. } catch (const std::exception& e) {
  195. std::cout << e.what() << std::endl;
  196. }
  197. }