|
|
@@ -13,7 +13,10 @@
|
|
|
#include <grpcpp/security/server_credentials.h>
|
|
|
|
|
|
#include "paraformer.grpc.pb.h"
|
|
|
-#include "paraformer_server.h"
|
|
|
+#include "paraformer-server.h"
|
|
|
+#include "tclap/CmdLine.h"
|
|
|
+#include "com-define.h"
|
|
|
+#include "glog/logging.h"
|
|
|
|
|
|
using grpc::Server;
|
|
|
using grpc::ServerBuilder;
|
|
|
@@ -27,31 +30,43 @@ using paraformer::Request;
|
|
|
using paraformer::Response;
|
|
|
using paraformer::ASR;
|
|
|
|
|
|
-ASRServicer::ASRServicer(const char* model_path, int thread_num, bool quantize) {
|
|
|
- AsrHanlde=FunASRInit(model_path, thread_num, quantize);
|
|
|
+ASRServicer::ASRServicer(std::map<std::string, std::string>& model_path) {
|
|
|
+ AsrHanlde=FunASRInit(model_path, 1);
|
|
|
std::cout << "ASRServicer init" << std::endl;
|
|
|
init_flag = 0;
|
|
|
}
|
|
|
|
|
|
+void ASRServicer::clear_states(const std::string& user) {
|
|
|
+ clear_buffers(user);
|
|
|
+ clear_transcriptions(user);
|
|
|
+}
|
|
|
+
|
|
|
+void ASRServicer::clear_buffers(const std::string& user) {
|
|
|
+ if (client_buffers.count(user)) {
|
|
|
+ client_buffers.erase(user);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void ASRServicer::clear_transcriptions(const std::string& user) {
|
|
|
+ if (client_transcription.count(user)) {
|
|
|
+ client_transcription.erase(user);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void ASRServicer::disconnect(const std::string& user) {
|
|
|
+ clear_states(user);
|
|
|
+ std::cout << "Disconnecting user: " << user << std::endl;
|
|
|
+}
|
|
|
+
|
|
|
grpc::Status ASRServicer::Recognize(
|
|
|
grpc::ServerContext* context,
|
|
|
grpc::ServerReaderWriter<Response, Request>* stream) {
|
|
|
|
|
|
Request req;
|
|
|
- std::unordered_map<std::string, std::string> client_buffers;
|
|
|
- std::unordered_map<std::string, std::string> client_transcription;
|
|
|
-
|
|
|
while (stream->Read(&req)) {
|
|
|
if (req.isend()) {
|
|
|
std::cout << "asr end" << std::endl;
|
|
|
- // disconnect
|
|
|
- if (client_buffers.count(req.user())) {
|
|
|
- client_buffers.erase(req.user());
|
|
|
- }
|
|
|
- if (client_transcription.count(req.user())) {
|
|
|
- client_transcription.erase(req.user());
|
|
|
- }
|
|
|
-
|
|
|
+ disconnect(req.user());
|
|
|
Response res;
|
|
|
res.set_sentence(
|
|
|
R"({"success": true, "detail": "asr end"})"
|
|
|
@@ -89,14 +104,8 @@ grpc::Status ASRServicer::Recognize(
|
|
|
auto& buf = client_buffers[req.user()];
|
|
|
buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end());
|
|
|
}
|
|
|
- std::string tmp_data = client_buffers[req.user()];
|
|
|
- // clear_states
|
|
|
- if (client_buffers.count(req.user())) {
|
|
|
- client_buffers.erase(req.user());
|
|
|
- }
|
|
|
- if (client_transcription.count(req.user())) {
|
|
|
- client_transcription.erase(req.user());
|
|
|
- }
|
|
|
+ std::string tmp_data = this->client_buffers[req.user()];
|
|
|
+ this->clear_states(req.user());
|
|
|
|
|
|
Response res;
|
|
|
res.set_sentence(
|
|
|
@@ -161,10 +170,17 @@ grpc::Status ASRServicer::Recognize(
|
|
|
return Status::OK;
|
|
|
}
|
|
|
|
|
|
-void RunServer(const std::string& port, int thread_num, const char* model_path, bool quantize) {
|
|
|
+void RunServer(std::map<std::string, std::string>& model_path) {
|
|
|
+ std::string port;
|
|
|
+ try{
|
|
|
+ port = model_path.at(PORT_ID);
|
|
|
+ }catch(std::exception const &e){
|
|
|
+ printf("Error when read port.\n");
|
|
|
+ exit(0);
|
|
|
+ }
|
|
|
std::string server_address;
|
|
|
server_address = "0.0.0.0:" + port;
|
|
|
- ASRServicer service(model_path, thread_num, quantize);
|
|
|
+ ASRServicer service(model_path);
|
|
|
|
|
|
ServerBuilder builder;
|
|
|
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
|
|
@@ -174,16 +190,54 @@ void RunServer(const std::string& port, int thread_num, const char* model_path,
|
|
|
server->Wait();
|
|
|
}
|
|
|
|
|
|
-int main(int argc, char* argv[]) {
|
|
|
- if (argc < 5)
|
|
|
- {
|
|
|
- printf("Usage: %s port thread_num /path/to/model_file quantize(true or false) \n", argv[0]);
|
|
|
- exit(-1);
|
|
|
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, std::string key, std::map<std::string, std::string>& model_path)
|
|
|
+{
|
|
|
+ if (value_arg.isSet()){
|
|
|
+ model_path.insert({key, value_arg.getValue()});
|
|
|
+ LOG(INFO)<< key << " : " << value_arg.getValue();
|
|
|
}
|
|
|
+}
|
|
|
+
|
|
|
+int main(int argc, char* argv[]) {
|
|
|
|
|
|
- // is quantize
|
|
|
- bool quantize = false;
|
|
|
- std::istringstream(argv[4]) >> std::boolalpha >> quantize;
|
|
|
- RunServer(argv[1], atoi(argv[2]), argv[3], quantize);
|
|
|
+ google::InitGoogleLogging(argv[0]);
|
|
|
+ FLAGS_logtostderr = true;
|
|
|
+
|
|
|
+ TCLAP::CmdLine cmd("paraformer-server", ' ', "1.0");
|
|
|
+ TCLAP::ValueArg<std::string> vad_model("", VAD_MODEL_PATH, "vad model path", false, "", "string");
|
|
|
+ TCLAP::ValueArg<std::string> vad_cmvn("", VAD_CMVN_PATH, "vad cmvn path", false, "", "string");
|
|
|
+ TCLAP::ValueArg<std::string> vad_config("", VAD_CONFIG_PATH, "vad config path", false, "", "string");
|
|
|
+
|
|
|
+ TCLAP::ValueArg<std::string> am_model("", AM_MODEL_PATH, "am model path", true, "", "string");
|
|
|
+ TCLAP::ValueArg<std::string> am_cmvn("", AM_CMVN_PATH, "am cmvn path", true, "", "string");
|
|
|
+ TCLAP::ValueArg<std::string> am_config("", AM_CONFIG_PATH, "am config path", true, "", "string");
|
|
|
+
|
|
|
+ TCLAP::ValueArg<std::string> punc_model("", PUNC_MODEL_PATH, "punc model path", false, "", "string");
|
|
|
+ TCLAP::ValueArg<std::string> punc_config("", PUNC_CONFIG_PATH, "punc config path", false, "", "string");
|
|
|
+ TCLAP::ValueArg<std::string> port_id("", PORT_ID, "port id", true, "", "string");
|
|
|
+
|
|
|
+ cmd.add(vad_model);
|
|
|
+ cmd.add(vad_cmvn);
|
|
|
+ cmd.add(vad_config);
|
|
|
+ cmd.add(am_model);
|
|
|
+ cmd.add(am_cmvn);
|
|
|
+ cmd.add(am_config);
|
|
|
+ cmd.add(punc_model);
|
|
|
+ cmd.add(punc_config);
|
|
|
+ cmd.add(port_id);
|
|
|
+ cmd.parse(argc, argv);
|
|
|
+
|
|
|
+ std::map<std::string, std::string> model_path;
|
|
|
+ GetValue(vad_model, VAD_MODEL_PATH, model_path);
|
|
|
+ GetValue(vad_cmvn, VAD_CMVN_PATH, model_path);
|
|
|
+ GetValue(vad_config, VAD_CONFIG_PATH, model_path);
|
|
|
+ GetValue(am_model, AM_MODEL_PATH, model_path);
|
|
|
+ GetValue(am_cmvn, AM_CMVN_PATH, model_path);
|
|
|
+ GetValue(am_config, AM_CONFIG_PATH, model_path);
|
|
|
+ GetValue(punc_model, PUNC_MODEL_PATH, model_path);
|
|
|
+ GetValue(punc_config, PUNC_CONFIG_PATH, model_path);
|
|
|
+ GetValue(port_id, PORT_ID, model_path);
|
|
|
+
|
|
|
+ RunServer(model_path);
|
|
|
return 0;
|
|
|
}
|