lattice-incremental-online-decoder.cc 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. // decoder/lattice-incremental-online-decoder.cc
  2. // Copyright 2019 Zhehuai Chen
  3. // See ../../COPYING for clarification regarding multiple authors
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  12. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  13. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  14. // MERCHANTABLITY OR NON-INFRINGEMENT.
  15. // See the Apache 2 License for the specific language governing permissions and
  16. // limitations under the License.
  17. // see note at the top of lattice-faster-decoder.cc, about how to maintain this
  18. // file in sync with lattice-faster-decoder.cc
  19. #include "decoder/lattice-incremental-decoder.h"
  20. #include "decoder/lattice-incremental-online-decoder.h"
  21. #include "lat/lattice-functions.h"
  22. #include "base/timer.h"
  23. namespace kaldi {
  24. // Outputs an FST corresponding to the single best path through the lattice.
  25. template <typename FST>
  26. bool LatticeIncrementalOnlineDecoderTpl<FST>::GetBestPath(Lattice *olat,
  27. bool use_final_probs) const {
  28. olat->DeleteStates();
  29. BaseFloat final_graph_cost;
  30. BestPathIterator iter = BestPathEnd(use_final_probs, &final_graph_cost);
  31. if (iter.Done())
  32. return false; // would have printed warning.
  33. StateId state = olat->AddState();
  34. olat->SetFinal(state, LatticeWeight(final_graph_cost, 0.0));
  35. while (!iter.Done()) {
  36. LatticeArc arc;
  37. iter = TraceBackBestPath(iter, &arc);
  38. arc.nextstate = state;
  39. StateId new_state = olat->AddState();
  40. olat->AddArc(new_state, arc);
  41. state = new_state;
  42. }
  43. olat->SetStart(state);
  44. return true;
  45. }
  46. template <typename FST>
  47. typename LatticeIncrementalOnlineDecoderTpl<FST>::BestPathIterator LatticeIncrementalOnlineDecoderTpl<FST>::BestPathEnd(
  48. bool use_final_probs,
  49. BaseFloat *final_cost_out) const {
  50. if (this->decoding_finalized_ && !use_final_probs)
  51. KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
  52. << "BestPathEnd() with use_final_probs == false";
  53. KALDI_ASSERT(this->NumFramesDecoded() > 0 &&
  54. "You cannot call BestPathEnd if no frames were decoded.");
  55. unordered_map<Token*, BaseFloat> final_costs_local;
  56. const unordered_map<Token*, BaseFloat> &final_costs =
  57. (this->decoding_finalized_ ? this->final_costs_ :final_costs_local);
  58. if (!this->decoding_finalized_ && use_final_probs)
  59. this->ComputeFinalCosts(&final_costs_local, NULL, NULL);
  60. // Singly linked list of tokens on last frame (access list through "next"
  61. // pointer).
  62. BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
  63. BaseFloat best_final_cost = 0;
  64. Token *best_tok = NULL;
  65. for (Token *tok = this->active_toks_.back().toks;
  66. tok != NULL; tok = tok->next) {
  67. BaseFloat cost = tok->tot_cost, final_cost = 0.0;
  68. if (use_final_probs && !final_costs.empty()) {
  69. // if we are instructed to use final-probs, and any final tokens were
  70. // active on final frame, include the final-prob in the cost of the token.
  71. typename unordered_map<Token*, BaseFloat>::const_iterator
  72. iter = final_costs.find(tok);
  73. if (iter != final_costs.end()) {
  74. final_cost = iter->second;
  75. cost += final_cost;
  76. } else {
  77. cost = std::numeric_limits<BaseFloat>::infinity();
  78. }
  79. }
  80. if (cost < best_cost) {
  81. best_cost = cost;
  82. best_tok = tok;
  83. best_final_cost = final_cost;
  84. }
  85. }
  86. if (best_tok == NULL) { // this should not happen, and is likely a code error or
  87. // caused by infinities in likelihoods, but I'm not making
  88. // it a fatal error for now.
  89. KALDI_WARN << "No final token found.";
  90. }
  91. if (final_cost_out != NULL)
  92. *final_cost_out = best_final_cost;
  93. return BestPathIterator(best_tok, this->NumFramesDecoded() - 1);
  94. }
  95. template <typename FST>
  96. typename LatticeIncrementalOnlineDecoderTpl<FST>::BestPathIterator LatticeIncrementalOnlineDecoderTpl<FST>::TraceBackBestPath(
  97. BestPathIterator iter, LatticeArc *oarc) const {
  98. KALDI_ASSERT(!iter.Done() && oarc != NULL);
  99. Token *tok = static_cast<Token*>(iter.tok);
  100. int32 cur_t = iter.frame, step_t = 0;
  101. if (tok->backpointer != NULL) {
  102. // retrieve the correct forward link(with the best link cost)
  103. BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
  104. ForwardLinkT *link;
  105. for (link = tok->backpointer->links;
  106. link != NULL; link = link->next) {
  107. if (link->next_tok == tok) { // this is the a to "tok"
  108. BaseFloat graph_cost = link->graph_cost,
  109. acoustic_cost = link->acoustic_cost;
  110. BaseFloat cost = graph_cost + acoustic_cost;
  111. if (cost < best_cost) {
  112. oarc->ilabel = link->ilabel;
  113. oarc->olabel = link->olabel;
  114. if (link->ilabel != 0) {
  115. KALDI_ASSERT(static_cast<size_t>(cur_t) < this->cost_offsets_.size());
  116. acoustic_cost -= this->cost_offsets_[cur_t];
  117. step_t = -1;
  118. } else {
  119. step_t = 0;
  120. }
  121. oarc->weight = LatticeWeight(graph_cost, acoustic_cost);
  122. best_cost = cost;
  123. }
  124. }
  125. }
  126. if (link == NULL &&
  127. best_cost == std::numeric_limits<BaseFloat>::infinity()) { // Did not find correct link.
  128. KALDI_ERR << "Error tracing best-path back (likely "
  129. << "bug in token-pruning algorithm)";
  130. }
  131. } else {
  132. oarc->ilabel = 0;
  133. oarc->olabel = 0;
  134. oarc->weight = LatticeWeight::One(); // zero costs.
  135. }
  136. return BestPathIterator(tok->backpointer, cur_t + step_t);
  137. }
  138. // Instantiate the template for the FST types that we'll need.
  139. template class LatticeIncrementalOnlineDecoderTpl<fst::Fst<fst::StdArc> >;
  140. template class LatticeIncrementalOnlineDecoderTpl<fst::VectorFst<fst::StdArc> >;
  141. template class LatticeIncrementalOnlineDecoderTpl<fst::ConstFst<fst::StdArc> >;
  142. template class LatticeIncrementalOnlineDecoderTpl<fst::ConstGrammarFst >;
  143. template class LatticeIncrementalOnlineDecoderTpl<fst::VectorGrammarFst >;
  144. } // end namespace kaldi.