decoder-wrappers.h 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. // decoder/decoder-wrappers.h
  2. // Copyright 2014 Johns Hopkins University (author: Daniel Povey)
  3. // See ../../COPYING for clarification regarding multiple authors
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  12. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  13. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  14. // MERCHANTABLITY OR NON-INFRINGEMENT.
  15. // See the Apache 2 License for the specific language governing permissions and
  16. // limitations under the License.
  17. #ifndef KALDI_DECODER_DECODER_WRAPPERS_H_
  18. #define KALDI_DECODER_DECODER_WRAPPERS_H_
  19. #include "itf/options-itf.h"
  20. #include "decoder/lattice-faster-decoder.h"
  21. #include "decoder/lattice-incremental-decoder.h"
  22. #include "decoder/lattice-simple-decoder.h"
  23. // This header contains declarations from various convenience functions that are called
  24. // from binary-level programs such as gmm-decode-faster.cc, gmm-align-compiled.cc, and
  25. // so on.
  26. namespace kaldi {
  27. struct AlignConfig {
  28. BaseFloat beam;
  29. BaseFloat retry_beam;
  30. bool careful;
  31. AlignConfig(): beam(200.0), retry_beam(0.0), careful(false) { }
  32. void Register(OptionsItf *opts) {
  33. opts->Register("beam", &beam, "Decoding beam used in alignment");
  34. opts->Register("retry-beam", &retry_beam,
  35. "Decoding beam for second try at alignment");
  36. opts->Register("careful", &careful,
  37. "If true, do 'careful' alignment, which is better at detecting "
  38. "alignment failure (involves loop to start of decoding graph).");
  39. }
  40. };
  41. /// AlignUtteranceWapper is a wrapper for alignment code used in training, that
  42. /// is called from many different binaries, e.g. gmm-align, gmm-align-compiled,
  43. /// sgmm-align, etc. The writers for alignments and words will only be written
  44. /// to if they are open. The num_done, num_error, num_retried, tot_like and
  45. /// frame_count pointers will (if non-NULL) be incremented or added to, not set,
  46. /// by this function.
  47. void AlignUtteranceWrapper(
  48. const AlignConfig &config,
  49. const std::string &utt,
  50. BaseFloat acoustic_scale, // affects scores written to scores_writer, if
  51. // present
  52. fst::VectorFst<fst::StdArc> *fst, // non-const in case config.careful ==
  53. // true, we add loop.
  54. DecodableInterface *decodable, // not const but is really an input.
  55. Int32VectorWriter *alignment_writer,
  56. BaseFloatWriter *scores_writer,
  57. int32 *num_done,
  58. int32 *num_error,
  59. int32 *num_retried,
  60. double *tot_like,
  61. int64 *frame_count,
  62. BaseFloatVectorWriter *per_frame_acwt_writer = NULL);
  63. /// This function modifies the decoding graph for what we call "careful
  64. /// alignment". The problem we are trying to solve is that if the decoding eats
  65. /// up the words in the graph too fast, it can get stuck at the end, and produce
  66. /// what looks like a valid alignment even though there was really a failure.
  67. /// So what we want to do is to introduce, after the final-states of the graph,
  68. /// a "blind alley" with no final-probs reachable, where the decoding can go to
  69. /// get lost. Our basic idea is to append the decoding-graph to itself using
  70. /// the fst Concat operation; but in order that there should be final-probs at the end of
  71. /// the first but not the second FST, we modify the right-hand argument to the
  72. /// Concat operation so that it has none of the original final-probs, and add
  73. /// a "pre-initial" state that is final.
  74. void ModifyGraphForCarefulAlignment(
  75. fst::VectorFst<fst::StdArc> *fst);
  76. /// TODO
  77. template <typename FST>
  78. bool DecodeUtteranceLatticeIncremental(
  79. LatticeIncrementalDecoderTpl<FST> &decoder, // not const but is really an input.
  80. DecodableInterface &decodable, // not const but is really an input.
  81. const TransitionInformation &trans_model,
  82. const fst::SymbolTable *word_syms,
  83. std::string utt,
  84. double acoustic_scale,
  85. bool determinize,
  86. bool allow_partial,
  87. Int32VectorWriter *alignments_writer,
  88. Int32VectorWriter *words_writer,
  89. CompactLatticeWriter *compact_lattice_writer,
  90. LatticeWriter *lattice_writer,
  91. double *like_ptr); // puts utterance's likelihood in like_ptr on success.
  92. /// This function DecodeUtteranceLatticeFaster is used in several decoders, and
  93. /// we have moved it here. Note: this is really "binary-level" code as it
  94. /// involves table readers and writers; we've just put it here as there is no
  95. /// other obvious place to put it. If determinize == false, it writes to
  96. /// lattice_writer, else to compact_lattice_writer. The writers for
  97. /// alignments and words will only be written to if they are open.
  98. ///
  99. /// Caution: this will only link correctly if FST is either fst::Fst<fst::StdArc>,
  100. /// or fst::GrammarFst, as the template function is defined in the .cc file and
  101. /// only instantiated for those two types.
  102. template <typename FST>
  103. bool DecodeUtteranceLatticeFaster(
  104. LatticeFasterDecoderTpl<FST> &decoder, // not const but is really an input.
  105. DecodableInterface &decodable, // not const but is really an input.
  106. const TransitionInformation &trans_model,
  107. const fst::SymbolTable *word_syms,
  108. std::string utt,
  109. double acoustic_scale,
  110. bool determinize,
  111. bool allow_partial,
  112. Int32VectorWriter *alignments_writer,
  113. Int32VectorWriter *words_writer,
  114. CompactLatticeWriter *compact_lattice_writer,
  115. LatticeWriter *lattice_writer,
  116. double *like_ptr); // puts utterance's likelihood in like_ptr on success.
  117. /// This class basically does the same job as the function
  118. /// DecodeUtteranceLatticeFaster, but in a way that allows us
  119. /// to build a multi-threaded command line program more easily.
  120. /// The main computation takes place in operator (), and the output
  121. /// happens in the destructor.
  122. class DecodeUtteranceLatticeFasterClass {
  123. public:
  124. // Initializer sets various variables.
  125. // NOTE: we "take ownership" of "decoder" and "decodable". These
  126. // are deleted by the destructor. On error, "num_err" is incremented.
  127. DecodeUtteranceLatticeFasterClass(
  128. LatticeFasterDecoder *decoder,
  129. DecodableInterface *decodable,
  130. const TransitionInformation &trans_model,
  131. const fst::SymbolTable *word_syms,
  132. const std::string &utt,
  133. BaseFloat acoustic_scale,
  134. bool determinize,
  135. bool allow_partial,
  136. Int32VectorWriter *alignments_writer,
  137. Int32VectorWriter *words_writer,
  138. CompactLatticeWriter *compact_lattice_writer,
  139. LatticeWriter *lattice_writer,
  140. double *like_sum, // on success, adds likelihood to this.
  141. int64 *frame_sum, // on success, adds #frames to this.
  142. int32 *num_done, // on success (including partial decode), increments this.
  143. int32 *num_err, // on failure, increments this.
  144. int32 *num_partial); // If partial decode (final-state not reached), increments this.
  145. void operator () (); // The decoding happens here.
  146. ~DecodeUtteranceLatticeFasterClass(); // Output happens here.
  147. private:
  148. // The following variables correspond to inputs:
  149. LatticeFasterDecoder *decoder_;
  150. DecodableInterface *decodable_;
  151. const TransitionInformation *trans_model_;
  152. const fst::SymbolTable *word_syms_;
  153. std::string utt_;
  154. BaseFloat acoustic_scale_;
  155. bool determinize_;
  156. bool allow_partial_;
  157. Int32VectorWriter *alignments_writer_;
  158. Int32VectorWriter *words_writer_;
  159. CompactLatticeWriter *compact_lattice_writer_;
  160. LatticeWriter *lattice_writer_;
  161. double *like_sum_;
  162. int64 *frame_sum_;
  163. int32 *num_done_;
  164. int32 *num_err_;
  165. int32 *num_partial_;
  166. // The following variables are stored by the computation.
  167. bool computed_; // operator () was called.
  168. bool success_; // decoding succeeded (possibly partial)
  169. bool partial_; // decoding was partial.
  170. CompactLattice *clat_; // Stored output, if determinize_ == true.
  171. Lattice *lat_; // Stored output, if determinize_ == false.
  172. };
  173. // This function DecodeUtteranceLatticeSimple is used in several decoders, and
  174. // we have moved it here. Note: this is really "binary-level" code as it
  175. // involves table readers and writers; we've just put it here as there is no
  176. // other obvious place to put it. If determinize == false, it writes to
  177. // lattice_writer, else to compact_lattice_writer. The writers for
  178. // alignments and words will only be written to if they are open.
  179. bool DecodeUtteranceLatticeSimple(
  180. LatticeSimpleDecoder &decoder, // not const but is really an input.
  181. DecodableInterface &decodable, // not const but is really an input.
  182. const TransitionInformation &trans_model,
  183. const fst::SymbolTable *word_syms,
  184. std::string utt,
  185. double acoustic_scale,
  186. bool determinize,
  187. bool allow_partial,
  188. Int32VectorWriter *alignments_writer,
  189. Int32VectorWriter *words_writer,
  190. CompactLatticeWriter *compact_lattice_writer,
  191. LatticeWriter *lattice_writer,
  192. double *like_ptr); // puts utterance's likelihood in like_ptr on success.
  193. } // end namespace kaldi.
  194. #endif