| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265 |
- /**
- * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
- * Reserved. MIT License (https://opensource.org/licenses/MIT)
- */
- /* 2023 by burkliu(刘柏基) liubaiji@xverse.cn */
- #include "paraformer-server.h"
- GrpcEngine::GrpcEngine(
- grpc::ServerReaderWriter<Response, Request>* stream,
- std::shared_ptr<FUNASR_HANDLE> asr_handler)
- : stream_(std::move(stream)),
- asr_handler_(std::move(asr_handler)) {
- request_ = std::make_shared<Request>();
- }
- void GrpcEngine::DecodeThreadFunc() {
- FUNASR_HANDLE tpass_online_handler = FunTpassOnlineInit(*asr_handler_, chunk_size_);
- int step = (sampling_rate_ * step_duration_ms_ / 1000) * 2; // int16 = 2bytes;
- std::vector<std::vector<std::string>> punc_cache(2);
- bool is_final = false;
- std::string online_result = "";
- std::string tpass_result = "";
- LOG(INFO) << "Decoder init, start decoding loop with mode";
- while (true) {
- if (audio_buffer_.length() > step || is_end_) {
- if (audio_buffer_.length() <= step && is_end_) {
- is_final = true;
- step = audio_buffer_.length();
- }
- FUNASR_RESULT result = FunTpassInferBuffer(*asr_handler_,
- tpass_online_handler,
- audio_buffer_.c_str(),
- step,
- punc_cache,
- is_final,
- sampling_rate_,
- encoding_,
- mode_);
- p_mutex_->lock();
- audio_buffer_ = audio_buffer_.substr(step);
- p_mutex_->unlock();
- if (result) {
- std::string online_message = FunASRGetResult(result, 0);
- online_result += online_message;
- if(online_message != ""){
- Response response;
- response.set_mode(DecodeMode::online);
- response.set_text(online_message);
- response.set_is_final(is_final);
- stream_->Write(response);
- LOG(INFO) << "send online results: " << online_message;
- }
- std::string tpass_message = FunASRGetTpassResult(result, 0);
- tpass_result += tpass_message;
- if(tpass_message != ""){
- Response response;
- response.set_mode(DecodeMode::two_pass);
- response.set_text(tpass_message);
- response.set_is_final(is_final);
- stream_->Write(response);
- LOG(INFO) << "send offline results: " << tpass_message;
- }
- FunASRFreeResult(result);
- }
- if (is_final) {
- FunTpassOnlineUninit(tpass_online_handler);
- break;
- }
- }
- sleep(0.001);
- }
- }
- void GrpcEngine::OnSpeechStart() {
- if (request_->chunk_size_size() == 3) {
- for (int i = 0; i < 3; i++) {
- chunk_size_[i] = int(request_->chunk_size(i));
- }
- }
- std::string chunk_size_str;
- for (int i = 0; i < 3; i++) {
- chunk_size_str = " " + chunk_size_[i];
- }
- LOG(INFO) << "chunk_size is" << chunk_size_str;
- if (request_->sampling_rate() != 0) {
- sampling_rate_ = request_->sampling_rate();
- }
- LOG(INFO) << "sampling_rate is " << sampling_rate_;
- switch(request_->wav_format()) {
- case WavFormat::pcm: encoding_ = "pcm";
- }
- LOG(INFO) << "encoding is " << encoding_;
- std::string mode_str;
- switch(request_->mode()) {
- case DecodeMode::offline:
- mode_ = ASR_OFFLINE;
- mode_str = "offline";
- break;
- case DecodeMode::online:
- mode_ = ASR_ONLINE;
- mode_str = "online";
- break;
- case DecodeMode::two_pass:
- mode_ = ASR_TWO_PASS;
- mode_str = "two_pass";
- break;
- }
- LOG(INFO) << "decode mode is " << mode_str;
-
- decode_thread_ = std::make_shared<std::thread>(&GrpcEngine::DecodeThreadFunc, this);
- is_start_ = true;
- }
- void GrpcEngine::OnSpeechData() {
- p_mutex_->lock();
- audio_buffer_ += request_->audio_data();
- p_mutex_->unlock();
- }
- void GrpcEngine::OnSpeechEnd() {
- is_end_ = true;
- LOG(INFO) << "Read all pcm data, wait for decoding thread";
- if (decode_thread_ != nullptr) {
- decode_thread_->join();
- }
- }
- void GrpcEngine::operator()() {
- try {
- LOG(INFO) << "start engine main loop";
- while (stream_->Read(request_.get())) {
- LOG(INFO) << "receive data";
- if (!is_start_) {
- OnSpeechStart();
- }
- OnSpeechData();
- if (request_->is_final()) {
- break;
- }
- }
- OnSpeechEnd();
- LOG(INFO) << "Connect finish";
- } catch (std::exception const& e) {
- LOG(ERROR) << e.what();
- }
- }
- GrpcService::GrpcService(std::map<std::string, std::string>& config, int onnx_thread)
- : config_(config) {
- asr_handler_ = std::make_shared<FUNASR_HANDLE>(std::move(FunTpassInit(config_, onnx_thread)));
- LOG(INFO) << "GrpcService model loaded";
- std::vector<int> chunk_size = {5, 10, 5};
- FUNASR_HANDLE tmp_online_handler = FunTpassOnlineInit(*asr_handler_, chunk_size);
- int sampling_rate = 16000;
- int buffer_len = sampling_rate * 1;
- std::string tmp_data(buffer_len, '0');
- std::vector<std::vector<std::string>> punc_cache(2);
- bool is_final = true;
- std::string encoding = "pcm";
- FUNASR_RESULT result = FunTpassInferBuffer(*asr_handler_,
- tmp_online_handler,
- tmp_data.c_str(),
- buffer_len,
- punc_cache,
- is_final,
- buffer_len,
- encoding,
- ASR_TWO_PASS);
- if (result) {
- FunASRFreeResult(result);
- }
- FunTpassOnlineUninit(tmp_online_handler);
- LOG(INFO) << "GrpcService model warmup";
- }
- grpc::Status GrpcService::Recognize(
- grpc::ServerContext* context,
- grpc::ServerReaderWriter<Response, Request>* stream) {
- LOG(INFO) << "Get Recognize request";
- GrpcEngine engine(
- stream,
- asr_handler_
- );
- std::thread t(std::move(engine));
- t.join();
- return grpc::Status::OK;
- }
- void GetValue(TCLAP::ValueArg<std::string>& value_arg, std::string key, std::map<std::string, std::string>& config) {
- if (value_arg.isSet()) {
- config.insert({key, value_arg.getValue()});
- LOG(INFO) << key << " : " << value_arg.getValue();
- }
- }
- int main(int argc, char* argv[]) {
- FLAGS_logtostderr = true;
- google::InitGoogleLogging(argv[0]);
- TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0");
- TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
- TCLAP::ValueArg<std::string> online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
- TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
- TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
- TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
- TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
- TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
- TCLAP::ValueArg<std::int32_t> onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
- TCLAP::ValueArg<std::string> port_id("", PORT_ID, "port id", true, "", "string");
- cmd.add(model_dir);
- cmd.add(online_model_dir);
- cmd.add(quantize);
- cmd.add(vad_dir);
- cmd.add(vad_quant);
- cmd.add(punc_dir);
- cmd.add(punc_quant);
- cmd.add(onnx_thread);
- cmd.add(port_id);
- cmd.parse(argc, argv);
- std::map<std::string, std::string> config;
- GetValue(model_dir, MODEL_DIR, config);
- GetValue(online_model_dir, ONLINE_MODEL_DIR, config);
- GetValue(quantize, QUANTIZE, config);
- GetValue(vad_dir, VAD_DIR, config);
- GetValue(vad_quant, VAD_QUANT, config);
- GetValue(punc_dir, PUNC_DIR, config);
- GetValue(punc_quant, PUNC_QUANT, config);
- GetValue(port_id, PORT_ID, config);
- std::string port;
- try {
- port = config.at(PORT_ID);
- } catch(std::exception const &e) {
- LOG(INFO) << ("Error when read port.");
- exit(0);
- }
- std::string server_address;
- server_address = "0.0.0.0:" + port;
- GrpcService service(config, onnx_thread);
- grpc::ServerBuilder builder;
- builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
- builder.RegisterService(&service);
- std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
- LOG(INFO) << "Server listening on " << server_address;
- server->Wait();
- return 0;
- }
|