websocketsrv.cpp 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. /**
  2. * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
  3. * Reserved. MIT License (https://opensource.org/licenses/MIT)
  4. */
  5. /* 2022-2023 by zhaomingwork */
  6. // websocket server for asr engine
  7. // take some ideas from https://github.com/k2-fsa/sherpa-onnx
  8. // online-websocket-server-impl.cc, thanks. The websocket server has two threads
  9. // pools, one for handle network data and one for asr decoder.
  10. // now only support offline engine.
  11. #include "websocketsrv.h"
  12. #include <thread>
  13. #include <utility>
  14. #include <vector>
  15. // feed buffer to asr engine for decoder
  16. void WebSocketServer::do_decoder(const std::vector<char>& buffer,
  17. websocketpp::connection_hdl& hdl) {
  18. try {
  19. int num_samples = buffer.size(); // the size of the buf
  20. if (!buffer.empty()) {
  21. // fout.write(buffer.data(), buffer.size());
  22. // feed data to asr engine
  23. FUNASR_RESULT Result = FunOfflineInferBuffer(
  24. asr_hanlde, buffer.data(), buffer.size(), RASR_NONE, NULL, 16000);
  25. std::string asr_result =
  26. ((FUNASR_RECOG_RESULT*)Result)->msg; // get decode result
  27. websocketpp::lib::error_code ec;
  28. nlohmann::json jsonresult; // result json
  29. jsonresult["text"] = asr_result; // put result in 'text'
  30. // send the json to client
  31. server_->send(hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
  32. ec);
  33. std::cout << "buffer.size=" << buffer.size()
  34. << ",result json=" << jsonresult.dump() << std::endl;
  35. if (!isonline) {
  36. // close the client if it is not online asr
  37. server_->close(hdl, websocketpp::close::status::normal, "DONE", ec);
  38. // fout.close();
  39. }
  40. }
  41. } catch (std::exception const& e) {
  42. std::cerr << "Error: " << e.what() << std::endl;
  43. }
  44. }
  45. void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
  46. scoped_lock guard(m_lock); // for threads safty
  47. check_and_clean_connection(); // remove closed connection
  48. sample_map.emplace(
  49. hdl, std::make_shared<std::vector<char>>()); // put a new data vector for
  50. // new connection
  51. std::cout << "on_open, active connections: " << sample_map.size()
  52. << std::endl;
  53. }
  54. void WebSocketServer::on_close(websocketpp::connection_hdl hdl) {
  55. scoped_lock guard(m_lock);
  56. sample_map.erase(hdl); // remove data vector when connection is closed
  57. std::cout << "on_close, active connections: " << sample_map.size()
  58. << std::endl;
  59. }
  60. // remove closed connection
  61. void WebSocketServer::check_and_clean_connection() {
  62. std::vector<websocketpp::connection_hdl> to_remove; // remove list
  63. auto iter = sample_map.begin();
  64. while (iter != sample_map.end()) { // loop to find closed connection
  65. websocketpp::connection_hdl hdl = iter->first;
  66. server::connection_ptr con = server_->get_con_from_hdl(hdl);
  67. if (con->get_state() != 1) { // session::state::open ==1
  68. to_remove.push_back(hdl);
  69. }
  70. iter++;
  71. }
  72. for (auto hdl : to_remove) {
  73. sample_map.erase(hdl);
  74. std::cout << "remove one connection " << std::endl;
  75. }
  76. }
  77. void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
  78. message_ptr msg) {
  79. unique_lock lock(m_lock);
  80. // find the sample data vector according to one connection
  81. std::shared_ptr<std::vector<char>> sample_data_p = nullptr;
  82. auto it = sample_map.find(hdl);
  83. if (it != sample_map.end()) {
  84. sample_data_p = it->second;
  85. }
  86. lock.unlock();
  87. if (sample_data_p == nullptr) {
  88. std::cout << "error when fetch sample data vector" << std::endl;
  89. return;
  90. }
  91. const std::string& payload = msg->get_payload(); // get msg type
  92. switch (msg->get_opcode()) {
  93. case websocketpp::frame::opcode::text:
  94. if (payload == "Done") {
  95. std::cout << "client done" << std::endl;
  96. if (isonline) {
  97. // do_close(ws);
  98. } else {
  99. // for offline, send all receive data to decoder engine
  100. asio::post(io_decoder_, std::bind(&WebSocketServer::do_decoder, this,
  101. std::move(*(sample_data_p.get())),
  102. std::move(hdl)));
  103. }
  104. }
  105. break;
  106. case websocketpp::frame::opcode::binary: {
  107. // recived binary data
  108. const auto* pcm_data = static_cast<const char*>(payload.data());
  109. int32_t num_samples = payload.size();
  110. if (isonline) {
  111. // if online TODO(zhaoming) still not done
  112. std::vector<char> s(pcm_data, pcm_data + num_samples);
  113. asio::post(io_decoder_, std::bind(&WebSocketServer::do_decoder, this,
  114. std::move(s), std::move(hdl)));
  115. } else {
  116. // for offline, we add receive data to end of the sample data vector
  117. sample_data_p->insert(sample_data_p->end(), pcm_data,
  118. pcm_data + num_samples);
  119. }
  120. break;
  121. }
  122. default:
  123. break;
  124. }
  125. }
  126. // init asr model
  127. void WebSocketServer::initAsr(std::map<std::string, std::string>& model_path,
  128. int thread_num) {
  129. try {
  130. // init model with api
  131. asr_hanlde = FunOfflineInit(model_path, thread_num);
  132. std::cout << "model ready" << std::endl;
  133. } catch (const std::exception& e) {
  134. std::cout << e.what() << std::endl;
  135. }
  136. }