Ver código fonte

add sentence timestamp

雾聪 2 anos atrás
pai
commit
f72914003a

+ 4 - 0
runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp

@@ -83,6 +83,10 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
             if(stamp !=""){
                 LOG(INFO) << "Thread: " << this_thread::get_id() << "," << wav_ids[i] << " : " << stamp;
             }
+            string stamp_sents = FunASRGetStampSents(result);
+            if(stamp_sents !=""){
+                LOG(INFO)<< wav_ids[i] <<" : "<<stamp_sents;
+            }
             float snippet_time = FunASRGetRetSnippetTime(result);
             n_total_length += snippet_time;
             FunASRFreeResult(result);

+ 4 - 0
runtime/onnxruntime/bin/funasr-onnx-offline.cpp

@@ -172,6 +172,10 @@ int main(int argc, char** argv)
             if(stamp !=""){
                 LOG(INFO)<< wav_id <<" : "<<stamp;
             }
+            string stamp_sents = FunASRGetStampSents(result);
+            if(stamp_sents !=""){
+                LOG(INFO)<< wav_id <<" : "<<stamp_sents;
+            }
             snippet_time += FunASRGetRetSnippetTime(result);
             FunASRFreeResult(result);
         }

+ 1 - 0
runtime/onnxruntime/include/funasrruntime.h

@@ -68,6 +68,7 @@ _FUNASRAPI FUNASR_RESULT	FunASRInfer(FUNASR_HANDLE handle, const char* sz_filena
 
 _FUNASRAPI const char*	FunASRGetResult(FUNASR_RESULT result,int n_index);
 _FUNASRAPI const char*	FunASRGetStamp(FUNASR_RESULT result);
+_FUNASRAPI const char*	FunASRGetStampSents(FUNASR_RESULT result);
 _FUNASRAPI const char*	FunASRGetTpassResult(FUNASR_RESULT result,int n_index);
 _FUNASRAPI const int	FunASRGetRetNumber(FUNASR_RESULT result);
 _FUNASRAPI void			FunASRFreeResult(FUNASR_RESULT result);

+ 1 - 0
runtime/onnxruntime/src/commonfunc.h

@@ -9,6 +9,7 @@ typedef struct
 {
     std::string msg;
     std::string stamp;
+    std::string stamp_sents;
     std::string tpass_msg;
     float snippet_time;
 }FUNASR_RECOG_RESULT;

+ 18 - 2
runtime/onnxruntime/src/funasrruntime.cpp

@@ -303,7 +303,9 @@
 			p_result->msg = msg_itn;
 		}
 #endif
-
+		if (!(p_result->stamp).empty()){
+			p_result->stamp_sents = funasr::TimestampSentence(p_result->msg, p_result->stamp);
+		}
 		return p_result;
 	}
 
@@ -399,6 +401,9 @@
 			p_result->msg = msg_itn;
 		}
 #endif
+		if (!(p_result->stamp).empty()){
+			p_result->stamp_sents = funasr::TimestampSentence(p_result->msg, p_result->stamp);
+		}
 		return p_result;
 	}
 
@@ -546,7 +551,9 @@
 				p_result->tpass_msg = msg_itn;
 			}
 #endif
-
+			if (!(p_result->stamp).empty()){
+				p_result->stamp_sents = funasr::TimestampSentence(p_result->tpass_msg, p_result->stamp);
+			}
 			if(frame != NULL){
 				delete frame;
 				frame = NULL;
@@ -603,6 +610,15 @@
 		return p_result->stamp.c_str();
 	}
 
