Просмотр исходного кода

Merge branch 'main' into dev_wjm_infer

jmwang66 2 лет назад
Родитель
Сommit
7375292887

BIN
docs/images/dingding.jpg


+ 1 - 1
egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py

@@ -17,7 +17,7 @@ inference_diar_pipline = pipeline(
     diar_model_config="sond.yaml",
     model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch',
     sv_model="damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch",
-    sv_model_revision="master",
+    sv_model_revision="v1.2.2",
 )
 
 # use audio_list as the input, where the first one is the record to be detected

+ 20 - 6
funasr/runtime/onnxruntime/src/offline-stream.cpp

@@ -1,11 +1,11 @@
 #include "precomp.h"
+#include <unistd.h>
 
 namespace funasr {
 OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
 {
     // VAD model
     if(model_path.find(VAD_DIR) != model_path.end()){
-        use_vad = true;
         string vad_model_path;
         string vad_cmvn_path;
         string vad_config_path;
@@ -16,8 +16,16 @@ OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int
         }
         vad_cmvn_path = PathAppend(model_path.at(VAD_DIR), VAD_CMVN_NAME);
         vad_config_path = PathAppend(model_path.at(VAD_DIR), VAD_CONFIG_NAME);
-        vad_handle = make_unique<FsmnVad>();
-        vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
+        if (access(vad_model_path.c_str(), F_OK) != 0 ||
+            access(vad_cmvn_path.c_str(), F_OK) != 0 ||
+            access(vad_config_path.c_str(), F_OK) != 0 )
+        {
+            LOG(INFO) << "VAD model file is not exist, skip load vad model.";
+        }else{
+            vad_handle = make_unique<FsmnVad>();
+            vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
+            use_vad = true;
+        }
     }
 
     // AM model
@@ -39,7 +47,6 @@ OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int
 
     // PUNC model
     if(model_path.find(PUNC_DIR) != model_path.end()){
-        use_punc = true;
         string punc_model_path;
         string punc_config_path;
     
@@ -49,8 +56,15 @@ OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int
         }
         punc_config_path = PathAppend(model_path.at(PUNC_DIR), PUNC_CONFIG_NAME);
 
-        punc_handle = make_unique<CTTransformer>();
-        punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
+        if (access(punc_model_path.c_str(), F_OK) != 0 ||
+            access(punc_config_path.c_str(), F_OK) != 0 )
+        {
+            LOG(INFO) << "PUNC model file is not exist, skip load punc model.";
+        }else{
+            punc_handle = make_unique<CTTransformer>();
+            punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
+            use_punc = true;
+        }
     }
 }
 

+ 17 - 6
funasr/runtime/python/websocket/wss_client_asr.py

@@ -71,6 +71,8 @@ print(args)
 from queue import Queue
 
 voices = Queue()
+offline_msg_done=False
+ 
 ibest_writer = None
 if args.output_dir is not None:
     writer = DatadirWriter(args.output_dir)
@@ -158,13 +160,20 @@ async def record_from_scp(chunk_begin, chunk_size):
                 message = json.dumps({"is_speaking": is_speaking})
                 #voices.put(message)
                 await websocket.send(message)
-            # print("data_chunk: ", len(data_chunk))
-            # print(voices.qsize())
+ 
             sleep_duration = 0.001 if args.send_without_sleep else 60 * args.chunk_size[1] / args.chunk_interval / 1000
             await asyncio.sleep(sleep_duration)
+    # when all data sent, we need to close websocket
     while not voices.empty():
          await asyncio.sleep(1)
     await asyncio.sleep(3)
+    # offline model need to wait for message recved
+    
+    if args.mode=="offline":
+      global offline_msg_done
+      while  not  offline_msg_done:
+         await asyncio.sleep(1)
+    
     await websocket.close()
      
  
@@ -173,7 +182,7 @@ async def record_from_scp(chunk_begin, chunk_size):
  
              
 async def message(id):
-    global websocket,voices
+    global websocket,voices,offline_msg_done
     text_print = ""
     text_print_2pass_online = ""
     text_print_2pass_offline = ""
@@ -183,7 +192,6 @@ async def message(id):
             meg = await websocket.recv()
             meg = json.loads(meg)
             wav_name = meg.get("wav_name", "demo")
