Kaynağa Gözat

add timestamp smooth

雾聪 2 yıl önce
ebeveyn
işleme
d674c29323

+ 2 - 2
runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp

@@ -55,7 +55,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
     for (size_t i = 0; i < 1; i++)
     for (size_t i = 0; i < 1; i++)
     {
     {
         FunOfflineReset(asr_handle, decoder_handle);
         FunOfflineReset(asr_handle, decoder_handle);
-        FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, false, decoder_handle);
+        FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, true, decoder_handle);
         if(result){
         if(result){
             FunASRFreeResult(result);
             FunASRFreeResult(result);
         }
         }
@@ -69,7 +69,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
         }
         }
 
 
         gettimeofday(&start, NULL);
         gettimeofday(&start, NULL);
-        FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[i].c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, false, decoder_handle);
+        FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[i].c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, true, decoder_handle);
 
 
         gettimeofday(&end, NULL);
         gettimeofday(&end, NULL);
         seconds = (end.tv_sec - start.tv_sec);
         seconds = (end.tv_sec - start.tv_sec);

+ 1 - 1
runtime/onnxruntime/bin/funasr-onnx-offline.cpp

@@ -157,7 +157,7 @@ int main(int argc, char** argv)
         auto& wav_file = wav_list[i];
         auto& wav_file = wav_list[i];
         auto& wav_id = wav_ids[i];
         auto& wav_id = wav_ids[i];
         gettimeofday(&start, NULL);
         gettimeofday(&start, NULL);
-        FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, false, decoder_handle);
+        FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, true, decoder_handle);
         gettimeofday(&end, NULL);
         gettimeofday(&end, NULL);
         seconds = (end.tv_sec - start.tv_sec);
         seconds = (end.tv_sec - start.tv_sec);
         taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
         taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);

+ 19 - 0
runtime/onnxruntime/src/funasrruntime.cpp

@@ -294,6 +294,12 @@
 #if !defined(__APPLE__)
 #if !defined(__APPLE__)
 		if(offline_stream->UseITN() && itn){
 		if(offline_stream->UseITN() && itn){
 			string msg_itn = offline_stream->itn_handle->Normalize(p_result->msg);
 			string msg_itn = offline_stream->itn_handle->Normalize(p_result->msg);
+			if(!(p_result->stamp).empty()){
+				std::string new_stamp = funasr::TimestampSmooth(p_result->msg, msg_itn, p_result->stamp);
+				if(!new_stamp.empty()){
+					p_result->stamp = new_stamp;
+				}
+			}			
 			p_result->msg = msg_itn;
 			p_result->msg = msg_itn;
 		}
 		}
 #endif
 #endif
@@ -384,6 +390,12 @@
 #if !defined(__APPLE__)
 #if !defined(__APPLE__)
 		if(offline_stream->UseITN() && itn){
 		if(offline_stream->UseITN() && itn){
 			string msg_itn = offline_stream->itn_handle->Normalize(p_result->msg);
 			string msg_itn = offline_stream->itn_handle->Normalize(p_result->msg);
+			if(!(p_result->stamp).empty()){
+				std::string new_stamp = funasr::TimestampSmooth(p_result->msg, msg_itn, p_result->stamp);
+				if(!new_stamp.empty()){
+					p_result->stamp = new_stamp;
+				}
+			}
 			p_result->msg = msg_itn;
 			p_result->msg = msg_itn;
 		}
 		}
 #endif
 #endif
@@ -524,6 +536,13 @@
 #if !defined(__APPLE__)
 #if !defined(__APPLE__)
 			if(tpass_stream->UseITN() && itn){
 			if(tpass_stream->UseITN() && itn){
 				string msg_itn = tpass_stream->itn_handle->Normalize(msg_punc);
 				string msg_itn = tpass_stream->itn_handle->Normalize(msg_punc);
+				// TimestampSmooth
+				if(!(p_result->stamp).empty()){
+					std::string new_stamp = funasr::TimestampSmooth(p_result->tpass_msg, msg_itn, p_result->stamp);
+					if(!new_stamp.empty()){
+						p_result->stamp = new_stamp;
+					}
+				}
 				p_result->tpass_msg = msg_itn;
 				p_result->tpass_msg = msg_itn;
 			}
 			}
 #endif
 #endif

+ 7 - 2
runtime/onnxruntime/src/paraformer.cpp

