simple-decoder.h 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. // decoder/simple-decoder.h
  2. // Copyright 2009-2013 Microsoft Corporation; Lukas Burget;
  3. // Saarland University (author: Arnab Ghoshal);
  4. // Johns Hopkins University (author: Daniel Povey)
  5. // See ../../COPYING for clarification regarding multiple authors
  6. //
  7. // Licensed under the Apache License, Version 2.0 (the "License");
  8. // you may not use this file except in compliance with the License.
  9. // You may obtain a copy of the License at
  10. //
  11. // http://www.apache.org/licenses/LICENSE-2.0
  12. //
  13. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  15. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  16. // MERCHANTABLITY OR NON-INFRINGEMENT.
  17. // See the Apache 2 License for the specific language governing permissions and
  18. // limitations under the License.
  19. #ifndef KALDI_DECODER_SIMPLE_DECODER_H_
  20. #define KALDI_DECODER_SIMPLE_DECODER_H_
  21. #include "util/stl-utils.h"
  22. #include "fst/fstlib.h"
  23. #include "lat/kaldi-lattice.h"
  24. #include "itf/decodable-itf.h"
  25. namespace kaldi {
  26. /** Simplest possible decoder, included largely for didactic purposes and as a
  27. means to debug more highly optimized decoders. See \ref decoders_simple
  28. for more information.
  29. */
  30. class SimpleDecoder {
  31. public:
  32. typedef fst::StdArc StdArc;
  33. typedef StdArc::Weight StdWeight;
  34. typedef StdArc::Label Label;
  35. typedef StdArc::StateId StateId;
  36. SimpleDecoder(const fst::Fst<fst::StdArc> &fst, BaseFloat beam): fst_(fst), beam_(beam) { }
  37. ~SimpleDecoder();
  38. /// Decode this utterance.
  39. /// Returns true if any tokens reached the end of the file (regardless of
  40. /// whether they are in a final state); query ReachedFinal() after Decode()
  41. /// to see whether we reached a final state.
  42. bool Decode(DecodableInterface *decodable);
  43. bool ReachedFinal() const;
  44. // GetBestPath gets the decoding traceback. If "use_final_probs" is true
  45. // AND we reached a final state, it limits itself to final states;
  46. // otherwise it gets the most likely token not taking into account final-probs.
  47. // fst_out will be empty (Start() == kNoStateId) if nothing was available due to
  48. // search error.
  49. // If Decode() returned true, it is safe to assume GetBestPath will return true.
  50. // It returns true if the output lattice was nonempty (i.e. had states in it);
  51. // using the return value is deprecated.
  52. bool GetBestPath(Lattice *fst_out, bool use_final_probs = true) const;
  53. /// *** The next functions are from the "new interface". ***
  54. /// FinalRelativeCost() serves the same function as ReachedFinal(), but gives
  55. /// more information. It returns the difference between the best (final-cost plus
  56. /// cost) of any token on the final frame, and the best cost of any token
  57. /// on the final frame. If it is infinity it means no final-states were present
  58. /// on the final frame. It will usually be nonnegative.
  59. BaseFloat FinalRelativeCost() const;
  60. /// InitDecoding initializes the decoding, and should only be used if you
  61. /// intend to call AdvanceDecoding(). If you call Decode(), you don't need
  62. /// to call this. You can call InitDecoding if you have already decoded an
  63. /// utterance and want to start with a new utterance.
  64. void InitDecoding();
  65. /// This will decode until there are no more frames ready in the decodable
  66. /// object, but if max_num_frames is >= 0 it will decode no more than
  67. /// that many frames. If it returns false, then no tokens are alive,
  68. /// which is a kind of error state.
  69. void AdvanceDecoding(DecodableInterface *decodable,
  70. int32 max_num_frames = -1);
  71. /// Returns the number of frames already decoded.
  72. int32 NumFramesDecoded() const { return num_frames_decoded_; }
  73. private:
  74. class Token {
  75. public:
  76. LatticeArc arc_; // We use LatticeArc so that we can separately
  77. // store the acoustic and graph cost, in case
  78. // we need to produce lattice-formatted output.
  79. Token *prev_;
  80. int32 ref_count_;
  81. double cost_; // accumulated total cost up to this point.
  82. Token(const StdArc &arc,
  83. BaseFloat acoustic_cost,
  84. Token *prev): prev_(prev), ref_count_(1) {
  85. arc_.ilabel = arc.ilabel;
  86. arc_.olabel = arc.olabel;
  87. arc_.weight = LatticeWeight(arc.weight.Value(), acoustic_cost);
  88. arc_.nextstate = arc.nextstate;
  89. if (prev) {
  90. prev->ref_count_++;
  91. cost_ = prev->cost_ + (arc.weight.Value() + acoustic_cost);
  92. } else {
  93. cost_ = arc.weight.Value() + acoustic_cost;
  94. }
  95. }
  96. bool operator < (const Token &other) {
  97. return cost_ > other.cost_;
  98. }
  99. static void TokenDelete(Token *tok) {
  100. while (--tok->ref_count_ == 0) {
  101. Token *prev = tok->prev_;
  102. delete tok;
  103. if (prev == NULL) return;
  104. else tok = prev;
  105. }
  106. #ifdef KALDI_PARANOID
  107. KALDI_ASSERT(tok->ref_count_ > 0);
  108. #endif
  109. }
  110. };
  111. // ProcessEmitting decodes the frame num_frames_decoded_ of the
  112. // decodable object, then increments num_frames_decoded_.
  113. void ProcessEmitting(DecodableInterface *decodable);
  114. void ProcessNonemitting();
  115. unordered_map<StateId, Token*> cur_toks_;
  116. unordered_map<StateId, Token*> prev_toks_;
  117. const fst::Fst<fst::StdArc> &fst_;
  118. BaseFloat beam_;
  119. // Keep track of the number of frames decoded in the current file.
  120. int32 num_frames_decoded_;
  121. static void ClearToks(unordered_map<StateId, Token*> &toks);
  122. static void PruneToks(BaseFloat beam, unordered_map<StateId, Token*> *toks);
  123. KALDI_DISALLOW_COPY_AND_ASSIGN(SimpleDecoder);
  124. };
  125. } // end namespace kaldi.
  126. #endif