|
|
@@ -1,93 +1,192 @@
|
|
|
-#include "paraformer-server.h"
|
|
|
+/**
|
|
|
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
|
|
+ * Reserved. MIT License (https://opensource.org/licenses/MIT)
|
|
|
+ */
|
|
|
+/* 2023 by burkliu(刘柏基) liubaiji@xverse.cn */
|
|
|
|
|
|
-using paraformer::Request;
|
|
|
-using paraformer::Response;
|
|
|
-using paraformer::ASR;
|
|
|
+#include "paraformer-server.h"
|
|
|
|
|
|
GrpcEngine::GrpcEngine(
|
|
|
grpc::ServerReaderWriter<Response, Request>* stream,
|
|
|
std::shared_ptr<FUNASR_HANDLE> asr_handler)
|
|
|
: stream_(std::move(stream)),
|
|
|
- asr_handler_(std::move(asr_handler)) {}
|
|
|
+ asr_handler_(std::move(asr_handler)) {
|
|
|
|
|
|
-void GrpcEngine::operator()() {
|
|
|
- Request request;
|
|
|
- while (stream_->Read(&request)) {
|
|
|
- Response respond;
|
|
|
- respond.set_user(request.user());
|
|
|
- respond.set_language(request.language());
|
|
|
-
|
|
|
- if (request.isend()) {
|
|
|
- std::cout << "asr end" << std::endl;
|
|
|
- respond.set_sentence(R"({"success": true, "detail": "asr end"})");
|
|
|
- respond.set_action("terminate");
|
|
|
- stream_->Write(respond);
|
|
|
- } else if (request.speaking()) {
|
|
|
- if (request.audio_data().size() > 0) {
|
|
|
- auto& buf = client_buffers[request.user()];
|
|
|
- buf.insert(buf.end(), request.audio_data().begin(), request.audio_data().end());
|
|
|
+ request_ = std::make_shared<Request>();
|
|
|
+}
|
|
|
+
|
|
|
+void GrpcEngine::DecodeThreadFunc() {
|
|
|
+ FUNASR_HANDLE tpass_online_handler = FunTpassOnlineInit(*asr_handler_, chunk_size_);
|
|
|
+ int step = (sampling_rate_ * step_duration_ms_ / 1000) * 2; // int16 = 2bytes;
|
|
|
+ std::vector<std::vector<std::string>> punc_cache(2);
|
|
|
+
|
|
|
+ bool is_final = false;
|
|
|
+ std::string online_result = "";
|
|
|
+ std::string tpass_result = "";
|
|
|
+
|
|
|
+ LOG(INFO) << "Decoder init, start decoding loop with mode";
|
|
|
+
|
|
|
+ while (true) {
|
|
|
+ if (audio_buffer_.length() > step || is_end_) {
|
|
|
+ if (audio_buffer_.length() <= step && is_end_) {
|
|
|
+ is_final = true;
|
|
|
+ step = audio_buffer_.length();
|
|
|
}
|
|
|
- respond.set_sentence(R"({"success": true, "detail": "speaking"})");
|
|
|
- respond.set_action("speaking");
|
|
|
- stream_->Write(respond);
|
|
|
- } else {
|
|
|
- if (client_buffers.count(request.user()) == 0 && request.audio_data().size() == 0) {
|
|
|
- respond.set_sentence(R"({"success": true, "detail": "waiting_for_voice"})");
|
|
|
- respond.set_action("waiting");
|
|
|
- stream_->Write(respond);
|
|
|
- } else {
|
|
|
- auto begin_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
|
|
- if (request.audio_data().size() > 0) {
|
|
|
- auto& buf = client_buffers[request.user()];
|
|
|
- buf.insert(buf.end(), request.audio_data().begin(), request.audio_data().end());
|
|
|
+
|
|
|
+ FUNASR_RESULT result = FunTpassInferBuffer(*asr_handler_,
|
|
|
+ tpass_online_handler,
|
|
|
+ audio_buffer_.c_str(),
|
|
|
+ step,
|
|
|
+ punc_cache,
|
|
|
+ is_final,
|
|
|
+ sampling_rate_,
|
|
|
+ encoding_,
|
|
|
+ mode_);
|
|
|
+ audio_buffer_ = audio_buffer_.substr(step);
|
|
|
+
|
|
|
+ if (result) {
|
|
|
+ std::string online_message = FunASRGetResult(result, 0);
|
|
|
+ online_result += online_message;
|
|
|
+ if(online_message != ""){
|
|
|
+ Response response;
|
|
|
+ response.set_mode(DecodeMode::online);
|
|
|
+ response.set_text(online_message);
|
|
|
+ response.set_is_final(is_final);
|
|
|
+ stream_->Write(response);
|
|
|
+ LOG(INFO) << "send online results: " << online_message;
|
|
|
}
|
|
|
- std::string tmp_data = this->client_buffers[request.user()];
|
|
|
-
|
|
|
- int data_len_int = tmp_data.length();
|
|
|
- std::string data_len = std::to_string(data_len_int);
|
|
|
- std::stringstream ss;
|
|
|
- ss << R"({"success": true, "detail": "decoding data: )" << data_len << R"( bytes")" << R"("})";
|
|
|
-
|
|
|
- respond.set_sentence(ss.str());
|
|
|
- respond.set_action("decoding");
|
|
|
- stream_->Write(respond);
|
|
|
-
|
|
|
- // start recoginize
|
|
|
- std::string asr_result;
|
|
|
- if (tmp_data.length() < 800) { //min input_len for asr model
|
|
|
- asr_result = "";
|
|
|
- std::cout << "error: data_is_not_long_enough" << std::endl;
|
|
|
- } else {
|
|
|
- FUNASR_RESULT result = FunOfflineInferBuffer(*asr_handler_, tmp_data.c_str(), data_len_int, RASR_NONE, NULL, 16000);
|
|
|
- asr_result = ((FUNASR_RECOG_RESULT*) result)->msg;
|
|
|
+ std::string tpass_message = FunASRGetTpassResult(result, 0);
|
|
|
+ tpass_result += tpass_message;
|
|
|
+ if(tpass_message != ""){
|
|
|
+ Response response;
|
|
|
+ response.set_mode(DecodeMode::two_pass);
|
|
|
+ response.set_text(tpass_message);
|
|
|
+ response.set_is_final(is_final);
|
|
|
+ stream_->Write(response);
|
|
|
+ LOG(INFO) << "send offline results: " << tpass_message;
|
|
|
}
|
|
|
+ FunASRFreeResult(result);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (is_final) {
|
|
|
+ FunTpassOnlineUninit(tpass_online_handler);
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ sleep(0.001);
|
|
|
+ }
|
|
|
+}
|
|
|
|
|
|
- auto end_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
|
|
- std::string delay_str = std::to_string(end_time - begin_time);
|
|
|
- std::cout << "user: " << request.user() << " , delay(ms): " << delay_str << ", text: " << asr_result << std::endl;
|
|
|
- std::stringstream ss2;
|
|
|
- ss2 << R"({"success": true, "detail": "finish_sentence","server_delay_ms":)" << delay_str << R"(,"text":")" << asr_result << R"("})";
|
|
|
+void GrpcEngine::OnSpeechStart() {
|
|
|
+ if (request_->chunk_size_size() == 3) {
|
|
|
+ for (int i = 0; i < 3; i++) {
|
|
|
+ chunk_size_[i] = int(request_->chunk_size(i));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ std::string chunk_size_str;
|
|
|
+ for (int i = 0; i < 3; i++) {
|
|
|
+ chunk_size_str = " " + chunk_size_[i];
|
|
|
+ }
|
|
|
+ LOG(INFO) << "chunk_size is" << chunk_size_str;
|
|
|
|
|
|
- respond.set_sentence(ss2.str());
|
|
|
- respond.set_action("finish");
|
|
|
- stream_->Write(respond);
|
|
|
+ if (request_->sampling_rate() != 0) {
|
|
|
+ sampling_rate_ = request_->sampling_rate();
|
|
|
+ }
|
|
|
+ LOG(INFO) << "sampling_rate is " << sampling_rate_;
|
|
|
+
|
|
|
+ switch(request_->wav_format()) {
|
|
|
+ case WavFormat::pcm: encoding_ = "pcm";
|
|
|
+ }
|
|
|
+ LOG(INFO) << "encoding is " << encoding_;
|
|
|
+
|
|
|
+ std::string mode_str;
|
|
|
+ LOG(INFO) << request_->mode() << DecodeMode::offline << DecodeMode::online << DecodeMode::two_pass;
|
|
|
+ switch(request_->mode()) {
|
|
|
+ case DecodeMode::offline:
|
|
|
+ mode_ = ASR_OFFLINE;
|
|
|
+ mode_str = "offline";
|
|
|
+ break;
|
|
|
+ case DecodeMode::online:
|
|
|
+ mode_ = ASR_ONLINE;
|
|
|
+ mode_str = "online";
|
|
|
+ break;
|
|
|
+ case DecodeMode::two_pass:
|
|
|
+ mode_ = ASR_TWO_PASS;
|
|
|
+ mode_str = "two_pass";
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ LOG(INFO) << "decode mode is " << mode_str;
|
|
|
+
|
|
|
+ decode_thread_ = std::make_shared<std::thread>(&GrpcEngine::DecodeThreadFunc, this);
|
|
|
+ is_start_ = true;
|
|
|
+}
|
|
|
+
|
|
|
+void GrpcEngine::OnSpeechData() {
|
|
|
+ audio_buffer_ += request_->audio_data();
|
|
|
+}
|
|
|
+
|
|
|
+void GrpcEngine::OnSpeechEnd() {
|
|
|
+ is_end_ = true;
|
|
|
+ LOG(INFO) << "Read all pcm data, wait for decoding thread";
|
|
|
+ if (decode_thread_ != nullptr) {
|
|
|
+ decode_thread_->join();
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void GrpcEngine::operator()() {
|
|
|
+ try {
|
|
|
+ LOG(INFO) << "start engine main loop";
|
|
|
+ while (stream_->Read(request_.get())) {
|
|
|
+ LOG(INFO) << "receive data";
|
|
|
+ if (!is_start_) {
|
|
|
+ OnSpeechStart();
|
|
|
+ }
|
|
|
+ OnSpeechData();
|
|
|
+ if (request_->is_final()) {
|
|
|
+ OnSpeechEnd();
|
|
|
+ break;
|
|
|
}
|
|
|
}
|
|
|
+ LOG(INFO) << "Connect finish";
|
|
|
+ } catch (std::exception const& e) {
|
|
|
+ LOG(ERROR) << e.what();
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-GrpcService::GrpcService(std::map<std::string, std::string>& config, int num_thread)
|
|
|
+GrpcService::GrpcService(std::map<std::string, std::string>& config, int onnx_thread)
|
|
|
: config_(config) {
|
|
|
|
|
|
- asr_handler_ = std::make_shared<FUNASR_HANDLE>(std::move(FunOfflineInit(config_, num_thread)));
|
|
|
- std::cout << "GrpcService model loades" << std::endl;
|
|
|
+ asr_handler_ = std::make_shared<FUNASR_HANDLE>(std::move(FunTpassInit(config_, onnx_thread)));
|
|
|
+ LOG(INFO) << "GrpcService model loaded";
|
|
|
+
|
|
|
+ std::vector<int> chunk_size = {5, 10, 5};
|
|
|
+ FUNASR_HANDLE tmp_online_handler = FunTpassOnlineInit(*asr_handler_, chunk_size);
|
|
|
+ int sampling_rate = 16000;
|
|
|
+ int buffer_len = sampling_rate * 1;
|
|
|
+ std::string tmp_data(buffer_len, '0');
|
|
|
+ std::vector<std::vector<std::string>> punc_cache(2);
|
|
|
+ bool is_final = true;
|
|
|
+ std::string encoding = "pcm";
|
|
|
+ FUNASR_RESULT result = FunTpassInferBuffer(*asr_handler_,
|
|
|
+ tmp_online_handler,
|
|
|
+ tmp_data.c_str(),
|
|
|
+ buffer_len,
|
|
|
+ punc_cache,
|
|
|
+ is_final,
|
|
|
+ buffer_len,
|
|
|
+ encoding,
|
|
|
+ ASR_TWO_PASS);
|
|
|
+ if (result) {
|
|
|
+ FunASRFreeResult(result);
|
|
|
+ }
|
|
|
+ FunTpassOnlineUninit(tmp_online_handler);
|
|
|
+ LOG(INFO) << "GrpcService model warmup";
|
|
|
}
|
|
|
|
|
|
grpc::Status GrpcService::Recognize(
|
|
|
grpc::ServerContext* context,
|
|
|
grpc::ServerReaderWriter<Response, Request>* stream) {
|
|
|
-
|
|
|
- LOG(INFO) << "Get Recognize request" << std::endl;
|
|
|
+ LOG(INFO) << "Get Recognize request";
|
|
|
GrpcEngine engine(
|
|
|
stream,
|
|
|
asr_handler_
|
|
|
@@ -106,29 +205,34 @@ void GetValue(TCLAP::ValueArg<std::string>& value_arg, std::string key, std::map
|
|
|
}
|
|
|
|
|
|
int main(int argc, char* argv[]) {
|
|
|
- google::InitGoogleLogging(argv[0]);
|
|
|
FLAGS_logtostderr = true;
|
|
|
+ google::InitGoogleLogging(argv[0]);
|
|
|
|
|
|
- TCLAP::CmdLine cmd("paraformer-server", ' ', "1.0");
|
|
|
- TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the asr model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
|
|
|
- TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
|
|
|
- TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
|
|
|
- TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string");
|
|
|
- TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
|
|
|
- TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string");
|
|
|
+ TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0");
|
|
|
+ TCLAP::ValueArg<std::string> offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
|
|
|
+ TCLAP::ValueArg<std::string> online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
|
|
|
+ TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
|
|
|
+ TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
|
|
|
+ TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
|
|
|
+ TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
|
|
|
+ TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
|
|
|
+ TCLAP::ValueArg<std::int32_t> onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
|
|
|
TCLAP::ValueArg<std::string> port_id("", PORT_ID, "port id", true, "", "string");
|
|
|
|
|
|
- cmd.add(model_dir);
|
|
|
+ cmd.add(offline_model_dir);
|
|
|
+ cmd.add(online_model_dir);
|
|
|
cmd.add(quantize);
|
|
|
cmd.add(vad_dir);
|
|
|
cmd.add(vad_quant);
|
|
|
cmd.add(punc_dir);
|
|
|
cmd.add(punc_quant);
|
|
|
+ cmd.add(onnx_thread);
|
|
|
cmd.add(port_id);
|
|
|
cmd.parse(argc, argv);
|
|
|
|
|
|
std::map<std::string, std::string> config;
|
|
|
- GetValue(model_dir, MODEL_DIR, config);
|
|
|
+ GetValue(offline_model_dir, OFFLINE_MODEL_DIR, config);
|
|
|
+ GetValue(online_model_dir, ONLINE_MODEL_DIR, config);
|
|
|
GetValue(quantize, QUANTIZE, config);
|
|
|
GetValue(vad_dir, VAD_DIR, config);
|
|
|
GetValue(vad_quant, VAD_QUANT, config);
|
|
|
@@ -140,18 +244,18 @@ int main(int argc, char* argv[]) {
|
|
|
try {
|
|
|
port = config.at(PORT_ID);
|
|
|
} catch(std::exception const &e) {
|
|
|
- std::cout << ("Error when read port.") << std::endl;
|
|
|
+ LOG(INFO) << ("Error when read port.");
|
|
|
exit(0);
|
|
|
}
|
|
|
std::string server_address;
|
|
|
server_address = "0.0.0.0:" + port;
|
|
|
- GrpcService service(config, 1);
|
|
|
+ GrpcService service(config, onnx_thread);
|
|
|
|
|
|
grpc::ServerBuilder builder;
|
|
|
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
|
|
builder.RegisterService(&service);
|
|
|
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
|
|
|
- std::cout << "Server listening on " << server_address << std::endl;
|
|
|
+ LOG(INFO) << "Server listening on " << server_address;
|
|
|
server->Wait();
|
|
|
|
|
|
return 0;
|