+		_FUNASRAPI const char* FunASRGetStampSents(FUNASR_RESULT result)
+	{
+		funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result;
+		if(!p_result)
+			return nullptr;
+
+		return p_result->stamp_sents.c_str();
+	}
+
 	_FUNASRAPI const char* FunASRGetTpassResult(FUNASR_RESULT result,int n_index)
 	{
 		funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result;

+ 72 - 1
runtime/onnxruntime/src/util.cpp

@@ -255,7 +255,8 @@ void TimestampAdd(std::deque<string> &alignment_str1, std::string str_word){
 }
 
 bool TimestampIsPunctuation(const std::string& str) {
-    const std::string punctuation = u8",。?、,.?";
+    const std::string punctuation = u8",。?、,?";
+    // const std::string punctuation = u8",。?、,.?";
     for (char ch : str) {
         if (punctuation.find(ch) == std::string::npos) {
             return false;
@@ -557,6 +558,76 @@ std::string TimestampSmooth(std::string &text, std::string &text_itn, std::strin
     return timestamps_str;
 }
 
+std::string TimestampSentence(std::string &text, std::string &str_time){
+    std::vector<std::string> characters;
+    funasr::TimestampSplitChiEngCharacters(text, characters);
+    vector<vector<int>> timestamps = funasr::ParseTimestamps(str_time);
+    
+    int idx_str = 0, idx_ts = 0;
+    int start = -1, end = -1;
+    std::string text_seg = "";
+    std::string ts_sentences = "";
+    std::string ts_sent = "";
+    vector<vector<int>> ts_seg;
+    while(idx_str < characters.size()){
+        if (TimestampIsPunctuation(characters[idx_str])){
+            if(ts_seg.size() >0){
+                if (ts_seg[0].size() == 2){
+                    start = ts_seg[0][0];
+                }
+                if (ts_seg[ts_seg.size()-1].size() == 2){
+                    end = ts_seg[ts_seg.size()-1][1];
+                }
+            }
+            // format
+            ts_sent += "{'text':'" + text_seg + "',";
+            ts_sent += "'start':'" + to_string(start) + "',";
+            ts_sent += "'end':'" + to_string(end) + "',";
+            ts_sent += "'ts_list':" + VectorToString(ts_seg) + "}";
+            
+            if (idx_str == characters.size()-1){
+                ts_sentences += ts_sent;
+            } else{
+                ts_sentences += ts_sent + ",";
+            }
+
+            // clear
+            idx_str++;
+            text_seg = "";
+            ts_sent = "";
+            start = 0;
+            end = 0;
+            ts_seg.clear();
+        } else if(idx_ts < timestamps.size()) {
+            if (text_seg.empty()){
+                text_seg = characters[idx_str];
+            }else{
+                text_seg += " " + characters[idx_str];
+            }
+            ts_seg.push_back(timestamps[idx_ts]);
+            idx_str++;
+            idx_ts++;
+        }
+    }
+    // for none punc results
+    if(ts_seg.size() >0){
+        if (ts_seg[0].size() == 2){
+            start = ts_seg[0][0];
+        }
+        if (ts_seg[ts_seg.size()-1].size() == 2){
+            end = ts_seg[ts_seg.size()-1][1];
+        }
+        // format
+        ts_sent += "{'text':'" + text_seg + "',";
+        ts_sent += "'start':'" + to_string(start) + "',";
+        ts_sent += "'end':'" + to_string(end) + "',";
+        ts_sent += "'ts_list':" + VectorToString(ts_seg) + "}";
+        ts_sentences += ts_sent;
+    }
+
+    return "[" +ts_sentences + "]";
+}
+
 std::vector<std::string> split(const std::string &s, char delim) {
   std::vector<std::string> elems;
   std::stringstream ss(s);

+ 1 - 1
runtime/onnxruntime/src/util.h

@@ -47,7 +47,7 @@ void TimestampSplitChiEngCharacters(const std::string &input_str,
                                   std::vector<std::string> &characters);
 std::string VectorToString(const std::vector<std::vector<int>>& vec);                                  
 std::string TimestampSmooth(std::string &text, std::string &text_itn, std::string &str_time);
-
+std::string TimestampSentence(std::string &text, std::string &str_time);
 std::vector<std::string> split(const std::string &s, char delim);
 
 template<typename T>

+ 7 - 1
runtime/websocket/bin/websocket-server-2pass.cpp

@@ -80,6 +80,12 @@ nlohmann::json handle_result(FUNASR_RESULT result) {
     jsonresult["timestamp"] = tmp_stamp_msg;
   }
 
+  std::string tmp_stamp_sents = FunASRGetStampSents(result);
+  if (tmp_stamp_sents != "") {
+    LOG(INFO) << "offline stamp_sents : " << tmp_stamp_sents;
+    jsonresult["stamp_sents"] = tmp_stamp_sents;
+  }
+
   return jsonresult;
 }
 // feed buffer to asr engine for decoder
@@ -318,7 +324,7 @@ void WebSocketServer::check_and_clean_connection() {
         data_msg->msg["is_eof"]=true;
         guard_decoder.unlock();
         to_remove.push_back(hdl);
-        LOG(INFO)<<"connection is closed: "<<e.what();
+        LOG(INFO)<<"connection is closed.";
         
       }
       iter++;

+ 6 - 1
runtime/websocket/bin/websocket-server.cpp

@@ -74,6 +74,7 @@ void WebSocketServer::do_decoder(const std::vector<char>& buffer,
     if (!buffer.empty() && hotwords_embedding.size() > 0) {
       std::string asr_result;
       std::string stamp_res;
+      std::string stamp_sents;
       try{
         FUNASR_RESULT Result = FunOfflineInferBuffer(
             asr_handle, buffer.data(), buffer.size(), RASR_NONE, NULL, 
@@ -81,6 +82,7 @@ void WebSocketServer::do_decoder(const std::vector<char>& buffer,
 
         asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg;  // get decode result
         stamp_res = ((FUNASR_RECOG_RESULT*)Result)->stamp;
+        stamp_sents = ((FUNASR_RECOG_RESULT*)Result)->stamp_sents;
         FunASRFreeResult(Result);
       }catch (std::exception const& e) {
         LOG(ERROR) << e.what();
@@ -95,6 +97,9 @@ void WebSocketServer::do_decoder(const std::vector<char>& buffer,
       if(stamp_res != ""){
         jsonresult["timestamp"] = stamp_res;
       }
+      if(stamp_sents != ""){
+        jsonresult["stamp_sents"] = stamp_sents;
+      }
       jsonresult["wav_name"] = wav_name;
 
       // send the json to client
@@ -227,7 +232,7 @@ void WebSocketServer::check_and_clean_connection() {
         data_msg->msg["is_eof"]=true;
         guard_decoder.unlock();
         to_remove.push_back(hdl);
-        LOG(INFO)<<"connection is closed: "<<e.what();
+        LOG(INFO)<<"connection is closed.";
         
       }
       iter++;

+ 1 - 0
runtime/websocket/bin/websocket-server.h

@@ -50,6 +50,7 @@ typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context>
 typedef struct {
     std::string msg="";
     std::string stamp="";
+    std::string stamp_sents;
     std::string tpass_msg="";
     float snippet_time=0;
 } FUNASR_RECOG_RESULT;