Browse Source

[feature] support 2pass grpc cpp server and python client, can change mode to use offline, online or 2pass decoding

boji123 2 years ago
parent
commit
a8f9253214

+ 3 - 17
funasr/runtime/grpc/CMakeLists.txt

@@ -1,21 +1,7 @@
-# Copyright 2018 gRPC authors.
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+# Reserved. MIT License  (https://opensource.org/licenses/MIT)
 #
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-# cmake build file for C++ paraformer example.
-# Assumes protobuf and gRPC have been installed using cmake.
-# See cmake_externalproject/CMakeLists.txt for all-in-one cmake build
-# that automatically builds all the dependencies before building paraformer.
+# 2023 by burkliu(刘柏基) liubaiji@xverse.cn
 
 cmake_minimum_required(VERSION 3.10)
 

+ 3 - 1
funasr/runtime/grpc/build.sh

@@ -4,7 +4,9 @@ rm build -rf
 mkdir -p build
 cd build
 
-cmake -DCMAKE_BUILD_TYPE=release ../ \
+mode=debug #[debug|release]
+
+cmake -DCMAKE_BUILD_TYPE=$mode ../ \
   -DONNXRUNTIME_DIR=/cfs/user/burkliu/work2023/FunASR/funasr/runtime/onnxruntime/onnxruntime-linux-x64-1.14.0 \
   -DFFMPEG_DIR=/cfs/user/burkliu/work2023/FunASR/funasr/runtime/onnxruntime/ffmpeg-N-111383-g20b8688092-linux64-gpl-shared
 cmake --build . -j 4

+ 183 - 79
funasr/runtime/grpc/paraformer-server.cc

