Explorar el Código

Merge branch 'main' of github.com:alibaba-damo-academy/FunASR
add

游雁 hace 2 años
padre
commit
1a9583748a

+ 4 - 2
funasr/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp

@@ -195,13 +195,14 @@ int main(int argc, char** argv)
     TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0");
     TCLAP::ValueArg<std::string>    offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
     TCLAP::ValueArg<std::string>    online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
-    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
+    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
     TCLAP::ValueArg<std::string>    vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
     TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
     TCLAP::ValueArg<std::string>    punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
     TCLAP::ValueArg<std::string>    punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
     TCLAP::ValueArg<std::string>    asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
     TCLAP::ValueArg<std::int32_t>   onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
+    TCLAP::ValueArg<std::int32_t>   thread_num_("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
 
     TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
 
@@ -215,6 +216,7 @@ int main(int argc, char** argv)
     cmd.add(wav_path);
     cmd.add(asr_mode);
     cmd.add(onnx_thread);
+    cmd.add(thread_num_);
     cmd.parse(argc, argv);
 
     std::map<std::string, std::string> model_path;
@@ -288,7 +290,7 @@ int main(int argc, char** argv)
     long total_time = 0;
     std::vector<std::thread> threads;
 
-    int rtf_threds = 5;
+    int rtf_threds = thread_num_.getValue();
     for (int i = 0; i < rtf_threds; i++)
     {
         threads.emplace_back(thread(runReg, tpass_hanlde, chunk_size, wav_list, wav_ids, &total_length, &total_time, i, (ASR_TYPE)asr_mode_));

+ 6 - 14
funasr/runtime/websocket/websocket-server-2pass.cpp

@@ -53,22 +53,19 @@ context_ptr WebSocketServer::on_tls_init(tls_mode mode,
   return ctx;
 }
 
-nlohmann::json handle_result(FUNASR_RESULT result, std::string& online_res,
-                             std::string& tpass_res, nlohmann::json msg) {
+nlohmann::json handle_result(FUNASR_RESULT result, nlohmann::json msg) {
 
     websocketpp::lib::error_code ec;
     nlohmann::json jsonresult;
     jsonresult["text"]="";
 
     std::string tmp_online_msg = FunASRGetResult(result, 0);
-    online_res += tmp_online_msg;
     if (tmp_online_msg != "") {
       LOG(INFO) << "online_res :" << tmp_online_msg;
       jsonresult["text"] = tmp_online_msg; 
       jsonresult["mode"] = "2pass-online";
     }
     std::string tmp_tpass_msg = FunASRGetTpassResult(result, 0);
-    tpass_res += tmp_tpass_msg;
     if (tmp_tpass_msg != "") {
       LOG(INFO) << "offline results : " << tmp_tpass_msg;
       jsonresult["text"] = tmp_tpass_msg; 
@@ -86,8 +83,7 @@ void WebSocketServer::do_decoder(
     std::vector<char>& buffer, websocketpp::connection_hdl& hdl,
     nlohmann::json& msg, std::vector<std::vector<std::string>>& punc_cache,
     websocketpp::lib::mutex& thread_lock, bool& is_final,
-    FUNASR_HANDLE& tpass_online_handle, std::string& online_res,
-    std::string& tpass_res) {
+    FUNASR_HANDLE& tpass_online_handle) {
  
   // lock for each connection
   scoped_lock guard(thread_lock);
@@ -127,7 +123,7 @@ void WebSocketServer::do_decoder(
       if (Result) {
         websocketpp::lib::error_code ec;
         nlohmann::json jsonresult =
-            handle_result(Result, online_res, tpass_res, msg["wav_name"]);
+            handle_result(Result, msg["wav_name"]);
         jsonresult["is_final"] = false;
         if(jsonresult["text"] != "") {
           if (is_ssl) {
@@ -158,7 +154,7 @@ void WebSocketServer::do_decoder(
       if (Result) {
         websocketpp::lib::error_code ec;
         nlohmann::json jsonresult =
-            handle_result(Result, online_res, tpass_res, msg["wav_name"]);
+            handle_result(Result, msg["wav_name"]);
         jsonresult["is_final"] = true;
         if (is_ssl) {
           wss_server_->send(hdl, jsonresult.dump(),
@@ -306,9 +302,7 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
                       std::move(*(sample_data_p.get())), std::move(hdl),
                       std::ref(msg_data->msg), std::ref(*(punc_cache_p.get())),
                       std::ref(*thread_lock_p), std::move(true),
-                      std::ref(msg_data->tpass_online_handle),
-                      std::ref(msg_data->online_res),
-                      std::ref(msg_data->tpass_res)));
+                      std::ref(msg_data->tpass_online_handle)));
       }
       break;
     }
@@ -338,9 +332,7 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
                                   std::ref(msg_data->msg),
                                   std::ref(*(punc_cache_p.get())),
                                   std::ref(*thread_lock_p), std::move(false),
-                                  std::ref(msg_data->tpass_online_handle),
-                                  std::ref(msg_data->online_res),
-                                  std::ref(msg_data->tpass_res)));
+                                  std::ref(msg_data->tpass_online_handle)));
         }
       } else {
         sample_data_p->insert(sample_data_p->end(), pcm_data,

+ 1 - 2
funasr/runtime/websocket/websocket-server-2pass.h

@@ -115,8 +115,7 @@ class WebSocketServer {
                   nlohmann::json& msg,
                   std::vector<std::vector<std::string>>& punc_cache,
                   websocketpp::lib::mutex& thread_lock, bool& is_final,
-                  FUNASR_HANDLE& tpass_online_handle, std::string& online_res,
-                  std::string& tpass_res);
+                  FUNASR_HANDLE& tpass_online_handle);
 
   void initAsr(std::map<std::string, std::string>& model_path, int thread_num);
   void on_message(websocketpp::connection_hdl hdl, message_ptr msg);