| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- #include "bias-lm.h"
- #ifdef _WIN32
- #include "fst-types.cc"
- #endif
- namespace funasr {
- void print(std::queue<StateId> &q) {
- std::queue<StateId> data = q;
- while (!data.empty())
- {
- cout << data.front() << " ";
- data.pop();
- }
- cout << endl;
- }
- void BiasLm::LoadCfgFromYaml(const char* filename, BiasLmOption &opt) {
- YAML::Node config;
- try {
- config = YAML::LoadFile(filename);
- } catch(exception const &e) {
- LOG(INFO) << "Error loading file, yaml file error or not exist.";
- exit(-1);
- }
- try {
- YAML::Node bias_lm_conf = config["bias_lm_conf"];
- opt_.incre_bias_ = bias_lm_conf["increment_weight"].as<float>();
- } catch(exception const &e) {
- }
- }
- void BiasLm::BuildGraph(std::vector<std::vector<int>> &split_id_vec,
- std::vector<float> &custom_weight) {
- if (split_id_vec.empty()) {
- LOG(INFO) << "Skip building biaslm graph, hotword not exits.";
- return ;
- }
- assert(split_id_vec.size() == custom_weight.size());
- // Build prefix tree
- std::unique_ptr<fst::StdVectorFst> prefix_tree(new fst::StdVectorFst());
- StateId start_state = prefix_tree->AddState();
- prefix_tree->SetStart(start_state);
- int id = 0;
- for (auto& x : split_id_vec) {
- StateId state = start_state;
- StateId next_state = state;
- float w = custom_weight[id++];
- std::vector<int> split_id = x;
- for (int j = 0; j < split_id.size(); j++) {
- next_state = prefix_tree->AddState();
- if (j == split_id.size() - 1) {
- prefix_tree->SetFinal(next_state, w);
- }
- prefix_tree->AddArc(state, Arc(split_id[j], split_id[j], opt_.incre_bias_, next_state));
- state = next_state;
- }
- }
- graph_ = std::unique_ptr<fst::StdVectorFst>(new fst::StdVectorFst());
- fst::Determinize(*prefix_tree, graph_.get());
- int num_node = graph_->NumStates();
- node_list_.resize(num_node);
- for (auto& x : split_id_vec) {
- StateId cur_state = 0;
- StateId next_state = 0;
- std::vector<int> split_id = x;
- for (int j = 0; j < split_id.size(); j++) {
- Matcher matcher(*graph_, fst::MATCH_INPUT);
- matcher.SetState(cur_state);
- if (matcher.Find(split_id[j])) {
- next_state = matcher.Value().nextstate;
- if (graph_->Final(next_state) != Weight::Zero()) {
- node_list_[next_state].is_final_ = true;
- }
- node_list_[next_state].score_ = opt_.incre_bias_ * (j + 1);
- cur_state = next_state;
- }
- }
- }
-
- // Build Aho-Corasick Automata
- std::queue<StateId> q;
- Matcher matcher(*graph_, fst::MATCH_INPUT);
- // Back off state of all child nodes of the root node points to the root node
- for (ArcIterator aiter(*graph_, start_state); !aiter.Done(); aiter.Next()) {
- const Arc& arc = aiter.Value();
- node_list_[arc.nextstate].back_off_ = start_state;
- float back_off_score = (node_list_[arc.nextstate].is_final_ ? 0 :
- node_list_[start_state].score_ - node_list_[arc.nextstate].score_);
- graph_->AddArc(arc.nextstate, Arc(0, 0, back_off_score, start_state));
- q.push(arc.nextstate);
- }
- while (!q.empty()) {
- StateId state_id = q.front();
- q.pop();
- for (ArcIterator aiter(*graph_, state_id); !aiter.Done(); aiter.Next()) {
- const Arc& arc = aiter.Value();
- StateId next_state = arc.nextstate;
- StateId temp_state = node_list_[state_id].back_off_;
- if (next_state == start_state || next_state == temp_state) {
- continue;
- }
- while (true) {
- matcher.SetState(temp_state);
- if (matcher.Find(arc.ilabel)) {
- node_list_[next_state].back_off_ = matcher.Value().nextstate;
- break;
- } else if (temp_state == start_state) {
- node_list_[next_state].back_off_ = start_state;
- break;
- }
- temp_state = node_list_[temp_state].back_off_;
- }
- float back_off_score = (node_list_[next_state].is_final_ ? 0 :
- node_list_[node_list_[next_state].back_off_].score_ -
- node_list_[next_state].score_);
- graph_->AddArc(next_state, Arc(0, 0, back_off_score,
- node_list_[next_state].back_off_));
- q.push(next_state);
- }
- }
- fst::ArcSort(graph_.get(), fst::StdILabelCompare());
- //graph_->Write("graph.final.fst");
- }
- float BiasLm::BiasLmScore(const StateId &his_state, const Label &lab, Label &new_state) {
- if (lab < 1 || lab > phn_set_.Size() || !graph_) { return VALUE_ZERO; }
- StateId cur_state = his_state;
- StateId next_state;
- float score = VALUE_ZERO;
- Matcher matcher(*graph_, fst::MATCH_INPUT);
- while (true) {
- StateId prev_state = cur_state;
- matcher.SetState(cur_state);
- if (matcher.Find(lab)) {
- next_state = matcher.Value().nextstate;
- score += matcher.Value().weight.Value();
- if (node_list_[next_state].is_final_) {
- score = score + graph_->Final(next_state).Value();
- }
- cur_state = next_state;
- break;
- } else {
- ArcIterator aiter(*graph_, cur_state);
- const Arc& arc = aiter.Value();
- if (arc.ilabel == 0) {
- score += arc.weight.Value();
- next_state = arc.nextstate;
- cur_state = next_state;
- }
- if (prev_state == ROOT_NODE && cur_state == ROOT_NODE) {
- break;
- }
- }
- }
- new_state = cur_state;
- return score;
- }
- void BiasLm::VocabIdToPhnIdVector(int vocab_id, std::vector<int> &phn_ids) {
- bool is_oov = false;
- phn_ids.clear();
- std::string word = vocab_.Id2String(vocab_id);
- std::vector<std::string> phn_vec;
- Utf8ToCharset(word, phn_vec);
- for (auto& phn : phn_vec) {
- if (!phn_set_.Find(phn)) {
- is_oov = true;
- break;
- } else {
- phn_ids.push_back(phn_set_.String2Id(phn));
- }
- }
- if (is_oov) { phn_ids.clear(); }
- }
- std::string BiasLm::GetPhoneLabel(int phone_id) {
- if (phone_id < 0 || phone_id >= phn_set_.Size()) { return ""; }
- return phn_set_.Id2String(phone_id);
- }
- }
|