-            # print(wav_name)
             text = meg["text"]
             if ibest_writer is not None:
                 ibest_writer["text"][wav_name] = text
@@ -198,6 +206,7 @@ async def message(id):
                 text_print = text_print[-args.words_max_print:]
                 os.system('clear')
                 print("\rpid" + str(id) + ": " + text_print)
+                offline_msg_done=True
             else:
                 if meg["mode"] == "2pass-online":
                     text_print_2pass_online += "{}".format(text)
@@ -233,8 +242,10 @@ async def ws_client(id, chunk_begin, chunk_size):
   if args.audio_in is None:
        chunk_begin=0
        chunk_size=1
-  global websocket,voices
+  global websocket,voices,offline_msg_done
+ 
   for i in range(chunk_begin,chunk_begin+chunk_size):
+    offline_msg_done=False
     voices = Queue()
     if args.ssl == 1:
         ssl_context = ssl.SSLContext()
@@ -251,7 +262,7 @@ async def ws_client(id, chunk_begin, chunk_size):
         else:
             task = asyncio.create_task(record_microphone())
         #task2 = asyncio.create_task(ws_send())
-        task3 = asyncio.create_task(message(id))
+        task3 = asyncio.create_task(message(str(id)+"_"+str(i))) #processid+fileid
         await asyncio.gather(task, task3)
   exit(0)
     

+ 4 - 4
funasr/runtime/websocket/CMakeLists.txt

@@ -56,8 +56,8 @@ add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog)
 # install openssl first apt-get install libssl-dev
 find_package(OpenSSL REQUIRED)
 
-add_executable(funasr-ws-server "funasr-ws-server.cpp" "websocket-server.cpp")
-add_executable(funasr-ws-client "funasr-ws-client.cpp")
+add_executable(funasr-wss-server "funasr-wss-server.cpp" "websocket-server.cpp")
+add_executable(funasr-wss-client "funasr-wss-client.cpp")
 
-target_link_libraries(funasr-ws-client PUBLIC funasr ssl crypto)
-target_link_libraries(funasr-ws-server PUBLIC funasr ssl crypto)
+target_link_libraries(funasr-wss-client PUBLIC funasr ssl crypto)
+target_link_libraries(funasr-wss-server PUBLIC funasr ssl crypto)

+ 26 - 14
funasr/runtime/websocket/funasr-ws-client.cpp → funasr/runtime/websocket/funasr-wss-client.cpp

@@ -5,7 +5,14 @@
 /* 2022-2023 by zhaomingwork */
 
 // client for websocket, support multiple threads
-// Usage: websocketclient server_ip port wav_path threads_num
+// ./funasr-ws-client  --server-ip <string>
+//                     --port <string>
+//                     --wav-path <string>
+//                     [--thread-num <int>] 
+//                     [--is-ssl <int>]  [--]
+//                     [--version] [-h]
+// example:
+// ./funasr-ws-client --server-ip 127.0.0.1 --port 8889 --wav-path test.wav --thread-num 1 --is-ssl 0
 
 #define ASIO_STANDALONE 1
 #include <websocketpp/client.hpp>
@@ -55,7 +62,7 @@ context_ptr OnTlsInit(websocketpp::connection_hdl) {
             asio::ssl::context::no_sslv3 | asio::ssl::context::single_dh_use);
 
     } catch (std::exception& e) {
-        std::cout << e.what() << std::endl;
+        LOG(ERROR) << e.what();
     }
     return ctx;
 }
@@ -99,7 +106,16 @@ class WebsocketClient {
         const std::string& payload = msg->get_payload();
         switch (msg->get_opcode()) {
             case websocketpp::frame::opcode::text:
-                std::cout << "on_message = " << payload << std::endl;
+				total_num=total_num+1;
+                LOG(INFO)<<total_num<<",on_message = " << payload;
+				if((total_num+1)==wav_index)
+				{
+					websocketpp::lib::error_code ec;
+					m_client.close(m_hdl, websocketpp::close::status::going_away, "", ec);
+					if (ec){
+                        LOG(ERROR)<< "Error closing connection " << ec.message();
+					}
+				}
         }
     }
 
@@ -132,12 +148,8 @@ class WebsocketClient {
             }
             send_wav_data(wav_list[i], wav_ids[i]);
         }
