training-graph-compiler.cc 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. // decoder/training-graph-compiler.cc
  2. // Copyright 2009-2011 Microsoft Corporation
  3. // 2018 Johns Hopkins University (author: Daniel Povey)
  4. // 2021 Xiaomi Corporation (Author: Junbo Zhang)
  5. // See ../../COPYING for clarification regarding multiple authors
  6. //
  7. // Licensed under the Apache License, Version 2.0 (the "License");
  8. // you may not use this file except in compliance with the License.
  9. // You may obtain a copy of the License at
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  12. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  13. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  14. // MERCHANTABLITY OR NON-INFRINGEMENT.
  15. // See the Apache 2 License for the specific language governing permissions and
  16. // limitations under the License.
  17. #include "decoder/training-graph-compiler.h"
  18. #include "hmm/hmm-utils.h" // for GetHTransducer
  19. namespace kaldi {
  20. TrainingGraphCompiler::TrainingGraphCompiler(const TransitionModel &trans_model,
  21. const ContextDependency &ctx_dep, // Does not maintain reference to this.
  22. fst::VectorFst<fst::StdArc> *lex_fst,
  23. const std::vector<int32> &disambig_syms,
  24. const TrainingGraphCompilerOptions &opts):
  25. trans_model_(trans_model), ctx_dep_(ctx_dep), lex_fst_(lex_fst),
  26. disambig_syms_(disambig_syms), opts_(opts) {
  27. using namespace fst;
  28. const std::vector<int32> &phone_syms = trans_model_.GetPhones(); // needed to create context fst.
  29. KALDI_ASSERT(!phone_syms.empty());
  30. KALDI_ASSERT(IsSortedAndUniq(phone_syms));
  31. SortAndUniq(&disambig_syms_);
  32. for (int32 i = 0; i < disambig_syms_.size(); i++)
  33. if (std::binary_search(phone_syms.begin(), phone_syms.end(),
  34. disambig_syms_[i]))
  35. KALDI_ERR << "Disambiguation symbol " << disambig_syms_[i]
  36. << " is also a phone.";
  37. subsequential_symbol_ = 1 + phone_syms.back();
  38. if (!disambig_syms_.empty() && subsequential_symbol_ <= disambig_syms_.back())
  39. subsequential_symbol_ = 1 + disambig_syms_.back();
  40. if (lex_fst == NULL) return;
  41. {
  42. int32 N = ctx_dep.ContextWidth(),
  43. P = ctx_dep.CentralPosition();
  44. if (P != N-1)
  45. AddSubsequentialLoop(subsequential_symbol_, lex_fst_); // This is needed for
  46. // systems with right-context or we will not successfully compose
  47. // with C.
  48. }
  49. { // make sure lexicon is olabel sorted.
  50. fst::OLabelCompare<fst::StdArc> olabel_comp;
  51. fst::ArcSort(lex_fst_, olabel_comp);
  52. }
  53. }
  54. bool TrainingGraphCompiler::CompileGraphFromText(
  55. const std::vector<int32> &transcript,
  56. fst::VectorFst<fst::StdArc> *out_fst) {
  57. using namespace fst;
  58. VectorFst<StdArc> word_fst;
  59. MakeLinearAcceptor(transcript, &word_fst);
  60. return CompileGraph(word_fst, out_fst);
  61. }
  62. bool TrainingGraphCompiler::CompileGraphFromLG(const fst::VectorFst<fst::StdArc> &phone2word_fst,
  63. fst::VectorFst<fst::StdArc> *out_fst) {
  64. using namespace fst;
  65. KALDI_ASSERT(phone2word_fst.Start() != kNoStateId);
  66. const std::vector<int32> &phone_syms = trans_model_.GetPhones(); // needed to create context fst.
  67. // inv_cfst will be expanded on the fly, as needed.
  68. InverseContextFst inv_cfst(subsequential_symbol_,
  69. phone_syms,
  70. disambig_syms_,
  71. ctx_dep_.ContextWidth(),
  72. ctx_dep_.CentralPosition());
  73. VectorFst<StdArc> ctx2word_fst;
  74. ComposeDeterministicOnDemandInverse(phone2word_fst, &inv_cfst, &ctx2word_fst);
  75. // now ctx2word_fst is C * LG, assuming phone2word_fst is written as LG.
  76. KALDI_ASSERT(ctx2word_fst.Start() != kNoStateId);
  77. HTransducerConfig h_cfg;
  78. h_cfg.transition_scale = opts_.transition_scale;
  79. std::vector<int32> disambig_syms_h; // disambiguation symbols on
  80. // input side of H.
  81. VectorFst<StdArc> *H = GetHTransducer(inv_cfst.IlabelInfo(),
  82. ctx_dep_,
  83. trans_model_,
  84. h_cfg,
  85. &disambig_syms_h);
  86. VectorFst<StdArc> &trans2word_fst = *out_fst; // transition-id to word.
  87. TableCompose(*H, ctx2word_fst, &trans2word_fst);
  88. KALDI_ASSERT(trans2word_fst.Start() != kNoStateId);
  89. // Epsilon-removal and determinization combined. This will fail if not determinizable.
  90. DeterminizeStarInLog(&trans2word_fst);
  91. if (!disambig_syms_h.empty()) {
  92. RemoveSomeInputSymbols(disambig_syms_h, &trans2word_fst);
  93. // we elect not to remove epsilons after this phase, as it is
  94. // a little slow.
  95. if (opts_.rm_eps)
  96. RemoveEpsLocal(&trans2word_fst);
  97. }
  98. // Encoded minimization.
  99. MinimizeEncoded(&trans2word_fst);
  100. std::vector<int32> disambig;
  101. bool check_no_self_loops = true;
  102. AddSelfLoops(trans_model_,
  103. disambig,
  104. opts_.self_loop_scale,
  105. opts_.reorder,
  106. check_no_self_loops,
  107. &trans2word_fst);
  108. delete H;
  109. return true;
  110. }
  111. bool TrainingGraphCompiler::CompileGraph(const fst::VectorFst<fst::StdArc> &word_fst,
  112. fst::VectorFst<fst::StdArc> *out_fst) {
  113. using namespace fst;
  114. KALDI_ASSERT(lex_fst_ !=NULL);
  115. KALDI_ASSERT(out_fst != NULL);
  116. VectorFst<StdArc> phone2word_fst;
  117. // TableCompose more efficient than compose.
  118. TableCompose(*lex_fst_, word_fst, &phone2word_fst, &lex_cache_);
  119. return CompileGraphFromLG(phone2word_fst, out_fst);
  120. }
  121. bool TrainingGraphCompiler::CompileGraphsFromText(
  122. const std::vector<std::vector<int32> > &transcripts,
  123. std::vector<fst::VectorFst<fst::StdArc>*> *out_fsts) {
  124. using namespace fst;
  125. std::vector<const VectorFst<StdArc>* > word_fsts(transcripts.size());
  126. for (size_t i = 0; i < transcripts.size(); i++) {
  127. VectorFst<StdArc> *word_fst = new VectorFst<StdArc>();
  128. MakeLinearAcceptor(transcripts[i], word_fst);
  129. word_fsts[i] = word_fst;
  130. }
  131. bool ans = CompileGraphs(word_fsts, out_fsts);
  132. for (size_t i = 0; i < transcripts.size(); i++)
  133. delete word_fsts[i];
  134. return ans;
  135. }
  136. bool TrainingGraphCompiler::CompileGraphs(
  137. const std::vector<const fst::VectorFst<fst::StdArc>* > &word_fsts,
  138. std::vector<fst::VectorFst<fst::StdArc>* > *out_fsts) {
  139. out_fsts->resize(word_fsts.size(), NULL);
  140. for (size_t i = 0; i < word_fsts.size(); i++) {
  141. fst::VectorFst<fst::StdArc> trans2word_fst;
  142. if (!CompileGraph(*(word_fsts[i]), &trans2word_fst)) return false;
  143. (*out_fsts)[i] = trans2word_fst.Copy();
  144. }
  145. return true;
  146. }
  147. } // end namespace kaldi