lattice-simple-decoder.h 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. // decoder/lattice-simple-decoder.h
  2. // Copyright 2009-2012 Microsoft Corporation
  3. // 2012-2014 Johns Hopkins University (Author: Daniel Povey)
  4. // 2014 Guoguo Chen
  5. // See ../../COPYING for clarification regarding multiple authors
  6. //
  7. // Licensed under the Apache License, Version 2.0 (the "License");
  8. // you may not use this file except in compliance with the License.
  9. // You may obtain a copy of the License at
  10. //
  11. // http://www.apache.org/licenses/LICENSE-2.0
  12. //
  13. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  15. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  16. // MERCHANTABLITY OR NON-INFRINGEMENT.
  17. // See the Apache 2 License for the specific language governing permissions and
  18. // limitations under the License.
  19. #ifndef KALDI_DECODER_LATTICE_SIMPLE_DECODER_H_
  20. #define KALDI_DECODER_LATTICE_SIMPLE_DECODER_H_
  21. #include "util/stl-utils.h"
  22. #include "fst/fstlib.h"
  23. #include "itf/decodable-itf.h"
  24. #include "fstext/fstext-lib.h"
  25. #include "lat/determinize-lattice-pruned.h"
  26. #include "lat/kaldi-lattice.h"
  27. #include <algorithm>
  28. namespace kaldi {
  29. struct LatticeSimpleDecoderConfig {
  30. BaseFloat beam;
  31. BaseFloat lattice_beam;
  32. int32 prune_interval;
  33. bool determinize_lattice; // not inspected by this class... used in
  34. // command-line program.
  35. bool prune_lattice;
  36. BaseFloat beam_ratio;
  37. BaseFloat prune_scale; // Note: we don't make this configurable on the command line,
  38. // it's not a very important parameter. It affects the
  39. // algorithm that prunes the tokens as we go.
  40. fst::DeterminizeLatticePhonePrunedOptions det_opts;
  41. LatticeSimpleDecoderConfig(): beam(16.0),
  42. lattice_beam(10.0),
  43. prune_interval(25),
  44. determinize_lattice(true),
  45. beam_ratio(0.9),
  46. prune_scale(0.1) { }
  47. void Register(OptionsItf *opts) {
  48. det_opts.Register(opts);
  49. opts->Register("beam", &beam, "Decoding beam.");
  50. opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam");
  51. opts->Register("prune-interval", &prune_interval, "Interval (in frames) at "
  52. "which to prune tokens");
  53. opts->Register("determinize-lattice", &determinize_lattice, "If true, "
  54. "determinize the lattice (in a special sense, keeping only "
  55. "best pdf-sequence for each word-sequence).");
  56. }
  57. void Check() const {
  58. KALDI_ASSERT(beam > 0.0 && lattice_beam > 0.0 && prune_interval > 0);
  59. }
  60. };
  61. /** Simplest possible decoder, included largely for didactic purposes and as a
  62. means to debug more highly optimized decoders. See \ref decoders_simple
  63. for more information.
  64. */
  65. class LatticeSimpleDecoder {
  66. public:
  67. typedef fst::StdArc Arc;
  68. typedef Arc::Label Label;
  69. typedef Arc::StateId StateId;
  70. typedef Arc::Weight Weight;
  71. // instantiate this class once for each thing you have to decode.
  72. LatticeSimpleDecoder(const fst::Fst<fst::StdArc> &fst,
  73. const LatticeSimpleDecoderConfig &config):
  74. fst_(fst), config_(config), num_toks_(0) { config.Check(); }
  75. ~LatticeSimpleDecoder() { ClearActiveTokens(); }
  76. const LatticeSimpleDecoderConfig &GetOptions() const {
  77. return config_;
  78. }
  79. // Returns true if any kind of traceback is available (not necessarily from
  80. // a final state).
  81. bool Decode(DecodableInterface *decodable);
  82. /// says whether a final-state was active on the last frame. If it was not, the
  83. /// lattice (or traceback) will end with states that are not final-states.
  84. bool ReachedFinal() const {
  85. return FinalRelativeCost() != std::numeric_limits<BaseFloat>::infinity();
  86. }
  87. /// InitDecoding initializes the decoding, and should only be used if you
  88. /// intend to call AdvanceDecoding(). If you call Decode(), you don't need
  89. /// to call this. You can call InitDecoding if you have already decoded an
  90. /// utterance and want to start with a new utterance.
  91. void InitDecoding();
  92. /// This function may be optionally called after AdvanceDecoding(), when you
  93. /// do not plan to decode any further. It does an extra pruning step that
  94. /// will help to prune the lattices output by GetLattice and (particularly)
  95. /// GetRawLattice more accurately, particularly toward the end of the
  96. /// utterance. It does this by using the final-probs in pruning (if any
  97. /// final-state survived); it also does a final pruning step that visits all
  98. /// states (the pruning that is done during decoding may fail to prune states
  99. /// that are within kPruningScale = 0.1 outside of the beam). If you call
  100. /// this, you cannot call AdvanceDecoding again (it will fail), and you
  101. /// cannot call GetLattice() and related functions with use_final_probs =
  102. /// false.
  103. /// Used to be called PruneActiveTokensFinal().
  104. void FinalizeDecoding();
  105. /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives
  106. /// more information. It returns the difference between the best (final-cost
  107. /// plus cost) of any token on the final frame, and the best cost of any token
  108. /// on the final frame. If it is infinity it means no final-states were
  109. /// present on the final frame. It will usually be nonnegative. If it not
  110. /// too positive (e.g. < 5 is my first guess, but this is not tested) you can
  111. /// take it as a good indication that we reached the final-state with
  112. /// reasonable likelihood.
  113. BaseFloat FinalRelativeCost() const;
  114. // Outputs an FST corresponding to the single best path
  115. // through the lattice. Returns true if result is nonempty
  116. // (using the return status is deprecated, it will become void).
  117. // If "use_final_probs" is true AND we reached the final-state
  118. // of the graph then it will include those as final-probs, else
  119. // it will treat all final-probs as one.
  120. bool GetBestPath(Lattice *lat,
  121. bool use_final_probs = true) const;
  122. // Outputs an FST corresponding to the raw, state-level
  123. // tracebacks. Returns true if result is nonempty
  124. // (using the return status is deprecated, it will become void).
  125. // If "use_final_probs" is true AND we reached the final-state
  126. // of the graph then it will include those as final-probs, else
  127. // it will treat all final-probs as one.
  128. bool GetRawLattice(Lattice *lat,
  129. bool use_final_probs = true) const;
  130. // This function is now deprecated, since now we do determinization from
  131. // outside the LatticeTrackingDecoder class.
  132. // Outputs an FST corresponding to the lattice-determinized
  133. // lattice (one path per word sequence). [will become deprecated,
  134. // users should determinize themselves.]
  135. bool GetLattice(CompactLattice *clat,
  136. bool use_final_probs = true) const;
  137. inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; }
  138. private:
  139. struct Token;
  140. // ForwardLinks are the links from a token to a token on the next frame.
  141. // or sometimes on the current frame (for input-epsilon links).
  142. struct ForwardLink {
  143. Token *next_tok; // the next token [or NULL if represents final-state]
  144. Label ilabel; // ilabel on link.
  145. Label olabel; // olabel on link.
  146. BaseFloat graph_cost; // graph cost of traversing link (contains LM, etc.)
  147. BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing link
  148. ForwardLink *next; // next in singly-linked list of forward links from a
  149. // token.
  150. ForwardLink(Token *next_tok, Label ilabel, Label olabel,
  151. BaseFloat graph_cost, BaseFloat acoustic_cost,
  152. ForwardLink *next):
  153. next_tok(next_tok), ilabel(ilabel), olabel(olabel),
  154. graph_cost(graph_cost), acoustic_cost(acoustic_cost),
  155. next(next) { }
  156. };
  157. // Token is what's resident in a particular state at a particular time.
  158. // In this decoder a Token actually contains *forward* links.
  159. // When first created, a Token just has the (total) cost. We add forward
  160. // links from it when we process the next frame.
  161. struct Token {
  162. BaseFloat tot_cost; // would equal weight.Value()... cost up to this point.
  163. BaseFloat extra_cost; // >= 0. After calling PruneForwardLinks, this equals
  164. // the minimum difference between the cost of the best path this is on,
  165. // and the cost of the absolute best path, under the assumption
  166. // that any of the currently active states at the decoding front may
  167. // eventually succeed (e.g. if you were to take the currently active states
  168. // one by one and compute this difference, and then take the minimum).
  169. ForwardLink *links; // Head of singly linked list of ForwardLinks
  170. Token *next; // Next in list of tokens for this frame.
  171. Token(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLink *links,
  172. Token *next): tot_cost(tot_cost), extra_cost(extra_cost), links(links),
  173. next(next) { }
  174. Token() {}
  175. void DeleteForwardLinks() {
  176. ForwardLink *l = links, *m;
  177. while (l != NULL) {
  178. m = l->next;
  179. delete l;
  180. l = m;
  181. }
  182. links = NULL;
  183. }
  184. };
  185. // head and tail of per-frame list of Tokens (list is in topological order),
  186. // and something saying whether we ever pruned it using PruneForwardLinks.
  187. struct TokenList {
  188. Token *toks;
  189. bool must_prune_forward_links;
  190. bool must_prune_tokens;
  191. TokenList(): toks(NULL), must_prune_forward_links(true),
  192. must_prune_tokens(true) { }
  193. };
  194. // FindOrAddToken either locates a token in cur_toks_, or if necessary inserts a new,
  195. // empty token (i.e. with no forward links) for the current frame. [note: it's
  196. // inserted if necessary into cur_toks_ and also into the singly linked list
  197. // of tokens active on this frame (whose head is at active_toks_[frame]).
  198. //
  199. // Returns the Token pointer. Sets "changed" (if non-NULL) to true
  200. // if the token was newly created or the cost changed.
  201. inline Token *FindOrAddToken(StateId state, int32 frame_plus_one,
  202. BaseFloat tot_cost, bool emitting, bool *changed);
  203. // delta is the amount by which the extra_costs must
  204. // change before it sets "extra_costs_changed" to true. If delta is larger,
  205. // we'll tend to go back less far toward the beginning of the file.
  206. void PruneForwardLinks(int32 frame, bool *extra_costs_changed,
  207. bool *links_pruned,
  208. BaseFloat delta);
  209. // PruneForwardLinksFinal is a version of PruneForwardLinks that we call
  210. // on the final frame. If there are final tokens active, it uses the final-probs
  211. // for pruning, otherwise it treats all tokens as final.
  212. void PruneForwardLinksFinal();
  213. // Prune away any tokens on this frame that have no forward links. [we don't do
  214. // this in PruneForwardLinks because it would give us a problem with dangling
  215. // pointers].
  216. void PruneTokensForFrame(int32 frame);
  217. // Go backwards through still-alive tokens, pruning them if the
  218. // forward+backward cost is more than lat_beam away from the best path. It's
  219. // possible to prove that this is "correct" in the sense that we won't lose
  220. // anything outside of lat_beam, regardless of what happens in the future.
  221. // delta controls when it considers a cost to have changed enough to continue
  222. // going backward and propagating the change. larger delta -> will recurse
  223. // less far.
  224. void PruneActiveTokens(BaseFloat delta);
  225. void ProcessEmitting(DecodableInterface *decodable);
  226. void ProcessNonemitting();
  227. void ClearActiveTokens(); // a cleanup routine, at utt end/begin
  228. // This function computes the final-costs for tokens active on the final
  229. // frame. It outputs to final-costs, if non-NULL, a map from the Token*
  230. // pointer to the final-prob of the corresponding state, or zero for all states if
  231. // none were final. It outputs to final_relative_cost, if non-NULL, the
  232. // difference between the best forward-cost including the final-prob cost, and
  233. // the best forward-cost without including the final-prob cost (this will
  234. // usually be positive), or infinity if there were no final-probs. It outputs
  235. // to final_best_cost, if non-NULL, the lowest for any token t active on the
  236. // final frame, of t + final-cost[t], where final-cost[t] is the final-cost
  237. // in the graph of the state corresponding to token t, or zero if there
  238. // were no final-probs active on the final frame.
  239. // You cannot call this after FinalizeDecoding() has been called; in that
  240. // case you should get the answer from class-member variables.
  241. void ComputeFinalCosts(unordered_map<Token*, BaseFloat> *final_costs,
  242. BaseFloat *final_relative_cost,
  243. BaseFloat *final_best_cost) const;
  244. // PruneCurrentTokens deletes the tokens from the "toks" map, but not
  245. // from the active_toks_ list, which could cause dangling forward pointers
  246. // (will delete it during regular pruning operation).
  247. void PruneCurrentTokens(BaseFloat beam, unordered_map<StateId, Token*> *toks);
  248. unordered_map<StateId, Token*> cur_toks_;
  249. unordered_map<StateId, Token*> prev_toks_;
  250. std::vector<TokenList> active_toks_; // Lists of tokens, indexed by
  251. // frame_plus_one
  252. const fst::Fst<fst::StdArc> &fst_;
  253. LatticeSimpleDecoderConfig config_;
  254. int32 num_toks_; // current total #toks allocated...
  255. bool warned_;
  256. /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note,
  257. /// calling this is optional]. If true, it's forbidden to decode more. Also,
  258. /// if this is set, then the output of ComputeFinalCosts() is in the next
  259. /// three variables. The reason we need to do this is that after
  260. /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some
  261. /// of the tokens on the last frame are freed, so we free the list from
  262. /// cur_toks_ to avoid having dangling pointers hanging around.
  263. bool decoding_finalized_;
  264. /// For the meaning of the next 3 variables, see the comment for
  265. /// decoding_finalized_ above., and ComputeFinalCosts().
  266. unordered_map<Token*, BaseFloat> final_costs_;
  267. BaseFloat final_relative_cost_;
  268. BaseFloat final_best_cost_;
  269. };
  270. } // end namespace kaldi.
  271. #endif