-        WaitABit();
-		m_client.close(m_hdl,websocketpp::close::status::going_away, "", ec);
-        if (ec) {
-                std::cout << "> Error closing connection " << ec.message() << std::endl;
-            }
-        //send_wav_data();
+        WaitABit(); 
+
         asio_thread.join();
 
     }
@@ -206,7 +218,7 @@ class WebsocketClient {
                 }
             }
             if (wait) {
-                std::cout << "wait.." << m_open << std::endl;
+                LOG(INFO) << "wait.." << m_open;
                 WaitABit();
                 continue;
             }
@@ -236,7 +248,7 @@ class WebsocketClient {
             // send data to server
             m_client.send(m_hdl, iArray, len * sizeof(short),
                           websocketpp::frame::opcode::binary, ec);
-            std::cout << "sended data len=" << len * sizeof(short) << std::endl;
+            LOG(INFO) << "sended data len=" << len * sizeof(short);
             // The most likely error that we will get is that the connection is
             // not in the right state. Usually this means we tried to send a
             // message to a connection that was closed or in the process of
@@ -247,14 +259,13 @@ class WebsocketClient {
                                         "Send Error: " + ec.message());
               break;
             }
-
-            WaitABit();
+            // WaitABit();
         }
         nlohmann::json jsonresult;
         jsonresult["is_speaking"] = false;
         m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
                       ec);
-        WaitABit();
+        // WaitABit();
     }
     websocketpp::client<T> m_client;
 
@@ -263,6 +274,7 @@ class WebsocketClient {
     websocketpp::lib::mutex m_lock;
     bool m_open;
     bool m_done;
+	int total_num=0;
 };
 
 int main(int argc, char* argv[]) {

+ 20 - 19
funasr/runtime/websocket/funasr-ws-server.cpp → funasr/runtime/websocket/funasr-wss-server.cpp

@@ -5,7 +5,7 @@
 /* 2022-2023 by zhaomingwork */
 
 // io server
-// Usage:websocketmain  [--model_thread_num <int>] [--decoder_thread_num <int>]
+// Usage:funasr-ws-server  [--model_thread_num <int>] [--decoder_thread_num <int>]
 //                    [--io_thread_num <int>] [--port <int>] [--listen_ip
 //                    <string>] [--punc-quant <string>] [--punc-dir <string>]
 //                    [--vad-quant <string>] [--vad-dir <string>] [--quantize
@@ -15,44 +15,43 @@
 using namespace std;
 void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
               std::map<std::string, std::string>& model_path) {
-  if (value_arg.isSet()) {
     model_path.insert({key, value_arg.getValue()});
     LOG(INFO) << key << " : " << value_arg.getValue();
-  }
 }
 int main(int argc, char* argv[]) {
   try {
     google::InitGoogleLogging(argv[0]);
     FLAGS_logtostderr = true;
 
-    TCLAP::CmdLine cmd("websocketmain", ' ', "1.0");
+    TCLAP::CmdLine cmd("funasr-ws-server", ' ', "1.0");
     TCLAP::ValueArg<std::string> model_dir(
         "", MODEL_DIR,
-        "the asr model path, which contains model.onnx, config.yaml, am.mvn",
-        true, "", "string");
+        "default: /workspace/models/asr, the asr model path, which contains model.onnx, config.yaml, am.mvn",
+        false, "/workspace/models/asr", "string");
     TCLAP::ValueArg<std::string> quantize(
         "", QUANTIZE,
-        "false (Default), load the model of model.onnx in model_dir. If set "
+        "true (Default), load the model of model.onnx in model_dir. If set "
         "true, load the model of model_quant.onnx in model_dir",
-        false, "false", "string");
+        false, "true", "string");
     TCLAP::ValueArg<std::string> vad_dir(
         "", VAD_DIR,
-        "the vad model path, which contains model.onnx, vad.yaml, vad.mvn",
-        false, "", "string");
+        "default: /workspace/models/vad, the vad model path, which contains model.onnx, vad.yaml, vad.mvn",
+        false, "/workspace/models/vad", "string");
     TCLAP::ValueArg<std::string> vad_quant(
         "", VAD_QUANT,
-        "false (Default), load the model of model.onnx in vad_dir. If set "
+        "true (Default), load the model of model.onnx in vad_dir. If set "
         "true, load the model of model_quant.onnx in vad_dir",
-        false, "false", "string");
+        false, "true", "string");
     TCLAP::ValueArg<std::string> punc_dir(
         "", PUNC_DIR,
-        "the punc model path, which contains model.onnx, punc.yaml", false, "",
+        "default: /workspace/models/punc, the punc model path, which contains model.onnx, punc.yaml", 
+        false, "/workspace/models/punc",
         "string");
     TCLAP::ValueArg<std::string> punc_quant(
         "", PUNC_QUANT,
-        "false (Default), load the model of model.onnx in punc_dir. If set "
+        "true (Default), load the model of model.onnx in punc_dir. If set "
         "true, load the model of model_quant.onnx in punc_dir",
-        false, "false", "string");
+        false, "true", "string");
 
     TCLAP::ValueArg<std::string> listen_ip("", "listen_ip", "listen_ip", false,
                                            "0.0.0.0", "string");
