simple-decoder.cc 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. // decoder/simple-decoder.cc
  2. // Copyright 2009-2011 Microsoft Corporation
  3. // 2012-2013 Johns Hopkins University (author: Daniel Povey)
  4. // See ../../COPYING for clarification regarding multiple authors
  5. //
  6. // Licensed under the Apache License, Version 2.0 (the "License");
  7. // you may not use this file except in compliance with the License.
  8. // You may obtain a copy of the License at
  9. //
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. //
  12. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  13. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  14. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  15. // MERCHANTABLITY OR NON-INFRINGEMENT.
  16. // See the Apache 2 License for the specific language governing permissions and
  17. // limitations under the License.
  18. #include "decoder/simple-decoder.h"
  19. #include "fstext/remove-eps-local.h"
  20. #include <algorithm>
  21. namespace kaldi {
  22. SimpleDecoder::~SimpleDecoder() {
  23. ClearToks(cur_toks_);
  24. ClearToks(prev_toks_);
  25. }
  26. bool SimpleDecoder::Decode(DecodableInterface *decodable) {
  27. InitDecoding();
  28. AdvanceDecoding(decodable);
  29. return (!cur_toks_.empty());
  30. }
  31. void SimpleDecoder::InitDecoding() {
  32. // clean up from last time:
  33. ClearToks(cur_toks_);
  34. ClearToks(prev_toks_);
  35. // initialize decoding:
  36. StateId start_state = fst_.Start();
  37. KALDI_ASSERT(start_state != fst::kNoStateId);
  38. StdArc dummy_arc(0, 0, StdWeight::One(), start_state);
  39. cur_toks_[start_state] = new Token(dummy_arc, 0.0, NULL);
  40. num_frames_decoded_ = 0;
  41. ProcessNonemitting();
  42. }
  43. void SimpleDecoder::AdvanceDecoding(DecodableInterface *decodable,
  44. int32 max_num_frames) {
  45. KALDI_ASSERT(num_frames_decoded_ >= 0 &&
  46. "You must call InitDecoding() before AdvanceDecoding()");
  47. int32 num_frames_ready = decodable->NumFramesReady();
  48. // num_frames_ready must be >= num_frames_decoded, or else
  49. // the number of frames ready must have decreased (which doesn't
  50. // make sense) or the decodable object changed between calls
  51. // (which isn't allowed).
  52. KALDI_ASSERT(num_frames_ready >= num_frames_decoded_);
  53. int32 target_frames_decoded = num_frames_ready;
  54. if (max_num_frames >= 0)
  55. target_frames_decoded = std::min(target_frames_decoded,
  56. num_frames_decoded_ + max_num_frames);
  57. while (num_frames_decoded_ < target_frames_decoded) {
  58. // note: ProcessEmitting() increments num_frames_decoded_
  59. ClearToks(prev_toks_);
  60. cur_toks_.swap(prev_toks_);
  61. ProcessEmitting(decodable);
  62. ProcessNonemitting();
  63. PruneToks(beam_, &cur_toks_);
  64. }
  65. }
  66. bool SimpleDecoder::ReachedFinal() const {
  67. for (unordered_map<StateId, Token*>::const_iterator iter = cur_toks_.begin();
  68. iter != cur_toks_.end();
  69. ++iter) {
  70. if (iter->second->cost_ != std::numeric_limits<BaseFloat>::infinity() &&
  71. fst_.Final(iter->first) != StdWeight::Zero())
  72. return true;
  73. }
  74. return false;
  75. }
  76. BaseFloat SimpleDecoder::FinalRelativeCost() const {
  77. // as a special case, if there are no active tokens at all (e.g. some kind of
  78. // pruning failure), return infinity.
  79. double infinity = std::numeric_limits<double>::infinity();
  80. if (cur_toks_.empty())
  81. return infinity;
  82. double best_cost = infinity,
  83. best_cost_with_final = infinity;
  84. for (unordered_map<StateId, Token*>::const_iterator iter = cur_toks_.begin();
  85. iter != cur_toks_.end();
  86. ++iter) {
  87. // Note: Plus is taking the minimum cost, since we're in the tropical
  88. // semiring.
  89. best_cost = std::min(best_cost, iter->second->cost_);
  90. best_cost_with_final = std::min(best_cost_with_final,
  91. iter->second->cost_ +
  92. fst_.Final(iter->first).Value());
  93. }
  94. BaseFloat extra_cost = best_cost_with_final - best_cost;
  95. if (extra_cost != extra_cost) { // NaN. This shouldn't happen; it indicates some
  96. // kind of error, most likely.
  97. KALDI_WARN << "Found NaN (likely search failure in decoding)";
  98. return infinity;
  99. }
  100. // Note: extra_cost will be infinity if no states were final.
  101. return extra_cost;
  102. }
  103. // Outputs an FST corresponding to the single best path
  104. // through the lattice.
  105. bool SimpleDecoder::GetBestPath(Lattice *fst_out, bool use_final_probs) const {
  106. fst_out->DeleteStates();
  107. Token *best_tok = NULL;
  108. bool is_final = ReachedFinal();
  109. if (!is_final) {
  110. for (unordered_map<StateId, Token*>::const_iterator iter = cur_toks_.begin();
  111. iter != cur_toks_.end();
  112. ++iter)
  113. if (best_tok == NULL || *best_tok < *(iter->second) )
  114. best_tok = iter->second;
  115. } else {
  116. double infinity =std::numeric_limits<double>::infinity(),
  117. best_cost = infinity;
  118. for (unordered_map<StateId, Token*>::const_iterator iter = cur_toks_.begin();
  119. iter != cur_toks_.end();
  120. ++iter) {
  121. double this_cost = iter->second->cost_ + fst_.Final(iter->first).Value();
  122. if (this_cost != infinity && this_cost < best_cost) {
  123. best_cost = this_cost;
  124. best_tok = iter->second;
  125. }
  126. }
  127. }
  128. if (best_tok == NULL) return false; // No output.
  129. std::vector<LatticeArc> arcs_reverse; // arcs in reverse order.
  130. for (Token *tok = best_tok; tok != NULL; tok = tok->prev_)
  131. arcs_reverse.push_back(tok->arc_);
  132. KALDI_ASSERT(arcs_reverse.back().nextstate == fst_.Start());
  133. arcs_reverse.pop_back(); // that was a "fake" token... gives no info.
  134. StateId cur_state = fst_out->AddState();
  135. fst_out->SetStart(cur_state);
  136. for (ssize_t i = static_cast<ssize_t>(arcs_reverse.size())-1; i >= 0; i--) {
  137. LatticeArc arc = arcs_reverse[i];
  138. arc.nextstate = fst_out->AddState();
  139. fst_out->AddArc(cur_state, arc);
  140. cur_state = arc.nextstate;
  141. }
  142. if (is_final && use_final_probs)
  143. fst_out->SetFinal(cur_state,
  144. LatticeWeight(fst_.Final(best_tok->arc_.nextstate).Value(),
  145. 0.0));
  146. else
  147. fst_out->SetFinal(cur_state, LatticeWeight::One());
  148. fst::RemoveEpsLocal(fst_out);
  149. return true;
  150. }
  151. void SimpleDecoder::ProcessEmitting(DecodableInterface *decodable) {
  152. int32 frame = num_frames_decoded_;
  153. // Processes emitting arcs for one frame. Propagates from
  154. // prev_toks_ to cur_toks_.
  155. double cutoff = std::numeric_limits<BaseFloat>::infinity();
  156. for (unordered_map<StateId, Token*>::iterator iter = prev_toks_.begin();
  157. iter != prev_toks_.end();
  158. ++iter) {
  159. StateId state = iter->first;
  160. Token *tok = iter->second;
  161. KALDI_ASSERT(state == tok->arc_.nextstate);
  162. for (fst::ArcIterator<fst::Fst<StdArc> > aiter(fst_, state);
  163. !aiter.Done();
  164. aiter.Next()) {
  165. const StdArc &arc = aiter.Value();
  166. if (arc.ilabel != 0) { // propagate..
  167. BaseFloat acoustic_cost = -decodable->LogLikelihood(frame, arc.ilabel);
  168. double total_cost = tok->cost_ + arc.weight.Value() + acoustic_cost;
  169. if (total_cost >= cutoff) continue;
  170. if (total_cost + beam_ < cutoff)
  171. cutoff = total_cost + beam_;
  172. Token *new_tok = new Token(arc, acoustic_cost, tok);
  173. unordered_map<StateId, Token*>::iterator find_iter
  174. = cur_toks_.find(arc.nextstate);
  175. if (find_iter == cur_toks_.end()) {
  176. cur_toks_[arc.nextstate] = new_tok;
  177. } else {
  178. if ( *(find_iter->second) < *new_tok ) {
  179. Token::TokenDelete(find_iter->second);
  180. find_iter->second = new_tok;
  181. } else {
  182. Token::TokenDelete(new_tok);
  183. }
  184. }
  185. }
  186. }
  187. }
  188. num_frames_decoded_++;
  189. }
  190. void SimpleDecoder::ProcessNonemitting() {
  191. // Processes nonemitting arcs for one frame. Propagates within
  192. // cur_toks_.
  193. std::vector<StateId> queue;
  194. double infinity = std::numeric_limits<double>::infinity();
  195. double best_cost = infinity;
  196. for (unordered_map<StateId, Token*>::iterator iter = cur_toks_.begin();
  197. iter != cur_toks_.end();
  198. ++iter) {
  199. queue.push_back(iter->first);
  200. best_cost = std::min(best_cost, iter->second->cost_);
  201. }
  202. double cutoff = best_cost + beam_;
  203. while (!queue.empty()) {
  204. StateId state = queue.back();
  205. queue.pop_back();
  206. Token *tok = cur_toks_[state];
  207. KALDI_ASSERT(tok != NULL && state == tok->arc_.nextstate);
  208. for (fst::ArcIterator<fst::Fst<StdArc> > aiter(fst_, state);
  209. !aiter.Done();
  210. aiter.Next()) {
  211. const StdArc &arc = aiter.Value();
  212. if (arc.ilabel == 0) { // propagate nonemitting only...
  213. const BaseFloat acoustic_cost = 0.0;
  214. Token *new_tok = new Token(arc, acoustic_cost, tok);
  215. if (new_tok->cost_ > cutoff) {
  216. Token::TokenDelete(new_tok);
  217. } else {
  218. unordered_map<StateId, Token*>::iterator find_iter
  219. = cur_toks_.find(arc.nextstate);
  220. if (find_iter == cur_toks_.end()) {
  221. cur_toks_[arc.nextstate] = new_tok;
  222. queue.push_back(arc.nextstate);
  223. } else {
  224. if ( *(find_iter->second) < *new_tok ) {
  225. Token::TokenDelete(find_iter->second);
  226. find_iter->second = new_tok;
  227. queue.push_back(arc.nextstate);
  228. } else {
  229. Token::TokenDelete(new_tok);
  230. }
  231. }
  232. }
  233. }
  234. }
  235. }
  236. }
  237. // static
  238. void SimpleDecoder::ClearToks(unordered_map<StateId, Token*> &toks) {
  239. for (unordered_map<StateId, Token*>::iterator iter = toks.begin();
  240. iter != toks.end(); ++iter) {
  241. Token::TokenDelete(iter->second);
  242. }
  243. toks.clear();
  244. }
  245. // static
  246. void SimpleDecoder::PruneToks(BaseFloat beam, unordered_map<StateId, Token*> *toks) {
  247. if (toks->empty()) {
  248. KALDI_VLOG(2) << "No tokens to prune.\n";
  249. return;
  250. }
  251. double best_cost = std::numeric_limits<double>::infinity();
  252. for (unordered_map<StateId, Token*>::iterator iter = toks->begin();
  253. iter != toks->end(); ++iter)
  254. best_cost = std::min(best_cost, iter->second->cost_);
  255. std::vector<StateId> retained;
  256. double cutoff = best_cost + beam;
  257. for (unordered_map<StateId, Token*>::iterator iter = toks->begin();
  258. iter != toks->end(); ++iter) {
  259. if (iter->second->cost_ < cutoff)
  260. retained.push_back(iter->first);
  261. else
  262. Token::TokenDelete(iter->second);
  263. }
  264. unordered_map<StateId, Token*> tmp;
  265. for (size_t i = 0; i < retained.size(); i++) {
  266. tmp[retained[i]] = (*toks)[retained[i]];
  267. }
  268. KALDI_VLOG(2) << "Pruned to " << (retained.size()) << " toks.\n";
  269. tmp.swap(*toks);
  270. }
  271. } // end namespace kaldi.