biglm-faster-decoder.h 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. // decoder/biglm-faster-decoder.h
  2. // Copyright 2009-2011 Microsoft Corporation, Gilles Boulianne
  3. // See ../../COPYING for clarification regarding multiple authors
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  12. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  13. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  14. // MERCHANTABLITY OR NON-INFRINGEMENT.
  15. // See the Apache 2 License for the specific language governing permissions and
  16. // limitations under the License.
  17. #ifndef KALDI_DECODER_BIGLM_FASTER_DECODER_H_
  18. #define KALDI_DECODER_BIGLM_FASTER_DECODER_H_
  19. #include "util/stl-utils.h"
  20. #include "util/hash-list.h"
  21. #include "fst/fstlib.h"
  22. #include "itf/decodable-itf.h"
  23. #include "lat/kaldi-lattice.h" // for CompactLatticeArc
  24. #include "decoder/faster-decoder.h" // for options class
  25. #include "fstext/deterministic-fst.h"
  26. namespace kaldi {
  27. struct BiglmFasterDecoderOptions: public FasterDecoderOptions {
  28. BiglmFasterDecoderOptions() {
  29. min_active = 200;
  30. }
  31. };
  32. /** This is as FasterDecoder, but does online composition between
  33. HCLG and the "difference language model", which is a deterministic
  34. FST that represents the difference between the language model you want
  35. and the language model you compiled HCLG with. The class
  36. DeterministicOnDemandFst follows through the epsilons in G for you
  37. (assuming G is a standard backoff language model) and makes it look
  38. like a determinized FST. Actually, in practice,
  39. DeterministicOnDemandFst operates in a mode where it composes two
  40. G's together; one has negated likelihoods and works by removing the
  41. LM probabilities that you made HCLG with, and one is the language model
  42. you want to use.
  43. */
  44. class BiglmFasterDecoder {
  45. public:
  46. typedef fst::StdArc Arc;
  47. typedef Arc::Label Label;
  48. typedef Arc::StateId StateId;
  49. // A PairId will be constructed as: (StateId in fst) + (StateId in lm_diff_fst) << 32;
  50. typedef uint64 PairId;
  51. typedef Arc::Weight Weight;
  52. // This constructor is the same as for FasterDecoder, except the second
  53. // argument (lm_diff_fst) is new; it's an FST (actually, a
  54. // DeterministicOnDemandFst) that represents the difference in LM scores
  55. // between the LM we want and the LM the decoding-graph "fst" was built with.
  56. // See e.g. gmm-decode-biglm-faster.cc for an example of how this is called.
  57. // Basically, we are using fst o lm_diff_fst (where o is composition)
  58. // as the decoding graph. Instead of having everything indexed by the state in
  59. // "fst", we now index by the pair of states in (fst, lm_diff_fst).
  60. // Whenever we cross a word, we need to propagate the state within
  61. // lm_diff_fst.
  62. BiglmFasterDecoder(const fst::Fst<fst::StdArc> &fst,
  63. const BiglmFasterDecoderOptions &opts,
  64. fst::DeterministicOnDemandFst<fst::StdArc> *lm_diff_fst):
  65. fst_(fst), lm_diff_fst_(lm_diff_fst), opts_(opts), warned_noarc_(false) {
  66. KALDI_ASSERT(opts_.hash_ratio >= 1.0); // less doesn't make much sense.
  67. KALDI_ASSERT(opts_.max_active > 1);
  68. KALDI_ASSERT(fst.Start() != fst::kNoStateId &&
  69. lm_diff_fst->Start() != fst::kNoStateId);
  70. toks_.SetSize(1000); // just so on the first frame we do something reasonable.
  71. }
  72. void SetOptions(const BiglmFasterDecoderOptions &opts) { opts_ = opts; }
  73. ~BiglmFasterDecoder() {
  74. ClearToks(toks_.Clear());
  75. }
  76. void Decode(DecodableInterface *decodable) {
  77. // clean up from last time:
  78. ClearToks(toks_.Clear());
  79. PairId start_pair = ConstructPair(fst_.Start(), lm_diff_fst_->Start());
  80. Arc dummy_arc(0, 0, Weight::One(), fst_.Start()); // actually, the last element of
  81. // the Arcs (fst_.Start(), here) is never needed.
  82. toks_.Insert(start_pair, new Token(dummy_arc, NULL));
  83. ProcessNonemitting(std::numeric_limits<float>::max());
  84. for (int32 frame = 0; !decodable->IsLastFrame(frame-1); frame++) {
  85. BaseFloat weight_cutoff = ProcessEmitting(decodable, frame);
  86. ProcessNonemitting(weight_cutoff);
  87. }
  88. }
  89. bool ReachedFinal() {
  90. for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
  91. PairId state_pair = e->key;
  92. StateId state = PairToState(state_pair),
  93. lm_state = PairToLmState(state_pair);
  94. Weight this_weight =
  95. Times(e->val->weight_,
  96. Times(fst_.Final(state), lm_diff_fst_->Final(lm_state)));
  97. if (this_weight != Weight::Zero())
  98. return true;
  99. }
  100. return false;
  101. }
  102. bool GetBestPath(fst::MutableFst<LatticeArc> *fst_out,
  103. bool use_final_probs = true) {
  104. // GetBestPath gets the decoding output. If "use_final_probs" is true
  105. // AND we reached a final state, it limits itself to final states;
  106. // otherwise it gets the most likely token not taking into
  107. // account final-probs. fst_out will be empty (Start() == kNoStateId) if
  108. // nothing was available. It returns true if it got output (thus, fst_out
  109. // will be nonempty).
  110. fst_out->DeleteStates();
  111. Token *best_tok = NULL;
  112. Weight best_final = Weight::Zero(); // set only if is_final == true. The
  113. // final-prob corresponding to the best final token (i.e. the one with best
  114. // weight best_weight, below).
  115. bool is_final = ReachedFinal();
  116. if (!is_final) {
  117. for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
  118. if (best_tok == NULL || *best_tok < *(e->val) )
  119. best_tok = e->val;
  120. } else {
  121. Weight best_weight = Weight::Zero();
  122. for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
  123. Weight fst_final = fst_.Final(PairToState(e->key)),
  124. lm_final = lm_diff_fst_->Final(PairToLmState(e->key)),
  125. final = Times(fst_final, lm_final);
  126. Weight this_weight = Times(e->val->weight_, final);
  127. if (this_weight != Weight::Zero() &&
  128. this_weight.Value() < best_weight.Value()) {
  129. best_weight = this_weight;
  130. best_final = final;
  131. best_tok = e->val;
  132. }
  133. }
  134. }
  135. if (best_tok == NULL) return false; // No output.
  136. std::vector<LatticeArc> arcs_reverse; // arcs in reverse order.
  137. for (Token *tok = best_tok; tok != NULL; tok = tok->prev_) {
  138. BaseFloat tot_cost = tok->weight_.Value() -
  139. (tok->prev_ ? tok->prev_->weight_.Value() : 0.0),
  140. graph_cost = tok->arc_.weight.Value(),
  141. ac_cost = tot_cost - graph_cost;
  142. LatticeArc l_arc(tok->arc_.ilabel,
  143. tok->arc_.olabel,
  144. LatticeWeight(graph_cost, ac_cost),
  145. tok->arc_.nextstate);
  146. arcs_reverse.push_back(l_arc);
  147. }
  148. KALDI_ASSERT(arcs_reverse.back().nextstate == fst_.Start());
  149. arcs_reverse.pop_back(); // that was a "fake" token... gives no info.
  150. StateId cur_state = fst_out->AddState();
  151. fst_out->SetStart(cur_state);
  152. for (ssize_t i = static_cast<ssize_t>(arcs_reverse.size())-1; i >= 0; i--) {
  153. LatticeArc arc = arcs_reverse[i];
  154. arc.nextstate = fst_out->AddState();
  155. fst_out->AddArc(cur_state, arc);
  156. cur_state = arc.nextstate;
  157. }
  158. if (is_final && use_final_probs) {
  159. fst_out->SetFinal(cur_state, LatticeWeight(best_final.Value(), 0.0));
  160. } else {
  161. fst_out->SetFinal(cur_state, LatticeWeight::One());
  162. }
  163. RemoveEpsLocal(fst_out);
  164. return true;
  165. }
  166. private:
  167. inline PairId ConstructPair(StateId fst_state, StateId lm_state) {
  168. return static_cast<PairId>(fst_state) + (static_cast<PairId>(lm_state) << 32);
  169. }
  170. static inline StateId PairToState(PairId state_pair) {
  171. return static_cast<StateId>(static_cast<uint32>(state_pair));
  172. }
  173. static inline StateId PairToLmState(PairId state_pair) {
  174. return static_cast<StateId>(static_cast<uint32>(state_pair >> 32));
  175. }
  176. class Token {
  177. public:
  178. Arc arc_; // contains only the graph part of the cost,
  179. // including the part in "fst" (== HCLG) plus lm_diff_fst.
  180. // We can work out the acoustic part from difference between
  181. // "weight_" and prev->weight_.
  182. Token *prev_;
  183. int32 ref_count_;
  184. Weight weight_; // weight up to current point.
  185. inline Token(const Arc &arc, Weight &ac_weight, Token *prev):
  186. arc_(arc), prev_(prev), ref_count_(1) {
  187. if (prev) {
  188. prev->ref_count_++;
  189. weight_ = Times(Times(prev->weight_, arc.weight), ac_weight);
  190. } else {
  191. weight_ = Times(arc.weight, ac_weight);
  192. }
  193. }
  194. inline Token(const Arc &arc, Token *prev):
  195. arc_(arc), prev_(prev), ref_count_(1) {
  196. if (prev) {
  197. prev->ref_count_++;
  198. weight_ = Times(prev->weight_, arc.weight);
  199. } else {
  200. weight_ = arc.weight;
  201. }
  202. }
  203. inline bool operator < (const Token &other) {
  204. return weight_.Value() > other.weight_.Value();
  205. // This makes sense for log + tropical semiring.
  206. }
  207. inline ~Token() {
  208. KALDI_ASSERT(ref_count_ == 1);
  209. if (prev_ != NULL) TokenDelete(prev_);
  210. }
  211. inline static void TokenDelete(Token *tok) {
  212. if (tok->ref_count_ == 1) {
  213. delete tok;
  214. } else {
  215. tok->ref_count_--;
  216. }
  217. }
  218. };
  219. typedef HashList<PairId, Token*>::Elem Elem;
  220. /// Gets the weight cutoff. Also counts the active tokens.
  221. BaseFloat GetCutoff(Elem *list_head, size_t *tok_count,
  222. BaseFloat *adaptive_beam, Elem **best_elem) {
  223. BaseFloat best_weight = 1.0e+10; // positive == high cost == bad.
  224. size_t count = 0;
  225. if (opts_.max_active == std::numeric_limits<int32>::max() &&
  226. opts_.min_active == 0) {
  227. for (Elem *e = list_head; e != NULL; e = e->tail, count++) {
  228. BaseFloat w = static_cast<BaseFloat>(e->val->weight_.Value());
  229. if (w < best_weight) {
  230. best_weight = w;
  231. if (best_elem) *best_elem = e;
  232. }
  233. }
  234. if (tok_count != NULL) *tok_count = count;
  235. if (adaptive_beam != NULL) *adaptive_beam = opts_.beam;
  236. return best_weight + opts_.beam;
  237. } else {
  238. tmp_array_.clear();
  239. for (Elem *e = list_head; e != NULL; e = e->tail, count++) {
  240. BaseFloat w = e->val->weight_.Value();
  241. tmp_array_.push_back(w);
  242. if (w < best_weight) {
  243. best_weight = w;
  244. if (best_elem) *best_elem = e;
  245. }
  246. }
  247. if (tok_count != NULL) *tok_count = count;
  248. BaseFloat beam_cutoff = best_weight + opts_.beam,
  249. min_active_cutoff = std::numeric_limits<BaseFloat>::infinity(),
  250. max_active_cutoff = std::numeric_limits<BaseFloat>::infinity();
  251. if (tmp_array_.size() > static_cast<size_t>(opts_.max_active)) {
  252. std::nth_element(tmp_array_.begin(),
  253. tmp_array_.begin() + opts_.max_active,
  254. tmp_array_.end());
  255. max_active_cutoff = tmp_array_[opts_.max_active];
  256. }
  257. if (tmp_array_.size() > static_cast<size_t>(opts_.min_active)) {
  258. if (opts_.min_active == 0) min_active_cutoff = best_weight;
  259. else {
  260. std::nth_element(tmp_array_.begin(),
  261. tmp_array_.begin() + opts_.min_active,
  262. tmp_array_.size() > static_cast<size_t>(opts_.max_active) ?
  263. tmp_array_.begin() + opts_.max_active :
  264. tmp_array_.end());
  265. min_active_cutoff = tmp_array_[opts_.min_active];
  266. }
  267. }
  268. if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam.
  269. if (adaptive_beam)
  270. *adaptive_beam = max_active_cutoff - best_weight + opts_.beam_delta;
  271. return max_active_cutoff;
  272. } else if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam.
  273. if (adaptive_beam)
  274. *adaptive_beam = min_active_cutoff - best_weight + opts_.beam_delta;
  275. return min_active_cutoff;
  276. } else {
  277. *adaptive_beam = opts_.beam;
  278. return beam_cutoff;
  279. }
  280. }
  281. }
  282. void PossiblyResizeHash(size_t num_toks) {
  283. size_t new_sz = static_cast<size_t>(static_cast<BaseFloat>(num_toks)
  284. * opts_.hash_ratio);
  285. if (new_sz > toks_.Size()) {
  286. toks_.SetSize(new_sz);
  287. }
  288. }
  289. inline StateId PropagateLm(StateId lm_state,
  290. Arc *arc) { // returns new LM state.
  291. if (arc->olabel == 0) {
  292. return lm_state; // no change in LM state if no word crossed.
  293. } else { // Propagate in the LM-diff FST.
  294. Arc lm_arc;
  295. bool ans = lm_diff_fst_->GetArc(lm_state, arc->olabel, &lm_arc);
  296. if (!ans) { // this case is unexpected for statistical LMs.
  297. if (!warned_noarc_) {
  298. warned_noarc_ = true;
  299. KALDI_WARN << "No arc available in LM (unlikely to be correct "
  300. "if a statistical language model); will not warn again";
  301. }
  302. arc->weight = Weight::Zero();
  303. return lm_state; // doesn't really matter what we return here; will
  304. // be pruned.
  305. } else {
  306. arc->weight = Times(arc->weight, lm_arc.weight);
  307. arc->olabel = lm_arc.olabel; // probably will be the same.
  308. return lm_arc.nextstate; // return the new LM state.
  309. }
  310. }
  311. }
  312. // ProcessEmitting returns the likelihood cutoff used.
  313. BaseFloat ProcessEmitting(DecodableInterface *decodable, int frame) {
  314. Elem *last_toks = toks_.Clear();
  315. size_t tok_cnt;
  316. BaseFloat adaptive_beam;
  317. Elem *best_elem = NULL;
  318. BaseFloat weight_cutoff = GetCutoff(last_toks, &tok_cnt,
  319. &adaptive_beam, &best_elem);
  320. PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough.
  321. // This is the cutoff we use after adding in the log-likes (i.e.
  322. // for the next frame). This is a bound on the cutoff we will use
  323. // on the next frame.
  324. BaseFloat next_weight_cutoff = 1.0e+10;
  325. // First process the best token to get a hopefully
  326. // reasonably tight bound on the next cutoff.
  327. if (best_elem) {
  328. PairId state_pair = best_elem->key;
  329. StateId state = PairToState(state_pair),
  330. lm_state = PairToLmState(state_pair);
  331. Token *tok = best_elem->val;
  332. for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
  333. !aiter.Done();
  334. aiter.Next()) {
  335. Arc arc = aiter.Value();
  336. if (arc.ilabel != 0) { // we'd propagate..
  337. PropagateLm(lm_state, &arc); // may affect "arc.weight".
  338. // We don't need the return value (the new LM state).
  339. BaseFloat ac_cost = - decodable->LogLikelihood(frame, arc.ilabel),
  340. new_weight = arc.weight.Value() + tok->weight_.Value() + ac_cost;
  341. if (new_weight + adaptive_beam < next_weight_cutoff)
  342. next_weight_cutoff = new_weight + adaptive_beam;
  343. }
  344. }
  345. }
  346. // the tokens are now owned here, in last_toks, and the hash is empty.
  347. // 'owned' is a complex thing here; the point is we need to call toks_.Delete(e)
  348. // on each elem 'e' to let toks_ know we're done with them.
  349. for (Elem *e = last_toks, *e_tail; e != NULL; e = e_tail) { // loop this way
  350. // because we delete "e" as we go.
  351. PairId state_pair = e->key;
  352. StateId state = PairToState(state_pair),
  353. lm_state = PairToLmState(state_pair);
  354. Token *tok = e->val;
  355. if (tok->weight_.Value() < weight_cutoff) { // not pruned.
  356. KALDI_ASSERT(state == tok->arc_.nextstate);
  357. for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
  358. !aiter.Done();
  359. aiter.Next()) {
  360. Arc arc = aiter.Value();
  361. if (arc.ilabel != 0) { // propagate.
  362. StateId next_lm_state = PropagateLm(lm_state, &arc);
  363. Weight ac_weight(-decodable->LogLikelihood(frame, arc.ilabel));
  364. BaseFloat new_weight = arc.weight.Value() + tok->weight_.Value()
  365. + ac_weight.Value();
  366. if (new_weight < next_weight_cutoff) { // not pruned..
  367. PairId next_pair = ConstructPair(arc.nextstate, next_lm_state);
  368. Token *new_tok = new Token(arc, ac_weight, tok);
  369. Elem *e_found = toks_.Insert(next_pair, new_tok);
  370. if (new_weight + adaptive_beam < next_weight_cutoff)
  371. next_weight_cutoff = new_weight + adaptive_beam;
  372. if (e_found->val != new_tok) {
  373. if (*(e_found->val) < *new_tok) {
  374. Token::TokenDelete(e_found->val);
  375. e_found->val = new_tok;
  376. } else {
  377. Token::TokenDelete(new_tok);
  378. }
  379. }
  380. }
  381. }
  382. }
  383. }
  384. e_tail = e->tail;
  385. Token::TokenDelete(e->val);
  386. toks_.Delete(e);
  387. }
  388. return next_weight_cutoff;
  389. }
  390. // TODO: first time we go through this, could avoid using the queue.
  391. void ProcessNonemitting(BaseFloat cutoff) {
  392. // Processes nonemitting arcs for one frame.
  393. KALDI_ASSERT(queue_.empty());
  394. for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
  395. queue_.push_back(e);
  396. while (!queue_.empty()) {
  397. const Elem *e = queue_.back();
  398. queue_.pop_back();
  399. PairId state_pair = e->key;
  400. Token *tok = e->val; // would segfault if state not
  401. // in toks_ but this can't happen.
  402. if (tok->weight_.Value() > cutoff) { // Don't bother processing successors.
  403. continue;
  404. }
  405. KALDI_ASSERT(tok != NULL);
  406. StateId state = PairToState(state_pair),
  407. lm_state = PairToLmState(state_pair);
  408. for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
  409. !aiter.Done();
  410. aiter.Next()) {
  411. const Arc &arc_ref = aiter.Value();
  412. if (arc_ref.ilabel == 0) { // propagate nonemitting only...
  413. Arc arc(arc_ref);
  414. StateId next_lm_state = PropagateLm(lm_state, &arc);
  415. PairId next_pair = ConstructPair(arc.nextstate, next_lm_state);
  416. Token *new_tok = new Token(arc, tok);
  417. if (new_tok->weight_.Value() > cutoff) { // prune
  418. Token::TokenDelete(new_tok);
  419. } else {
  420. Elem *e_found = toks_.Insert(next_pair, new_tok);
  421. if (e_found->val == new_tok) {
  422. queue_.push_back(e_found);
  423. } else {
  424. if ( *(e_found->val) < *new_tok ) {
  425. Token::TokenDelete(e_found->val);
  426. e_found->val = new_tok;
  427. queue_.push_back(e_found);
  428. } else {
  429. Token::TokenDelete(new_tok);
  430. }
  431. }
  432. }
  433. }
  434. }
  435. }
  436. }
  437. // HashList defined in ../util/hash-list.h. It actually allows us to maintain
  438. // more than one list (e.g. for current and previous frames), but only one of
  439. // them at a time can be indexed by PairId.
  440. HashList<PairId, Token*> toks_;
  441. const fst::Fst<fst::StdArc> &fst_;
  442. fst::DeterministicOnDemandFst<fst::StdArc> *lm_diff_fst_;
  443. BiglmFasterDecoderOptions opts_;
  444. bool warned_noarc_;
  445. std::vector<const Elem* > queue_; // temp variable used in ProcessNonemitting,
  446. std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
  447. // make it class member to avoid internal new/delete.
  448. // It might seem unclear why we call ClearToks(toks_.Clear()).
  449. // There are two separate cleanup tasks we need to do at when we start a new file.
  450. // one is to delete the Token objects in the list; the other is to delete
  451. // the Elem objects. toks_.Clear() just clears them from the hash and gives ownership
  452. // to the caller, who then has to call toks_.Delete(e) for each one. It was designed
  453. // this way for convenience in propagating tokens from one frame to the next.
  454. void ClearToks(Elem *list) {
  455. for (Elem *e = list, *e_tail; e != NULL; e = e_tail) {
  456. Token::TokenDelete(e->val);
  457. e_tail = e->tail;
  458. toks_.Delete(e);
  459. }
  460. }
  461. KALDI_DISALLOW_COPY_AND_ASSIGN(BiglmFasterDecoder);
  462. };
  463. } // end namespace kaldi.
  464. #endif