bias-lm.cpp 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. #include "bias-lm.h"
  2. #ifdef _WIN32
  3. #include "fst-types.cc"
  4. #endif
  5. namespace funasr {
  6. void print(std::queue<StateId> &q) {
  7. std::queue<StateId> data = q;
  8. while (!data.empty())
  9. {
  10. cout << data.front() << " ";
  11. data.pop();
  12. }
  13. cout << endl;
  14. }
  15. void BiasLm::LoadCfgFromYaml(const char* filename, BiasLmOption &opt) {
  16. YAML::Node config;
  17. try {
  18. config = YAML::LoadFile(filename);
  19. } catch(exception const &e) {
  20. LOG(INFO) << "Error loading file, yaml file error or not exist.";
  21. exit(-1);
  22. }
  23. try {
  24. YAML::Node bias_lm_conf = config["bias_lm_conf"];
  25. opt_.incre_bias_ = bias_lm_conf["increment_weight"].as<float>();
  26. } catch(exception const &e) {
  27. }
  28. }
  29. void BiasLm::BuildGraph(std::vector<std::vector<int>> &split_id_vec,
  30. std::vector<float> &custom_weight) {
  31. if (split_id_vec.empty()) {
  32. LOG(INFO) << "Skip building biaslm graph, hotword not exits.";
  33. return ;
  34. }
  35. assert(split_id_vec.size() == custom_weight.size());
  36. // Build prefix tree
  37. std::unique_ptr<fst::StdVectorFst> prefix_tree(new fst::StdVectorFst());
  38. StateId start_state = prefix_tree->AddState();
  39. prefix_tree->SetStart(start_state);
  40. int id = 0;
  41. for (auto& x : split_id_vec) {
  42. StateId state = start_state;
  43. StateId next_state = state;
  44. float w = custom_weight[id++];
  45. std::vector<int> split_id = x;
  46. for (int j = 0; j < split_id.size(); j++) {
  47. next_state = prefix_tree->AddState();
  48. if (j == split_id.size() - 1) {
  49. prefix_tree->SetFinal(next_state, w);
  50. }
  51. prefix_tree->AddArc(state, Arc(split_id[j], split_id[j], opt_.incre_bias_, next_state));
  52. state = next_state;
  53. }
  54. }
  55. graph_ = std::unique_ptr<fst::StdVectorFst>(new fst::StdVectorFst());
  56. fst::Determinize(*prefix_tree, graph_.get());
  57. int num_node = graph_->NumStates();
  58. node_list_.resize(num_node);
  59. for (auto& x : split_id_vec) {
  60. StateId cur_state = 0;
  61. StateId next_state = 0;
  62. std::vector<int> split_id = x;
  63. for (int j = 0; j < split_id.size(); j++) {
  64. Matcher matcher(*graph_, fst::MATCH_INPUT);
  65. matcher.SetState(cur_state);
  66. if (matcher.Find(split_id[j])) {
  67. next_state = matcher.Value().nextstate;
  68. if (graph_->Final(next_state) != Weight::Zero()) {
  69. node_list_[next_state].is_final_ = true;
  70. }
  71. node_list_[next_state].score_ = opt_.incre_bias_ * (j + 1);
  72. cur_state = next_state;
  73. }
  74. }
  75. }
  76. // Build Aho-Corasick Automata
  77. std::queue<StateId> q;
  78. Matcher matcher(*graph_, fst::MATCH_INPUT);
  79. // Back off state of all child nodes of the root node points to the root node
  80. for (ArcIterator aiter(*graph_, start_state); !aiter.Done(); aiter.Next()) {
  81. const Arc& arc = aiter.Value();
  82. node_list_[arc.nextstate].back_off_ = start_state;
  83. float back_off_score = (node_list_[arc.nextstate].is_final_ ? 0 :
  84. node_list_[start_state].score_ - node_list_[arc.nextstate].score_);
  85. graph_->AddArc(arc.nextstate, Arc(0, 0, back_off_score, start_state));
  86. q.push(arc.nextstate);
  87. }
  88. while (!q.empty()) {
  89. StateId state_id = q.front();
  90. q.pop();
  91. for (ArcIterator aiter(*graph_, state_id); !aiter.Done(); aiter.Next()) {
  92. const Arc& arc = aiter.Value();
  93. StateId next_state = arc.nextstate;
  94. StateId temp_state = node_list_[state_id].back_off_;
  95. if (next_state == start_state || next_state == temp_state) {
  96. continue;
  97. }
  98. while (true) {
  99. matcher.SetState(temp_state);
  100. if (matcher.Find(arc.ilabel)) {
  101. node_list_[next_state].back_off_ = matcher.Value().nextstate;
  102. break;
  103. } else if (temp_state == start_state) {
  104. node_list_[next_state].back_off_ = start_state;
  105. break;
  106. }
  107. temp_state = node_list_[temp_state].back_off_;
  108. }
  109. float back_off_score = (node_list_[next_state].is_final_ ? 0 :
  110. node_list_[node_list_[next_state].back_off_].score_ -
  111. node_list_[next_state].score_);
  112. graph_->AddArc(next_state, Arc(0, 0, back_off_score,
  113. node_list_[next_state].back_off_));
  114. q.push(next_state);
  115. }
  116. }
  117. fst::ArcSort(graph_.get(), fst::StdILabelCompare());
  118. //graph_->Write("graph.final.fst");
  119. }
  120. float BiasLm::BiasLmScore(const StateId &his_state, const Label &lab, Label &new_state) {
  121. if (lab < 1 || lab > phn_set_.Size() || !graph_) { return VALUE_ZERO; }
  122. StateId cur_state = his_state;
  123. StateId next_state;
  124. float score = VALUE_ZERO;
  125. Matcher matcher(*graph_, fst::MATCH_INPUT);
  126. while (true) {
  127. StateId prev_state = cur_state;
  128. matcher.SetState(cur_state);
  129. if (matcher.Find(lab)) {
  130. next_state = matcher.Value().nextstate;
  131. score += matcher.Value().weight.Value();
  132. if (node_list_[next_state].is_final_) {
  133. score = score + graph_->Final(next_state).Value();
  134. }
  135. cur_state = next_state;
  136. break;
  137. } else {
  138. ArcIterator aiter(*graph_, cur_state);
  139. const Arc& arc = aiter.Value();
  140. if (arc.ilabel == 0) {
  141. score += arc.weight.Value();
  142. next_state = arc.nextstate;
  143. cur_state = next_state;
  144. }
  145. if (prev_state == ROOT_NODE && cur_state == ROOT_NODE) {
  146. break;
  147. }
  148. }
  149. }
  150. new_state = cur_state;
  151. return score;
  152. }
  153. void BiasLm::VocabIdToPhnIdVector(int vocab_id, std::vector<int> &phn_ids) {
  154. bool is_oov = false;
  155. phn_ids.clear();
  156. std::string word = vocab_.Id2String(vocab_id);
  157. std::vector<std::string> phn_vec;
  158. Utf8ToCharset(word, phn_vec);
  159. for (auto& phn : phn_vec) {
  160. if (!phn_set_.Find(phn)) {
  161. is_oov = true;
  162. break;
  163. } else {
  164. phn_ids.push_back(phn_set_.String2Id(phn));
  165. }
  166. }
  167. if (is_oov) { phn_ids.clear(); }
  168. }
  169. std::string BiasLm::GetPhoneLabel(int phone_id) {
  170. if (phone_id < 0 || phone_id >= phn_set_.Size()) { return ""; }
  171. return phn_set_.Id2String(phone_id);
  172. }
  173. }