雾聪 2 лет назад
Родитель
Сommit
d4a021b45b

+ 9 - 8
funasr/runtime/websocket/websocket-server-2pass.cpp

@@ -53,7 +53,7 @@ context_ptr WebSocketServer::on_tls_init(tls_mode mode,
   return ctx;
 }
 
-nlohmann::json handle_result(FUNASR_RESULT result, nlohmann::json msg) {
+nlohmann::json handle_result(FUNASR_RESULT result) {
 
     websocketpp::lib::error_code ec;
     nlohmann::json jsonresult;
@@ -72,10 +72,6 @@ nlohmann::json handle_result(FUNASR_RESULT result, nlohmann::json msg) {
       jsonresult["mode"] = "2pass-offline";    
     }
 
-    if (msg.contains("wav_name")) {
-      jsonresult["wav_name"] = msg["wav_name"];
-    }
-
     return jsonresult;
 }
 // feed buffer to asr engine for decoder
@@ -83,7 +79,7 @@ void WebSocketServer::do_decoder(
     std::vector<char>& buffer, websocketpp::connection_hdl& hdl,
     nlohmann::json& msg, std::vector<std::vector<std::string>>& punc_cache,
     websocketpp::lib::mutex& thread_lock, bool& is_final,
-    FUNASR_HANDLE& tpass_online_handle) {
+    std::string wav_name, FUNASR_HANDLE& tpass_online_handle) {
  
   // lock for each connection
   scoped_lock guard(thread_lock);
@@ -123,7 +119,8 @@ void WebSocketServer::do_decoder(
       if (Result) {
         websocketpp::lib::error_code ec;
         nlohmann::json jsonresult =
-            handle_result(Result, msg);
+            handle_result(Result);
+        jsonresult["wav_name"] = wav_name;
         jsonresult["is_final"] = false;
         if(jsonresult["text"] != "") {
           if (is_ssl) {
@@ -154,7 +151,8 @@ void WebSocketServer::do_decoder(
       if (Result) {
         websocketpp::lib::error_code ec;
         nlohmann::json jsonresult =
-            handle_result(Result, msg);
+            handle_result(Result);
+        jsonresult["wav_name"] = wav_name;
         jsonresult["is_final"] = true;
         if (is_ssl) {
           wss_server_->send(hdl, jsonresult.dump(),
@@ -285,6 +283,7 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
       if (jsonresult.contains("chunk_size")){
         if(msg_data->tpass_online_handle == NULL){
           std::vector<int> chunk_size_vec = jsonresult["chunk_size"].get<std::vector<int>>();
+          LOG(INFO) << "----------------FunTpassOnlineInit----------------------";
           FUNASR_HANDLE tpass_online_handle =
               FunTpassOnlineInit(tpass_handle, chunk_size_vec);
           msg_data->tpass_online_handle = tpass_online_handle;
@@ -303,6 +302,7 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
                       std::move(*(sample_data_p.get())), std::move(hdl),
                       std::ref(msg_data->msg), std::ref(*(punc_cache_p.get())),
                       std::ref(*thread_lock_p), std::move(true),
+                      msg_data->msg["wav_name"],
                       std::ref(msg_data->tpass_online_handle)));
       }
       break;
@@ -333,6 +333,7 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
                                   std::ref(msg_data->msg),
                                   std::ref(*(punc_cache_p.get())),
                                   std::ref(*thread_lock_p), std::move(false),
+                                  msg_data->msg["wav_name"],
                                   std::ref(msg_data->tpass_online_handle)));
         }
       } else {

+ 1 - 0
funasr/runtime/websocket/websocket-server-2pass.h

@@ -115,6 +115,7 @@ class WebSocketServer {
                   nlohmann::json& msg,
                   std::vector<std::vector<std::string>>& punc_cache,
                   websocketpp::lib::mutex& thread_lock, bool& is_final,
+                  std::string wav_name,
                   FUNASR_HANDLE& tpass_online_handle);
 
   void initAsr(std::map<std::string, std::string>& model_path, int thread_num);