faster-decoder.cc 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. // decoder/faster-decoder.cc
  2. // Copyright 2009-2011 Microsoft Corporation
  3. // 2012-2013 Johns Hopkins University (author: Daniel Povey)
  4. // See ../../COPYING for clarification regarding multiple authors
  5. //
  6. // Licensed under the Apache License, Version 2.0 (the "License");
  7. // you may not use this file except in compliance with the License.
  8. // You may obtain a copy of the License at
  9. //
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. //
  12. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  13. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  14. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  15. // MERCHANTABLITY OR NON-INFRINGEMENT.
  16. // See the Apache 2 License for the specific language governing permissions and
  17. // limitations under the License.
  18. #include "decoder/faster-decoder.h"
  19. namespace kaldi {
  20. FasterDecoder::FasterDecoder(const fst::Fst<fst::StdArc> &fst,
  21. const FasterDecoderOptions &opts):
  22. fst_(fst), config_(opts), num_frames_decoded_(-1) {
  23. KALDI_ASSERT(config_.hash_ratio >= 1.0); // less doesn't make much sense.
  24. KALDI_ASSERT(config_.max_active > 1);
  25. KALDI_ASSERT(config_.min_active >= 0 && config_.min_active < config_.max_active);
  26. toks_.SetSize(1000); // just so on the first frame we do something reasonable.
  27. }
  28. void FasterDecoder::InitDecoding() {
  29. // clean up from last time:
  30. ClearToks(toks_.Clear());
  31. StateId start_state = fst_.Start();
  32. KALDI_ASSERT(start_state != fst::kNoStateId);
  33. Arc dummy_arc(0, 0, Weight::One(), start_state);
  34. toks_.Insert(start_state, new Token(dummy_arc, NULL));
  35. ProcessNonemitting(std::numeric_limits<float>::max());
  36. num_frames_decoded_ = 0;
  37. }
  38. void FasterDecoder::Decode(DecodableInterface *decodable) {
  39. InitDecoding();
  40. AdvanceDecoding(decodable);
  41. }
  42. void FasterDecoder::AdvanceDecoding(DecodableInterface *decodable,
  43. int32 max_num_frames) {
  44. KALDI_ASSERT(num_frames_decoded_ >= 0 &&
  45. "You must call InitDecoding() before AdvanceDecoding()");
  46. int32 num_frames_ready = decodable->NumFramesReady();
  47. // num_frames_ready must be >= num_frames_decoded, or else
  48. // the number of frames ready must have decreased (which doesn't
  49. // make sense) or the decodable object changed between calls
  50. // (which isn't allowed).
  51. KALDI_ASSERT(num_frames_ready >= num_frames_decoded_);
  52. int32 target_frames_decoded = num_frames_ready;
  53. if (max_num_frames >= 0)
  54. target_frames_decoded = std::min(target_frames_decoded,
  55. num_frames_decoded_ + max_num_frames);
  56. while (num_frames_decoded_ < target_frames_decoded) {
  57. // note: ProcessEmitting() increments num_frames_decoded_
  58. double weight_cutoff = ProcessEmitting(decodable);
  59. ProcessNonemitting(weight_cutoff);
  60. }
  61. }
  62. bool FasterDecoder::ReachedFinal() const {
  63. for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
  64. if (e->val->cost_ != std::numeric_limits<double>::infinity() &&
  65. fst_.Final(e->key) != Weight::Zero())
  66. return true;
  67. }
  68. return false;
  69. }
  70. bool FasterDecoder::GetBestPath(fst::MutableFst<LatticeArc> *fst_out,
  71. bool use_final_probs) {
  72. // GetBestPath gets the decoding output. If "use_final_probs" is true
  73. // AND we reached a final state, it limits itself to final states;
  74. // otherwise it gets the most likely token not taking into
  75. // account final-probs. fst_out will be empty (Start() == kNoStateId) if
  76. // nothing was available. It returns true if it got output (thus, fst_out
  77. // will be nonempty).
  78. fst_out->DeleteStates();
  79. Token *best_tok = NULL;
  80. bool is_final = ReachedFinal();
  81. if (!is_final) {
  82. for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
  83. if (best_tok == NULL || *best_tok < *(e->val) )
  84. best_tok = e->val;
  85. } else {
  86. double infinity = std::numeric_limits<double>::infinity(),
  87. best_cost = infinity;
  88. for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
  89. double this_cost = e->val->cost_ + fst_.Final(e->key).Value();
  90. if (this_cost < best_cost && this_cost != infinity) {
  91. best_cost = this_cost;
  92. best_tok = e->val;
  93. }
  94. }
  95. }
  96. if (best_tok == NULL) return false; // No output.
  97. std::vector<LatticeArc> arcs_reverse; // arcs in reverse order.
  98. for (Token *tok = best_tok; tok != NULL; tok = tok->prev_) {
  99. BaseFloat tot_cost = tok->cost_ -
  100. (tok->prev_ ? tok->prev_->cost_ : 0.0),
  101. graph_cost = tok->arc_.weight.Value(),
  102. ac_cost = tot_cost - graph_cost;
  103. LatticeArc l_arc(tok->arc_.ilabel,
  104. tok->arc_.olabel,
  105. LatticeWeight(graph_cost, ac_cost),
  106. tok->arc_.nextstate);
  107. arcs_reverse.push_back(l_arc);
  108. }
  109. KALDI_ASSERT(arcs_reverse.back().nextstate == fst_.Start());
  110. arcs_reverse.pop_back(); // that was a "fake" token... gives no info.
  111. StateId cur_state = fst_out->AddState();
  112. fst_out->SetStart(cur_state);
  113. for (ssize_t i = static_cast<ssize_t>(arcs_reverse.size())-1; i >= 0; i--) {
  114. LatticeArc arc = arcs_reverse[i];
  115. arc.nextstate = fst_out->AddState();
  116. fst_out->AddArc(cur_state, arc);
  117. cur_state = arc.nextstate;
  118. }
  119. if (is_final && use_final_probs) {
  120. Weight final_weight = fst_.Final(best_tok->arc_.nextstate);
  121. fst_out->SetFinal(cur_state, LatticeWeight(final_weight.Value(), 0.0));
  122. } else {
  123. fst_out->SetFinal(cur_state, LatticeWeight::One());
  124. }
  125. RemoveEpsLocal(fst_out);
  126. return true;
  127. }
  128. // Gets the weight cutoff. Also counts the active tokens.
  129. double FasterDecoder::GetCutoff(Elem *list_head, size_t *tok_count,
  130. BaseFloat *adaptive_beam, Elem **best_elem) {
  131. double best_cost = std::numeric_limits<double>::infinity();
  132. size_t count = 0;
  133. if (config_.max_active == std::numeric_limits<int32>::max() &&
  134. config_.min_active == 0) {
  135. for (Elem *e = list_head; e != NULL; e = e->tail, count++) {
  136. double w = e->val->cost_;
  137. if (w < best_cost) {
  138. best_cost = w;
  139. if (best_elem) *best_elem = e;
  140. }
  141. }
  142. if (tok_count != NULL) *tok_count = count;
  143. if (adaptive_beam != NULL) *adaptive_beam = config_.beam;
  144. return best_cost + config_.beam;
  145. } else {
  146. tmp_array_.clear();
  147. for (Elem *e = list_head; e != NULL; e = e->tail, count++) {
  148. double w = e->val->cost_;
  149. tmp_array_.push_back(w);
  150. if (w < best_cost) {
  151. best_cost = w;
  152. if (best_elem) *best_elem = e;
  153. }
  154. }
  155. if (tok_count != NULL) *tok_count = count;
  156. double beam_cutoff = best_cost + config_.beam,
  157. min_active_cutoff = std::numeric_limits<double>::infinity(),
  158. max_active_cutoff = std::numeric_limits<double>::infinity();
  159. if (tmp_array_.size() > static_cast<size_t>(config_.max_active)) {
  160. std::nth_element(tmp_array_.begin(),
  161. tmp_array_.begin() + config_.max_active,
  162. tmp_array_.end());
  163. max_active_cutoff = tmp_array_[config_.max_active];
  164. }
  165. if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam.
  166. if (adaptive_beam)
  167. *adaptive_beam = max_active_cutoff - best_cost + config_.beam_delta;
  168. return max_active_cutoff;
  169. }
  170. if (tmp_array_.size() > static_cast<size_t>(config_.min_active)) {
  171. if (config_.min_active == 0) min_active_cutoff = best_cost;
  172. else {
  173. std::nth_element(tmp_array_.begin(),
  174. tmp_array_.begin() + config_.min_active,
  175. tmp_array_.size() > static_cast<size_t>(config_.max_active) ?
  176. tmp_array_.begin() + config_.max_active :
  177. tmp_array_.end());
  178. min_active_cutoff = tmp_array_[config_.min_active];
  179. }
  180. }
  181. if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam.
  182. if (adaptive_beam)
  183. *adaptive_beam = min_active_cutoff - best_cost + config_.beam_delta;
  184. return min_active_cutoff;
  185. } else {
  186. *adaptive_beam = config_.beam;
  187. return beam_cutoff;
  188. }
  189. }
  190. }
  191. void FasterDecoder::PossiblyResizeHash(size_t num_toks) {
  192. size_t new_sz = static_cast<size_t>(static_cast<BaseFloat>(num_toks)
  193. * config_.hash_ratio);
  194. if (new_sz > toks_.Size()) {
  195. toks_.SetSize(new_sz);
  196. }
  197. }
  198. // ProcessEmitting returns the likelihood cutoff used.
  199. double FasterDecoder::ProcessEmitting(DecodableInterface *decodable) {
  200. int32 frame = num_frames_decoded_;
  201. Elem *last_toks = toks_.Clear();
  202. size_t tok_cnt;
  203. BaseFloat adaptive_beam;
  204. Elem *best_elem = NULL;
  205. double weight_cutoff = GetCutoff(last_toks, &tok_cnt,
  206. &adaptive_beam, &best_elem);
  207. KALDI_VLOG(3) << tok_cnt << " tokens active.";
  208. PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough.
  209. // This is the cutoff we use after adding in the log-likes (i.e.
  210. // for the next frame). This is a bound on the cutoff we will use
  211. // on the next frame.
  212. double next_weight_cutoff = std::numeric_limits<double>::infinity();
  213. // First process the best token to get a hopefully
  214. // reasonably tight bound on the next cutoff.
  215. if (best_elem) {
  216. StateId state = best_elem->key;
  217. Token *tok = best_elem->val;
  218. for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
  219. !aiter.Done();
  220. aiter.Next()) {
  221. const Arc &arc = aiter.Value();
  222. if (arc.ilabel != 0) { // we'd propagate..
  223. BaseFloat ac_cost = - decodable->LogLikelihood(frame, arc.ilabel);
  224. double new_weight = arc.weight.Value() + tok->cost_ + ac_cost;
  225. if (new_weight + adaptive_beam < next_weight_cutoff)
  226. next_weight_cutoff = new_weight + adaptive_beam;
  227. }
  228. }
  229. }
  230. // int32 n = 0, np = 0;
  231. // the tokens are now owned here, in last_toks, and the hash is empty.
  232. // 'owned' is a complex thing here; the point is we need to call TokenDelete
  233. // on each elem 'e' to let toks_ know we're done with them.
  234. for (Elem *e = last_toks, *e_tail; e != NULL; e = e_tail) { // loop this way
  235. // n++;
  236. // because we delete "e" as we go.
  237. StateId state = e->key;
  238. Token *tok = e->val;
  239. if (tok->cost_ < weight_cutoff) { // not pruned.
  240. // np++;
  241. KALDI_ASSERT(state == tok->arc_.nextstate);
  242. for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
  243. !aiter.Done();
  244. aiter.Next()) {
  245. Arc arc = aiter.Value();
  246. if (arc.ilabel != 0) { // propagate..
  247. BaseFloat ac_cost = - decodable->LogLikelihood(frame, arc.ilabel);
  248. double new_weight = arc.weight.Value() + tok->cost_ + ac_cost;
  249. if (new_weight < next_weight_cutoff) { // not pruned..
  250. Token *new_tok = new Token(arc, ac_cost, tok);
  251. Elem *e_found = toks_.Insert(arc.nextstate, new_tok);
  252. if (new_weight + adaptive_beam < next_weight_cutoff)
  253. next_weight_cutoff = new_weight + adaptive_beam;
  254. if (e_found->val != new_tok) {
  255. if (*(e_found->val) < *new_tok) {
  256. Token::TokenDelete(e_found->val);
  257. e_found->val = new_tok;
  258. } else {
  259. Token::TokenDelete(new_tok);
  260. }
  261. }
  262. }
  263. }
  264. }
  265. }
  266. e_tail = e->tail;
  267. Token::TokenDelete(e->val);
  268. toks_.Delete(e);
  269. }
  270. num_frames_decoded_++;
  271. return next_weight_cutoff;
  272. }
  273. // TODO: first time we go through this, could avoid using the queue.
  274. void FasterDecoder::ProcessNonemitting(double cutoff) {
  275. // Processes nonemitting arcs for one frame.
  276. KALDI_ASSERT(queue_.empty());
  277. for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
  278. queue_.push_back(e);
  279. while (!queue_.empty()) {
  280. const Elem* e = queue_.back();
  281. queue_.pop_back();
  282. StateId state = e->key;
  283. Token *tok = e->val; // would segfault if state not
  284. // in toks_ but this can't happen.
  285. if (tok->cost_ > cutoff) { // Don't bother processing successors.
  286. continue;
  287. }
  288. KALDI_ASSERT(tok != NULL && state == tok->arc_.nextstate);
  289. for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
  290. !aiter.Done();
  291. aiter.Next()) {
  292. const Arc &arc = aiter.Value();
  293. if (arc.ilabel == 0) { // propagate nonemitting only...
  294. Token *new_tok = new Token(arc, tok);
  295. if (new_tok->cost_ > cutoff) { // prune
  296. Token::TokenDelete(new_tok);
  297. } else {
  298. Elem *e_found = toks_.Insert(arc.nextstate, new_tok);
  299. if (e_found->val == new_tok) {
  300. queue_.push_back(e_found);
  301. } else {
  302. if (*(e_found->val) < *new_tok) {
  303. Token::TokenDelete(e_found->val);
  304. e_found->val = new_tok;
  305. queue_.push_back(e_found);
  306. } else {
  307. Token::TokenDelete(new_tok);
  308. }
  309. }
  310. }
  311. }
  312. }
  313. }
  314. }
  315. void FasterDecoder::ClearToks(Elem *list) {
  316. for (Elem *e = list, *e_tail; e != NULL; e = e_tail) {
  317. Token::TokenDelete(e->val);
  318. e_tail = e->tail;
  319. toks_.Delete(e);
  320. }
  321. }
  322. } // end namespace kaldi.