@@ -1,93 +1,192 @@
-#include "paraformer-server.h"
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+ * Reserved. MIT License  (https://opensource.org/licenses/MIT)
+ */
+/* 2023 by burkliu(刘柏基) liubaiji@xverse.cn */
 
-using paraformer::Request;
-using paraformer::Response;
-using paraformer::ASR;
+#include "paraformer-server.h"
 
 GrpcEngine::GrpcEngine(
   grpc::ServerReaderWriter<Response, Request>* stream,
   std::shared_ptr<FUNASR_HANDLE> asr_handler)
   : stream_(std::move(stream)),
-    asr_handler_(std::move(asr_handler)) {}
+    asr_handler_(std::move(asr_handler)) {
 
-void GrpcEngine::operator()() {
-  Request request;
-  while (stream_->Read(&request)) {
-    Response respond;
-    respond.set_user(request.user());
-    respond.set_language(request.language());
-
-    if (request.isend()) {
-      std::cout << "asr end" << std::endl;
-      respond.set_sentence(R"({"success": true, "detail": "asr end"})");
-      respond.set_action("terminate");
-      stream_->Write(respond);
-    } else if (request.speaking()) {
-      if (request.audio_data().size() > 0) {
-        auto& buf = client_buffers[request.user()];
-        buf.insert(buf.end(), request.audio_data().begin(), request.audio_data().end());
+  request_ = std::make_shared<Request>();
+}
+
+void GrpcEngine::DecodeThreadFunc() {
+  FUNASR_HANDLE tpass_online_handler = FunTpassOnlineInit(*asr_handler_, chunk_size_);
+  int step = (sampling_rate_ * step_duration_ms_ / 1000) * 2; // int16 = 2bytes;
+  std::vector<std::vector<std::string>> punc_cache(2);
+
+  bool is_final = false;
+  std::string online_result = "";
+  std::string tpass_result = "";
+
+  LOG(INFO) << "Decoder init, start decoding loop with mode";
+
+  while (true) {
+    if (audio_buffer_.length() > step || is_end_) {
+      if (audio_buffer_.length() <= step && is_end_) {
+        is_final = true;
+        step = audio_buffer_.length();
       }
-      respond.set_sentence(R"({"success": true, "detail": "speaking"})");
-      respond.set_action("speaking");
-      stream_->Write(respond);
-    } else {
-      if (client_buffers.count(request.user()) == 0 && request.audio_data().size() == 0) {
-        respond.set_sentence(R"({"success": true, "detail": "waiting_for_voice"})");
-        respond.set_action("waiting");
-        stream_->Write(respond);
-      } else {
-        auto begin_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
-        if (request.audio_data().size() > 0) {
-          auto& buf = client_buffers[request.user()];
-          buf.insert(buf.end(), request.audio_data().begin(), request.audio_data().end());
+
+      FUNASR_RESULT result = FunTpassInferBuffer(*asr_handler_,
+                                                 tpass_online_handler,
+                                                 audio_buffer_.c_str(),
+                                                 step,
+                                                 punc_cache,
+                                                 is_final,
+                                                 sampling_rate_,
+                                                 encoding_,
+                                                 mode_);
+      audio_buffer_ = audio_buffer_.substr(step);
+
+      if (result) {
+        std::string online_message = FunASRGetResult(result, 0);
+        online_result += online_message;
+        if(online_message != ""){
+          Response response;
+          response.set_mode(DecodeMode::online);
+          response.set_text(online_message);
+          response.set_is_final(is_final);
+          stream_->Write(response);
+          LOG(INFO) << "send online results: " << online_message;
         }
-        std::string tmp_data = this->client_buffers[request.user()];
-
-        int data_len_int = tmp_data.length();
-        std::string data_len = std::to_string(data_len_int);
-        std::stringstream ss;
-        ss << R"({"success": true, "detail": "decoding data: )" << data_len << R"( bytes")"  << R"("})";
-
-        respond.set_sentence(ss.str());
-        respond.set_action("decoding");
-        stream_->Write(respond);
-
-        // start recoginize
-        std::string asr_result;
-        if (tmp_data.length() < 800) { //min input_len for asr model
-          asr_result = "";
-          std::cout << "error: data_is_not_long_enough" << std::endl;
-        } else {
-          FUNASR_RESULT result = FunOfflineInferBuffer(*asr_handler_, tmp_data.c_str(), data_len_int, RASR_NONE, NULL, 16000);
-          asr_result = ((FUNASR_RECOG_RESULT*) result)->msg;
+        std::string tpass_message = FunASRGetTpassResult(result, 0);
+        tpass_result += tpass_message;
+        if(tpass_message != ""){
+          Response response;
+          response.set_mode(DecodeMode::two_pass);
+          response.set_text(tpass_message);
+          response.set_is_final(is_final);
+          stream_->Write(response);
+          LOG(INFO) << "send offline results: " << tpass_message;
         }
+        FunASRFreeResult(result);
+      }
+
+      if (is_final) {
+        FunTpassOnlineUninit(tpass_online_handler);
+        break;
+      }
+    }
+    sleep(0.001);
+  }
+}
 
-        auto end_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
-        std::string delay_str = std::to_string(end_time - begin_time);
-        std::cout << "user: " << request.user() << " , delay(ms): " << delay_str << ", text: " << asr_result << std::endl;
-        std::stringstream ss2;
-        ss2 << R"({"success": true, "detail": "finish_sentence","server_delay_ms":)" << delay_str << R"(,"text":")" << asr_result << R"("})";
+void GrpcEngine::OnSpeechStart() {
+  if (request_->chunk_size_size() == 3) {
+    for (int i = 0; i < 3; i++) {
+      chunk_size_[i] = int(request_->chunk_size(i));
+    }
+  }
+  std::string chunk_size_str;
+  for (int i = 0; i < 3; i++) {
+    chunk_size_str = " " + chunk_size_[i];
+  }
+  LOG(INFO) << "chunk_size is" << chunk_size_str;
 
-        respond.set_sentence(ss2.str());
-        respond.set_action("finish");
-        stream_->Write(respond);
+  if (request_->sampling_rate() != 0) {
+    sampling_rate_ = request_->sampling_rate();
+  }
+  LOG(INFO) << "sampling_rate is " << sampling_rate_;
+
+  switch(request_->wav_format()) {
+    case WavFormat::pcm: encoding_ = "pcm";
+  }
+  LOG(INFO) << "encoding is " << encoding_;
+
+  std::string mode_str;
+  LOG(INFO) << request_->mode() << DecodeMode::offline << DecodeMode::online << DecodeMode::two_pass;
+  switch(request_->mode()) {
+    case DecodeMode::offline:
+      mode_ = ASR_OFFLINE;
+      mode_str = "offline";
+      break;
+    case DecodeMode::online:
+      mode_ = ASR_ONLINE;
+      mode_str = "online";
+      break;
+    case DecodeMode::two_pass:
+      mode_ = ASR_TWO_PASS;
+      mode_str = "two_pass";
+      break;
+  }
+  LOG(INFO) << "decode mode is " << mode_str;
+  
+  decode_thread_ = std::make_shared<std::thread>(&GrpcEngine::DecodeThreadFunc, this);
+  is_start_ = true;
+}
+
+void GrpcEngine::OnSpeechData() {
+  audio_buffer_ += request_->audio_data();
+}
+
+void GrpcEngine::OnSpeechEnd() {
+  is_end_ = true;
+  LOG(INFO) << "Read all pcm data, wait for decoding thread";
+  if (decode_thread_ != nullptr) {
+    decode_thread_->join();
+  }
+}
+
+void GrpcEngine::operator()() {
+  try {
+    LOG(INFO) << "start engine main loop";
+    while (stream_->Read(request_.get())) {
+      LOG(INFO) << "receive data";
+      if (!is_start_) {
+        OnSpeechStart();
+      }
+      OnSpeechData();
+      if (request_->is_final()) {
+        OnSpeechEnd();
+        break;
       }
     }
+    LOG(INFO) << "Connect finish";
+  } catch (std::exception const& e) {
+    LOG(ERROR) << e.what();
   }
 }
 
-GrpcService::GrpcService(std::map<std::string, std::string>& config, int num_thread)
+GrpcService::GrpcService(std::map<std::string, std::string>& config, int onnx_thread)
   : config_(config) {
 
-  asr_handler_ = std::make_shared<FUNASR_HANDLE>(std::move(FunOfflineInit(config_, num_thread)));
-  std::cout << "GrpcService model loades" << std::endl;
+  asr_handler_ = std::make_shared<FUNASR_HANDLE>(std::move(FunTpassInit(config_, onnx_thread)));
+  LOG(INFO) << "GrpcService model loaded";
+
+  std::vector<int> chunk_size = {5, 10, 5};
+  FUNASR_HANDLE tmp_online_handler = FunTpassOnlineInit(*asr_handler_, chunk_size);
+  int sampling_rate = 16000;
+  int buffer_len = sampling_rate * 1;
+  std::string tmp_data(buffer_len, '0');
+  std::vector<std::vector<std::string>> punc_cache(2);
+  bool is_final = true;
+  std::string encoding = "pcm";
+  FUNASR_RESULT result = FunTpassInferBuffer(*asr_handler_,
+                                             tmp_online_handler,
+                                             tmp_data.c_str(),
+                                             buffer_len,
+                                             punc_cache,
+                                             is_final,
+                                             buffer_len,
+                                             encoding,
+                                             ASR_TWO_PASS);
+  if (result) {
+      FunASRFreeResult(result);
+  }
+  FunTpassOnlineUninit(tmp_online_handler);
+  LOG(INFO) << "GrpcService model warmup";
 }
 
 grpc::Status GrpcService::Recognize(
   grpc::ServerContext* context,
   grpc::ServerReaderWriter<Response, Request>* stream) {
-
-  LOG(INFO) << "Get Recognize request" << std::endl;
+  LOG(INFO) << "Get Recognize request";
   GrpcEngine engine(
     stream,
     asr_handler_
@@ -106,29 +205,34 @@ void GetValue(TCLAP::ValueArg<std::string>& value_arg, std::string key, std::map
 }
 
 int main(int argc, char* argv[]) {
-  google::InitGoogleLogging(argv[0]);
   FLAGS_logtostderr = true;
+  google::InitGoogleLogging(argv[0]);
 
-  TCLAP::CmdLine cmd("paraformer-server", ' ', "1.0");
-  TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the asr model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
-  TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
-  TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
-  TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string");
-  TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
-  TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string");
+  TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0");
+  TCLAP::ValueArg<std::string>  offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
+  TCLAP::ValueArg<std::string>  online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
+  TCLAP::ValueArg<std::string>  quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
+  TCLAP::ValueArg<std::string>  vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
+  TCLAP::ValueArg<std::string>  vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
+  TCLAP::ValueArg<std::string>  punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
+  TCLAP::ValueArg<std::string>  punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
+  TCLAP::ValueArg<std::int32_t>  onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
   TCLAP::ValueArg<std::string> port_id("", PORT_ID, "port id", true, "", "string");
 
-  cmd.add(model_dir);
+  cmd.add(offline_model_dir);
+  cmd.add(online_model_dir);
   cmd.add(quantize);
   cmd.add(vad_dir);
   cmd.add(vad_quant);
   cmd.add(punc_dir);
   cmd.add(punc_quant);
+  cmd.add(onnx_thread);
   cmd.add(port_id);
   cmd.parse(argc, argv);
 
   std::map<std::string, std::string> config;
-  GetValue(model_dir, MODEL_DIR, config);
+  GetValue(offline_model_dir, OFFLINE_MODEL_DIR, config);
+  GetValue(online_model_dir, ONLINE_MODEL_DIR, config);
   GetValue(quantize, QUANTIZE, config);
   GetValue(vad_dir, VAD_DIR, config);
   GetValue(vad_quant, VAD_QUANT, config);
@@ -140,18 +244,18 @@ int main(int argc, char* argv[]) {
   try {
     port = config.at(PORT_ID);
   } catch(std::exception const &e) {
-    std::cout << ("Error when read port.") << std::endl;
+    LOG(INFO) << ("Error when read port.");
     exit(0);
   }
   std::string server_address;
   server_address = "0.0.0.0:" + port;
-  GrpcService service(config, 1);
+  GrpcService service(config, onnx_thread);
 
   grpc::ServerBuilder builder;
   builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
   builder.RegisterService(&service);
   std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
-  std::cout << "Server listening on " << server_address << std::endl;
+  LOG(INFO) << "Server listening on " << server_address;
   server->Wait();
 
   return 0;

+ 27 - 15
funasr/runtime/grpc/paraformer-server.h

@@ -1,26 +1,22 @@
-#include <algorithm>
-#include <chrono>
-#include <cmath>
-#include <iostream>
-#include <sstream>
-#include <memory>
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+ * Reserved. MIT License  (https://opensource.org/licenses/MIT)
+ */
+/* 2023 by burkliu(刘柏基) liubaiji@xverse.cn */
+
 #include <string>
-#include <unordered_map>
-#include <chrono>
 #include <thread>
+#include <unistd.h>
 
-#include <grpc/grpc.h>
-#include <grpcpp/server.h>
-#include <grpcpp/server_builder.h>
-#include <grpcpp/server_context.h>
-#include <grpcpp/security/server_credentials.h>
-
+#include "grpcpp/server_builder.h"
 #include "paraformer.grpc.pb.h"
 #include "funasrruntime.h"
 #include "tclap/CmdLine.h"
 #include "com-define.h"
 #include "glog/logging.h"
 
+using paraformer::WavFormat;
+using paraformer::DecodeMode;
 using paraformer::Request;
 using paraformer::Response;
 using paraformer::ASR;
@@ -37,9 +33,25 @@ class GrpcEngine {
   void operator()();
 
  private:
+  void DecodeThreadFunc();
+  void OnSpeechStart();
+  void OnSpeechData();
+  void OnSpeechEnd();
+
   grpc::ServerReaderWriter<Response, Request>* stream_;
+  std::shared_ptr<Request> request_;
+  std::shared_ptr<Response> response_;
   std::shared_ptr<FUNASR_HANDLE> asr_handler_;
-  std::unordered_map<std::string, std::string> client_buffers;
+  std::string audio_buffer_;
+  std::shared_ptr<std::thread> decode_thread_ = nullptr;
+  bool is_start_ = false;
+  bool is_end_ = false;
+
+  std::vector<int> chunk_size_ = {5, 10, 5};
+  int sampling_rate_ = 16000;
+  std::string encoding_;
+  ASR_TYPE mode_ = ASR_TWO_PASS;
+  int step_duration_ms_ = 100;
 };
 
 class GrpcService final : public ASR::Service {

+ 12 - 0
funasr/runtime/grpc/run_server.sh

@@ -0,0 +1,12 @@
+#!/bin/bash
+
+./build/bin/paraformer-server \
+  --port-id 10100 \
+  --offline-model-dir /cfs/user/burkliu/data/funasr_models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
+  --online-model-dir /cfs/user/burkliu/data/funasr_models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online \
+  --quantize true \
+  --vad-dir /cfs/user/burkliu/data/funasr_models/damo/speech_fsmn_vad_zh-cn-16k-common-onnx \
+  --vad-quant true \
+  --punc-dir /cfs/user/burkliu/data/funasr_models/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727 \
+  --punc-quant true \
+  2>&1

+ 0 - 17
funasr/runtime/python/grpc/grpc_client.py

@@ -1,17 +0,0 @@
-import queue
-import paraformer_pb2
-
-def transcribe_audio_bytes(stub, chunk, user='zksz', language='zh-CN', speaking = True, isEnd = False):
-    req = paraformer_pb2.Request()
-    if chunk is not None:
-        req.audio_data = chunk
-    req.user = user
-    req.language = language
-    req.speaking = speaking
-    req.isEnd = isEnd
-    my_queue = queue.SimpleQueue()
-    my_queue.put(req) 
-    return  stub.Recognize(iter(my_queue.get, None))
-
-
-

+ 80 - 58
funasr/runtime/python/grpc/grpc_main_client.py

@@ -1,62 +1,84 @@
-import grpc
-import json
-import time
-import asyncio
-import soundfile as sf
+import logging
 import argparse
+import soundfile as sf
+import asyncio
+import time
 
-from grpc_client import transcribe_audio_bytes
-from paraformer_pb2_grpc import ASRStub
-
-# send the audio data once
-async def grpc_rec(wav_scp, grpc_uri, asr_user, language):
-    with grpc.insecure_channel(grpc_uri) as channel:
-        stub = ASRStub(channel)
-        for line in wav_scp:
-            wav_file = line.split()[1]
-            wav, _ = sf.read(wav_file, dtype='int16')
-            
-            b = time.time()
-            response = transcribe_audio_bytes(stub, wav.tobytes(), user=asr_user, language=language, speaking=False, isEnd=False)
-            resp = response.next()
-            text = ''
-            if 'decoding' == resp.action:
-                resp = response.next()
-                if 'finish' == resp.action:
-                    text = json.loads(resp.sentence)['text']
-            response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking=False, isEnd=True)
-            res= {'text': text, 'time': time.time() - b}
-            print(res)
-
-async def test(args):
-    wav_scp = open(args.wav_scp, "r").readlines()
-    uri = '{}:{}'.format(args.host, args.port)
-    res = await grpc_rec(wav_scp, uri, args.user_allowed, language = 'zh-CN')
+import grpc
+import paraformer_pb2_grpc
+from paraformer_pb2 import Request, WavFormat, DecodeMode
 
-if __name__ == '__main__':
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--host",
-                        type=str,
-                        default="127.0.0.1",
-                        required=False,
-                        help="grpc server host ip")
-    parser.add_argument("--port",
-                        type=int,
-                        default=10108,
-                        required=False,
-                        help="grpc server port")              
-    parser.add_argument("--user_allowed",
-                        type=str,
-                        default="project1_user1",
-                        help="allowed user for grpc client")
-    parser.add_argument("--sample_rate",
-                        type=int,
-                        default=16000,
-                        help="audio sample_rate from client") 
-    parser.add_argument("--wav_scp",
-                        type=str,
-                        required=True,
-                        help="audio wav scp")                    
-    args = parser.parse_args()
+class GrpcClient:
+  def __init__(self, wav_path, uri, mode):
+    self.wav, self.sr = sf.read(wav_path, dtype='int16')
+    self.wav_format = WavFormat.pcm
+    self.audio_chunk_duration = 1000 # ms
+    self.audio_chunk_size = int(self.sr * self.audio_chunk_duration / 1000)
+    self.send_interval = 100 # ms
+
+    # connect to grpc server
+    channel = grpc.insecure_channel(uri)
+    self.stub = paraformer_pb2_grpc.ASRStub(channel)
     
-    asyncio.run(test(args))
+    # start request
+    for respond in self.stub.Recognize(self.request_iterator(
+      audio_chunk_size=self.audio_chunk_size, mode = mode)):
+
+      logging.info("[receive] mode {}, text {}, is final {}".format(
+        DecodeMode.Name(respond.mode), respond.text, respond.is_final))
+
+  def request_iterator(self, audio_chunk_size, mode = DecodeMode.two_pass):
+    is_first_pack = True
+    is_final = False
+    for start in range(0, len(self.wav), audio_chunk_size):
+      request = Request()
+      audio_chunk = self.wav[start:start+audio_chunk_size]
+
+      if is_first_pack:
+        is_first_pack = False
+        request.sampling_rate = 16000
+        request.mode = mode
+        request.wav_format = self.wav_format
+        if request.mode == DecodeMode.two_pass or request.mode == DecodeMode.online:
+          request.chunk_size.extend([5, 10, 5])
+
+      if start + audio_chunk_size >= len(self.wav):
+        is_final = True
+      request.is_final = is_final
+      request.audio_data = audio_chunk.tobytes()
+      logging.info("[request] audio_data len {}, is final {}".format(
+        len(request.audio_data), request.is_final)) # int16 = 2bytes
+      time.sleep(self.send_interval/1000)
+      yield request
+
+if __name__ == '__main__':
+  logging.basicConfig(filename="", format="%(asctime)s %(message)s", level=logging.INFO)
+  parser = argparse.ArgumentParser()
+  parser.add_argument("--host",
+                      type=str,
+                      default="127.0.0.1",
+                      required=False,
+                      help="grpc server host ip")
+  parser.add_argument("--port",
+                      type=int,
+                      default=10100,
+                      required=False,
+                      help="grpc server port")
+  parser.add_argument("--sample_rate",
+                      type=int,
+                      default=16000,
+                      help="audio sample_rate from client")
+  parser.add_argument("--wav_path",
+                      type=str,
+                      required=True,
+                      help="audio wav path")
+  args = parser.parse_args()
+
+  for mode in [DecodeMode.offline, DecodeMode.online, DecodeMode.two_pass]:
+    mode_name = DecodeMode.Name(mode)
+    logging.info("[request] start requesting with mode {}".format(mode_name))
+
+    st = time.time()
+    uri = '{}:{}'.format(args.host, args.port)
+    client = GrpcClient(args.wav_path, uri, mode)
+    logging.info("mode {} time pass: {}".format(mode_name, time.time() - st))

+ 0 - 112
funasr/runtime/python/grpc/grpc_main_client_mic.py

@@ -1,112 +0,0 @@
-import pyaudio
-import grpc
-import json
-import webrtcvad
-import time
-import asyncio
-import argparse
-
-from grpc_client import transcribe_audio_bytes
-from paraformer_pb2_grpc import ASRStub
-
-async def deal_chunk(sig_mic):
-    global stub,SPEAKING,asr_user,language,sample_rate
-    if vad.is_speech(sig_mic, sample_rate): #speaking
-        SPEAKING = True
-        response = transcribe_audio_bytes(stub, sig_mic, user=asr_user, language=language, speaking = True, isEnd = False) #speaking, send audio to server.
-    else: #silence   
-        begin_time = 0
-        if SPEAKING: #means we have some audio recorded, send recognize order to server.
-            SPEAKING = False
-            begin_time = int(round(time.time() * 1000))            
-            response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking = False, isEnd = False) #speak end, call server for recognize one sentence
-            resp = response.next()           
-            if "decoding" == resp.action:   
-                resp = response.next() #TODO, blocking operation may leads to miss some audio clips. C++ multi-threading is preferred.
-                if "finish" == resp.action:        
-                    end_time = int(round(time.time() * 1000))
-                    print (json.loads(resp.sentence))
-                    print ("delay in ms: %d " % (end_time - begin_time))
-                else:
-                    pass
-        
-
-async def record(host,port,sample_rate,mic_chunk,record_seconds,asr_user,language):
-    with grpc.insecure_channel('{}:{}'.format(host, port)) as channel:
-        global stub
-        stub = ASRStub(channel)
-        for i in range(0, int(sample_rate / mic_chunk * record_seconds)):
-     
-            sig_mic = stream.read(mic_chunk,exception_on_overflow = False) 
-            await asyncio.create_task(deal_chunk(sig_mic))
-
-        #end grpc
-        response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking = False, isEnd = True)
-        print (response.next().action)
-
-
-if __name__ == '__main__':
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--host",
-                        type=str,
-                        default="127.0.0.1",
-                        required=True,
-                        help="grpc server host ip")
-                        
-    parser.add_argument("--port",
-                        type=int,
-                        default=10095,
-                        required=True,
-                        help="grpc server port")              
-                        
-    parser.add_argument("--user_allowed",
-                        type=str,
-                        default="project1_user1",
-                        help="allowed user for grpc client")
-                        
-    parser.add_argument("--sample_rate",
-                        type=int,
-                        default=16000,
-                        help="audio sample_rate from client")    
-
-    parser.add_argument("--mic_chunk",
-                        type=int,
-                        default=160,
-                        help="chunk size for mic")  
-
-    parser.add_argument("--record_seconds",
-                        type=int,
-                        default=120,
-                        help="run specified seconds then exit ")                       
-
-    args = parser.parse_args()
-    
-
-    SPEAKING = False
-    asr_user = args.user_allowed
-    sample_rate = args.sample_rate
-    language = 'zh-CN'  
-    
-
-    vad = webrtcvad.Vad()
-    vad.set_mode(1)
-
-    FORMAT = pyaudio.paInt16
-    CHANNELS = 1
-    p = pyaudio.PyAudio()
-    
-    stream = p.open(format=FORMAT,
-                channels=CHANNELS,
-                rate=args.sample_rate,
-                input=True,
-                frames_per_buffer=args.mic_chunk)
-                
-    print("* recording")
-    asyncio.run(record(args.host,args.port,args.sample_rate,args.mic_chunk,args.record_seconds,args.user_allowed,language))
-    stream.stop_stream()
-    stream.close()
-    p.terminate()
-    print("recording stop")
-
-    
-

+ 0 - 68
funasr/runtime/python/grpc/grpc_main_server.py

@@ -1,68 +0,0 @@
-import grpc
-from concurrent import futures
-import argparse
-
-import paraformer_pb2_grpc
-from grpc_server import ASRServicer
-
-def serve(args):
-      server = grpc.server(futures.ThreadPoolExecutor(max_workers=10),
-                        # interceptors=(AuthInterceptor('Bearer mysecrettoken'),)
-                           )
-      paraformer_pb2_grpc.add_ASRServicer_to_server(
-          ASRServicer(args.user_allowed, args.model, args.sample_rate, args.backend, args.onnx_dir, vad_model=args.vad_model, punc_model=args.punc_model), server)
-      port = "[::]:" + str(args.port)
-      server.add_insecure_port(port)
-      server.start()
-      print("grpc server started!")
-      server.wait_for_termination()
-
-if __name__ == '__main__':
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--port",
-                        type=int,
-                        default=10095,
-                        required=True,
-                        help="grpc server port")
-                        
-    parser.add_argument("--user_allowed",
-                        type=str,
-                        default="project1_user1|project1_user2|project2_user3",
-                        help="allowed user for grpc client")
-                        
-    parser.add_argument("--model",
-                        type=str,
-                        default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
-                        help="model from modelscope")
-    parser.add_argument("--vad_model",
-                        type=str,
-                        default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
-                        help="model from modelscope")
-    
-    parser.add_argument("--punc_model",
-                        type=str,
-                        default="",
-                        help="model from modelscope")
-    
-    parser.add_argument("--sample_rate",
-                        type=int,
-                        default=16000,
-                        help="audio sample_rate from client")
-
-    parser.add_argument("--backend",
-                        type=str,
-                        default="pipeline",
-                        choices=("pipeline", "onnxruntime"),
-                        help="backend, optional modelscope pipeline or onnxruntime")
-
-    parser.add_argument("--onnx_dir",
-                        type=str,
-                        default="/nfs/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
-                        help="onnx model dir")
-    
-                        
-
-
-    args = parser.parse_args()
-
-    serve(args)

+ 0 - 132
funasr/runtime/python/grpc/grpc_server.py

@@ -1,132 +0,0 @@
-from concurrent import futures
-import grpc
-import json
-import time
-
-import paraformer_pb2_grpc
-from paraformer_pb2 import Response
-
-
-class ASRServicer(paraformer_pb2_grpc.ASRServicer):
-    def __init__(self, user_allowed, model, sample_rate, backend, onnx_dir, vad_model='', punc_model=''):
-        print("ASRServicer init")
-        self.backend = backend
-        self.init_flag = 0
-        self.client_buffers = {}
-        self.client_transcription = {}
-        self.auth_user = user_allowed.split("|")
-        if self.backend == "pipeline":
-            try:
-                from modelscope.pipelines import pipeline
-                from modelscope.utils.constant import Tasks
-            except ImportError:
-                raise ImportError(f"Please install modelscope")
-            self.inference_16k_pipeline = pipeline(task=Tasks.auto_speech_recognition, model=model, vad_model=vad_model, punc_model=punc_model)
-        elif self.backend == "onnxruntime":
-            try:
-                from funasr_onnx import Paraformer
-            except ImportError:
-                raise ImportError(f"Please install onnxruntime environment")
-            self.inference_16k_pipeline = Paraformer(model_dir=onnx_dir)
-        self.sample_rate = sample_rate
-
-    def clear_states(self, user):
-        self.clear_buffers(user)
-        self.clear_transcriptions(user)
-
-    def clear_buffers(self, user):
-        if user in self.client_buffers:
-            del self.client_buffers[user]
-
-    def clear_transcriptions(self, user):
-        if user in self.client_transcription:
-            del self.client_transcription[user]
-
-    def disconnect(self, user):
-        self.clear_states(user)
-        print("Disconnecting user: %s" % str(user))
-
-    def Recognize(self, request_iterator, context):
-        
-            
-        for req in request_iterator:
-            if req.user not in self.auth_user:
-                result = {}
-                result["success"] = False
-                result["detail"] = "Not Authorized user: %s " % req.user
-                result["text"] = ""
-                yield Response(sentence=json.dumps(result), user=req.user, action="terminate", language=req.language)
-            elif req.isEnd: #end grpc
-                print("asr end")
-                self.disconnect(req.user)
-                result = {}
-                result["success"] = True
-                result["detail"] = "asr end"
-                result["text"] = ""
-                yield Response(sentence=json.dumps(result), user=req.user, action="terminate",language=req.language)
-            elif req.speaking: #continue speaking
-                if req.audio_data is not None and len(req.audio_data) > 0:
-                    if req.user in self.client_buffers:
-                        self.client_buffers[req.user] += req.audio_data #append audio
-                    else:
-                        self.client_buffers[req.user] = req.audio_data
-                result = {}
-                result["success"] = True
-                result["detail"] = "speaking"
-                result["text"] = ""
-                yield Response(sentence=json.dumps(result), user=req.user, action="speaking", language=req.language)
-            elif not req.speaking: #silence
-                if req.user not in self.client_buffers:
-                    result = {}
-                    result["success"] = True
-                    result["detail"] = "waiting_for_more_voice"
-                    result["text"] = ""
-                    yield Response(sentence=json.dumps(result), user=req.user, action="waiting", language=req.language)
-                else:
-                    begin_time = int(round(time.time() * 1000))
-                    tmp_data = self.client_buffers[req.user]
-                    self.clear_states(req.user)
-                    result = {}
-                    result["success"] = True
-                    result["detail"] = "decoding data: %d bytes" % len(tmp_data)
-                    result["text"] = ""
-                    yield Response(sentence=json.dumps(result), user=req.user, action="decoding", language=req.language)
-                    if len(tmp_data) < 9600: #min input_len for asr model , 300ms
-                        end_time = int(round(time.time() * 1000))
-                        delay_str = str(end_time - begin_time)
-                        result = {}
-                        result["success"] = True
-                        result["detail"] = "waiting_for_more_voice"
-                        result["server_delay_ms"] = delay_str
-                        result["text"] = ""
-                        print ("user: %s , delay(ms): %s, info: %s " % (req.user, delay_str, "waiting_for_more_voice"))
-                        yield Response(sentence=json.dumps(result), user=req.user, action="waiting", language=req.language)
-                    else:
-                        if self.backend == "pipeline":
-                            asr_result = self.inference_16k_pipeline(audio_in=tmp_data, audio_fs = self.sample_rate)
-                            if "text" in asr_result:
-                                asr_result = asr_result['text']
-                            else:
-                                asr_result = ""
-                        elif self.backend == "onnxruntime":
-                            from funasr_onnx.utils.frontend import load_bytes
-                            array = load_bytes(tmp_data)
-                            asr_result = self.inference_16k_pipeline(array)[0]
-                        end_time = int(round(time.time() * 1000))
-                        delay_str = str(end_time - begin_time)
-                        print ("user: %s , delay(ms): %s, text: %s " % (req.user, delay_str, asr_result))
-                        result = {}
-                        result["success"] = True
-                        result["detail"] = "finish_sentence"
-                        result["server_delay_ms"] = delay_str
-                        result["text"] = asr_result
-                        yield Response(sentence=json.dumps(result), user=req.user, action="finish", language=req.language)
-            else:
-                result = {}
-                result["success"] = False 
-                result["detail"] = "error, no condition matched! Unknown reason."
-                result["text"] = ""
-                self.disconnect(req.user)
-                yield Response(sentence=json.dumps(result), user=req.user, action="terminate", language=req.language)
-                
-

+ 0 - 30
funasr/runtime/python/grpc/paraformer_pb2.py

@@ -1,30 +0,0 @@
-# -*- coding: utf-8 -*-
-# Generated by the protocol buffer compiler.  DO NOT EDIT!
-# source: paraformer.proto
-"""Generated protocol buffer code."""
-from google.protobuf.internal import builder as _builder
-from google.protobuf import descriptor as _descriptor
-from google.protobuf import descriptor_pool as _descriptor_pool
-from google.protobuf import symbol_database as _symbol_database
-# @@protoc_insertion_point(imports)
-
-_sym_db = _symbol_database.Default()
-
-
-
-
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10paraformer.proto\x12\nparaformer\"^\n\x07Request\x12\x12\n\naudio_data\x18\x01 \x01(\x0c\x12\x0c\n\x04user\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x10\n\x08speaking\x18\x04 \x01(\x08\x12\r\n\x05isEnd\x18\x05 \x01(\x08\"L\n\x08Response\x12\x10\n\x08sentence\x18\x01 \x01(\t\x12\x0c\n\x04user\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x0e\n\x06\x61\x63tion\x18\x04 \x01(\t2C\n\x03\x41SR\x12<\n\tRecognize\x12\x13.paraformer.Request\x1a\x14.paraformer.Response\"\x00(\x01\x30\x01\x42\x16\n\x07\x65x.grpc\xa2\x02\nparaformerb\x06proto3')
-
-_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
-_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'paraformer_pb2', globals())
-if _descriptor._USE_C_DESCRIPTORS == False:
-
-  DESCRIPTOR._options = None
-  DESCRIPTOR._serialized_options = b'\n\007ex.grpc\242\002\nparaformer'
-  _REQUEST._serialized_start=32
-  _REQUEST._serialized_end=126
-  _RESPONSE._serialized_start=128
-  _RESPONSE._serialized_end=204
-  _ASR._serialized_start=206
-  _ASR._serialized_end=273
-# @@protoc_insertion_point(module_scope)

+ 24 - 9
funasr/runtime/python/grpc/proto/paraformer.proto

@@ -1,3 +1,8 @@
+// Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+// Reserved. MIT License  (https://opensource.org/licenses/MIT)
+//
+// 2023 by burkliu(刘柏基) liubaiji@xverse.cn
+
 syntax = "proto3";
 
 option objc_class_prefix = "paraformer";
@@ -8,17 +13,27 @@ service ASR {
   rpc Recognize (stream Request) returns (stream Response) {}
 }
 
+enum WavFormat {
+  pcm = 0;
+}
+
+enum DecodeMode {
+  offline = 0;
+  online = 1;
+  two_pass = 2;
+}
+
 message Request {
-  bytes audio_data = 1;
-  string user = 2;
-  string language = 3;
-  bool speaking = 4;
-  bool isEnd = 5;
+  DecodeMode mode = 1;
+  WavFormat wav_format = 2;
+  int32 sampling_rate = 3;
+  repeated int32 chunk_size = 4;
+  bool is_final = 5;
+  bytes audio_data = 6;
 }
 
 message Response {
-  string sentence = 1;
-  string user = 2;
-  string language = 3;
-  string action = 4;
+  DecodeMode mode = 1;
+  string text = 2;
+  bool is_final = 3;
 }

+ 0 - 0
funasr/runtime/python/grpc/requirements_client.txt → funasr/runtime/python/grpc/requirements.txt


+ 0 - 2
funasr/runtime/python/grpc/requirements_server.txt

@@ -1,2 +0,0 @@
-grpcio
-grpcio-tools