faster-decoder.h 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. // decoder/faster-decoder.h
  2. // Copyright 2009-2011 Microsoft Corporation
  3. // 2013 Johns Hopkins University (author: Daniel Povey)
  4. // See ../../COPYING for clarification regarding multiple authors
  5. //
  6. // Licensed under the Apache License, Version 2.0 (the "License");
  7. // you may not use this file except in compliance with the License.
  8. // You may obtain a copy of the License at
  9. //
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. //
  12. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  13. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  14. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  15. // MERCHANTABLITY OR NON-INFRINGEMENT.
  16. // See the Apache 2 License for the specific language governing permissions and
  17. // limitations under the License.
  18. #ifndef KALDI_DECODER_FASTER_DECODER_H_
  19. #define KALDI_DECODER_FASTER_DECODER_H_
  20. #include "util/stl-utils.h"
  21. #include "itf/options-itf.h"
  22. #include "util/hash-list.h"
  23. #include "fst/fstlib.h"
  24. #include "itf/decodable-itf.h"
  25. #include "lat/kaldi-lattice.h" // for CompactLatticeArc
  26. namespace kaldi {
  27. struct FasterDecoderOptions {
  28. BaseFloat beam;
  29. int32 max_active;
  30. int32 min_active;
  31. BaseFloat beam_delta;
  32. BaseFloat hash_ratio;
  33. FasterDecoderOptions(): beam(16.0),
  34. max_active(std::numeric_limits<int32>::max()),
  35. min_active(20), // This decoder mostly used for
  36. // alignment, use small default.
  37. beam_delta(0.5),
  38. hash_ratio(2.0) { }
  39. void Register(OptionsItf *opts, bool full) { /// if "full", use obscure
  40. /// options too.
  41. /// Depends on program.
  42. opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate.");
  43. opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; "
  44. "more accurate");
  45. opts->Register("min-active", &min_active,
  46. "Decoder min active states (don't prune if #active less than this).");
  47. if (full) {
  48. opts->Register("beam-delta", &beam_delta,
  49. "Increment used in decoder [obscure setting]");
  50. opts->Register("hash-ratio", &hash_ratio,
  51. "Setting used in decoder to control hash behavior");
  52. }
  53. }
  54. };
  55. class FasterDecoder {
  56. public:
  57. typedef fst::StdArc Arc;
  58. typedef Arc::Label Label;
  59. typedef Arc::StateId StateId;
  60. typedef Arc::Weight Weight;
  61. FasterDecoder(const fst::Fst<fst::StdArc> &fst,
  62. const FasterDecoderOptions &config);
  63. void SetOptions(const FasterDecoderOptions &config) { config_ = config; }
  64. ~FasterDecoder() { ClearToks(toks_.Clear()); }
  65. void Decode(DecodableInterface *decodable);
  66. /// Returns true if a final state was active on the last frame.
  67. bool ReachedFinal() const;
  68. /// GetBestPath gets the decoding traceback. If "use_final_probs" is true
  69. /// AND we reached a final state, it limits itself to final states;
  70. /// otherwise it gets the most likely token not taking into account
  71. /// final-probs. Returns true if the output best path was not the empty
  72. /// FST (will only return false in unusual circumstances where
  73. /// no tokens survived).
  74. bool GetBestPath(fst::MutableFst<LatticeArc> *fst_out,
  75. bool use_final_probs = true);
  76. /// As a new alternative to Decode(), you can call InitDecoding
  77. /// and then (possibly multiple times) AdvanceDecoding().
  78. void InitDecoding();
  79. /// This will decode until there are no more frames ready in the decodable
  80. /// object, but if max_num_frames is >= 0 it will decode no more than
  81. /// that many frames.
  82. void AdvanceDecoding(DecodableInterface *decodable,
  83. int32 max_num_frames = -1);
  84. /// Returns the number of frames already decoded.
  85. int32 NumFramesDecoded() const { return num_frames_decoded_; }
  86. protected:
  87. class Token {
  88. public:
  89. Arc arc_; // contains only the graph part of the cost;
  90. // we can work out the acoustic part from difference between
  91. // "cost_" and prev->cost_.
  92. Token *prev_;
  93. int32 ref_count_;
  94. // if you are looking for weight_ here, it was removed and now we just have
  95. // cost_, which corresponds to ConvertToCost(weight_).
  96. double cost_;
  97. inline Token(const Arc &arc, BaseFloat ac_cost, Token *prev):
  98. arc_(arc), prev_(prev), ref_count_(1) {
  99. if (prev) {
  100. prev->ref_count_++;
  101. cost_ = prev->cost_ + arc.weight.Value() + ac_cost;
  102. } else {
  103. cost_ = arc.weight.Value() + ac_cost;
  104. }
  105. }
  106. inline Token(const Arc &arc, Token *prev):
  107. arc_(arc), prev_(prev), ref_count_(1) {
  108. if (prev) {
  109. prev->ref_count_++;
  110. cost_ = prev->cost_ + arc.weight.Value();
  111. } else {
  112. cost_ = arc.weight.Value();
  113. }
  114. }
  115. inline bool operator < (const Token &other) {
  116. return cost_ > other.cost_;
  117. }
  118. inline static void TokenDelete(Token *tok) {
  119. while (--tok->ref_count_ == 0) {
  120. Token *prev = tok->prev_;
  121. delete tok;
  122. if (prev == NULL) return;
  123. else tok = prev;
  124. }
  125. #ifdef KALDI_PARANOID
  126. KALDI_ASSERT(tok->ref_count_ > 0);
  127. #endif
  128. }
  129. };
  130. typedef HashList<StateId, Token*>::Elem Elem;
  131. /// Gets the weight cutoff. Also counts the active tokens.
  132. double GetCutoff(Elem *list_head, size_t *tok_count,
  133. BaseFloat *adaptive_beam, Elem **best_elem);
  134. void PossiblyResizeHash(size_t num_toks);
  135. // ProcessEmitting returns the likelihood cutoff used.
  136. // It decodes the frame num_frames_decoded_ of the decodable object
  137. // and then increments num_frames_decoded_
  138. double ProcessEmitting(DecodableInterface *decodable);
  139. // TODO: first time we go through this, could avoid using the queue.
  140. void ProcessNonemitting(double cutoff);
  141. // HashList defined in ../util/hash-list.h. It actually allows us to maintain
  142. // more than one list (e.g. for current and previous frames), but only one of
  143. // them at a time can be indexed by StateId.
  144. HashList<StateId, Token*> toks_;
  145. const fst::Fst<fst::StdArc> &fst_;
  146. FasterDecoderOptions config_;
  147. std::vector<const Elem* > queue_; // temp variable used in ProcessNonemitting,
  148. std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
  149. // make it class member to avoid internal new/delete.
  150. // Keep track of the number of frames decoded in the current file.
  151. int32 num_frames_decoded_;
  152. // It might seem unclear why we call ClearToks(toks_.Clear()).
  153. // There are two separate cleanup tasks we need to do at when we start a new file.
  154. // one is to delete the Token objects in the list; the other is to delete
  155. // the Elem objects. toks_.Clear() just clears them from the hash and gives ownership
  156. // to the caller, who then has to call toks_.Delete(e) for each one. It was designed
  157. // this way for convenience in propagating tokens from one frame to the next.
  158. void ClearToks(Elem *list);
  159. KALDI_DISALLOW_COPY_AND_ASSIGN(FasterDecoder);
  160. };
  161. } // end namespace kaldi.
  162. #endif