| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195 |
- // decoder/faster-decoder.h
- // Copyright 2009-2011 Microsoft Corporation
- // 2013 Johns Hopkins University (author: Daniel Povey)
- // See ../../COPYING for clarification regarding multiple authors
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
- // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
- // MERCHANTABLITY OR NON-INFRINGEMENT.
- // See the Apache 2 License for the specific language governing permissions and
- // limitations under the License.
- #ifndef KALDI_DECODER_FASTER_DECODER_H_
- #define KALDI_DECODER_FASTER_DECODER_H_
- #include "util/stl-utils.h"
- #include "itf/options-itf.h"
- #include "util/hash-list.h"
- #include "fst/fstlib.h"
- #include "itf/decodable-itf.h"
- #include "lat/kaldi-lattice.h" // for CompactLatticeArc
- namespace kaldi {
- struct FasterDecoderOptions {
- BaseFloat beam;
- int32 max_active;
- int32 min_active;
- BaseFloat beam_delta;
- BaseFloat hash_ratio;
- FasterDecoderOptions(): beam(16.0),
- max_active(std::numeric_limits<int32>::max()),
- min_active(20), // This decoder mostly used for
- // alignment, use small default.
- beam_delta(0.5),
- hash_ratio(2.0) { }
- void Register(OptionsItf *opts, bool full) { /// if "full", use obscure
- /// options too.
- /// Depends on program.
- opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate.");
- opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; "
- "more accurate");
- opts->Register("min-active", &min_active,
- "Decoder min active states (don't prune if #active less than this).");
- if (full) {
- opts->Register("beam-delta", &beam_delta,
- "Increment used in decoder [obscure setting]");
- opts->Register("hash-ratio", &hash_ratio,
- "Setting used in decoder to control hash behavior");
- }
- }
- };
- class FasterDecoder {
- public:
- typedef fst::StdArc Arc;
- typedef Arc::Label Label;
- typedef Arc::StateId StateId;
- typedef Arc::Weight Weight;
- FasterDecoder(const fst::Fst<fst::StdArc> &fst,
- const FasterDecoderOptions &config);
- void SetOptions(const FasterDecoderOptions &config) { config_ = config; }
- ~FasterDecoder() { ClearToks(toks_.Clear()); }
- void Decode(DecodableInterface *decodable);
- /// Returns true if a final state was active on the last frame.
- bool ReachedFinal() const;
- /// GetBestPath gets the decoding traceback. If "use_final_probs" is true
- /// AND we reached a final state, it limits itself to final states;
- /// otherwise it gets the most likely token not taking into account
- /// final-probs. Returns true if the output best path was not the empty
- /// FST (will only return false in unusual circumstances where
- /// no tokens survived).
- bool GetBestPath(fst::MutableFst<LatticeArc> *fst_out,
- bool use_final_probs = true);
- /// As a new alternative to Decode(), you can call InitDecoding
- /// and then (possibly multiple times) AdvanceDecoding().
- void InitDecoding();
- /// This will decode until there are no more frames ready in the decodable
- /// object, but if max_num_frames is >= 0 it will decode no more than
- /// that many frames.
- void AdvanceDecoding(DecodableInterface *decodable,
- int32 max_num_frames = -1);
- /// Returns the number of frames already decoded.
- int32 NumFramesDecoded() const { return num_frames_decoded_; }
- protected:
- class Token {
- public:
- Arc arc_; // contains only the graph part of the cost;
- // we can work out the acoustic part from difference between
- // "cost_" and prev->cost_.
- Token *prev_;
- int32 ref_count_;
- // if you are looking for weight_ here, it was removed and now we just have
- // cost_, which corresponds to ConvertToCost(weight_).
- double cost_;
- inline Token(const Arc &arc, BaseFloat ac_cost, Token *prev):
- arc_(arc), prev_(prev), ref_count_(1) {
- if (prev) {
- prev->ref_count_++;
- cost_ = prev->cost_ + arc.weight.Value() + ac_cost;
- } else {
- cost_ = arc.weight.Value() + ac_cost;
- }
- }
- inline Token(const Arc &arc, Token *prev):
- arc_(arc), prev_(prev), ref_count_(1) {
- if (prev) {
- prev->ref_count_++;
- cost_ = prev->cost_ + arc.weight.Value();
- } else {
- cost_ = arc.weight.Value();
- }
- }
- inline bool operator < (const Token &other) {
- return cost_ > other.cost_;
- }
- inline static void TokenDelete(Token *tok) {
- while (--tok->ref_count_ == 0) {
- Token *prev = tok->prev_;
- delete tok;
- if (prev == NULL) return;
- else tok = prev;
- }
- #ifdef KALDI_PARANOID
- KALDI_ASSERT(tok->ref_count_ > 0);
- #endif
- }
- };
- typedef HashList<StateId, Token*>::Elem Elem;
- /// Gets the weight cutoff. Also counts the active tokens.
- double GetCutoff(Elem *list_head, size_t *tok_count,
- BaseFloat *adaptive_beam, Elem **best_elem);
- void PossiblyResizeHash(size_t num_toks);
- // ProcessEmitting returns the likelihood cutoff used.
- // It decodes the frame num_frames_decoded_ of the decodable object
- // and then increments num_frames_decoded_
- double ProcessEmitting(DecodableInterface *decodable);
- // TODO: first time we go through this, could avoid using the queue.
- void ProcessNonemitting(double cutoff);
- // HashList defined in ../util/hash-list.h. It actually allows us to maintain
- // more than one list (e.g. for current and previous frames), but only one of
- // them at a time can be indexed by StateId.
- HashList<StateId, Token*> toks_;
- const fst::Fst<fst::StdArc> &fst_;
- FasterDecoderOptions config_;
- std::vector<const Elem* > queue_; // temp variable used in ProcessNonemitting,
- std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
- // make it class member to avoid internal new/delete.
- // Keep track of the number of frames decoded in the current file.
- int32 num_frames_decoded_;
- // It might seem unclear why we call ClearToks(toks_.Clear()).
- // There are two separate cleanup tasks we need to do at when we start a new file.
- // one is to delete the Token objects in the list; the other is to delete
- // the Elem objects. toks_.Clear() just clears them from the hash and gives ownership
- // to the caller, who then has to call toks_.Delete(e) for each one. It was designed
- // this way for convenience in propagating tokens from one frame to the next.
- void ClearToks(Elem *list);
- KALDI_DISALLOW_COPY_AND_ASSIGN(FasterDecoder);
- };
- } // end namespace kaldi.
- #endif
|