lattice-simple-decoder.cc 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666
  1. // decoder/lattice-simple-decoder.cc
  2. // Copyright 2009-2012 Microsoft Corporation
  3. // 2013-2014 Johns Hopkins University (Author: Daniel Povey)
  4. // 2014 Guoguo Chen
  5. // See ../../COPYING for clarification regarding multiple authors
  6. //
  7. // Licensed under the Apache License, Version 2.0 (the "License");
  8. // you may not use this file except in compliance with the License.
  9. // You may obtain a copy of the License at
  10. //
  11. // http://www.apache.org/licenses/LICENSE-2.0
  12. //
  13. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  15. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  16. // MERCHANTABLITY OR NON-INFRINGEMENT.
  17. // See the Apache 2 License for the specific language governing permissions and
  18. // limitations under the License.
  19. #include "decoder/lattice-simple-decoder.h"
  20. namespace kaldi {
  21. void LatticeSimpleDecoder::InitDecoding() {
  22. // clean up from last time:
  23. cur_toks_.clear();
  24. prev_toks_.clear();
  25. ClearActiveTokens();
  26. warned_ = false;
  27. decoding_finalized_ = false;
  28. final_costs_.clear();
  29. num_toks_ = 0;
  30. StateId start_state = fst_.Start();
  31. KALDI_ASSERT(start_state != fst::kNoStateId);
  32. active_toks_.resize(1);
  33. Token *start_tok = new Token(0.0, 0.0, NULL, NULL);
  34. active_toks_[0].toks = start_tok;
  35. cur_toks_[start_state] = start_tok;
  36. num_toks_++;
  37. ProcessNonemitting();
  38. }
  39. bool LatticeSimpleDecoder::Decode(DecodableInterface *decodable) {
  40. InitDecoding();
  41. while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) {
  42. if (NumFramesDecoded() % config_.prune_interval == 0)
  43. PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
  44. ProcessEmitting(decodable);
  45. // Important to call PruneCurrentTokens before ProcessNonemitting, or we
  46. // would get dangling forward pointers. Anyway, ProcessNonemitting uses the
  47. // beam.
  48. PruneCurrentTokens(config_.beam, &cur_toks_);
  49. ProcessNonemitting();
  50. }
  51. FinalizeDecoding();
  52. // Returns true if we have any kind of traceback available (not necessarily
  53. // to the end state; query ReachedFinal() for that).
  54. return !final_costs_.empty();
  55. }
  56. // Outputs an FST corresponding to the single best path
  57. // through the lattice.
  58. bool LatticeSimpleDecoder::GetBestPath(Lattice *ofst,
  59. bool use_final_probs) const {
  60. fst::VectorFst<LatticeArc> fst;
  61. GetRawLattice(&fst, use_final_probs);
  62. ShortestPath(fst, ofst);
  63. return (ofst->NumStates() > 0);
  64. }
  65. // Outputs an FST corresponding to the raw, state-level
  66. // tracebacks.
  67. bool LatticeSimpleDecoder::GetRawLattice(Lattice *ofst,
  68. bool use_final_probs) const {
  69. typedef LatticeArc Arc;
  70. typedef Arc::StateId StateId;
  71. typedef Arc::Weight Weight;
  72. typedef Arc::Label Label;
  73. // Note: you can't use the old interface (Decode()) if you want to
  74. // get the lattice with use_final_probs = false. You'd have to do
  75. // InitDecoding() and then AdvanceDecoding().
  76. if (decoding_finalized_ && !use_final_probs)
  77. KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
  78. << "GetRawLattice() with use_final_probs == false";
  79. unordered_map<Token*, BaseFloat> final_costs_local;
  80. const unordered_map<Token*, BaseFloat> &final_costs =
  81. (decoding_finalized_ ? final_costs_ : final_costs_local);
  82. if (!decoding_finalized_ && use_final_probs)
  83. ComputeFinalCosts(&final_costs_local, NULL, NULL);
  84. ofst->DeleteStates();
  85. int32 num_frames = NumFramesDecoded();
  86. KALDI_ASSERT(num_frames > 0);
  87. const int32 bucket_count = num_toks_/2 + 3;
  88. unordered_map<Token*, StateId> tok_map(bucket_count);
  89. // First create all states.
  90. for (int32 f = 0; f <= num_frames; f++) {
  91. if (active_toks_[f].toks == NULL) {
  92. KALDI_WARN << "GetRawLattice: no tokens active on frame " << f
  93. << ": not producing lattice.\n";
  94. return false;
  95. }
  96. for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next)
  97. tok_map[tok] = ofst->AddState();
  98. // The next statement sets the start state of the output FST.
  99. // Because we always add new states to the head of the list
  100. // active_toks_[f].toks, and the start state was the first one
  101. // added, it will be the last one added to ofst.
  102. if (f == 0 && ofst->NumStates() > 0)
  103. ofst->SetStart(ofst->NumStates()-1);
  104. }
  105. StateId cur_state = 0; // we rely on the fact that we numbered these
  106. // consecutively (AddState() returns the numbers in order..)
  107. for (int32 f = 0; f <= num_frames; f++) {
  108. for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next,
  109. cur_state++) {
  110. for (ForwardLink *l = tok->links;
  111. l != NULL;
  112. l = l->next) {
  113. unordered_map<Token*, StateId>::const_iterator iter =
  114. tok_map.find(l->next_tok);
  115. StateId nextstate = iter->second;
  116. KALDI_ASSERT(iter != tok_map.end());
  117. Arc arc(l->ilabel, l->olabel,
  118. Weight(l->graph_cost, l->acoustic_cost),
  119. nextstate);
  120. ofst->AddArc(cur_state, arc);
  121. }
  122. if (f == num_frames) {
  123. if (use_final_probs && !final_costs.empty()) {
  124. unordered_map<Token*, BaseFloat>::const_iterator iter =
  125. final_costs.find(tok);
  126. if (iter != final_costs.end())
  127. ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0));
  128. } else {
  129. ofst->SetFinal(cur_state, LatticeWeight::One());
  130. }
  131. }
  132. }
  133. }
  134. KALDI_ASSERT(cur_state == ofst->NumStates());
  135. return (cur_state != 0);
  136. }
  137. // This function is now deprecated, since now we do determinization from outside
  138. // the LatticeSimpleDecoder class.
  139. // Outputs an FST corresponding to the lattice-determinized
  140. // lattice (one path per word sequence).
  141. bool LatticeSimpleDecoder::GetLattice(
  142. CompactLattice *ofst,
  143. bool use_final_probs) const {
  144. Lattice raw_fst;
  145. GetRawLattice(&raw_fst, use_final_probs);
  146. Invert(&raw_fst); // make it so word labels are on the input.
  147. if (!TopSort(&raw_fst)) // topological sort makes lattice-determinization more efficient
  148. KALDI_WARN << "Topological sorting of state-level lattice failed "
  149. "(probably your lexicon has empty words or your LM has epsilon cycles; this "
  150. " is a bad idea.)";
  151. // (in phase where we get backward-costs).
  152. fst::ILabelCompare<LatticeArc> ilabel_comp;
  153. ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes
  154. // lattice-determinization more efficient.
  155. fst::DeterminizeLatticePrunedOptions lat_opts;
  156. lat_opts.max_mem = config_.det_opts.max_mem;
  157. DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts);
  158. raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed.
  159. Connect(ofst); // Remove unreachable states... there might be
  160. // a small number of these, in some cases.
  161. // Note: if something went wrong and the raw lattice was empty,
  162. // we should still get to this point in the code without warnings or failures.
  163. return (ofst->NumStates() != 0);
  164. }
  165. // FindOrAddToken either locates a token in cur_toks_, or if necessary inserts a new,
  166. // empty token (i.e. with no forward links) for the current frame. [note: it's
  167. // inserted if necessary into cur_toks_ and also into the singly linked list
  168. // of tokens active on this frame (whose head is at active_toks_[frame]).
  169. //
  170. // Returns the Token pointer. Sets "changed" (if non-NULL) to true
  171. // if the token was newly created or the cost changed.
  172. inline LatticeSimpleDecoder::Token *LatticeSimpleDecoder::FindOrAddToken(
  173. StateId state, int32 frame, BaseFloat tot_cost,
  174. bool emitting, bool *changed) {
  175. KALDI_ASSERT(frame < active_toks_.size());
  176. Token *&toks = active_toks_[frame].toks;
  177. unordered_map<StateId, Token*>::iterator find_iter = cur_toks_.find(state);
  178. if (find_iter == cur_toks_.end()) { // no such token presently.
  179. // Create one.
  180. const BaseFloat extra_cost = 0.0;
  181. // tokens on the currently final frame have zero extra_cost
  182. // as any of them could end up
  183. // on the winning path.
  184. Token *new_tok = new Token (tot_cost, extra_cost, NULL, toks);
  185. toks = new_tok;
  186. num_toks_++;
  187. cur_toks_[state] = new_tok;
  188. if (changed) *changed = true;
  189. return new_tok;
  190. } else {
  191. Token *tok = find_iter->second; // There is an existing Token for this state.
  192. if (tok->tot_cost > tot_cost) {
  193. tok->tot_cost = tot_cost;
  194. if (changed) *changed = true;
  195. } else {
  196. if (changed) *changed = false;
  197. }
  198. return tok;
  199. }
  200. }
  201. // delta is the amount by which the extra_costs must
  202. // change before it sets "extra_costs_changed" to true. If delta is larger,
  203. // we'll tend to go back less far toward the beginning of the file.
  204. void LatticeSimpleDecoder::PruneForwardLinks(
  205. int32 frame, bool *extra_costs_changed,
  206. bool *links_pruned, BaseFloat delta) {
  207. // We have to iterate until there is no more change, because the links
  208. // are not guaranteed to be in topological order.
  209. *extra_costs_changed = false;
  210. *links_pruned = false;
  211. KALDI_ASSERT(frame >= 0 && frame < active_toks_.size());
  212. if (active_toks_[frame].toks == NULL ) { // empty list; this should
  213. // not happen.
  214. if (!warned_) {
  215. KALDI_WARN << "No tokens alive [doing pruning].. warning first "
  216. "time only for each utterance\n";
  217. warned_ = true;
  218. }
  219. }
  220. bool changed = true;
  221. while (changed) {
  222. changed = false;
  223. for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) {
  224. ForwardLink *link, *prev_link = NULL;
  225. // will recompute tok_extra_cost.
  226. BaseFloat tok_extra_cost = std::numeric_limits<BaseFloat>::infinity();
  227. for (link = tok->links; link != NULL; ) {
  228. // See if we need to excise this link...
  229. Token *next_tok = link->next_tok;
  230. BaseFloat link_extra_cost = next_tok->extra_cost +
  231. ((tok->tot_cost + link->acoustic_cost + link->graph_cost)
  232. - next_tok->tot_cost);
  233. KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN
  234. if (link_extra_cost > config_.lattice_beam) { // excise link
  235. ForwardLink *next_link = link->next;
  236. if (prev_link != NULL) prev_link->next = next_link;
  237. else tok->links = next_link;
  238. delete link;
  239. link = next_link; // advance link but leave prev_link the same.
  240. *links_pruned = true;
  241. } else { // keep the link and update the tok_extra_cost if needed.
  242. if (link_extra_cost < 0.0) { // this is just a precaution.
  243. if (link_extra_cost < -0.01)
  244. KALDI_WARN << "Negative extra_cost: " << link_extra_cost;
  245. link_extra_cost = 0.0;
  246. }
  247. if (link_extra_cost < tok_extra_cost)
  248. tok_extra_cost = link_extra_cost;
  249. prev_link = link;
  250. link = link->next;
  251. }
  252. }
  253. if (fabs(tok_extra_cost - tok->extra_cost) > delta)
  254. changed = true;
  255. tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_.
  256. }
  257. if (changed) *extra_costs_changed = true;
  258. // Note: it's theoretically possible that aggressive compiler
  259. // optimizations could cause an infinite loop here for small delta and
  260. // high-dynamic-range scores.
  261. }
  262. }
  263. void LatticeSimpleDecoder::ComputeFinalCosts(
  264. unordered_map<Token*, BaseFloat> *final_costs,
  265. BaseFloat *final_relative_cost,
  266. BaseFloat *final_best_cost) const {
  267. KALDI_ASSERT(!decoding_finalized_);
  268. if (final_costs != NULL)
  269. final_costs->clear();
  270. BaseFloat infinity = std::numeric_limits<BaseFloat>::infinity();
  271. BaseFloat best_cost = infinity,
  272. best_cost_with_final = infinity;
  273. for (unordered_map<StateId, Token*>::const_iterator iter = cur_toks_.begin();
  274. iter != cur_toks_.end(); ++iter) {
  275. StateId state = iter->first;
  276. Token *tok = iter->second;
  277. BaseFloat final_cost = fst_.Final(state).Value();
  278. BaseFloat cost = tok->tot_cost,
  279. cost_with_final = cost + final_cost;
  280. best_cost = std::min(cost, best_cost);
  281. best_cost_with_final = std::min(cost_with_final, best_cost_with_final);
  282. if (final_costs != NULL && final_cost != infinity)
  283. (*final_costs)[tok] = final_cost;
  284. }
  285. if (final_relative_cost != NULL) {
  286. if (best_cost == infinity && best_cost_with_final == infinity) {
  287. // Likely this will only happen if there are no tokens surviving.
  288. // This seems the least bad way to handle it.
  289. *final_relative_cost = infinity;
  290. } else {
  291. *final_relative_cost = best_cost_with_final - best_cost;
  292. }
  293. }
  294. if (final_best_cost != NULL) {
  295. if (best_cost_with_final != infinity) { // final-state exists.
  296. *final_best_cost = best_cost_with_final;
  297. } else { // no final-state exists.
  298. *final_best_cost = best_cost;
  299. }
  300. }
  301. }
  302. // PruneForwardLinksFinal is a version of PruneForwardLinks that we call
  303. // on the final frame. If there are final tokens active, it uses the final-probs
  304. // for pruning, otherwise it treats all tokens as final.
  305. void LatticeSimpleDecoder::PruneForwardLinksFinal() {
  306. KALDI_ASSERT(!active_toks_.empty());
  307. int32 frame_plus_one = active_toks_.size() - 1;
  308. if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen.
  309. KALDI_WARN << "No tokens alive at end of file\n";
  310. typedef unordered_map<Token*, BaseFloat>::const_iterator IterType;
  311. ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_);
  312. decoding_finalized_ = true;
  313. // We're about to delete some of the tokens active on the final frame, so we
  314. // clear cur_toks_ because otherwise it would then contain dangling pointers.
  315. cur_toks_.clear();
  316. // Now go through tokens on this frame, pruning forward links... may have to
  317. // iterate a few times until there is no more change, because the list is not
  318. // in topological order. This is a modified version of the code in
  319. // PruneForwardLinks, but here we also take account of the final-probs.
  320. bool changed = true;
  321. BaseFloat delta = 1.0e-05;
  322. while (changed) {
  323. changed = false;
  324. for (Token *tok = active_toks_[frame_plus_one].toks;
  325. tok != NULL; tok = tok->next) {
  326. ForwardLink *link, *prev_link=NULL;
  327. // will recompute tok_extra_cost. It has a term in it that corresponds
  328. // to the "final-prob", so instead of initializing tok_extra_cost to infinity
  329. // below we set it to the difference between the (score+final_prob) of this token,
  330. // and the best such (score+final_prob).
  331. BaseFloat final_cost;
  332. if (final_costs_.empty()) {
  333. final_cost = 0.0;
  334. } else {
  335. IterType iter = final_costs_.find(tok);
  336. if (iter != final_costs_.end())
  337. final_cost = iter->second;
  338. else
  339. final_cost = std::numeric_limits<BaseFloat>::infinity();
  340. }
  341. BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_;
  342. // tok_extra_cost will be a "min" over either directly being final, or
  343. // being indirectly final through other links, and the loop below may
  344. // decrease its value:
  345. for (link = tok->links; link != NULL; ) {
  346. // See if we need to excise this link...
  347. Token *next_tok = link->next_tok;
  348. BaseFloat link_extra_cost = next_tok->extra_cost +
  349. ((tok->tot_cost + link->acoustic_cost + link->graph_cost)
  350. - next_tok->tot_cost);
  351. if (link_extra_cost > config_.lattice_beam) { // excise link
  352. ForwardLink *next_link = link->next;
  353. if (prev_link != NULL) prev_link->next = next_link;
  354. else tok->links = next_link;
  355. delete link;
  356. link = next_link; // advance link but leave prev_link the same.
  357. } else { // keep the link and update the tok_extra_cost if needed.
  358. if (link_extra_cost < 0.0) { // this is just a precaution.
  359. if (link_extra_cost < -0.01)
  360. KALDI_WARN << "Negative extra_cost: " << link_extra_cost;
  361. link_extra_cost = 0.0;
  362. }
  363. if (link_extra_cost < tok_extra_cost)
  364. tok_extra_cost = link_extra_cost;
  365. prev_link = link;
  366. link = link->next;
  367. }
  368. }
  369. // prune away tokens worse than lattice_beam above best path. This step
  370. // was not necessary in the non-final case because then, this case
  371. // showed up as having no forward links. Here, the tok_extra_cost has
  372. // an extra component relating to the final-prob.
  373. if (tok_extra_cost > config_.lattice_beam)
  374. tok_extra_cost = std::numeric_limits<BaseFloat>::infinity();
  375. // to be pruned in PruneTokensForFrame
  376. if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta))
  377. changed = true;
  378. tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_.
  379. }
  380. } // while changed
  381. }
  382. BaseFloat LatticeSimpleDecoder::FinalRelativeCost() const {
  383. if (!decoding_finalized_) {
  384. BaseFloat relative_cost;
  385. ComputeFinalCosts(NULL, &relative_cost, NULL);
  386. return relative_cost;
  387. } else {
  388. // we're not allowed to call that function if FinalizeDecoding() has
  389. // been called; return a cached value.
  390. return final_relative_cost_;
  391. }
  392. }
  393. // Prune away any tokens on this frame that have no forward links. [we don't do
  394. // this in PruneForwardLinks because it would give us a problem with dangling
  395. // pointers].
  396. void LatticeSimpleDecoder::PruneTokensForFrame(int32 frame) {
  397. KALDI_ASSERT(frame >= 0 && frame < active_toks_.size());
  398. Token *&toks = active_toks_[frame].toks;
  399. if (toks == NULL)
  400. KALDI_WARN << "No tokens alive [doing pruning]";
  401. Token *tok, *next_tok, *prev_tok = NULL;
  402. for (tok = toks; tok != NULL; tok = next_tok) {
  403. next_tok = tok->next;
  404. if (tok->extra_cost == std::numeric_limits<BaseFloat>::infinity()) {
  405. // Next token is unreachable from end of graph; excise tok from list
  406. // and delete tok.
  407. if (prev_tok != NULL) prev_tok->next = tok->next;
  408. else toks = tok->next;
  409. delete tok;
  410. num_toks_--;
  411. } else {
  412. prev_tok = tok;
  413. }
  414. }
  415. }
  416. // Go backwards through still-alive tokens, pruning them, starting not from
  417. // the current frame (where we want to keep all tokens) but from the frame before
  418. // that. We go backwards through the frames and stop when we reach a point
  419. // where the delta-costs are not changing (and the delta controls when we consider
  420. // a cost to have "not changed").
  421. void LatticeSimpleDecoder::PruneActiveTokens(BaseFloat delta) {
  422. int32 cur_frame_plus_one = NumFramesDecoded();
  423. int32 num_toks_begin = num_toks_;
  424. // The index "f" below represents a "frame plus one", i.e. you'd have to subtract
  425. // one to get the corresponding index for the decodable object.
  426. for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) {
  427. // Reason why we need to prune forward links in this situation:
  428. // (1) we have never pruned them
  429. // (2) we never pruned the forward links on the next frame, which
  430. //
  431. if (active_toks_[f].must_prune_forward_links) {
  432. bool extra_costs_changed = false, links_pruned = false;
  433. PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta);
  434. if (extra_costs_changed && f > 0)
  435. active_toks_[f-1].must_prune_forward_links = true;
  436. if (links_pruned)
  437. active_toks_[f].must_prune_tokens = true;
  438. active_toks_[f].must_prune_forward_links = false;
  439. }
  440. if (f+1 < cur_frame_plus_one &&
  441. active_toks_[f+1].must_prune_tokens) {
  442. PruneTokensForFrame(f+1);
  443. active_toks_[f+1].must_prune_tokens = false;
  444. }
  445. }
  446. KALDI_VLOG(3) << "PruneActiveTokens: pruned tokens from " << num_toks_begin
  447. << " to " << num_toks_;
  448. }
  449. // FinalizeDecoding() is a version of PruneActiveTokens that we call
  450. // (optionally) on the final frame. Takes into account the final-prob of
  451. // tokens. This function used to be called PruneActiveTokensFinal().
  452. void LatticeSimpleDecoder::FinalizeDecoding() {
  453. int32 final_frame_plus_one = NumFramesDecoded();
  454. int32 num_toks_begin = num_toks_;
  455. PruneForwardLinksFinal();
  456. for (int32 f = final_frame_plus_one - 1; f >= 0; f--) {
  457. bool b1, b2; // values not used.
  458. BaseFloat dontcare = 0.0;
  459. PruneForwardLinks(f, &b1, &b2, dontcare);
  460. PruneTokensForFrame(f + 1);
  461. }
  462. PruneTokensForFrame(0);
  463. KALDI_VLOG(3) << "pruned tokens from " << num_toks_begin
  464. << " to " << num_toks_;
  465. }
  466. void LatticeSimpleDecoder::ProcessEmitting(DecodableInterface *decodable) {
  467. int32 frame = active_toks_.size() - 1; // frame is the frame-index
  468. // (zero-based) used to get likelihoods
  469. // from the decodable object.
  470. active_toks_.resize(active_toks_.size() + 1);
  471. prev_toks_.clear();
  472. cur_toks_.swap(prev_toks_);
  473. // Processes emitting arcs for one frame. Propagates from
  474. // prev_toks_ to cur_toks_.
  475. BaseFloat cutoff = std::numeric_limits<BaseFloat>::infinity();
  476. for (unordered_map<StateId, Token*>::iterator iter = prev_toks_.begin();
  477. iter != prev_toks_.end();
  478. ++iter) {
  479. StateId state = iter->first;
  480. Token *tok = iter->second;
  481. for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
  482. !aiter.Done();
  483. aiter.Next()) {
  484. const Arc &arc = aiter.Value();
  485. if (arc.ilabel != 0) { // propagate..
  486. BaseFloat ac_cost = -decodable->LogLikelihood(frame, arc.ilabel),
  487. graph_cost = arc.weight.Value(),
  488. cur_cost = tok->tot_cost,
  489. tot_cost = cur_cost + ac_cost + graph_cost;
  490. if (tot_cost >= cutoff) continue;
  491. else if (tot_cost + config_.beam < cutoff)
  492. cutoff = tot_cost + config_.beam;
  493. // AddToken adds the next_tok to cur_toks_ (if not already present).
  494. Token *next_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost,
  495. true, NULL);
  496. // Add ForwardLink from tok to next_tok (put on head of list tok->links)
  497. tok->links = new ForwardLink(next_tok, arc.ilabel, arc.olabel,
  498. graph_cost, ac_cost, tok->links);
  499. }
  500. }
  501. }
  502. }
  503. void LatticeSimpleDecoder::ProcessNonemitting() {
  504. KALDI_ASSERT(!active_toks_.empty());
  505. int32 frame = static_cast<int32>(active_toks_.size()) - 2;
  506. // Note: "frame" is the time-index we just processed, or -1 if
  507. // we are processing the nonemitting transitions before the
  508. // first frame (called from InitDecoding()).
  509. // Processes nonemitting arcs for one frame. Propagates within
  510. // cur_toks_. Note-- this queue structure is is not very optimal as
  511. // it may cause us to process states unnecessarily (e.g. more than once),
  512. // but in the baseline code, turning this vector into a set to fix this
  513. // problem did not improve overall speed.
  514. std::vector<StateId> queue;
  515. BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
  516. for (unordered_map<StateId, Token*>::iterator iter = cur_toks_.begin();
  517. iter != cur_toks_.end();
  518. ++iter) {
  519. StateId state = iter->first;
  520. if (fst_.NumInputEpsilons(state) != 0)
  521. queue.push_back(state);
  522. best_cost = std::min(best_cost, iter->second->tot_cost);
  523. }
  524. if (queue.empty()) {
  525. if (!warned_) {
  526. KALDI_ERR << "Error in ProcessEmitting: no surviving tokens: frame is "
  527. << frame;
  528. warned_ = true;
  529. }
  530. }
  531. BaseFloat cutoff = best_cost + config_.beam;
  532. while (!queue.empty()) {
  533. StateId state = queue.back();
  534. queue.pop_back();
  535. Token *tok = cur_toks_[state];
  536. // If "tok" has any existing forward links, delete them,
  537. // because we're about to regenerate them. This is a kind
  538. // of non-optimality (remember, this is the simple decoder),
  539. // but since most states are emitting it's not a huge issue.
  540. tok->DeleteForwardLinks();
  541. tok->links = NULL;
  542. for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
  543. !aiter.Done();
  544. aiter.Next()) {
  545. const Arc &arc = aiter.Value();
  546. if (arc.ilabel == 0) { // propagate nonemitting only...
  547. BaseFloat graph_cost = arc.weight.Value(),
  548. cur_cost = tok->tot_cost,
  549. tot_cost = cur_cost + graph_cost;
  550. if (tot_cost < cutoff) {
  551. bool changed;
  552. Token *new_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost,
  553. false, &changed);
  554. tok->links = new ForwardLink(new_tok, 0, arc.olabel,
  555. graph_cost, 0, tok->links);
  556. // "changed" tells us whether the new token has a different
  557. // cost from before, or is new [if so, add into queue].
  558. if (changed && fst_.NumInputEpsilons(arc.nextstate) != 0)
  559. queue.push_back(arc.nextstate);
  560. }
  561. }
  562. }
  563. }
  564. }
  565. void LatticeSimpleDecoder::ClearActiveTokens() { // a cleanup routine, at utt end/begin
  566. for (size_t i = 0; i < active_toks_.size(); i++) {
  567. // Delete all tokens alive on this frame, and any forward
  568. // links they may have.
  569. for (Token *tok = active_toks_[i].toks; tok != NULL; ) {
  570. tok->DeleteForwardLinks();
  571. Token *next_tok = tok->next;
  572. delete tok;
  573. num_toks_--;
  574. tok = next_tok;
  575. }
  576. }
  577. active_toks_.clear();
  578. KALDI_ASSERT(num_toks_ == 0);
  579. }
  580. // PruneCurrentTokens deletes the tokens from the "toks" map, but not
  581. // from the active_toks_ list, which could cause dangling forward pointers
  582. // (will delete it during regular pruning operation).
  583. void LatticeSimpleDecoder::PruneCurrentTokens(BaseFloat beam, unordered_map<StateId, Token*> *toks) {
  584. if (toks->empty()) {
  585. KALDI_VLOG(2) << "No tokens to prune.\n";
  586. return;
  587. }
  588. BaseFloat best_cost = 1.0e+10; // positive == high cost == bad.
  589. for (unordered_map<StateId, Token*>::iterator iter = toks->begin();
  590. iter != toks->end(); ++iter) {
  591. best_cost =
  592. std::min(best_cost,
  593. static_cast<BaseFloat>(iter->second->tot_cost));
  594. }
  595. std::vector<StateId> retained;
  596. BaseFloat cutoff = best_cost + beam;
  597. for (unordered_map<StateId, Token*>::iterator iter = toks->begin();
  598. iter != toks->end(); ++iter) {
  599. if (iter->second->tot_cost < cutoff)
  600. retained.push_back(iter->first);
  601. }
  602. unordered_map<StateId, Token*> tmp;
  603. for (size_t i = 0; i < retained.size(); i++) {
  604. tmp[retained[i]] = (*toks)[retained[i]];
  605. }
  606. KALDI_VLOG(2) << "Pruned to "<<(retained.size())<<" toks.\n";
  607. tmp.swap(*toks);
  608. }
  609. } // end namespace kaldi.