decoder-wrappers.cc 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665
  1. // decoder/decoder-wrappers.cc
  2. // Copyright 2014 Johns Hopkins University (author: Daniel Povey)
  3. // See ../../COPYING for clarification regarding multiple authors
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  12. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  13. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  14. // MERCHANTABLITY OR NON-INFRINGEMENT.
  15. // See the Apache 2 License for the specific language governing permissions and
  16. // limitations under the License.
  17. #include "decoder/decoder-wrappers.h"
  18. #include "decoder/faster-decoder.h"
  19. #include "decoder/lattice-faster-decoder.h"
  20. #include "decoder/grammar-fst.h"
  21. #include "lat/lattice-functions.h"
  22. namespace kaldi {
  23. DecodeUtteranceLatticeFasterClass::DecodeUtteranceLatticeFasterClass(
  24. LatticeFasterDecoder *decoder,
  25. DecodableInterface *decodable,
  26. const TransitionInformation &trans_model,
  27. const fst::SymbolTable *word_syms,
  28. const std::string &utt,
  29. BaseFloat acoustic_scale,
  30. bool determinize,
  31. bool allow_partial,
  32. Int32VectorWriter *alignments_writer,
  33. Int32VectorWriter *words_writer,
  34. CompactLatticeWriter *compact_lattice_writer,
  35. LatticeWriter *lattice_writer,
  36. double *like_sum, // on success, adds likelihood to this.
  37. int64 *frame_sum, // on success, adds #frames to this.
  38. int32 *num_done, // on success (including partial decode), increments this.
  39. int32 *num_err, // on failure, increments this.
  40. int32 *num_partial): // If partial decode (final-state not reached), increments this.
  41. decoder_(decoder), decodable_(decodable), trans_model_(&trans_model),
  42. word_syms_(word_syms), utt_(utt), acoustic_scale_(acoustic_scale),
  43. determinize_(determinize), allow_partial_(allow_partial),
  44. alignments_writer_(alignments_writer),
  45. words_writer_(words_writer),
  46. compact_lattice_writer_(compact_lattice_writer),
  47. lattice_writer_(lattice_writer),
  48. like_sum_(like_sum), frame_sum_(frame_sum),
  49. num_done_(num_done), num_err_(num_err),
  50. num_partial_(num_partial),
  51. computed_(false), success_(false), partial_(false),
  52. clat_(NULL), lat_(NULL) { }
  53. void DecodeUtteranceLatticeFasterClass::operator () () {
  54. // Decoding and lattice determinization happens here.
  55. computed_ = true; // Just means this function was called-- a check on the
  56. // calling code.
  57. success_ = true;
  58. using fst::VectorFst;
  59. if (!decoder_->Decode(decodable_)) {
  60. KALDI_WARN << "Failed to decode utterance with id " << utt_;
  61. success_ = false;
  62. }
  63. if (!decoder_->ReachedFinal()) {
  64. if (allow_partial_) {
  65. KALDI_WARN << "Outputting partial output for utterance " << utt_
  66. << " since no final-state reached\n";
  67. partial_ = true;
  68. } else {
  69. KALDI_WARN << "Not producing output for utterance " << utt_
  70. << " since no final-state reached and "
  71. << "--allow-partial=false.\n";
  72. success_ = false;
  73. }
  74. }
  75. if (!success_) return;
  76. // Get lattice, and do determinization if requested.
  77. lat_ = new Lattice;
  78. decoder_->GetRawLattice(lat_);
  79. if (lat_->NumStates() == 0)
  80. KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt_;
  81. fst::Connect(lat_);
  82. if (determinize_) {
  83. clat_ = new CompactLattice;
  84. if (!DeterminizeLatticePhonePrunedWrapper(
  85. *trans_model_,
  86. lat_,
  87. decoder_->GetOptions().lattice_beam,
  88. clat_,
  89. decoder_->GetOptions().det_opts))
  90. KALDI_WARN << "Determinization finished earlier than the beam for "
  91. << "utterance " << utt_;
  92. delete lat_;
  93. lat_ = NULL;
  94. // We'll write the lattice without acoustic scaling.
  95. if (acoustic_scale_ != 0.0)
  96. fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale_), clat_);
  97. } else {
  98. // We'll write the lattice without acoustic scaling.
  99. if (acoustic_scale_ != 0.0)
  100. fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale_), lat_);
  101. }
  102. }
  103. DecodeUtteranceLatticeFasterClass::~DecodeUtteranceLatticeFasterClass() {
  104. if (!computed_)
  105. KALDI_ERR << "Destructor called without operator (), error in calling code.";
  106. if (!success_) {
  107. if (num_err_ != NULL) (*num_err_)++;
  108. } else { // successful decode.
  109. // Getting the one-best output is lightweight enough that we can do it in
  110. // the destructor (easier than adding more variables to the class, and
  111. // will rarely slow down the main thread.)
  112. double likelihood;
  113. LatticeWeight weight;
  114. int32 num_frames;
  115. { // First do some stuff with word-level traceback...
  116. // This is basically for diagnostics.
  117. fst::VectorFst<LatticeArc> decoded;
  118. decoder_->GetBestPath(&decoded);
  119. if (decoded.NumStates() == 0) {
  120. // Shouldn't really reach this point as already checked success.
  121. KALDI_ERR << "Failed to get traceback for utterance " << utt_;
  122. }
  123. std::vector<int32> alignment;
  124. std::vector<int32> words;
  125. GetLinearSymbolSequence(decoded, &alignment, &words, &weight);
  126. num_frames = alignment.size();
  127. if (words_writer_->IsOpen())
  128. words_writer_->Write(utt_, words);
  129. if (alignments_writer_->IsOpen())
  130. alignments_writer_->Write(utt_, alignment);
  131. if (word_syms_ != NULL) {
  132. std::cerr << utt_ << ' ';
  133. for (size_t i = 0; i < words.size(); i++) {
  134. std::string s = word_syms_->Find(words[i]);
  135. if (s == "")
  136. KALDI_ERR << "Word-id " << words[i] << " not in symbol table.";
  137. std::cerr << s << ' ';
  138. }
  139. std::cerr << '\n';
  140. }
  141. likelihood = -(weight.Value1() + weight.Value2());
  142. }
  143. // Ouptut the lattices.
  144. if (determinize_) { // CompactLattice output.
  145. KALDI_ASSERT(compact_lattice_writer_ != NULL && clat_ != NULL);
  146. if (clat_->NumStates() == 0) {
  147. KALDI_WARN << "Empty lattice for utterance " << utt_;
  148. } else {
  149. compact_lattice_writer_->Write(utt_, *clat_);
  150. }
  151. delete clat_;
  152. clat_ = NULL;
  153. } else {
  154. KALDI_ASSERT(lattice_writer_ != NULL && lat_ != NULL);
  155. if (lat_->NumStates() == 0) {
  156. KALDI_WARN << "Empty lattice for utterance " << utt_;
  157. } else {
  158. lattice_writer_->Write(utt_, *lat_);
  159. }
  160. delete lat_;
  161. lat_ = NULL;
  162. }
  163. // Print out logging information.
  164. KALDI_LOG << "Log-like per frame for utterance " << utt_ << " is "
  165. << (likelihood / num_frames) << " over "
  166. << num_frames << " frames.";
  167. KALDI_VLOG(2) << "Cost for utterance " << utt_ << " is "
  168. << weight.Value1() << " + " << weight.Value2();
  169. // Now output the various diagnostic variables.
  170. if (like_sum_ != NULL) *like_sum_ += likelihood;
  171. if (frame_sum_ != NULL) *frame_sum_ += num_frames;
  172. if (num_done_ != NULL) (*num_done_)++;
  173. if (partial_ && num_partial_ != NULL) (*num_partial_)++;
  174. }
  175. // We were given ownership of these two objects that were passed in in
  176. // the initializer.
  177. delete decoder_;
  178. delete decodable_;
  179. }
  180. template <typename FST>
  181. bool DecodeUtteranceLatticeIncremental(
  182. LatticeIncrementalDecoderTpl<FST> &decoder, // not const but is really an input.
  183. DecodableInterface &decodable, // not const but is really an input.
  184. const TransitionInformation &trans_model,
  185. const fst::SymbolTable *word_syms,
  186. std::string utt,
  187. double acoustic_scale,
  188. bool determinize,
  189. bool allow_partial,
  190. Int32VectorWriter *alignment_writer,
  191. Int32VectorWriter *words_writer,
  192. CompactLatticeWriter *compact_lattice_writer,
  193. LatticeWriter *lattice_writer,
  194. double *like_ptr) { // puts utterance's like in like_ptr on success.
  195. using fst::VectorFst;
  196. if (!decoder.Decode(&decodable)) {
  197. KALDI_WARN << "Failed to decode utterance with id " << utt;
  198. return false;
  199. }
  200. if (!decoder.ReachedFinal()) {
  201. if (allow_partial) {
  202. KALDI_WARN << "Outputting partial output for utterance " << utt
  203. << " since no final-state reached\n";
  204. } else {
  205. KALDI_WARN << "Not producing output for utterance " << utt
  206. << " since no final-state reached and "
  207. << "--allow-partial=false.\n";
  208. return false;
  209. }
  210. }
  211. // Get lattice
  212. CompactLattice clat = decoder.GetLattice(decoder.NumFramesDecoded(), true);
  213. if (clat.NumStates() == 0)
  214. KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt;
  215. double likelihood;
  216. LatticeWeight weight;
  217. int32 num_frames;
  218. { // First do some stuff with word-level traceback...
  219. CompactLattice decoded_clat;
  220. CompactLatticeShortestPath(clat, &decoded_clat);
  221. Lattice decoded;
  222. fst::ConvertLattice(decoded_clat, &decoded);
  223. if (decoded.Start() == fst::kNoStateId)
  224. // Shouldn't really reach this point as already checked success.
  225. KALDI_ERR << "Failed to get traceback for utterance " << utt;
  226. std::vector<int32> alignment;
  227. std::vector<int32> words;
  228. GetLinearSymbolSequence(decoded, &alignment, &words, &weight);
  229. num_frames = alignment.size();
  230. KALDI_ASSERT(num_frames == decoder.NumFramesDecoded());
  231. if (words_writer->IsOpen())
  232. words_writer->Write(utt, words);
  233. if (alignment_writer->IsOpen())
  234. alignment_writer->Write(utt, alignment);
  235. if (word_syms != NULL) {
  236. std::cerr << utt << ' ';
  237. for (size_t i = 0; i < words.size(); i++) {
  238. std::string s = word_syms->Find(words[i]);
  239. if (s == "")
  240. KALDI_ERR << "Word-id " << words[i] << " not in symbol table.";
  241. std::cerr << s << ' ';
  242. }
  243. std::cerr << '\n';
  244. }
  245. likelihood = -(weight.Value1() + weight.Value2());
  246. }
  247. // We'll write the lattice without acoustic scaling.
  248. if (acoustic_scale != 0.0)
  249. fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat);
  250. Connect(&clat);
  251. compact_lattice_writer->Write(utt, clat);
  252. KALDI_LOG << "Log-like per frame for utterance " << utt << " is "
  253. << (likelihood / num_frames) << " over "
  254. << num_frames << " frames.";
  255. KALDI_VLOG(2) << "Cost for utterance " << utt << " is "
  256. << weight.Value1() << " + " << weight.Value2();
  257. *like_ptr = likelihood;
  258. return true;
  259. }
  260. // Takes care of output. Returns true on success.
  261. template <typename FST>
  262. bool DecodeUtteranceLatticeFaster(
  263. LatticeFasterDecoderTpl<FST> &decoder, // not const but is really an input.
  264. DecodableInterface &decodable, // not const but is really an input.
  265. const TransitionInformation &trans_model,
  266. const fst::SymbolTable *word_syms,
  267. std::string utt,
  268. double acoustic_scale,
  269. bool determinize,
  270. bool allow_partial,
  271. Int32VectorWriter *alignment_writer,
  272. Int32VectorWriter *words_writer,
  273. CompactLatticeWriter *compact_lattice_writer,
  274. LatticeWriter *lattice_writer,
  275. double *like_ptr) { // puts utterance's like in like_ptr on success.
  276. using fst::VectorFst;
  277. if (!decoder.Decode(&decodable)) {
  278. KALDI_WARN << "Failed to decode utterance with id " << utt;
  279. return false;
  280. }
  281. if (!decoder.ReachedFinal()) {
  282. if (allow_partial) {
  283. KALDI_WARN << "Outputting partial output for utterance " << utt
  284. << " since no final-state reached\n";
  285. } else {
  286. KALDI_WARN << "Not producing output for utterance " << utt
  287. << " since no final-state reached and "
  288. << "--allow-partial=false.\n";
  289. return false;
  290. }
  291. }
  292. double likelihood;
  293. LatticeWeight weight;
  294. int32 num_frames;
  295. { // First do some stuff with word-level traceback...
  296. VectorFst<LatticeArc> decoded;
  297. if (!decoder.GetBestPath(&decoded))
  298. // Shouldn't really reach this point as already checked success.
  299. KALDI_ERR << "Failed to get traceback for utterance " << utt;
  300. std::vector<int32> alignment;
  301. std::vector<int32> words;
  302. GetLinearSymbolSequence(decoded, &alignment, &words, &weight);
  303. num_frames = alignment.size();
  304. if (words_writer->IsOpen())
  305. words_writer->Write(utt, words);
  306. if (alignment_writer->IsOpen())
  307. alignment_writer->Write(utt, alignment);
  308. if (word_syms != NULL) {
  309. std::cerr << utt << ' ';
  310. for (size_t i = 0; i < words.size(); i++) {
  311. std::string s = word_syms->Find(words[i]);
  312. if (s == "")
  313. KALDI_ERR << "Word-id " << words[i] << " not in symbol table.";
  314. std::cerr << s << ' ';
  315. }
  316. std::cerr << '\n';
  317. }
  318. likelihood = -(weight.Value1() + weight.Value2());
  319. }
  320. // Get lattice, and do determinization if requested.
  321. Lattice lat;
  322. decoder.GetRawLattice(&lat);
  323. if (lat.NumStates() == 0)
  324. KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt;
  325. fst::Connect(&lat);
  326. if (determinize) {
  327. CompactLattice clat;
  328. if (!DeterminizeLatticePhonePrunedWrapper(
  329. trans_model,
  330. &lat,
  331. decoder.GetOptions().lattice_beam,
  332. &clat,
  333. decoder.GetOptions().det_opts))
  334. KALDI_WARN << "Determinization finished earlier than the beam for "
  335. << "utterance " << utt;
  336. // We'll write the lattice without acoustic scaling.
  337. if (acoustic_scale != 0.0)
  338. fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat);
  339. compact_lattice_writer->Write(utt, clat);
  340. } else {
  341. // We'll write the lattice without acoustic scaling.
  342. if (acoustic_scale != 0.0)
  343. fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &lat);
  344. lattice_writer->Write(utt, lat);
  345. }
  346. KALDI_LOG << "Log-like per frame for utterance " << utt << " is "
  347. << (likelihood / num_frames) << " over "
  348. << num_frames << " frames.";
  349. KALDI_VLOG(2) << "Cost for utterance " << utt << " is "
  350. << weight.Value1() << " + " << weight.Value2();
  351. *like_ptr = likelihood;
  352. return true;
  353. }
  354. // Instantiate the template above for the two required FST types.
  355. template bool DecodeUtteranceLatticeIncremental(
  356. LatticeIncrementalDecoderTpl<fst::Fst<fst::StdArc> > &decoder,
  357. DecodableInterface &decodable,
  358. const TransitionInformation &trans_model,
  359. const fst::SymbolTable *word_syms,
  360. std::string utt,
  361. double acoustic_scale,
  362. bool determinize,
  363. bool allow_partial,
  364. Int32VectorWriter *alignment_writer,
  365. Int32VectorWriter *words_writer,
  366. CompactLatticeWriter *compact_lattice_writer,
  367. LatticeWriter *lattice_writer,
  368. double *like_ptr);
  369. template bool DecodeUtteranceLatticeIncremental(
  370. LatticeIncrementalDecoderTpl<fst::ConstGrammarFst > &decoder,
  371. DecodableInterface &decodable,
  372. const TransitionInformation &trans_model,
  373. const fst::SymbolTable *word_syms,
  374. std::string utt,
  375. double acoustic_scale,
  376. bool determinize,
  377. bool allow_partial,
  378. Int32VectorWriter *alignment_writer,
  379. Int32VectorWriter *words_writer,
  380. CompactLatticeWriter *compact_lattice_writer,
  381. LatticeWriter *lattice_writer,
  382. double *like_ptr);
  383. template bool DecodeUtteranceLatticeFaster(
  384. LatticeFasterDecoderTpl<fst::Fst<fst::StdArc> > &decoder,
  385. DecodableInterface &decodable,
  386. const TransitionInformation &trans_model,
  387. const fst::SymbolTable *word_syms,
  388. std::string utt,
  389. double acoustic_scale,
  390. bool determinize,
  391. bool allow_partial,
  392. Int32VectorWriter *alignment_writer,
  393. Int32VectorWriter *words_writer,
  394. CompactLatticeWriter *compact_lattice_writer,
  395. LatticeWriter *lattice_writer,
  396. double *like_ptr);
  397. template bool DecodeUtteranceLatticeFaster(
  398. LatticeFasterDecoderTpl<fst::ConstGrammarFst > &decoder,
  399. DecodableInterface &decodable,
  400. const TransitionInformation &trans_model,
  401. const fst::SymbolTable *word_syms,
  402. std::string utt,
  403. double acoustic_scale,
  404. bool determinize,
  405. bool allow_partial,
  406. Int32VectorWriter *alignment_writer,
  407. Int32VectorWriter *words_writer,
  408. CompactLatticeWriter *compact_lattice_writer,
  409. LatticeWriter *lattice_writer,
  410. double *like_ptr);
  411. // Takes care of output. Returns true on success.
  412. bool DecodeUtteranceLatticeSimple(
  413. LatticeSimpleDecoder &decoder, // not const but is really an input.
  414. DecodableInterface &decodable, // not const but is really an input.
  415. const TransitionInformation &trans_model,
  416. const fst::SymbolTable *word_syms,
  417. std::string utt,
  418. double acoustic_scale,
  419. bool determinize,
  420. bool allow_partial,
  421. Int32VectorWriter *alignment_writer,
  422. Int32VectorWriter *words_writer,
  423. CompactLatticeWriter *compact_lattice_writer,
  424. LatticeWriter *lattice_writer,
  425. double *like_ptr) { // puts utterance's like in like_ptr on success.
  426. using fst::VectorFst;
  427. if (!decoder.Decode(&decodable)) {
  428. KALDI_WARN << "Failed to decode utterance with id " << utt;
  429. return false;
  430. }
  431. if (!decoder.ReachedFinal()) {
  432. if (allow_partial) {
  433. KALDI_WARN << "Outputting partial output for utterance " << utt
  434. << " since no final-state reached\n";
  435. } else {
  436. KALDI_WARN << "Not producing output for utterance " << utt
  437. << " since no final-state reached and "
  438. << "--allow-partial=false.\n";
  439. return false;
  440. }
  441. }
  442. double likelihood;
  443. LatticeWeight weight = LatticeWeight::Zero();
  444. int32 num_frames;
  445. { // First do some stuff with word-level traceback...
  446. VectorFst<LatticeArc> decoded;
  447. if (!decoder.GetBestPath(&decoded))
  448. // Shouldn't really reach this point as already checked success.
  449. KALDI_ERR << "Failed to get traceback for utterance " << utt;
  450. std::vector<int32> alignment;
  451. std::vector<int32> words;
  452. GetLinearSymbolSequence(decoded, &alignment, &words, &weight);
  453. num_frames = alignment.size();
  454. if (words_writer->IsOpen())
  455. words_writer->Write(utt, words);
  456. if (alignment_writer->IsOpen())
  457. alignment_writer->Write(utt, alignment);
  458. if (word_syms != NULL) {
  459. std::cerr << utt << ' ';
  460. for (size_t i = 0; i < words.size(); i++) {
  461. std::string s = word_syms->Find(words[i]);
  462. if (s == "")
  463. KALDI_ERR << "Word-id " << words[i] << " not in symbol table.";
  464. std::cerr << s << ' ';
  465. }
  466. std::cerr << '\n';
  467. }
  468. likelihood = -(weight.Value1() + weight.Value2());
  469. }
  470. // Get lattice, and do determinization if requested.
  471. Lattice lat;
  472. if (!decoder.GetRawLattice(&lat))
  473. KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt;
  474. fst::Connect(&lat);
  475. if (determinize) {
  476. CompactLattice clat;
  477. if (!DeterminizeLatticePhonePrunedWrapper(
  478. trans_model,
  479. &lat,
  480. decoder.GetOptions().lattice_beam,
  481. &clat,
  482. decoder.GetOptions().det_opts))
  483. KALDI_WARN << "Determinization finished earlier than the beam for "
  484. << "utterance " << utt;
  485. // We'll write the lattice without acoustic scaling.
  486. if (acoustic_scale != 0.0)
  487. fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat);
  488. compact_lattice_writer->Write(utt, clat);
  489. } else {
  490. // We'll write the lattice without acoustic scaling.
  491. if (acoustic_scale != 0.0)
  492. fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &lat);
  493. lattice_writer->Write(utt, lat);
  494. }
  495. KALDI_LOG << "Log-like per frame for utterance " << utt << " is "
  496. << (likelihood / num_frames) << " over "
  497. << num_frames << " frames.";
  498. KALDI_VLOG(2) << "Cost for utterance " << utt << " is "
  499. << weight.Value1() << " + " << weight.Value2();
  500. *like_ptr = likelihood;
  501. return true;
  502. }
  503. // see comment in header.
  504. void ModifyGraphForCarefulAlignment(
  505. fst::VectorFst<fst::StdArc> *fst) {
  506. typedef fst::StdArc Arc;
  507. typedef Arc::StateId StateId;
  508. typedef Arc::Label Label;
  509. typedef Arc::Weight Weight;
  510. StateId num_states = fst->NumStates();
  511. if (num_states == 0) {
  512. KALDI_WARN << "Empty FST input.";
  513. return;
  514. }
  515. Weight zero = Weight::Zero();
  516. // fst_rhs will be the right hand side of the Concat operation.
  517. fst::VectorFst<fst::StdArc> fst_rhs(*fst);
  518. // first remove the final-probs from fst_rhs.
  519. for (StateId state = 0; state < num_states; state++)
  520. fst_rhs.SetFinal(state, zero);
  521. StateId pre_initial = fst_rhs.AddState();
  522. Arc to_initial(0, 0, Weight::One(), fst_rhs.Start());
  523. fst_rhs.AddArc(pre_initial, to_initial);
  524. fst_rhs.SetStart(pre_initial);
  525. // make the pre_initial state final with probability one;
  526. // this is equivalent to keeping the final-probs of the first
  527. // FST when we do concat (otherwise they would get deleted).
  528. fst_rhs.SetFinal(pre_initial, Weight::One());
  529. fst::VectorFst<fst::StdArc> fst_concat;
  530. fst::Concat(fst, fst_rhs);
  531. }
  532. void AlignUtteranceWrapper(
  533. const AlignConfig &config,
  534. const std::string &utt,
  535. BaseFloat acoustic_scale, // affects scores written to scores_writer, if
  536. // present
  537. fst::VectorFst<fst::StdArc> *fst, // non-const in case config.careful ==
  538. // true.
  539. DecodableInterface *decodable, // not const but is really an input.
  540. Int32VectorWriter *alignment_writer,
  541. BaseFloatWriter *scores_writer,
  542. int32 *num_done,
  543. int32 *num_error,
  544. int32 *num_retried,
  545. double *tot_like,
  546. int64 *frame_count,
  547. BaseFloatVectorWriter *per_frame_acwt_writer) {
  548. if ((config.retry_beam != 0 && config.retry_beam <= config.beam) ||
  549. config.beam <= 0.0) {
  550. KALDI_ERR << "Beams do not make sense: beam " << config.beam
  551. << ", retry-beam " << config.retry_beam;
  552. }
  553. if (fst->Start() == fst::kNoStateId) {
  554. KALDI_WARN << "Empty decoding graph for " << utt;
  555. if (num_error != NULL) (*num_error)++;
  556. return;
  557. }
  558. if (config.careful)
  559. ModifyGraphForCarefulAlignment(fst);
  560. FasterDecoderOptions decode_opts;
  561. decode_opts.beam = config.beam;
  562. FasterDecoder decoder(*fst, decode_opts);
  563. decoder.Decode(decodable);
  564. bool ans = decoder.ReachedFinal(); // consider only final states.
  565. if (!ans && config.retry_beam != 0.0) {
  566. if (num_retried != NULL) (*num_retried)++;
  567. KALDI_WARN << "Retrying utterance " << utt << " with beam "
  568. << config.retry_beam;
  569. decode_opts.beam = config.retry_beam;
  570. decoder.SetOptions(decode_opts);
  571. decoder.Decode(decodable);
  572. ans = decoder.ReachedFinal();
  573. }
  574. if (!ans) { // Still did not reach final state.
  575. KALDI_WARN << "Did not successfully decode file " << utt << ", len = "
  576. << decodable->NumFramesReady();
  577. if (num_error != NULL) (*num_error)++;
  578. return;
  579. }
  580. fst::VectorFst<LatticeArc> decoded; // linear FST.
  581. decoder.GetBestPath(&decoded);
  582. if (decoded.NumStates() == 0) {
  583. KALDI_WARN << "Error getting best path from decoder (likely a bug)";
  584. if (num_error != NULL) (*num_error)++;
  585. return;
  586. }
  587. std::vector<int32> alignment;
  588. std::vector<int32> words;
  589. LatticeWeight weight;
  590. GetLinearSymbolSequence(decoded, &alignment, &words, &weight);
  591. BaseFloat like = -(weight.Value1()+weight.Value2()) / acoustic_scale;
  592. if (num_done != NULL) (*num_done)++;
  593. if (tot_like != NULL) (*tot_like) += like;
  594. if (frame_count != NULL) (*frame_count) += decodable->NumFramesReady();
  595. if (alignment_writer != NULL && alignment_writer->IsOpen())
  596. alignment_writer->Write(utt, alignment);
  597. if (scores_writer != NULL && scores_writer->IsOpen())
  598. scores_writer->Write(utt, -(weight.Value1()+weight.Value2()));
  599. Vector<BaseFloat> per_frame_loglikes;
  600. if (per_frame_acwt_writer != NULL && per_frame_acwt_writer->IsOpen()) {
  601. GetPerFrameAcousticCosts(decoded, &per_frame_loglikes);
  602. per_frame_loglikes.Scale(-1 / acoustic_scale);
  603. per_frame_acwt_writer->Write(utt, per_frame_loglikes);
  604. }
  605. }
  606. } // end namespace kaldi.