Procházet zdrojové kódy

Dev server hotwords (#1033)

* add server use hotwords (#1027)

* change send data size

* 增加菜单栏和APK下载地址

* 忽略SSL证书验证

* add server use hotwords

* 增加官方的服务地址作为默认地址

* adapt for server hotwords

---------

Co-authored-by: 夜雨飘零 <yeyupiaoling@foxmail.com>
Yabin Li před 2 roky
rodič
revize
6eb0dfced4

+ 20 - 11
funasr/runtime/android/AndroidClient/app/src/main/java/com/yeyupiaoling/androidclient/MainActivity.java

@@ -39,6 +39,8 @@ public class MainActivity extends AppCompatActivity {
     public static final String TAG = MainActivity.class.getSimpleName();
     // WebSocket地址
     public String ASR_HOST = "";
+    // 官方WebSocket地址
+    public static final String DEFAULT_HOST = "wss://101.37.77.25:10088";
     // 发送的JSON数据
     public static final String MODE = "2pass";
     public static final String CHUNK_SIZE = "5, 10, 5";
@@ -61,7 +63,6 @@ public class MainActivity extends AppCompatActivity {
     // 控件
     private Button recordBtn;
     private TextView resultText;
-    private WebSocket webSocket;
 
     @SuppressLint("ClickableViewAccessibility")
     @Override
@@ -106,8 +107,8 @@ public class MainActivity extends AppCompatActivity {
             ASR_HOST = uri;
         }
         // 读取热词
-        String hotWords = sharedPreferences.getString("hotwords", "");
-        if (!hotWords.equals("")) {
+        String hotWords = sharedPreferences.getString("hotwords", null);
+        if (hotWords != null) {
             this.hotWords = hotWords;
         }
     }
@@ -150,6 +151,14 @@ public class MainActivity extends AppCompatActivity {
                 editor.apply();
             }
         });
+        builder.setNeutralButton("使用官方服务", (dialog, id) -> {
+            ASR_HOST = DEFAULT_HOST;
+            input.setText(DEFAULT_HOST);
+            Toast.makeText(MainActivity.this, "WebSocket地址:" + ASR_HOST, Toast.LENGTH_SHORT).show();
+            SharedPreferences.Editor editor = sharedPreferences.edit();
+            editor.putString("uri", ASR_HOST);
+            editor.apply();
+        });
         AlertDialog dialog = builder.create();
         dialog.show();
     }
@@ -166,12 +175,10 @@ public class MainActivity extends AppCompatActivity {
         builder.setView(view);
         builder.setPositiveButton("确定", (dialog, id) -> {
             String hotwords = input.getText().toString();
-            if (!hotwords.equals("")) {
-                this.hotWords = hotwords;
-                SharedPreferences.Editor editor = sharedPreferences.edit();
-                editor.putString("hotwords", hotwords);
-                editor.apply();
-            }
+            this.hotWords = hotwords;
+            SharedPreferences.Editor editor = sharedPreferences.edit();
+            editor.putString("hotwords", hotwords);
+            editor.apply();
         });
         AlertDialog dialog = builder.create();
         dialog.show();
@@ -225,7 +232,7 @@ public class MainActivity extends AppCompatActivity {
         Request request = new Request.Builder()
                 .url(ASR_HOST)
                 .build();
-        webSocket = client.newWebSocket(request, new WebSocketListener() {
+        WebSocket webSocket = client.newWebSocket(request, new WebSocketListener() {
 
             @Override
             public void onOpen(@NonNull WebSocket webSocket, @NonNull Response response) {
@@ -311,7 +318,9 @@ public class MainActivity extends AppCompatActivity {
             obj.put("chunk_size", array);
             obj.put("chunk_interval", CHUNK_INTERVAL);
             obj.put("wav_name", "default");
-            obj.put("hotwords", hotWords);
+            if (!hotWords.equals("")) {
+                obj.put("hotwords", hotWords);
+            }
             obj.put("wav_format", "pcm");
             obj.put("is_speaking", isSpeaking);
             return obj.toString();

+ 3 - 1
funasr/runtime/docs/SDK_advanced_guide_offline.md

@@ -83,7 +83,8 @@ nohup bash run_server.sh \
   --io-thread-num  8 \
   --port 10095 \
   --certfile  ../../../ssl_key/server.crt \
-  --keyfile ../../../ssl_key/server.key > log.out 2>&1 &
+  --keyfile ../../../ssl_key/server.key \
+  --hotwordsfile ../../hotwords.txt > log.out 2>&1 &
  ```
 
 Introduction to run_server.sh parameters: 
@@ -102,6 +103,7 @@ Introduction to run_server.sh parameters:
 --io-thread-num: Number of IO threads that the server starts. Default is 1.
 --certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close ssl,set 0
 --keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key. 
+--hotwordsfile   Hotword file path, one line for each hot word, if the client provides hot words, then combined with the hot words provided by the client. Default is ../../hotwords.txt
 ```
 
 ### Shutting Down the FunASR Service

+ 3 - 1
funasr/runtime/docs/SDK_advanced_guide_offline_en.md

@@ -79,7 +79,8 @@ nohup bash run_server.sh \
   --io-thread-num  8 \
   --port 10095 \
   --certfile  ../../../ssl_key/server.crt \
-  --keyfile ../../../ssl_key/server.key
+  --keyfile ../../../ssl_key/server.key \
+  --hotwordsfile ../../hotwords.txt
  ```
 
 Introduction to run_server.sh parameters: 
@@ -98,6 +99,7 @@ Introduction to run_server.sh parameters:
 --io-thread-num: Number of IO threads that the server starts. Default is 1.
 --certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close ssl,set 0
 --keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key. 
+--hotwordsfile   Hotword file path, one line for each hot word, if the client provides hot words, then combined with the hot words provided by the client. Default is ../../hotwords.txt
 ```
 
 ### Shutting Down the FunASR Service

+ 3 - 1
funasr/runtime/docs/SDK_advanced_guide_offline_en_zh.md

@@ -162,7 +162,8 @@ nohup bash run_server.sh \
   --io-thread-num  8 \
   --port 10095 \
   --certfile  ../../../ssl_key/server.crt \
-  --keyfile ../../../ssl_key/server.key > log.out 2>&1 &
+  --keyfile ../../../ssl_key/server.key \
+  --hotwordsfile ../../hotwords.txt > log.out 2>&1 &
  ```
 **run_server.sh命令参数介绍**
 ```text
@@ -179,6 +180,7 @@ nohup bash run_server.sh \
 --io-thread-num  服务端启动的IO线程数,默认为 1
 --certfile  ssl的证书文件,默认为:../../../ssl_key/server.crt,如果需要关闭ssl,参数设置为0
 --keyfile   ssl的密钥文件,默认为:../../../ssl_key/server.key
+--hotwordsfile   热词文件路径,每一个热词一行,如果客户端提供热词,则与客户端提供的热词合并一起使用。默认为:../../hotwords.txt
 ```
 
 ### 关闭FunASR服务

+ 3 - 1
funasr/runtime/docs/SDK_advanced_guide_offline_zh.md

@@ -165,7 +165,8 @@ nohup bash run_server.sh \
   --io-thread-num  8 \
   --port 10095 \
   --certfile  ../../../ssl_key/server.crt \
-  --keyfile ../../../ssl_key/server.key > log.out 2>&1 &
+  --keyfile ../../../ssl_key/server.key \
+  --hotwordsfile ../../hotwords.txt > log.out 2>&1 &
  ```
 **run_server.sh命令参数介绍**
 ```text
@@ -182,6 +183,7 @@ nohup bash run_server.sh \
 --io-thread-num  服务端启动的IO线程数,默认为 1
 --certfile  ssl的证书文件,默认为:../../../ssl_key/server.crt,如果需要关闭ssl,参数设置为0
 --keyfile   ssl的密钥文件,默认为:../../../ssl_key/server.key
+--hotwordsfile   热词文件路径,每一个热词一行,如果客户端提供热词,则与客户端提供的热词合并一起使用。默认为:../../hotwords.txt
 ```
 
 ### 关闭FunASR服务

+ 3 - 1
funasr/runtime/docs/SDK_advanced_guide_online.md

@@ -72,7 +72,8 @@ nohup bash run_server_2pass.sh \
   --io-thread-num  8 \
   --port 10095 \
   --certfile  ../../../ssl_key/server.crt \
-  --keyfile ../../../ssl_key/server.key > log.out 2>&1 &
+  --keyfile ../../../ssl_key/server.key \
+  --hotwordsfile ../../hotwords.txt > log.out 2>&1 &
 
 # If you want to close ssl,please add:--certfile 0
 # If you want to deploy the timestamp or hotword model, please set --model-dir to the corresponding model:
@@ -97,6 +98,7 @@ nohup bash run_server_2pass.sh \
 --io-thread-num: Number of IO threads that the server starts. Default is 1.
 --certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close ssl,set 0
 --keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key. 
+--hotwordsfile   Hotword file path, one line for each hot word, if the client provides hot words, then combined with the hot words provided by the client. Default is ../../hotwords.txt
 ```
 
 ### Shutting Down the FunASR Service

+ 5 - 2
funasr/runtime/docs/SDK_advanced_guide_online_zh.md

@@ -31,7 +31,8 @@ nohup bash run_server_2pass.sh \
   --model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx  \
   --online-model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx  \
   --punc-dir damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx \
-  --itn-dir thuduj12/fst_itn_zh > log.out 2>&1 &
+  --itn-dir thuduj12/fst_itn_zh \
+  --hotwordsfile ../../hotwords.txt > log.out 2>&1 &
 
 # 如果您想关闭ssl,增加参数:--certfile 0
 # 如果您想使用时间戳或者热词模型进行部署,请设置--model-dir为对应模型:
@@ -80,7 +81,8 @@ nohup bash run_server_2pass.sh \
   --io-thread-num  8 \
   --port 10095 \
   --certfile  ../../../ssl_key/server.crt \
-  --keyfile ../../../ssl_key/server.key > log.out 2>&1 &
+  --keyfile ../../../ssl_key/server.key \
+  --hotwordsfile ../../hotwords.txt > log.out 2>&1 &
  ```
 **run_server_2pass.sh命令参数介绍**
 ```text
@@ -98,6 +100,7 @@ nohup bash run_server_2pass.sh \
 --io-thread-num  服务端启动的IO线程数,默认为 1
 --certfile  ssl的证书文件,默认为:../../../ssl_key/server.crt,如果需要关闭ssl,参数设置为0
 --keyfile   ssl的密钥文件,默认为:../../../ssl_key/server.key
+--hotwordsfile   热词文件路径,每一个热词一行,如果客户端提供热词,则与客户端提供的热词合并一起使用。默认为:../../hotwords.txt
 ```
 
 ### 关闭FunASR服务

+ 5 - 2
funasr/runtime/run_server.sh

@@ -9,6 +9,7 @@ io_thread_num=8
 port=10095
 certfile="../../../ssl_key/server.crt"
 keyfile="../../../ssl_key/server.key"
+hotwordsfile="../../hotwords.txt"
 
 . ../../egs/aishell/transformer/utils/parse_options.sh || exit 1;
 
@@ -24,7 +25,8 @@ if [ -z "$certfile" ] || [ "$certfile" -eq 0 ]; then
   --io-thread-num  ${io_thread_num} \
   --port ${port} \
   --certfile  "" \
-  --keyfile ""
+  --keyfile "" \
+  --hotwordsfile ${hotwordsfile}
 else
 ./funasr-wss-server  \
   --download-model-dir ${download_model_dir} \
@@ -36,5 +38,6 @@ else
   --io-thread-num  ${io_thread_num} \
   --port ${port} \
   --certfile  ${certfile} \
-  --keyfile ${keyfile}
+  --keyfile ${keyfile} \
+  --hotwordsfile ${hotwordsfile}
 fi

+ 5 - 2
funasr/runtime/run_server_2pass.sh

@@ -10,6 +10,7 @@ io_thread_num=8
 port=10095
 certfile="../../../ssl_key/server.crt"
 keyfile="../../../ssl_key/server.key"
+hotwordsfile="../../hotwords.txt"
 
 . ../../egs/aishell/transformer/utils/parse_options.sh || exit 1;
 
@@ -26,7 +27,8 @@ if [ -z "$certfile" ] || [ "$certfile" -eq 0 ]; then
   --io-thread-num  ${io_thread_num} \
   --port ${port} \
   --certfile  "" \
-  --keyfile ""
+  --keyfile "" \
+  --hotwordsfile ${hotwordsfile}
 else
 ./funasr-wss-server-2pass  \
   --download-model-dir ${download_model_dir} \
@@ -39,5 +41,6 @@ else
   --io-thread-num  ${io_thread_num} \
   --port ${port} \
   --certfile  ${certfile} \
-  --keyfile ${keyfile}
+  --keyfile ${keyfile} \
+  --hotwordsfile ${hotwordsfile}
 fi

+ 29 - 3
funasr/runtime/websocket/bin/funasr-wss-server-2pass.cpp

@@ -14,6 +14,9 @@
 #include <unistd.h>
 #include "websocket-server-2pass.h"
 
+#include <fstream>
+std::string hotwords = "";
+
 using namespace std;
 void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
               std::map<std::string, std::string>& model_path) {
@@ -109,6 +112,15 @@ int main(int argc, char* argv[]) {
         "connection",
         false, "../../../ssl_key/server.key", "string");
 
+    TCLAP::ValueArg<std::string> hotwordsfile(
+        "", "hotwordsfile",
+        "default: ../../hotwords.txt, path of hotwordsfile"
+        "connection",
+        false, "../../hotwords.txt", "string");
+
+    // add file
+    cmd.add(hotwordsfile);
+
     cmd.add(certfile);
     cmd.add(keyfile);
 
@@ -417,6 +429,21 @@ int main(int argc, char* argv[]) {
     std::string s_certfile = certfile.getValue();
     std::string s_keyfile = keyfile.getValue();
 
+    std::string s_hotwordsfile = hotwordsfile.getValue();
+    std::string line;
+    std::ifstream file(s_hotwordsfile);
+    LOG(INFO) << "hotwordsfile path: " << s_hotwordsfile;
+
+    if (file.is_open()) {
+        while (getline(file, line)) {
+            hotwords += line+HOTWORD_SEP;
+        }
+        LOG(INFO) << "hotwords: " << hotwords;
+        file.close();
+    } else {
+        LOG(ERROR) << "Unable to open hotwords file: " << s_hotwordsfile;
+    }
+
     bool is_ssl = false;
     if (!s_certfile.empty()) {
       is_ssl = true;
@@ -460,8 +487,7 @@ int main(int argc, char* argv[]) {
       websocket_srv.initAsr(model_path, s_model_thread_num);  // init asr model
     }
 
-    std::cout << "asr model init finished. listen on port:" << s_port
-              << std::endl;
+    LOG(INFO) << "asr model init finished. listen on port:" << s_port;
 
     // Start the ASIO network io_service run loop
     std::vector<std::thread> ts;
@@ -480,7 +506,7 @@ int main(int argc, char* argv[]) {
     }
 
   } catch (std::exception const& e) {
-    std::cerr << "Error: " << e.what() << std::endl;
+    LOG(ERROR) << "Error: " << e.what();
   }
 
   return 0;

+ 29 - 3
funasr/runtime/websocket/bin/funasr-wss-server.cpp

@@ -13,6 +13,9 @@
 #include "websocket-server.h"
 #include <unistd.h>
 
+#include <fstream>
+std::string hotwords = "";
+
 using namespace std;
 void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
               std::map<std::string, std::string>& model_path) {
@@ -95,6 +98,15 @@ int main(int argc, char* argv[]) {
         "default: ../../../ssl_key/server.key, path of keyfile for WSS connection", 
         false, "../../../ssl_key/server.key", "string");
 
+    TCLAP::ValueArg<std::string> hotwordsfile(
+        "", "hotwordsfile",
+        "default: ../../hotwords.txt, path of hotwordsfile"
+        "connection",
+        false, "../../hotwords.txt", "string");
+
+    // add file
+    cmd.add(hotwordsfile);
+
     cmd.add(certfile);
     cmd.add(keyfile);
 
@@ -331,6 +343,21 @@ int main(int argc, char* argv[]) {
     std::string s_certfile = certfile.getValue();
     std::string s_keyfile = keyfile.getValue();
 
+    std::string s_hotwordsfile = hotwordsfile.getValue();
+    std::string line;
+    std::ifstream file(s_hotwordsfile);
+    LOG(INFO) << "hotwordsfile path: " << s_hotwordsfile;
+
+    if (file.is_open()) {
+        while (getline(file, line)) {
+            hotwords += line+HOTWORD_SEP;
+        }
+        LOG(INFO) << "hotwords: " << hotwords;
+        file.close();
+    } else {
+        LOG(ERROR) << "Unable to open hotwords file: " << s_hotwordsfile;
+    }
+
     bool is_ssl = false;
     if (!s_certfile.empty()) {
       is_ssl = true;
@@ -374,8 +401,7 @@ int main(int argc, char* argv[]) {
       websocket_srv.initAsr(model_path, s_model_thread_num);  // init asr model
     }
 
-    std::cout << "asr model init finished. listen on port:" << s_port
-              << std::endl;
+    LOG(INFO) << "asr model init finished. listen on port:" << s_port;
 
     // Start the ASIO network io_service run loop
     std::vector<std::thread> ts;
@@ -394,7 +420,7 @@ int main(int argc, char* argv[]) {
     }
 
   } catch (std::exception const& e) {
-    std::cerr << "Error: " << e.what() << std::endl;
+    LOG(ERROR) << "Error: " << e.what();
   }
 
   return 0;

+ 15 - 4
funasr/runtime/websocket/bin/websocket-server-2pass.cpp

@@ -15,7 +15,9 @@
 #include <thread>
 #include <utility>
 #include <vector>
-#include <chrono>
+
+extern std::string hotwords;
+
 context_ptr WebSocketServer::on_tls_init(tls_mode mode,
                                          websocketpp::connection_hdl hdl,
                                          std::string& s_certfile,
@@ -370,17 +372,26 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
           msg_data->msg["hotwords"] = jsonresult["hotwords"];
           if (!msg_data->msg["hotwords"].empty()) {
             std::string hw = msg_data->msg["hotwords"];
-            LOG(INFO)<<"hotwords: " << hw;
-            std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
+            hw = hw + " " + hotwords;
+            LOG(INFO) << "hotwords: " << hw;
+            std::vector<std::vector<float>> new_hotwords_embedding = CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
             msg_data->hotwords_embedding =
                 std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
           }
-        }else{
+        } else {
+          if (hotwords.empty()) {
             std::string hw = "";
             LOG(INFO)<<"hotwords: " << hw;
             std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
             msg_data->hotwords_embedding =
                 std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
+          }else {
+            std::string hw = hotwords;
+            LOG(INFO) << "hotwords: " << hw;
+            std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
+            msg_data->hotwords_embedding =
+                std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
+          }
         }
       }
       if (jsonresult.contains("audio_fs")) {

+ 14 - 3
funasr/runtime/websocket/bin/websocket-server.cpp

@@ -16,6 +16,8 @@
 #include <utility>
 #include <vector>
 
+extern std::string hotwords;
+
 context_ptr WebSocketServer::on_tls_init(tls_mode mode,
                                          websocketpp::connection_hdl hdl,
                                          std::string& s_certfile,
@@ -266,17 +268,26 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
           msg_data->msg["hotwords"] = jsonresult["hotwords"];
           if (!msg_data->msg["hotwords"].empty()) {
             std::string hw = msg_data->msg["hotwords"];
-            LOG(INFO)<<"hotwords: " << hw;
-            std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
+            hw = hw + " " + hotwords;
+            LOG(INFO) << "hotwords: " << hw;
+            std::vector<std::vector<float>> new_hotwords_embedding = CompileHotwordEmbedding(asr_hanlde, hw);
             msg_data->hotwords_embedding =
                 std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
           }
-        }else{
+        } else {
+          if (hotwords.empty()) {
             std::string hw = "";
             LOG(INFO)<<"hotwords: " << hw;
             std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
             msg_data->hotwords_embedding =
                 std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
+          }else {
+            std::string hw = hotwords;
+            LOG(INFO) << "hotwords: " << hw;
+            std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
+            msg_data->hotwords_embedding =
+                std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
+          }
         }
       }
       if (jsonresult.contains("audio_fs")) {

+ 2 - 0
funasr/runtime/websocket/hotwords.txt

@@ -0,0 +1,2 @@
+阿里巴巴
+通义实验室