浏览代码

fix eng oov hotwords

雾聪 2 年之前
父节点
当前提交
639ae933aa
共有 2 个文件被更改,包括 7 次插入0 次删除
  1. 6 0
      funasr/runtime/onnxruntime/src/paraformer.cpp
  2. 1 0
      funasr/runtime/onnxruntime/src/seg_dict.cpp

+ 6 - 0
funasr/runtime/onnxruntime/src/paraformer.cpp

@@ -719,6 +719,7 @@ std::vector<std::vector<float>> Paraformer::CompileHotwordEmbedding(std::string
     std::vector<int32_t> hotword_matrix;
     std::vector<int32_t> lengths;
     int hotword_size = 1;
+    int real_hw_size = 0;
     if (!hotwords.empty()) {
       std::vector<std::string> hotword_array = split(hotwords, ' ');
       hotword_size = hotword_array.size() + 1;
@@ -735,6 +736,9 @@ std::vector<std::vector<float>> Paraformer::CompileHotwordEmbedding(std::string
             chars.insert(chars.end(), tokens.begin(), tokens.end());
           }
         }
+        if(chars.size()==0){
+            continue;
+        }
         std::vector<int32_t> hw_vector(max_hotword_len, 0);
         int vector_len = std::min(max_hotword_len, (int)chars.size());
         for (int i=0; i<chars.size(); i++) {
@@ -743,8 +747,10 @@ std::vector<std::vector<float>> Paraformer::CompileHotwordEmbedding(std::string
         }
         std::cout << std::endl;
         lengths.push_back(vector_len);
+        real_hw_size += 1;
         hotword_matrix.insert(hotword_matrix.end(), hw_vector.begin(), hw_vector.end());
       }
+      hotword_size = real_hw_size + 1;
     }
     std::vector<int32_t> blank_vec(max_hotword_len, 0);
     blank_vec[0] = 1;

+ 1 - 0
funasr/runtime/onnxruntime/src/seg_dict.cpp

@@ -40,6 +40,7 @@ std::vector<std::string> SegDict::GetTokensByWord(const std::string &word) {
   if (seg_dict.count(word))
     return seg_dict[word];
   else {
+    LOG(INFO)<< word <<" is OOV!";
     std::vector<string> vec;
     return vec;
   }