paraformer-server.cc 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. /**
  2. * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
  3. * Reserved. MIT License (https://opensource.org/licenses/MIT)
  4. */
  5. /* 2023 by burkliu(刘柏基) liubaiji@xverse.cn */
  6. #include "paraformer-server.h"
  7. GrpcEngine::GrpcEngine(
  8. grpc::ServerReaderWriter<Response, Request>* stream,
  9. std::shared_ptr<FUNASR_HANDLE> asr_handler)
  10. : stream_(std::move(stream)),
  11. asr_handler_(std::move(asr_handler)) {
  12. request_ = std::make_shared<Request>();
  13. }
  14. void GrpcEngine::DecodeThreadFunc() {
  15. FUNASR_HANDLE tpass_online_handler = FunTpassOnlineInit(*asr_handler_, chunk_size_);
  16. int step = (sampling_rate_ * step_duration_ms_ / 1000) * 2; // int16 = 2bytes;
  17. std::vector<std::vector<std::string>> punc_cache(2);
  18. bool is_final = false;
  19. std::string online_result = "";
  20. std::string tpass_result = "";
  21. LOG(INFO) << "Decoder init, start decoding loop with mode";
  22. while (true) {
  23. if (audio_buffer_.length() > step || is_end_) {
  24. if (audio_buffer_.length() <= step && is_end_) {
  25. is_final = true;
  26. step = audio_buffer_.length();
  27. }
  28. FUNASR_RESULT result = FunTpassInferBuffer(*asr_handler_,
  29. tpass_online_handler,
  30. audio_buffer_.c_str(),
  31. step,
  32. punc_cache,
  33. is_final,
  34. sampling_rate_,
  35. encoding_,
  36. mode_);
  37. p_mutex_->lock();
  38. audio_buffer_ = audio_buffer_.substr(step);
  39. p_mutex_->unlock();
  40. if (result) {
  41. std::string online_message = FunASRGetResult(result, 0);
  42. online_result += online_message;
  43. if(online_message != ""){
  44. Response response;
  45. response.set_mode(DecodeMode::online);
  46. response.set_text(online_message);
  47. response.set_is_final(is_final);
  48. stream_->Write(response);
  49. LOG(INFO) << "send online results: " << online_message;
  50. }
  51. std::string tpass_message = FunASRGetTpassResult(result, 0);
  52. tpass_result += tpass_message;
  53. if(tpass_message != ""){
  54. Response response;
  55. response.set_mode(DecodeMode::two_pass);
  56. response.set_text(tpass_message);
  57. response.set_is_final(is_final);
  58. stream_->Write(response);
  59. LOG(INFO) << "send offline results: " << tpass_message;
  60. }
  61. FunASRFreeResult(result);
  62. }
  63. if (is_final) {
  64. FunTpassOnlineUninit(tpass_online_handler);
  65. break;
  66. }
  67. }
  68. sleep(0.001);
  69. }
  70. }
  71. void GrpcEngine::OnSpeechStart() {
  72. if (request_->chunk_size_size() == 3) {
  73. for (int i = 0; i < 3; i++) {
  74. chunk_size_[i] = int(request_->chunk_size(i));
  75. }
  76. }
  77. std::string chunk_size_str;
  78. for (int i = 0; i < 3; i++) {
  79. chunk_size_str = " " + chunk_size_[i];
  80. }
  81. LOG(INFO) << "chunk_size is" << chunk_size_str;
  82. if (request_->sampling_rate() != 0) {
  83. sampling_rate_ = request_->sampling_rate();
  84. }
  85. LOG(INFO) << "sampling_rate is " << sampling_rate_;
  86. switch(request_->wav_format()) {
  87. case WavFormat::pcm: encoding_ = "pcm";
  88. }
  89. LOG(INFO) << "encoding is " << encoding_;
  90. std::string mode_str;
  91. switch(request_->mode()) {
  92. case DecodeMode::offline:
  93. mode_ = ASR_OFFLINE;
  94. mode_str = "offline";
  95. break;
  96. case DecodeMode::online:
  97. mode_ = ASR_ONLINE;
  98. mode_str = "online";
  99. break;
  100. case DecodeMode::two_pass:
  101. mode_ = ASR_TWO_PASS;
  102. mode_str = "two_pass";
  103. break;
  104. }
  105. LOG(INFO) << "decode mode is " << mode_str;
  106. decode_thread_ = std::make_shared<std::thread>(&GrpcEngine::DecodeThreadFunc, this);
  107. is_start_ = true;
  108. }
  109. void GrpcEngine::OnSpeechData() {
  110. p_mutex_->lock();
  111. audio_buffer_ += request_->audio_data();
  112. p_mutex_->unlock();
  113. }
  114. void GrpcEngine::OnSpeechEnd() {
  115. is_end_ = true;
  116. LOG(INFO) << "Read all pcm data, wait for decoding thread";
  117. if (decode_thread_ != nullptr) {
  118. decode_thread_->join();
  119. }
  120. }
  121. void GrpcEngine::operator()() {
  122. try {
  123. LOG(INFO) << "start engine main loop";
  124. while (stream_->Read(request_.get())) {
  125. LOG(INFO) << "receive data";
  126. if (!is_start_) {
  127. OnSpeechStart();
  128. }
  129. OnSpeechData();
  130. if (request_->is_final()) {
  131. break;
  132. }
  133. }
  134. OnSpeechEnd();
  135. LOG(INFO) << "Connect finish";
  136. } catch (std::exception const& e) {
  137. LOG(ERROR) << e.what();
  138. }
  139. }
  140. GrpcService::GrpcService(std::map<std::string, std::string>& config, int onnx_thread)
  141. : config_(config) {
  142. asr_handler_ = std::make_shared<FUNASR_HANDLE>(std::move(FunTpassInit(config_, onnx_thread)));
  143. LOG(INFO) << "GrpcService model loaded";
  144. std::vector<int> chunk_size = {5, 10, 5};
  145. FUNASR_HANDLE tmp_online_handler = FunTpassOnlineInit(*asr_handler_, chunk_size);
  146. int sampling_rate = 16000;
  147. int buffer_len = sampling_rate * 1;
  148. std::string tmp_data(buffer_len, '0');
  149. std::vector<std::vector<std::string>> punc_cache(2);
  150. bool is_final = true;
  151. std::string encoding = "pcm";
  152. FUNASR_RESULT result = FunTpassInferBuffer(*asr_handler_,
  153. tmp_online_handler,
  154. tmp_data.c_str(),
  155. buffer_len,
  156. punc_cache,
  157. is_final,
  158. buffer_len,
  159. encoding,
  160. ASR_TWO_PASS);
  161. if (result) {
  162. FunASRFreeResult(result);
  163. }
  164. FunTpassOnlineUninit(tmp_online_handler);
  165. LOG(INFO) << "GrpcService model warmup";
  166. }
  167. grpc::Status GrpcService::Recognize(
  168. grpc::ServerContext* context,
  169. grpc::ServerReaderWriter<Response, Request>* stream) {
  170. LOG(INFO) << "Get Recognize request";
  171. GrpcEngine engine(
  172. stream,
  173. asr_handler_
  174. );
  175. std::thread t(std::move(engine));
  176. t.join();
  177. return grpc::Status::OK;
  178. }
  179. void GetValue(TCLAP::ValueArg<std::string>& value_arg, std::string key, std::map<std::string, std::string>& config) {
  180. if (value_arg.isSet()) {
  181. config.insert({key, value_arg.getValue()});
  182. LOG(INFO) << key << " : " << value_arg.getValue();
  183. }
  184. }
  185. int main(int argc, char* argv[]) {
  186. FLAGS_logtostderr = true;
  187. google::InitGoogleLogging(argv[0]);
  188. TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0");
  189. TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
  190. TCLAP::ValueArg<std::string> online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
  191. TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
  192. TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
  193. TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
  194. TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
  195. TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
  196. TCLAP::ValueArg<std::int32_t> onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
  197. TCLAP::ValueArg<std::string> port_id("", PORT_ID, "port id", true, "", "string");
  198. cmd.add(model_dir);
  199. cmd.add(online_model_dir);
  200. cmd.add(quantize);
  201. cmd.add(vad_dir);
  202. cmd.add(vad_quant);
  203. cmd.add(punc_dir);
  204. cmd.add(punc_quant);
  205. cmd.add(onnx_thread);
  206. cmd.add(port_id);
  207. cmd.parse(argc, argv);
  208. std::map<std::string, std::string> config;
  209. GetValue(model_dir, MODEL_DIR, config);
  210. GetValue(online_model_dir, ONLINE_MODEL_DIR, config);
  211. GetValue(quantize, QUANTIZE, config);
  212. GetValue(vad_dir, VAD_DIR, config);
  213. GetValue(vad_quant, VAD_QUANT, config);
  214. GetValue(punc_dir, PUNC_DIR, config);
  215. GetValue(punc_quant, PUNC_QUANT, config);
  216. GetValue(port_id, PORT_ID, config);
  217. std::string port;
  218. try {
  219. port = config.at(PORT_ID);
  220. } catch(std::exception const &e) {
  221. LOG(INFO) << ("Error when read port.");
  222. exit(0);
  223. }
  224. std::string server_address;
  225. server_address = "0.0.0.0:" + port;
  226. GrpcService service(config, onnx_thread);
  227. grpc::ServerBuilder builder;
  228. builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
  229. builder.RegisterService(&service);
  230. std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
  231. LOG(INFO) << "Server listening on " << server_address;
  232. server->Wait();
  233. return 0;
  234. }