@@ -300,10 +300,15 @@ void Paraformer::InitSegDict(const std::string &seg_dict_model) {
 
 
 Paraformer::~Paraformer()
 Paraformer::~Paraformer()
 {
 {
-    if(vocab)
+    if(vocab){
         delete vocab;
         delete vocab;
-    if(seg_dict)
+    }
+    if(seg_dict){
         delete seg_dict;
         delete seg_dict;
+    }
+    if(phone_set_){
+        delete phone_set_;
+    }
 }
 }
 
 
 void Paraformer::StartUtterance()
 void Paraformer::StartUtterance()

+ 331 - 5
runtime/onnxruntime/src/util.cpp

@@ -247,6 +247,316 @@ void SplitChiEngCharacters(const std::string &input_str,
   }
   }
 }
 }
 
 
+// Timestamp Smooth
+void TimestampAdd(std::deque<string> &alignment_str1, std::string str_word){
+    if(!TimestampIsPunctuation(str_word)){
+        alignment_str1.push_front(str_word);
+    }
+}
+
+bool TimestampIsPunctuation(const std::string& str) {
+    const std::string punctuation = u8",。?、,.?";
+    for (char ch : str) {
+        if (punctuation.find(ch) == std::string::npos) {
+            return false;
+        }
+    }
+    return true;
+}
+
+vector<vector<int>> ParseTimestamps(const std::string& str) {
+    vector<vector<int>> timestamps;
+    std::istringstream ss(str);
+    std::string segment;
+
+    // skip first'['
+    ss.ignore(1);
+
+    while (std::getline(ss, segment, ']')) {
+        std::istringstream segmentStream(segment);
+        std::string number;
+        vector<int> ts;
+
+        // skip'['
+        segmentStream.ignore(1);
+
+        while (std::getline(segmentStream, number, ',')) {
+            ts.push_back(std::stoi(number));
+        }
+        if(ts.size() != 2){
+            LOG(ERROR) << "ParseTimestamps Failed";
+            timestamps.clear();
+            return timestamps;
+        }
+        timestamps.push_back(ts);
+        ss.ignore(1);
+    }
+
+    return timestamps;
+}
+
+bool TimestampIsDigit(U16CHAR_T &u16) {
+    return u16 >= L'0' && u16 <= L'9';
+}
+
+bool TimestampIsAlpha(U16CHAR_T &u16) {
+    return (u16 >= L'A' && u16 <= L'Z') || (u16 >= L'a' && u16 <= L'z');
+}
+
+bool TimestampIsPunctuation(U16CHAR_T &u16) {
+    return (u16 >= 0x21 && u16 <= 0x2F)     // 标准ASCII标点
+        || (u16 >= 0x3A && u16 <= 0x40)     // 标准ASCII标点
+        || (u16 >= 0x5B && u16 <= 0x60)     // 标准ASCII标点
+        || (u16 >= 0x7B && u16 <= 0x7E)     // 标准ASCII标点
+        || (u16 >= 0x2000 && u16 <= 0x206F) // 常用的Unicode标点
+        || (u16 >= 0x3000 && u16 <= 0x303F); // CJK符号和标点
+}
+
+void TimestampSplitChiEngCharacters(const std::string &input_str,
+                                  std::vector<std::string> &characters) {
+  characters.resize(0);
+  std::string eng_word = "";
+  U16CHAR_T space = 0x0020;
+  std::vector<U16CHAR_T> u16_buf;
+  u16_buf.resize(std::max(u16_buf.size(), input_str.size() + 1));
+  U16CHAR_T* pu16 = u16_buf.data();
+  U8CHAR_T * pu8 = (U8CHAR_T*)input_str.data();
+  size_t ilen = input_str.size();
+  size_t len = EncodeConverter::Utf8ToUtf16(pu8, ilen, pu16, ilen + 1);
+  for (size_t i = 0; i < len; i++) {
+    if (EncodeConverter::IsChineseCharacter(pu16[i])) {
+      if(!eng_word.empty()){
+        characters.push_back(eng_word);
+        eng_word = "";
+      }
+      U8CHAR_T u8buf[4];
+      size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
+      u8buf[n] = '\0';
+      characters.push_back((const char*)u8buf);
+    } else if (TimestampIsDigit(pu16[i]) || TimestampIsPunctuation(pu16[i])){
+      if(!eng_word.empty()){
+        characters.push_back(eng_word);
+        eng_word = "";
+      }
+      U8CHAR_T u8buf[4];
+      size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
+      u8buf[n] = '\0';
+      characters.push_back((const char*)u8buf);
+    } else if (pu16[i] == space){
+      if(!eng_word.empty()){
+        characters.push_back(eng_word);
+        eng_word = "";
+      }      
+    }else{
+      U8CHAR_T u8buf[4];
+      size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
+      u8buf[n] = '\0';
+      eng_word += (const char*)u8buf;
+    }
+  }
+  if(!eng_word.empty()){
+    characters.push_back(eng_word);
+    eng_word = "";
+  }
+}
+
+std::string VectorToString(const std::vector<std::vector<int>>& vec) {
+    if(vec.size() == 0){
+        return "";
+    }
+    std::ostringstream out;
+    out << "[";
+
+    for (size_t i = 0; i < vec.size(); ++i) {
+        out << "[";
+        for (size_t j = 0; j < vec[i].size(); ++j) {
+            out << vec[i][j];
+            if (j < vec[i].size() - 1) {
+                out << ",";
+            }
+        }
+        out << "]";
+        if (i < vec.size() - 1) {
+            out << ",";
+        }
+    }
+
+    out << "]";
+    return out.str();
+}
+
+std::string TimestampSmooth(std::string &text, std::string &text_itn, std::string &str_time){
+    vector<vector<int>> timestamps_out;
+    std::string timestamps_str = "";
+    // process string to vector<string>
+    std::vector<std::string> characters;
+    funasr::TimestampSplitChiEngCharacters(text, characters);
+    
+    std::vector<std::string> characters_itn;
+    funasr::TimestampSplitChiEngCharacters(text_itn, characters_itn);
+    
+    //convert string to vector<vector<int>>
+    vector<vector<int>> timestamps = funasr::ParseTimestamps(str_time);
+
+    if (timestamps.size() == 0){
+        LOG(ERROR) << "Timestamp Smooth Failed: Length of timestamp is zero";
+        return timestamps_str;
+    }
+    
+    // edit distance
+    int m = characters.size();
+    int n = characters_itn.size();
+    std::vector<std::vector<int>> dp(m + 1, std::vector<int>(n + 1, 0));
+
+    // init
+    for (int i = 0; i <= m; ++i) {
+        dp[i][0] = i;
+    }
+    for (int j = 0; j <= n; ++j) {
+        dp[0][j] = j;
+    }
+
+    // dp
+    for (int i = 1; i <= m; ++i) {
+        for (int j = 1; j <= n; ++j) {
+            if (characters[i - 1] == characters_itn[j - 1]) {
+                dp[i][j] = dp[i - 1][j - 1];
+            } else {
+                dp[i][j] = std::min({dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]}) + 1;
+            }
+        }
+    }
+
+    // backtrack
+    std::deque<string> alignment_str1, alignment_str2;
+    int i = m, j = n;
+    while (i > 0 || j > 0) {
+        if (i > 0 && j > 0 && dp[i][j] == dp[i - 1][j - 1]) {
+            funasr::TimestampAdd(alignment_str1, characters[i - 1]);
+            funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]);
+            i -= 1;
+            j -= 1;
+        } else if (i > 0 && dp[i][j] == dp[i - 1][j] + 1) {
+            funasr::TimestampAdd(alignment_str1, characters[i - 1]);
+            alignment_str2.push_front("");
+            i -= 1;
+        } else if (j > 0 && dp[i][j] == dp[i][j - 1] + 1) {
+            alignment_str1.push_front("");
+            funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]);
+            j -= 1;
+        } else{
+            funasr::TimestampAdd(alignment_str1, characters[i - 1]);
+            funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]);
+            i -= 1;
+            j -= 1;            
+        }
+    }
+
+    // smooth
+    int itn_count = 0;
+    int idx_tp = 0;
+    int idx_itn = 0;
+    vector<vector<int>> timestamps_tmp;
+    for(int index = 0; index < alignment_str1.size(); index++){
+        if (alignment_str1[index] == alignment_str2[index]){
+            bool subsidy = false;
+            if (itn_count > 0 && timestamps_tmp.size() == 0){
+                if(idx_tp >= timestamps.size()){
+                    LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. ";
+                    return timestamps_str;
+                }
+                timestamps_tmp.push_back(timestamps[idx_tp]);
+                subsidy = true;
+                itn_count++;
+            }
+
+            if (timestamps_tmp.size() > 0){
+                if (itn_count > 0){
+                    int begin = timestamps_tmp[0][0];
+                    int end = timestamps_tmp.back()[1];
+                    int total_time = end - begin;
+                    int interval = total_time / itn_count;
+                    for(int idx_cnt=0; idx_cnt < itn_count; idx_cnt++){
+                        vector<int> ts;
+                        ts.push_back(begin + interval*idx_cnt);
+                        if(idx_cnt == itn_count-1){
+                            ts.push_back(end);
+                        }else {
+                            ts.push_back(begin + interval*(idx_cnt + 1));
+                        }
+                        timestamps_out.push_back(ts);
+                    }
+                }
+                timestamps_tmp.clear();
+            }
+            if(!subsidy){
+                if(idx_tp >= timestamps.size()){
+                    LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. ";
+                    return timestamps_str;
+                }
+                timestamps_out.push_back(timestamps[idx_tp]);
+            }
+            idx_tp++;
+            itn_count = 0;
+        }else{
+            if (!alignment_str1[index].empty()){
+                if(idx_tp >= timestamps.size()){
+                    LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. ";
+                    return timestamps_str;
+                }
+                timestamps_tmp.push_back(timestamps[idx_tp]);
+                idx_tp++;
+            }
+            if (!alignment_str2[index].empty()){
+                itn_count++;
+            }
+        }
+        // count length of itn
+        if (!alignment_str2[index].empty()){
+            idx_itn++;
+        }
+    }
+    {
+        if (itn_count > 0 && timestamps_tmp.size() == 0){
+            if (timestamps_out.size() > 0){
+                timestamps_tmp.push_back(timestamps_out.back());
+                itn_count++;
+                timestamps_out.pop_back();
+            } else{
+                LOG(ERROR) << "Timestamp Smooth Failed: Last itn has no timestamp.";
+                return timestamps_str;
+            }
+        }
+
+        if (timestamps_tmp.size() > 0){
+            if (itn_count > 0){
+                int begin = timestamps_tmp[0][0];
+                int end = timestamps_tmp.back()[1];
+                int total_time = end - begin;
+                int interval = total_time / itn_count;
+                for(int idx_cnt=0; idx_cnt < itn_count; idx_cnt++){
+                    vector<int> ts;
+                    ts.push_back(begin + interval*idx_cnt);
+                    if(idx_cnt == itn_count-1){
+                        ts.push_back(end);
+                    }else {
+                        ts.push_back(begin + interval*(idx_cnt + 1));
+                    }
+                    timestamps_out.push_back(ts);
+                }
+            }
+            timestamps_tmp.clear();
+        }
+    }
+    if(timestamps_out.size() != idx_itn){
+        LOG(ERROR) << "Timestamp Smooth Failed: Timestamp length does not matched.";
+        return timestamps_str;
+    }
+    
+    timestamps_str = VectorToString(timestamps_out);
+    return timestamps_str;
+}
+
 std::vector<std::string> split(const std::string &s, char delim) {
 std::vector<std::string> split(const std::string &s, char delim) {
   std::vector<std::string> elems;
   std::vector<std::string> elems;
   std::stringstream ss(s);
   std::stringstream ss(s);
@@ -333,12 +643,23 @@ string PostProcess(std::vector<string> &raw_char, std::vector<std::vector<float>
             int sub_word = !(word.find("@@") == string::npos);
             int sub_word = !(word.find("@@") == string::npos);
             // process word start and middle part
             // process word start and middle part
             if (sub_word) {
             if (sub_word) {
-                combine += word.erase(word.length() - 2);
-                if(!is_combining){
-                    begin = timestamp_list[i][0];
+                // if badcase: lo@@ chinese
+                if (i == raw_char.size()-1 || i<raw_char.size()-1 && IsChinese(raw_char[i+1])){
+                    word = word.erase(word.length() - 2) + " ";
+                    if (is_combining) {
+                        combine += word;
+                        is_combining = false;
+                        word = combine;
+                        combine = "";
+                    }
+                }else{
+                    combine += word.erase(word.length() - 2);
+                    if(!is_combining){
+                        begin = timestamp_list[i][0];
+                    }
+                    is_combining = true;
+                    continue;
                 }
                 }
-                is_combining = true;
-                continue;
             }
             }
             // process word end part
             // process word end part
             else if (is_combining) {
             else if (is_combining) {
@@ -669,4 +990,9 @@ void ExtractHws(string hws_file, unordered_map<string, int> &hws_map, string& nn
     ifs_hws.close();
     ifs_hws.close();
 }
 }
 
 
+void SmoothTimestamps(std::string &str_punc, std::string &str_itn, std::string &str_timetamp){
+    
+    return;
+}
+
 } // namespace funasr
 } // namespace funasr

+ 12 - 0
runtime/onnxruntime/src/util.h

@@ -3,11 +3,13 @@
 #include <vector>
 #include <vector>
 #include <memory>
 #include <memory>
 #include <unordered_map>
 #include <unordered_map>
+#include <deque>
 #include "tensor.h"
 #include "tensor.h"
 
 
 using namespace std;
 using namespace std;
 
 
 namespace funasr {
 namespace funasr {
+typedef unsigned short          U16CHAR_T;
 extern float *LoadParams(const char *filename);
 extern float *LoadParams(const char *filename);
 
 
 extern void SaveDataFile(const char *filename, void *data, uint32_t len);
 extern void SaveDataFile(const char *filename, void *data, uint32_t len);
@@ -35,6 +37,16 @@ void KeepChineseCharacterAndSplit(const std::string &input_str,
                                   std::vector<std::string> &chinese_characters);
                                   std::vector<std::string> &chinese_characters);
 void SplitChiEngCharacters(const std::string &input_str,
 void SplitChiEngCharacters(const std::string &input_str,
                                   std::vector<std::string> &characters);
                                   std::vector<std::string> &characters);
+void TimestampAdd(std::deque<string> &alignment_str1, std::string str_word);
+vector<vector<int>> ParseTimestamps(const std::string& str);
+bool TimestampIsDigit(U16CHAR_T &u16);
+bool TimestampIsAlpha(U16CHAR_T &u16);
+bool TimestampIsPunctuation(U16CHAR_T &u16);
+bool TimestampIsPunctuation(const std::string& str);
+void TimestampSplitChiEngCharacters(const std::string &input_str,
+                                  std::vector<std::string> &characters);
+std::string VectorToString(const std::vector<std::vector<int>>& vec);                                  
+std::string TimestampSmooth(std::string &text, std::string &text_itn, std::string &str_time);
 
 
 std::vector<std::string> split(const std::string &s, char delim);
 std::vector<std::string> split(const std::string &s, char delim);
 
 

+ 16 - 5
runtime/onnxruntime/src/vocab.cpp

@@ -120,8 +120,8 @@ string Vocab::Vector2StringV2(vector<int> in, std::string language)
     std::string combine = "";
     std::string combine = "";
     std::string unicodeChar = "▁";
     std::string unicodeChar = "▁";
 
 
-    for (auto it = in.begin(); it != in.end(); it++) {
-        string word = vocab[*it];
+    for (i=0; i<in.size(); i++){
+        string word = vocab[in[i]];
         // step1 space character skips
         // step1 space character skips
         if (word == "<s>" || word == "</s>" || word == "<unk>")
         if (word == "<s>" || word == "</s>" || word == "<unk>")
             continue;
             continue;
@@ -146,9 +146,20 @@ string Vocab::Vector2StringV2(vector<int> in, std::string language)
             int sub_word = !(word.find("@@") == string::npos);
             int sub_word = !(word.find("@@") == string::npos);
             // process word start and middle part
             // process word start and middle part
             if (sub_word) {
             if (sub_word) {
-                combine += word.erase(word.length() - 2);
-                is_combining = true;
-                continue;
+                // if badcase: lo@@ chinese
+                if (i == in.size()-1 || i<in.size()-1 && IsChinese(vocab[in[i+1]])){
+                    word = word.erase(word.length() - 2) + " ";
+                    if (is_combining) {
+                        combine += word;
+                        is_combining = false;
+                        word = combine;
+                        combine = "";
+                    }
+                }else{
+                    combine += word.erase(word.length() - 2);
+                    is_combining = true;
+                    continue;
+                }
             }
             }
             // process word end part
             // process word end part
             else if (is_combining) {
             else if (is_combining) {