@@ -64,10 +63,12 @@ int main(int argc, char* argv[]) {
     TCLAP::ValueArg<int> model_thread_num("", "model_thread_num",
                                           "model_thread_num", false, 1, "int");
 
-    TCLAP::ValueArg<std::string> certfile("", "certfile", "certfile", false, "",
-                                          "string");
-    TCLAP::ValueArg<std::string> keyfile("", "keyfile", "keyfile", false, "",
-                                         "string");
+    TCLAP::ValueArg<std::string> certfile("", "certfile", 
+        "default: ../../../ssl_key/server.crt, path of certficate for WSS connection. if it is empty, it will be in WS mode.", 
+        false, "../../../ssl_key/server.crt", "string");
+    TCLAP::ValueArg<std::string> keyfile("", "keyfile", 
+        "default: ../../../ssl_key/server.key, path of keyfile for WSS connection", 
+        false, "../../../ssl_key/server.key", "string");
 
     cmd.add(certfile);
     cmd.add(keyfile);

+ 12 - 15
funasr/runtime/websocket/readme.md

@@ -51,7 +51,7 @@ make
 
 ```shell
 cd bin
-   ./funasr-ws-server  [--model_thread_num <int>] [--decoder_thread_num <int>]
+./funasr-wss-server  [--model_thread_num <int>] [--decoder_thread_num <int>]
                     [--io_thread_num <int>] [--port <int>] [--listen_ip
                     <string>] [--punc-quant <string>] [--punc-dir <string>]
                     [--vad-quant <string>] [--vad-dir <string>] [--quantize
@@ -59,19 +59,19 @@ cd bin
                     [--certfile <string>] [--] [--version] [-h]
 Where:
    --model-dir <string>
-     (required)  the asr model path, which contains model.onnx, config.yaml, am.mvn
+     default: /workspace/models/asr, the asr model path, which contains model.onnx, config.yaml, am.mvn
    --quantize <string>
-     false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
+     true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
 
    --vad-dir <string>
-     the vad model path, which contains model.onnx, vad.yaml, vad.mvn
+     default: /workspace/models/vad, the vad model path, which contains model.onnx, vad.yaml, vad.mvn
    --vad-quant <string>
-     false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir
+     true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir
 
    --punc-dir <string>
-     the punc model path, which contains model.onnx, punc.yaml
+     default: /workspace/models/punc, the punc model path, which contains model.onnx, punc.yaml
    --punc-quant <string>
-     false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir
+     true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir
 
    --decoder_thread_num <int>
      number of threads for decoder, default:8
@@ -80,21 +80,18 @@ Where:
    --port <int>
      listen port, default:8889
    --certfile <string>
-     path of certficate for WSS connection. if it is empty, it will be in WS mode.
+     default: ../../../ssl_key/server.crt, path of certficate for WSS connection. if it is empty, it will be in WS mode.
    --keyfile <string>
-     path of keyfile for WSS connection
+     default: ../../../ssl_key/server.key, path of keyfile for WSS connection
   
-   Required:  --model-dir <string>
-   If use vad, please add: --vad-dir <string>
-   If use punc, please add: --punc-dir <string>
 example:
-   funasr-ws-server --model-dir /FunASR/funasr/runtime/onnxruntime/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
+./funasr-wss-server --model-dir /FunASR/funasr/runtime/onnxruntime/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
 ```
 
 ## Run websocket client test
 
 ```shell
-./funasr-ws-client  --server-ip <string>
+./funasr-wss-client  --server-ip <string>
                     --port <string>
                     --wav-path <string>
                     [--thread-num <int>] 
@@ -119,7 +116,7 @@ Where:
      is-ssl is 1 means use wss connection, or use ws connection
 
 example:
-./funasr-ws-client --server-ip 127.0.0.1 --port 8889 --wav-path test.wav --thread-num 1 --is-ssl 0
+./funasr-wss-client --server-ip 127.0.0.1 --port 8889 --wav-path test.wav --thread-num 1 --is-ssl 1
 
 result json, example like:
 {"mode":"offline","text":"欢迎大家来体验达摩院推出的语音识别模型","wav_name":"wav2"}

+ 13 - 15
funasr/runtime/websocket/websocket-server.cpp

@@ -22,12 +22,11 @@ context_ptr WebSocketServer::on_tls_init(tls_mode mode,
                                          std::string& s_keyfile) {
   namespace asio = websocketpp::lib::asio;
 
-  std::cout << "on_tls_init called with hdl: " << hdl.lock().get() << std::endl;
-  std::cout << "using TLS mode: "
+  LOG(INFO) << "on_tls_init called with hdl: " << hdl.lock().get();
+  LOG(INFO) << "using TLS mode: "
             << (mode == MOZILLA_MODERN ? "Mozilla Modern"
-                                       : "Mozilla Intermediate")
-            << std::endl;
-
+                                       : "Mozilla Intermediate");
+                                       
   context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
       asio::ssl::context::sslv23);
 
@@ -49,7 +48,7 @@ context_ptr WebSocketServer::on_tls_init(tls_mode mode,
     ctx->use_private_key_file(s_keyfile, asio::ssl::context::pem);
 
   } catch (std::exception& e) {
-    std::cout << "Exception: " << e.what() << std::endl;
+    LOG(INFO) << "Exception: " << e.what();
   }
   return ctx;
 }
@@ -86,8 +85,7 @@ void WebSocketServer::do_decoder(const std::vector<char>& buffer,
                       ec);
       }
 
-      std::cout << "buffer.size=" << buffer.size()
-                << ",result json=" << jsonresult.dump() << std::endl;
+      LOG(INFO) << "buffer.size=" << buffer.size() << ",result json=" << jsonresult.dump();
       if (!isonline) {
         //  close the client if it is not online asr
         // server_->close(hdl, websocketpp::close::status::normal, "DONE", ec);
@@ -110,14 +108,14 @@ void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
   data_msg->samples = std::make_shared<std::vector<char>>();
   data_msg->msg = nlohmann::json::parse("{}");
   data_map.emplace(hdl, data_msg);
-  std::cout << "on_open, active connections: " << data_map.size() << std::endl;
+  LOG(INFO) << "on_open, active connections: " << data_map.size();
 }
 
 void WebSocketServer::on_close(websocketpp::connection_hdl hdl) {
   scoped_lock guard(m_lock);
   data_map.erase(hdl);  // remove data vector when  connection is closed
 
-  std::cout << "on_close, active connections: " << data_map.size() << std::endl;
+  LOG(INFO) << "on_close, active connections: " << data_map.size();
 }
 
 // remove closed connection
@@ -143,7 +141,7 @@ void WebSocketServer::check_and_clean_connection() {
   }
   for (auto hdl : to_remove) {
     data_map.erase(hdl);
-    std::cout << "remove one connection " << std::endl;
+    LOG(INFO)<< "remove one connection ";
   }
 }
 void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
@@ -161,7 +159,7 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
 
   lock.unlock();
   if (sample_data_p == nullptr) {
-    std::cout << "error when fetch sample data vector" << std::endl;
+    LOG(INFO) << "error when fetch sample data vector";
     return;
   }
 
@@ -176,7 +174,7 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
 
       if (jsonresult["is_speaking"] == false ||
           jsonresult["is_finished"] == true) {
-        std::cout << "client done" << std::endl;
+        LOG(INFO) << "client done";
 
         if (isonline) {
           // do_close(ws);
@@ -225,9 +223,9 @@ void WebSocketServer::initAsr(std::map<std::string, std::string>& model_path,
     // init model with api
 
     asr_hanlde = FunOfflineInit(model_path, thread_num);
-    std::cout << "model ready" << std::endl;
+    LOG(INFO) << "model successfully inited";
 
   } catch (const std::exception& e) {
-    std::cout << e.what() << std::endl;
+    LOG(INFO) << e.what();
   }
 }

+ 6 - 0
tests/test_asr_inference_pipeline.py

@@ -87,6 +87,7 @@ class TestParaformerInferencePipelines(unittest.TestCase):
         rec_result = inference_pipeline(
             audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_hotword.wav')
         logger.info("asr inference result: {0}".format(rec_result))
+        assert rec_result["text"] == "国务院发展研究中心市场经济研究所副所长邓郁松认为"
 
     def test_paraformer_large_aishell1(self):
         inference_pipeline = pipeline(
@@ -95,6 +96,7 @@ class TestParaformerInferencePipelines(unittest.TestCase):
         rec_result = inference_pipeline(
             audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
         logger.info("asr inference result: {0}".format(rec_result))
+        assert rec_result["text"] == "欢迎大家来体验达摩院推出的语音识别模型"
 
     def test_paraformer_large_aishell2(self):
         inference_pipeline = pipeline(
@@ -103,6 +105,7 @@ class TestParaformerInferencePipelines(unittest.TestCase):
         rec_result = inference_pipeline(
             audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
         logger.info("asr inference result: {0}".format(rec_result))
+        assert rec_result["text"] == "欢迎大家来体验达摩院推出的语音识别模型"
 
     def test_paraformer_large_common(self):
         inference_pipeline = pipeline(
@@ -111,6 +114,7 @@ class TestParaformerInferencePipelines(unittest.TestCase):
         rec_result = inference_pipeline(
             audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
         logger.info("asr inference result: {0}".format(rec_result))
+        assert rec_result["text"] == "欢迎大家来体验达摩院推出的语音识别模型"
 
     def test_paraformer_large_online_common(self):
         inference_pipeline = pipeline(
@@ -119,6 +123,7 @@ class TestParaformerInferencePipelines(unittest.TestCase):
         rec_result = inference_pipeline(
             audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
         logger.info("asr inference result: {0}".format(rec_result))
+        assert rec_result["text"] == "欢迎大 家来 体验达 摩院推 出的 语音识 别模 型"
 
     def test_paraformer_online_common(self):
         inference_pipeline = pipeline(
@@ -127,6 +132,7 @@ class TestParaformerInferencePipelines(unittest.TestCase):
         rec_result = inference_pipeline(
             audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
         logger.info("asr inference result: {0}".format(rec_result))
+        assert rec_result["text"] == "欢迎 大家来 体验达 摩院推 出的 语音识 别模 型"
 
     def test_paraformer_tiny_commandword(self):
         inference_pipeline = pipeline(

+ 1 - 0
tests/test_asr_vad_punc_inference_pipeline.py

@@ -26,6 +26,7 @@ class TestParaformerInferencePipelines(unittest.TestCase):
         rec_result = inference_pipeline(
             audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
         logger.info("asr_vad_punc inference result: {0}".format(rec_result))
+        assert rec_result["text"] == "欢迎大家来体验达摩院推出的语音识别模型。"
 
 
 if __name__ == '__main__':

+ 3 - 4
tests/test_sv_inference_pipeline.py

@@ -24,16 +24,15 @@ class TestXVectorInferencePipelines(unittest.TestCase):
         rec_result = inference_sv_pipline(audio_in=(
             'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav',
             'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_same.wav'))
-        assert abs(rec_result["scores"][0] - 0.85) < 0.1 and abs(rec_result["scores"][1] - 0.14) < 0.1
+        assert abs(rec_result["scores"][0]-0.85) < 0.1 and abs(rec_result["scores"][1]-0.14) < 0.1
         logger.info(f"Similarity {rec_result['scores']}")
-
+    
         # different speaker
         rec_result = inference_sv_pipline(audio_in=(
             'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav',
             'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav'))
-        assert abs(rec_result["scores"][0] - 0.0) < 0.1 and abs(rec_result["scores"][1] - 1.0) < 0.1
+        assert abs(rec_result["scores"][0]-0.0) < 0.1 and abs(rec_result["scores"][1]-1.0) < 0.1
         logger.info(f"Similarity {rec_result['scores']}")
 
-
 if __name__ == '__main__':
     unittest.main()