From 5503ba176e8e5c50ad43c3f9ae1353fe50f54e68 Mon Sep 17 00:00:00 2001 From: LvHang Date: Sun, 24 Feb 2019 22:49:04 -0500 Subject: [PATCH 01/29] Combine ProcessEmitting() and ProcessNonemitting() small fix --- src/decoder/Makefile | 2 +- src/decoder/lattice-faster-decoder-combine.cc | 1096 +++++++++++++++++ src/decoder/lattice-faster-decoder-combine.h | 539 ++++++++ 3 files changed, 1636 insertions(+), 1 deletion(-) create mode 100644 src/decoder/lattice-faster-decoder-combine.cc create mode 100644 src/decoder/lattice-faster-decoder-combine.h diff --git a/src/decoder/Makefile b/src/decoder/Makefile index 020fe358fe9..53d469f4860 100644 --- a/src/decoder/Makefile +++ b/src/decoder/Makefile @@ -7,7 +7,7 @@ TESTFILES = OBJFILES = training-graph-compiler.o lattice-simple-decoder.o lattice-faster-decoder.o \ lattice-faster-online-decoder.o simple-decoder.o faster-decoder.o \ - decoder-wrappers.o grammar-fst.o decodable-matrix.o + decoder-wrappers.o grammar-fst.o decodable-matrix.o lattice-faster-decoder-combine.o LIBNAME = kaldi-decoder diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc new file mode 100644 index 00000000000..8cb6e59564d --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -0,0 +1,1096 @@ +// decoder/lattice-faster-decoder.cc + +// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann +// 2013-2018 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/lattice-faster-decoder-combine.h" +#include "lat/lattice-functions.h" + +namespace kaldi { + +// instantiate this class once for each thing you have to decode. +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const FST &fst, + const LatticeFasterDecoderCombineConfig &config): + fst_(&fst), delete_fst_(false), config_(config), num_toks_(0) { + config.Check(); +} + + +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const LatticeFasterDecoderCombineConfig &config, FST *fst): + fst_(fst), delete_fst_(true), config_(config), num_toks_(0) { + config.Check(); +} + + +template +LatticeFasterDecoderCombineTpl::~LatticeFasterDecoderCombineTpl() { + ClearActiveTokens(); + if (delete_fst_) delete fst_; +} + +template +void LatticeFasterDecoderCombineTpl::InitDecoding() { + // clean up from last time: + cur_toks_.clear(); + next_toks_.clear(); + cost_offsets_.clear(); + ClearActiveTokens(); + + warned_ = false; + num_toks_ = 0; + decoding_finalized_ = false; + final_costs_.clear(); + StateId start_state = fst_->Start(); + KALDI_ASSERT(start_state != fst::kNoStateId); + active_toks_.resize(1); + Token *start_tok = new Token(0.0, 0.0, NULL, NULL, NULL); + active_toks_[0].toks = start_tok; + cur_toks_[start_state] = start_tok; // initialize current tokens map + num_toks_++; + + recover_ = false; + frame_processed_.resize(1); + frame_processed_[0] = false; +} + +// Returns true if any kind of traceback is available (not necessarily from +// a final state). It should only very rarely return false; this indicates +// an unusual search error. +template +bool LatticeFasterDecoderCombineTpl::Decode(DecodableInterface *decodable) { + InitDecoding(); + + // We use 1-based indexing for frames in this decoder (if you view it in + // terms of features), but note that the decodable object uses zero-based + // numbering, which we have to correct for when we call it. + + while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { + if (NumFramesDecoded() % config_.prune_interval == 0) + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + ProcessForFrame(decodable); + } + // Procss non-emitting arcs for the last frame. + ProcessNonemitting(false); + frame_processed_[active_toks_.size() - 1] = true; // the last frame is processed. + + FinalizeDecoding(); + + // Returns true if we have any kind of traceback available (not necessarily + // to the end state; query ReachedFinal() for that). + return !active_toks_.empty() && active_toks_.back().toks != NULL; +} + + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeFasterDecoderCombineTpl::GetBestPath( + Lattice *olat, + bool use_final_probs) { + Lattice raw_lat; + GetRawLattice(&raw_lat, use_final_probs); + ShortestPath(raw_lat, olat); + return (olat->NumStates() != 0); +} + + +// Outputs an FST corresponding to the raw, state-level lattice +template +bool LatticeFasterDecoderCombineTpl::GetRawLattice( + Lattice *ofst, + bool use_final_probs) { + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + // Process the non-emitting arcs for the unfinished last frame. + if (!frame_processed_[active_toks_.size() - 1]) { + ProcessNonemitting(true); + } + // Note: you can't use the old interface (Decode()) if you want to + // get the lattice with use_final_probs = false. You'd have to do + // InitDecoding() and then AdvanceDecoding(). + if (decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "GetRawLattice() with use_final_probs == false"; + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (decoding_finalized_ ? final_costs_ : final_costs_local); + if (!decoding_finalized_ && use_final_probs) + ComputeFinalCosts(&final_costs_local, NULL, NULL); + + ofst->DeleteStates(); + // num-frames plus one (since frames are one-based, and we have + // an extra frame for the start-state). + int32 num_frames = active_toks_.size() - 1; + KALDI_ASSERT(num_frames > 0); + const int32 bucket_count = num_toks_/2 + 3; + unordered_map tok_map(bucket_count); + // First create all states. + std::vector token_list; + for (int32 f = 0; f <= num_frames; f++) { + if (active_toks_[f].toks == NULL) { + KALDI_WARN << "GetRawLattice: no tokens active on frame " << f + << ": not producing lattice.\n"; + return false; + } + TopSortTokens(active_toks_[f].toks, &token_list); + for (size_t i = 0; i < token_list.size(); i++) + if (token_list[i] != NULL) + tok_map[token_list[i]] = ofst->AddState(); + } + // The next statement sets the start state of the output FST. Because we + // topologically sorted the tokens, state zero must be the start-state. + ofst->SetStart(0); + + KALDI_VLOG(4) << "init:" << num_toks_/2 + 3 << " buckets:" + << tok_map.bucket_count() << " load:" << tok_map.load_factor() + << " max:" << tok_map.max_load_factor(); + // Now create all arcs. + for (int32 f = 0; f <= num_frames; f++) { + for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + for (ForwardLinkT *l = tok->links; + l != NULL; + l = l->next) { + typename unordered_map::const_iterator + iter = tok_map.find(l->next_tok); + StateId nextstate = iter->second; + KALDI_ASSERT(iter != tok_map.end()); + BaseFloat cost_offset = 0.0; + if (l->ilabel != 0) { // emitting.. + KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); + cost_offset = cost_offsets_[f]; + } + Arc arc(l->ilabel, l->olabel, + Weight(l->graph_cost, l->acoustic_cost - cost_offset), + nextstate); + ofst->AddArc(cur_state, arc); + } + if (f == num_frames) { + if (use_final_probs && !final_costs.empty()) { + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) + ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); + } else { + ofst->SetFinal(cur_state, LatticeWeight::One()); + } + } + } + } + return (ofst->NumStates() > 0); +} + + +// This function is now deprecated, since now we do determinization from outside +// the LatticeFasterDecoder class. Outputs an FST corresponding to the +// lattice-determinized lattice (one path per word sequence). +template +bool LatticeFasterDecoderCombineTpl::GetLattice( + CompactLattice *ofst, + bool use_final_probs) { + Lattice raw_fst; + GetRawLattice(&raw_fst, use_final_probs); + Invert(&raw_fst); // make it so word labels are on the input. + // (in phase where we get backward-costs). + fst::ILabelCompare ilabel_comp; + ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes + // lattice-determinization more efficient. + + fst::DeterminizeLatticePrunedOptions lat_opts; + lat_opts.max_mem = config_.det_opts.max_mem; + + DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts); + raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed. + Connect(ofst); // Remove unreachable states... there might be + // a small number of these, in some cases. + // Note: if something went wrong and the raw lattice was empty, + // we should still get to this point in the code without warnings or failures. + return (ofst->NumStates() != 0); +} + +/* + A note on the definition of extra_cost. + + extra_cost is used in pruning tokens, to save memory. + + Define the 'forward cost' of a token as zero for any token on the frame + we're currently decoding; and for other frames, as the shortest-path cost + between that token and a token on the frame we're currently decoding. + (by "currently decoding" I mean the most recently processed frame). + + Then define the extra_cost of a token (always >= 0) as the forward-cost of + the token minus the smallest forward-cost of any token on the same frame. + + We can use the extra_cost to accurately prune away tokens that we know will + never appear in the lattice. If the extra_cost is greater than the desired + lattice beam, the token would provably never appear in the lattice, so we can + prune away the token. + + The advantage of storing the extra_cost rather than the forward-cost, is that + it is less costly to keep the extra_cost up-to-date when we process new frames. + When we process a new frame, *all* the previous frames' forward-costs would change; + but in general the extra_cost will change only for a finite number of frames. + (Actually we don't update all the extra_costs every time we update a frame; we + only do it every 'config_.prune_interval' frames). + */ + +// FindOrAddToken either locates a token in hash of toks_, +// or if necessary inserts a new, empty token (i.e. with no forward links) +// for the current frame. [note: it's inserted if necessary into hash toks_ +// and also into the singly linked list of tokens active on this frame +// (whose head is at active_toks_[frame]). +template +inline Token* LatticeFasterDecoderCombineTpl::FindOrAddToken( + StateId state, int32 frame, BaseFloat tot_cost, Token *backpointer, + StateIdToTokenMap *token_map, bool *changed) { + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + KALDI_ASSERT(frame < active_toks_.size()); + Token *&toks = active_toks_[frame].toks; + typename StateIdToTokenMap::iterator e_found = token_map->find(state); + if (e_found == token_map->end()) { // no such token presently. + const BaseFloat extra_cost = 0.0; + // tokens on the currently final frame have zero extra_cost + // as any of them could end up + // on the winning path. + Token *new_tok = new Token (tot_cost, extra_cost, NULL, toks, backpointer); + // NULL: no forward links yet + toks = new_tok; + num_toks_++; + // insert into the map + (*token_map)[state] = new_tok; + if (changed) *changed = true; + return new_tok; + } else { + Token *tok = e_found->second; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { // replace old token + tok->tot_cost = tot_cost; + // SetBackpointer() just does tok->backpointer = backpointer in + // the case where Token == BackpointerToken, else nothing. + tok->SetBackpointer(backpointer); + // we don't allocate a new token, the old stays linked in active_toks_ + // we only replace the tot_cost + // in the current frame, there are no forward links (and no extra_cost) + // only in ProcessNonemitting we have to delete forward links + // in case we visit a state for the second time + // those forward links, that lead to this replaced token before: + // they remain and will hopefully be pruned later (PruneForwardLinks...) + if (changed) *changed = true; + } else { + if (changed) *changed = false; + } + return tok; + } +} + +// prunes outgoing links for all tokens in active_toks_[frame] +// it's called by PruneActiveTokens +// all links, that have link_extra_cost > lattice_beam are pruned +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinks( + int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, BaseFloat delta) { + // delta is the amount by which the extra_costs must change + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + + *extra_costs_changed = false; + *links_pruned = false; + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen. + if (!warned_) { + KALDI_WARN << "No tokens alive [doing pruning].. warning first " + "time only for each utterance\n"; + warned_ = true; + } + } + + // We have to iterate until there is no more change, because the links + // are not guaranteed to be in topological order. + bool changed = true; // difference new minus old extra cost >= delta ? + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost for tok. + BaseFloat tok_extra_cost = std::numeric_limits::infinity(); + // tok_extra_cost is the best (min) of link_extra_cost of outgoing links + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); // difference in brackets is >= 0 + // link_exta_cost is the difference in score between the best paths + // through link source state and through link destination state + KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + *links_pruned = true; + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; // move to next link + link = link->next; + } + } // for all outgoing links + if (fabs(tok_extra_cost - tok->extra_cost) > delta) + changed = true; // difference new minus old is bigger than delta + tok->extra_cost = tok_extra_cost; + // will be +infinity or <= lattice_beam_. + // infinity indicates, that no forward link survived pruning + } // for all Token on active_toks_[frame] + if (changed) *extra_costs_changed = true; + + // Note: it's theoretically possible that aggressive compiler + // optimizations could cause an infinite loop here for small delta and + // high-dynamic-range scores. + } // while changed +} + +// PruneForwardLinksFinal is a version of PruneForwardLinks that we call +// on the final frame. If there are final tokens active, it uses +// the final-probs for pruning, otherwise it treats all tokens as final. +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinksFinal() { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame_plus_one = active_toks_.size() - 1; + + if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. + KALDI_WARN << "No tokens alive at end of file"; + + typedef typename unordered_map::const_iterator IterType; + ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); + decoding_finalized_ = true; + + // Now go through tokens on this frame, pruning forward links... may have to + // iterate a few times until there is no more change, because the list is not + // in topological order. This is a modified version of the code in + // PruneForwardLinks, but here we also take account of the final-probs. + bool changed = true; + BaseFloat delta = 1.0e-05; + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost. It has a term in it that corresponds + // to the "final-prob", so instead of initializing tok_extra_cost to infinity + // below we set it to the difference between the (score+final_prob) of this token, + // and the best such (score+final_prob). + BaseFloat final_cost; + if (final_costs_.empty()) { + final_cost = 0.0; + } else { + IterType iter = final_costs_.find(tok); + if (iter != final_costs_.end()) + final_cost = iter->second; + else + final_cost = std::numeric_limits::infinity(); + } + BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_; + // tok_extra_cost will be a "min" over either directly being final, or + // being indirectly final through other links, and the loop below may + // decrease its value: + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; + link = link->next; + } + } + // prune away tokens worse than lattice_beam above best path. This step + // was not necessary in the non-final case because then, this case + // showed up as having no forward links. Here, the tok_extra_cost has + // an extra component relating to the final-prob. + if (tok_extra_cost > config_.lattice_beam) + tok_extra_cost = std::numeric_limits::infinity(); + // to be pruned in PruneTokensForFrame + + if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) + changed = true; + tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. + } + } // while changed +} + +template +BaseFloat LatticeFasterDecoderCombineTpl::FinalRelativeCost() const { + if (!decoding_finalized_) { + BaseFloat relative_cost; + ComputeFinalCosts(NULL, &relative_cost, NULL); + return relative_cost; + } else { + // we're not allowed to call that function if FinalizeDecoding() has + // been called; return a cached value. + return final_relative_cost_; + } +} + + +// Prune away any tokens on this frame that have no forward links. +// [we don't do this in PruneForwardLinks because it would give us +// a problem with dangling pointers]. +// It's called by PruneActiveTokens if any forward links have been pruned +template +void LatticeFasterDecoderCombineTpl::PruneTokensForFrame( + int32 frame_plus_one) { + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + if (toks == NULL) + KALDI_WARN << "No tokens alive [doing pruning]"; + Token *tok, *next_tok, *prev_tok = NULL; + for (tok = toks; tok != NULL; tok = next_tok) { + next_tok = tok->next; + if (tok->extra_cost == std::numeric_limits::infinity()) { + // token is unreachable from end of graph; (no forward links survived) + // excise tok from list and delete tok. + if (prev_tok != NULL) prev_tok->next = tok->next; + else toks = tok->next; + delete tok; + num_toks_--; + } else { // fetch next Token + prev_tok = tok; + } + } +} + +// Go backwards through still-alive tokens, pruning them, starting not from +// the current frame (where we want to keep all tokens) but from the frame before +// that. We go backwards through the frames and stop when we reach a point +// where the delta-costs are not changing (and the delta controls when we consider +// a cost to have "not changed"). +template +void LatticeFasterDecoderCombineTpl::PruneActiveTokens( + BaseFloat delta) { + int32 cur_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // The index "f" below represents a "frame plus one", i.e. you'd have to subtract + // one to get the corresponding index for the decodable object. + for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) { + // Reason why we need to prune forward links in this situation: + // (1) we have never pruned them (new TokenList) + // (2) we have not yet pruned the forward links to the next f, + // after any of those tokens have changed their extra_cost. + if (active_toks_[f].must_prune_forward_links) { + bool extra_costs_changed = false, links_pruned = false; + PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); + if (extra_costs_changed && f > 0) // any token has changed extra_cost + active_toks_[f-1].must_prune_forward_links = true; + if (links_pruned) // any link was pruned + active_toks_[f].must_prune_tokens = true; + active_toks_[f].must_prune_forward_links = false; // job done + } + if (f+1 < cur_frame_plus_one && // except for last f (no forward links) + active_toks_[f+1].must_prune_tokens) { + PruneTokensForFrame(f+1); + active_toks_[f+1].must_prune_tokens = false; + } + } + KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +template +void LatticeFasterDecoderCombineTpl::ComputeFinalCosts( + unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const { + KALDI_ASSERT(!decoding_finalized_); + if (final_costs != NULL) + final_costs->clear(); + BaseFloat infinity = std::numeric_limits::infinity(); + BaseFloat best_cost = infinity, + best_cost_with_final = infinity; + + // The final tokens are recorded in unordered_map "next_toks_". + for (IterType iter = next_toks_.begin(); iter != next_toks_.end(); iter++) { + StateId state = iter->first; + Token *tok = iter->second; + BaseFloat final_cost = fst_->Final(state).Value(); + BaseFloat cost = tok->tot_cost, + cost_with_final = cost + final_cost; + best_cost = std::min(cost, best_cost); + best_cost_with_final = std::min(cost_with_final, best_cost_with_final); + if (final_costs != NULL && final_cost != infinity) + (*final_costs)[tok] = final_cost; + } + if (final_relative_cost != NULL) { + if (best_cost == infinity && best_cost_with_final == infinity) { + // Likely this will only happen if there are no tokens surviving. + // This seems the least bad way to handle it. + *final_relative_cost = infinity; + } else { + *final_relative_cost = best_cost_with_final - best_cost; + } + } + if (final_best_cost != NULL) { + if (best_cost_with_final != infinity) { // final-state exists. + *final_best_cost = best_cost_with_final; + } else { // no final-state exists. + *final_best_cost = best_cost; + } + } +} + +template +void LatticeFasterDecoderCombineTpl::AdvanceDecoding( + DecodableInterface *decodable, + int32 max_num_frames) { + if (std::is_same >::value) { + // if the type 'FST' is the FST base-class, then see if the FST type of fst_ + // is actually VectorFst or ConstFst. If so, call the AdvanceDecoding() + // function after casting *this to the more specific type. + if (fst_->Type() == "const") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } else if (fst_->Type() == "vector") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } + } + + + KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ && + "You must call InitDecoding() before AdvanceDecoding"); + int32 num_frames_ready = decodable->NumFramesReady(); + // num_frames_ready must be >= num_frames_decoded, or else + // the number of frames ready must have decreased (which doesn't + // make sense) or the decodable object changed between calls + // (which isn't allowed). + KALDI_ASSERT(num_frames_ready >= NumFramesDecoded()); + int32 target_frames_decoded = num_frames_ready; + if (max_num_frames >= 0) + target_frames_decoded = std::min(target_frames_decoded, + NumFramesDecoded() + max_num_frames); + while (NumFramesDecoded() < target_frames_decoded) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + ProcessForFrame(decodable); + } + ProcessNonemitting(false); +} + +// FinalizeDecoding() is a version of PruneActiveTokens that we call +// (optionally) on the final frame. Takes into account the final-prob of +// tokens. This function used to be called PruneActiveTokensFinal(). +template +void LatticeFasterDecoderCombineTpl::FinalizeDecoding() { + int32 final_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // PruneForwardLinksFinal() prunes final frame (with final-probs), and + // sets decoding_finalized_. + PruneForwardLinksFinal(); + for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { + bool b1, b2; // values not used. + BaseFloat dontcare = 0.0; // delta of zero means we must always update + PruneForwardLinks(f, &b1, &b2, dontcare); + PruneTokensForFrame(f + 1); + } + PruneTokensForFrame(0); + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +/// Gets the weight cutoff. +template +BaseFloat LatticeFasterDecoderCombineTpl::GetCutoff( + const StateIdToTokenMap &toks, BaseFloat *adaptive_beam, + StateId *best_elem_id, Token **best_elem) { + // positive == high cost == bad. + // best_weight is the minimum value. + BaseFloat best_weight = std::numeric_limits::infinity(); + if (config_.max_active == std::numeric_limits::max() && + config_.min_active == 0) { + for (IterType iter = toks.begin(); iter != toks.end(); iter++) { + BaseFloat w = static_cast(iter->second->tot_cost); + if (w < best_weight) { + best_weight = w; + if (best_elem) { + *best_elem_id = iter->first; + *best_elem = iter->second; + } + } + } + if (adaptive_beam != NULL) *adaptive_beam = config_.beam; + return best_weight + config_.beam; + } else { + tmp_array_.clear(); + for (IterType iter = toks.begin(); iter != toks.end(); iter++) { + BaseFloat w = static_cast(iter->second->tot_cost); + tmp_array_.push_back(w); + if (w < best_weight) { + best_weight = w; + if (best_elem) { + *best_elem_id = iter->first; + *best_elem = iter->second; + } + } + } + + BaseFloat beam_cutoff = best_weight + config_.beam, + min_active_cutoff = std::numeric_limits::infinity(), + max_active_cutoff = std::numeric_limits::infinity(); + + KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() + << " is " << tmp_array_.size(); + + if (tmp_array_.size() > static_cast(config_.max_active)) { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.max_active, + tmp_array_.end()); + max_active_cutoff = tmp_array_[config_.max_active]; + } + if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. + if (adaptive_beam) + *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta; + return max_active_cutoff; + } + if (tmp_array_.size() > static_cast(config_.min_active)) { + if (config_.min_active == 0) min_active_cutoff = best_weight; + else { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.min_active, + tmp_array_.size() > static_cast(config_.max_active) ? + tmp_array_.begin() + config_.max_active : tmp_array_.end()); + min_active_cutoff = tmp_array_[config_.min_active]; + } + } + if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. + if (adaptive_beam) + *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta; + return min_active_cutoff; + } else { + *adaptive_beam = config_.beam; + return beam_cutoff; + } + } +} + +template +void LatticeFasterDecoderCombineTpl::ProcessForFrame( + DecodableInterface *decodable) { + KALDI_ASSERT(active_toks_.size() > 0); + int32 frame = active_toks_.size() - 1; // frame is the frame-index + // (zero-based) used to get likelihoods + // from the decodable object. + if (!recover_ && frame_processed_[frame]) { + KALDI_ERR << "Maybe the whole utterance has been processed, you shouldn't" + << " call ProcessForFrame() again."; + } else if (recover_ && !frame_processed_[frame]) { + KALDI_ERR << "Should not happen."; + } + + // Maybe called GetRawLattice() in the middle of an utterance. The + // active_toks_[frame] is changed. Recover it. + // Notice: as new token will be added to the head of TokenList, tok->next + // will not be affacted. + if (recover_) { + frame_processed_[frame] = false; + for (Token* tok = active_toks_[frame].toks; tok != NULL;) { + if (recover_map_.find(tok) != recover_map_.end()) { + DeleteForwardLinks(tok); + tok->tot_cost = recover_map_[tok]; + tok->in_current_queue = false; + tok = tok->next; + } else { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + recover_ = false; + } + + active_toks_.resize(active_toks_.size() + 1); + frame_processed_.resize(frame_processed_.size() + 1); + + cur_toks_.clear(); + cur_toks_.swap(next_toks_); + if (cur_toks_.empty()) { + if (!warned_) { + KALDI_WARN << "Error, no surviving tokens on frame " << frame; + warned_ = true; + } + } + + BaseFloat adaptive_beam; + Token *best_tok = NULL; + StateId best_tok_state_id; + // "cur_cutoff" is used to constrain the epsilon emittion in current frame. + // It will not be updated. + BaseFloat cur_cutoff = GetCutoff(cur_toks_, &adaptive_beam, + &best_tok_state_id, &best_tok); + KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " + << adaptive_beam; + + + // pruning "online" before having seen all tokens + + // "next_cutoff" is used to limit a new token in next frame should be handle + // or not. It will be updated along with the further processing. + BaseFloat next_cutoff = std::numeric_limits::infinity(); + // "cost_offset" contains the acoustic log-likelihoods on current frame in + // order to keep everything in a nice dynamic range. Reduce roundoff errors. + BaseFloat cost_offset = 0.0; + + // First process the best token to get a hopefully + // reasonably tight bound on the next cutoff. The only + // products of the next block are "next_cutoff" and "cost_offset". + // Notice: As the difference between the combine version and the traditional + // version, this "best_tok" is choosen from emittion tokens. Normally, the + // best token of one frame comes from an epsilon non-emittion. So the best + // token is a looser boundary. Use it to estimate a bound on the next cutoff. + // The "next_cutoff" will be updated in further processing. + if (best_tok) { + cost_offset = - best_tok->tot_cost; + for (fst::ArcIterator aiter(*fst_, best_tok_state_id); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + // ac_cost + graph_cost + BaseFloat new_weight = arc.weight.Value() + cost_offset - + decodable->LogLikelihood(frame, arc.ilabel) + best_tok->tot_cost; + if (new_weight + adaptive_beam < next_cutoff) + next_cutoff = new_weight + adaptive_beam; + } + } + } + + // Store the offset on the acoustic likelihoods that we're applying. + // Could just do cost_offsets_.push_back(cost_offset), but we + // do it this way as it's more robust to future code changes. + cost_offsets_.resize(frame + 1, 0.0); + cost_offsets_[frame] = cost_offset; + + // Build a queue which contains the emittion tokens from previous frame. + std::vector cur_queue; + for (IterType iter = cur_toks_.begin(); iter != cur_toks_.end(); iter++) { + cur_queue.push_back(iter->first); + iter->second->in_current_queue = true; + } + + // Iterator the "cur_queue" to process non-emittion and emittion arcs in fst. + while (!cur_queue.empty()) { + StateId state = cur_queue.back(); + cur_queue.pop_back(); + + KALDI_ASSERT(cur_toks_.find(state) != cur_toks_.end()); + Token *tok = cur_toks_[state]; + BaseFloat cur_cost = tok->tot_cost; + if (cur_cost > cur_cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, + tok, &cur_toks_, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed && !new_tok->in_current_queue) { + cur_queue.push_back(arc.nextstate); + new_tok->in_current_queue = true; + } + } + } else { // propagate emitting + BaseFloat graph_cost = arc.weight.Value(), + ac_cost = cost_offset - decodable->LogLikelihood(frame, arc.ilabel), + cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + if (tot_cost > next_cutoff) continue; + else if (tot_cost + adaptive_beam < next_cutoff) + next_cutoff = tot_cost + adaptive_beam; // a tighter boundary for emitting + + // no change flag is needed + Token *next_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, + tok, &next_toks_, NULL); + // Add ForwardLink from tok to next_tok. Put it on the head of tok->link + // list + tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, + graph_cost, ac_cost, tok->links); + } + } // for all arcs + tok->in_current_queue = false; // out of queue + } // end of while loop + frame_processed_[frame] = true; + frame_processed_[frame + 1] = false; +} + + +template +void LatticeFasterDecoderCombineTpl::ProcessNonemitting(bool recover) { + if (recover) { // Build the elements which are used to recover + // Set the flag to true so that we will recover "next_toks_" map in + // ProcessForFrame() firstly. + recover_ = true; + for (IterType iter = next_toks_.begin(); iter != next_toks_.end(); iter++) { + recover_map_[iter->second] = iter->second->tot_cost; + } + } + + StateIdToTokenMap tmp_toks_(next_toks_); + int32 frame = active_toks_.size() - 1; + // Build the queue to process non-emitting arcs + std::vector cur_queue; + for (IterType iter = cur_toks_.begin(); iter != cur_toks_.end(); iter++) { + if (fst_->NumInputEpsilons(iter->first) != 0) { + cur_queue.push_back(iter->first); + iter->second->in_current_queue = true; + } + } + + // "cur_cutoff" is used to constrain the epsilon emittion in current frame. + // It will not be updated. + BaseFloat cur_cutoff = GetCutoff(tmp_toks_, NULL, NULL, NULL); + + while (!cur_queue.empty()) { + StateId state = cur_queue.back(); + cur_queue.pop_back(); + + KALDI_ASSERT(tmp_toks_.find(state) != tmp_toks_.end()); + Token *tok = tmp_toks_[state]; + BaseFloat cur_cost = tok->tot_cost; + if (cur_cost > cur_cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, + tok, &tmp_toks_, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed && !new_tok->in_current_queue) { + cur_queue.push_back(arc.nextstate); + new_tok->in_current_queue = true; + } + } + } + } // end of for loop + tok->in_current_queue = false; + } // end of while loop + frame_processed_[active_toks_.size() - 1] = true; // in case someone call + // GetRawLattice() twice + // continuously. +} + + + +// static inline +template +void LatticeFasterDecoderCombineTpl::DeleteForwardLinks(Token *tok) { + ForwardLinkT *l = tok->links, *m; + while (l != NULL) { + m = l->next; + delete l; + l = m; + } + tok->links = NULL; +} + + +template +void LatticeFasterDecoderCombineTpl::ClearActiveTokens() { + // a cleanup routine, at utt end/begin + for (size_t i = 0; i < active_toks_.size(); i++) { + // Delete all tokens alive on this frame, and any forward + // links they may have. + for (Token *tok = active_toks_[i].toks; tok != NULL; ) { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + active_toks_.clear(); + KALDI_ASSERT(num_toks_ == 0); +} + +// static +template +void LatticeFasterDecoderCombineTpl::TopSortTokens( + Token *tok_list, std::vector *topsorted_list) { + unordered_map token2pos; + typedef typename unordered_map::iterator IterType; + int32 num_toks = 0; + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + num_toks++; + int32 cur_pos = 0; + // We assign the tokens numbers num_toks - 1, ... , 2, 1, 0. + // This is likely to be in closer to topological order than + // if we had given them ascending order, because of the way + // new tokens are put at the front of the list. + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + token2pos[tok] = num_toks - ++cur_pos; + + unordered_set reprocess; + + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) { + Token *tok = iter->first; + int32 pos = iter->second; + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + // We only need to consider epsilon links, since non-epsilon links + // transition between frames and this function only needs to sort a list + // of tokens from a single frame. + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { // another token on this frame, + // so must consider it. + int32 next_pos = following_iter->second; + if (next_pos < pos) { // reassign the position of the next Token. + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + // In case we had previously assigned this token to be reprocessed, we can + // erase it from that set because it's "happy now" (we just processed it). + reprocess.erase(tok); + } + + size_t max_loop = 1000000, loop_count; // max_loop is to detect epsilon cycles. + for (loop_count = 0; + !reprocess.empty() && loop_count < max_loop; ++loop_count) { + std::vector reprocess_vec; + for (typename unordered_set::iterator iter = reprocess.begin(); + iter != reprocess.end(); ++iter) + reprocess_vec.push_back(*iter); + reprocess.clear(); + for (typename std::vector::iterator iter = reprocess_vec.begin(); + iter != reprocess_vec.end(); ++iter) { + Token *tok = *iter; + int32 pos = token2pos[tok]; + // Repeat the processing we did above (for comments, see above). + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { + int32 next_pos = following_iter->second; + if (next_pos < pos) { + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + } + } + KALDI_ASSERT(loop_count < max_loop && "Epsilon loops exist in your decoding " + "graph (this is not allowed!)"); + + topsorted_list->clear(); + topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between. + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) + (*topsorted_list)[iter->second] = iter->first; +} + +// Instantiate the template for the combination of token types and FST types +// that we'll need. +template class LatticeFasterDecoderCombineTpl, decoder::StdToken>; +template class LatticeFasterDecoderCombineTpl, decoder::StdToken >; +template class LatticeFasterDecoderCombineTpl, decoder::StdToken >; +template class LatticeFasterDecoderCombineTpl; + +template class LatticeFasterDecoderCombineTpl , decoder::BackpointerToken>; +template class LatticeFasterDecoderCombineTpl, decoder::BackpointerToken >; +template class LatticeFasterDecoderCombineTpl, decoder::BackpointerToken >; +template class LatticeFasterDecoderCombineTpl; + + +} // end namespace kaldi. diff --git a/src/decoder/lattice-faster-decoder-combine.h b/src/decoder/lattice-faster-decoder-combine.h new file mode 100644 index 00000000000..abfdb5c21fc --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine.h @@ -0,0 +1,539 @@ +// decoder/lattice-faster-decoder.h + +// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; +// 2013-2014 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ +#define KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ + + +#include "util/stl-utils.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/grammar-fst.h" + +namespace kaldi { + +struct LatticeFasterDecoderCombineConfig { + BaseFloat beam; + int32 max_active; + int32 min_active; + BaseFloat lattice_beam; + int32 prune_interval; + bool determinize_lattice; // not inspected by this class... used in + // command-line program. + BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat hash_ratio; + BaseFloat prune_scale; // Note: we don't make this configurable on the command line, + // it's not a very important parameter. It affects the + // algorithm that prunes the tokens as we go. + // Most of the options inside det_opts are not actually queried by the + // LatticeFasterDecoder class itself, but by the code that calls it, for + // example in the function DecodeUtteranceLatticeFaster. + fst::DeterminizeLatticePhonePrunedOptions det_opts; + + LatticeFasterDecoderCombineConfig(): beam(16.0), + max_active(std::numeric_limits::max()), + min_active(200), + lattice_beam(10.0), + prune_interval(25), + determinize_lattice(true), + beam_delta(0.5), + hash_ratio(2.0), + prune_scale(0.1) { } + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); + opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; " + "more accurate"); + opts->Register("min-active", &min_active, "Decoder minimum #active states."); + opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, " + "and deeper lattices"); + opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " + "which to prune tokens"); + opts->Register("determinize-lattice", &determinize_lattice, "If true, " + "determinize the lattice (lattice-determinization, keeping only " + "best pdf-sequence for each word-sequence)."); + opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " + "parameter is obscure and relates to a speedup in the way the " + "max-active constraint is applied. Larger is more accurate."); + opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " + "control hash behavior"); + } + void Check() const { + KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 + && min_active <= max_active + && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0 + && prune_scale > 0.0 && prune_scale < 1.0); + } +}; + +namespace decoder { +// We will template the decoder on the token type as well as the FST type; this +// is a mechanism so that we can use the same underlying decoder code for +// versions of the decoder that support quickly getting the best path +// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also +// those that do not (LatticeFasterDecoder). + + +// ForwardLinks are the links from a token to a token on the next frame. +// or sometimes on the current frame (for input-epsilon links). +template +struct ForwardLink { + using Label = fst::StdArc::Label; + + Token *next_tok; // the next token [or NULL if represents final-state] + Label ilabel; // ilabel on arc + Label olabel; // olabel on arc + BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.) + BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc + ForwardLink *next; // next in singly-linked list of forward arcs (arcs + // in the state-level lattice) from a token. + inline ForwardLink(Token *next_tok, Label ilabel, Label olabel, + BaseFloat graph_cost, BaseFloat acoustic_cost, + ForwardLink *next): + next_tok(next_tok), ilabel(ilabel), olabel(olabel), + graph_cost(graph_cost), acoustic_cost(acoustic_cost), + next(next) { } +}; + + +struct StdToken { + using ForwardLinkT = ForwardLink; + using Token = StdToken; + + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + Token *next; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_current_queue; + + // This function does nothing and should be optimized out; it's needed + // so we can share the regular LatticeFasterDecoderTpl code and the code + // for LatticeFasterOnlineDecoder that supports fast traceback. + inline void SetBackpointer (Token *backpointer) { } + + // This constructor just ignores the 'backpointer' argument. That argument is + // needed so that we can use the same decoder code for LatticeFasterDecoderTpl + // and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a + // fast way to obtain the best path). + inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next), + in_current_queue(false) { } +}; + +struct BackpointerToken { + using ForwardLinkT = ForwardLink; + using Token = BackpointerToken; + + // BackpointerToken is like Token but also + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + BackpointerToken *next; + + // Best preceding BackpointerToken (could be a on this frame, connected to + // this via an epsilon transition, or on a previous frame). This is only + // required for an efficient GetBestPath function in + // LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation + // (the "links" list is what stores the forward links, for that). + Token *backpointer; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_current_queue; + + inline void SetBackpointer (Token *backpointer) { + this->backpointer = backpointer; + } + + inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next), + backpointer(backpointer), in_current_queue(false) { } +}; + +} // namespace decoder + + +/** This is the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to support + quick lookup of the current best path (see lattice-faster-online-decoder.h) + + The FST you invoke this decoder with is expected to equal + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ +template +class LatticeFasterDecoderCombineTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ForwardLinkT = decoder::ForwardLink; + + using StateIdToTokenMap = typename std::unordered_map; + using IterType = typename StateIdToTokenMap::const_iterator; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeFasterDecoderCombineTpl(const FST &fst, + const LatticeFasterDecoderCombineConfig &config); + + // This version of the constructor takes ownership of the fst, and will delete + // it when this object is destroyed. + LatticeFasterDecoderCombineTpl(const LatticeFasterDecoderCombineConfig &config, + FST *fst); + + void SetOptions(const LatticeFasterDecoderCombineConfig &config) { + config_ = config; + } + + const LatticeFasterDecoderCombineConfig &GetOptions() const { + return config_; + } + + ~LatticeFasterDecoderCombineTpl(); + + /// Decodes until there are no more frames left in the "decodable" object.. + /// note, this may block waiting for input if the "decodable" object blocks. + /// Returns true if any kind of traceback is available (not necessarily from a + /// final state). + bool Decode(DecodableInterface *decodable); + + + /// says whether a final-state was active on the last frame. If it was not, the + /// lattice (or traceback) will end with states that are not final-states. + bool ReachedFinal() const { + return FinalRelativeCost() != std::numeric_limits::infinity(); + } + + /// Outputs an FST corresponding to the single best path through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. Note: this just calls GetRawLattice() + /// and figures out the shortest path. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true); + + /// Outputs an FST corresponding to the raw, state-level + /// tracebacks. Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state + /// of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + /// The raw lattice will be topologically sorted. + /// + /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, + /// which also supports a pruning beam, in case for some reason + /// you want it pruned tighter than the regular lattice beam. + /// We could put that here in future needed. + bool GetRawLattice(Lattice *ofst, bool use_final_probs = true); + + + + /// [Deprecated, users should now use GetRawLattice and determinize it + /// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper]. + /// Outputs an FST corresponding to the lattice-determinized + /// lattice (one path per word sequence). Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state of the graph + /// then it will include those as final-probs, else it will treat all + /// final-probs as one. + bool GetLattice(CompactLattice *ofst, + bool use_final_probs = true); + + /// InitDecoding initializes the decoding, and should only be used if you + /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to + /// call this. You can also call InitDecoding if you have already decoded an + /// utterance and want to start with a new utterance. + void InitDecoding(); + + /// This will decode until there are no more frames ready in the decodable + /// object. You can keep calling it each time more frames become available. + /// If max_num_frames is specified, it specifies the maximum number of frames + /// the function will decode before returning. + void AdvanceDecoding(DecodableInterface *decodable, + int32 max_num_frames = -1); + + /// This function may be optionally called after AdvanceDecoding(), when you + /// do not plan to decode any further. It does an extra pruning step that + /// will help to prune the lattices output by GetLattice and (particularly) + /// GetRawLattice more accurately, particularly toward the end of the + /// utterance. It does this by using the final-probs in pruning (if any + /// final-state survived); it also does a final pruning step that visits all + /// states (the pruning that is done during decoding may fail to prune states + /// that are within kPruningScale = 0.1 outside of the beam). If you call + /// this, you cannot call AdvanceDecoding again (it will fail), and you + /// cannot call GetLattice() and related functions with use_final_probs = + /// false. + /// Used to be called PruneActiveTokensFinal(). + void FinalizeDecoding(); + + /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + /// more information. It returns the difference between the best (final-cost + /// plus cost) of any token on the final frame, and the best cost of any token + /// on the final frame. If it is infinity it means no final-states were + /// present on the final frame. It will usually be nonnegative. If it not + /// too positive (e.g. < 5 is my first guess, but this is not tested) you can + /// take it as a good indication that we reached the final-state with + /// reasonable likelihood. + BaseFloat FinalRelativeCost() const; + + + // Returns the number of frames decoded so far. The value returned changes + // whenever we call ProcessForFrame(). + inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } + + protected: + // we make things protected instead of private, as code in + // LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the + // internals. + + // Deletes the elements of the singly linked list tok->links. + inline static void DeleteForwardLinks(Token *tok); + + // head of per-frame list of Tokens (list is in topological order), + // and something saying whether we ever pruned it using PruneForwardLinks. + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + TokenList(): toks(NULL), must_prune_forward_links(true), + must_prune_tokens(true) { } + }; + + // FindOrAddToken either locates a token in hash of toks_, or if necessary + // inserts a new, empty token (i.e. with no forward links) for the current + // frame. [note: it's inserted if necessary into hash toks_ and also into the + // singly linked list of tokens active on this frame (whose head is at + // active_toks_[frame]). The frame_plus_one argument is the acoustic frame + // index plus one, which is used to index into the active_toks_ array. + // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the + // token was newly created or the cost changed. + // If Token == StdToken, the 'backpointer' argument has no purpose (and will + // hopefully be optimized out). + inline Token *FindOrAddToken(StateId state, int32 frame, + BaseFloat tot_cost, Token *backpointer, + StateIdToTokenMap *token_map, + bool *changed); + + // prunes outgoing links for all tokens in active_toks_[frame] + // it's called by PruneActiveTokens + // all links, that have link_extra_cost > lattice_beam are pruned + // delta is the amount by which the extra_costs must change + // before we set *extra_costs_changed = true. + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, + BaseFloat delta); + + // This function computes the final-costs for tokens active on the final + // frame. It outputs to final-costs, if non-NULL, a map from the Token* + // pointer to the final-prob of the corresponding state, for all Tokens + // that correspond to states that have final-probs. This map will be + // empty if there were no final-probs. It outputs to + // final_relative_cost, if non-NULL, the difference between the best + // forward-cost including the final-prob cost, and the best forward-cost + // without including the final-prob cost (this will usually be positive), or + // infinity if there were no final-probs. [c.f. FinalRelativeCost(), which + // outputs this quanitity]. It outputs to final_best_cost, if + // non-NULL, the lowest for any token t active on the final frame, of + // forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in + // the graph of the state corresponding to token t, or the best of + // forward-cost[t] if there were no final-probs active on the final frame. + // You cannot call this after FinalizeDecoding() has been called; in that + // case you should get the answer from class-member variables. + void ComputeFinalCosts(unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const; + + // PruneForwardLinksFinal is a version of PruneForwardLinks that we call + // on the final frame. If there are final tokens active, it uses + // the final-probs for pruning, otherwise it treats all tokens as final. + void PruneForwardLinksFinal(); + + // Prune away any tokens on this frame that have no forward links. + // [we don't do this in PruneForwardLinks because it would give us + // a problem with dangling pointers]. + // It's called by PruneActiveTokens if any forward links have been pruned + void PruneTokensForFrame(int32 frame_plus_one); + + + // Go backwards through still-alive tokens, pruning them if the + // forward+backward cost is more than lat_beam away from the best path. It's + // possible to prove that this is "correct" in the sense that we won't lose + // anything outside of lat_beam, regardless of what happens in the future. + // delta controls when it considers a cost to have changed enough to continue + // going backward and propagating the change. larger delta -> will recurse + // less far. + void PruneActiveTokens(BaseFloat delta); + + /// Processes nonemitting (epsilon) arcs and emitting arcs for one frame + /// together. Consider it as a combination of ProcessEmitting() and + /// ProcessNonemitting(). + void ProcessForFrame(DecodableInterface *decodable); + + /// Processes nonemitting (epsilon) arcs for one frame. + /// Called once when all frames were processed or in GetRawLattice(). + /// Deal With the tokens in map "next_toks_" which would only contains + /// emittion tokens from previous frame. + /// If you call this function not in the end of an utterance, recover + /// should be true. + void ProcessNonemitting(bool recover); + + /// The "cur_toks_" and "next_toks_" actually allow us to maintain current + /// and next frames. They are indexed by StateId. It is indexed by frame-index + /// plus one, where the frame-index is zero-based, as used in decodable object. + /// That is, the emitting probs of frame t are accounted for in tokens at + /// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of + /// the graph. + StateIdToTokenMap cur_toks_; + StateIdToTokenMap next_toks_; + + /// When we call GetRawLattice() in the middle of an utterance, we have to + /// process non-emitting arcs so that we need to recover it original status. + std::unordered_map recover_map_; // Token pointer to tot_cost + bool recover_; + /// Indicate each frame is processed wholly or not. The size equals to + /// active_toks_. + std::vector frame_processed_; + + /// Gets the weight cutoff. + /// Notice: In traiditional version, the histogram prunning method is applied + /// on a complete token list on one frame. But, in this version, it is used + /// on a token list which only contains the emittion part. So the max_active + /// and min_active values might be narrowed. + BaseFloat GetCutoff(const StateIdToTokenMap& toks, + BaseFloat *adaptive_beam, + StateId *best_elem_id, Token **best_elem); + + std::vector active_toks_; // Lists of tokens, indexed by + // frame (members of TokenList are toks, must_prune_forward_links, + // must_prune_tokens). + std::vector queue_; // temp variable used in ProcessForFrame for + // epsilon arcs. + std::vector tmp_array_; // used in GetCutoff. + + // fst_ is a pointer to the FST we are decoding from. + const FST *fst_; + // delete_fst_ is true if the pointer fst_ needs to be deleted when this + // object is destroyed. + bool delete_fst_; + + std::vector cost_offsets_; // This contains, for each + // frame, an offset that was added to the acoustic log-likelihoods on that + // frame in order to keep everything in a nice dynamic range i.e. close to + // zero, to reduce roundoff errors. + // Notice: It will only be added to emitting arcs (i.e. cost_offsets_[t] is + // added to arcs from "frame t" to "frame t+1"). + LatticeFasterDecoderCombineConfig config_; + int32 num_toks_; // current total #toks allocated... + bool warned_; + + /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, + /// calling this is optional]. If true, it's forbidden to decode more. Also, + /// if this is set, then the output of ComputeFinalCosts() is in the next + /// three variables. The reason we need to do this is that after + /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some + /// of the tokens on the last frame are freed, so we free the list from toks_ + /// to avoid having dangling pointers hanging around. + bool decoding_finalized_; + /// For the meaning of the next 3 variables, see the comment for + /// decoding_finalized_ above., and ComputeFinalCosts(). + unordered_map final_costs_; + BaseFloat final_relative_cost_; + BaseFloat final_best_cost_; + + // This function takes a singly linked list of tokens for a single frame, and + // outputs a list of them in topological order (it will crash if no such order + // can be found, which will typically be due to decoding graphs with epsilon + // cycles, which are not allowed). Note: the output list may contain NULLs, + // which the caller should pass over; it just happens to be more efficient for + // the algorithm to output a list that contains NULLs. + static void TopSortTokens(Token *tok_list, + std::vector *topsorted_list); + + void ClearActiveTokens(); + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderCombineTpl); +}; + +typedef LatticeFasterDecoderCombineTpl LatticeFasterDecoderCombine; + + + +} // end namespace kaldi. + +#endif From 4b30697af298863f771b2637f3613bdb67c1c833 Mon Sep 17 00:00:00 2001 From: LvHang Date: Thu, 28 Feb 2019 01:54:13 -0500 Subject: [PATCH 02/29] add test binary --- src/bin/Makefile | 2 +- src/decoder/decoder-wrappers.cc | 299 ++++++++++++++++++ src/decoder/decoder-wrappers.h | 73 +++++ src/decoder/lattice-faster-decoder-combine.cc | 25 +- src/decoder/lattice-faster-decoder-combine.h | 10 +- 5 files changed, 392 insertions(+), 17 deletions(-) diff --git a/src/bin/Makefile b/src/bin/Makefile index 7cb01b50120..3fcbef6ad32 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -22,7 +22,7 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ matrix-sum build-pfile-from-ali get-post-on-ali tree-info am-info \ vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \ transform-vec align-text matrix-dim post-to-smat compile-graph \ - compare-int-vector + compare-int-vector latgen-faster-mapped-combine OBJFILES = diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index ff573c74d15..3c1dbd7ed8d 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -546,4 +546,303 @@ void AlignUtteranceWrapper( } } +// For lattice-faster-decoder-combine +DecodeUtteranceLatticeFasterCombineClass::DecodeUtteranceLatticeFasterCombineClass( + LatticeFasterDecoderCombine *decoder, + DecodableInterface *decodable, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + BaseFloat acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignments_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_sum, // on success, adds likelihood to this. + int64 *frame_sum, // on success, adds #frames to this. + int32 *num_done, // on success (including partial decode), increments this. + int32 *num_err, // on failure, increments this. + int32 *num_partial): // If partial decode (final-state not reached), increments this. + decoder_(decoder), decodable_(decodable), trans_model_(&trans_model), + word_syms_(word_syms), utt_(utt), acoustic_scale_(acoustic_scale), + determinize_(determinize), allow_partial_(allow_partial), + alignments_writer_(alignments_writer), + words_writer_(words_writer), + compact_lattice_writer_(compact_lattice_writer), + lattice_writer_(lattice_writer), + like_sum_(like_sum), frame_sum_(frame_sum), + num_done_(num_done), num_err_(num_err), + num_partial_(num_partial), + computed_(false), success_(false), partial_(false), + clat_(NULL), lat_(NULL) { } + + +void DecodeUtteranceLatticeFasterCombineClass::operator () () { + // Decoding and lattice determinization happens here. + computed_ = true; // Just means this function was called-- a check on the + // calling code. + success_ = true; + using fst::VectorFst; + if (!decoder_->Decode(decodable_)) { + KALDI_WARN << "Failed to decode file " << utt_; + success_ = false; + } + if (!decoder_->ReachedFinal()) { + if (allow_partial_) { + KALDI_WARN << "Outputting partial output for utterance " << utt_ + << " since no final-state reached\n"; + partial_ = true; + } else { + KALDI_WARN << "Not producing output for utterance " << utt_ + << " since no final-state reached and " + << "--allow-partial=false.\n"; + success_ = false; + } + } + if (!success_) return; + + // Get lattice, and do determinization if requested. + lat_ = new Lattice; + decoder_->GetRawLattice(lat_); + if (lat_->NumStates() == 0) + KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt_; + fst::Connect(lat_); + if (determinize_) { + clat_ = new CompactLattice; + if (!DeterminizeLatticePhonePrunedWrapper( + *trans_model_, + lat_, + decoder_->GetOptions().lattice_beam, + clat_, + decoder_->GetOptions().det_opts)) + KALDI_WARN << "Determinization finished earlier than the beam for " + << "utterance " << utt_; + delete lat_; + lat_ = NULL; + // We'll write the lattice without acoustic scaling. + if (acoustic_scale_ != 0.0) + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale_), clat_); + } else { + // We'll write the lattice without acoustic scaling. + if (acoustic_scale_ != 0.0) + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale_), lat_); + } +} + +DecodeUtteranceLatticeFasterCombineClass::~DecodeUtteranceLatticeFasterCombineClass() { + if (!computed_) + KALDI_ERR << "Destructor called without operator (), error in calling code."; + + if (!success_) { + if (num_err_ != NULL) (*num_err_)++; + } else { // successful decode. + // Getting the one-best output is lightweight enough that we can do it in + // the destructor (easier than adding more variables to the class, and + // will rarely slow down the main thread.) + double likelihood; + LatticeWeight weight; + int32 num_frames; + { // First do some stuff with word-level traceback... + // This is basically for diagnostics. + fst::VectorFst decoded; + decoder_->GetBestPath(&decoded); + if (decoded.NumStates() == 0) { + // Shouldn't really reach this point as already checked success. + KALDI_ERR << "Failed to get traceback for utterance " << utt_; + } + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(decoded, &alignment, &words, &weight); + num_frames = alignment.size(); + if (words_writer_->IsOpen()) + words_writer_->Write(utt_, words); + if (alignments_writer_->IsOpen()) + alignments_writer_->Write(utt_, alignment); + if (word_syms_ != NULL) { + std::cerr << utt_ << ' '; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms_->Find(words[i]); + if (s == "") + KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; + std::cerr << s << ' '; + } + std::cerr << '\n'; + } + likelihood = -(weight.Value1() + weight.Value2()); + } + + // Ouptut the lattices. + if (determinize_) { // CompactLattice output. + KALDI_ASSERT(compact_lattice_writer_ != NULL && clat_ != NULL); + if (clat_->NumStates() == 0) { + KALDI_WARN << "Empty lattice for utterance " << utt_; + } else { + compact_lattice_writer_->Write(utt_, *clat_); + } + delete clat_; + clat_ = NULL; + } else { + KALDI_ASSERT(lattice_writer_ != NULL && lat_ != NULL); + if (lat_->NumStates() == 0) { + KALDI_WARN << "Empty lattice for utterance " << utt_; + } else { + lattice_writer_->Write(utt_, *lat_); + } + delete lat_; + lat_ = NULL; + } + + // Print out logging information. + KALDI_LOG << "Log-like per frame for utterance " << utt_ << " is " + << (likelihood / num_frames) << " over " + << num_frames << " frames."; + KALDI_VLOG(2) << "Cost for utterance " << utt_ << " is " + << weight.Value1() << " + " << weight.Value2(); + + // Now output the various diagnostic variables. + if (like_sum_ != NULL) *like_sum_ += likelihood; + if (frame_sum_ != NULL) *frame_sum_ += num_frames; + if (num_done_ != NULL) (*num_done_)++; + if (partial_ && num_partial_ != NULL) (*num_partial_)++; + } + // We were given ownership of these two objects that were passed in in + // the initializer. + delete decoder_; + delete decodable_; +} + + +// Takes care of output. Returns true on success. +template +bool DecodeUtteranceLatticeFasterCombine( + LatticeFasterDecoderCombineTpl &decoder, // not const but is really an input. + DecodableInterface &decodable, // not const but is really an input. + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr) { // puts utterance's like in like_ptr on success. + using fst::VectorFst; + + if (!decoder.Decode(&decodable)) { + KALDI_WARN << "Failed to decode file " << utt; + return false; + } + if (!decoder.ReachedFinal()) { + if (allow_partial) { + KALDI_WARN << "Outputting partial output for utterance " << utt + << " since no final-state reached\n"; + } else { + KALDI_WARN << "Not producing output for utterance " << utt + << " since no final-state reached and " + << "--allow-partial=false.\n"; + return false; + } + } + + double likelihood; + LatticeWeight weight; + int32 num_frames; + { // First do some stuff with word-level traceback... + VectorFst decoded; + if (!decoder.GetBestPath(&decoded)) + // Shouldn't really reach this point as already checked success. + KALDI_ERR << "Failed to get traceback for utterance " << utt; + + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(decoded, &alignment, &words, &weight); + num_frames = alignment.size(); + if (words_writer->IsOpen()) + words_writer->Write(utt, words); + if (alignment_writer->IsOpen()) + alignment_writer->Write(utt, alignment); + if (word_syms != NULL) { + std::cerr << utt << ' '; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms->Find(words[i]); + if (s == "") + KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; + std::cerr << s << ' '; + } + std::cerr << '\n'; + } + likelihood = -(weight.Value1() + weight.Value2()); + } + + // Get lattice, and do determinization if requested. + Lattice lat; + decoder.GetRawLattice(&lat); + if (lat.NumStates() == 0) + KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt; + fst::Connect(&lat); + if (determinize) { + CompactLattice clat; + if (!DeterminizeLatticePhonePrunedWrapper( + trans_model, + &lat, + decoder.GetOptions().lattice_beam, + &clat, + decoder.GetOptions().det_opts)) + KALDI_WARN << "Determinization finished earlier than the beam for " + << "utterance " << utt; + // We'll write the lattice without acoustic scaling. + if (acoustic_scale != 0.0) + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat); + compact_lattice_writer->Write(utt, clat); + } else { + // We'll write the lattice without acoustic scaling. + if (acoustic_scale != 0.0) + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &lat); + lattice_writer->Write(utt, lat); + } + KALDI_LOG << "Log-like per frame for utterance " << utt << " is " + << (likelihood / num_frames) << " over " + << num_frames << " frames."; + KALDI_VLOG(2) << "Cost for utterance " << utt << " is " + << weight.Value1() << " + " << weight.Value2(); + *like_ptr = likelihood; + return true; +} + +// Instantiate the template above for the two required FST types. +template bool DecodeUtteranceLatticeFasterCombine( + LatticeFasterDecoderCombineTpl > &decoder, + DecodableInterface &decodable, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); + +template bool DecodeUtteranceLatticeFasterCombine( + LatticeFasterDecoderCombineTpl &decoder, + DecodableInterface &decodable, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); + + } // end namespace kaldi. diff --git a/src/decoder/decoder-wrappers.h b/src/decoder/decoder-wrappers.h index fc81137f356..19d01e4316a 100644 --- a/src/decoder/decoder-wrappers.h +++ b/src/decoder/decoder-wrappers.h @@ -23,6 +23,7 @@ #include "itf/options-itf.h" #include "decoder/lattice-faster-decoder.h" #include "decoder/lattice-simple-decoder.h" +#include "decoder/lattice-faster-decoder-combine.h" // This header contains declarations from various convenience functions that are called // from binary-level programs such as gmm-decode-faster.cc, gmm-align-compiled.cc, and @@ -196,6 +197,78 @@ bool DecodeUtteranceLatticeSimple( double *like_ptr); // puts utterance's likelihood in like_ptr on success. +// For lattice-faster-decoder-combine +template +bool DecodeUtteranceLatticeFasterCombine( + LatticeFasterDecoderCombineTpl &decoder, // not const but is really an input. + DecodableInterface &decodable, // not const but is really an input. + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignments_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); // puts utterance's likelihood in like_ptr on success. + + +class DecodeUtteranceLatticeFasterCombineClass { + public: + // Initializer sets various variables. + // NOTE: we "take ownership" of "decoder" and "decodable". These + // are deleted by the destructor. On error, "num_err" is incremented. + DecodeUtteranceLatticeFasterCombineClass( + LatticeFasterDecoderCombine *decoder, + DecodableInterface *decodable, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + BaseFloat acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignments_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_sum, // on success, adds likelihood to this. + int64 *frame_sum, // on success, adds #frames to this. + int32 *num_done, // on success (including partial decode), increments this. + int32 *num_err, // on failure, increments this. + int32 *num_partial); // If partial decode (final-state not reached), increments this. + void operator () (); // The decoding happens here. + ~DecodeUtteranceLatticeFasterCombineClass(); // Output happens here. + private: + // The following variables correspond to inputs: + LatticeFasterDecoderCombine *decoder_; + DecodableInterface *decodable_; + const TransitionModel *trans_model_; + const fst::SymbolTable *word_syms_; + std::string utt_; + BaseFloat acoustic_scale_; + bool determinize_; + bool allow_partial_; + Int32VectorWriter *alignments_writer_; + Int32VectorWriter *words_writer_; + CompactLatticeWriter *compact_lattice_writer_; + LatticeWriter *lattice_writer_; + double *like_sum_; + int64 *frame_sum_; + int32 *num_done_; + int32 *num_err_; + int32 *num_partial_; + + // The following variables are stored by the computation. + bool computed_; // operator () was called. + bool success_; // decoding succeeded (possibly partial) + bool partial_; // decoding was partial. + CompactLattice *clat_; // Stored output, if determinize_ == true. + Lattice *lat_; // Stored output, if determinize_ == false. +}; + + } // end namespace kaldi. diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index 8cb6e59564d..67c4bfe7e8e 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -66,7 +66,7 @@ void LatticeFasterDecoderCombineTpl::InitDecoding() { active_toks_.resize(1); Token *start_tok = new Token(0.0, 0.0, NULL, NULL, NULL); active_toks_[0].toks = start_tok; - cur_toks_[start_state] = start_tok; // initialize current tokens map + next_toks_[start_state] = start_tok; // initialize current tokens map num_toks_++; recover_ = false; @@ -910,7 +910,7 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting(bool recover int32 frame = active_toks_.size() - 1; // Build the queue to process non-emitting arcs std::vector cur_queue; - for (IterType iter = cur_toks_.begin(); iter != cur_toks_.end(); iter++) { + for (IterType iter = tmp_toks_.begin(); iter != tmp_toks_.end(); iter++) { if (fst_->NumInputEpsilons(iter->first) != 0) { cur_queue.push_back(iter->first); iter->second->in_current_queue = true; @@ -919,7 +919,8 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting(bool recover // "cur_cutoff" is used to constrain the epsilon emittion in current frame. // It will not be updated. - BaseFloat cur_cutoff = GetCutoff(tmp_toks_, NULL, NULL, NULL); + BaseFloat adaptive_beam; + BaseFloat cur_cutoff = GetCutoff(tmp_toks_, &adaptive_beam, NULL, NULL); while (!cur_queue.empty()) { StateId state = cur_queue.back(); @@ -1082,15 +1083,15 @@ void LatticeFasterDecoderCombineTpl::TopSortTokens( // Instantiate the template for the combination of token types and FST types // that we'll need. -template class LatticeFasterDecoderCombineTpl, decoder::StdToken>; -template class LatticeFasterDecoderCombineTpl, decoder::StdToken >; -template class LatticeFasterDecoderCombineTpl, decoder::StdToken >; -template class LatticeFasterDecoderCombineTpl; - -template class LatticeFasterDecoderCombineTpl , decoder::BackpointerToken>; -template class LatticeFasterDecoderCombineTpl, decoder::BackpointerToken >; -template class LatticeFasterDecoderCombineTpl, decoder::BackpointerToken >; -template class LatticeFasterDecoderCombineTpl; +template class LatticeFasterDecoderCombineTpl, decodercombine::StdToken>; +template class LatticeFasterDecoderCombineTpl, decodercombine::StdToken >; +template class LatticeFasterDecoderCombineTpl, decodercombine::StdToken >; +template class LatticeFasterDecoderCombineTpl; + +template class LatticeFasterDecoderCombineTpl , decodercombine::BackpointerToken>; +template class LatticeFasterDecoderCombineTpl, decodercombine::BackpointerToken >; +template class LatticeFasterDecoderCombineTpl, decodercombine::BackpointerToken >; +template class LatticeFasterDecoderCombineTpl; } // end namespace kaldi. diff --git a/src/decoder/lattice-faster-decoder-combine.h b/src/decoder/lattice-faster-decoder-combine.h index abfdb5c21fc..f4d74a5acd8 100644 --- a/src/decoder/lattice-faster-decoder-combine.h +++ b/src/decoder/lattice-faster-decoder-combine.h @@ -31,6 +31,7 @@ #include "lat/determinize-lattice-pruned.h" #include "lat/kaldi-lattice.h" #include "decoder/grammar-fst.h" +#include "decoder/lattice-faster-decoder.h" namespace kaldi { @@ -88,7 +89,8 @@ struct LatticeFasterDecoderCombineConfig { } }; -namespace decoder { + +namespace decodercombine { // We will template the decoder on the token type as well as the FST type; this // is a mechanism so that we can use the same underlying decoder code for // versions of the decoder that support quickly getting the best path @@ -231,14 +233,14 @@ struct BackpointerToken { will internally cast itself to one that is templated on those more specific types; this is an optimization for speed. */ -template +template class LatticeFasterDecoderCombineTpl { public: using Arc = typename FST::Arc; using Label = typename Arc::Label; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; - using ForwardLinkT = decoder::ForwardLink; + using ForwardLinkT = decodercombine::ForwardLink; using StateIdToTokenMap = typename std::unordered_map; using IterType = typename StateIdToTokenMap::const_iterator; @@ -530,7 +532,7 @@ class LatticeFasterDecoderCombineTpl { KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderCombineTpl); }; -typedef LatticeFasterDecoderCombineTpl LatticeFasterDecoderCombine; +typedef LatticeFasterDecoderCombineTpl LatticeFasterDecoderCombine; From 2538a32c4fa6df9e21a6fd4a5cdd829f45fc2373 Mon Sep 17 00:00:00 2001 From: LvHang Date: Thu, 28 Feb 2019 02:02:13 -0500 Subject: [PATCH 03/29] add test2 --- src/bin/latgen-faster-mapped-combine.cc | 179 ++++++++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 src/bin/latgen-faster-mapped-combine.cc diff --git a/src/bin/latgen-faster-mapped-combine.cc b/src/bin/latgen-faster-mapped-combine.cc new file mode 100644 index 00000000000..ae5946d9e8e --- /dev/null +++ b/src/bin/latgen-faster-mapped-combine.cc @@ -0,0 +1,179 @@ +// bin/latgen-faster-mapped.cc + +// Copyright 2009-2012 Microsoft Corporation, Karel Vesely +// 2013 Johns Hopkins University (author: Daniel Povey) +// 2014 Guoguo Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "tree/context-dep.h" +#include "hmm/transition-model.h" +#include "fstext/fstext-lib.h" +#include "decoder/decoder-wrappers.h" +#include "decoder/decodable-matrix.h" +#include "base/timer.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::Fst; + using fst::StdArc; + + const char *usage = + "Generate lattices, reading log-likelihoods as matrices\n" + " (model is needed only for the integer mappings in its transition-model)\n" + "Usage: latgen-faster-mapped [options] trans-model-in (fst-in|fsts-rspecifier) loglikes-rspecifier" + " lattice-wspecifier [ words-wspecifier [alignments-wspecifier] ]\n"; + ParseOptions po(usage); + Timer timer; + bool allow_partial = false; + BaseFloat acoustic_scale = 0.1; + LatticeFasterDecoderCombineConfig config; + + std::string word_syms_filename; + config.Register(&po); + po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic likelihoods"); + + po.Register("word-symbol-table", &word_syms_filename, "Symbol table for words [for debug output]"); + po.Register("allow-partial", &allow_partial, "If true, produce output even if end state was not reached."); + + po.Read(argc, argv); + + if (po.NumArgs() < 4 || po.NumArgs() > 6) { + po.PrintUsage(); + exit(1); + } + + std::string model_in_filename = po.GetArg(1), + fst_in_str = po.GetArg(2), + feature_rspecifier = po.GetArg(3), + lattice_wspecifier = po.GetArg(4), + words_wspecifier = po.GetOptArg(5), + alignment_wspecifier = po.GetOptArg(6); + + TransitionModel trans_model; + ReadKaldiObject(model_in_filename, &trans_model); + + bool determinize = config.determinize_lattice; + CompactLatticeWriter compact_lattice_writer; + LatticeWriter lattice_writer; + if (! (determinize ? compact_lattice_writer.Open(lattice_wspecifier) + : lattice_writer.Open(lattice_wspecifier))) + KALDI_ERR << "Could not open table for writing lattices: " + << lattice_wspecifier; + + Int32VectorWriter words_writer(words_wspecifier); + + Int32VectorWriter alignment_writer(alignment_wspecifier); + + fst::SymbolTable *word_syms = NULL; + if (word_syms_filename != "") + if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename))) + KALDI_ERR << "Could not read symbol table from file " + << word_syms_filename; + + double tot_like = 0.0; + kaldi::int64 frame_count = 0; + int num_success = 0, num_fail = 0; + + if (ClassifyRspecifier(fst_in_str, NULL, NULL) == kNoRspecifier) { + SequentialBaseFloatMatrixReader loglike_reader(feature_rspecifier); + // Input FST is just one FST, not a table of FSTs. + Fst *decode_fst = fst::ReadFstKaldiGeneric(fst_in_str); + timer.Reset(); + + { + LatticeFasterDecoderCombine decoder(*decode_fst, config); + + for (; !loglike_reader.Done(); loglike_reader.Next()) { + std::string utt = loglike_reader.Key(); + Matrix loglikes (loglike_reader.Value()); + loglike_reader.FreeCurrent(); + if (loglikes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + + DecodableMatrixScaledMapped decodable(trans_model, loglikes, acoustic_scale); + + double like; + if (DecodeUtteranceLatticeFasterCombine( + decoder, decodable, trans_model, word_syms, utt, + acoustic_scale, determinize, allow_partial, &alignment_writer, + &words_writer, &compact_lattice_writer, &lattice_writer, + &like)) { + tot_like += like; + frame_count += loglikes.NumRows(); + num_success++; + } else num_fail++; + } + } + delete decode_fst; // delete this only after decoder goes out of scope. + } else { // We have different FSTs for different utterances. + SequentialTableReader fst_reader(fst_in_str); + RandomAccessBaseFloatMatrixReader loglike_reader(feature_rspecifier); + for (; !fst_reader.Done(); fst_reader.Next()) { + std::string utt = fst_reader.Key(); + if (!loglike_reader.HasKey(utt)) { + KALDI_WARN << "Not decoding utterance " << utt + << " because no loglikes available."; + num_fail++; + continue; + } + const Matrix &loglikes = loglike_reader.Value(utt); + if (loglikes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + LatticeFasterDecoderCombine decoder(fst_reader.Value(), config); + DecodableMatrixScaledMapped decodable(trans_model, loglikes, acoustic_scale); + double like; + if (DecodeUtteranceLatticeFasterCombine( + decoder, decodable, trans_model, word_syms, utt, acoustic_scale, + determinize, allow_partial, &alignment_writer, &words_writer, + &compact_lattice_writer, &lattice_writer, &like)) { + tot_like += like; + frame_count += loglikes.NumRows(); + num_success++; + } else num_fail++; + } + } + + double elapsed = timer.Elapsed(); + KALDI_LOG << "Time taken "<< elapsed + << "s: real-time factor assuming 100 frames/sec is " + << (elapsed*100.0/frame_count); + KALDI_LOG << "Done " << num_success << " utterances, failed for " + << num_fail; + KALDI_LOG << "Overall log-likelihood per frame is " << (tot_like/frame_count) << " over " + << frame_count<<" frames."; + + delete word_syms; + if (num_success != 0) return 0; + else return 1; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} From d0113426c3115de4784523e37f44188dd700c859 Mon Sep 17 00:00:00 2001 From: LvHang Date: Thu, 28 Feb 2019 22:57:37 -0500 Subject: [PATCH 04/29] Update design and comments --- src/decoder/lattice-faster-decoder-combine.cc | 125 +++++++++--------- src/decoder/lattice-faster-decoder-combine.h | 50 ++++--- 2 files changed, 88 insertions(+), 87 deletions(-) diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index 67c4bfe7e8e..4664e515ea2 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -1,9 +1,10 @@ -// decoder/lattice-faster-decoder.cc +// decoder/lattice-faster-decoder-combine.cc // Copyright 2009-2012 Microsoft Corporation Mirko Hannemann -// 2013-2018 Johns Hopkins University (Author: Daniel Povey) +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) // 2014 Guoguo Chen // 2018 Zhehuai Chen +// 2019 Hang Lyu // See ../../COPYING for clarification regarding multiple authors // @@ -68,10 +69,6 @@ void LatticeFasterDecoderCombineTpl::InitDecoding() { active_toks_[0].toks = start_tok; next_toks_[start_state] = start_tok; // initialize current tokens map num_toks_++; - - recover_ = false; - frame_processed_.resize(1); - frame_processed_[0] = false; } // Returns true if any kind of traceback is available (not necessarily from @@ -91,8 +88,7 @@ bool LatticeFasterDecoderCombineTpl::Decode(DecodableInterface *deco ProcessForFrame(decodable); } // Procss non-emitting arcs for the last frame. - ProcessNonemitting(false); - frame_processed_[active_toks_.size() - 1] = true; // the last frame is processed. + ProcessNonemitting(NULL); FinalizeDecoding(); @@ -123,10 +119,6 @@ bool LatticeFasterDecoderCombineTpl::GetRawLattice( typedef Arc::StateId StateId; typedef Arc::Weight Weight; typedef Arc::Label Label; - // Process the non-emitting arcs for the unfinished last frame. - if (!frame_processed_[active_toks_.size() - 1]) { - ProcessNonemitting(true); - } // Note: you can't use the old interface (Decode()) if you want to // get the lattice with use_final_probs = false. You'd have to do // InitDecoding() and then AdvanceDecoding(). @@ -134,6 +126,14 @@ bool LatticeFasterDecoderCombineTpl::GetRawLattice( KALDI_ERR << "You cannot call FinalizeDecoding() and then call " << "GetRawLattice() with use_final_probs == false"; + std::unordered_map *recover_map = NULL; + if (!decoding_finalized_) { + recover_map = new std::unordered_map(); + // Process the non-emitting arcs for the unfinished last frame. + ProcessNonemitting(recover_map); + } + + unordered_map final_costs_local; const unordered_map &final_costs = @@ -201,10 +201,42 @@ bool LatticeFasterDecoderCombineTpl::GetRawLattice( } } } + + if (recover_map) { // recover last token list + RecoverLastTokenList(recover_map); + delete recover_map; + } return (ofst->NumStates() > 0); } +// When GetRawLattice() is called during decoding, the +// active_toks_[last_frame] is changed. To keep the consistency of function +// ProcessForFrame(), recover it. +// Notice: as new token will be added to the head of TokenList, tok->next +// will not be affacted. +template +void LatticeFasterDecoderCombineTpl::RecoverLastTokenList( + std::unordered_map *recover_map) { + if (recover_map) { + for (Token* tok = active_toks_[active_toks_.size() - 1].toks; + tok != NULL;) { + if (recover_map->find(tok) != recover_map->end()) { + DeleteForwardLinks(tok); + tok->tot_cost = (*recover_map)[tok]; + tok->in_current_queue = false; + tok = tok->next; + } else { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + } +} + // This function is now deprecated, since now we do determinization from outside // the LatticeFasterDecoder class. Outputs an FST corresponding to the // lattice-determinized lattice (one path per word sequence). @@ -258,19 +290,19 @@ bool LatticeFasterDecoderCombineTpl::GetLattice( only do it every 'config_.prune_interval' frames). */ -// FindOrAddToken either locates a token in hash of toks_, +// FindOrAddToken either locates a token in hash map "token_map" // or if necessary inserts a new, empty token (i.e. with no forward links) // for the current frame. [note: it's inserted if necessary into hash toks_ // and also into the singly linked list of tokens active on this frame // (whose head is at active_toks_[frame]). template inline Token* LatticeFasterDecoderCombineTpl::FindOrAddToken( - StateId state, int32 frame, BaseFloat tot_cost, Token *backpointer, + StateId state, int32 frame_plus_one, BaseFloat tot_cost, Token *backpointer, StateIdToTokenMap *token_map, bool *changed) { // Returns the Token pointer. Sets "changed" (if non-NULL) to true // if the token was newly created or the cost changed. - KALDI_ASSERT(frame < active_toks_.size()); - Token *&toks = active_toks_[frame].toks; + KALDI_ASSERT(frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; typename StateIdToTokenMap::iterator e_found = token_map->find(state); if (e_found == token_map->end()) { // no such token presently. const BaseFloat extra_cost = 0.0; @@ -626,7 +658,7 @@ void LatticeFasterDecoderCombineTpl::AdvanceDecoding( } ProcessForFrame(decodable); } - ProcessNonemitting(false); + ProcessNonemitting(NULL); } // FinalizeDecoding() is a version of PruneActiveTokens that we call @@ -732,38 +764,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( int32 frame = active_toks_.size() - 1; // frame is the frame-index // (zero-based) used to get likelihoods // from the decodable object. - if (!recover_ && frame_processed_[frame]) { - KALDI_ERR << "Maybe the whole utterance has been processed, you shouldn't" - << " call ProcessForFrame() again."; - } else if (recover_ && !frame_processed_[frame]) { - KALDI_ERR << "Should not happen."; - } - - // Maybe called GetRawLattice() in the middle of an utterance. The - // active_toks_[frame] is changed. Recover it. - // Notice: as new token will be added to the head of TokenList, tok->next - // will not be affacted. - if (recover_) { - frame_processed_[frame] = false; - for (Token* tok = active_toks_[frame].toks; tok != NULL;) { - if (recover_map_.find(tok) != recover_map_.end()) { - DeleteForwardLinks(tok); - tok->tot_cost = recover_map_[tok]; - tok->in_current_queue = false; - tok = tok->next; - } else { - DeleteForwardLinks(tok); - Token *next_tok = tok->next; - delete tok; - num_toks_--; - tok = next_tok; - } - } - recover_ = false; - } - active_toks_.resize(active_toks_.size() + 1); - frame_processed_.resize(frame_processed_.size() + 1); cur_toks_.clear(); cur_toks_.swap(next_toks_); @@ -890,27 +891,24 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( } // for all arcs tok->in_current_queue = false; // out of queue } // end of while loop - frame_processed_[frame] = true; - frame_processed_[frame + 1] = false; + KALDI_VLOG(6) << "toks after: " << cur_toks_.size(); } template -void LatticeFasterDecoderCombineTpl::ProcessNonemitting(bool recover) { - if (recover) { // Build the elements which are used to recover - // Set the flag to true so that we will recover "next_toks_" map in - // ProcessForFrame() firstly. - recover_ = true; +void LatticeFasterDecoderCombineTpl::ProcessNonemitting( + std::unordered_map *recover_map) { + if (recover_map) { // Build the elements which are used to recover for (IterType iter = next_toks_.begin(); iter != next_toks_.end(); iter++) { - recover_map_[iter->second] = iter->second->tot_cost; + (*recover_map)[iter->second] = iter->second->tot_cost; } } - StateIdToTokenMap tmp_toks_(next_toks_); + StateIdToTokenMap tmp_toks(next_toks_); int32 frame = active_toks_.size() - 1; // Build the queue to process non-emitting arcs std::vector cur_queue; - for (IterType iter = tmp_toks_.begin(); iter != tmp_toks_.end(); iter++) { + for (IterType iter = tmp_toks.begin(); iter != tmp_toks.end(); iter++) { if (fst_->NumInputEpsilons(iter->first) != 0) { cur_queue.push_back(iter->first); iter->second->in_current_queue = true; @@ -920,14 +918,14 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting(bool recover // "cur_cutoff" is used to constrain the epsilon emittion in current frame. // It will not be updated. BaseFloat adaptive_beam; - BaseFloat cur_cutoff = GetCutoff(tmp_toks_, &adaptive_beam, NULL, NULL); + BaseFloat cur_cutoff = GetCutoff(tmp_toks, &adaptive_beam, NULL, NULL); while (!cur_queue.empty()) { StateId state = cur_queue.back(); cur_queue.pop_back(); - KALDI_ASSERT(tmp_toks_.find(state) != tmp_toks_.end()); - Token *tok = tmp_toks_[state]; + KALDI_ASSERT(tmp_toks.find(state) != tmp_toks.end()); + Token *tok = tmp_toks[state]; BaseFloat cur_cost = tok->tot_cost; if (cur_cost > cur_cutoff) // Don't bother processing successors. continue; @@ -946,7 +944,7 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting(bool recover BaseFloat tot_cost = cur_cost + graph_cost; if (tot_cost < cur_cutoff) { Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, - tok, &tmp_toks_, &changed); + tok, &tmp_toks, &changed); // Add ForwardLink from tok to new_tok. Put it on the head of // tok->link list @@ -964,9 +962,6 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting(bool recover } // end of for loop tok->in_current_queue = false; } // end of while loop - frame_processed_[active_toks_.size() - 1] = true; // in case someone call - // GetRawLattice() twice - // continuously. } diff --git a/src/decoder/lattice-faster-decoder-combine.h b/src/decoder/lattice-faster-decoder-combine.h index f4d74a5acd8..57914acbd38 100644 --- a/src/decoder/lattice-faster-decoder-combine.h +++ b/src/decoder/lattice-faster-decoder-combine.h @@ -1,9 +1,10 @@ -// decoder/lattice-faster-decoder.h +// decoder/lattice-faster-decoder-combine.h // Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; -// 2013-2014 Johns Hopkins University (Author: Daniel Povey) +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) // 2014 Guoguo Chen // 2018 Zhehuai Chen +// 2019 Hang Lyu // See ../../COPYING for clarification regarding multiple authors // @@ -294,6 +295,9 @@ class LatticeFasterDecoderCombineTpl { /// of the graph then it will include those as final-probs, else /// it will treat all final-probs as one. /// The raw lattice will be topologically sorted. + /// The function can be called during decoding, it will take "next_toks_" map + /// and generate the complete token list for the last frame. Then recover it + /// to ensure the consistency of ProcessForFrame(). /// /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, /// which also supports a pruning beam, in case for some reason @@ -373,9 +377,9 @@ class LatticeFasterDecoderCombineTpl { must_prune_tokens(true) { } }; - // FindOrAddToken either locates a token in hash of toks_, or if necessary + // FindOrAddToken either locates a token in hash map "token_map", or if necessary // inserts a new, empty token (i.e. with no forward links) for the current - // frame. [note: it's inserted if necessary into hash toks_ and also into the + // frame. [note: it's inserted if necessary into hash map and also into the // singly linked list of tokens active on this frame (whose head is at // active_toks_[frame]). The frame_plus_one argument is the acoustic frame // index plus one, which is used to index into the active_toks_ array. @@ -383,7 +387,7 @@ class LatticeFasterDecoderCombineTpl { // token was newly created or the cost changed. // If Token == StdToken, the 'backpointer' argument has no purpose (and will // hopefully be optimized out). - inline Token *FindOrAddToken(StateId state, int32 frame, + inline Token *FindOrAddToken(StateId state, int32 frame_plus_one, BaseFloat tot_cost, Token *backpointer, StateIdToTokenMap *token_map, bool *changed); @@ -442,18 +446,28 @@ class LatticeFasterDecoderCombineTpl { // less far. void PruneActiveTokens(BaseFloat delta); - /// Processes nonemitting (epsilon) arcs and emitting arcs for one frame - /// together. Consider it as a combination of ProcessEmitting() and - /// ProcessNonemitting(). + /// Processes non-emitting (epsilon) arcs and emitting arcs for one frame + /// together. It takes the emittion tokens in "cur_toks_" from last frame. + /// Generates non-emitting tokens for current frame and emitting tokens for + /// next frame. void ProcessForFrame(DecodableInterface *decodable); /// Processes nonemitting (epsilon) arcs for one frame. - /// Called once when all frames were processed or in GetRawLattice(). - /// Deal With the tokens in map "next_toks_" which would only contains - /// emittion tokens from previous frame. - /// If you call this function not in the end of an utterance, recover - /// should be true. - void ProcessNonemitting(bool recover); + /// Calls this function once when all frames were processed. + /// Or calls it in GetRawLattice() to generate the complete token list for + /// the last frame. [Deal With the tokens in map "next_toks_" which would + /// only contains emittion tokens from previous frame.] + /// If "recover_map" isn't NULL, we build the recover_map which will be used + /// to recover "active_toks_[last_frame]" token list for the last frame. + void ProcessNonemitting(std::unordered_map *recover_map); + + /// When GetRawLattice() is called during decoding, the + /// active_toks_[last_frame] is changed. To keep the consistency of function + /// ProcessForFrame(), recover it. + /// Notice: as new token will be added to the head of TokenList, tok->next + /// will not be affacted. + void RecoverLastTokenList(std::unordered_map *recover_map); + /// The "cur_toks_" and "next_toks_" actually allow us to maintain current /// and next frames. They are indexed by StateId. It is indexed by frame-index @@ -464,14 +478,6 @@ class LatticeFasterDecoderCombineTpl { StateIdToTokenMap cur_toks_; StateIdToTokenMap next_toks_; - /// When we call GetRawLattice() in the middle of an utterance, we have to - /// process non-emitting arcs so that we need to recover it original status. - std::unordered_map recover_map_; // Token pointer to tot_cost - bool recover_; - /// Indicate each frame is processed wholly or not. The size equals to - /// active_toks_. - std::vector frame_processed_; - /// Gets the weight cutoff. /// Notice: In traiditional version, the histogram prunning method is applied /// on a complete token list on one frame. But, in this version, it is used From a758ba43916ea4d6b6ea4e3169aaec81e2849f3b Mon Sep 17 00:00:00 2001 From: LvHang Date: Sat, 2 Mar 2019 16:55:46 -0500 Subject: [PATCH 05/29] update comments and the functions about PNE() --- src/decoder/lattice-faster-decoder-combine.cc | 78 +++++++++---------- src/decoder/lattice-faster-decoder-combine.h | 29 ++++--- 2 files changed, 57 insertions(+), 50 deletions(-) diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index 4664e515ea2..71533fbbc9a 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -41,6 +41,8 @@ LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( const LatticeFasterDecoderCombineConfig &config, FST *fst): fst_(fst), delete_fst_(true), config_(config), num_toks_(0) { config.Check(); + prev_toks_.reserve(1000); + cur_toks_.reserve(1000); } @@ -53,8 +55,8 @@ LatticeFasterDecoderCombineTpl::~LatticeFasterDecoderCombineTpl() { template void LatticeFasterDecoderCombineTpl::InitDecoding() { // clean up from last time: + prev_toks_.clear(); cur_toks_.clear(); - next_toks_.clear(); cost_offsets_.clear(); ClearActiveTokens(); @@ -67,7 +69,7 @@ void LatticeFasterDecoderCombineTpl::InitDecoding() { active_toks_.resize(1); Token *start_tok = new Token(0.0, 0.0, NULL, NULL, NULL); active_toks_[0].toks = start_tok; - next_toks_[start_state] = start_tok; // initialize current tokens map + cur_toks_[start_state] = start_tok; // initialize current tokens map num_toks_++; } @@ -87,9 +89,7 @@ bool LatticeFasterDecoderCombineTpl::Decode(DecodableInterface *deco PruneActiveTokens(config_.lattice_beam * config_.prune_scale); ProcessForFrame(decodable); } - // Procss non-emitting arcs for the last frame. - ProcessNonemitting(NULL); - + // A complete token list of the last frame will be generated in FinalizeDecoding() FinalizeDecoding(); // Returns true if we have any kind of traceback available (not necessarily @@ -126,11 +126,10 @@ bool LatticeFasterDecoderCombineTpl::GetRawLattice( KALDI_ERR << "You cannot call FinalizeDecoding() and then call " << "GetRawLattice() with use_final_probs == false"; - std::unordered_map *recover_map = NULL; + std::unordered_map recover_map; if (!decoding_finalized_) { - recover_map = new std::unordered_map(); // Process the non-emitting arcs for the unfinished last frame. - ProcessNonemitting(recover_map); + ProcessNonemitting(&recover_map); } @@ -202,9 +201,8 @@ bool LatticeFasterDecoderCombineTpl::GetRawLattice( } } - if (recover_map) { // recover last token list + if (!decoding_finalized_) { // recover last token list RecoverLastTokenList(recover_map); - delete recover_map; } return (ofst->NumStates() > 0); } @@ -217,13 +215,13 @@ bool LatticeFasterDecoderCombineTpl::GetRawLattice( // will not be affacted. template void LatticeFasterDecoderCombineTpl::RecoverLastTokenList( - std::unordered_map *recover_map) { - if (recover_map) { + const std::unordered_map &recover_map) { + if (!recover_map.empty()) { for (Token* tok = active_toks_[active_toks_.size() - 1].toks; tok != NULL;) { - if (recover_map->find(tok) != recover_map->end()) { + if (recover_map.find(tok) != recover_map.end()) { DeleteForwardLinks(tok); - tok->tot_cost = (*recover_map)[tok]; + tok->tot_cost = recover_map.find(tok)->second; tok->in_current_queue = false; tok = tok->next; } else { @@ -588,8 +586,8 @@ void LatticeFasterDecoderCombineTpl::ComputeFinalCosts( BaseFloat best_cost = infinity, best_cost_with_final = infinity; - // The final tokens are recorded in unordered_map "next_toks_". - for (IterType iter = next_toks_.begin(); iter != next_toks_.end(); iter++) { + // The final tokens are recorded in unordered_map "cur_toks_". + for (IterType iter = cur_toks_.begin(); iter != cur_toks_.end(); iter++) { StateId state = iter->first; Token *tok = iter->second; BaseFloat final_cost = fst_->Final(state).Value(); @@ -658,7 +656,6 @@ void LatticeFasterDecoderCombineTpl::AdvanceDecoding( } ProcessForFrame(decodable); } - ProcessNonemitting(NULL); } // FinalizeDecoding() is a version of PruneActiveTokens that we call @@ -686,7 +683,7 @@ void LatticeFasterDecoderCombineTpl::FinalizeDecoding() { template BaseFloat LatticeFasterDecoderCombineTpl::GetCutoff( const StateIdToTokenMap &toks, BaseFloat *adaptive_beam, - StateId *best_elem_id, Token **best_elem) { + StateId *best_state_id, Token **best_token) { // positive == high cost == bad. // best_weight is the minimum value. BaseFloat best_weight = std::numeric_limits::infinity(); @@ -696,9 +693,9 @@ BaseFloat LatticeFasterDecoderCombineTpl::GetCutoff( BaseFloat w = static_cast(iter->second->tot_cost); if (w < best_weight) { best_weight = w; - if (best_elem) { - *best_elem_id = iter->first; - *best_elem = iter->second; + if (best_token) { + *best_state_id = iter->first; + *best_token = iter->second; } } } @@ -711,9 +708,9 @@ BaseFloat LatticeFasterDecoderCombineTpl::GetCutoff( tmp_array_.push_back(w); if (w < best_weight) { best_weight = w; - if (best_elem) { - *best_elem_id = iter->first; - *best_elem = iter->second; + if (best_token) { + *best_state_id = iter->first; + *best_token = iter->second; } } } @@ -722,8 +719,8 @@ BaseFloat LatticeFasterDecoderCombineTpl::GetCutoff( min_active_cutoff = std::numeric_limits::infinity(), max_active_cutoff = std::numeric_limits::infinity(); - KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() - << " is " << tmp_array_.size(); + KALDI_VLOG(6) << "Number of emitting tokens on frame " + << NumFramesDecoded() - 1 << " is " << tmp_array_.size(); if (tmp_array_.size() > static_cast(config_.max_active)) { std::nth_element(tmp_array_.begin(), @@ -766,9 +763,9 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // from the decodable object. active_toks_.resize(active_toks_.size() + 1); + prev_toks_.swap(cur_toks_); cur_toks_.clear(); - cur_toks_.swap(next_toks_); - if (cur_toks_.empty()) { + if (prev_toks_.empty()) { if (!warned_) { KALDI_WARN << "Error, no surviving tokens on frame " << frame; warned_ = true; @@ -780,7 +777,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( StateId best_tok_state_id; // "cur_cutoff" is used to constrain the epsilon emittion in current frame. // It will not be updated. - BaseFloat cur_cutoff = GetCutoff(cur_toks_, &adaptive_beam, + BaseFloat cur_cutoff = GetCutoff(prev_toks_, &adaptive_beam, &best_tok_state_id, &best_tok); KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " << adaptive_beam; @@ -801,7 +798,8 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // Notice: As the difference between the combine version and the traditional // version, this "best_tok" is choosen from emittion tokens. Normally, the // best token of one frame comes from an epsilon non-emittion. So the best - // token is a looser boundary. Use it to estimate a bound on the next cutoff. + // token is a looser boundary. We use it to estimate a bound on the next + // cutoff and we will update the "next_cutoff" once we have better tokens. // The "next_cutoff" will be updated in further processing. if (best_tok) { cost_offset = - best_tok->tot_cost; @@ -827,7 +825,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // Build a queue which contains the emittion tokens from previous frame. std::vector cur_queue; - for (IterType iter = cur_toks_.begin(); iter != cur_toks_.end(); iter++) { + for (IterType iter = prev_toks_.begin(); iter != prev_toks_.end(); iter++) { cur_queue.push_back(iter->first); iter->second->in_current_queue = true; } @@ -837,9 +835,11 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( StateId state = cur_queue.back(); cur_queue.pop_back(); - KALDI_ASSERT(cur_toks_.find(state) != cur_toks_.end()); - Token *tok = cur_toks_[state]; + KALDI_ASSERT(prev_toks_.find(state) != prev_toks_.end()); + Token *tok = prev_toks_[state]; + BaseFloat cur_cost = tok->tot_cost; + tok->in_current_queue = false; // out of queue if (cur_cost > cur_cutoff) // Don't bother processing successors. continue; // If "tok" has any existing forward links, delete them, @@ -857,7 +857,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( BaseFloat tot_cost = cur_cost + graph_cost; if (tot_cost < cur_cutoff) { Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, - tok, &cur_toks_, &changed); + tok, &prev_toks_, &changed); // Add ForwardLink from tok to new_tok. Put it on the head of // tok->link list @@ -882,16 +882,16 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // no change flag is needed Token *next_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, - tok, &next_toks_, NULL); + tok, &cur_toks_, NULL); // Add ForwardLink from tok to next_tok. Put it on the head of tok->link // list tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, graph_cost, ac_cost, tok->links); } } // for all arcs - tok->in_current_queue = false; // out of queue } // end of while loop - KALDI_VLOG(6) << "toks after: " << cur_toks_.size(); + KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() - 1 + << " is " << prev_toks_.size(); } @@ -899,12 +899,12 @@ template void LatticeFasterDecoderCombineTpl::ProcessNonemitting( std::unordered_map *recover_map) { if (recover_map) { // Build the elements which are used to recover - for (IterType iter = next_toks_.begin(); iter != next_toks_.end(); iter++) { + for (IterType iter = cur_toks_.begin(); iter != cur_toks_.end(); iter++) { (*recover_map)[iter->second] = iter->second->tot_cost; } } - StateIdToTokenMap tmp_toks(next_toks_); + StateIdToTokenMap tmp_toks(cur_toks_); int32 frame = active_toks_.size() - 1; // Build the queue to process non-emitting arcs std::vector cur_queue; diff --git a/src/decoder/lattice-faster-decoder-combine.h b/src/decoder/lattice-faster-decoder-combine.h index 57914acbd38..99c0540de44 100644 --- a/src/decoder/lattice-faster-decoder-combine.h +++ b/src/decoder/lattice-faster-decoder-combine.h @@ -243,7 +243,9 @@ class LatticeFasterDecoderCombineTpl { using Weight = typename Arc::Weight; using ForwardLinkT = decodercombine::ForwardLink; - using StateIdToTokenMap = typename std::unordered_map; + using StateIdToTokenMap = typename std::unordered_map, std::equal_to, + fst::PoolAllocator > >; using IterType = typename StateIdToTokenMap::const_iterator; // Instantiate this class once for each thing you have to decode. @@ -295,9 +297,10 @@ class LatticeFasterDecoderCombineTpl { /// of the graph then it will include those as final-probs, else /// it will treat all final-probs as one. /// The raw lattice will be topologically sorted. - /// The function can be called during decoding, it will take "next_toks_" map - /// and generate the complete token list for the last frame. Then recover it - /// to ensure the consistency of ProcessForFrame(). + /// The function can be called during decoding, it will process non-emitting + /// arcs from "cur_toks_" map to get tokens from both non-emitting and + /// emitting arcs for getting raw lattice. Then recover it to ensure the + /// consistency of ProcessForFrame(). /// /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, /// which also supports a pruning beam, in case for some reason @@ -447,15 +450,18 @@ class LatticeFasterDecoderCombineTpl { void PruneActiveTokens(BaseFloat delta); /// Processes non-emitting (epsilon) arcs and emitting arcs for one frame - /// together. It takes the emittion tokens in "cur_toks_" from last frame. - /// Generates non-emitting tokens for current frame and emitting tokens for + /// together. It takes the emittion tokens in "prev_toks_" from last frame. + /// Generates non-emitting tokens for previous frame and emitting tokens for /// next frame. + /// Notice: The emitting tokens for the current frame means the token take + /// acoustic scores of the current frame. (i.e. the destnations of emitting + /// arcs.) void ProcessForFrame(DecodableInterface *decodable); /// Processes nonemitting (epsilon) arcs for one frame. /// Calls this function once when all frames were processed. /// Or calls it in GetRawLattice() to generate the complete token list for - /// the last frame. [Deal With the tokens in map "next_toks_" which would + /// the last frame. [Deal With the tokens in map "cur_toks_" which would /// only contains emittion tokens from previous frame.] /// If "recover_map" isn't NULL, we build the recover_map which will be used /// to recover "active_toks_[last_frame]" token list for the last frame. @@ -466,17 +472,18 @@ class LatticeFasterDecoderCombineTpl { /// ProcessForFrame(), recover it. /// Notice: as new token will be added to the head of TokenList, tok->next /// will not be affacted. - void RecoverLastTokenList(std::unordered_map *recover_map); + void RecoverLastTokenList( + const std::unordered_map &recover_map); - /// The "cur_toks_" and "next_toks_" actually allow us to maintain current + /// The "prev_toks_" and "cur_toks_" actually allow us to maintain current /// and next frames. They are indexed by StateId. It is indexed by frame-index /// plus one, where the frame-index is zero-based, as used in decodable object. /// That is, the emitting probs of frame t are accounted for in tokens at /// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of /// the graph. + StateIdToTokenMap prev_toks_; StateIdToTokenMap cur_toks_; - StateIdToTokenMap next_toks_; /// Gets the weight cutoff. /// Notice: In traiditional version, the histogram prunning method is applied @@ -485,7 +492,7 @@ class LatticeFasterDecoderCombineTpl { /// and min_active values might be narrowed. BaseFloat GetCutoff(const StateIdToTokenMap& toks, BaseFloat *adaptive_beam, - StateId *best_elem_id, Token **best_elem); + StateId *best_state_id, Token **best_token); std::vector active_toks_; // Lists of tokens, indexed by // frame (members of TokenList are toks, must_prune_forward_links, From 85da9980af44edfcd95ea2dcc6b29febbdc2a436 Mon Sep 17 00:00:00 2001 From: LvHang Date: Sat, 2 Mar 2019 17:39:23 -0500 Subject: [PATCH 06/29] =?UTF-8?q?change=20recvoer=E2=80=94=C3=A2=5Fmap=20t?= =?UTF-8?q?o=20token=5Forig=5Fcost=20and=20document?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/decoder/lattice-faster-decoder-combine.cc | 20 +++++++++---------- src/decoder/lattice-faster-decoder-combine.h | 14 +++++++++---- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index 71533fbbc9a..1b56bf92904 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -126,10 +126,10 @@ bool LatticeFasterDecoderCombineTpl::GetRawLattice( KALDI_ERR << "You cannot call FinalizeDecoding() and then call " << "GetRawLattice() with use_final_probs == false"; - std::unordered_map recover_map; + std::unordered_map token_orig_cost; if (!decoding_finalized_) { // Process the non-emitting arcs for the unfinished last frame. - ProcessNonemitting(&recover_map); + ProcessNonemitting(&token_orig_cost); } @@ -202,7 +202,7 @@ bool LatticeFasterDecoderCombineTpl::GetRawLattice( } if (!decoding_finalized_) { // recover last token list - RecoverLastTokenList(recover_map); + RecoverLastTokenList(token_orig_cost); } return (ofst->NumStates() > 0); } @@ -215,13 +215,13 @@ bool LatticeFasterDecoderCombineTpl::GetRawLattice( // will not be affacted. template void LatticeFasterDecoderCombineTpl::RecoverLastTokenList( - const std::unordered_map &recover_map) { - if (!recover_map.empty()) { + const std::unordered_map &token_orig_cost) { + if (!token_orig_cost.empty()) { for (Token* tok = active_toks_[active_toks_.size() - 1].toks; tok != NULL;) { - if (recover_map.find(tok) != recover_map.end()) { + if (token_orig_cost.find(tok) != token_orig_cost.end()) { DeleteForwardLinks(tok); - tok->tot_cost = recover_map.find(tok)->second; + tok->tot_cost = token_orig_cost.find(tok)->second; tok->in_current_queue = false; tok = tok->next; } else { @@ -897,10 +897,10 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( template void LatticeFasterDecoderCombineTpl::ProcessNonemitting( - std::unordered_map *recover_map) { - if (recover_map) { // Build the elements which are used to recover + std::unordered_map *token_orig_cost) { + if (token_orig_cost) { // Build the elements which are used to recover for (IterType iter = cur_toks_.begin(); iter != cur_toks_.end(); iter++) { - (*recover_map)[iter->second] = iter->second->tot_cost; + (*token_orig_cost)[iter->second] = iter->second->tot_cost; } } diff --git a/src/decoder/lattice-faster-decoder-combine.h b/src/decoder/lattice-faster-decoder-combine.h index 99c0540de44..92e6f4fb682 100644 --- a/src/decoder/lattice-faster-decoder-combine.h +++ b/src/decoder/lattice-faster-decoder-combine.h @@ -463,17 +463,23 @@ class LatticeFasterDecoderCombineTpl { /// Or calls it in GetRawLattice() to generate the complete token list for /// the last frame. [Deal With the tokens in map "cur_toks_" which would /// only contains emittion tokens from previous frame.] - /// If "recover_map" isn't NULL, we build the recover_map which will be used - /// to recover "active_toks_[last_frame]" token list for the last frame. - void ProcessNonemitting(std::unordered_map *recover_map); + /// If the map, "token_orig_cost", isn't NULL, we build the map which will + /// be used to recover "active_toks_[last_frame]" token list for the last + /// frame. + void ProcessNonemitting(std::unordered_map *token_orig_cost); /// When GetRawLattice() is called during decoding, the /// active_toks_[last_frame] is changed. To keep the consistency of function /// ProcessForFrame(), recover it. /// Notice: as new token will be added to the head of TokenList, tok->next /// will not be affacted. + /// "token_orig_cost" is a mapping from token pointer to the tot_cost of the + /// token before propagating non-emitting arcs. It is used to recover the + /// change of original tokens in the last frame and remove the new tokens + /// which come from propagating non-emitting arcs, so that we can guarantee + /// the consistency of function ProcessForFrame(). void RecoverLastTokenList( - const std::unordered_map &recover_map); + const std::unordered_map &token_orig_cost); /// The "prev_toks_" and "cur_toks_" actually allow us to maintain current From 18e8758d9f76bf08bc263dff2fc83f31a29bcc97 Mon Sep 17 00:00:00 2001 From: LvHang Date: Sat, 2 Mar 2019 19:23:27 -0500 Subject: [PATCH 07/29] add a simple test script --- egs/wsj/s5/steps/decode_combine_test.sh | 128 ++++++++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100755 egs/wsj/s5/steps/decode_combine_test.sh diff --git a/egs/wsj/s5/steps/decode_combine_test.sh b/egs/wsj/s5/steps/decode_combine_test.sh new file mode 100755 index 00000000000..53fde71f830 --- /dev/null +++ b/egs/wsj/s5/steps/decode_combine_test.sh @@ -0,0 +1,128 @@ +#!/bin/bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +# Apache 2.0 + +# Begin configuration. +nj=4 +cmd=run.pl +maxactive=7000 +beam=15.0 +lattice_beam=8.0 +expand_beam=30.0 +acwt=1.0 +skip_scoring=false +combine_version=false + +stage=0 +online_ivector_dir= +post_decode_acwt=10.0 +extra_left_context=0 +extra_right_context=0 +extra_left_context_initial=0 +extra_right_context_final=0 +chunk_width=140,100,160 +use_gpu=no +# End configuration. + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh; # source the path. +. parse_options.sh || exit 1; + +if [ $# != 5 ]; then + echo "Usage: steps/decode_combine_test.sh [options] " + echo "... where is assumed to be a sub-directory of the directory" + echo " where the model is." + echo "e.g.: steps/decode_combine_test.sh exp/mono/graph_tgpar data/test_dev93 exp/mono/decode_dev93_tgpr" + echo "" + echo "This script works on CMN + (delta+delta-delta | LDA+MLLT) features; it works out" + echo "what type of features you used (assuming it's one of these two)" + echo "" + echo "main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --nj # number of parallel jobs" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + exit 1; +fi + + +graphdir=$1 +data=$2 +dir=$3 + +srcdir=`dirname $dir`; # The model directory is one level up from decoding directory. +sdata=$data/split$nj; +splice_opts=`cat $srcdir/splice_opts 2>/dev/null` +cmvn_opts=`cat $srcdir/cmvn_opts 2>/dev/null` +delta_opts=`cat $srcdir/delta_opts 2>/dev/null` + +mkdir -p $dir/log +[[ -d $sdata && $data/feats.scp -ot $sdata ]] || split_data.sh $data $nj || exit 1; +echo $nj > $dir/num_jobs + + +for f in $sdata/1/feats.scp $sdata/1/cmvn.scp $srcdir/final.mdl $graphdir/HCLG.fst $oldlm_fst $newlm_fst; do + [ ! -f $f ] && echo "decode_si.sh: no such file $f" && exit 1; +done + + +if [ -f $srcdir/final.mat ]; then feat_type=lda; else feat_type=delta; fi +echo "decode_combine_test.sh: feature type is $feat_type" + +feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- |" + +posteriors="ark,scp:$sdata/JOB/posterior.ark,$sdata/JOB/posterior.scp" +posteriors_scp="scp:$sdata/JOB/posterior.scp" + +if [ ! -z "$online_ivector_dir" ]; then + ivector_period=$(cat $online_ivector_dir/ivector_period) || exit 1; + ivector_opts="--online-ivectors=scp:$online_ivector_dir/ivector_online.scp --online-ivector-period=$ivector_period" +fi + +if [ "$post_decode_acwt" == 1.0 ]; then + lat_wspecifier="ark:|gzip -c >$dir/lat.JOB.gz" +else + lat_wspecifier="ark:|lattice-scale --acoustic-scale=$post_decode_acwt ark:- ark:- | gzip -c >$dir/lat.JOB.gz" +fi + +frame_subsampling_opt= +if [ -f $srcdir/frame_subsampling_factor ]; then + # e.g. for 'chain' systems + frame_subsampling_opt="--frame-subsampling-factor=$(cat $srcdir/frame_subsampling_factor)" +fi + +frames_per_chunk=$(echo $chunk_width | cut -d, -f1) +# generate log-likelihood +if [ $stage -le 1 ]; then + $cmd JOB=1:$nj $dir/log/nnet_compute.JOB.log \ + nnet3-compute $ivector_opts $frame_subsampling_opt \ + --acoustic-scale=$acwt \ + --extra-left-context=$extra_left_context \ + --extra-right-context=$extra_right_context \ + --extra-left-context-initial=$extra_left_context_initial \ + --extra-right-context-final=$extra_right_context_final \ + --frames-per-chunk=$frames_per_chunk \ + --use-gpu=$use_gpu --use-priors=true \ + $srcdir/final.mdl "$feats" "$posteriors" +fi + +if [ $stage -le 2 ]; then + suffix= + if $combine_version ; then + suffix="-combine" + fi + $cmd JOB=1:$nj $dir/log/decode.JOB.log \ + latgen-faster-mapped$suffix --max-active=$maxactive --beam=$beam --lattice-beam=$lattice_beam + --acoustic-scale=$acwt --allow-partial=true --word-symbol-table=$graphdir/words.txt \ + $srcdir/final.mdl $graphdir/HCLG.fst "$posteriors_scp" "$lat_wspecifier" || exit 1; +fi + +if ! $skip_scoring ; then + [ ! -x local/score.sh ] && \ + echo "Not scoring because local/score.sh does not exist or not executable." && exit 1; + local/score.sh --cmd "$cmd" $data $graphdir $dir || + { echo "$0: Scoring failed. (ignore by '--skip-scoring true')"; exit 1; } +fi + +exit 0; From b223bb7c56e361f9b0991546362c7c44e0fe4c92 Mon Sep 17 00:00:00 2001 From: LvHang Date: Sat, 2 Mar 2019 19:31:35 -0500 Subject: [PATCH 08/29] small fix --- egs/wsj/s5/steps/decode_combine_test.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/wsj/s5/steps/decode_combine_test.sh b/egs/wsj/s5/steps/decode_combine_test.sh index 53fde71f830..7d53f67faad 100755 --- a/egs/wsj/s5/steps/decode_combine_test.sh +++ b/egs/wsj/s5/steps/decode_combine_test.sh @@ -30,7 +30,7 @@ echo "$0 $@" # Print the command line for logging [ -f ./path.sh ] && . ./path.sh; # source the path. . parse_options.sh || exit 1; -if [ $# != 5 ]; then +if [ $# != 3 ]; then echo "Usage: steps/decode_combine_test.sh [options] " echo "... where is assumed to be a sub-directory of the directory" echo " where the model is." @@ -62,8 +62,8 @@ mkdir -p $dir/log echo $nj > $dir/num_jobs -for f in $sdata/1/feats.scp $sdata/1/cmvn.scp $srcdir/final.mdl $graphdir/HCLG.fst $oldlm_fst $newlm_fst; do - [ ! -f $f ] && echo "decode_si.sh: no such file $f" && exit 1; +for f in $sdata/1/feats.scp $sdata/1/cmvn.scp $srcdir/final.mdl $graphdir/HCLG.fst; do + [ ! -f $f ] && echo "decode_combine_test.sh: no such file $f" && exit 1; done @@ -113,7 +113,7 @@ if [ $stage -le 2 ]; then suffix="-combine" fi $cmd JOB=1:$nj $dir/log/decode.JOB.log \ - latgen-faster-mapped$suffix --max-active=$maxactive --beam=$beam --lattice-beam=$lattice_beam + latgen-faster-mapped$suffix --max-active=$maxactive --beam=$beam --lattice-beam=$lattice_beam \ --acoustic-scale=$acwt --allow-partial=true --word-symbol-table=$graphdir/words.txt \ $srcdir/final.mdl $graphdir/HCLG.fst "$posteriors_scp" "$lat_wspecifier" || exit 1; fi From dca20d474432913f6261017ad23cbaa56b6f2e45 Mon Sep 17 00:00:00 2001 From: LvHang Date: Sun, 3 Mar 2019 20:05:41 -0500 Subject: [PATCH 09/29] fix --- src/decoder/lattice-faster-decoder-combine.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index 1b56bf92904..352cc644113 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -663,6 +663,7 @@ void LatticeFasterDecoderCombineTpl::AdvanceDecoding( // tokens. This function used to be called PruneActiveTokensFinal(). template void LatticeFasterDecoderCombineTpl::FinalizeDecoding() { + ProcessNonemitting(NULL); int32 final_frame_plus_one = NumFramesDecoded(); int32 num_toks_begin = num_toks_; // PruneForwardLinksFinal() prunes final frame (with final-probs), and From 06d38cbb3bdad8df13cfb45330fb18083b2ebb26 Mon Sep 17 00:00:00 2001 From: LvHang Date: Tue, 5 Mar 2019 15:21:03 -0500 Subject: [PATCH 10/29] change queue for speeding up --- src/decoder/lattice-faster-decoder-combine.cc | 47 +++++++++++-------- src/decoder/lattice-faster-decoder-combine.h | 4 +- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index 352cc644113..b788f6505e6 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -825,16 +825,15 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( cost_offsets_[frame] = cost_offset; // Build a queue which contains the emittion tokens from previous frame. - std::vector cur_queue; for (IterType iter = prev_toks_.begin(); iter != prev_toks_.end(); iter++) { - cur_queue.push_back(iter->first); + cur_queue_.push(iter->first); iter->second->in_current_queue = true; } - // Iterator the "cur_queue" to process non-emittion and emittion arcs in fst. - while (!cur_queue.empty()) { - StateId state = cur_queue.back(); - cur_queue.pop_back(); + // Iterator the "cur_queue_" to process non-emittion and emittion arcs in fst. + while (!cur_queue_.empty()) { + StateId state = cur_queue_.front(); + cur_queue_.pop(); KALDI_ASSERT(prev_toks_.find(state) != prev_toks_.end()); Token *tok = prev_toks_[state]; @@ -868,7 +867,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // "changed" tells us whether the new token has a different // cost from before, or is new. if (changed && !new_tok->in_current_queue) { - cur_queue.push_back(arc.nextstate); + cur_queue_.push(arc.nextstate); new_tok->in_current_queue = true; } } @@ -905,13 +904,20 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( } } - StateIdToTokenMap tmp_toks(cur_toks_); + StateIdToTokenMap *tmp_toks; + if (token_orig_cost) { // "token_orig_cost" isn't NULL. It means we need to + // recover active_toks_[last_frame] and "cur_toks_" + // will be used in the future. + tmp_toks = new StateIdToTokenMap(cur_toks_); + } else { + tmp_toks = &cur_toks_; + } + int32 frame = active_toks_.size() - 1; - // Build the queue to process non-emitting arcs - std::vector cur_queue; - for (IterType iter = tmp_toks.begin(); iter != tmp_toks.end(); iter++) { + // Build the queue to process non-emitting arcs. + for (IterType iter = tmp_toks->begin(); iter != tmp_toks->end(); iter++) { if (fst_->NumInputEpsilons(iter->first) != 0) { - cur_queue.push_back(iter->first); + cur_queue_.push(iter->first); iter->second->in_current_queue = true; } } @@ -919,14 +925,14 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( // "cur_cutoff" is used to constrain the epsilon emittion in current frame. // It will not be updated. BaseFloat adaptive_beam; - BaseFloat cur_cutoff = GetCutoff(tmp_toks, &adaptive_beam, NULL, NULL); + BaseFloat cur_cutoff = GetCutoff(*tmp_toks, &adaptive_beam, NULL, NULL); - while (!cur_queue.empty()) { - StateId state = cur_queue.back(); - cur_queue.pop_back(); + while (!cur_queue_.empty()) { + StateId state = cur_queue_.front(); + cur_queue_.pop(); - KALDI_ASSERT(tmp_toks.find(state) != tmp_toks.end()); - Token *tok = tmp_toks[state]; + KALDI_ASSERT(tmp_toks->find(state) != tmp_toks->end()); + Token *tok = (*tmp_toks)[state]; BaseFloat cur_cost = tok->tot_cost; if (cur_cost > cur_cutoff) // Don't bother processing successors. continue; @@ -945,7 +951,7 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( BaseFloat tot_cost = cur_cost + graph_cost; if (tot_cost < cur_cutoff) { Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, - tok, &tmp_toks, &changed); + tok, tmp_toks, &changed); // Add ForwardLink from tok to new_tok. Put it on the head of // tok->link list @@ -955,7 +961,7 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( // "changed" tells us whether the new token has a different // cost from before, or is new. if (changed && !new_tok->in_current_queue) { - cur_queue.push_back(arc.nextstate); + cur_queue_.push(arc.nextstate); new_tok->in_current_queue = true; } } @@ -963,6 +969,7 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( } // end of for loop tok->in_current_queue = false; } // end of while loop + if (token_orig_cost) delete tmp_toks; } diff --git a/src/decoder/lattice-faster-decoder-combine.h b/src/decoder/lattice-faster-decoder-combine.h index 92e6f4fb682..799cf20c872 100644 --- a/src/decoder/lattice-faster-decoder-combine.h +++ b/src/decoder/lattice-faster-decoder-combine.h @@ -503,8 +503,8 @@ class LatticeFasterDecoderCombineTpl { std::vector active_toks_; // Lists of tokens, indexed by // frame (members of TokenList are toks, must_prune_forward_links, // must_prune_tokens). - std::vector queue_; // temp variable used in ProcessForFrame for - // epsilon arcs. + std::queue cur_queue_; // temp variable used in ProcessForFrame + // and ProcessNonemitting std::vector tmp_array_; // used in GetCutoff. // fst_ is a pointer to the FST we are decoding from. From 603b7053a981bd5a688d7e8ae4255c95779da993 Mon Sep 17 00:00:00 2001 From: LvHang Date: Thu, 7 Mar 2019 21:41:29 -0500 Subject: [PATCH 11/29] add hashlist version for test --- ...lattice-faster-decoder-combine-hashlist.cc | 1129 +++++++++++++++++ .../lattice-faster-decoder-combine-hashlist.h | 567 +++++++++ 2 files changed, 1696 insertions(+) create mode 100644 src/decoder/lattice-faster-decoder-combine-hashlist.cc create mode 100644 src/decoder/lattice-faster-decoder-combine-hashlist.h diff --git a/src/decoder/lattice-faster-decoder-combine-hashlist.cc b/src/decoder/lattice-faster-decoder-combine-hashlist.cc new file mode 100644 index 00000000000..c0bf7fb6672 --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine-hashlist.cc @@ -0,0 +1,1129 @@ +// decoder/lattice-faster-decoder-combine.cc + +// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/lattice-faster-decoder-combine.h" +#include "lat/lattice-functions.h" + +namespace kaldi { + +// instantiate this class once for each thing you have to decode. +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const FST &fst, + const LatticeFasterDecoderCombineConfig &config): + fst_(&fst), delete_fst_(false), config_(config), num_toks_(0) { + config.Check(); + prev_toks_ = new StateIdToTokenMap(); + prev_toks_->SetSize(1000); + cur_toks_ = new StateIdToTokenMap(); + cur_toks_->SetSize(1000); +} + + +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const LatticeFasterDecoderCombineConfig &config, FST *fst): + fst_(fst), delete_fst_(true), config_(config), num_toks_(0) { + config.Check(); + prev_toks_ = new StateIdToTokenMap(); + prev_toks_->SetSize(1000); + cur_toks_ = new StateIdToTokenMap(); + cur_toks_->SetSize(1000); +} + + +template +void LatticeFasterDecoderCombineTpl::DeleteElems( + Elem *list, HashList *toks) { + for (Elem *e = list, *e_tail; e != NULL; e = e_tail) { + e_tail = e->tail; + toks->Delete(e); + } +} + + +template +LatticeFasterDecoderCombineTpl::~LatticeFasterDecoderCombineTpl() { + DeleteElems(cur_toks_->Clear(), cur_toks_); + ClearActiveTokens(); + if (delete_fst_) delete fst_; + delete prev_toks_; + delete cur_toks_; +} + +template +void LatticeFasterDecoderCombineTpl::InitDecoding() { + // clean up from last time: + DeleteElems(prev_toks_->Clear(), prev_toks_); + DeleteElems(cur_toks_->Clear(), cur_toks_); + cost_offsets_.clear(); + ClearActiveTokens(); + + warned_ = false; + num_toks_ = 0; + decoding_finalized_ = false; + final_costs_.clear(); + StateId start_state = fst_->Start(); + KALDI_ASSERT(start_state != fst::kNoStateId); + active_toks_.resize(1); + Token *start_tok = new Token(0.0, 0.0, NULL, NULL, NULL); + active_toks_[0].toks = start_tok; + //cur_toks_[start_state] = start_tok; // initialize current tokens map + cur_toks_->Insert(start_state, start_tok); + num_toks_++; +} + +// Returns true if any kind of traceback is available (not necessarily from +// a final state). It should only very rarely return false; this indicates +// an unusual search error. +template +bool LatticeFasterDecoderCombineTpl::Decode(DecodableInterface *decodable) { + InitDecoding(); + + // We use 1-based indexing for frames in this decoder (if you view it in + // terms of features), but note that the decodable object uses zero-based + // numbering, which we have to correct for when we call it. + + while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { + if (NumFramesDecoded() % config_.prune_interval == 0) + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + ProcessForFrame(decodable); + } + // A complete token list of the last frame will be generated in FinalizeDecoding() + FinalizeDecoding(); + + // Returns true if we have any kind of traceback available (not necessarily + // to the end state; query ReachedFinal() for that). + return !active_toks_.empty() && active_toks_.back().toks != NULL; +} + + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeFasterDecoderCombineTpl::GetBestPath( + Lattice *olat, + bool use_final_probs) { + Lattice raw_lat; + GetRawLattice(&raw_lat, use_final_probs); + ShortestPath(raw_lat, olat); + return (olat->NumStates() != 0); +} + + +// Outputs an FST corresponding to the raw, state-level lattice +template +bool LatticeFasterDecoderCombineTpl::GetRawLattice( + Lattice *ofst, + bool use_final_probs) { + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + // Note: you can't use the old interface (Decode()) if you want to + // get the lattice with use_final_probs = false. You'd have to do + // InitDecoding() and then AdvanceDecoding(). + if (decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "GetRawLattice() with use_final_probs == false"; + + std::unordered_map token_orig_cost; + if (!decoding_finalized_) { + // Process the non-emitting arcs for the unfinished last frame. + ProcessNonemitting(&token_orig_cost); + } + + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (decoding_finalized_ ? final_costs_ : final_costs_local); + if (!decoding_finalized_ && use_final_probs) + ComputeFinalCosts(&final_costs_local, NULL, NULL); + + ofst->DeleteStates(); + // num-frames plus one (since frames are one-based, and we have + // an extra frame for the start-state). + int32 num_frames = active_toks_.size() - 1; + KALDI_ASSERT(num_frames > 0); + const int32 bucket_count = num_toks_/2 + 3; + unordered_map tok_map(bucket_count); + // First create all states. + std::vector token_list; + for (int32 f = 0; f <= num_frames; f++) { + if (active_toks_[f].toks == NULL) { + KALDI_WARN << "GetRawLattice: no tokens active on frame " << f + << ": not producing lattice.\n"; + return false; + } + TopSortTokens(active_toks_[f].toks, &token_list); + for (size_t i = 0; i < token_list.size(); i++) + if (token_list[i] != NULL) + tok_map[token_list[i]] = ofst->AddState(); + } + // The next statement sets the start state of the output FST. Because we + // topologically sorted the tokens, state zero must be the start-state. + ofst->SetStart(0); + + KALDI_VLOG(4) << "init:" << num_toks_/2 + 3 << " buckets:" + << tok_map.bucket_count() << " load:" << tok_map.load_factor() + << " max:" << tok_map.max_load_factor(); + // Now create all arcs. + for (int32 f = 0; f <= num_frames; f++) { + for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + for (ForwardLinkT *l = tok->links; + l != NULL; + l = l->next) { + typename unordered_map::const_iterator + iter = tok_map.find(l->next_tok); + StateId nextstate = iter->second; + KALDI_ASSERT(iter != tok_map.end()); + BaseFloat cost_offset = 0.0; + if (l->ilabel != 0) { // emitting.. + KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); + cost_offset = cost_offsets_[f]; + } + Arc arc(l->ilabel, l->olabel, + Weight(l->graph_cost, l->acoustic_cost - cost_offset), + nextstate); + ofst->AddArc(cur_state, arc); + } + if (f == num_frames) { + if (use_final_probs && !final_costs.empty()) { + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) + ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); + } else { + ofst->SetFinal(cur_state, LatticeWeight::One()); + } + } + } + } + + if (!decoding_finalized_) { // recover last token list + RecoverLastTokenList(token_orig_cost); + } + return (ofst->NumStates() > 0); +} + + +// When GetRawLattice() is called during decoding, the +// active_toks_[last_frame] is changed. To keep the consistency of function +// ProcessForFrame(), recover it. +// Notice: as new token will be added to the head of TokenList, tok->next +// will not be affacted. +template +void LatticeFasterDecoderCombineTpl::RecoverLastTokenList( + const std::unordered_map &token_orig_cost) { + if (!token_orig_cost.empty()) { + for (const Elem *e = cur_toks_->GetList(); e != NULL; e = e->tail) { + Token *tok = e->val; + if (token_orig_cost.find(tok) != token_orig_cost.end()) { + DeleteForwardLinks(tok); + tok->tot_cost = token_orig_cost.find(tok)->second; + tok->in_current_queue = false; + tok = tok->next; + } else { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + } +} + +// This function is now deprecated, since now we do determinization from outside +// the LatticeFasterDecoder class. Outputs an FST corresponding to the +// lattice-determinized lattice (one path per word sequence). +template +bool LatticeFasterDecoderCombineTpl::GetLattice( + CompactLattice *ofst, + bool use_final_probs) { + Lattice raw_fst; + GetRawLattice(&raw_fst, use_final_probs); + Invert(&raw_fst); // make it so word labels are on the input. + // (in phase where we get backward-costs). + fst::ILabelCompare ilabel_comp; + ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes + // lattice-determinization more efficient. + + fst::DeterminizeLatticePrunedOptions lat_opts; + lat_opts.max_mem = config_.det_opts.max_mem; + + DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts); + raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed. + Connect(ofst); // Remove unreachable states... there might be + // a small number of these, in some cases. + // Note: if something went wrong and the raw lattice was empty, + // we should still get to this point in the code without warnings or failures. + return (ofst->NumStates() != 0); +} + +/* + A note on the definition of extra_cost. + + extra_cost is used in pruning tokens, to save memory. + + Define the 'forward cost' of a token as zero for any token on the frame + we're currently decoding; and for other frames, as the shortest-path cost + between that token and a token on the frame we're currently decoding. + (by "currently decoding" I mean the most recently processed frame). + + Then define the extra_cost of a token (always >= 0) as the forward-cost of + the token minus the smallest forward-cost of any token on the same frame. + + We can use the extra_cost to accurately prune away tokens that we know will + never appear in the lattice. If the extra_cost is greater than the desired + lattice beam, the token would provably never appear in the lattice, so we can + prune away the token. + + The advantage of storing the extra_cost rather than the forward-cost, is that + it is less costly to keep the extra_cost up-to-date when we process new frames. + When we process a new frame, *all* the previous frames' forward-costs would change; + but in general the extra_cost will change only for a finite number of frames. + (Actually we don't update all the extra_costs every time we update a frame; we + only do it every 'config_.prune_interval' frames). + */ + +// FindOrAddToken either locates a token in hash map "token_map" +// or if necessary inserts a new, empty token (i.e. with no forward links) +// for the current frame. [note: it's inserted if necessary into hash toks_ +// and also into the singly linked list of tokens active on this frame +// (whose head is at active_toks_[frame]). +template +inline Token* LatticeFasterDecoderCombineTpl::FindOrAddToken( + StateId state, int32 frame_plus_one, BaseFloat tot_cost, Token *backpointer, + StateIdToTokenMap *token_map, bool *changed) { + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + KALDI_ASSERT(frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + Elem *e_found = token_map->Find(state); + if (e_found == NULL) { // no such token presently. + const BaseFloat extra_cost = 0.0; + // tokens on the currently final frame have zero extra_cost + // as any of them could end up + // on the winning path. + Token *new_tok = new Token (tot_cost, extra_cost, NULL, toks, backpointer); + // NULL: no forward links yet + toks = new_tok; + num_toks_++; + // insert into the map + token_map->Insert(state, new_tok); + if (changed) *changed = true; + return new_tok; + } else { + Token *tok = e_found->val; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { // replace old token + tok->tot_cost = tot_cost; + // SetBackpointer() just does tok->backpointer = backpointer in + // the case where Token == BackpointerToken, else nothing. + tok->SetBackpointer(backpointer); + // we don't allocate a new token, the old stays linked in active_toks_ + // we only replace the tot_cost + // in the current frame, there are no forward links (and no extra_cost) + // only in ProcessNonemitting we have to delete forward links + // in case we visit a state for the second time + // those forward links, that lead to this replaced token before: + // they remain and will hopefully be pruned later (PruneForwardLinks...) + if (changed) *changed = true; + } else { + if (changed) *changed = false; + } + return tok; + } +} + + +// prunes outgoing links for all tokens in active_toks_[frame] +// it's called by PruneActiveTokens +// all links, that have link_extra_cost > lattice_beam are pruned +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinks( + int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, BaseFloat delta) { + // delta is the amount by which the extra_costs must change + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + + *extra_costs_changed = false; + *links_pruned = false; + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen. + if (!warned_) { + KALDI_WARN << "No tokens alive [doing pruning].. warning first " + "time only for each utterance\n"; + warned_ = true; + } + } + + // We have to iterate until there is no more change, because the links + // are not guaranteed to be in topological order. + bool changed = true; // difference new minus old extra cost >= delta ? + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost for tok. + BaseFloat tok_extra_cost = std::numeric_limits::infinity(); + // tok_extra_cost is the best (min) of link_extra_cost of outgoing links + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); // difference in brackets is >= 0 + // link_exta_cost is the difference in score between the best paths + // through link source state and through link destination state + KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + *links_pruned = true; + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; // move to next link + link = link->next; + } + } // for all outgoing links + if (fabs(tok_extra_cost - tok->extra_cost) > delta) + changed = true; // difference new minus old is bigger than delta + tok->extra_cost = tok_extra_cost; + // will be +infinity or <= lattice_beam_. + // infinity indicates, that no forward link survived pruning + } // for all Token on active_toks_[frame] + if (changed) *extra_costs_changed = true; + + // Note: it's theoretically possible that aggressive compiler + // optimizations could cause an infinite loop here for small delta and + // high-dynamic-range scores. + } // while changed +} + +// PruneForwardLinksFinal is a version of PruneForwardLinks that we call +// on the final frame. If there are final tokens active, it uses +// the final-probs for pruning, otherwise it treats all tokens as final. +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinksFinal() { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame_plus_one = active_toks_.size() - 1; + + if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. + KALDI_WARN << "No tokens alive at end of file"; + + typedef typename unordered_map::const_iterator IterType; + ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); + decoding_finalized_ = true; + + // Now go through tokens on this frame, pruning forward links... may have to + // iterate a few times until there is no more change, because the list is not + // in topological order. This is a modified version of the code in + // PruneForwardLinks, but here we also take account of the final-probs. + bool changed = true; + BaseFloat delta = 1.0e-05; + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost. It has a term in it that corresponds + // to the "final-prob", so instead of initializing tok_extra_cost to infinity + // below we set it to the difference between the (score+final_prob) of this token, + // and the best such (score+final_prob). + BaseFloat final_cost; + if (final_costs_.empty()) { + final_cost = 0.0; + } else { + IterType iter = final_costs_.find(tok); + if (iter != final_costs_.end()) + final_cost = iter->second; + else + final_cost = std::numeric_limits::infinity(); + } + BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_; + // tok_extra_cost will be a "min" over either directly being final, or + // being indirectly final through other links, and the loop below may + // decrease its value: + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; + link = link->next; + } + } + // prune away tokens worse than lattice_beam above best path. This step + // was not necessary in the non-final case because then, this case + // showed up as having no forward links. Here, the tok_extra_cost has + // an extra component relating to the final-prob. + if (tok_extra_cost > config_.lattice_beam) + tok_extra_cost = std::numeric_limits::infinity(); + // to be pruned in PruneTokensForFrame + + if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) + changed = true; + tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. + } + } // while changed +} + +template +BaseFloat LatticeFasterDecoderCombineTpl::FinalRelativeCost() const { + if (!decoding_finalized_) { + BaseFloat relative_cost; + ComputeFinalCosts(NULL, &relative_cost, NULL); + return relative_cost; + } else { + // we're not allowed to call that function if FinalizeDecoding() has + // been called; return a cached value. + return final_relative_cost_; + } +} + + +// Prune away any tokens on this frame that have no forward links. +// [we don't do this in PruneForwardLinks because it would give us +// a problem with dangling pointers]. +// It's called by PruneActiveTokens if any forward links have been pruned +template +void LatticeFasterDecoderCombineTpl::PruneTokensForFrame( + int32 frame_plus_one) { + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + if (toks == NULL) + KALDI_WARN << "No tokens alive [doing pruning]"; + Token *tok, *next_tok, *prev_tok = NULL; + for (tok = toks; tok != NULL; tok = next_tok) { + next_tok = tok->next; + if (tok->extra_cost == std::numeric_limits::infinity()) { + // token is unreachable from end of graph; (no forward links survived) + // excise tok from list and delete tok. + if (prev_tok != NULL) prev_tok->next = tok->next; + else toks = tok->next; + delete tok; + num_toks_--; + } else { // fetch next Token + prev_tok = tok; + } + } +} + +// Go backwards through still-alive tokens, pruning them, starting not from +// the current frame (where we want to keep all tokens) but from the frame before +// that. We go backwards through the frames and stop when we reach a point +// where the delta-costs are not changing (and the delta controls when we consider +// a cost to have "not changed"). +template +void LatticeFasterDecoderCombineTpl::PruneActiveTokens( + BaseFloat delta) { + int32 cur_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // The index "f" below represents a "frame plus one", i.e. you'd have to subtract + // one to get the corresponding index for the decodable object. + for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) { + // Reason why we need to prune forward links in this situation: + // (1) we have never pruned them (new TokenList) + // (2) we have not yet pruned the forward links to the next f, + // after any of those tokens have changed their extra_cost. + if (active_toks_[f].must_prune_forward_links) { + bool extra_costs_changed = false, links_pruned = false; + PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); + if (extra_costs_changed && f > 0) // any token has changed extra_cost + active_toks_[f-1].must_prune_forward_links = true; + if (links_pruned) // any link was pruned + active_toks_[f].must_prune_tokens = true; + active_toks_[f].must_prune_forward_links = false; // job done + } + if (f+1 < cur_frame_plus_one && // except for last f (no forward links) + active_toks_[f+1].must_prune_tokens) { + PruneTokensForFrame(f+1); + active_toks_[f+1].must_prune_tokens = false; + } + } + KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +template +void LatticeFasterDecoderCombineTpl::ComputeFinalCosts( + unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const { + KALDI_ASSERT(!decoding_finalized_); + if (final_costs != NULL) + final_costs->clear(); + BaseFloat infinity = std::numeric_limits::infinity(); + BaseFloat best_cost = infinity, + best_cost_with_final = infinity; + + // The final tokens are recorded in unordered_map "cur_toks_". + const Elem *final_toks = cur_toks_->GetList(); + while (final_toks != NULL) { + StateId state = final_toks->key; + Token *tok = final_toks->val; + const Elem *next = final_toks->tail; + BaseFloat final_cost = fst_->Final(state).Value(); + BaseFloat cost = tok->tot_cost, + cost_with_final = cost + final_cost; + best_cost = std::min(cost, best_cost); + best_cost_with_final = std::min(cost_with_final, best_cost_with_final); + if (final_costs != NULL && final_cost != infinity) + (*final_costs)[tok] = final_cost; + final_toks = next; + } + if (final_relative_cost != NULL) { + if (best_cost == infinity && best_cost_with_final == infinity) { + // Likely this will only happen if there are no tokens surviving. + // This seems the least bad way to handle it. + *final_relative_cost = infinity; + } else { + *final_relative_cost = best_cost_with_final - best_cost; + } + } + if (final_best_cost != NULL) { + if (best_cost_with_final != infinity) { // final-state exists. + *final_best_cost = best_cost_with_final; + } else { // no final-state exists. + *final_best_cost = best_cost; + } + } +} + +template +void LatticeFasterDecoderCombineTpl::AdvanceDecoding( + DecodableInterface *decodable, + int32 max_num_frames) { + if (std::is_same >::value) { + // if the type 'FST' is the FST base-class, then see if the FST type of fst_ + // is actually VectorFst or ConstFst. If so, call the AdvanceDecoding() + // function after casting *this to the more specific type. + if (fst_->Type() == "const") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } else if (fst_->Type() == "vector") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } + } + + + KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ && + "You must call InitDecoding() before AdvanceDecoding"); + int32 num_frames_ready = decodable->NumFramesReady(); + // num_frames_ready must be >= num_frames_decoded, or else + // the number of frames ready must have decreased (which doesn't + // make sense) or the decodable object changed between calls + // (which isn't allowed). + KALDI_ASSERT(num_frames_ready >= NumFramesDecoded()); + int32 target_frames_decoded = num_frames_ready; + if (max_num_frames >= 0) + target_frames_decoded = std::min(target_frames_decoded, + NumFramesDecoded() + max_num_frames); + while (NumFramesDecoded() < target_frames_decoded) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + ProcessForFrame(decodable); + } +} + +// FinalizeDecoding() is a version of PruneActiveTokens that we call +// (optionally) on the final frame. Takes into account the final-prob of +// tokens. This function used to be called PruneActiveTokensFinal(). +template +void LatticeFasterDecoderCombineTpl::FinalizeDecoding() { + ProcessNonemitting(NULL); + int32 final_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // PruneForwardLinksFinal() prunes final frame (with final-probs), and + // sets decoding_finalized_. + PruneForwardLinksFinal(); + for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { + bool b1, b2; // values not used. + BaseFloat dontcare = 0.0; // delta of zero means we must always update + PruneForwardLinks(f, &b1, &b2, dontcare); + PruneTokensForFrame(f + 1); + } + PruneTokensForFrame(0); + DeleteElems(prev_toks_->Clear(), prev_toks_); + DeleteElems(cur_toks_->Clear(), cur_toks_); + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +/// Gets the weight cutoff. +template +BaseFloat LatticeFasterDecoderCombineTpl::GetCutoff( + const Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, + Elem **best_elem) { + BaseFloat best_weight = std::numeric_limits::infinity(); + // positive == high cost == bad. + size_t count = 0; + if (config_.max_active == std::numeric_limits::max() && + config_.min_active == 0) { + for (const Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = static_cast(e->val->tot_cost); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = const_cast(e); + } + } + if (tok_count != NULL) *tok_count = count; + if (adaptive_beam != NULL) *adaptive_beam = config_.beam; + return best_weight + config_.beam; + } else { + tmp_array_.clear(); + for (const Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = e->val->tot_cost; + tmp_array_.push_back(w); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = const_cast(e); + } + } + if (tok_count != NULL) *tok_count = count; + + BaseFloat beam_cutoff = best_weight + config_.beam, + min_active_cutoff = std::numeric_limits::infinity(), + max_active_cutoff = std::numeric_limits::infinity(); + + KALDI_VLOG(5) << "Number of tokens active on frame " << NumFramesDecoded() + << " is " << tmp_array_.size(); + + if (tmp_array_.size() > static_cast(config_.max_active)) { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.max_active, + tmp_array_.end()); + max_active_cutoff = tmp_array_[config_.max_active]; + } + if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. + if (adaptive_beam) + *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta; + return max_active_cutoff; + } + if (tmp_array_.size() > static_cast(config_.min_active)) { + if (config_.min_active == 0) min_active_cutoff = best_weight; + else { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.min_active, + tmp_array_.size() > static_cast(config_.max_active) ? + tmp_array_.begin() + config_.max_active : + tmp_array_.end()); + min_active_cutoff = tmp_array_[config_.min_active]; + } + } + if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. + if (adaptive_beam) + *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta; + return min_active_cutoff; + } else { + *adaptive_beam = config_.beam; + return beam_cutoff; + } + } +} + +template +void LatticeFasterDecoderCombineTpl::PossiblyResizeHash( + size_t num_toks) { + size_t new_sz = static_cast(static_cast(num_toks) + * config_.hash_ratio); + if (new_sz > cur_toks_->Size()) { + cur_toks_->SetSize(new_sz); + } +} + +template +void LatticeFasterDecoderCombineTpl::ProcessForFrame( + DecodableInterface *decodable) { + KALDI_ASSERT(active_toks_.size() > 0); + int32 frame = active_toks_.size() - 1; // frame is the frame-index + // (zero-based) used to get likelihoods + // from the decodable object. + active_toks_.resize(active_toks_.size() + 1); + + StateIdToTokenMap *tmp = prev_toks_; + prev_toks_ = cur_toks_; + cur_toks_ = tmp; + DeleteElems(cur_toks_->Clear(), cur_toks_); + + if (prev_toks_->GetList() == NULL) { + if (!warned_) { + KALDI_WARN << "Error, no surviving tokens on frame " << frame; + warned_ = true; + } + } + + Elem *best_elem = NULL; + BaseFloat adaptive_beam; + size_t tok_cnt; + // "cur_cutoff" is used to constrain the epsilon emittion in current frame. + // It will not be updated. + BaseFloat cur_cutoff = GetCutoff(prev_toks_->GetList(), &tok_cnt, + &adaptive_beam, &best_elem); + KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " + << adaptive_beam; + + PossiblyResizeHash(tok_cnt); + + // pruning "online" before having seen all tokens + + // "next_cutoff" is used to limit a new token in next frame should be handle + // or not. It will be updated along with the further processing. + BaseFloat next_cutoff = std::numeric_limits::infinity(); + // "cost_offset" contains the acoustic log-likelihoods on current frame in + // order to keep everything in a nice dynamic range. Reduce roundoff errors. + BaseFloat cost_offset = 0.0; + + // First process the best token to get a hopefully + // reasonably tight bound on the next cutoff. The only + // products of the next block are "next_cutoff" and "cost_offset". + // Notice: As the difference between the combine version and the traditional + // version, this "best_tok" is choosen from emittion tokens. Normally, the + // best token of one frame comes from an epsilon non-emittion. So the best + // token is a looser boundary. We use it to estimate a bound on the next + // cutoff and we will update the "next_cutoff" once we have better tokens. + // The "next_cutoff" will be updated in further processing. + if (best_elem) { + StateId state = best_elem->key; + Token *best_tok = best_elem->val; + cost_offset = - best_tok->tot_cost; + for(fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + // ac_cost + graph_cost + BaseFloat new_weight = arc.weight.Value() + cost_offset - + decodable->LogLikelihood(frame, arc.ilabel) + best_tok->tot_cost; + if (new_weight + adaptive_beam < next_cutoff) + next_cutoff = new_weight + adaptive_beam; + } + } + } + + // Store the offset on the acoustic likelihoods that we're applying. + // Could just do cost_offsets_.push_back(cost_offset), but we + // do it this way as it's more robust to future code changes. + cost_offsets_.resize(frame + 1, 0.0); + cost_offsets_[frame] = cost_offset; + + // Build a queue which contains the emittion tokens from previous frame. + for (const Elem *e = prev_toks_->GetList(); e != NULL; e = e->tail) { + cur_queue_.push(e->key); + e->val->in_current_queue = true; + } + + // Iterator the "cur_queue_" to process non-emittion and emittion arcs in fst. + while (!cur_queue_.empty()) { + StateId state = cur_queue_.front(); + cur_queue_.pop(); + + //KALDI_ASSERT(prev_toks_.find(state) != prev_toks_.end()); + //Token *tok = prev_toks_[state]; + Token *tok = prev_toks_->Find(state)->val; + + BaseFloat cur_cost = tok->tot_cost; + tok->in_current_queue = false; // out of queue + if (cur_cost > cur_cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, + tok, prev_toks_, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed && !new_tok->in_current_queue) { + cur_queue_.push(arc.nextstate); + new_tok->in_current_queue = true; + } + } + } else { // propagate emitting + BaseFloat graph_cost = arc.weight.Value(), + ac_cost = cost_offset - decodable->LogLikelihood(frame, arc.ilabel), + cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + if (tot_cost > next_cutoff) continue; + else if (tot_cost + adaptive_beam < next_cutoff) + next_cutoff = tot_cost + adaptive_beam; // a tighter boundary for emitting + + // no change flag is needed + Token *next_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, + tok, cur_toks_, NULL); + // Add ForwardLink from tok to next_tok. Put it on the head of tok->link + // list + tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, + graph_cost, ac_cost, tok->links); + } + } // for all arcs + } // end of while loop +} + + +template +void LatticeFasterDecoderCombineTpl::ProcessNonemitting( + std::unordered_map *token_orig_cost) { + if (token_orig_cost) { // Build the elements which are used to recover + for (const Elem *e = cur_toks_->GetList(); e != NULL; e = e->tail) { + (*token_orig_cost)[e->val] = e->val->tot_cost; + } + } + + StateIdToTokenMap *tmp_toks = cur_toks_; + + int32 frame = active_toks_.size() - 1; + // Build the queue to process non-emitting arcs. + for (const Elem *e = tmp_toks->GetList(); e != NULL; e = e->tail) { + if (fst_->NumInputEpsilons(e->key) != 0) { + cur_queue_.push(e->key); + e->val->in_current_queue = true; + } + } + + // "cur_cutoff" is used to constrain the epsilon emittion in current frame. + // It will not be updated. + BaseFloat adaptive_beam; + BaseFloat cur_cutoff = GetCutoff(tmp_toks->GetList(), NULL, &adaptive_beam, NULL); + + while (!cur_queue_.empty()) { + StateId state = cur_queue_.front(); + cur_queue_.pop(); + + Token *tok = tmp_toks->Find(state)->val; + BaseFloat cur_cost = tok->tot_cost; + if (cur_cost > cur_cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, + tok, tmp_toks, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed && !new_tok->in_current_queue) { + cur_queue_.push(arc.nextstate); + new_tok->in_current_queue = true; + } + } + } + } // end of for loop + tok->in_current_queue = false; + } // end of while loop +} + + + +// static inline +template +void LatticeFasterDecoderCombineTpl::DeleteForwardLinks(Token *tok) { + ForwardLinkT *l = tok->links, *m; + while (l != NULL) { + m = l->next; + delete l; + l = m; + } + tok->links = NULL; +} + + +template +void LatticeFasterDecoderCombineTpl::ClearActiveTokens() { + // a cleanup routine, at utt end/begin + for (size_t i = 0; i < active_toks_.size(); i++) { + // Delete all tokens alive on this frame, and any forward + // links they may have. + for (Token *tok = active_toks_[i].toks; tok != NULL; ) { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + active_toks_.clear(); + KALDI_ASSERT(num_toks_ == 0); +} + +// static +template +void LatticeFasterDecoderCombineTpl::TopSortTokens( + Token *tok_list, std::vector *topsorted_list) { + unordered_map token2pos; + typedef typename unordered_map::iterator IterType; + int32 num_toks = 0; + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + num_toks++; + int32 cur_pos = 0; + // We assign the tokens numbers num_toks - 1, ... , 2, 1, 0. + // This is likely to be in closer to topological order than + // if we had given them ascending order, because of the way + // new tokens are put at the front of the list. + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + token2pos[tok] = num_toks - ++cur_pos; + + unordered_set reprocess; + + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) { + Token *tok = iter->first; + int32 pos = iter->second; + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + // We only need to consider epsilon links, since non-epsilon links + // transition between frames and this function only needs to sort a list + // of tokens from a single frame. + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { // another token on this frame, + // so must consider it. + int32 next_pos = following_iter->second; + if (next_pos < pos) { // reassign the position of the next Token. + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + // In case we had previously assigned this token to be reprocessed, we can + // erase it from that set because it's "happy now" (we just processed it). + reprocess.erase(tok); + } + + size_t max_loop = 1000000, loop_count; // max_loop is to detect epsilon cycles. + for (loop_count = 0; + !reprocess.empty() && loop_count < max_loop; ++loop_count) { + std::vector reprocess_vec; + for (typename unordered_set::iterator iter = reprocess.begin(); + iter != reprocess.end(); ++iter) + reprocess_vec.push_back(*iter); + reprocess.clear(); + for (typename std::vector::iterator iter = reprocess_vec.begin(); + iter != reprocess_vec.end(); ++iter) { + Token *tok = *iter; + int32 pos = token2pos[tok]; + // Repeat the processing we did above (for comments, see above). + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { + int32 next_pos = following_iter->second; + if (next_pos < pos) { + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + } + } + KALDI_ASSERT(loop_count < max_loop && "Epsilon loops exist in your decoding " + "graph (this is not allowed!)"); + + topsorted_list->clear(); + topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between. + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) + (*topsorted_list)[iter->second] = iter->first; +} + +// Instantiate the template for the combination of token types and FST types +// that we'll need. +template class LatticeFasterDecoderCombineTpl, decodercombine::StdToken>; +template class LatticeFasterDecoderCombineTpl, decodercombine::StdToken >; +template class LatticeFasterDecoderCombineTpl, decodercombine::StdToken >; +template class LatticeFasterDecoderCombineTpl; + +template class LatticeFasterDecoderCombineTpl , decodercombine::BackpointerToken>; +template class LatticeFasterDecoderCombineTpl, decodercombine::BackpointerToken >; +template class LatticeFasterDecoderCombineTpl, decodercombine::BackpointerToken >; +template class LatticeFasterDecoderCombineTpl; + + +} // end namespace kaldi. diff --git a/src/decoder/lattice-faster-decoder-combine-hashlist.h b/src/decoder/lattice-faster-decoder-combine-hashlist.h new file mode 100644 index 00000000000..e3b45f444de --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine-hashlist.h @@ -0,0 +1,567 @@ +// decoder/lattice-faster-decoder-combine.h + +// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ +#define KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ + + +#include "util/stl-utils.h" +#include "util/hash-list.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/grammar-fst.h" +#include "decoder/lattice-faster-decoder.h" + +namespace kaldi { + +struct LatticeFasterDecoderCombineConfig { + BaseFloat beam; + int32 max_active; + int32 min_active; + BaseFloat lattice_beam; + int32 prune_interval; + bool determinize_lattice; // not inspected by this class... used in + // command-line program. + BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat hash_ratio; + BaseFloat prune_scale; // Note: we don't make this configurable on the command line, + // it's not a very important parameter. It affects the + // algorithm that prunes the tokens as we go. + // Most of the options inside det_opts are not actually queried by the + // LatticeFasterDecoder class itself, but by the code that calls it, for + // example in the function DecodeUtteranceLatticeFaster. + fst::DeterminizeLatticePhonePrunedOptions det_opts; + + LatticeFasterDecoderCombineConfig(): beam(16.0), + max_active(std::numeric_limits::max()), + min_active(200), + lattice_beam(10.0), + prune_interval(25), + determinize_lattice(true), + beam_delta(0.5), + hash_ratio(2.0), + prune_scale(0.1) { } + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); + opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; " + "more accurate"); + opts->Register("min-active", &min_active, "Decoder minimum #active states."); + opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, " + "and deeper lattices"); + opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " + "which to prune tokens"); + opts->Register("determinize-lattice", &determinize_lattice, "If true, " + "determinize the lattice (lattice-determinization, keeping only " + "best pdf-sequence for each word-sequence)."); + opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " + "parameter is obscure and relates to a speedup in the way the " + "max-active constraint is applied. Larger is more accurate."); + opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " + "control hash behavior"); + } + void Check() const { + KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 + && min_active <= max_active + && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0 + && prune_scale > 0.0 && prune_scale < 1.0); + } +}; + + +namespace decodercombine { +// We will template the decoder on the token type as well as the FST type; this +// is a mechanism so that we can use the same underlying decoder code for +// versions of the decoder that support quickly getting the best path +// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also +// those that do not (LatticeFasterDecoder). + + +// ForwardLinks are the links from a token to a token on the next frame. +// or sometimes on the current frame (for input-epsilon links). +template +struct ForwardLink { + using Label = fst::StdArc::Label; + + Token *next_tok; // the next token [or NULL if represents final-state] + Label ilabel; // ilabel on arc + Label olabel; // olabel on arc + BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.) + BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc + ForwardLink *next; // next in singly-linked list of forward arcs (arcs + // in the state-level lattice) from a token. + inline ForwardLink(Token *next_tok, Label ilabel, Label olabel, + BaseFloat graph_cost, BaseFloat acoustic_cost, + ForwardLink *next): + next_tok(next_tok), ilabel(ilabel), olabel(olabel), + graph_cost(graph_cost), acoustic_cost(acoustic_cost), + next(next) { } +}; + + +struct StdToken { + using ForwardLinkT = ForwardLink; + using Token = StdToken; + + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + Token *next; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_current_queue; + + // This function does nothing and should be optimized out; it's needed + // so we can share the regular LatticeFasterDecoderTpl code and the code + // for LatticeFasterOnlineDecoder that supports fast traceback. + inline void SetBackpointer (Token *backpointer) { } + + // This constructor just ignores the 'backpointer' argument. That argument is + // needed so that we can use the same decoder code for LatticeFasterDecoderTpl + // and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a + // fast way to obtain the best path). + inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next), + in_current_queue(false) { } +}; + +struct BackpointerToken { + using ForwardLinkT = ForwardLink; + using Token = BackpointerToken; + + // BackpointerToken is like Token but also + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + BackpointerToken *next; + + // Best preceding BackpointerToken (could be a on this frame, connected to + // this via an epsilon transition, or on a previous frame). This is only + // required for an efficient GetBestPath function in + // LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation + // (the "links" list is what stores the forward links, for that). + Token *backpointer; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_current_queue; + + inline void SetBackpointer (Token *backpointer) { + this->backpointer = backpointer; + } + + inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next), + backpointer(backpointer), in_current_queue(false) { } +}; + +} // namespace decoder + + +/** This is the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to support + quick lookup of the current best path (see lattice-faster-online-decoder.h) + + The FST you invoke this decoder with is expected to equal + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ +template +class LatticeFasterDecoderCombineTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ForwardLinkT = decodercombine::ForwardLink; + + using StateIdToTokenMap = HashList; + using Elem = typename HashList::Elem; + //using StateIdToTokenMap = typename std::unordered_map; + //using StateIdToTokenMap = typename std::unordered_map, std::equal_to, + // fst::PoolAllocator > >; + //using IterType = typename StateIdToTokenMap::const_iterator; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeFasterDecoderCombineTpl(const FST &fst, + const LatticeFasterDecoderCombineConfig &config); + + // This version of the constructor takes ownership of the fst, and will delete + // it when this object is destroyed. + LatticeFasterDecoderCombineTpl(const LatticeFasterDecoderCombineConfig &config, + FST *fst); + + void SetOptions(const LatticeFasterDecoderCombineConfig &config) { + config_ = config; + } + + const LatticeFasterDecoderCombineConfig &GetOptions() const { + return config_; + } + + ~LatticeFasterDecoderCombineTpl(); + + /// Decodes until there are no more frames left in the "decodable" object.. + /// note, this may block waiting for input if the "decodable" object blocks. + /// Returns true if any kind of traceback is available (not necessarily from a + /// final state). + bool Decode(DecodableInterface *decodable); + + + /// says whether a final-state was active on the last frame. If it was not, the + /// lattice (or traceback) will end with states that are not final-states. + bool ReachedFinal() const { + return FinalRelativeCost() != std::numeric_limits::infinity(); + } + + /// Outputs an FST corresponding to the single best path through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. Note: this just calls GetRawLattice() + /// and figures out the shortest path. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true); + + /// Outputs an FST corresponding to the raw, state-level + /// tracebacks. Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state + /// of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + /// The raw lattice will be topologically sorted. + /// The function can be called during decoding, it will process non-emitting + /// arcs from "cur_toks_" map to get tokens from both non-emitting and + /// emitting arcs for getting raw lattice. Then recover it to ensure the + /// consistency of ProcessForFrame(). + /// + /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, + /// which also supports a pruning beam, in case for some reason + /// you want it pruned tighter than the regular lattice beam. + /// We could put that here in future needed. + bool GetRawLattice(Lattice *ofst, bool use_final_probs = true); + + + + /// [Deprecated, users should now use GetRawLattice and determinize it + /// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper]. + /// Outputs an FST corresponding to the lattice-determinized + /// lattice (one path per word sequence). Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state of the graph + /// then it will include those as final-probs, else it will treat all + /// final-probs as one. + bool GetLattice(CompactLattice *ofst, + bool use_final_probs = true); + + /// InitDecoding initializes the decoding, and should only be used if you + /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to + /// call this. You can also call InitDecoding if you have already decoded an + /// utterance and want to start with a new utterance. + void InitDecoding(); + + /// This will decode until there are no more frames ready in the decodable + /// object. You can keep calling it each time more frames become available. + /// If max_num_frames is specified, it specifies the maximum number of frames + /// the function will decode before returning. + void AdvanceDecoding(DecodableInterface *decodable, + int32 max_num_frames = -1); + + /// This function may be optionally called after AdvanceDecoding(), when you + /// do not plan to decode any further. It does an extra pruning step that + /// will help to prune the lattices output by GetLattice and (particularly) + /// GetRawLattice more accurately, particularly toward the end of the + /// utterance. It does this by using the final-probs in pruning (if any + /// final-state survived); it also does a final pruning step that visits all + /// states (the pruning that is done during decoding may fail to prune states + /// that are within kPruningScale = 0.1 outside of the beam). If you call + /// this, you cannot call AdvanceDecoding again (it will fail), and you + /// cannot call GetLattice() and related functions with use_final_probs = + /// false. + /// Used to be called PruneActiveTokensFinal(). + void FinalizeDecoding(); + + /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + /// more information. It returns the difference between the best (final-cost + /// plus cost) of any token on the final frame, and the best cost of any token + /// on the final frame. If it is infinity it means no final-states were + /// present on the final frame. It will usually be nonnegative. If it not + /// too positive (e.g. < 5 is my first guess, but this is not tested) you can + /// take it as a good indication that we reached the final-state with + /// reasonable likelihood. + BaseFloat FinalRelativeCost() const; + + + // Returns the number of frames decoded so far. The value returned changes + // whenever we call ProcessForFrame(). + inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } + + protected: + // we make things protected instead of private, as code in + // LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the + // internals. + + // Deletes the elements of the singly linked list tok->links. + inline static void DeleteForwardLinks(Token *tok); + + // head of per-frame list of Tokens (list is in topological order), + // and something saying whether we ever pruned it using PruneForwardLinks. + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + TokenList(): toks(NULL), must_prune_forward_links(true), + must_prune_tokens(true) { } + }; + + // FindOrAddToken either locates a token in hash map "token_map", or if necessary + // inserts a new, empty token (i.e. with no forward links) for the current + // frame. [note: it's inserted if necessary into hash map and also into the + // singly linked list of tokens active on this frame (whose head is at + // active_toks_[frame]). The frame_plus_one argument is the acoustic frame + // index plus one, which is used to index into the active_toks_ array. + // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the + // token was newly created or the cost changed. + // If Token == StdToken, the 'backpointer' argument has no purpose (and will + // hopefully be optimized out). + inline Token *FindOrAddToken(StateId state, int32 frame_plus_one, + BaseFloat tot_cost, Token *backpointer, + StateIdToTokenMap *token_map, + bool *changed); + + // prunes outgoing links for all tokens in active_toks_[frame] + // it's called by PruneActiveTokens + // all links, that have link_extra_cost > lattice_beam are pruned + // delta is the amount by which the extra_costs must change + // before we set *extra_costs_changed = true. + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, + BaseFloat delta); + + // This function computes the final-costs for tokens active on the final + // frame. It outputs to final-costs, if non-NULL, a map from the Token* + // pointer to the final-prob of the corresponding state, for all Tokens + // that correspond to states that have final-probs. This map will be + // empty if there were no final-probs. It outputs to + // final_relative_cost, if non-NULL, the difference between the best + // forward-cost including the final-prob cost, and the best forward-cost + // without including the final-prob cost (this will usually be positive), or + // infinity if there were no final-probs. [c.f. FinalRelativeCost(), which + // outputs this quanitity]. It outputs to final_best_cost, if + // non-NULL, the lowest for any token t active on the final frame, of + // forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in + // the graph of the state corresponding to token t, or the best of + // forward-cost[t] if there were no final-probs active on the final frame. + // You cannot call this after FinalizeDecoding() has been called; in that + // case you should get the answer from class-member variables. + void ComputeFinalCosts(unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const; + + // PruneForwardLinksFinal is a version of PruneForwardLinks that we call + // on the final frame. If there are final tokens active, it uses + // the final-probs for pruning, otherwise it treats all tokens as final. + void PruneForwardLinksFinal(); + + // Prune away any tokens on this frame that have no forward links. + // [we don't do this in PruneForwardLinks because it would give us + // a problem with dangling pointers]. + // It's called by PruneActiveTokens if any forward links have been pruned + void PruneTokensForFrame(int32 frame_plus_one); + + + // Go backwards through still-alive tokens, pruning them if the + // forward+backward cost is more than lat_beam away from the best path. It's + // possible to prove that this is "correct" in the sense that we won't lose + // anything outside of lat_beam, regardless of what happens in the future. + // delta controls when it considers a cost to have changed enough to continue + // going backward and propagating the change. larger delta -> will recurse + // less far. + void PruneActiveTokens(BaseFloat delta); + + /// Processes non-emitting (epsilon) arcs and emitting arcs for one frame + /// together. It takes the emittion tokens in "prev_toks_" from last frame. + /// Generates non-emitting tokens for previous frame and emitting tokens for + /// next frame. + /// Notice: The emitting tokens for the current frame means the token take + /// acoustic scores of the current frame. (i.e. the destnations of emitting + /// arcs.) + void ProcessForFrame(DecodableInterface *decodable); + + /// Processes nonemitting (epsilon) arcs for one frame. + /// Calls this function once when all frames were processed. + /// Or calls it in GetRawLattice() to generate the complete token list for + /// the last frame. [Deal With the tokens in map "cur_toks_" which would + /// only contains emittion tokens from previous frame.] + /// If the map, "token_orig_cost", isn't NULL, we build the map which will + /// be used to recover "active_toks_[last_frame]" token list for the last + /// frame. + void ProcessNonemitting(std::unordered_map *token_orig_cost); + + /// When GetRawLattice() is called during decoding, the + /// active_toks_[last_frame] is changed. To keep the consistency of function + /// ProcessForFrame(), recover it. + /// Notice: as new token will be added to the head of TokenList, tok->next + /// will not be affacted. + /// "token_orig_cost" is a mapping from token pointer to the tot_cost of the + /// token before propagating non-emitting arcs. It is used to recover the + /// change of original tokens in the last frame and remove the new tokens + /// which come from propagating non-emitting arcs, so that we can guarantee + /// the consistency of function ProcessForFrame(). + void RecoverLastTokenList( + const std::unordered_map &token_orig_cost); + + + /// The "prev_toks_" and "cur_toks_" actually allow us to maintain current + /// and next frames. They are indexed by StateId. It is indexed by frame-index + /// plus one, where the frame-index is zero-based, as used in decodable object. + /// That is, the emitting probs of frame t are accounted for in tokens at + /// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of + /// the graph. + StateIdToTokenMap *prev_toks_; + StateIdToTokenMap *cur_toks_; + + void PossiblyResizeHash(size_t num_toks); + + /// Gets the weight cutoff. + /// Notice: In traiditional version, the histogram prunning method is applied + /// on a complete token list on one frame. But, in this version, it is used + /// on a token list which only contains the emittion part. So the max_active + /// and min_active values might be narrowed. + BaseFloat GetCutoff(const Elem *list_head, size_t *tok_count, + BaseFloat *adaptive_beam, Elem **best_elem); + + + std::vector active_toks_; // Lists of tokens, indexed by + // frame (members of TokenList are toks, must_prune_forward_links, + // must_prune_tokens). + std::queue cur_queue_; // temp variable used in ProcessForFrame + // and ProcessNonemitting + std::vector tmp_array_; // used in GetCutoff. + + // fst_ is a pointer to the FST we are decoding from. + const FST *fst_; + // delete_fst_ is true if the pointer fst_ needs to be deleted when this + // object is destroyed. + bool delete_fst_; + + std::vector cost_offsets_; // This contains, for each + // frame, an offset that was added to the acoustic log-likelihoods on that + // frame in order to keep everything in a nice dynamic range i.e. close to + // zero, to reduce roundoff errors. + // Notice: It will only be added to emitting arcs (i.e. cost_offsets_[t] is + // added to arcs from "frame t" to "frame t+1"). + LatticeFasterDecoderCombineConfig config_; + int32 num_toks_; // current total #toks allocated... + bool warned_; + + /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, + /// calling this is optional]. If true, it's forbidden to decode more. Also, + /// if this is set, then the output of ComputeFinalCosts() is in the next + /// three variables. The reason we need to do this is that after + /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some + /// of the tokens on the last frame are freed, so we free the list from toks_ + /// to avoid having dangling pointers hanging around. + bool decoding_finalized_; + /// For the meaning of the next 3 variables, see the comment for + /// decoding_finalized_ above., and ComputeFinalCosts(). + unordered_map final_costs_; + BaseFloat final_relative_cost_; + BaseFloat final_best_cost_; + + // This function takes a singly linked list of tokens for a single frame, and + // outputs a list of them in topological order (it will crash if no such order + // can be found, which will typically be due to decoding graphs with epsilon + // cycles, which are not allowed). Note: the output list may contain NULLs, + // which the caller should pass over; it just happens to be more efficient for + // the algorithm to output a list that contains NULLs. + static void TopSortTokens(Token *tok_list, + std::vector *topsorted_list); + + void DeleteElems(Elem *list, HashList *toks); + void ClearActiveTokens(); + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderCombineTpl); +}; + +typedef LatticeFasterDecoderCombineTpl LatticeFasterDecoderCombine; + + + +} // end namespace kaldi. + +#endif From 2a0846347cfd64a6b074cc5d2324b2e7b67b3c48 Mon Sep 17 00:00:00 2001 From: Hang Lyu Date: Wed, 13 Mar 2019 18:19:52 -0400 Subject: [PATCH 12/29] iterator singly-list --- src/decoder/lattice-faster-decoder-combine.cc | 89 ++++++++++--------- src/decoder/lattice-faster-decoder-combine.h | 44 +++++---- 2 files changed, 78 insertions(+), 55 deletions(-) diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index b788f6505e6..fbb67729828 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -67,7 +67,7 @@ void LatticeFasterDecoderCombineTpl::InitDecoding() { StateId start_state = fst_->Start(); KALDI_ASSERT(start_state != fst::kNoStateId); active_toks_.resize(1); - Token *start_tok = new Token(0.0, 0.0, NULL, NULL, NULL); + Token *start_tok = new Token(0.0, 0.0, start_state, NULL, NULL, NULL); active_toks_[0].toks = start_tok; cur_toks_[start_state] = start_tok; // initialize current tokens map num_toks_++; @@ -295,19 +295,20 @@ bool LatticeFasterDecoderCombineTpl::GetLattice( // (whose head is at active_toks_[frame]). template inline Token* LatticeFasterDecoderCombineTpl::FindOrAddToken( - StateId state, int32 frame_plus_one, BaseFloat tot_cost, Token *backpointer, - StateIdToTokenMap *token_map, bool *changed) { + StateId state, int32 token_list_index, BaseFloat tot_cost, + Token *backpointer, StateIdToTokenMap *token_map, bool *changed) { // Returns the Token pointer. Sets "changed" (if non-NULL) to true // if the token was newly created or the cost changed. - KALDI_ASSERT(frame_plus_one < active_toks_.size()); - Token *&toks = active_toks_[frame_plus_one].toks; + KALDI_ASSERT(token_list_index < active_toks_.size()); + Token *&toks = active_toks_[token_list_index].toks; typename StateIdToTokenMap::iterator e_found = token_map->find(state); if (e_found == token_map->end()) { // no such token presently. const BaseFloat extra_cost = 0.0; // tokens on the currently final frame have zero extra_cost // as any of them could end up // on the winning path. - Token *new_tok = new Token (tot_cost, extra_cost, NULL, toks, backpointer); + Token *new_tok = new Token (tot_cost, extra_cost, state, + NULL, toks, backpointer); // NULL: no forward links yet toks = new_tok; num_toks_++; @@ -586,10 +587,10 @@ void LatticeFasterDecoderCombineTpl::ComputeFinalCosts( BaseFloat best_cost = infinity, best_cost_with_final = infinity; - // The final tokens are recorded in unordered_map "cur_toks_". - for (IterType iter = cur_toks_.begin(); iter != cur_toks_.end(); iter++) { - StateId state = iter->first; - Token *tok = iter->second; + // The final tokens are recorded in active_toks_[last_frame] + for (Token *tok = active_toks_[active_toks_.size() - 1].toks; tok != NULL; + tok = tok->next) { + StateId state = tok->state_id; BaseFloat final_cost = fst_->Final(state).Value(); BaseFloat cost = tok->tot_cost, cost_with_final = cost + final_cost; @@ -683,20 +684,20 @@ void LatticeFasterDecoderCombineTpl::FinalizeDecoding() { /// Gets the weight cutoff. template BaseFloat LatticeFasterDecoderCombineTpl::GetCutoff( - const StateIdToTokenMap &toks, BaseFloat *adaptive_beam, + const TokenList &token_list, BaseFloat *adaptive_beam, StateId *best_state_id, Token **best_token) { // positive == high cost == bad. // best_weight is the minimum value. BaseFloat best_weight = std::numeric_limits::infinity(); if (config_.max_active == std::numeric_limits::max() && config_.min_active == 0) { - for (IterType iter = toks.begin(); iter != toks.end(); iter++) { - BaseFloat w = static_cast(iter->second->tot_cost); + for (Token* tok = token_list.toks; tok != NULL; tok = tok->next) { + BaseFloat w = static_cast(tok->tot_cost); if (w < best_weight) { best_weight = w; if (best_token) { - *best_state_id = iter->first; - *best_token = iter->second; + *best_state_id = tok->state_id; + *best_token = tok; } } } @@ -704,14 +705,14 @@ BaseFloat LatticeFasterDecoderCombineTpl::GetCutoff( return best_weight + config_.beam; } else { tmp_array_.clear(); - for (IterType iter = toks.begin(); iter != toks.end(); iter++) { - BaseFloat w = static_cast(iter->second->tot_cost); + for (Token* tok = token_list.toks; tok != NULL; tok = tok->next) { + BaseFloat w = static_cast(tok->tot_cost); tmp_array_.push_back(w); if (w < best_weight) { best_weight = w; if (best_token) { - *best_state_id = iter->first; - *best_token = iter->second; + *best_state_id = tok->state_id; + *best_token = tok; } } } @@ -778,7 +779,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( StateId best_tok_state_id; // "cur_cutoff" is used to constrain the epsilon emittion in current frame. // It will not be updated. - BaseFloat cur_cutoff = GetCutoff(prev_toks_, &adaptive_beam, + BaseFloat cur_cutoff = GetCutoff(active_toks_[frame], &adaptive_beam, &best_tok_state_id, &best_tok); KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " << adaptive_beam; @@ -825,9 +826,9 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( cost_offsets_[frame] = cost_offset; // Build a queue which contains the emittion tokens from previous frame. - for (IterType iter = prev_toks_.begin(); iter != prev_toks_.end(); iter++) { - cur_queue_.push(iter->first); - iter->second->in_current_queue = true; + for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + cur_queue_.push(tok->state_id); + tok->in_current_queue = true; } // Iterator the "cur_queue_" to process non-emittion and emittion arcs in fst. @@ -898,9 +899,10 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( template void LatticeFasterDecoderCombineTpl::ProcessNonemitting( std::unordered_map *token_orig_cost) { + int32 frame = active_toks_.size() - 1; if (token_orig_cost) { // Build the elements which are used to recover - for (IterType iter = cur_toks_.begin(); iter != cur_toks_.end(); iter++) { - (*token_orig_cost)[iter->second] = iter->second->tot_cost; + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + (*token_orig_cost)[tok] = tok->tot_cost; } } @@ -913,19 +915,18 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( tmp_toks = &cur_toks_; } - int32 frame = active_toks_.size() - 1; // Build the queue to process non-emitting arcs. - for (IterType iter = tmp_toks->begin(); iter != tmp_toks->end(); iter++) { - if (fst_->NumInputEpsilons(iter->first) != 0) { - cur_queue_.push(iter->first); - iter->second->in_current_queue = true; + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + if (fst_->NumInputEpsilons(tok->state_id) != 0) { + cur_queue_.push(tok->state_id); + tok->in_current_queue = true; } } // "cur_cutoff" is used to constrain the epsilon emittion in current frame. // It will not be updated. BaseFloat adaptive_beam; - BaseFloat cur_cutoff = GetCutoff(*tmp_toks, &adaptive_beam, NULL, NULL); + BaseFloat cur_cutoff = GetCutoff(active_toks_[frame], &adaptive_beam, NULL, NULL); while (!cur_queue_.empty()) { StateId state = cur_queue_.front(); @@ -1086,15 +1087,23 @@ void LatticeFasterDecoderCombineTpl::TopSortTokens( // Instantiate the template for the combination of token types and FST types // that we'll need. -template class LatticeFasterDecoderCombineTpl, decodercombine::StdToken>; -template class LatticeFasterDecoderCombineTpl, decodercombine::StdToken >; -template class LatticeFasterDecoderCombineTpl, decodercombine::StdToken >; -template class LatticeFasterDecoderCombineTpl; - -template class LatticeFasterDecoderCombineTpl , decodercombine::BackpointerToken>; -template class LatticeFasterDecoderCombineTpl, decodercombine::BackpointerToken >; -template class LatticeFasterDecoderCombineTpl, decodercombine::BackpointerToken >; -template class LatticeFasterDecoderCombineTpl; +template class LatticeFasterDecoderCombineTpl, + decodercombine::StdToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::StdToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::StdToken > >; +template class LatticeFasterDecoderCombineTpl >; + +template class LatticeFasterDecoderCombineTpl , + decodercombine::BackpointerToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::BackpointerToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::BackpointerToken > >; +template class LatticeFasterDecoderCombineTpl >; } // end namespace kaldi. diff --git a/src/decoder/lattice-faster-decoder-combine.h b/src/decoder/lattice-faster-decoder-combine.h index 799cf20c872..900d03520e4 100644 --- a/src/decoder/lattice-faster-decoder-combine.h +++ b/src/decoder/lattice-faster-decoder-combine.h @@ -120,10 +120,11 @@ struct ForwardLink { next(next) { } }; - +template struct StdToken { using ForwardLinkT = ForwardLink; using Token = StdToken; + using StateId = typename Fst::Arc::StateId; // Standard token type for LatticeFasterDecoder. Each active HCLG // (decoding-graph) state on each frame has one token. @@ -141,6 +142,9 @@ struct StdToken { // one by one and compute this difference, and then take the minimum). BaseFloat extra_cost; + // Record the state id of the token + StateId state_id; + // 'links' is the head of singly-linked list of ForwardLinks, which is what we // use for lattice generation. ForwardLinkT *links; @@ -152,6 +156,7 @@ struct StdToken { // function ProcessOneFrame(). bool in_current_queue; + // This function does nothing and should be optimized out; it's needed // so we can share the regular LatticeFasterDecoderTpl code and the code // for LatticeFasterOnlineDecoder that supports fast traceback. @@ -161,15 +166,17 @@ struct StdToken { // needed so that we can use the same decoder code for LatticeFasterDecoderTpl // and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a // fast way to obtain the best path). - inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, - Token *next, Token *backpointer): - tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next), - in_current_queue(false) { } + inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, StateId state_id, + ForwardLinkT *links, Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), state_id(state_id), + links(links), next(next), in_current_queue(false) { } }; +template struct BackpointerToken { using ForwardLinkT = ForwardLink; using Token = BackpointerToken; + using StateId = typename Fst::Arc::StateId; // BackpointerToken is like Token but also // Standard token type for LatticeFasterDecoder. Each active HCLG @@ -188,6 +195,9 @@ struct BackpointerToken { // one by one and compute this difference, and then take the minimum). BaseFloat extra_cost; + // Record the state id of the token + StateId state_id; + // 'links' is the head of singly-linked list of ForwardLinks, which is what we // use for lattice generation. ForwardLinkT *links; @@ -210,10 +220,12 @@ struct BackpointerToken { this->backpointer = backpointer; } - inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, + inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, + StateId state_id, ForwardLinkT *links, Token *next, Token *backpointer): - tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next), - backpointer(backpointer), in_current_queue(false) { } + tot_cost(tot_cost), extra_cost(extra_cost), state_id(state_id), + links(links), next(next), backpointer(backpointer), + in_current_queue(false) { } }; } // namespace decoder @@ -234,7 +246,7 @@ struct BackpointerToken { will internally cast itself to one that is templated on those more specific types; this is an optimization for speed. */ -template +template > class LatticeFasterDecoderCombineTpl { public: using Arc = typename FST::Arc; @@ -243,9 +255,10 @@ class LatticeFasterDecoderCombineTpl { using Weight = typename Arc::Weight; using ForwardLinkT = decodercombine::ForwardLink; - using StateIdToTokenMap = typename std::unordered_map, std::equal_to, - fst::PoolAllocator > >; + using StateIdToTokenMap = typename std::unordered_map; + //using StateIdToTokenMap = typename std::unordered_map, std::equal_to, + // fst::PoolAllocator > >; using IterType = typename StateIdToTokenMap::const_iterator; // Instantiate this class once for each thing you have to decode. @@ -390,7 +403,7 @@ class LatticeFasterDecoderCombineTpl { // token was newly created or the cost changed. // If Token == StdToken, the 'backpointer' argument has no purpose (and will // hopefully be optimized out). - inline Token *FindOrAddToken(StateId state, int32 frame_plus_one, + inline Token *FindOrAddToken(StateId state, int32 token_list_index, BaseFloat tot_cost, Token *backpointer, StateIdToTokenMap *token_map, bool *changed); @@ -496,7 +509,7 @@ class LatticeFasterDecoderCombineTpl { /// on a complete token list on one frame. But, in this version, it is used /// on a token list which only contains the emittion part. So the max_active /// and min_active values might be narrowed. - BaseFloat GetCutoff(const StateIdToTokenMap& toks, + BaseFloat GetCutoff(const TokenList &token_list, BaseFloat *adaptive_beam, StateId *best_state_id, Token **best_token); @@ -551,7 +564,8 @@ class LatticeFasterDecoderCombineTpl { KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderCombineTpl); }; -typedef LatticeFasterDecoderCombineTpl LatticeFasterDecoderCombine; +typedef LatticeFasterDecoderCombineTpl > LatticeFasterDecoderCombine; From 768fd96d7bc34dedc0a7a9c5dc384aa0fbe6897b Mon Sep 17 00:00:00 2001 From: Hang Lyu Date: Thu, 14 Mar 2019 18:19:20 -0400 Subject: [PATCH 13/29] small fix. --- ...lattice-faster-decoder-combine-hashlist.cc | 8 +- .../lattice-faster-decoder-combine-hashlist.h | 2 +- ...lattice-faster-decoder-combine-iterlist.cc | 1111 +++++++++++++++++ .../lattice-faster-decoder-combine-iterlist.h | 574 +++++++++ .../lattice-faster-decoder-combine-itermap.cc | 1100 ++++++++++++++++ .../lattice-faster-decoder-combine-itermap.h | 561 +++++++++ src/decoder/lattice-faster-decoder-combine.cc | 2 + 7 files changed, 3353 insertions(+), 5 deletions(-) create mode 100644 src/decoder/lattice-faster-decoder-combine-iterlist.cc create mode 100644 src/decoder/lattice-faster-decoder-combine-iterlist.h create mode 100644 src/decoder/lattice-faster-decoder-combine-itermap.cc create mode 100644 src/decoder/lattice-faster-decoder-combine-itermap.h diff --git a/src/decoder/lattice-faster-decoder-combine-hashlist.cc b/src/decoder/lattice-faster-decoder-combine-hashlist.cc index c0bf7fb6672..bd45a83a3c9 100644 --- a/src/decoder/lattice-faster-decoder-combine-hashlist.cc +++ b/src/decoder/lattice-faster-decoder-combine-hashlist.cc @@ -315,12 +315,12 @@ bool LatticeFasterDecoderCombineTpl::GetLattice( // (whose head is at active_toks_[frame]). template inline Token* LatticeFasterDecoderCombineTpl::FindOrAddToken( - StateId state, int32 frame_plus_one, BaseFloat tot_cost, Token *backpointer, - StateIdToTokenMap *token_map, bool *changed) { + StateId state, int32 token_list_index, BaseFloat tot_cost, + Token *backpointer, StateIdToTokenMap *token_map, bool *changed) { // Returns the Token pointer. Sets "changed" (if non-NULL) to true // if the token was newly created or the cost changed. - KALDI_ASSERT(frame_plus_one < active_toks_.size()); - Token *&toks = active_toks_[frame_plus_one].toks; + KALDI_ASSERT(token_list_index < active_toks_.size()); + Token *&toks = active_toks_[token_list_index].toks; Elem *e_found = token_map->Find(state); if (e_found == NULL) { // no such token presently. const BaseFloat extra_cost = 0.0; diff --git a/src/decoder/lattice-faster-decoder-combine-hashlist.h b/src/decoder/lattice-faster-decoder-combine-hashlist.h index e3b45f444de..ca67cf4c531 100644 --- a/src/decoder/lattice-faster-decoder-combine-hashlist.h +++ b/src/decoder/lattice-faster-decoder-combine-hashlist.h @@ -394,7 +394,7 @@ class LatticeFasterDecoderCombineTpl { // token was newly created or the cost changed. // If Token == StdToken, the 'backpointer' argument has no purpose (and will // hopefully be optimized out). - inline Token *FindOrAddToken(StateId state, int32 frame_plus_one, + inline Token *FindOrAddToken(StateId state, int32 token_list_index, BaseFloat tot_cost, Token *backpointer, StateIdToTokenMap *token_map, bool *changed); diff --git a/src/decoder/lattice-faster-decoder-combine-iterlist.cc b/src/decoder/lattice-faster-decoder-combine-iterlist.cc new file mode 100644 index 00000000000..5c87d72fe14 --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine-iterlist.cc @@ -0,0 +1,1111 @@ +// decoder/lattice-faster-decoder-combine.cc + +// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/lattice-faster-decoder-combine.h" +#include "lat/lattice-functions.h" + +namespace kaldi { + +// instantiate this class once for each thing you have to decode. +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const FST &fst, + const LatticeFasterDecoderCombineConfig &config): + fst_(&fst), delete_fst_(false), config_(config), num_toks_(0) { + config.Check(); +} + + +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const LatticeFasterDecoderCombineConfig &config, FST *fst): + fst_(fst), delete_fst_(true), config_(config), num_toks_(0) { + config.Check(); + prev_toks_.reserve(1000); + cur_toks_.reserve(1000); +} + + +template +LatticeFasterDecoderCombineTpl::~LatticeFasterDecoderCombineTpl() { + ClearActiveTokens(); + if (delete_fst_) delete fst_; + //prev_toks_.clear(); + //cur_toks_.clear(); +} + +template +void LatticeFasterDecoderCombineTpl::InitDecoding() { + // clean up from last time: + prev_toks_.clear(); + cur_toks_.clear(); + cost_offsets_.clear(); + ClearActiveTokens(); + + warned_ = false; + num_toks_ = 0; + decoding_finalized_ = false; + final_costs_.clear(); + StateId start_state = fst_->Start(); + KALDI_ASSERT(start_state != fst::kNoStateId); + active_toks_.resize(1); + Token *start_tok = new Token(0.0, 0.0, start_state, NULL, NULL, NULL); + active_toks_[0].toks = start_tok; + cur_toks_[start_state] = start_tok; // initialize current tokens map + num_toks_++; +} + +// Returns true if any kind of traceback is available (not necessarily from +// a final state). It should only very rarely return false; this indicates +// an unusual search error. +template +bool LatticeFasterDecoderCombineTpl::Decode(DecodableInterface *decodable) { + InitDecoding(); + + // We use 1-based indexing for frames in this decoder (if you view it in + // terms of features), but note that the decodable object uses zero-based + // numbering, which we have to correct for when we call it. + + while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { + if (NumFramesDecoded() % config_.prune_interval == 0) + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + ProcessForFrame(decodable); + } + // A complete token list of the last frame will be generated in FinalizeDecoding() + FinalizeDecoding(); + + // Returns true if we have any kind of traceback available (not necessarily + // to the end state; query ReachedFinal() for that). + return !active_toks_.empty() && active_toks_.back().toks != NULL; +} + + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeFasterDecoderCombineTpl::GetBestPath( + Lattice *olat, + bool use_final_probs) { + Lattice raw_lat; + GetRawLattice(&raw_lat, use_final_probs); + ShortestPath(raw_lat, olat); + return (olat->NumStates() != 0); +} + + +// Outputs an FST corresponding to the raw, state-level lattice +template +bool LatticeFasterDecoderCombineTpl::GetRawLattice( + Lattice *ofst, + bool use_final_probs) { + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + // Note: you can't use the old interface (Decode()) if you want to + // get the lattice with use_final_probs = false. You'd have to do + // InitDecoding() and then AdvanceDecoding(). + if (decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "GetRawLattice() with use_final_probs == false"; + + std::unordered_map token_orig_cost; + if (!decoding_finalized_) { + // Process the non-emitting arcs for the unfinished last frame. + ProcessNonemitting(&token_orig_cost); + } + + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (decoding_finalized_ ? final_costs_ : final_costs_local); + if (!decoding_finalized_ && use_final_probs) + ComputeFinalCosts(&final_costs_local, NULL, NULL); + + ofst->DeleteStates(); + // num-frames plus one (since frames are one-based, and we have + // an extra frame for the start-state). + int32 num_frames = active_toks_.size() - 1; + KALDI_ASSERT(num_frames > 0); + const int32 bucket_count = num_toks_/2 + 3; + unordered_map tok_map(bucket_count); + // First create all states. + std::vector token_list; + for (int32 f = 0; f <= num_frames; f++) { + if (active_toks_[f].toks == NULL) { + KALDI_WARN << "GetRawLattice: no tokens active on frame " << f + << ": not producing lattice.\n"; + return false; + } + TopSortTokens(active_toks_[f].toks, &token_list); + for (size_t i = 0; i < token_list.size(); i++) + if (token_list[i] != NULL) + tok_map[token_list[i]] = ofst->AddState(); + } + // The next statement sets the start state of the output FST. Because we + // topologically sorted the tokens, state zero must be the start-state. + ofst->SetStart(0); + + KALDI_VLOG(4) << "init:" << num_toks_/2 + 3 << " buckets:" + << tok_map.bucket_count() << " load:" << tok_map.load_factor() + << " max:" << tok_map.max_load_factor(); + // Now create all arcs. + for (int32 f = 0; f <= num_frames; f++) { + for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + for (ForwardLinkT *l = tok->links; + l != NULL; + l = l->next) { + typename unordered_map::const_iterator + iter = tok_map.find(l->next_tok); + StateId nextstate = iter->second; + KALDI_ASSERT(iter != tok_map.end()); + BaseFloat cost_offset = 0.0; + if (l->ilabel != 0) { // emitting.. + KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); + cost_offset = cost_offsets_[f]; + } + Arc arc(l->ilabel, l->olabel, + Weight(l->graph_cost, l->acoustic_cost - cost_offset), + nextstate); + ofst->AddArc(cur_state, arc); + } + if (f == num_frames) { + if (use_final_probs && !final_costs.empty()) { + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) + ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); + } else { + ofst->SetFinal(cur_state, LatticeWeight::One()); + } + } + } + } + + if (!decoding_finalized_) { // recover last token list + RecoverLastTokenList(token_orig_cost); + } + return (ofst->NumStates() > 0); +} + + +// When GetRawLattice() is called during decoding, the +// active_toks_[last_frame] is changed. To keep the consistency of function +// ProcessForFrame(), recover it. +// Notice: as new token will be added to the head of TokenList, tok->next +// will not be affacted. +template +void LatticeFasterDecoderCombineTpl::RecoverLastTokenList( + const std::unordered_map &token_orig_cost) { + if (!token_orig_cost.empty()) { + for (Token* tok = active_toks_[active_toks_.size() - 1].toks; + tok != NULL;) { + if (token_orig_cost.find(tok) != token_orig_cost.end()) { + DeleteForwardLinks(tok); + tok->tot_cost = token_orig_cost.find(tok)->second; + tok->in_current_queue = false; + tok = tok->next; + } else { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + } +} + +// This function is now deprecated, since now we do determinization from outside +// the LatticeFasterDecoder class. Outputs an FST corresponding to the +// lattice-determinized lattice (one path per word sequence). +template +bool LatticeFasterDecoderCombineTpl::GetLattice( + CompactLattice *ofst, + bool use_final_probs) { + Lattice raw_fst; + GetRawLattice(&raw_fst, use_final_probs); + Invert(&raw_fst); // make it so word labels are on the input. + // (in phase where we get backward-costs). + fst::ILabelCompare ilabel_comp; + ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes + // lattice-determinization more efficient. + + fst::DeterminizeLatticePrunedOptions lat_opts; + lat_opts.max_mem = config_.det_opts.max_mem; + + DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts); + raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed. + Connect(ofst); // Remove unreachable states... there might be + // a small number of these, in some cases. + // Note: if something went wrong and the raw lattice was empty, + // we should still get to this point in the code without warnings or failures. + return (ofst->NumStates() != 0); +} + +/* + A note on the definition of extra_cost. + + extra_cost is used in pruning tokens, to save memory. + + Define the 'forward cost' of a token as zero for any token on the frame + we're currently decoding; and for other frames, as the shortest-path cost + between that token and a token on the frame we're currently decoding. + (by "currently decoding" I mean the most recently processed frame). + + Then define the extra_cost of a token (always >= 0) as the forward-cost of + the token minus the smallest forward-cost of any token on the same frame. + + We can use the extra_cost to accurately prune away tokens that we know will + never appear in the lattice. If the extra_cost is greater than the desired + lattice beam, the token would provably never appear in the lattice, so we can + prune away the token. + + The advantage of storing the extra_cost rather than the forward-cost, is that + it is less costly to keep the extra_cost up-to-date when we process new frames. + When we process a new frame, *all* the previous frames' forward-costs would change; + but in general the extra_cost will change only for a finite number of frames. + (Actually we don't update all the extra_costs every time we update a frame; we + only do it every 'config_.prune_interval' frames). + */ + +// FindOrAddToken either locates a token in hash map "token_map" +// or if necessary inserts a new, empty token (i.e. with no forward links) +// for the current frame. [note: it's inserted if necessary into hash toks_ +// and also into the singly linked list of tokens active on this frame +// (whose head is at active_toks_[frame]). +template +inline Token* LatticeFasterDecoderCombineTpl::FindOrAddToken( + StateId state, int32 token_list_index, BaseFloat tot_cost, + Token *backpointer, StateIdToTokenMap *token_map, bool *changed) { + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + KALDI_ASSERT(token_list_index < active_toks_.size()); + Token *&toks = active_toks_[token_list_index].toks; + typename StateIdToTokenMap::iterator e_found = token_map->find(state); + if (e_found == token_map->end()) { // no such token presently. + const BaseFloat extra_cost = 0.0; + // tokens on the currently final frame have zero extra_cost + // as any of them could end up + // on the winning path. + Token *new_tok = new Token (tot_cost, extra_cost, state, + NULL, toks, backpointer); + // NULL: no forward links yet + toks = new_tok; + num_toks_++; + // insert into the map + (*token_map)[state] = new_tok; + if (changed) *changed = true; + return new_tok; + } else { + Token *tok = e_found->second; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { // replace old token + tok->tot_cost = tot_cost; + // SetBackpointer() just does tok->backpointer = backpointer in + // the case where Token == BackpointerToken, else nothing. + tok->SetBackpointer(backpointer); + // we don't allocate a new token, the old stays linked in active_toks_ + // we only replace the tot_cost + // in the current frame, there are no forward links (and no extra_cost) + // only in ProcessNonemitting we have to delete forward links + // in case we visit a state for the second time + // those forward links, that lead to this replaced token before: + // they remain and will hopefully be pruned later (PruneForwardLinks...) + if (changed) *changed = true; + } else { + if (changed) *changed = false; + } + return tok; + } +} + +// prunes outgoing links for all tokens in active_toks_[frame] +// it's called by PruneActiveTokens +// all links, that have link_extra_cost > lattice_beam are pruned +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinks( + int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, BaseFloat delta) { + // delta is the amount by which the extra_costs must change + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + + *extra_costs_changed = false; + *links_pruned = false; + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen. + if (!warned_) { + KALDI_WARN << "No tokens alive [doing pruning].. warning first " + "time only for each utterance\n"; + warned_ = true; + } + } + + // We have to iterate until there is no more change, because the links + // are not guaranteed to be in topological order. + bool changed = true; // difference new minus old extra cost >= delta ? + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost for tok. + BaseFloat tok_extra_cost = std::numeric_limits::infinity(); + // tok_extra_cost is the best (min) of link_extra_cost of outgoing links + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); // difference in brackets is >= 0 + // link_exta_cost is the difference in score between the best paths + // through link source state and through link destination state + KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + *links_pruned = true; + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; // move to next link + link = link->next; + } + } // for all outgoing links + if (fabs(tok_extra_cost - tok->extra_cost) > delta) + changed = true; // difference new minus old is bigger than delta + tok->extra_cost = tok_extra_cost; + // will be +infinity or <= lattice_beam_. + // infinity indicates, that no forward link survived pruning + } // for all Token on active_toks_[frame] + if (changed) *extra_costs_changed = true; + + // Note: it's theoretically possible that aggressive compiler + // optimizations could cause an infinite loop here for small delta and + // high-dynamic-range scores. + } // while changed +} + +// PruneForwardLinksFinal is a version of PruneForwardLinks that we call +// on the final frame. If there are final tokens active, it uses +// the final-probs for pruning, otherwise it treats all tokens as final. +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinksFinal() { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame_plus_one = active_toks_.size() - 1; + + if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. + KALDI_WARN << "No tokens alive at end of file"; + + typedef typename unordered_map::const_iterator IterType; + ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); + decoding_finalized_ = true; + + // Now go through tokens on this frame, pruning forward links... may have to + // iterate a few times until there is no more change, because the list is not + // in topological order. This is a modified version of the code in + // PruneForwardLinks, but here we also take account of the final-probs. + bool changed = true; + BaseFloat delta = 1.0e-05; + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost. It has a term in it that corresponds + // to the "final-prob", so instead of initializing tok_extra_cost to infinity + // below we set it to the difference between the (score+final_prob) of this token, + // and the best such (score+final_prob). + BaseFloat final_cost; + if (final_costs_.empty()) { + final_cost = 0.0; + } else { + IterType iter = final_costs_.find(tok); + if (iter != final_costs_.end()) + final_cost = iter->second; + else + final_cost = std::numeric_limits::infinity(); + } + BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_; + // tok_extra_cost will be a "min" over either directly being final, or + // being indirectly final through other links, and the loop below may + // decrease its value: + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; + link = link->next; + } + } + // prune away tokens worse than lattice_beam above best path. This step + // was not necessary in the non-final case because then, this case + // showed up as having no forward links. Here, the tok_extra_cost has + // an extra component relating to the final-prob. + if (tok_extra_cost > config_.lattice_beam) + tok_extra_cost = std::numeric_limits::infinity(); + // to be pruned in PruneTokensForFrame + + if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) + changed = true; + tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. + } + } // while changed +} + +template +BaseFloat LatticeFasterDecoderCombineTpl::FinalRelativeCost() const { + if (!decoding_finalized_) { + BaseFloat relative_cost; + ComputeFinalCosts(NULL, &relative_cost, NULL); + return relative_cost; + } else { + // we're not allowed to call that function if FinalizeDecoding() has + // been called; return a cached value. + return final_relative_cost_; + } +} + + +// Prune away any tokens on this frame that have no forward links. +// [we don't do this in PruneForwardLinks because it would give us +// a problem with dangling pointers]. +// It's called by PruneActiveTokens if any forward links have been pruned +template +void LatticeFasterDecoderCombineTpl::PruneTokensForFrame( + int32 frame_plus_one) { + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + if (toks == NULL) + KALDI_WARN << "No tokens alive [doing pruning]"; + Token *tok, *next_tok, *prev_tok = NULL; + for (tok = toks; tok != NULL; tok = next_tok) { + next_tok = tok->next; + if (tok->extra_cost == std::numeric_limits::infinity()) { + // token is unreachable from end of graph; (no forward links survived) + // excise tok from list and delete tok. + if (prev_tok != NULL) prev_tok->next = tok->next; + else toks = tok->next; + delete tok; + num_toks_--; + } else { // fetch next Token + prev_tok = tok; + } + } +} + +// Go backwards through still-alive tokens, pruning them, starting not from +// the current frame (where we want to keep all tokens) but from the frame before +// that. We go backwards through the frames and stop when we reach a point +// where the delta-costs are not changing (and the delta controls when we consider +// a cost to have "not changed"). +template +void LatticeFasterDecoderCombineTpl::PruneActiveTokens( + BaseFloat delta) { + int32 cur_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // The index "f" below represents a "frame plus one", i.e. you'd have to subtract + // one to get the corresponding index for the decodable object. + for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) { + // Reason why we need to prune forward links in this situation: + // (1) we have never pruned them (new TokenList) + // (2) we have not yet pruned the forward links to the next f, + // after any of those tokens have changed their extra_cost. + if (active_toks_[f].must_prune_forward_links) { + bool extra_costs_changed = false, links_pruned = false; + PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); + if (extra_costs_changed && f > 0) // any token has changed extra_cost + active_toks_[f-1].must_prune_forward_links = true; + if (links_pruned) // any link was pruned + active_toks_[f].must_prune_tokens = true; + active_toks_[f].must_prune_forward_links = false; // job done + } + if (f+1 < cur_frame_plus_one && // except for last f (no forward links) + active_toks_[f+1].must_prune_tokens) { + PruneTokensForFrame(f+1); + active_toks_[f+1].must_prune_tokens = false; + } + } + KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +template +void LatticeFasterDecoderCombineTpl::ComputeFinalCosts( + unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const { + KALDI_ASSERT(!decoding_finalized_); + if (final_costs != NULL) + final_costs->clear(); + BaseFloat infinity = std::numeric_limits::infinity(); + BaseFloat best_cost = infinity, + best_cost_with_final = infinity; + + // The final tokens are recorded in active_toks_[last_frame] + for (Token *tok = active_toks_[active_toks_.size() - 1].toks; tok != NULL; + tok = tok->next) { + StateId state = tok->state_id; + BaseFloat final_cost = fst_->Final(state).Value(); + BaseFloat cost = tok->tot_cost, + cost_with_final = cost + final_cost; + best_cost = std::min(cost, best_cost); + best_cost_with_final = std::min(cost_with_final, best_cost_with_final); + if (final_costs != NULL && final_cost != infinity) + (*final_costs)[tok] = final_cost; + } + if (final_relative_cost != NULL) { + if (best_cost == infinity && best_cost_with_final == infinity) { + // Likely this will only happen if there are no tokens surviving. + // This seems the least bad way to handle it. + *final_relative_cost = infinity; + } else { + *final_relative_cost = best_cost_with_final - best_cost; + } + } + if (final_best_cost != NULL) { + if (best_cost_with_final != infinity) { // final-state exists. + *final_best_cost = best_cost_with_final; + } else { // no final-state exists. + *final_best_cost = best_cost; + } + } +} + +template +void LatticeFasterDecoderCombineTpl::AdvanceDecoding( + DecodableInterface *decodable, + int32 max_num_frames) { + if (std::is_same >::value) { + // if the type 'FST' is the FST base-class, then see if the FST type of fst_ + // is actually VectorFst or ConstFst. If so, call the AdvanceDecoding() + // function after casting *this to the more specific type. + if (fst_->Type() == "const") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } else if (fst_->Type() == "vector") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } + } + + + KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ && + "You must call InitDecoding() before AdvanceDecoding"); + int32 num_frames_ready = decodable->NumFramesReady(); + // num_frames_ready must be >= num_frames_decoded, or else + // the number of frames ready must have decreased (which doesn't + // make sense) or the decodable object changed between calls + // (which isn't allowed). + KALDI_ASSERT(num_frames_ready >= NumFramesDecoded()); + int32 target_frames_decoded = num_frames_ready; + if (max_num_frames >= 0) + target_frames_decoded = std::min(target_frames_decoded, + NumFramesDecoded() + max_num_frames); + while (NumFramesDecoded() < target_frames_decoded) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + ProcessForFrame(decodable); + } +} + +// FinalizeDecoding() is a version of PruneActiveTokens that we call +// (optionally) on the final frame. Takes into account the final-prob of +// tokens. This function used to be called PruneActiveTokensFinal(). +template +void LatticeFasterDecoderCombineTpl::FinalizeDecoding() { + ProcessNonemitting(NULL); + int32 final_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // PruneForwardLinksFinal() prunes final frame (with final-probs), and + // sets decoding_finalized_. + PruneForwardLinksFinal(); + for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { + bool b1, b2; // values not used. + BaseFloat dontcare = 0.0; // delta of zero means we must always update + PruneForwardLinks(f, &b1, &b2, dontcare); + PruneTokensForFrame(f + 1); + } + PruneTokensForFrame(0); + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +/// Gets the weight cutoff. +template +BaseFloat LatticeFasterDecoderCombineTpl::GetCutoff( + const TokenList &token_list, BaseFloat *adaptive_beam, + StateId *best_state_id, Token **best_token) { + // positive == high cost == bad. + // best_weight is the minimum value. + BaseFloat best_weight = std::numeric_limits::infinity(); + if (config_.max_active == std::numeric_limits::max() && + config_.min_active == 0) { + for (Token* tok = token_list.toks; tok != NULL; tok = tok->next) { + BaseFloat w = static_cast(tok->tot_cost); + if (w < best_weight) { + best_weight = w; + if (best_token) { + *best_state_id = tok->state_id; + *best_token = tok; + } + } + } + if (adaptive_beam != NULL) *adaptive_beam = config_.beam; + return best_weight + config_.beam; + } else { + tmp_array_.clear(); + for (Token* tok = token_list.toks; tok != NULL; tok = tok->next) { + BaseFloat w = static_cast(tok->tot_cost); + tmp_array_.push_back(w); + if (w < best_weight) { + best_weight = w; + if (best_token) { + *best_state_id = tok->state_id; + *best_token = tok; + } + } + } + + BaseFloat beam_cutoff = best_weight + config_.beam, + min_active_cutoff = std::numeric_limits::infinity(), + max_active_cutoff = std::numeric_limits::infinity(); + + KALDI_VLOG(6) << "Number of emitting tokens on frame " + << NumFramesDecoded() - 1 << " is " << tmp_array_.size(); + + if (tmp_array_.size() > static_cast(config_.max_active)) { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.max_active, + tmp_array_.end()); + max_active_cutoff = tmp_array_[config_.max_active]; + } + if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. + if (adaptive_beam) + *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta; + return max_active_cutoff; + } + if (tmp_array_.size() > static_cast(config_.min_active)) { + if (config_.min_active == 0) min_active_cutoff = best_weight; + else { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.min_active, + tmp_array_.size() > static_cast(config_.max_active) ? + tmp_array_.begin() + config_.max_active : tmp_array_.end()); + min_active_cutoff = tmp_array_[config_.min_active]; + } + } + if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. + if (adaptive_beam) + *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta; + return min_active_cutoff; + } else { + *adaptive_beam = config_.beam; + return beam_cutoff; + } + } +} + +template +void LatticeFasterDecoderCombineTpl::ProcessForFrame( + DecodableInterface *decodable) { + KALDI_ASSERT(active_toks_.size() > 0); + int32 frame = active_toks_.size() - 1; // frame is the frame-index + // (zero-based) used to get likelihoods + // from the decodable object. + active_toks_.resize(active_toks_.size() + 1); + + prev_toks_.swap(cur_toks_); + cur_toks_.clear(); + if (prev_toks_.empty()) { + if (!warned_) { + KALDI_WARN << "Error, no surviving tokens on frame " << frame; + warned_ = true; + } + } + + BaseFloat adaptive_beam; + Token *best_tok = NULL; + StateId best_tok_state_id; + // "cur_cutoff" is used to constrain the epsilon emittion in current frame. + // It will not be updated. + BaseFloat cur_cutoff = GetCutoff(active_toks_[frame], &adaptive_beam, + &best_tok_state_id, &best_tok); + KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " + << adaptive_beam; + + + // pruning "online" before having seen all tokens + + // "next_cutoff" is used to limit a new token in next frame should be handle + // or not. It will be updated along with the further processing. + BaseFloat next_cutoff = std::numeric_limits::infinity(); + // "cost_offset" contains the acoustic log-likelihoods on current frame in + // order to keep everything in a nice dynamic range. Reduce roundoff errors. + BaseFloat cost_offset = 0.0; + + // First process the best token to get a hopefully + // reasonably tight bound on the next cutoff. The only + // products of the next block are "next_cutoff" and "cost_offset". + // Notice: As the difference between the combine version and the traditional + // version, this "best_tok" is choosen from emittion tokens. Normally, the + // best token of one frame comes from an epsilon non-emittion. So the best + // token is a looser boundary. We use it to estimate a bound on the next + // cutoff and we will update the "next_cutoff" once we have better tokens. + // The "next_cutoff" will be updated in further processing. + if (best_tok) { + cost_offset = - best_tok->tot_cost; + for (fst::ArcIterator aiter(*fst_, best_tok_state_id); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + // ac_cost + graph_cost + BaseFloat new_weight = arc.weight.Value() + cost_offset - + decodable->LogLikelihood(frame, arc.ilabel) + best_tok->tot_cost; + if (new_weight + adaptive_beam < next_cutoff) + next_cutoff = new_weight + adaptive_beam; + } + } + } + + // Store the offset on the acoustic likelihoods that we're applying. + // Could just do cost_offsets_.push_back(cost_offset), but we + // do it this way as it's more robust to future code changes. + cost_offsets_.resize(frame + 1, 0.0); + cost_offsets_[frame] = cost_offset; + + // Build a queue which contains the emittion tokens from previous frame. + for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + cur_queue_.push(tok->state_id); + tok->in_current_queue = true; + } + + // Iterator the "cur_queue_" to process non-emittion and emittion arcs in fst. + while (!cur_queue_.empty()) { + StateId state = cur_queue_.front(); + cur_queue_.pop(); + + KALDI_ASSERT(prev_toks_.find(state) != prev_toks_.end()); + Token *tok = prev_toks_[state]; + + BaseFloat cur_cost = tok->tot_cost; + tok->in_current_queue = false; // out of queue + if (cur_cost > cur_cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, + tok, &prev_toks_, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed && !new_tok->in_current_queue) { + cur_queue_.push(arc.nextstate); + new_tok->in_current_queue = true; + } + } + } else { // propagate emitting + BaseFloat graph_cost = arc.weight.Value(), + ac_cost = cost_offset - decodable->LogLikelihood(frame, arc.ilabel), + cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + if (tot_cost > next_cutoff) continue; + else if (tot_cost + adaptive_beam < next_cutoff) + next_cutoff = tot_cost + adaptive_beam; // a tighter boundary for emitting + + // no change flag is needed + Token *next_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, + tok, &cur_toks_, NULL); + // Add ForwardLink from tok to next_tok. Put it on the head of tok->link + // list + tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, + graph_cost, ac_cost, tok->links); + } + } // for all arcs + } // end of while loop + KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() - 1 + << " is " << prev_toks_.size(); +} + + +template +void LatticeFasterDecoderCombineTpl::ProcessNonemitting( + std::unordered_map *token_orig_cost) { + int32 frame = active_toks_.size() - 1; + if (token_orig_cost) { // Build the elements which are used to recover + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + (*token_orig_cost)[tok] = tok->tot_cost; + } + } + + StateIdToTokenMap *tmp_toks; + if (token_orig_cost) { // "token_orig_cost" isn't NULL. It means we need to + // recover active_toks_[last_frame] and "cur_toks_" + // will be used in the future. + tmp_toks = new StateIdToTokenMap(cur_toks_); + } else { + tmp_toks = &cur_toks_; + } + + // Build the queue to process non-emitting arcs. + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + if (fst_->NumInputEpsilons(tok->state_id) != 0) { + cur_queue_.push(tok->state_id); + tok->in_current_queue = true; + } + } + + // "cur_cutoff" is used to constrain the epsilon emittion in current frame. + // It will not be updated. + BaseFloat adaptive_beam; + BaseFloat cur_cutoff = GetCutoff(active_toks_[frame], &adaptive_beam, NULL, NULL); + + while (!cur_queue_.empty()) { + StateId state = cur_queue_.front(); + cur_queue_.pop(); + + KALDI_ASSERT(tmp_toks->find(state) != tmp_toks->end()); + Token *tok = (*tmp_toks)[state]; + BaseFloat cur_cost = tok->tot_cost; + if (cur_cost > cur_cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, + tok, tmp_toks, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed && !new_tok->in_current_queue) { + cur_queue_.push(arc.nextstate); + new_tok->in_current_queue = true; + } + } + } + } // end of for loop + tok->in_current_queue = false; + } // end of while loop + if (token_orig_cost) delete tmp_toks; +} + + + +// static inline +template +void LatticeFasterDecoderCombineTpl::DeleteForwardLinks(Token *tok) { + ForwardLinkT *l = tok->links, *m; + while (l != NULL) { + m = l->next; + delete l; + l = m; + } + tok->links = NULL; +} + + +template +void LatticeFasterDecoderCombineTpl::ClearActiveTokens() { + // a cleanup routine, at utt end/begin + for (size_t i = 0; i < active_toks_.size(); i++) { + // Delete all tokens alive on this frame, and any forward + // links they may have. + for (Token *tok = active_toks_[i].toks; tok != NULL; ) { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + active_toks_.clear(); + KALDI_ASSERT(num_toks_ == 0); +} + +// static +template +void LatticeFasterDecoderCombineTpl::TopSortTokens( + Token *tok_list, std::vector *topsorted_list) { + unordered_map token2pos; + typedef typename unordered_map::iterator IterType; + int32 num_toks = 0; + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + num_toks++; + int32 cur_pos = 0; + // We assign the tokens numbers num_toks - 1, ... , 2, 1, 0. + // This is likely to be in closer to topological order than + // if we had given them ascending order, because of the way + // new tokens are put at the front of the list. + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + token2pos[tok] = num_toks - ++cur_pos; + + unordered_set reprocess; + + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) { + Token *tok = iter->first; + int32 pos = iter->second; + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + // We only need to consider epsilon links, since non-epsilon links + // transition between frames and this function only needs to sort a list + // of tokens from a single frame. + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { // another token on this frame, + // so must consider it. + int32 next_pos = following_iter->second; + if (next_pos < pos) { // reassign the position of the next Token. + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + // In case we had previously assigned this token to be reprocessed, we can + // erase it from that set because it's "happy now" (we just processed it). + reprocess.erase(tok); + } + + size_t max_loop = 1000000, loop_count; // max_loop is to detect epsilon cycles. + for (loop_count = 0; + !reprocess.empty() && loop_count < max_loop; ++loop_count) { + std::vector reprocess_vec; + for (typename unordered_set::iterator iter = reprocess.begin(); + iter != reprocess.end(); ++iter) + reprocess_vec.push_back(*iter); + reprocess.clear(); + for (typename std::vector::iterator iter = reprocess_vec.begin(); + iter != reprocess_vec.end(); ++iter) { + Token *tok = *iter; + int32 pos = token2pos[tok]; + // Repeat the processing we did above (for comments, see above). + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { + int32 next_pos = following_iter->second; + if (next_pos < pos) { + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + } + } + KALDI_ASSERT(loop_count < max_loop && "Epsilon loops exist in your decoding " + "graph (this is not allowed!)"); + + topsorted_list->clear(); + topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between. + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) + (*topsorted_list)[iter->second] = iter->first; +} + +// Instantiate the template for the combination of token types and FST types +// that we'll need. +template class LatticeFasterDecoderCombineTpl, + decodercombine::StdToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::StdToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::StdToken > >; +template class LatticeFasterDecoderCombineTpl >; + +template class LatticeFasterDecoderCombineTpl , + decodercombine::BackpointerToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::BackpointerToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::BackpointerToken > >; +template class LatticeFasterDecoderCombineTpl >; + + +} // end namespace kaldi. diff --git a/src/decoder/lattice-faster-decoder-combine-iterlist.h b/src/decoder/lattice-faster-decoder-combine-iterlist.h new file mode 100644 index 00000000000..900d03520e4 --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine-iterlist.h @@ -0,0 +1,574 @@ +// decoder/lattice-faster-decoder-combine.h + +// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ +#define KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ + + +#include "util/stl-utils.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/grammar-fst.h" +#include "decoder/lattice-faster-decoder.h" + +namespace kaldi { + +struct LatticeFasterDecoderCombineConfig { + BaseFloat beam; + int32 max_active; + int32 min_active; + BaseFloat lattice_beam; + int32 prune_interval; + bool determinize_lattice; // not inspected by this class... used in + // command-line program. + BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat hash_ratio; + BaseFloat prune_scale; // Note: we don't make this configurable on the command line, + // it's not a very important parameter. It affects the + // algorithm that prunes the tokens as we go. + // Most of the options inside det_opts are not actually queried by the + // LatticeFasterDecoder class itself, but by the code that calls it, for + // example in the function DecodeUtteranceLatticeFaster. + fst::DeterminizeLatticePhonePrunedOptions det_opts; + + LatticeFasterDecoderCombineConfig(): beam(16.0), + max_active(std::numeric_limits::max()), + min_active(200), + lattice_beam(10.0), + prune_interval(25), + determinize_lattice(true), + beam_delta(0.5), + hash_ratio(2.0), + prune_scale(0.1) { } + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); + opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; " + "more accurate"); + opts->Register("min-active", &min_active, "Decoder minimum #active states."); + opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, " + "and deeper lattices"); + opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " + "which to prune tokens"); + opts->Register("determinize-lattice", &determinize_lattice, "If true, " + "determinize the lattice (lattice-determinization, keeping only " + "best pdf-sequence for each word-sequence)."); + opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " + "parameter is obscure and relates to a speedup in the way the " + "max-active constraint is applied. Larger is more accurate."); + opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " + "control hash behavior"); + } + void Check() const { + KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 + && min_active <= max_active + && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0 + && prune_scale > 0.0 && prune_scale < 1.0); + } +}; + + +namespace decodercombine { +// We will template the decoder on the token type as well as the FST type; this +// is a mechanism so that we can use the same underlying decoder code for +// versions of the decoder that support quickly getting the best path +// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also +// those that do not (LatticeFasterDecoder). + + +// ForwardLinks are the links from a token to a token on the next frame. +// or sometimes on the current frame (for input-epsilon links). +template +struct ForwardLink { + using Label = fst::StdArc::Label; + + Token *next_tok; // the next token [or NULL if represents final-state] + Label ilabel; // ilabel on arc + Label olabel; // olabel on arc + BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.) + BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc + ForwardLink *next; // next in singly-linked list of forward arcs (arcs + // in the state-level lattice) from a token. + inline ForwardLink(Token *next_tok, Label ilabel, Label olabel, + BaseFloat graph_cost, BaseFloat acoustic_cost, + ForwardLink *next): + next_tok(next_tok), ilabel(ilabel), olabel(olabel), + graph_cost(graph_cost), acoustic_cost(acoustic_cost), + next(next) { } +}; + +template +struct StdToken { + using ForwardLinkT = ForwardLink; + using Token = StdToken; + using StateId = typename Fst::Arc::StateId; + + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // Record the state id of the token + StateId state_id; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + Token *next; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_current_queue; + + + // This function does nothing and should be optimized out; it's needed + // so we can share the regular LatticeFasterDecoderTpl code and the code + // for LatticeFasterOnlineDecoder that supports fast traceback. + inline void SetBackpointer (Token *backpointer) { } + + // This constructor just ignores the 'backpointer' argument. That argument is + // needed so that we can use the same decoder code for LatticeFasterDecoderTpl + // and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a + // fast way to obtain the best path). + inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, StateId state_id, + ForwardLinkT *links, Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), state_id(state_id), + links(links), next(next), in_current_queue(false) { } +}; + +template +struct BackpointerToken { + using ForwardLinkT = ForwardLink; + using Token = BackpointerToken; + using StateId = typename Fst::Arc::StateId; + + // BackpointerToken is like Token but also + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // Record the state id of the token + StateId state_id; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + BackpointerToken *next; + + // Best preceding BackpointerToken (could be a on this frame, connected to + // this via an epsilon transition, or on a previous frame). This is only + // required for an efficient GetBestPath function in + // LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation + // (the "links" list is what stores the forward links, for that). + Token *backpointer; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_current_queue; + + inline void SetBackpointer (Token *backpointer) { + this->backpointer = backpointer; + } + + inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, + StateId state_id, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), state_id(state_id), + links(links), next(next), backpointer(backpointer), + in_current_queue(false) { } +}; + +} // namespace decoder + + +/** This is the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to support + quick lookup of the current best path (see lattice-faster-online-decoder.h) + + The FST you invoke this decoder with is expected to equal + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ +template > +class LatticeFasterDecoderCombineTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ForwardLinkT = decodercombine::ForwardLink; + + using StateIdToTokenMap = typename std::unordered_map; + //using StateIdToTokenMap = typename std::unordered_map, std::equal_to, + // fst::PoolAllocator > >; + using IterType = typename StateIdToTokenMap::const_iterator; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeFasterDecoderCombineTpl(const FST &fst, + const LatticeFasterDecoderCombineConfig &config); + + // This version of the constructor takes ownership of the fst, and will delete + // it when this object is destroyed. + LatticeFasterDecoderCombineTpl(const LatticeFasterDecoderCombineConfig &config, + FST *fst); + + void SetOptions(const LatticeFasterDecoderCombineConfig &config) { + config_ = config; + } + + const LatticeFasterDecoderCombineConfig &GetOptions() const { + return config_; + } + + ~LatticeFasterDecoderCombineTpl(); + + /// Decodes until there are no more frames left in the "decodable" object.. + /// note, this may block waiting for input if the "decodable" object blocks. + /// Returns true if any kind of traceback is available (not necessarily from a + /// final state). + bool Decode(DecodableInterface *decodable); + + + /// says whether a final-state was active on the last frame. If it was not, the + /// lattice (or traceback) will end with states that are not final-states. + bool ReachedFinal() const { + return FinalRelativeCost() != std::numeric_limits::infinity(); + } + + /// Outputs an FST corresponding to the single best path through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. Note: this just calls GetRawLattice() + /// and figures out the shortest path. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true); + + /// Outputs an FST corresponding to the raw, state-level + /// tracebacks. Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state + /// of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + /// The raw lattice will be topologically sorted. + /// The function can be called during decoding, it will process non-emitting + /// arcs from "cur_toks_" map to get tokens from both non-emitting and + /// emitting arcs for getting raw lattice. Then recover it to ensure the + /// consistency of ProcessForFrame(). + /// + /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, + /// which also supports a pruning beam, in case for some reason + /// you want it pruned tighter than the regular lattice beam. + /// We could put that here in future needed. + bool GetRawLattice(Lattice *ofst, bool use_final_probs = true); + + + + /// [Deprecated, users should now use GetRawLattice and determinize it + /// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper]. + /// Outputs an FST corresponding to the lattice-determinized + /// lattice (one path per word sequence). Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state of the graph + /// then it will include those as final-probs, else it will treat all + /// final-probs as one. + bool GetLattice(CompactLattice *ofst, + bool use_final_probs = true); + + /// InitDecoding initializes the decoding, and should only be used if you + /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to + /// call this. You can also call InitDecoding if you have already decoded an + /// utterance and want to start with a new utterance. + void InitDecoding(); + + /// This will decode until there are no more frames ready in the decodable + /// object. You can keep calling it each time more frames become available. + /// If max_num_frames is specified, it specifies the maximum number of frames + /// the function will decode before returning. + void AdvanceDecoding(DecodableInterface *decodable, + int32 max_num_frames = -1); + + /// This function may be optionally called after AdvanceDecoding(), when you + /// do not plan to decode any further. It does an extra pruning step that + /// will help to prune the lattices output by GetLattice and (particularly) + /// GetRawLattice more accurately, particularly toward the end of the + /// utterance. It does this by using the final-probs in pruning (if any + /// final-state survived); it also does a final pruning step that visits all + /// states (the pruning that is done during decoding may fail to prune states + /// that are within kPruningScale = 0.1 outside of the beam). If you call + /// this, you cannot call AdvanceDecoding again (it will fail), and you + /// cannot call GetLattice() and related functions with use_final_probs = + /// false. + /// Used to be called PruneActiveTokensFinal(). + void FinalizeDecoding(); + + /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + /// more information. It returns the difference between the best (final-cost + /// plus cost) of any token on the final frame, and the best cost of any token + /// on the final frame. If it is infinity it means no final-states were + /// present on the final frame. It will usually be nonnegative. If it not + /// too positive (e.g. < 5 is my first guess, but this is not tested) you can + /// take it as a good indication that we reached the final-state with + /// reasonable likelihood. + BaseFloat FinalRelativeCost() const; + + + // Returns the number of frames decoded so far. The value returned changes + // whenever we call ProcessForFrame(). + inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } + + protected: + // we make things protected instead of private, as code in + // LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the + // internals. + + // Deletes the elements of the singly linked list tok->links. + inline static void DeleteForwardLinks(Token *tok); + + // head of per-frame list of Tokens (list is in topological order), + // and something saying whether we ever pruned it using PruneForwardLinks. + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + TokenList(): toks(NULL), must_prune_forward_links(true), + must_prune_tokens(true) { } + }; + + // FindOrAddToken either locates a token in hash map "token_map", or if necessary + // inserts a new, empty token (i.e. with no forward links) for the current + // frame. [note: it's inserted if necessary into hash map and also into the + // singly linked list of tokens active on this frame (whose head is at + // active_toks_[frame]). The frame_plus_one argument is the acoustic frame + // index plus one, which is used to index into the active_toks_ array. + // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the + // token was newly created or the cost changed. + // If Token == StdToken, the 'backpointer' argument has no purpose (and will + // hopefully be optimized out). + inline Token *FindOrAddToken(StateId state, int32 token_list_index, + BaseFloat tot_cost, Token *backpointer, + StateIdToTokenMap *token_map, + bool *changed); + + // prunes outgoing links for all tokens in active_toks_[frame] + // it's called by PruneActiveTokens + // all links, that have link_extra_cost > lattice_beam are pruned + // delta is the amount by which the extra_costs must change + // before we set *extra_costs_changed = true. + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, + BaseFloat delta); + + // This function computes the final-costs for tokens active on the final + // frame. It outputs to final-costs, if non-NULL, a map from the Token* + // pointer to the final-prob of the corresponding state, for all Tokens + // that correspond to states that have final-probs. This map will be + // empty if there were no final-probs. It outputs to + // final_relative_cost, if non-NULL, the difference between the best + // forward-cost including the final-prob cost, and the best forward-cost + // without including the final-prob cost (this will usually be positive), or + // infinity if there were no final-probs. [c.f. FinalRelativeCost(), which + // outputs this quanitity]. It outputs to final_best_cost, if + // non-NULL, the lowest for any token t active on the final frame, of + // forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in + // the graph of the state corresponding to token t, or the best of + // forward-cost[t] if there were no final-probs active on the final frame. + // You cannot call this after FinalizeDecoding() has been called; in that + // case you should get the answer from class-member variables. + void ComputeFinalCosts(unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const; + + // PruneForwardLinksFinal is a version of PruneForwardLinks that we call + // on the final frame. If there are final tokens active, it uses + // the final-probs for pruning, otherwise it treats all tokens as final. + void PruneForwardLinksFinal(); + + // Prune away any tokens on this frame that have no forward links. + // [we don't do this in PruneForwardLinks because it would give us + // a problem with dangling pointers]. + // It's called by PruneActiveTokens if any forward links have been pruned + void PruneTokensForFrame(int32 frame_plus_one); + + + // Go backwards through still-alive tokens, pruning them if the + // forward+backward cost is more than lat_beam away from the best path. It's + // possible to prove that this is "correct" in the sense that we won't lose + // anything outside of lat_beam, regardless of what happens in the future. + // delta controls when it considers a cost to have changed enough to continue + // going backward and propagating the change. larger delta -> will recurse + // less far. + void PruneActiveTokens(BaseFloat delta); + + /// Processes non-emitting (epsilon) arcs and emitting arcs for one frame + /// together. It takes the emittion tokens in "prev_toks_" from last frame. + /// Generates non-emitting tokens for previous frame and emitting tokens for + /// next frame. + /// Notice: The emitting tokens for the current frame means the token take + /// acoustic scores of the current frame. (i.e. the destnations of emitting + /// arcs.) + void ProcessForFrame(DecodableInterface *decodable); + + /// Processes nonemitting (epsilon) arcs for one frame. + /// Calls this function once when all frames were processed. + /// Or calls it in GetRawLattice() to generate the complete token list for + /// the last frame. [Deal With the tokens in map "cur_toks_" which would + /// only contains emittion tokens from previous frame.] + /// If the map, "token_orig_cost", isn't NULL, we build the map which will + /// be used to recover "active_toks_[last_frame]" token list for the last + /// frame. + void ProcessNonemitting(std::unordered_map *token_orig_cost); + + /// When GetRawLattice() is called during decoding, the + /// active_toks_[last_frame] is changed. To keep the consistency of function + /// ProcessForFrame(), recover it. + /// Notice: as new token will be added to the head of TokenList, tok->next + /// will not be affacted. + /// "token_orig_cost" is a mapping from token pointer to the tot_cost of the + /// token before propagating non-emitting arcs. It is used to recover the + /// change of original tokens in the last frame and remove the new tokens + /// which come from propagating non-emitting arcs, so that we can guarantee + /// the consistency of function ProcessForFrame(). + void RecoverLastTokenList( + const std::unordered_map &token_orig_cost); + + + /// The "prev_toks_" and "cur_toks_" actually allow us to maintain current + /// and next frames. They are indexed by StateId. It is indexed by frame-index + /// plus one, where the frame-index is zero-based, as used in decodable object. + /// That is, the emitting probs of frame t are accounted for in tokens at + /// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of + /// the graph. + StateIdToTokenMap prev_toks_; + StateIdToTokenMap cur_toks_; + + /// Gets the weight cutoff. + /// Notice: In traiditional version, the histogram prunning method is applied + /// on a complete token list on one frame. But, in this version, it is used + /// on a token list which only contains the emittion part. So the max_active + /// and min_active values might be narrowed. + BaseFloat GetCutoff(const TokenList &token_list, + BaseFloat *adaptive_beam, + StateId *best_state_id, Token **best_token); + + std::vector active_toks_; // Lists of tokens, indexed by + // frame (members of TokenList are toks, must_prune_forward_links, + // must_prune_tokens). + std::queue cur_queue_; // temp variable used in ProcessForFrame + // and ProcessNonemitting + std::vector tmp_array_; // used in GetCutoff. + + // fst_ is a pointer to the FST we are decoding from. + const FST *fst_; + // delete_fst_ is true if the pointer fst_ needs to be deleted when this + // object is destroyed. + bool delete_fst_; + + std::vector cost_offsets_; // This contains, for each + // frame, an offset that was added to the acoustic log-likelihoods on that + // frame in order to keep everything in a nice dynamic range i.e. close to + // zero, to reduce roundoff errors. + // Notice: It will only be added to emitting arcs (i.e. cost_offsets_[t] is + // added to arcs from "frame t" to "frame t+1"). + LatticeFasterDecoderCombineConfig config_; + int32 num_toks_; // current total #toks allocated... + bool warned_; + + /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, + /// calling this is optional]. If true, it's forbidden to decode more. Also, + /// if this is set, then the output of ComputeFinalCosts() is in the next + /// three variables. The reason we need to do this is that after + /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some + /// of the tokens on the last frame are freed, so we free the list from toks_ + /// to avoid having dangling pointers hanging around. + bool decoding_finalized_; + /// For the meaning of the next 3 variables, see the comment for + /// decoding_finalized_ above., and ComputeFinalCosts(). + unordered_map final_costs_; + BaseFloat final_relative_cost_; + BaseFloat final_best_cost_; + + // This function takes a singly linked list of tokens for a single frame, and + // outputs a list of them in topological order (it will crash if no such order + // can be found, which will typically be due to decoding graphs with epsilon + // cycles, which are not allowed). Note: the output list may contain NULLs, + // which the caller should pass over; it just happens to be more efficient for + // the algorithm to output a list that contains NULLs. + static void TopSortTokens(Token *tok_list, + std::vector *topsorted_list); + + void ClearActiveTokens(); + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderCombineTpl); +}; + +typedef LatticeFasterDecoderCombineTpl > LatticeFasterDecoderCombine; + + + +} // end namespace kaldi. + +#endif diff --git a/src/decoder/lattice-faster-decoder-combine-itermap.cc b/src/decoder/lattice-faster-decoder-combine-itermap.cc new file mode 100644 index 00000000000..6c9d70bb9b3 --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine-itermap.cc @@ -0,0 +1,1100 @@ +// decoder/lattice-faster-decoder-combine.cc + +// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/lattice-faster-decoder-combine.h" +#include "lat/lattice-functions.h" + +namespace kaldi { + +// instantiate this class once for each thing you have to decode. +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const FST &fst, + const LatticeFasterDecoderCombineConfig &config): + fst_(&fst), delete_fst_(false), config_(config), num_toks_(0) { + config.Check(); +} + + +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const LatticeFasterDecoderCombineConfig &config, FST *fst): + fst_(fst), delete_fst_(true), config_(config), num_toks_(0) { + config.Check(); + prev_toks_.reserve(1000); + cur_toks_.reserve(1000); +} + + +template +LatticeFasterDecoderCombineTpl::~LatticeFasterDecoderCombineTpl() { + ClearActiveTokens(); + if (delete_fst_) delete fst_; +} + +template +void LatticeFasterDecoderCombineTpl::InitDecoding() { + // clean up from last time: + prev_toks_.clear(); + cur_toks_.clear(); + cost_offsets_.clear(); + ClearActiveTokens(); + + warned_ = false; + num_toks_ = 0; + decoding_finalized_ = false; + final_costs_.clear(); + StateId start_state = fst_->Start(); + KALDI_ASSERT(start_state != fst::kNoStateId); + active_toks_.resize(1); + Token *start_tok = new Token(0.0, 0.0, NULL, NULL, NULL); + active_toks_[0].toks = start_tok; + cur_toks_[start_state] = start_tok; // initialize current tokens map + num_toks_++; +} + +// Returns true if any kind of traceback is available (not necessarily from +// a final state). It should only very rarely return false; this indicates +// an unusual search error. +template +bool LatticeFasterDecoderCombineTpl::Decode(DecodableInterface *decodable) { + InitDecoding(); + + // We use 1-based indexing for frames in this decoder (if you view it in + // terms of features), but note that the decodable object uses zero-based + // numbering, which we have to correct for when we call it. + + while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { + if (NumFramesDecoded() % config_.prune_interval == 0) + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + ProcessForFrame(decodable); + } + // A complete token list of the last frame will be generated in FinalizeDecoding() + FinalizeDecoding(); + + // Returns true if we have any kind of traceback available (not necessarily + // to the end state; query ReachedFinal() for that). + return !active_toks_.empty() && active_toks_.back().toks != NULL; +} + + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeFasterDecoderCombineTpl::GetBestPath( + Lattice *olat, + bool use_final_probs) { + Lattice raw_lat; + GetRawLattice(&raw_lat, use_final_probs); + ShortestPath(raw_lat, olat); + return (olat->NumStates() != 0); +} + + +// Outputs an FST corresponding to the raw, state-level lattice +template +bool LatticeFasterDecoderCombineTpl::GetRawLattice( + Lattice *ofst, + bool use_final_probs) { + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + // Note: you can't use the old interface (Decode()) if you want to + // get the lattice with use_final_probs = false. You'd have to do + // InitDecoding() and then AdvanceDecoding(). + if (decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "GetRawLattice() with use_final_probs == false"; + + std::unordered_map token_orig_cost; + if (!decoding_finalized_) { + // Process the non-emitting arcs for the unfinished last frame. + ProcessNonemitting(&token_orig_cost); + } + + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (decoding_finalized_ ? final_costs_ : final_costs_local); + if (!decoding_finalized_ && use_final_probs) + ComputeFinalCosts(&final_costs_local, NULL, NULL); + + ofst->DeleteStates(); + // num-frames plus one (since frames are one-based, and we have + // an extra frame for the start-state). + int32 num_frames = active_toks_.size() - 1; + KALDI_ASSERT(num_frames > 0); + const int32 bucket_count = num_toks_/2 + 3; + unordered_map tok_map(bucket_count); + // First create all states. + std::vector token_list; + for (int32 f = 0; f <= num_frames; f++) { + if (active_toks_[f].toks == NULL) { + KALDI_WARN << "GetRawLattice: no tokens active on frame " << f + << ": not producing lattice.\n"; + return false; + } + TopSortTokens(active_toks_[f].toks, &token_list); + for (size_t i = 0; i < token_list.size(); i++) + if (token_list[i] != NULL) + tok_map[token_list[i]] = ofst->AddState(); + } + // The next statement sets the start state of the output FST. Because we + // topologically sorted the tokens, state zero must be the start-state. + ofst->SetStart(0); + + KALDI_VLOG(4) << "init:" << num_toks_/2 + 3 << " buckets:" + << tok_map.bucket_count() << " load:" << tok_map.load_factor() + << " max:" << tok_map.max_load_factor(); + // Now create all arcs. + for (int32 f = 0; f <= num_frames; f++) { + for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + for (ForwardLinkT *l = tok->links; + l != NULL; + l = l->next) { + typename unordered_map::const_iterator + iter = tok_map.find(l->next_tok); + StateId nextstate = iter->second; + KALDI_ASSERT(iter != tok_map.end()); + BaseFloat cost_offset = 0.0; + if (l->ilabel != 0) { // emitting.. + KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); + cost_offset = cost_offsets_[f]; + } + Arc arc(l->ilabel, l->olabel, + Weight(l->graph_cost, l->acoustic_cost - cost_offset), + nextstate); + ofst->AddArc(cur_state, arc); + } + if (f == num_frames) { + if (use_final_probs && !final_costs.empty()) { + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) + ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); + } else { + ofst->SetFinal(cur_state, LatticeWeight::One()); + } + } + } + } + + if (!decoding_finalized_) { // recover last token list + RecoverLastTokenList(token_orig_cost); + } + return (ofst->NumStates() > 0); +} + + +// When GetRawLattice() is called during decoding, the +// active_toks_[last_frame] is changed. To keep the consistency of function +// ProcessForFrame(), recover it. +// Notice: as new token will be added to the head of TokenList, tok->next +// will not be affacted. +template +void LatticeFasterDecoderCombineTpl::RecoverLastTokenList( + const std::unordered_map &token_orig_cost) { + if (!token_orig_cost.empty()) { + for (Token* tok = active_toks_[active_toks_.size() - 1].toks; + tok != NULL;) { + if (token_orig_cost.find(tok) != token_orig_cost.end()) { + DeleteForwardLinks(tok); + tok->tot_cost = token_orig_cost.find(tok)->second; + tok->in_current_queue = false; + tok = tok->next; + } else { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + } +} + +// This function is now deprecated, since now we do determinization from outside +// the LatticeFasterDecoder class. Outputs an FST corresponding to the +// lattice-determinized lattice (one path per word sequence). +template +bool LatticeFasterDecoderCombineTpl::GetLattice( + CompactLattice *ofst, + bool use_final_probs) { + Lattice raw_fst; + GetRawLattice(&raw_fst, use_final_probs); + Invert(&raw_fst); // make it so word labels are on the input. + // (in phase where we get backward-costs). + fst::ILabelCompare ilabel_comp; + ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes + // lattice-determinization more efficient. + + fst::DeterminizeLatticePrunedOptions lat_opts; + lat_opts.max_mem = config_.det_opts.max_mem; + + DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts); + raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed. + Connect(ofst); // Remove unreachable states... there might be + // a small number of these, in some cases. + // Note: if something went wrong and the raw lattice was empty, + // we should still get to this point in the code without warnings or failures. + return (ofst->NumStates() != 0); +} + +/* + A note on the definition of extra_cost. + + extra_cost is used in pruning tokens, to save memory. + + Define the 'forward cost' of a token as zero for any token on the frame + we're currently decoding; and for other frames, as the shortest-path cost + between that token and a token on the frame we're currently decoding. + (by "currently decoding" I mean the most recently processed frame). + + Then define the extra_cost of a token (always >= 0) as the forward-cost of + the token minus the smallest forward-cost of any token on the same frame. + + We can use the extra_cost to accurately prune away tokens that we know will + never appear in the lattice. If the extra_cost is greater than the desired + lattice beam, the token would provably never appear in the lattice, so we can + prune away the token. + + The advantage of storing the extra_cost rather than the forward-cost, is that + it is less costly to keep the extra_cost up-to-date when we process new frames. + When we process a new frame, *all* the previous frames' forward-costs would change; + but in general the extra_cost will change only for a finite number of frames. + (Actually we don't update all the extra_costs every time we update a frame; we + only do it every 'config_.prune_interval' frames). + */ + +// FindOrAddToken either locates a token in hash map "token_map" +// or if necessary inserts a new, empty token (i.e. with no forward links) +// for the current frame. [note: it's inserted if necessary into hash toks_ +// and also into the singly linked list of tokens active on this frame +// (whose head is at active_toks_[frame]). +template +inline Token* LatticeFasterDecoderCombineTpl::FindOrAddToken( + StateId state, int32 token_list_index, BaseFloat tot_cost, + Token *backpointer, StateIdToTokenMap *token_map, bool *changed) { + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + KALDI_ASSERT(token_list_index < active_toks_.size()); + Token *&toks = active_toks_[token_list_index].toks; + typename StateIdToTokenMap::iterator e_found = token_map->find(state); + if (e_found == token_map->end()) { // no such token presently. + const BaseFloat extra_cost = 0.0; + // tokens on the currently final frame have zero extra_cost + // as any of them could end up + // on the winning path. + Token *new_tok = new Token (tot_cost, extra_cost, NULL, toks, backpointer); + // NULL: no forward links yet + toks = new_tok; + num_toks_++; + // insert into the map + (*token_map)[state] = new_tok; + if (changed) *changed = true; + return new_tok; + } else { + Token *tok = e_found->second; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { // replace old token + tok->tot_cost = tot_cost; + // SetBackpointer() just does tok->backpointer = backpointer in + // the case where Token == BackpointerToken, else nothing. + tok->SetBackpointer(backpointer); + // we don't allocate a new token, the old stays linked in active_toks_ + // we only replace the tot_cost + // in the current frame, there are no forward links (and no extra_cost) + // only in ProcessNonemitting we have to delete forward links + // in case we visit a state for the second time + // those forward links, that lead to this replaced token before: + // they remain and will hopefully be pruned later (PruneForwardLinks...) + if (changed) *changed = true; + } else { + if (changed) *changed = false; + } + return tok; + } +} + +// prunes outgoing links for all tokens in active_toks_[frame] +// it's called by PruneActiveTokens +// all links, that have link_extra_cost > lattice_beam are pruned +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinks( + int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, BaseFloat delta) { + // delta is the amount by which the extra_costs must change + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + + *extra_costs_changed = false; + *links_pruned = false; + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen. + if (!warned_) { + KALDI_WARN << "No tokens alive [doing pruning].. warning first " + "time only for each utterance\n"; + warned_ = true; + } + } + + // We have to iterate until there is no more change, because the links + // are not guaranteed to be in topological order. + bool changed = true; // difference new minus old extra cost >= delta ? + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost for tok. + BaseFloat tok_extra_cost = std::numeric_limits::infinity(); + // tok_extra_cost is the best (min) of link_extra_cost of outgoing links + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); // difference in brackets is >= 0 + // link_exta_cost is the difference in score between the best paths + // through link source state and through link destination state + KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + *links_pruned = true; + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; // move to next link + link = link->next; + } + } // for all outgoing links + if (fabs(tok_extra_cost - tok->extra_cost) > delta) + changed = true; // difference new minus old is bigger than delta + tok->extra_cost = tok_extra_cost; + // will be +infinity or <= lattice_beam_. + // infinity indicates, that no forward link survived pruning + } // for all Token on active_toks_[frame] + if (changed) *extra_costs_changed = true; + + // Note: it's theoretically possible that aggressive compiler + // optimizations could cause an infinite loop here for small delta and + // high-dynamic-range scores. + } // while changed +} + +// PruneForwardLinksFinal is a version of PruneForwardLinks that we call +// on the final frame. If there are final tokens active, it uses +// the final-probs for pruning, otherwise it treats all tokens as final. +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinksFinal() { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame_plus_one = active_toks_.size() - 1; + + if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. + KALDI_WARN << "No tokens alive at end of file"; + + typedef typename unordered_map::const_iterator IterType; + ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); + decoding_finalized_ = true; + + // Now go through tokens on this frame, pruning forward links... may have to + // iterate a few times until there is no more change, because the list is not + // in topological order. This is a modified version of the code in + // PruneForwardLinks, but here we also take account of the final-probs. + bool changed = true; + BaseFloat delta = 1.0e-05; + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost. It has a term in it that corresponds + // to the "final-prob", so instead of initializing tok_extra_cost to infinity + // below we set it to the difference between the (score+final_prob) of this token, + // and the best such (score+final_prob). + BaseFloat final_cost; + if (final_costs_.empty()) { + final_cost = 0.0; + } else { + IterType iter = final_costs_.find(tok); + if (iter != final_costs_.end()) + final_cost = iter->second; + else + final_cost = std::numeric_limits::infinity(); + } + BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_; + // tok_extra_cost will be a "min" over either directly being final, or + // being indirectly final through other links, and the loop below may + // decrease its value: + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; + link = link->next; + } + } + // prune away tokens worse than lattice_beam above best path. This step + // was not necessary in the non-final case because then, this case + // showed up as having no forward links. Here, the tok_extra_cost has + // an extra component relating to the final-prob. + if (tok_extra_cost > config_.lattice_beam) + tok_extra_cost = std::numeric_limits::infinity(); + // to be pruned in PruneTokensForFrame + + if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) + changed = true; + tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. + } + } // while changed +} + +template +BaseFloat LatticeFasterDecoderCombineTpl::FinalRelativeCost() const { + if (!decoding_finalized_) { + BaseFloat relative_cost; + ComputeFinalCosts(NULL, &relative_cost, NULL); + return relative_cost; + } else { + // we're not allowed to call that function if FinalizeDecoding() has + // been called; return a cached value. + return final_relative_cost_; + } +} + + +// Prune away any tokens on this frame that have no forward links. +// [we don't do this in PruneForwardLinks because it would give us +// a problem with dangling pointers]. +// It's called by PruneActiveTokens if any forward links have been pruned +template +void LatticeFasterDecoderCombineTpl::PruneTokensForFrame( + int32 frame_plus_one) { + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + if (toks == NULL) + KALDI_WARN << "No tokens alive [doing pruning]"; + Token *tok, *next_tok, *prev_tok = NULL; + for (tok = toks; tok != NULL; tok = next_tok) { + next_tok = tok->next; + if (tok->extra_cost == std::numeric_limits::infinity()) { + // token is unreachable from end of graph; (no forward links survived) + // excise tok from list and delete tok. + if (prev_tok != NULL) prev_tok->next = tok->next; + else toks = tok->next; + delete tok; + num_toks_--; + } else { // fetch next Token + prev_tok = tok; + } + } +} + +// Go backwards through still-alive tokens, pruning them, starting not from +// the current frame (where we want to keep all tokens) but from the frame before +// that. We go backwards through the frames and stop when we reach a point +// where the delta-costs are not changing (and the delta controls when we consider +// a cost to have "not changed"). +template +void LatticeFasterDecoderCombineTpl::PruneActiveTokens( + BaseFloat delta) { + int32 cur_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // The index "f" below represents a "frame plus one", i.e. you'd have to subtract + // one to get the corresponding index for the decodable object. + for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) { + // Reason why we need to prune forward links in this situation: + // (1) we have never pruned them (new TokenList) + // (2) we have not yet pruned the forward links to the next f, + // after any of those tokens have changed their extra_cost. + if (active_toks_[f].must_prune_forward_links) { + bool extra_costs_changed = false, links_pruned = false; + PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); + if (extra_costs_changed && f > 0) // any token has changed extra_cost + active_toks_[f-1].must_prune_forward_links = true; + if (links_pruned) // any link was pruned + active_toks_[f].must_prune_tokens = true; + active_toks_[f].must_prune_forward_links = false; // job done + } + if (f+1 < cur_frame_plus_one && // except for last f (no forward links) + active_toks_[f+1].must_prune_tokens) { + PruneTokensForFrame(f+1); + active_toks_[f+1].must_prune_tokens = false; + } + } + KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +template +void LatticeFasterDecoderCombineTpl::ComputeFinalCosts( + unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const { + KALDI_ASSERT(!decoding_finalized_); + if (final_costs != NULL) + final_costs->clear(); + BaseFloat infinity = std::numeric_limits::infinity(); + BaseFloat best_cost = infinity, + best_cost_with_final = infinity; + + // The final tokens are recorded in unordered_map "cur_toks_". + for (IterType iter = cur_toks_.begin(); iter != cur_toks_.end(); iter++) { + StateId state = iter->first; + Token *tok = iter->second; + BaseFloat final_cost = fst_->Final(state).Value(); + BaseFloat cost = tok->tot_cost, + cost_with_final = cost + final_cost; + best_cost = std::min(cost, best_cost); + best_cost_with_final = std::min(cost_with_final, best_cost_with_final); + if (final_costs != NULL && final_cost != infinity) + (*final_costs)[tok] = final_cost; + } + if (final_relative_cost != NULL) { + if (best_cost == infinity && best_cost_with_final == infinity) { + // Likely this will only happen if there are no tokens surviving. + // This seems the least bad way to handle it. + *final_relative_cost = infinity; + } else { + *final_relative_cost = best_cost_with_final - best_cost; + } + } + if (final_best_cost != NULL) { + if (best_cost_with_final != infinity) { // final-state exists. + *final_best_cost = best_cost_with_final; + } else { // no final-state exists. + *final_best_cost = best_cost; + } + } +} + +template +void LatticeFasterDecoderCombineTpl::AdvanceDecoding( + DecodableInterface *decodable, + int32 max_num_frames) { + if (std::is_same >::value) { + // if the type 'FST' is the FST base-class, then see if the FST type of fst_ + // is actually VectorFst or ConstFst. If so, call the AdvanceDecoding() + // function after casting *this to the more specific type. + if (fst_->Type() == "const") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } else if (fst_->Type() == "vector") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } + } + + + KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ && + "You must call InitDecoding() before AdvanceDecoding"); + int32 num_frames_ready = decodable->NumFramesReady(); + // num_frames_ready must be >= num_frames_decoded, or else + // the number of frames ready must have decreased (which doesn't + // make sense) or the decodable object changed between calls + // (which isn't allowed). + KALDI_ASSERT(num_frames_ready >= NumFramesDecoded()); + int32 target_frames_decoded = num_frames_ready; + if (max_num_frames >= 0) + target_frames_decoded = std::min(target_frames_decoded, + NumFramesDecoded() + max_num_frames); + while (NumFramesDecoded() < target_frames_decoded) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + ProcessForFrame(decodable); + } +} + +// FinalizeDecoding() is a version of PruneActiveTokens that we call +// (optionally) on the final frame. Takes into account the final-prob of +// tokens. This function used to be called PruneActiveTokensFinal(). +template +void LatticeFasterDecoderCombineTpl::FinalizeDecoding() { + ProcessNonemitting(NULL); + int32 final_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // PruneForwardLinksFinal() prunes final frame (with final-probs), and + // sets decoding_finalized_. + PruneForwardLinksFinal(); + for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { + bool b1, b2; // values not used. + BaseFloat dontcare = 0.0; // delta of zero means we must always update + PruneForwardLinks(f, &b1, &b2, dontcare); + PruneTokensForFrame(f + 1); + } + PruneTokensForFrame(0); + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +/// Gets the weight cutoff. +template +BaseFloat LatticeFasterDecoderCombineTpl::GetCutoff( + const StateIdToTokenMap &toks, BaseFloat *adaptive_beam, + StateId *best_state_id, Token **best_token) { + // positive == high cost == bad. + // best_weight is the minimum value. + BaseFloat best_weight = std::numeric_limits::infinity(); + if (config_.max_active == std::numeric_limits::max() && + config_.min_active == 0) { + for (IterType iter = toks.begin(); iter != toks.end(); iter++) { + BaseFloat w = static_cast(iter->second->tot_cost); + if (w < best_weight) { + best_weight = w; + if (best_token) { + *best_state_id = iter->first; + *best_token = iter->second; + } + } + } + if (adaptive_beam != NULL) *adaptive_beam = config_.beam; + return best_weight + config_.beam; + } else { + tmp_array_.clear(); + for (IterType iter = toks.begin(); iter != toks.end(); iter++) { + BaseFloat w = static_cast(iter->second->tot_cost); + tmp_array_.push_back(w); + if (w < best_weight) { + best_weight = w; + if (best_token) { + *best_state_id = iter->first; + *best_token = iter->second; + } + } + } + + BaseFloat beam_cutoff = best_weight + config_.beam, + min_active_cutoff = std::numeric_limits::infinity(), + max_active_cutoff = std::numeric_limits::infinity(); + + KALDI_VLOG(6) << "Number of emitting tokens on frame " + << NumFramesDecoded() - 1 << " is " << tmp_array_.size(); + + if (tmp_array_.size() > static_cast(config_.max_active)) { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.max_active, + tmp_array_.end()); + max_active_cutoff = tmp_array_[config_.max_active]; + } + if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. + if (adaptive_beam) + *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta; + return max_active_cutoff; + } + if (tmp_array_.size() > static_cast(config_.min_active)) { + if (config_.min_active == 0) min_active_cutoff = best_weight; + else { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.min_active, + tmp_array_.size() > static_cast(config_.max_active) ? + tmp_array_.begin() + config_.max_active : tmp_array_.end()); + min_active_cutoff = tmp_array_[config_.min_active]; + } + } + if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. + if (adaptive_beam) + *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta; + return min_active_cutoff; + } else { + *adaptive_beam = config_.beam; + return beam_cutoff; + } + } +} + +template +void LatticeFasterDecoderCombineTpl::ProcessForFrame( + DecodableInterface *decodable) { + KALDI_ASSERT(active_toks_.size() > 0); + int32 frame = active_toks_.size() - 1; // frame is the frame-index + // (zero-based) used to get likelihoods + // from the decodable object. + active_toks_.resize(active_toks_.size() + 1); + + prev_toks_.swap(cur_toks_); + cur_toks_.clear(); + if (prev_toks_.empty()) { + if (!warned_) { + KALDI_WARN << "Error, no surviving tokens on frame " << frame; + warned_ = true; + } + } + + BaseFloat adaptive_beam; + Token *best_tok = NULL; + StateId best_tok_state_id; + // "cur_cutoff" is used to constrain the epsilon emittion in current frame. + // It will not be updated. + BaseFloat cur_cutoff = GetCutoff(prev_toks_, &adaptive_beam, + &best_tok_state_id, &best_tok); + KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " + << adaptive_beam; + + + // pruning "online" before having seen all tokens + + // "next_cutoff" is used to limit a new token in next frame should be handle + // or not. It will be updated along with the further processing. + BaseFloat next_cutoff = std::numeric_limits::infinity(); + // "cost_offset" contains the acoustic log-likelihoods on current frame in + // order to keep everything in a nice dynamic range. Reduce roundoff errors. + BaseFloat cost_offset = 0.0; + + // First process the best token to get a hopefully + // reasonably tight bound on the next cutoff. The only + // products of the next block are "next_cutoff" and "cost_offset". + // Notice: As the difference between the combine version and the traditional + // version, this "best_tok" is choosen from emittion tokens. Normally, the + // best token of one frame comes from an epsilon non-emittion. So the best + // token is a looser boundary. We use it to estimate a bound on the next + // cutoff and we will update the "next_cutoff" once we have better tokens. + // The "next_cutoff" will be updated in further processing. + if (best_tok) { + cost_offset = - best_tok->tot_cost; + for (fst::ArcIterator aiter(*fst_, best_tok_state_id); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + // ac_cost + graph_cost + BaseFloat new_weight = arc.weight.Value() + cost_offset - + decodable->LogLikelihood(frame, arc.ilabel) + best_tok->tot_cost; + if (new_weight + adaptive_beam < next_cutoff) + next_cutoff = new_weight + adaptive_beam; + } + } + } + + // Store the offset on the acoustic likelihoods that we're applying. + // Could just do cost_offsets_.push_back(cost_offset), but we + // do it this way as it's more robust to future code changes. + cost_offsets_.resize(frame + 1, 0.0); + cost_offsets_[frame] = cost_offset; + + // Build a queue which contains the emittion tokens from previous frame. + for (IterType iter = prev_toks_.begin(); iter != prev_toks_.end(); iter++) { + cur_queue_.push(iter->first); + iter->second->in_current_queue = true; + } + + // Iterator the "cur_queue_" to process non-emittion and emittion arcs in fst. + while (!cur_queue_.empty()) { + StateId state = cur_queue_.front(); + cur_queue_.pop(); + + KALDI_ASSERT(prev_toks_.find(state) != prev_toks_.end()); + Token *tok = prev_toks_[state]; + + BaseFloat cur_cost = tok->tot_cost; + tok->in_current_queue = false; // out of queue + if (cur_cost > cur_cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, + tok, &prev_toks_, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed && !new_tok->in_current_queue) { + cur_queue_.push(arc.nextstate); + new_tok->in_current_queue = true; + } + } + } else { // propagate emitting + BaseFloat graph_cost = arc.weight.Value(), + ac_cost = cost_offset - decodable->LogLikelihood(frame, arc.ilabel), + cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + if (tot_cost > next_cutoff) continue; + else if (tot_cost + adaptive_beam < next_cutoff) + next_cutoff = tot_cost + adaptive_beam; // a tighter boundary for emitting + + // no change flag is needed + Token *next_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, + tok, &cur_toks_, NULL); + // Add ForwardLink from tok to next_tok. Put it on the head of tok->link + // list + tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, + graph_cost, ac_cost, tok->links); + } + } // for all arcs + } // end of while loop + KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() - 1 + << " is " << prev_toks_.size(); +} + + +template +void LatticeFasterDecoderCombineTpl::ProcessNonemitting( + std::unordered_map *token_orig_cost) { + if (token_orig_cost) { // Build the elements which are used to recover + for (IterType iter = cur_toks_.begin(); iter != cur_toks_.end(); iter++) { + (*token_orig_cost)[iter->second] = iter->second->tot_cost; + } + } + + StateIdToTokenMap *tmp_toks; + if (token_orig_cost) { // "token_orig_cost" isn't NULL. It means we need to + // recover active_toks_[last_frame] and "cur_toks_" + // will be used in the future. + tmp_toks = new StateIdToTokenMap(cur_toks_); + } else { + tmp_toks = &cur_toks_; + } + + int32 frame = active_toks_.size() - 1; + // Build the queue to process non-emitting arcs. + for (IterType iter = tmp_toks->begin(); iter != tmp_toks->end(); iter++) { + if (fst_->NumInputEpsilons(iter->first) != 0) { + cur_queue_.push(iter->first); + iter->second->in_current_queue = true; + } + } + + // "cur_cutoff" is used to constrain the epsilon emittion in current frame. + // It will not be updated. + BaseFloat adaptive_beam; + BaseFloat cur_cutoff = GetCutoff(*tmp_toks, &adaptive_beam, NULL, NULL); + + while (!cur_queue_.empty()) { + StateId state = cur_queue_.front(); + cur_queue_.pop(); + + KALDI_ASSERT(tmp_toks->find(state) != tmp_toks->end()); + Token *tok = (*tmp_toks)[state]; + BaseFloat cur_cost = tok->tot_cost; + if (cur_cost > cur_cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, + tok, tmp_toks, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed && !new_tok->in_current_queue) { + cur_queue_.push(arc.nextstate); + new_tok->in_current_queue = true; + } + } + } + } // end of for loop + tok->in_current_queue = false; + } // end of while loop + if (token_orig_cost) delete tmp_toks; +} + + + +// static inline +template +void LatticeFasterDecoderCombineTpl::DeleteForwardLinks(Token *tok) { + ForwardLinkT *l = tok->links, *m; + while (l != NULL) { + m = l->next; + delete l; + l = m; + } + tok->links = NULL; +} + + +template +void LatticeFasterDecoderCombineTpl::ClearActiveTokens() { + // a cleanup routine, at utt end/begin + for (size_t i = 0; i < active_toks_.size(); i++) { + // Delete all tokens alive on this frame, and any forward + // links they may have. + for (Token *tok = active_toks_[i].toks; tok != NULL; ) { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + active_toks_.clear(); + KALDI_ASSERT(num_toks_ == 0); +} + +// static +template +void LatticeFasterDecoderCombineTpl::TopSortTokens( + Token *tok_list, std::vector *topsorted_list) { + unordered_map token2pos; + typedef typename unordered_map::iterator IterType; + int32 num_toks = 0; + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + num_toks++; + int32 cur_pos = 0; + // We assign the tokens numbers num_toks - 1, ... , 2, 1, 0. + // This is likely to be in closer to topological order than + // if we had given them ascending order, because of the way + // new tokens are put at the front of the list. + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + token2pos[tok] = num_toks - ++cur_pos; + + unordered_set reprocess; + + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) { + Token *tok = iter->first; + int32 pos = iter->second; + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + // We only need to consider epsilon links, since non-epsilon links + // transition between frames and this function only needs to sort a list + // of tokens from a single frame. + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { // another token on this frame, + // so must consider it. + int32 next_pos = following_iter->second; + if (next_pos < pos) { // reassign the position of the next Token. + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + // In case we had previously assigned this token to be reprocessed, we can + // erase it from that set because it's "happy now" (we just processed it). + reprocess.erase(tok); + } + + size_t max_loop = 1000000, loop_count; // max_loop is to detect epsilon cycles. + for (loop_count = 0; + !reprocess.empty() && loop_count < max_loop; ++loop_count) { + std::vector reprocess_vec; + for (typename unordered_set::iterator iter = reprocess.begin(); + iter != reprocess.end(); ++iter) + reprocess_vec.push_back(*iter); + reprocess.clear(); + for (typename std::vector::iterator iter = reprocess_vec.begin(); + iter != reprocess_vec.end(); ++iter) { + Token *tok = *iter; + int32 pos = token2pos[tok]; + // Repeat the processing we did above (for comments, see above). + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { + int32 next_pos = following_iter->second; + if (next_pos < pos) { + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + } + } + KALDI_ASSERT(loop_count < max_loop && "Epsilon loops exist in your decoding " + "graph (this is not allowed!)"); + + topsorted_list->clear(); + topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between. + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) + (*topsorted_list)[iter->second] = iter->first; +} + +// Instantiate the template for the combination of token types and FST types +// that we'll need. +template class LatticeFasterDecoderCombineTpl, decodercombine::StdToken>; +template class LatticeFasterDecoderCombineTpl, decodercombine::StdToken >; +template class LatticeFasterDecoderCombineTpl, decodercombine::StdToken >; +template class LatticeFasterDecoderCombineTpl; + +template class LatticeFasterDecoderCombineTpl , decodercombine::BackpointerToken>; +template class LatticeFasterDecoderCombineTpl, decodercombine::BackpointerToken >; +template class LatticeFasterDecoderCombineTpl, decodercombine::BackpointerToken >; +template class LatticeFasterDecoderCombineTpl; + + +} // end namespace kaldi. diff --git a/src/decoder/lattice-faster-decoder-combine-itermap.h b/src/decoder/lattice-faster-decoder-combine-itermap.h new file mode 100644 index 00000000000..c0b76e4126d --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine-itermap.h @@ -0,0 +1,561 @@ +// decoder/lattice-faster-decoder-combine.h + +// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ +#define KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ + + +#include "util/stl-utils.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/grammar-fst.h" +#include "decoder/lattice-faster-decoder.h" + +namespace kaldi { + +struct LatticeFasterDecoderCombineConfig { + BaseFloat beam; + int32 max_active; + int32 min_active; + BaseFloat lattice_beam; + int32 prune_interval; + bool determinize_lattice; // not inspected by this class... used in + // command-line program. + BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat hash_ratio; + BaseFloat prune_scale; // Note: we don't make this configurable on the command line, + // it's not a very important parameter. It affects the + // algorithm that prunes the tokens as we go. + // Most of the options inside det_opts are not actually queried by the + // LatticeFasterDecoder class itself, but by the code that calls it, for + // example in the function DecodeUtteranceLatticeFaster. + fst::DeterminizeLatticePhonePrunedOptions det_opts; + + LatticeFasterDecoderCombineConfig(): beam(16.0), + max_active(std::numeric_limits::max()), + min_active(200), + lattice_beam(10.0), + prune_interval(25), + determinize_lattice(true), + beam_delta(0.5), + hash_ratio(2.0), + prune_scale(0.1) { } + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); + opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; " + "more accurate"); + opts->Register("min-active", &min_active, "Decoder minimum #active states."); + opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, " + "and deeper lattices"); + opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " + "which to prune tokens"); + opts->Register("determinize-lattice", &determinize_lattice, "If true, " + "determinize the lattice (lattice-determinization, keeping only " + "best pdf-sequence for each word-sequence)."); + opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " + "parameter is obscure and relates to a speedup in the way the " + "max-active constraint is applied. Larger is more accurate."); + opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " + "control hash behavior"); + } + void Check() const { + KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 + && min_active <= max_active + && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0 + && prune_scale > 0.0 && prune_scale < 1.0); + } +}; + + +namespace decodercombine { +// We will template the decoder on the token type as well as the FST type; this +// is a mechanism so that we can use the same underlying decoder code for +// versions of the decoder that support quickly getting the best path +// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also +// those that do not (LatticeFasterDecoder). + + +// ForwardLinks are the links from a token to a token on the next frame. +// or sometimes on the current frame (for input-epsilon links). +template +struct ForwardLink { + using Label = fst::StdArc::Label; + + Token *next_tok; // the next token [or NULL if represents final-state] + Label ilabel; // ilabel on arc + Label olabel; // olabel on arc + BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.) + BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc + ForwardLink *next; // next in singly-linked list of forward arcs (arcs + // in the state-level lattice) from a token. + inline ForwardLink(Token *next_tok, Label ilabel, Label olabel, + BaseFloat graph_cost, BaseFloat acoustic_cost, + ForwardLink *next): + next_tok(next_tok), ilabel(ilabel), olabel(olabel), + graph_cost(graph_cost), acoustic_cost(acoustic_cost), + next(next) { } +}; + + +struct StdToken { + using ForwardLinkT = ForwardLink; + using Token = StdToken; + + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + Token *next; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_current_queue; + + // This function does nothing and should be optimized out; it's needed + // so we can share the regular LatticeFasterDecoderTpl code and the code + // for LatticeFasterOnlineDecoder that supports fast traceback. + inline void SetBackpointer (Token *backpointer) { } + + // This constructor just ignores the 'backpointer' argument. That argument is + // needed so that we can use the same decoder code for LatticeFasterDecoderTpl + // and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a + // fast way to obtain the best path). + inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next), + in_current_queue(false) { } +}; + +struct BackpointerToken { + using ForwardLinkT = ForwardLink; + using Token = BackpointerToken; + + // BackpointerToken is like Token but also + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + BackpointerToken *next; + + // Best preceding BackpointerToken (could be a on this frame, connected to + // this via an epsilon transition, or on a previous frame). This is only + // required for an efficient GetBestPath function in + // LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation + // (the "links" list is what stores the forward links, for that). + Token *backpointer; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_current_queue; + + inline void SetBackpointer (Token *backpointer) { + this->backpointer = backpointer; + } + + inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next), + backpointer(backpointer), in_current_queue(false) { } +}; + +} // namespace decoder + + +/** This is the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to support + quick lookup of the current best path (see lattice-faster-online-decoder.h) + + The FST you invoke this decoder with is expected to equal + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ +template +class LatticeFasterDecoderCombineTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ForwardLinkT = decodercombine::ForwardLink; + + using StateIdToTokenMap = typename std::unordered_map; + //using StateIdToTokenMap = typename std::unordered_map, std::equal_to, + // fst::PoolAllocator > >; + using IterType = typename StateIdToTokenMap::const_iterator; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeFasterDecoderCombineTpl(const FST &fst, + const LatticeFasterDecoderCombineConfig &config); + + // This version of the constructor takes ownership of the fst, and will delete + // it when this object is destroyed. + LatticeFasterDecoderCombineTpl(const LatticeFasterDecoderCombineConfig &config, + FST *fst); + + void SetOptions(const LatticeFasterDecoderCombineConfig &config) { + config_ = config; + } + + const LatticeFasterDecoderCombineConfig &GetOptions() const { + return config_; + } + + ~LatticeFasterDecoderCombineTpl(); + + /// Decodes until there are no more frames left in the "decodable" object.. + /// note, this may block waiting for input if the "decodable" object blocks. + /// Returns true if any kind of traceback is available (not necessarily from a + /// final state). + bool Decode(DecodableInterface *decodable); + + + /// says whether a final-state was active on the last frame. If it was not, the + /// lattice (or traceback) will end with states that are not final-states. + bool ReachedFinal() const { + return FinalRelativeCost() != std::numeric_limits::infinity(); + } + + /// Outputs an FST corresponding to the single best path through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. Note: this just calls GetRawLattice() + /// and figures out the shortest path. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true); + + /// Outputs an FST corresponding to the raw, state-level + /// tracebacks. Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state + /// of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + /// The raw lattice will be topologically sorted. + /// The function can be called during decoding, it will process non-emitting + /// arcs from "cur_toks_" map to get tokens from both non-emitting and + /// emitting arcs for getting raw lattice. Then recover it to ensure the + /// consistency of ProcessForFrame(). + /// + /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, + /// which also supports a pruning beam, in case for some reason + /// you want it pruned tighter than the regular lattice beam. + /// We could put that here in future needed. + bool GetRawLattice(Lattice *ofst, bool use_final_probs = true); + + + + /// [Deprecated, users should now use GetRawLattice and determinize it + /// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper]. + /// Outputs an FST corresponding to the lattice-determinized + /// lattice (one path per word sequence). Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state of the graph + /// then it will include those as final-probs, else it will treat all + /// final-probs as one. + bool GetLattice(CompactLattice *ofst, + bool use_final_probs = true); + + /// InitDecoding initializes the decoding, and should only be used if you + /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to + /// call this. You can also call InitDecoding if you have already decoded an + /// utterance and want to start with a new utterance. + void InitDecoding(); + + /// This will decode until there are no more frames ready in the decodable + /// object. You can keep calling it each time more frames become available. + /// If max_num_frames is specified, it specifies the maximum number of frames + /// the function will decode before returning. + void AdvanceDecoding(DecodableInterface *decodable, + int32 max_num_frames = -1); + + /// This function may be optionally called after AdvanceDecoding(), when you + /// do not plan to decode any further. It does an extra pruning step that + /// will help to prune the lattices output by GetLattice and (particularly) + /// GetRawLattice more accurately, particularly toward the end of the + /// utterance. It does this by using the final-probs in pruning (if any + /// final-state survived); it also does a final pruning step that visits all + /// states (the pruning that is done during decoding may fail to prune states + /// that are within kPruningScale = 0.1 outside of the beam). If you call + /// this, you cannot call AdvanceDecoding again (it will fail), and you + /// cannot call GetLattice() and related functions with use_final_probs = + /// false. + /// Used to be called PruneActiveTokensFinal(). + void FinalizeDecoding(); + + /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + /// more information. It returns the difference between the best (final-cost + /// plus cost) of any token on the final frame, and the best cost of any token + /// on the final frame. If it is infinity it means no final-states were + /// present on the final frame. It will usually be nonnegative. If it not + /// too positive (e.g. < 5 is my first guess, but this is not tested) you can + /// take it as a good indication that we reached the final-state with + /// reasonable likelihood. + BaseFloat FinalRelativeCost() const; + + + // Returns the number of frames decoded so far. The value returned changes + // whenever we call ProcessForFrame(). + inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } + + protected: + // we make things protected instead of private, as code in + // LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the + // internals. + + // Deletes the elements of the singly linked list tok->links. + inline static void DeleteForwardLinks(Token *tok); + + // head of per-frame list of Tokens (list is in topological order), + // and something saying whether we ever pruned it using PruneForwardLinks. + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + TokenList(): toks(NULL), must_prune_forward_links(true), + must_prune_tokens(true) { } + }; + + // FindOrAddToken either locates a token in hash map "token_map", or if necessary + // inserts a new, empty token (i.e. with no forward links) for the current + // frame. [note: it's inserted if necessary into hash map and also into the + // singly linked list of tokens active on this frame (whose head is at + // active_toks_[frame]). The frame_plus_one argument is the acoustic frame + // index plus one, which is used to index into the active_toks_ array. + // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the + // token was newly created or the cost changed. + // If Token == StdToken, the 'backpointer' argument has no purpose (and will + // hopefully be optimized out). + inline Token *FindOrAddToken(StateId state, int32 token_list_index, + BaseFloat tot_cost, Token *backpointer, + StateIdToTokenMap *token_map, + bool *changed); + + // prunes outgoing links for all tokens in active_toks_[frame] + // it's called by PruneActiveTokens + // all links, that have link_extra_cost > lattice_beam are pruned + // delta is the amount by which the extra_costs must change + // before we set *extra_costs_changed = true. + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, + BaseFloat delta); + + // This function computes the final-costs for tokens active on the final + // frame. It outputs to final-costs, if non-NULL, a map from the Token* + // pointer to the final-prob of the corresponding state, for all Tokens + // that correspond to states that have final-probs. This map will be + // empty if there were no final-probs. It outputs to + // final_relative_cost, if non-NULL, the difference between the best + // forward-cost including the final-prob cost, and the best forward-cost + // without including the final-prob cost (this will usually be positive), or + // infinity if there were no final-probs. [c.f. FinalRelativeCost(), which + // outputs this quanitity]. It outputs to final_best_cost, if + // non-NULL, the lowest for any token t active on the final frame, of + // forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in + // the graph of the state corresponding to token t, or the best of + // forward-cost[t] if there were no final-probs active on the final frame. + // You cannot call this after FinalizeDecoding() has been called; in that + // case you should get the answer from class-member variables. + void ComputeFinalCosts(unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const; + + // PruneForwardLinksFinal is a version of PruneForwardLinks that we call + // on the final frame. If there are final tokens active, it uses + // the final-probs for pruning, otherwise it treats all tokens as final. + void PruneForwardLinksFinal(); + + // Prune away any tokens on this frame that have no forward links. + // [we don't do this in PruneForwardLinks because it would give us + // a problem with dangling pointers]. + // It's called by PruneActiveTokens if any forward links have been pruned + void PruneTokensForFrame(int32 frame_plus_one); + + + // Go backwards through still-alive tokens, pruning them if the + // forward+backward cost is more than lat_beam away from the best path. It's + // possible to prove that this is "correct" in the sense that we won't lose + // anything outside of lat_beam, regardless of what happens in the future. + // delta controls when it considers a cost to have changed enough to continue + // going backward and propagating the change. larger delta -> will recurse + // less far. + void PruneActiveTokens(BaseFloat delta); + + /// Processes non-emitting (epsilon) arcs and emitting arcs for one frame + /// together. It takes the emittion tokens in "prev_toks_" from last frame. + /// Generates non-emitting tokens for previous frame and emitting tokens for + /// next frame. + /// Notice: The emitting tokens for the current frame means the token take + /// acoustic scores of the current frame. (i.e. the destnations of emitting + /// arcs.) + void ProcessForFrame(DecodableInterface *decodable); + + /// Processes nonemitting (epsilon) arcs for one frame. + /// Calls this function once when all frames were processed. + /// Or calls it in GetRawLattice() to generate the complete token list for + /// the last frame. [Deal With the tokens in map "cur_toks_" which would + /// only contains emittion tokens from previous frame.] + /// If the map, "token_orig_cost", isn't NULL, we build the map which will + /// be used to recover "active_toks_[last_frame]" token list for the last + /// frame. + void ProcessNonemitting(std::unordered_map *token_orig_cost); + + /// When GetRawLattice() is called during decoding, the + /// active_toks_[last_frame] is changed. To keep the consistency of function + /// ProcessForFrame(), recover it. + /// Notice: as new token will be added to the head of TokenList, tok->next + /// will not be affacted. + /// "token_orig_cost" is a mapping from token pointer to the tot_cost of the + /// token before propagating non-emitting arcs. It is used to recover the + /// change of original tokens in the last frame and remove the new tokens + /// which come from propagating non-emitting arcs, so that we can guarantee + /// the consistency of function ProcessForFrame(). + void RecoverLastTokenList( + const std::unordered_map &token_orig_cost); + + + /// The "prev_toks_" and "cur_toks_" actually allow us to maintain current + /// and next frames. They are indexed by StateId. It is indexed by frame-index + /// plus one, where the frame-index is zero-based, as used in decodable object. + /// That is, the emitting probs of frame t are accounted for in tokens at + /// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of + /// the graph. + StateIdToTokenMap prev_toks_; + StateIdToTokenMap cur_toks_; + + /// Gets the weight cutoff. + /// Notice: In traiditional version, the histogram prunning method is applied + /// on a complete token list on one frame. But, in this version, it is used + /// on a token list which only contains the emittion part. So the max_active + /// and min_active values might be narrowed. + BaseFloat GetCutoff(const StateIdToTokenMap& toks, + BaseFloat *adaptive_beam, + StateId *best_state_id, Token **best_token); + + std::vector active_toks_; // Lists of tokens, indexed by + // frame (members of TokenList are toks, must_prune_forward_links, + // must_prune_tokens). + std::queue cur_queue_; // temp variable used in ProcessForFrame + // and ProcessNonemitting + std::vector tmp_array_; // used in GetCutoff. + + // fst_ is a pointer to the FST we are decoding from. + const FST *fst_; + // delete_fst_ is true if the pointer fst_ needs to be deleted when this + // object is destroyed. + bool delete_fst_; + + std::vector cost_offsets_; // This contains, for each + // frame, an offset that was added to the acoustic log-likelihoods on that + // frame in order to keep everything in a nice dynamic range i.e. close to + // zero, to reduce roundoff errors. + // Notice: It will only be added to emitting arcs (i.e. cost_offsets_[t] is + // added to arcs from "frame t" to "frame t+1"). + LatticeFasterDecoderCombineConfig config_; + int32 num_toks_; // current total #toks allocated... + bool warned_; + + /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, + /// calling this is optional]. If true, it's forbidden to decode more. Also, + /// if this is set, then the output of ComputeFinalCosts() is in the next + /// three variables. The reason we need to do this is that after + /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some + /// of the tokens on the last frame are freed, so we free the list from toks_ + /// to avoid having dangling pointers hanging around. + bool decoding_finalized_; + /// For the meaning of the next 3 variables, see the comment for + /// decoding_finalized_ above., and ComputeFinalCosts(). + unordered_map final_costs_; + BaseFloat final_relative_cost_; + BaseFloat final_best_cost_; + + // This function takes a singly linked list of tokens for a single frame, and + // outputs a list of them in topological order (it will crash if no such order + // can be found, which will typically be due to decoding graphs with epsilon + // cycles, which are not allowed). Note: the output list may contain NULLs, + // which the caller should pass over; it just happens to be more efficient for + // the algorithm to output a list that contains NULLs. + static void TopSortTokens(Token *tok_list, + std::vector *topsorted_list); + + void ClearActiveTokens(); + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderCombineTpl); +}; + +typedef LatticeFasterDecoderCombineTpl LatticeFasterDecoderCombine; + + + +} // end namespace kaldi. + +#endif diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index fbb67729828..5c87d72fe14 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -50,6 +50,8 @@ template LatticeFasterDecoderCombineTpl::~LatticeFasterDecoderCombineTpl() { ClearActiveTokens(); if (delete_fst_) delete fst_; + //prev_toks_.clear(); + //cur_toks_.clear(); } template From f9bab34401c729838e4d4df8bb3e5ea6c0b9e4ac Mon Sep 17 00:00:00 2001 From: Hang Lyu Date: Sat, 16 Mar 2019 13:27:15 -0400 Subject: [PATCH 14/29] Heap method head file --- .../lattice-faster-decoder-combine-heap.h | 697 ++++++++++++++++++ 1 file changed, 697 insertions(+) create mode 100644 src/decoder/lattice-faster-decoder-combine-heap.h diff --git a/src/decoder/lattice-faster-decoder-combine-heap.h b/src/decoder/lattice-faster-decoder-combine-heap.h new file mode 100644 index 00000000000..48719aa347c --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine-heap.h @@ -0,0 +1,697 @@ +// decoder/lattice-faster-decoder-combine.h + +// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ +#define KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ + + +#include "util/stl-utils.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/grammar-fst.h" +#include "decoder/lattice-faster-decoder.h" + +namespace kaldi { + +struct LatticeFasterDecoderCombineConfig { + BaseFloat beam; + int32 max_active; + int32 min_active; + BaseFloat lattice_beam; + int32 prune_interval; + bool determinize_lattice; // not inspected by this class... used in + // command-line program. + BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat hash_ratio; + BaseFloat prune_scale; // Note: we don't make this configurable on the command line, + // it's not a very important parameter. It affects the + // algorithm that prunes the tokens as we go. + // Most of the options inside det_opts are not actually queried by the + // LatticeFasterDecoder class itself, but by the code that calls it, for + // example in the function DecodeUtteranceLatticeFaster. + fst::DeterminizeLatticePhonePrunedOptions det_opts; + + LatticeFasterDecoderCombineConfig(): beam(16.0), + max_active(std::numeric_limits::max()), + min_active(200), + lattice_beam(10.0), + prune_interval(25), + determinize_lattice(true), + beam_delta(0.5), + hash_ratio(2.0), + prune_scale(0.1) { } + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); + opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; " + "more accurate"); + opts->Register("min-active", &min_active, "Decoder minimum #active states."); + opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, " + "and deeper lattices"); + opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " + "which to prune tokens"); + opts->Register("determinize-lattice", &determinize_lattice, "If true, " + "determinize the lattice (lattice-determinization, keeping only " + "best pdf-sequence for each word-sequence)."); + opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " + "parameter is obscure and relates to a speedup in the way the " + "max-active constraint is applied. Larger is more accurate."); + opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " + "control hash behavior"); + } + void Check() const { + KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 + && min_active <= max_active + && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0 + && prune_scale > 0.0 && prune_scale < 1.0); + } +}; + + +namespace decodercombine { +// We will template the decoder on the token type as well as the FST type; this +// is a mechanism so that we can use the same underlying decoder code for +// versions of the decoder that support quickly getting the best path +// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also +// those that do not (LatticeFasterDecoder). + + +// ForwardLinks are the links from a token to a token on the next frame. +// or sometimes on the current frame (for input-epsilon links). +template +struct ForwardLink { + using Label = fst::StdArc::Label; + + Token *next_tok; // the next token [or NULL if represents final-state] + Label ilabel; // ilabel on arc + Label olabel; // olabel on arc + BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.) + BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc + ForwardLink *next; // next in singly-linked list of forward arcs (arcs + // in the state-level lattice) from a token. + inline ForwardLink(Token *next_tok, Label ilabel, Label olabel, + BaseFloat graph_cost, BaseFloat acoustic_cost, + ForwardLink *next): + next_tok(next_tok), ilabel(ilabel), olabel(olabel), + graph_cost(graph_cost), acoustic_cost(acoustic_cost), + next(next) { } +}; + +template +struct StdToken { + using ForwardLinkT = ForwardLink; + using Token = StdToken; + using StateId = typename Fst::Arc::StateId; + + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // Record the state id of the token + StateId state_id; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + Token *next; + + // identitfy the token is in current heap or not(-1). Point out the position + // in current heap so that fix the heap after updating the cost of an existing + // token is more convience and faster + size_t position_in_heap; + + // This function does nothing and should be optimized out; it's needed + // so we can share the regular LatticeFasterDecoderTpl code and the code + // for LatticeFasterOnlineDecoder that supports fast traceback. + inline void SetBackpointer (Token *backpointer) { } + + // This constructor just ignores the 'backpointer' argument. That argument is + // needed so that we can use the same decoder code for LatticeFasterDecoderTpl + // and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a + // fast way to obtain the best path). + inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, StateId state_id, + ForwardLinkT *links, Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), state_id(state_id), + links(links), next(next), position_in_heap(-1) { } +}; + +template +struct BackpointerToken { + using ForwardLinkT = ForwardLink; + using Token = BackpointerToken; + using StateId = typename Fst::Arc::StateId; + + // BackpointerToken is like Token but also + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // Record the state id of the token + StateId state_id; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + BackpointerToken *next; + + // Best preceding BackpointerToken (could be a on this frame, connected to + // this via an epsilon transition, or on a previous frame). This is only + // required for an efficient GetBestPath function in + // LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation + // (the "links" list is what stores the forward links, for that). + Token *backpointer; + + // identitfy the token is in current heap or not(-1). Point out the position + // in current heap so that fix the heap after updating the cost of an existing + // token is more convience and faster + size_t position_in_heap; + + inline void SetBackpointer (Token *backpointer) { + this->backpointer = backpointer; + } + + inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, + StateId state_id, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), state_id(state_id), + links(links), next(next), backpointer(backpointer), + position_in_heap(-1) { } +}; + +} // namespace decoder + + +/** This is the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to support + quick lookup of the current best path (see lattice-faster-online-decoder.h) + + The FST you invoke this decoder with is expected to equal + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ +template > +class LatticeFasterDecoderCombineTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ForwardLinkT = decodercombine::ForwardLink; + + using StateIdToTokenMap = typename std::unordered_map; + //using StateIdToTokenMap = typename std::unordered_map, std::equal_to, + // fst::PoolAllocator > >; + using IterType = typename StateIdToTokenMap::const_iterator; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeFasterDecoderCombineTpl(const FST &fst, + const LatticeFasterDecoderCombineConfig &config); + + // This version of the constructor takes ownership of the fst, and will delete + // it when this object is destroyed. + LatticeFasterDecoderCombineTpl(const LatticeFasterDecoderCombineConfig &config, + FST *fst); + + void SetOptions(const LatticeFasterDecoderCombineConfig &config) { + config_ = config; + } + + const LatticeFasterDecoderCombineConfig &GetOptions() const { + return config_; + } + + ~LatticeFasterDecoderCombineTpl(); + + /// Decodes until there are no more frames left in the "decodable" object.. + /// note, this may block waiting for input if the "decodable" object blocks. + /// Returns true if any kind of traceback is available (not necessarily from a + /// final state). + bool Decode(DecodableInterface *decodable); + + + /// says whether a final-state was active on the last frame. If it was not, the + /// lattice (or traceback) will end with states that are not final-states. + bool ReachedFinal() const { + return FinalRelativeCost() != std::numeric_limits::infinity(); + } + + /// Outputs an FST corresponding to the single best path through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. Note: this just calls GetRawLattice() + /// and figures out the shortest path. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true); + + /// Outputs an FST corresponding to the raw, state-level + /// tracebacks. Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state + /// of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + /// The raw lattice will be topologically sorted. + /// The function can be called during decoding, it will process non-emitting + /// arcs from "cur_toks_" map to get tokens from both non-emitting and + /// emitting arcs for getting raw lattice. Then recover it to ensure the + /// consistency of ProcessForFrame(). + /// + /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, + /// which also supports a pruning beam, in case for some reason + /// you want it pruned tighter than the regular lattice beam. + /// We could put that here in future needed. + bool GetRawLattice(Lattice *ofst, bool use_final_probs = true); + + + + /// [Deprecated, users should now use GetRawLattice and determinize it + /// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper]. + /// Outputs an FST corresponding to the lattice-determinized + /// lattice (one path per word sequence). Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state of the graph + /// then it will include those as final-probs, else it will treat all + /// final-probs as one. + bool GetLattice(CompactLattice *ofst, + bool use_final_probs = true); + + /// InitDecoding initializes the decoding, and should only be used if you + /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to + /// call this. You can also call InitDecoding if you have already decoded an + /// utterance and want to start with a new utterance. + void InitDecoding(); + + /// This will decode until there are no more frames ready in the decodable + /// object. You can keep calling it each time more frames become available. + /// If max_num_frames is specified, it specifies the maximum number of frames + /// the function will decode before returning. + void AdvanceDecoding(DecodableInterface *decodable, + int32 max_num_frames = -1); + + /// This function may be optionally called after AdvanceDecoding(), when you + /// do not plan to decode any further. It does an extra pruning step that + /// will help to prune the lattices output by GetLattice and (particularly) + /// GetRawLattice more accurately, particularly toward the end of the + /// utterance. It does this by using the final-probs in pruning (if any + /// final-state survived); it also does a final pruning step that visits all + /// states (the pruning that is done during decoding may fail to prune states + /// that are within kPruningScale = 0.1 outside of the beam). If you call + /// this, you cannot call AdvanceDecoding again (it will fail), and you + /// cannot call GetLattice() and related functions with use_final_probs = + /// false. + /// Used to be called PruneActiveTokensFinal(). + void FinalizeDecoding(); + + /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + /// more information. It returns the difference between the best (final-cost + /// plus cost) of any token on the final frame, and the best cost of any token + /// on the final frame. If it is infinity it means no final-states were + /// present on the final frame. It will usually be nonnegative. If it not + /// too positive (e.g. < 5 is my first guess, but this is not tested) you can + /// take it as a good indication that we reached the final-state with + /// reasonable likelihood. + BaseFloat FinalRelativeCost() const; + + + // Returns the number of frames decoded so far. The value returned changes + // whenever we call ProcessForFrame(). + inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } + + protected: + // we make things protected instead of private, as code in + // LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the + // internals. + + // Deletes the elements of the singly linked list tok->links. + inline static void DeleteForwardLinks(Token *tok); + + // head of per-frame list of Tokens (list is in topological order), + // and something saying whether we ever pruned it using PruneForwardLinks. + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + TokenList(): toks(NULL), must_prune_forward_links(true), + must_prune_tokens(true) { } + }; + + // It is a minimum heap (A[parent(i)] <= A[i]) since the lower cost the better + // in decoding. + // The index of node of the heap is zero-beased. + // Given the parent is i, the left child is 2*i+1 and the right child is + // 2*i + 2 + // Given the child is i, the parent is "( (i+1) / 2) - 1" + struct TokenHeap { + std::vector elements; + + inline void Siftup(size_t child_index, size_t length) { + while(true) { + if (child_index == 0) break; // it doesn't have parent node + KALDI_ASSERT(child_index < length); + size_t parent_index = (child_index + 1) / 2 - 1; + if (elements[child_index]->tot_cost < elements[parent_index]->tot_cost) { + // Update the index of token + elements[parent_index]->position_in_heap = child_index; + elements[child_index]->position_in_heap = parent_index; + // Swap + std::swap(elements[parent_index], elements[child_index]); + // Update child_index for next turn + child_index = parent_index; + } else { + break; // finish + } + } + } + + inline void Siftdown(size_t parent_index, size_t length) { + while(true) { + if (parent_index >= elements.size() / 2) break; + // Prepare indexes + size_t left_child_index = parent_index * 2 + 1; + size_t right_child_index = parent_index * 2 + 2; + size_t largest = parent_index; + // Get the largest index + if (left_child_index < length && + elements[left_child_index]->tot_cost < elements[largest]->tot_cost) { + largest = left_child_index; + } + if (right_child_index < length && + elements[right_child_index]->tot_cost < elements[largest]->tot_cost) { + largest = right_child_index; + } + // Swap + if (largest != parent_index) { + // Update the index of token + elements[largest]->position_in_heap = parent_index; + elements[parent_index]->position_in_heap = largest; + // Swap + std::swap(elements[parent_index], elements[largest]); + // Update parent_index for next turn + parent_index = largest; + } else { + break; // finish + } + } + } + + inline bool Empty() { + return elements.empty(); + } + + inline Token* Top() { + KALDI_ASSERT(!elements.empty()); + return elements[0]; + } + + inline void Pop() { + // Set the position + elements[0]->position_in_heap = -1; + + // Swap with the last element of the heap + std::swap(elements[0], elements[elements.size() - 1]); + elements[0]->position_in_heap = 0; + + // Delete it from heap + elements.erase(elements.end() - 1); + + // Tune the position from top to down + Siftdown(0, elements.size()); + } + + // Push a new token into the heap + inline void Push(Token* tok) { + KALDI_ASSERT(tok->position_in_heap == -1); // not in heap + // Push to the end of the heap + tok->position_in_heap = elements.size(); + elements.push_back(tok); + // Tune + Siftup(tok->position_in_heap, elements.size()); + } + + inline void Clear() { + while (!Empty()) + Pop(); + } + + // Build Heap. The complexity of the function is O(n) rather than O(nlogn), + // as the series convergence. + inline Token* BuildTokenHeap(const TokenList &token_list, size_t num) { + KALDI_ASSERT(elements.empty()); + elements.reserve(num * 1.5); + // Add elements + for (Token* tok = token_list.toks; tok != NULL; tok = tok->next) { + tok->position_in_heap = elements.size(); + elements.push_back(tok); + } + // Sort with Siftdown + size_t start = elements.size() / 2 - 1; // start is the index of the last + // parent node + size_t length = elements.size(); + for (size_t i = start; i >= 0; i--) { + Siftdown(i, length); + } + } + + TokenHeap() { + elements.resize(0); + } + }; + + // FindOrAddToken either locates a token in hash map "token_map", or if necessary + // inserts a new, empty token (i.e. with no forward links) for the current + // frame. [note: it's inserted if necessary into hash map and also into the + // singly linked list of tokens active on this frame (whose head is at + // active_toks_[frame]). The frame_plus_one argument is the acoustic frame + // index plus one, which is used to index into the active_toks_ array. + // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the + // token was newly created or the cost changed. + // If Token == StdToken, the 'backpointer' argument has no purpose (and will + // hopefully be optimized out). + inline Token *FindOrAddToken(StateId state, int32 token_list_index, + BaseFloat tot_cost, Token *backpointer, + StateIdToTokenMap *token_map, + bool *changed); + + // prunes outgoing links for all tokens in active_toks_[frame] + // it's called by PruneActiveTokens + // all links, that have link_extra_cost > lattice_beam are pruned + // delta is the amount by which the extra_costs must change + // before we set *extra_costs_changed = true. + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, + BaseFloat delta); + + // This function computes the final-costs for tokens active on the final + // frame. It outputs to final-costs, if non-NULL, a map from the Token* + // pointer to the final-prob of the corresponding state, for all Tokens + // that correspond to states that have final-probs. This map will be + // empty if there were no final-probs. It outputs to + // final_relative_cost, if non-NULL, the difference between the best + // forward-cost including the final-prob cost, and the best forward-cost + // without including the final-prob cost (this will usually be positive), or + // infinity if there were no final-probs. [c.f. FinalRelativeCost(), which + // outputs this quanitity]. It outputs to final_best_cost, if + // non-NULL, the lowest for any token t active on the final frame, of + // forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in + // the graph of the state corresponding to token t, or the best of + // forward-cost[t] if there were no final-probs active on the final frame. + // You cannot call this after FinalizeDecoding() has been called; in that + // case you should get the answer from class-member variables. + void ComputeFinalCosts(unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const; + + // PruneForwardLinksFinal is a version of PruneForwardLinks that we call + // on the final frame. If there are final tokens active, it uses + // the final-probs for pruning, otherwise it treats all tokens as final. + void PruneForwardLinksFinal(); + + // Prune away any tokens on this frame that have no forward links. + // [we don't do this in PruneForwardLinks because it would give us + // a problem with dangling pointers]. + // It's called by PruneActiveTokens if any forward links have been pruned + void PruneTokensForFrame(int32 frame_plus_one); + + + // Go backwards through still-alive tokens, pruning them if the + // forward+backward cost is more than lat_beam away from the best path. It's + // possible to prove that this is "correct" in the sense that we won't lose + // anything outside of lat_beam, regardless of what happens in the future. + // delta controls when it considers a cost to have changed enough to continue + // going backward and propagating the change. larger delta -> will recurse + // less far. + void PruneActiveTokens(BaseFloat delta); + + /// Processes non-emitting (epsilon) arcs and emitting arcs for one frame + /// together. It takes the emittion tokens in "prev_toks_" from last frame. + /// Generates non-emitting tokens for previous frame and emitting tokens for + /// next frame. + /// Notice: The emitting tokens for the current frame means the token take + /// acoustic scores of the current frame. (i.e. the destnations of emitting + /// arcs.) + void ProcessForFrame(DecodableInterface *decodable); + + /// Processes nonemitting (epsilon) arcs for one frame. + /// Calls this function once when all frames were processed. + /// Or calls it in GetRawLattice() to generate the complete token list for + /// the last frame. [Deal With the tokens in map "cur_toks_" which would + /// only contains emittion tokens from previous frame.] + /// If the map, "token_orig_cost", isn't NULL, we build the map which will + /// be used to recover "active_toks_[last_frame]" token list for the last + /// frame. + void ProcessNonemitting(std::unordered_map *token_orig_cost); + + /// When GetRawLattice() is called during decoding, the + /// active_toks_[last_frame] is changed. To keep the consistency of function + /// ProcessForFrame(), recover it. + /// Notice: as new token will be added to the head of TokenList, tok->next + /// will not be affacted. + /// "token_orig_cost" is a mapping from token pointer to the tot_cost of the + /// token before propagating non-emitting arcs. It is used to recover the + /// change of original tokens in the last frame and remove the new tokens + /// which come from propagating non-emitting arcs, so that we can guarantee + /// the consistency of function ProcessForFrame(). + void RecoverLastTokenList( + const std::unordered_map &token_orig_cost); + + + /// The "prev_toks_" and "cur_toks_" actually allow us to maintain current + /// and next frames. They are indexed by StateId. It is indexed by frame-index + /// plus one, where the frame-index is zero-based, as used in decodable object. + /// That is, the emitting probs of frame t are accounted for in tokens at + /// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of + /// the graph. + StateIdToTokenMap prev_toks_; + StateIdToTokenMap cur_toks_; + + /// Gets the weight cutoff. + /// Notice: In traiditional version, the histogram prunning method is applied + /// on a complete token list on one frame. But, in this version, it is used + /// on a token list which only contains the emittion part. So the max_active + /// and min_active values might be narrowed. + BaseFloat GetCutoff(const TokenList &token_list, + BaseFloat *adaptive_beam, + StateId *best_state_id, Token **best_token); + + std::vector active_toks_; // Lists of tokens, indexed by + // frame (members of TokenList are toks, must_prune_forward_links, + // must_prune_tokens). + TokenHeap cur_heap_; // temp variable used in ProcessForFrame + // and ProcessNonemitting + std::vector tmp_array_; // used in GetCutoff. + + // fst_ is a pointer to the FST we are decoding from. + const FST *fst_; + // delete_fst_ is true if the pointer fst_ needs to be deleted when this + // object is destroyed. + bool delete_fst_; + + std::vector cost_offsets_; // This contains, for each + // frame, an offset that was added to the acoustic log-likelihoods on that + // frame in order to keep everything in a nice dynamic range i.e. close to + // zero, to reduce roundoff errors. + // Notice: It will only be added to emitting arcs (i.e. cost_offsets_[t] is + // added to arcs from "frame t" to "frame t+1"). + LatticeFasterDecoderCombineConfig config_; + int32 num_toks_; // current total #toks allocated... + bool warned_; + + /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, + /// calling this is optional]. If true, it's forbidden to decode more. Also, + /// if this is set, then the output of ComputeFinalCosts() is in the next + /// three variables. The reason we need to do this is that after + /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some + /// of the tokens on the last frame are freed, so we free the list from toks_ + /// to avoid having dangling pointers hanging around. + bool decoding_finalized_; + /// For the meaning of the next 3 variables, see the comment for + /// decoding_finalized_ above., and ComputeFinalCosts(). + unordered_map final_costs_; + BaseFloat final_relative_cost_; + BaseFloat final_best_cost_; + + // This function takes a singly linked list of tokens for a single frame, and + // outputs a list of them in topological order (it will crash if no such order + // can be found, which will typically be due to decoding graphs with epsilon + // cycles, which are not allowed). Note: the output list may contain NULLs, + // which the caller should pass over; it just happens to be more efficient for + // the algorithm to output a list that contains NULLs. + static void TopSortTokens(Token *tok_list, + std::vector *topsorted_list); + + void ClearActiveTokens(); + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderCombineTpl); +}; + +typedef LatticeFasterDecoderCombineTpl > LatticeFasterDecoderCombine; + + + +} // end namespace kaldi. + +#endif From a4a2ddcb7b124f6a8ba0cedd77c26607df34325b Mon Sep 17 00:00:00 2001 From: Hang Lyu Date: Mon, 18 Mar 2019 17:02:18 -0400 Subject: [PATCH 15/29] bucketqueue --- ...tice-faster-decoder-combine-bucketqueue.cc | 1153 +++++++++++++++++ ...ttice-faster-decoder-combine-bucketqueue.h | 638 +++++++++ src/decoder/lattice-faster-decoder-combine.cc | 172 ++- src/decoder/lattice-faster-decoder-combine.h | 76 +- 4 files changed, 1968 insertions(+), 71 deletions(-) create mode 100644 src/decoder/lattice-faster-decoder-combine-bucketqueue.cc create mode 100644 src/decoder/lattice-faster-decoder-combine-bucketqueue.h diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc new file mode 100644 index 00000000000..11c5d01df0a --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc @@ -0,0 +1,1153 @@ +// decoder/lattice-faster-decoder-combine.cc + +// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/lattice-faster-decoder-combine.h" +#include "lat/lattice-functions.h" + +namespace kaldi { + +template +BucketQueue::BucketQueue(BaseFloat best_cost_estimate, + BaseFloat cost_scale) : + cost_scale_(cost_scale) { + // NOTE: we reserve plenty of elements to avoid expensive reallocations + // later on. Normally, the size is a little bigger than (adaptive_beam + + // 5) * cost_scale. + int32 bucket_size = 100; + buckets_.resize(bucket_size); + bucket_storage_begin_ = std::floor((best_cost_estimate - 15) * cost_scale); + first_occupied_bucket_index_ = bucket_storage_begin_ + bucket_size; +} + +template +void BucketQueue::Push(Token *tok) { + int32 bucket_index = std::floor(tok->tot_cost * cost_scale_); + int32 vec_index = bucket_index - bucket_storage_begin_; + + if (vec_index < 0) { + KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve" + << " more elements in constructor. Push front."; + int32 increase_size = - vec_index; + std::vector > tmp(buckets_); + buckets_.resize(tmp.size() + increase_size); + std::copy(tmp.begin(), tmp.end(), buckets_.begin() + increase_size); + // Update start point + bucket_storage_begin_ = bucket_index; + vec_index = 0; + } else if (vec_index > buckets_.size() - 1) { + KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve" + << " more elements in constructor. Push back."; + buckets_.resize(vec_index + 1); + } + + tok->in_queue = true; + buckets_[vec_index].push_back(tok); + if (vec_index < (first_occupied_bucket_index_ - bucket_storage_begin_)) + first_occupied_bucket_index_ = vec_index + bucket_storage_begin_; +} + +template +Token* BucketQueue::Pop() { + int32 vec_index = first_occupied_bucket_index_ - bucket_storage_begin_; + Token* best_tok = NULL; + while(vec_index < buckets_.size()) { + // Remove the best token + best_tok = buckets_[vec_index].back(); + buckets_[vec_index].pop_back(); + + if (buckets_[vec_index].empty()) { // This bucket is empty. Update + // first_occupied_bucket_index_ + int32 next_vec_index = vec_index + 1; + for(; next_vec_index < buckets_.size(); next_vec_index++) { + if(!buckets_[next_vec_index].empty()) break; + } + first_occupied_bucket_index_ = bucket_storage_begin_ + next_vec_index; + vec_index = next_vec_index; + } + + if (best_tok->in_queue) { // This is a effective token + best_tok->in_queue = false; + break; + } else { + best_tok = NULL; + } + } + return best_tok; +} + +// instantiate this class once for each thing you have to decode. +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const FST &fst, + const LatticeFasterDecoderCombineConfig &config): + fst_(&fst), delete_fst_(false), config_(config), num_toks_(0) { + config.Check(); +} + + +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const LatticeFasterDecoderCombineConfig &config, FST *fst): + fst_(fst), delete_fst_(true), config_(config), num_toks_(0) { + config.Check(); + prev_toks_.reserve(1000); + cur_toks_.reserve(1000); +} + + +template +LatticeFasterDecoderCombineTpl::~LatticeFasterDecoderCombineTpl() { + ClearActiveTokens(); + if (delete_fst_) delete fst_; + //prev_toks_.clear(); + //cur_toks_.clear(); +} + +template +void LatticeFasterDecoderCombineTpl::InitDecoding() { + // clean up from last time: + prev_toks_.clear(); + cur_toks_.clear(); + cost_offsets_.clear(); + ClearActiveTokens(); + + warned_ = false; + num_toks_ = 0; + decoding_finalized_ = false; + final_costs_.clear(); + StateId start_state = fst_->Start(); + KALDI_ASSERT(start_state != fst::kNoStateId); + active_toks_.resize(1); + Token *start_tok = new Token(0.0, 0.0, start_state, NULL, NULL, NULL); + active_toks_[0].toks = start_tok; + cur_toks_[start_state] = start_tok; // initialize current tokens map + num_toks_++; + best_token_in_next_frame_ = start_tok; +} + +// Returns true if any kind of traceback is available (not necessarily from +// a final state). It should only very rarely return false; this indicates +// an unusual search error. +template +bool LatticeFasterDecoderCombineTpl::Decode(DecodableInterface *decodable) { + InitDecoding(); + + // We use 1-based indexing for frames in this decoder (if you view it in + // terms of features), but note that the decodable object uses zero-based + // numbering, which we have to correct for when we call it. + + while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { + if (NumFramesDecoded() % config_.prune_interval == 0) + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + ProcessForFrame(decodable); + } + // A complete token list of the last frame will be generated in FinalizeDecoding() + FinalizeDecoding(); + + // Returns true if we have any kind of traceback available (not necessarily + // to the end state; query ReachedFinal() for that). + return !active_toks_.empty() && active_toks_.back().toks != NULL; +} + + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeFasterDecoderCombineTpl::GetBestPath( + Lattice *olat, + bool use_final_probs) { + Lattice raw_lat; + GetRawLattice(&raw_lat, use_final_probs); + ShortestPath(raw_lat, olat); + return (olat->NumStates() != 0); +} + + +// Outputs an FST corresponding to the raw, state-level lattice +template +bool LatticeFasterDecoderCombineTpl::GetRawLattice( + Lattice *ofst, + bool use_final_probs) { + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + // Note: you can't use the old interface (Decode()) if you want to + // get the lattice with use_final_probs = false. You'd have to do + // InitDecoding() and then AdvanceDecoding(). + if (decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "GetRawLattice() with use_final_probs == false"; + + std::unordered_map token_orig_cost; + if (!decoding_finalized_) { + // Process the non-emitting arcs for the unfinished last frame. + ProcessNonemitting(&token_orig_cost); + } + + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (decoding_finalized_ ? final_costs_ : final_costs_local); + if (!decoding_finalized_ && use_final_probs) + ComputeFinalCosts(&final_costs_local, NULL, NULL); + + ofst->DeleteStates(); + // num-frames plus one (since frames are one-based, and we have + // an extra frame for the start-state). + int32 num_frames = active_toks_.size() - 1; + KALDI_ASSERT(num_frames > 0); + const int32 bucket_count = num_toks_/2 + 3; + unordered_map tok_map(bucket_count); + // First create all states. + std::vector token_list; + for (int32 f = 0; f <= num_frames; f++) { + if (active_toks_[f].toks == NULL) { + KALDI_WARN << "GetRawLattice: no tokens active on frame " << f + << ": not producing lattice.\n"; + return false; + } + TopSortTokens(active_toks_[f].toks, &token_list); + for (size_t i = 0; i < token_list.size(); i++) + if (token_list[i] != NULL) + tok_map[token_list[i]] = ofst->AddState(); + } + // The next statement sets the start state of the output FST. Because we + // topologically sorted the tokens, state zero must be the start-state. + ofst->SetStart(0); + + KALDI_VLOG(4) << "init:" << num_toks_/2 + 3 << " buckets:" + << tok_map.bucket_count() << " load:" << tok_map.load_factor() + << " max:" << tok_map.max_load_factor(); + // Now create all arcs. + for (int32 f = 0; f <= num_frames; f++) { + for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + for (ForwardLinkT *l = tok->links; + l != NULL; + l = l->next) { + typename unordered_map::const_iterator + iter = tok_map.find(l->next_tok); + StateId nextstate = iter->second; + KALDI_ASSERT(iter != tok_map.end()); + BaseFloat cost_offset = 0.0; + if (l->ilabel != 0) { // emitting.. + KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); + cost_offset = cost_offsets_[f]; + } + Arc arc(l->ilabel, l->olabel, + Weight(l->graph_cost, l->acoustic_cost - cost_offset), + nextstate); + ofst->AddArc(cur_state, arc); + } + if (f == num_frames) { + if (use_final_probs && !final_costs.empty()) { + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) + ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); + } else { + ofst->SetFinal(cur_state, LatticeWeight::One()); + } + } + } + } + + if (!decoding_finalized_) { // recover last token list + RecoverLastTokenList(token_orig_cost); + } + return (ofst->NumStates() > 0); +} + + +// When GetRawLattice() is called during decoding, the +// active_toks_[last_frame] is changed. To keep the consistency of function +// ProcessForFrame(), recover it. +// Notice: as new token will be added to the head of TokenList, tok->next +// will not be affacted. +template +void LatticeFasterDecoderCombineTpl::RecoverLastTokenList( + const std::unordered_map &token_orig_cost) { + if (!token_orig_cost.empty()) { + for (Token* tok = active_toks_[active_toks_.size() - 1].toks; + tok != NULL;) { + if (token_orig_cost.find(tok) != token_orig_cost.end()) { + DeleteForwardLinks(tok); + tok->tot_cost = token_orig_cost.find(tok)->second; + tok->in_queue = false; + tok = tok->next; + } else { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + } +} + +// This function is now deprecated, since now we do determinization from outside +// the LatticeFasterDecoder class. Outputs an FST corresponding to the +// lattice-determinized lattice (one path per word sequence). +template +bool LatticeFasterDecoderCombineTpl::GetLattice( + CompactLattice *ofst, + bool use_final_probs) { + Lattice raw_fst; + GetRawLattice(&raw_fst, use_final_probs); + Invert(&raw_fst); // make it so word labels are on the input. + // (in phase where we get backward-costs). + fst::ILabelCompare ilabel_comp; + ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes + // lattice-determinization more efficient. + + fst::DeterminizeLatticePrunedOptions lat_opts; + lat_opts.max_mem = config_.det_opts.max_mem; + + DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts); + raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed. + Connect(ofst); // Remove unreachable states... there might be + // a small number of these, in some cases. + // Note: if something went wrong and the raw lattice was empty, + // we should still get to this point in the code without warnings or failures. + return (ofst->NumStates() != 0); +} + +/* + A note on the definition of extra_cost. + + extra_cost is used in pruning tokens, to save memory. + + Define the 'forward cost' of a token as zero for any token on the frame + we're currently decoding; and for other frames, as the shortest-path cost + between that token and a token on the frame we're currently decoding. + (by "currently decoding" I mean the most recently processed frame). + + Then define the extra_cost of a token (always >= 0) as the forward-cost of + the token minus the smallest forward-cost of any token on the same frame. + + We can use the extra_cost to accurately prune away tokens that we know will + never appear in the lattice. If the extra_cost is greater than the desired + lattice beam, the token would provably never appear in the lattice, so we can + prune away the token. + + The advantage of storing the extra_cost rather than the forward-cost, is that + it is less costly to keep the extra_cost up-to-date when we process new frames. + When we process a new frame, *all* the previous frames' forward-costs would change; + but in general the extra_cost will change only for a finite number of frames. + (Actually we don't update all the extra_costs every time we update a frame; we + only do it every 'config_.prune_interval' frames). + */ + +// FindOrAddToken either locates a token in hash map "token_map" +// or if necessary inserts a new, empty token (i.e. with no forward links) +// for the current frame. [note: it's inserted if necessary into hash toks_ +// and also into the singly linked list of tokens active on this frame +// (whose head is at active_toks_[frame]). +template +inline Token* LatticeFasterDecoderCombineTpl::FindOrAddToken( + StateId state, int32 token_list_index, BaseFloat tot_cost, + Token *backpointer, StateIdToTokenMap *token_map, bool *changed) { + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + KALDI_ASSERT(token_list_index < active_toks_.size()); + Token *&toks = active_toks_[token_list_index].toks; + typename StateIdToTokenMap::iterator e_found = token_map->find(state); + if (e_found == token_map->end()) { // no such token presently. + const BaseFloat extra_cost = 0.0; + // tokens on the currently final frame have zero extra_cost + // as any of them could end up + // on the winning path. + Token *new_tok = new Token (tot_cost, extra_cost, state, + NULL, toks, backpointer); + // NULL: no forward links yet + toks = new_tok; + num_toks_++; + // insert into the map + (*token_map)[state] = new_tok; + if (changed) *changed = true; + return new_tok; + } else { + Token *tok = e_found->second; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { // replace old token + tok->tot_cost = tot_cost; + // SetBackpointer() just does tok->backpointer = backpointer in + // the case where Token == BackpointerToken, else nothing. + tok->SetBackpointer(backpointer); + // we don't allocate a new token, the old stays linked in active_toks_ + // we only replace the tot_cost + // in the current frame, there are no forward links (and no extra_cost) + // only in ProcessNonemitting we have to delete forward links + // in case we visit a state for the second time + // those forward links, that lead to this replaced token before: + // they remain and will hopefully be pruned later (PruneForwardLinks...) + if (changed) *changed = true; + } else { + if (changed) *changed = false; + } + return tok; + } +} + +// prunes outgoing links for all tokens in active_toks_[frame] +// it's called by PruneActiveTokens +// all links, that have link_extra_cost > lattice_beam are pruned +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinks( + int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, BaseFloat delta) { + // delta is the amount by which the extra_costs must change + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + + *extra_costs_changed = false; + *links_pruned = false; + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen. + if (!warned_) { + KALDI_WARN << "No tokens alive [doing pruning].. warning first " + "time only for each utterance\n"; + warned_ = true; + } + } + + // We have to iterate until there is no more change, because the links + // are not guaranteed to be in topological order. + bool changed = true; // difference new minus old extra cost >= delta ? + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost for tok. + BaseFloat tok_extra_cost = std::numeric_limits::infinity(); + // tok_extra_cost is the best (min) of link_extra_cost of outgoing links + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); // difference in brackets is >= 0 + // link_exta_cost is the difference in score between the best paths + // through link source state and through link destination state + KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + *links_pruned = true; + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; // move to next link + link = link->next; + } + } // for all outgoing links + if (fabs(tok_extra_cost - tok->extra_cost) > delta) + changed = true; // difference new minus old is bigger than delta + tok->extra_cost = tok_extra_cost; + // will be +infinity or <= lattice_beam_. + // infinity indicates, that no forward link survived pruning + } // for all Token on active_toks_[frame] + if (changed) *extra_costs_changed = true; + + // Note: it's theoretically possible that aggressive compiler + // optimizations could cause an infinite loop here for small delta and + // high-dynamic-range scores. + } // while changed +} + +// PruneForwardLinksFinal is a version of PruneForwardLinks that we call +// on the final frame. If there are final tokens active, it uses +// the final-probs for pruning, otherwise it treats all tokens as final. +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinksFinal() { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame_plus_one = active_toks_.size() - 1; + + if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. + KALDI_WARN << "No tokens alive at end of file"; + + typedef typename unordered_map::const_iterator IterType; + ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); + decoding_finalized_ = true; + + // Now go through tokens on this frame, pruning forward links... may have to + // iterate a few times until there is no more change, because the list is not + // in topological order. This is a modified version of the code in + // PruneForwardLinks, but here we also take account of the final-probs. + bool changed = true; + BaseFloat delta = 1.0e-05; + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost. It has a term in it that corresponds + // to the "final-prob", so instead of initializing tok_extra_cost to infinity + // below we set it to the difference between the (score+final_prob) of this token, + // and the best such (score+final_prob). + BaseFloat final_cost; + if (final_costs_.empty()) { + final_cost = 0.0; + } else { + IterType iter = final_costs_.find(tok); + if (iter != final_costs_.end()) + final_cost = iter->second; + else + final_cost = std::numeric_limits::infinity(); + } + BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_; + // tok_extra_cost will be a "min" over either directly being final, or + // being indirectly final through other links, and the loop below may + // decrease its value: + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; + link = link->next; + } + } + // prune away tokens worse than lattice_beam above best path. This step + // was not necessary in the non-final case because then, this case + // showed up as having no forward links. Here, the tok_extra_cost has + // an extra component relating to the final-prob. + if (tok_extra_cost > config_.lattice_beam) + tok_extra_cost = std::numeric_limits::infinity(); + // to be pruned in PruneTokensForFrame + + if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) + changed = true; + tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. + } + } // while changed +} + +template +BaseFloat LatticeFasterDecoderCombineTpl::FinalRelativeCost() const { + if (!decoding_finalized_) { + BaseFloat relative_cost; + ComputeFinalCosts(NULL, &relative_cost, NULL); + return relative_cost; + } else { + // we're not allowed to call that function if FinalizeDecoding() has + // been called; return a cached value. + return final_relative_cost_; + } +} + + +// Prune away any tokens on this frame that have no forward links. +// [we don't do this in PruneForwardLinks because it would give us +// a problem with dangling pointers]. +// It's called by PruneActiveTokens if any forward links have been pruned +template +void LatticeFasterDecoderCombineTpl::PruneTokensForFrame( + int32 frame_plus_one) { + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + if (toks == NULL) + KALDI_WARN << "No tokens alive [doing pruning]"; + Token *tok, *next_tok, *prev_tok = NULL; + for (tok = toks; tok != NULL; tok = next_tok) { + next_tok = tok->next; + if (tok->extra_cost == std::numeric_limits::infinity()) { + // token is unreachable from end of graph; (no forward links survived) + // excise tok from list and delete tok. + if (prev_tok != NULL) prev_tok->next = tok->next; + else toks = tok->next; + delete tok; + num_toks_--; + } else { // fetch next Token + prev_tok = tok; + } + } +} + +// Go backwards through still-alive tokens, pruning them, starting not from +// the current frame (where we want to keep all tokens) but from the frame before +// that. We go backwards through the frames and stop when we reach a point +// where the delta-costs are not changing (and the delta controls when we consider +// a cost to have "not changed"). +template +void LatticeFasterDecoderCombineTpl::PruneActiveTokens( + BaseFloat delta) { + int32 cur_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // The index "f" below represents a "frame plus one", i.e. you'd have to subtract + // one to get the corresponding index for the decodable object. + for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) { + // Reason why we need to prune forward links in this situation: + // (1) we have never pruned them (new TokenList) + // (2) we have not yet pruned the forward links to the next f, + // after any of those tokens have changed their extra_cost. + if (active_toks_[f].must_prune_forward_links) { + bool extra_costs_changed = false, links_pruned = false; + PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); + if (extra_costs_changed && f > 0) // any token has changed extra_cost + active_toks_[f-1].must_prune_forward_links = true; + if (links_pruned) // any link was pruned + active_toks_[f].must_prune_tokens = true; + active_toks_[f].must_prune_forward_links = false; // job done + } + if (f+1 < cur_frame_plus_one && // except for last f (no forward links) + active_toks_[f+1].must_prune_tokens) { + PruneTokensForFrame(f+1); + active_toks_[f+1].must_prune_tokens = false; + } + } + KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +template +void LatticeFasterDecoderCombineTpl::ComputeFinalCosts( + unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const { + KALDI_ASSERT(!decoding_finalized_); + if (final_costs != NULL) + final_costs->clear(); + BaseFloat infinity = std::numeric_limits::infinity(); + BaseFloat best_cost = infinity, + best_cost_with_final = infinity; + + // The final tokens are recorded in active_toks_[last_frame] + for (Token *tok = active_toks_[active_toks_.size() - 1].toks; tok != NULL; + tok = tok->next) { + StateId state = tok->state_id; + BaseFloat final_cost = fst_->Final(state).Value(); + BaseFloat cost = tok->tot_cost, + cost_with_final = cost + final_cost; + best_cost = std::min(cost, best_cost); + best_cost_with_final = std::min(cost_with_final, best_cost_with_final); + if (final_costs != NULL && final_cost != infinity) + (*final_costs)[tok] = final_cost; + } + if (final_relative_cost != NULL) { + if (best_cost == infinity && best_cost_with_final == infinity) { + // Likely this will only happen if there are no tokens surviving. + // This seems the least bad way to handle it. + *final_relative_cost = infinity; + } else { + *final_relative_cost = best_cost_with_final - best_cost; + } + } + if (final_best_cost != NULL) { + if (best_cost_with_final != infinity) { // final-state exists. + *final_best_cost = best_cost_with_final; + } else { // no final-state exists. + *final_best_cost = best_cost; + } + } +} + +template +void LatticeFasterDecoderCombineTpl::AdvanceDecoding( + DecodableInterface *decodable, + int32 max_num_frames) { + if (std::is_same >::value) { + // if the type 'FST' is the FST base-class, then see if the FST type of fst_ + // is actually VectorFst or ConstFst. If so, call the AdvanceDecoding() + // function after casting *this to the more specific type. + if (fst_->Type() == "const") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } else if (fst_->Type() == "vector") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } + } + + + KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ && + "You must call InitDecoding() before AdvanceDecoding"); + int32 num_frames_ready = decodable->NumFramesReady(); + // num_frames_ready must be >= num_frames_decoded, or else + // the number of frames ready must have decreased (which doesn't + // make sense) or the decodable object changed between calls + // (which isn't allowed). + KALDI_ASSERT(num_frames_ready >= NumFramesDecoded()); + int32 target_frames_decoded = num_frames_ready; + if (max_num_frames >= 0) + target_frames_decoded = std::min(target_frames_decoded, + NumFramesDecoded() + max_num_frames); + while (NumFramesDecoded() < target_frames_decoded) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + ProcessForFrame(decodable); + } +} + +// FinalizeDecoding() is a version of PruneActiveTokens that we call +// (optionally) on the final frame. Takes into account the final-prob of +// tokens. This function used to be called PruneActiveTokensFinal(). +template +void LatticeFasterDecoderCombineTpl::FinalizeDecoding() { + ProcessNonemitting(NULL); + int32 final_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // PruneForwardLinksFinal() prunes final frame (with final-probs), and + // sets decoding_finalized_. + PruneForwardLinksFinal(); + for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { + bool b1, b2; // values not used. + BaseFloat dontcare = 0.0; // delta of zero means we must always update + PruneForwardLinks(f, &b1, &b2, dontcare); + PruneTokensForFrame(f + 1); + } + PruneTokensForFrame(0); + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +/// Gets the weight cutoff. +template +BaseFloat LatticeFasterDecoderCombineTpl::GetCutoff( + const TokenList &token_list, const Token* best_token, + BaseFloat *adaptive_beam, BucketQueue *queue) { + BaseFloat best_weight = best_token->tot_cost; + // positive == high cost == bad. + // best_weight is the minimum value. + if (config_.max_active == std::numeric_limits::max() && + config_.min_active == 0) { + for (Token* tok = token_list.toks; tok != NULL; tok = tok->next) { + queue->Push(tok); + } + if (adaptive_beam != NULL) *adaptive_beam = config_.beam; + return best_weight + config_.beam; + } else { + tmp_array_.clear(); + for (Token* tok = token_list.toks; tok != NULL; tok = tok->next) { + BaseFloat w = static_cast(tok->tot_cost); + tmp_array_.push_back(w); + queue->Push(tok); + } + + BaseFloat beam_cutoff = best_weight + config_.beam, + min_active_cutoff = std::numeric_limits::infinity(), + max_active_cutoff = std::numeric_limits::infinity(); + + KALDI_VLOG(6) << "Number of emitting tokens on frame " + << NumFramesDecoded() - 1 << " is " << tmp_array_.size(); + + if (tmp_array_.size() > static_cast(config_.max_active)) { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.max_active, + tmp_array_.end()); + max_active_cutoff = tmp_array_[config_.max_active]; + } + if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. + if (adaptive_beam) + *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta; + return max_active_cutoff; + } + if (tmp_array_.size() > static_cast(config_.min_active)) { + if (config_.min_active == 0) min_active_cutoff = best_weight; + else { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.min_active, + tmp_array_.size() > static_cast(config_.max_active) ? + tmp_array_.begin() + config_.max_active : tmp_array_.end()); + min_active_cutoff = tmp_array_[config_.min_active]; + } + } + if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. + if (adaptive_beam) + *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta; + return min_active_cutoff; + } else { + *adaptive_beam = config_.beam; + return beam_cutoff; + } + } +} + +template +void LatticeFasterDecoderCombineTpl::ProcessForFrame( + DecodableInterface *decodable) { + KALDI_ASSERT(active_toks_.size() > 0); + int32 frame = active_toks_.size() - 1; // frame is the frame-index + // (zero-based) used to get likelihoods + // from the decodable object. + active_toks_.resize(active_toks_.size() + 1); + + prev_toks_.swap(cur_toks_); + cur_toks_.clear(); + if (prev_toks_.empty()) { + if (!warned_) { + KALDI_WARN << "Error, no surviving tokens on frame " << frame; + warned_ = true; + } + } + + KALDI_ASSERT(best_token_in_next_frame_); + BucketQueue cur_queue(best_token_in_next_frame_->tot_cost); + BaseFloat adaptive_beam; + // "cur_cutoff" is used to constrain the epsilon emittion in current frame. + // It will not be updated. + BaseFloat cur_cutoff = GetCutoff(active_toks_[frame], + best_token_in_next_frame_, + &adaptive_beam, &cur_queue); + KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " + << adaptive_beam; + + // pruning "online" before having seen all tokens + + // "next_cutoff" is used to limit a new token in next frame should be handle + // or not. It will be updated along with the further processing. + BaseFloat next_cutoff = std::numeric_limits::infinity(); + // "cost_offset" contains the acoustic log-likelihoods on current frame in + // order to keep everything in a nice dynamic range. Reduce roundoff errors. + BaseFloat cost_offset = 0.0; + + // First process the best token to get a hopefully + // reasonably tight bound on the next cutoff. The only + // products of the next block are "next_cutoff" and "cost_offset". + // Notice: As the difference between the combine version and the traditional + // version, this "best_tok" is choosen from emittion tokens. Normally, the + // best token of one frame comes from an epsilon non-emittion. So the best + // token is a looser boundary. We use it to estimate a bound on the next + // cutoff and we will update the "next_cutoff" once we have better tokens. + // The "next_cutoff" will be updated in further processing. + Token *best_tok = best_token_in_next_frame_; + StateId best_tok_state_id = best_tok->state_id; + if (best_tok) { + cost_offset = - best_tok->tot_cost; + for (fst::ArcIterator aiter(*fst_, best_tok_state_id); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + // ac_cost + graph_cost + BaseFloat new_weight = arc.weight.Value() + cost_offset - + decodable->LogLikelihood(frame, arc.ilabel) + best_tok->tot_cost; + if (new_weight + adaptive_beam < next_cutoff) + next_cutoff = new_weight + adaptive_beam; + } + } + } + best_token_in_next_frame_ = NULL; + // Store the offset on the acoustic likelihoods that we're applying. + // Could just do cost_offsets_.push_back(cost_offset), but we + // do it this way as it's more robust to future code changes. + cost_offsets_.resize(frame + 1, 0.0); + cost_offsets_[frame] = cost_offset; + + // Iterator the "cur_queue_" to process non-emittion and emittion arcs in fst. + Token *tok = NULL; + while ((tok = cur_queue.Pop()) != NULL) { + BaseFloat cur_cost = tok->tot_cost; + StateId state = tok->state_id; + if (cur_cost > cur_cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, + tok, &prev_toks_, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed && !new_tok->in_queue) { + cur_queue.Push(new_tok); + } + } + } else { // propagate emitting + BaseFloat graph_cost = arc.weight.Value(), + ac_cost = cost_offset - decodable->LogLikelihood(frame, arc.ilabel), + cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + if (tot_cost > next_cutoff) continue; + else if (tot_cost + adaptive_beam < next_cutoff) { + next_cutoff = tot_cost + adaptive_beam; // a tighter boundary for emitting + } + + // no change flag is needed + Token *next_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, + tok, &cur_toks_, NULL); + // Add ForwardLink from tok to next_tok. Put it on the head of tok->link + // list + tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, + graph_cost, ac_cost, tok->links); + if (best_token_in_next_frame_ == NULL || + next_tok->tot_cost < best_token_in_next_frame_->tot_cost) { + best_token_in_next_frame_ = next_tok; + } + } + } // for all arcs + } // end of while loop + //KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() - 1 + // << " is " << prev_toks_.size(); +} + + +template +void LatticeFasterDecoderCombineTpl::ProcessNonemitting( + std::unordered_map *token_orig_cost) { + int32 frame = active_toks_.size() - 1; + if (token_orig_cost) { // Build the elements which are used to recover + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + (*token_orig_cost)[tok] = tok->tot_cost; + } + } + + StateIdToTokenMap *tmp_toks; + if (token_orig_cost) { // "token_orig_cost" isn't NULL. It means we need to + // recover active_toks_[last_frame] and "cur_toks_" + // will be used in the future. + tmp_toks = new StateIdToTokenMap(cur_toks_); + } else { + tmp_toks = &cur_toks_; + } + + BucketQueue cur_queue(best_token_in_next_frame_->tot_cost); + // "cur_cutoff" is used to constrain the epsilon emittion in current frame. + // It will not be updated. + BaseFloat adaptive_beam; + BaseFloat cur_cutoff = GetCutoff(active_toks_[frame], + best_token_in_next_frame_, + &adaptive_beam, &cur_queue); + + Token *tok = NULL; + while ((tok = cur_queue.Pop()) != NULL) { + BaseFloat cur_cost = tok->tot_cost; + StateId state = tok->state_id; + if (cur_cost > cur_cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, + tok, tmp_toks, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed && !new_tok->in_queue) { + cur_queue.Push(new_tok); + } + } + } + } // end of for loop + } // end of while loop + if (token_orig_cost) delete tmp_toks; +} + + + +// static inline +template +void LatticeFasterDecoderCombineTpl::DeleteForwardLinks(Token *tok) { + ForwardLinkT *l = tok->links, *m; + while (l != NULL) { + m = l->next; + delete l; + l = m; + } + tok->links = NULL; +} + + +template +void LatticeFasterDecoderCombineTpl::ClearActiveTokens() { + // a cleanup routine, at utt end/begin + for (size_t i = 0; i < active_toks_.size(); i++) { + // Delete all tokens alive on this frame, and any forward + // links they may have. + for (Token *tok = active_toks_[i].toks; tok != NULL; ) { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + active_toks_.clear(); + KALDI_ASSERT(num_toks_ == 0); +} + +// static +template +void LatticeFasterDecoderCombineTpl::TopSortTokens( + Token *tok_list, std::vector *topsorted_list) { + unordered_map token2pos; + typedef typename unordered_map::iterator IterType; + int32 num_toks = 0; + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + num_toks++; + int32 cur_pos = 0; + // We assign the tokens numbers num_toks - 1, ... , 2, 1, 0. + // This is likely to be in closer to topological order than + // if we had given them ascending order, because of the way + // new tokens are put at the front of the list. + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + token2pos[tok] = num_toks - ++cur_pos; + + unordered_set reprocess; + + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) { + Token *tok = iter->first; + int32 pos = iter->second; + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + // We only need to consider epsilon links, since non-epsilon links + // transition between frames and this function only needs to sort a list + // of tokens from a single frame. + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { // another token on this frame, + // so must consider it. + int32 next_pos = following_iter->second; + if (next_pos < pos) { // reassign the position of the next Token. + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + // In case we had previously assigned this token to be reprocessed, we can + // erase it from that set because it's "happy now" (we just processed it). + reprocess.erase(tok); + } + + size_t max_loop = 1000000, loop_count; // max_loop is to detect epsilon cycles. + for (loop_count = 0; + !reprocess.empty() && loop_count < max_loop; ++loop_count) { + std::vector reprocess_vec; + for (typename unordered_set::iterator iter = reprocess.begin(); + iter != reprocess.end(); ++iter) + reprocess_vec.push_back(*iter); + reprocess.clear(); + for (typename std::vector::iterator iter = reprocess_vec.begin(); + iter != reprocess_vec.end(); ++iter) { + Token *tok = *iter; + int32 pos = token2pos[tok]; + // Repeat the processing we did above (for comments, see above). + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { + int32 next_pos = following_iter->second; + if (next_pos < pos) { + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + } + } + KALDI_ASSERT(loop_count < max_loop && "Epsilon loops exist in your decoding " + "graph (this is not allowed!)"); + + topsorted_list->clear(); + topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between. + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) + (*topsorted_list)[iter->second] = iter->first; +} + +// Instantiate the template for the combination of token types and FST types +// that we'll need. +template class LatticeFasterDecoderCombineTpl, + decodercombine::StdToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::StdToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::StdToken > >; +template class LatticeFasterDecoderCombineTpl >; + +template class LatticeFasterDecoderCombineTpl , + decodercombine::BackpointerToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::BackpointerToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::BackpointerToken > >; +template class LatticeFasterDecoderCombineTpl >; + + +} // end namespace kaldi. diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.h b/src/decoder/lattice-faster-decoder-combine-bucketqueue.h new file mode 100644 index 00000000000..05f36b8aeab --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.h @@ -0,0 +1,638 @@ +// decoder/lattice-faster-decoder-combine.h + +// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ +#define KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ + + +#include "util/stl-utils.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/grammar-fst.h" +#include "decoder/lattice-faster-decoder.h" + +namespace kaldi { + +struct LatticeFasterDecoderCombineConfig { + BaseFloat beam; + int32 max_active; + int32 min_active; + BaseFloat lattice_beam; + int32 prune_interval; + bool determinize_lattice; // not inspected by this class... used in + // command-line program. + BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat hash_ratio; + BaseFloat prune_scale; // Note: we don't make this configurable on the command line, + // it's not a very important parameter. It affects the + // algorithm that prunes the tokens as we go. + // Most of the options inside det_opts are not actually queried by the + // LatticeFasterDecoder class itself, but by the code that calls it, for + // example in the function DecodeUtteranceLatticeFaster. + fst::DeterminizeLatticePhonePrunedOptions det_opts; + + LatticeFasterDecoderCombineConfig(): beam(16.0), + max_active(std::numeric_limits::max()), + min_active(200), + lattice_beam(10.0), + prune_interval(25), + determinize_lattice(true), + beam_delta(0.5), + hash_ratio(2.0), + prune_scale(0.1) { } + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); + opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; " + "more accurate"); + opts->Register("min-active", &min_active, "Decoder minimum #active states."); + opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, " + "and deeper lattices"); + opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " + "which to prune tokens"); + opts->Register("determinize-lattice", &determinize_lattice, "If true, " + "determinize the lattice (lattice-determinization, keeping only " + "best pdf-sequence for each word-sequence)."); + opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " + "parameter is obscure and relates to a speedup in the way the " + "max-active constraint is applied. Larger is more accurate."); + opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " + "control hash behavior"); + } + void Check() const { + KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 + && min_active <= max_active + && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0 + && prune_scale > 0.0 && prune_scale < 1.0); + } +}; + + +namespace decodercombine { +// We will template the decoder on the token type as well as the FST type; this +// is a mechanism so that we can use the same underlying decoder code for +// versions of the decoder that support quickly getting the best path +// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also +// those that do not (LatticeFasterDecoder). + + +// ForwardLinks are the links from a token to a token on the next frame. +// or sometimes on the current frame (for input-epsilon links). +template +struct ForwardLink { + using Label = fst::StdArc::Label; + + Token *next_tok; // the next token [or NULL if represents final-state] + Label ilabel; // ilabel on arc + Label olabel; // olabel on arc + BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.) + BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc + ForwardLink *next; // next in singly-linked list of forward arcs (arcs + // in the state-level lattice) from a token. + inline ForwardLink(Token *next_tok, Label ilabel, Label olabel, + BaseFloat graph_cost, BaseFloat acoustic_cost, + ForwardLink *next): + next_tok(next_tok), ilabel(ilabel), olabel(olabel), + graph_cost(graph_cost), acoustic_cost(acoustic_cost), + next(next) { } +}; + +template +struct StdToken { + using ForwardLinkT = ForwardLink; + using Token = StdToken; + using StateId = typename Fst::Arc::StateId; + + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // Record the state id of the token + StateId state_id; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + Token *next; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_queue; + + + // This function does nothing and should be optimized out; it's needed + // so we can share the regular LatticeFasterDecoderTpl code and the code + // for LatticeFasterOnlineDecoder that supports fast traceback. + inline void SetBackpointer (Token *backpointer) { } + + // This constructor just ignores the 'backpointer' argument. That argument is + // needed so that we can use the same decoder code for LatticeFasterDecoderTpl + // and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a + // fast way to obtain the best path). + inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, StateId state_id, + ForwardLinkT *links, Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), state_id(state_id), + links(links), next(next), in_queue(false) { } +}; + +template +struct BackpointerToken { + using ForwardLinkT = ForwardLink; + using Token = BackpointerToken; + using StateId = typename Fst::Arc::StateId; + + // BackpointerToken is like Token but also + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // Record the state id of the token + StateId state_id; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + BackpointerToken *next; + + // Best preceding BackpointerToken (could be a on this frame, connected to + // this via an epsilon transition, or on a previous frame). This is only + // required for an efficient GetBestPath function in + // LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation + // (the "links" list is what stores the forward links, for that). + Token *backpointer; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_queue; + + inline void SetBackpointer (Token *backpointer) { + this->backpointer = backpointer; + } + + inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, + StateId state_id, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), state_id(state_id), + links(links), next(next), backpointer(backpointer), + in_queue(false) { } +}; + +} // namespace decoder + + +template +class BucketQueue { + public: + /** Constructor. 'cost_scale' is a scale that we multiply the token costs by + * before intergerizing; a larger value means more buckets. + * 'best_cost_estimate' is an estimate of the best (lowest) cost that + * we are likely to encounter (e.g. the best cost that we have seen so far). + * It is used to initialize 'bucket_storage_begin_'. + */ + BucketQueue(BaseFloat best_cost_estimate, BaseFloat cost_scale = 1.0); + + // Add a Token to the queue; sets the field tok->in_queue to true (it is not + // an error if it was already true). + // If a Token was already in the queue but its cost improves, you should + // just Push it again. It will be added to (possibly) a different bucket, but + // the old entry will remain. The old entry in the queue will be considered as + // nonexistent when we try to pop it and notice that the recorded cost + // does not match the cost in the Token. (Actually, we use in_queue to decide + // an entry is nonexistent or This strategy means that you may not + // delete Tokens as long as pointers to them might exist in this queue (hence, + // it is probably best to only ever have this queue as a local variable inside + // a function). + void Push(Token *tok); + + // Removes and returns the next Token 'tok' in the queue, or NULL if there + // were no Tokens left. Sets tok->in_queue to false for the returned Token. + Token* Pop(); + + private: + // Configuration value that is multiplied by tokens' costs before integerizing + // them to determine the bucket index + BaseFloat cost_scale_; + + // buckets_ is a list of Tokens 'tok' for each bucket. + // If tok->in_queue is false, then the item is considered as not + // existing (this is to avoid having to explicitly remove Tokens when their + // costs change). The index into buckets_ is determined as follows: + // bucket_index = std::floor(tok->cost * cost_scale_); + // vec_index = bucket_index - bucket_storage_begin_; + // then access buckets_[vec_index]. + std::vector > buckets_; + + // The lowest-numbered bucket_index that is occupied (i.e. the first one which + // has any elements). Will be updated as we add or remove tokens. + // If this corresponds to a value past the end of buckets_, we interpret it + // as 'there are no buckets with entries'. + int32 first_occupied_bucket_index_; + + // An offset that determines how we index into the buckets_ vector; + // may be interpreted as a 'bucket_index' that is better than any one that + // we are likely to see. + // In the constructor this will be initialized to something like + // bucket_storage_begin_ = std::floor((best_cost_estimate - 15) * cost_scale) + // which will make it unlikely that we have to change this value in future if + // we get a much better Token (this is expensive because it involves + // reallocating 'buckets_'). + int32 bucket_storage_begin_; +}; + +/** This is the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to support + quick lookup of the current best path (see lattice-faster-online-decoder.h) + + The FST you invoke this decoder with is expected to equal + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ +template > +class LatticeFasterDecoderCombineTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ForwardLinkT = decodercombine::ForwardLink; + + using StateIdToTokenMap = typename std::unordered_map; + //using StateIdToTokenMap = typename std::unordered_map, std::equal_to, + // fst::PoolAllocator > >; + using IterType = typename StateIdToTokenMap::const_iterator; + + using BucketQueue = typename kaldi::BucketQueue; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeFasterDecoderCombineTpl(const FST &fst, + const LatticeFasterDecoderCombineConfig &config); + + // This version of the constructor takes ownership of the fst, and will delete + // it when this object is destroyed. + LatticeFasterDecoderCombineTpl(const LatticeFasterDecoderCombineConfig &config, + FST *fst); + + void SetOptions(const LatticeFasterDecoderCombineConfig &config) { + config_ = config; + } + + const LatticeFasterDecoderCombineConfig &GetOptions() const { + return config_; + } + + ~LatticeFasterDecoderCombineTpl(); + + /// Decodes until there are no more frames left in the "decodable" object.. + /// note, this may block waiting for input if the "decodable" object blocks. + /// Returns true if any kind of traceback is available (not necessarily from a + /// final state). + bool Decode(DecodableInterface *decodable); + + + /// says whether a final-state was active on the last frame. If it was not, the + /// lattice (or traceback) will end with states that are not final-states. + bool ReachedFinal() const { + return FinalRelativeCost() != std::numeric_limits::infinity(); + } + + /// Outputs an FST corresponding to the single best path through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. Note: this just calls GetRawLattice() + /// and figures out the shortest path. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true); + + /// Outputs an FST corresponding to the raw, state-level + /// tracebacks. Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state + /// of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + /// The raw lattice will be topologically sorted. + /// The function can be called during decoding, it will process non-emitting + /// arcs from "cur_toks_" map to get tokens from both non-emitting and + /// emitting arcs for getting raw lattice. Then recover it to ensure the + /// consistency of ProcessForFrame(). + /// + /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, + /// which also supports a pruning beam, in case for some reason + /// you want it pruned tighter than the regular lattice beam. + /// We could put that here in future needed. + bool GetRawLattice(Lattice *ofst, bool use_final_probs = true); + + + + /// [Deprecated, users should now use GetRawLattice and determinize it + /// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper]. + /// Outputs an FST corresponding to the lattice-determinized + /// lattice (one path per word sequence). Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state of the graph + /// then it will include those as final-probs, else it will treat all + /// final-probs as one. + bool GetLattice(CompactLattice *ofst, + bool use_final_probs = true); + + /// InitDecoding initializes the decoding, and should only be used if you + /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to + /// call this. You can also call InitDecoding if you have already decoded an + /// utterance and want to start with a new utterance. + void InitDecoding(); + + /// This will decode until there are no more frames ready in the decodable + /// object. You can keep calling it each time more frames become available. + /// If max_num_frames is specified, it specifies the maximum number of frames + /// the function will decode before returning. + void AdvanceDecoding(DecodableInterface *decodable, + int32 max_num_frames = -1); + + /// This function may be optionally called after AdvanceDecoding(), when you + /// do not plan to decode any further. It does an extra pruning step that + /// will help to prune the lattices output by GetLattice and (particularly) + /// GetRawLattice more accurately, particularly toward the end of the + /// utterance. It does this by using the final-probs in pruning (if any + /// final-state survived); it also does a final pruning step that visits all + /// states (the pruning that is done during decoding may fail to prune states + /// that are within kPruningScale = 0.1 outside of the beam). If you call + /// this, you cannot call AdvanceDecoding again (it will fail), and you + /// cannot call GetLattice() and related functions with use_final_probs = + /// false. + /// Used to be called PruneActiveTokensFinal(). + void FinalizeDecoding(); + + /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + /// more information. It returns the difference between the best (final-cost + /// plus cost) of any token on the final frame, and the best cost of any token + /// on the final frame. If it is infinity it means no final-states were + /// present on the final frame. It will usually be nonnegative. If it not + /// too positive (e.g. < 5 is my first guess, but this is not tested) you can + /// take it as a good indication that we reached the final-state with + /// reasonable likelihood. + BaseFloat FinalRelativeCost() const; + + + // Returns the number of frames decoded so far. The value returned changes + // whenever we call ProcessForFrame(). + inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } + + protected: + // we make things protected instead of private, as code in + // LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the + // internals. + + // Deletes the elements of the singly linked list tok->links. + inline static void DeleteForwardLinks(Token *tok); + + // head of per-frame list of Tokens (list is in topological order), + // and something saying whether we ever pruned it using PruneForwardLinks. + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + TokenList(): toks(NULL), must_prune_forward_links(true), + must_prune_tokens(true) { } + }; + + // FindOrAddToken either locates a token in hash map "token_map", or if necessary + // inserts a new, empty token (i.e. with no forward links) for the current + // frame. [note: it's inserted if necessary into hash map and also into the + // singly linked list of tokens active on this frame (whose head is at + // active_toks_[frame]). The frame_plus_one argument is the acoustic frame + // index plus one, which is used to index into the active_toks_ array. + // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the + // token was newly created or the cost changed. + // If Token == StdToken, the 'backpointer' argument has no purpose (and will + // hopefully be optimized out). + inline Token *FindOrAddToken(StateId state, int32 token_list_index, + BaseFloat tot_cost, Token *backpointer, + StateIdToTokenMap *token_map, + bool *changed); + + // prunes outgoing links for all tokens in active_toks_[frame] + // it's called by PruneActiveTokens + // all links, that have link_extra_cost > lattice_beam are pruned + // delta is the amount by which the extra_costs must change + // before we set *extra_costs_changed = true. + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, + BaseFloat delta); + + // This function computes the final-costs for tokens active on the final + // frame. It outputs to final-costs, if non-NULL, a map from the Token* + // pointer to the final-prob of the corresponding state, for all Tokens + // that correspond to states that have final-probs. This map will be + // empty if there were no final-probs. It outputs to + // final_relative_cost, if non-NULL, the difference between the best + // forward-cost including the final-prob cost, and the best forward-cost + // without including the final-prob cost (this will usually be positive), or + // infinity if there were no final-probs. [c.f. FinalRelativeCost(), which + // outputs this quanitity]. It outputs to final_best_cost, if + // non-NULL, the lowest for any token t active on the final frame, of + // forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in + // the graph of the state corresponding to token t, or the best of + // forward-cost[t] if there were no final-probs active on the final frame. + // You cannot call this after FinalizeDecoding() has been called; in that + // case you should get the answer from class-member variables. + void ComputeFinalCosts(unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const; + + // PruneForwardLinksFinal is a version of PruneForwardLinks that we call + // on the final frame. If there are final tokens active, it uses + // the final-probs for pruning, otherwise it treats all tokens as final. + void PruneForwardLinksFinal(); + + // Prune away any tokens on this frame that have no forward links. + // [we don't do this in PruneForwardLinks because it would give us + // a problem with dangling pointers]. + // It's called by PruneActiveTokens if any forward links have been pruned + void PruneTokensForFrame(int32 frame_plus_one); + + + // Go backwards through still-alive tokens, pruning them if the + // forward+backward cost is more than lat_beam away from the best path. It's + // possible to prove that this is "correct" in the sense that we won't lose + // anything outside of lat_beam, regardless of what happens in the future. + // delta controls when it considers a cost to have changed enough to continue + // going backward and propagating the change. larger delta -> will recurse + // less far. + void PruneActiveTokens(BaseFloat delta); + + /// Processes non-emitting (epsilon) arcs and emitting arcs for one frame + /// together. It takes the emittion tokens in "prev_toks_" from last frame. + /// Generates non-emitting tokens for previous frame and emitting tokens for + /// next frame. + /// Notice: The emitting tokens for the current frame means the token take + /// acoustic scores of the current frame. (i.e. the destnations of emitting + /// arcs.) + void ProcessForFrame(DecodableInterface *decodable); + + /// Processes nonemitting (epsilon) arcs for one frame. + /// Calls this function once when all frames were processed. + /// Or calls it in GetRawLattice() to generate the complete token list for + /// the last frame. [Deal With the tokens in map "cur_toks_" which would + /// only contains emittion tokens from previous frame.] + /// If the map, "token_orig_cost", isn't NULL, we build the map which will + /// be used to recover "active_toks_[last_frame]" token list for the last + /// frame. + void ProcessNonemitting(std::unordered_map *token_orig_cost); + + /// When GetRawLattice() is called during decoding, the + /// active_toks_[last_frame] is changed. To keep the consistency of function + /// ProcessForFrame(), recover it. + /// Notice: as new token will be added to the head of TokenList, tok->next + /// will not be affacted. + /// "token_orig_cost" is a mapping from token pointer to the tot_cost of the + /// token before propagating non-emitting arcs. It is used to recover the + /// change of original tokens in the last frame and remove the new tokens + /// which come from propagating non-emitting arcs, so that we can guarantee + /// the consistency of function ProcessForFrame(). + void RecoverLastTokenList( + const std::unordered_map &token_orig_cost); + + + /// The "prev_toks_" and "cur_toks_" actually allow us to maintain current + /// and next frames. They are indexed by StateId. It is indexed by frame-index + /// plus one, where the frame-index is zero-based, as used in decodable object. + /// That is, the emitting probs of frame t are accounted for in tokens at + /// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of + /// the graph. + StateIdToTokenMap prev_toks_; + StateIdToTokenMap cur_toks_; + + /// Gets the weight cutoff. + /// Notice: In traiditional version, the histogram prunning method is applied + /// on a complete token list on one frame. But, in this version, it is used + /// on a token list which only contains the emittion part. So the max_active + /// and min_active values might be narrowed. + BaseFloat GetCutoff(const TokenList &token_list, const Token* best_token, + BaseFloat *adaptive_beam, + BucketQueue *queue); + + std::vector active_toks_; // Lists of tokens, indexed by + // frame (members of TokenList are toks, must_prune_forward_links, + // must_prune_tokens). + std::queue cur_queue_; // temp variable used in ProcessForFrame + // and ProcessNonemitting + std::vector tmp_array_; // used in GetCutoff. + // Stores the best token in next frame. The tot_cost of it will be used to + // initialize the BucketQueue. + Token* best_token_in_next_frame_; + + // fst_ is a pointer to the FST we are decoding from. + const FST *fst_; + // delete_fst_ is true if the pointer fst_ needs to be deleted when this + // object is destroyed. + bool delete_fst_; + + std::vector cost_offsets_; // This contains, for each + // frame, an offset that was added to the acoustic log-likelihoods on that + // frame in order to keep everything in a nice dynamic range i.e. close to + // zero, to reduce roundoff errors. + // Notice: It will only be added to emitting arcs (i.e. cost_offsets_[t] is + // added to arcs from "frame t" to "frame t+1"). + LatticeFasterDecoderCombineConfig config_; + int32 num_toks_; // current total #toks allocated... + bool warned_; + + /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, + /// calling this is optional]. If true, it's forbidden to decode more. Also, + /// if this is set, then the output of ComputeFinalCosts() is in the next + /// three variables. The reason we need to do this is that after + /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some + /// of the tokens on the last frame are freed, so we free the list from toks_ + /// to avoid having dangling pointers hanging around. + bool decoding_finalized_; + /// For the meaning of the next 3 variables, see the comment for + /// decoding_finalized_ above., and ComputeFinalCosts(). + unordered_map final_costs_; + BaseFloat final_relative_cost_; + BaseFloat final_best_cost_; + + // This function takes a singly linked list of tokens for a single frame, and + // outputs a list of them in topological order (it will crash if no such order + // can be found, which will typically be due to decoding graphs with epsilon + // cycles, which are not allowed). Note: the output list may contain NULLs, + // which the caller should pass over; it just happens to be more efficient for + // the algorithm to output a list that contains NULLs. + static void TopSortTokens(Token *tok_list, + std::vector *topsorted_list); + + void ClearActiveTokens(); + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderCombineTpl); +}; + +typedef LatticeFasterDecoderCombineTpl > LatticeFasterDecoderCombine; + + + +} // end namespace kaldi. + +#endif diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index 5c87d72fe14..11c5d01df0a 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -26,6 +26,75 @@ namespace kaldi { +template +BucketQueue::BucketQueue(BaseFloat best_cost_estimate, + BaseFloat cost_scale) : + cost_scale_(cost_scale) { + // NOTE: we reserve plenty of elements to avoid expensive reallocations + // later on. Normally, the size is a little bigger than (adaptive_beam + + // 5) * cost_scale. + int32 bucket_size = 100; + buckets_.resize(bucket_size); + bucket_storage_begin_ = std::floor((best_cost_estimate - 15) * cost_scale); + first_occupied_bucket_index_ = bucket_storage_begin_ + bucket_size; +} + +template +void BucketQueue::Push(Token *tok) { + int32 bucket_index = std::floor(tok->tot_cost * cost_scale_); + int32 vec_index = bucket_index - bucket_storage_begin_; + + if (vec_index < 0) { + KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve" + << " more elements in constructor. Push front."; + int32 increase_size = - vec_index; + std::vector > tmp(buckets_); + buckets_.resize(tmp.size() + increase_size); + std::copy(tmp.begin(), tmp.end(), buckets_.begin() + increase_size); + // Update start point + bucket_storage_begin_ = bucket_index; + vec_index = 0; + } else if (vec_index > buckets_.size() - 1) { + KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve" + << " more elements in constructor. Push back."; + buckets_.resize(vec_index + 1); + } + + tok->in_queue = true; + buckets_[vec_index].push_back(tok); + if (vec_index < (first_occupied_bucket_index_ - bucket_storage_begin_)) + first_occupied_bucket_index_ = vec_index + bucket_storage_begin_; +} + +template +Token* BucketQueue::Pop() { + int32 vec_index = first_occupied_bucket_index_ - bucket_storage_begin_; + Token* best_tok = NULL; + while(vec_index < buckets_.size()) { + // Remove the best token + best_tok = buckets_[vec_index].back(); + buckets_[vec_index].pop_back(); + + if (buckets_[vec_index].empty()) { // This bucket is empty. Update + // first_occupied_bucket_index_ + int32 next_vec_index = vec_index + 1; + for(; next_vec_index < buckets_.size(); next_vec_index++) { + if(!buckets_[next_vec_index].empty()) break; + } + first_occupied_bucket_index_ = bucket_storage_begin_ + next_vec_index; + vec_index = next_vec_index; + } + + if (best_tok->in_queue) { // This is a effective token + best_tok->in_queue = false; + break; + } else { + best_tok = NULL; + } + } + return best_tok; +} + // instantiate this class once for each thing you have to decode. template LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( @@ -73,6 +142,7 @@ void LatticeFasterDecoderCombineTpl::InitDecoding() { active_toks_[0].toks = start_tok; cur_toks_[start_state] = start_tok; // initialize current tokens map num_toks_++; + best_token_in_next_frame_ = start_tok; } // Returns true if any kind of traceback is available (not necessarily from @@ -224,7 +294,7 @@ void LatticeFasterDecoderCombineTpl::RecoverLastTokenList( if (token_orig_cost.find(tok) != token_orig_cost.end()) { DeleteForwardLinks(tok); tok->tot_cost = token_orig_cost.find(tok)->second; - tok->in_current_queue = false; + tok->in_queue = false; tok = tok->next; } else { DeleteForwardLinks(tok); @@ -686,22 +756,15 @@ void LatticeFasterDecoderCombineTpl::FinalizeDecoding() { /// Gets the weight cutoff. template BaseFloat LatticeFasterDecoderCombineTpl::GetCutoff( - const TokenList &token_list, BaseFloat *adaptive_beam, - StateId *best_state_id, Token **best_token) { + const TokenList &token_list, const Token* best_token, + BaseFloat *adaptive_beam, BucketQueue *queue) { + BaseFloat best_weight = best_token->tot_cost; // positive == high cost == bad. // best_weight is the minimum value. - BaseFloat best_weight = std::numeric_limits::infinity(); if (config_.max_active == std::numeric_limits::max() && config_.min_active == 0) { for (Token* tok = token_list.toks; tok != NULL; tok = tok->next) { - BaseFloat w = static_cast(tok->tot_cost); - if (w < best_weight) { - best_weight = w; - if (best_token) { - *best_state_id = tok->state_id; - *best_token = tok; - } - } + queue->Push(tok); } if (adaptive_beam != NULL) *adaptive_beam = config_.beam; return best_weight + config_.beam; @@ -710,13 +773,7 @@ BaseFloat LatticeFasterDecoderCombineTpl::GetCutoff( for (Token* tok = token_list.toks; tok != NULL; tok = tok->next) { BaseFloat w = static_cast(tok->tot_cost); tmp_array_.push_back(w); - if (w < best_weight) { - best_weight = w; - if (best_token) { - *best_state_id = tok->state_id; - *best_token = tok; - } - } + queue->Push(tok); } BaseFloat beam_cutoff = best_weight + config_.beam, @@ -776,17 +833,17 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( } } + KALDI_ASSERT(best_token_in_next_frame_); + BucketQueue cur_queue(best_token_in_next_frame_->tot_cost); BaseFloat adaptive_beam; - Token *best_tok = NULL; - StateId best_tok_state_id; // "cur_cutoff" is used to constrain the epsilon emittion in current frame. // It will not be updated. - BaseFloat cur_cutoff = GetCutoff(active_toks_[frame], &adaptive_beam, - &best_tok_state_id, &best_tok); + BaseFloat cur_cutoff = GetCutoff(active_toks_[frame], + best_token_in_next_frame_, + &adaptive_beam, &cur_queue); KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " << adaptive_beam; - // pruning "online" before having seen all tokens // "next_cutoff" is used to limit a new token in next frame should be handle @@ -805,6 +862,8 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // token is a looser boundary. We use it to estimate a bound on the next // cutoff and we will update the "next_cutoff" once we have better tokens. // The "next_cutoff" will be updated in further processing. + Token *best_tok = best_token_in_next_frame_; + StateId best_tok_state_id = best_tok->state_id; if (best_tok) { cost_offset = - best_tok->tot_cost; for (fst::ArcIterator aiter(*fst_, best_tok_state_id); @@ -820,29 +879,18 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( } } } - + best_token_in_next_frame_ = NULL; // Store the offset on the acoustic likelihoods that we're applying. // Could just do cost_offsets_.push_back(cost_offset), but we // do it this way as it's more robust to future code changes. cost_offsets_.resize(frame + 1, 0.0); cost_offsets_[frame] = cost_offset; - // Build a queue which contains the emittion tokens from previous frame. - for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { - cur_queue_.push(tok->state_id); - tok->in_current_queue = true; - } - // Iterator the "cur_queue_" to process non-emittion and emittion arcs in fst. - while (!cur_queue_.empty()) { - StateId state = cur_queue_.front(); - cur_queue_.pop(); - - KALDI_ASSERT(prev_toks_.find(state) != prev_toks_.end()); - Token *tok = prev_toks_[state]; - + Token *tok = NULL; + while ((tok = cur_queue.Pop()) != NULL) { BaseFloat cur_cost = tok->tot_cost; - tok->in_current_queue = false; // out of queue + StateId state = tok->state_id; if (cur_cost > cur_cutoff) // Don't bother processing successors. continue; // If "tok" has any existing forward links, delete them, @@ -869,9 +917,8 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // "changed" tells us whether the new token has a different // cost from before, or is new. - if (changed && !new_tok->in_current_queue) { - cur_queue_.push(arc.nextstate); - new_tok->in_current_queue = true; + if (changed && !new_tok->in_queue) { + cur_queue.Push(new_tok); } } } else { // propagate emitting @@ -880,9 +927,10 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( cur_cost = tok->tot_cost, tot_cost = cur_cost + ac_cost + graph_cost; if (tot_cost > next_cutoff) continue; - else if (tot_cost + adaptive_beam < next_cutoff) + else if (tot_cost + adaptive_beam < next_cutoff) { next_cutoff = tot_cost + adaptive_beam; // a tighter boundary for emitting - + } + // no change flag is needed Token *next_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, tok, &cur_toks_, NULL); @@ -890,11 +938,15 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // list tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, graph_cost, ac_cost, tok->links); + if (best_token_in_next_frame_ == NULL || + next_tok->tot_cost < best_token_in_next_frame_->tot_cost) { + best_token_in_next_frame_ = next_tok; + } } } // for all arcs } // end of while loop - KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() - 1 - << " is " << prev_toks_.size(); + //KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() - 1 + // << " is " << prev_toks_.size(); } @@ -917,26 +969,18 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( tmp_toks = &cur_toks_; } - // Build the queue to process non-emitting arcs. - for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { - if (fst_->NumInputEpsilons(tok->state_id) != 0) { - cur_queue_.push(tok->state_id); - tok->in_current_queue = true; - } - } - + BucketQueue cur_queue(best_token_in_next_frame_->tot_cost); // "cur_cutoff" is used to constrain the epsilon emittion in current frame. // It will not be updated. BaseFloat adaptive_beam; - BaseFloat cur_cutoff = GetCutoff(active_toks_[frame], &adaptive_beam, NULL, NULL); + BaseFloat cur_cutoff = GetCutoff(active_toks_[frame], + best_token_in_next_frame_, + &adaptive_beam, &cur_queue); - while (!cur_queue_.empty()) { - StateId state = cur_queue_.front(); - cur_queue_.pop(); - - KALDI_ASSERT(tmp_toks->find(state) != tmp_toks->end()); - Token *tok = (*tmp_toks)[state]; + Token *tok = NULL; + while ((tok = cur_queue.Pop()) != NULL) { BaseFloat cur_cost = tok->tot_cost; + StateId state = tok->state_id; if (cur_cost > cur_cutoff) // Don't bother processing successors. continue; // If "tok" has any existing forward links, delete them, @@ -963,14 +1007,12 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( // "changed" tells us whether the new token has a different // cost from before, or is new. - if (changed && !new_tok->in_current_queue) { - cur_queue_.push(arc.nextstate); - new_tok->in_current_queue = true; + if (changed && !new_tok->in_queue) { + cur_queue.Push(new_tok); } } } } // end of for loop - tok->in_current_queue = false; } // end of while loop if (token_orig_cost) delete tmp_toks; } diff --git a/src/decoder/lattice-faster-decoder-combine.h b/src/decoder/lattice-faster-decoder-combine.h index 900d03520e4..05f36b8aeab 100644 --- a/src/decoder/lattice-faster-decoder-combine.h +++ b/src/decoder/lattice-faster-decoder-combine.h @@ -154,7 +154,7 @@ struct StdToken { // identitfy the token is in current queue or not to prevent duplication in // function ProcessOneFrame(). - bool in_current_queue; + bool in_queue; // This function does nothing and should be optimized out; it's needed @@ -169,7 +169,7 @@ struct StdToken { inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, StateId state_id, ForwardLinkT *links, Token *next, Token *backpointer): tot_cost(tot_cost), extra_cost(extra_cost), state_id(state_id), - links(links), next(next), in_current_queue(false) { } + links(links), next(next), in_queue(false) { } }; template @@ -214,7 +214,7 @@ struct BackpointerToken { // identitfy the token is in current queue or not to prevent duplication in // function ProcessOneFrame(). - bool in_current_queue; + bool in_queue; inline void SetBackpointer (Token *backpointer) { this->backpointer = backpointer; @@ -225,12 +225,71 @@ struct BackpointerToken { Token *next, Token *backpointer): tot_cost(tot_cost), extra_cost(extra_cost), state_id(state_id), links(links), next(next), backpointer(backpointer), - in_current_queue(false) { } + in_queue(false) { } }; } // namespace decoder +template +class BucketQueue { + public: + /** Constructor. 'cost_scale' is a scale that we multiply the token costs by + * before intergerizing; a larger value means more buckets. + * 'best_cost_estimate' is an estimate of the best (lowest) cost that + * we are likely to encounter (e.g. the best cost that we have seen so far). + * It is used to initialize 'bucket_storage_begin_'. + */ + BucketQueue(BaseFloat best_cost_estimate, BaseFloat cost_scale = 1.0); + + // Add a Token to the queue; sets the field tok->in_queue to true (it is not + // an error if it was already true). + // If a Token was already in the queue but its cost improves, you should + // just Push it again. It will be added to (possibly) a different bucket, but + // the old entry will remain. The old entry in the queue will be considered as + // nonexistent when we try to pop it and notice that the recorded cost + // does not match the cost in the Token. (Actually, we use in_queue to decide + // an entry is nonexistent or This strategy means that you may not + // delete Tokens as long as pointers to them might exist in this queue (hence, + // it is probably best to only ever have this queue as a local variable inside + // a function). + void Push(Token *tok); + + // Removes and returns the next Token 'tok' in the queue, or NULL if there + // were no Tokens left. Sets tok->in_queue to false for the returned Token. + Token* Pop(); + + private: + // Configuration value that is multiplied by tokens' costs before integerizing + // them to determine the bucket index + BaseFloat cost_scale_; + + // buckets_ is a list of Tokens 'tok' for each bucket. + // If tok->in_queue is false, then the item is considered as not + // existing (this is to avoid having to explicitly remove Tokens when their + // costs change). The index into buckets_ is determined as follows: + // bucket_index = std::floor(tok->cost * cost_scale_); + // vec_index = bucket_index - bucket_storage_begin_; + // then access buckets_[vec_index]. + std::vector > buckets_; + + // The lowest-numbered bucket_index that is occupied (i.e. the first one which + // has any elements). Will be updated as we add or remove tokens. + // If this corresponds to a value past the end of buckets_, we interpret it + // as 'there are no buckets with entries'. + int32 first_occupied_bucket_index_; + + // An offset that determines how we index into the buckets_ vector; + // may be interpreted as a 'bucket_index' that is better than any one that + // we are likely to see. + // In the constructor this will be initialized to something like + // bucket_storage_begin_ = std::floor((best_cost_estimate - 15) * cost_scale) + // which will make it unlikely that we have to change this value in future if + // we get a much better Token (this is expensive because it involves + // reallocating 'buckets_'). + int32 bucket_storage_begin_; +}; + /** This is the "normal" lattice-generating decoder. See \ref lattices_generation \ref decoders_faster and \ref decoders_simple for more information. @@ -261,6 +320,8 @@ class LatticeFasterDecoderCombineTpl { // fst::PoolAllocator > >; using IterType = typename StateIdToTokenMap::const_iterator; + using BucketQueue = typename kaldi::BucketQueue; + // Instantiate this class once for each thing you have to decode. // This version of the constructor does not take ownership of // 'fst'. @@ -509,9 +570,9 @@ class LatticeFasterDecoderCombineTpl { /// on a complete token list on one frame. But, in this version, it is used /// on a token list which only contains the emittion part. So the max_active /// and min_active values might be narrowed. - BaseFloat GetCutoff(const TokenList &token_list, + BaseFloat GetCutoff(const TokenList &token_list, const Token* best_token, BaseFloat *adaptive_beam, - StateId *best_state_id, Token **best_token); + BucketQueue *queue); std::vector active_toks_; // Lists of tokens, indexed by // frame (members of TokenList are toks, must_prune_forward_links, @@ -519,6 +580,9 @@ class LatticeFasterDecoderCombineTpl { std::queue cur_queue_; // temp variable used in ProcessForFrame // and ProcessNonemitting std::vector tmp_array_; // used in GetCutoff. + // Stores the best token in next frame. The tot_cost of it will be used to + // initialize the BucketQueue. + Token* best_token_in_next_frame_; // fst_ is a pointer to the FST we are decoding from. const FST *fst_; From 226a69873a0b53b3ae3730a70a267644c6d7aab3 Mon Sep 17 00:00:00 2001 From: Hang Lyu Date: Wed, 20 Mar 2019 18:58:24 -0400 Subject: [PATCH 16/29] bucketqueue without GetCutoff --- ...tice-faster-decoder-combine-bucketqueue.cc | 34 +-- src/decoder/lattice-faster-decoder-combine.cc | 221 ++++++++---------- src/decoder/lattice-faster-decoder-combine.h | 13 +- 3 files changed, 120 insertions(+), 148 deletions(-) diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc index 11c5d01df0a..03ebf70cc10 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc @@ -42,22 +42,28 @@ BucketQueue::BucketQueue(BaseFloat best_cost_estimate, template void BucketQueue::Push(Token *tok) { int32 bucket_index = std::floor(tok->tot_cost * cost_scale_); - int32 vec_index = bucket_index - bucket_storage_begin_; + size_t vec_index = static_cast(bucket_index - bucket_storage_begin_); - if (vec_index < 0) { + if (vec_index >= buckets_.size()) { KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve" - << " more elements in constructor. Push front."; - int32 increase_size = - vec_index; - std::vector > tmp(buckets_); - buckets_.resize(tmp.size() + increase_size); - std::copy(tmp.begin(), tmp.end(), buckets_.begin() + increase_size); - // Update start point - bucket_storage_begin_ = bucket_index; - vec_index = 0; - } else if (vec_index > buckets_.size() - 1) { - KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve" - << " more elements in constructor. Push back."; - buckets_.resize(vec_index + 1); + << " more elements in constructor."; + int32 offset = static_cast(vec_index); + // a margin here (e.g. 10); + int32 increase_size = offset >= 0 ? offset + 1 - buckets_.size() + 10 : + - offset + 10; + buckets_.resize(buckets_.size() + increase_size); + + // Push front + if (offset < 0) { + std::vector > tmp(buckets_); + buckets_.clear(); + for (int32 i = 10 - offset ; i < buckets_.size(); i++) { + buckets_[i].swap(tmp[i + offset - 10]); + } + // Update start point + bucket_storage_begin_ = bucket_index - 10; + vec_index = 10; + } } tok->in_queue = true; diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index 11c5d01df0a..274eb7d4d45 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -42,22 +42,28 @@ BucketQueue::BucketQueue(BaseFloat best_cost_estimate, template void BucketQueue::Push(Token *tok) { int32 bucket_index = std::floor(tok->tot_cost * cost_scale_); - int32 vec_index = bucket_index - bucket_storage_begin_; + size_t vec_index = static_cast(bucket_index - bucket_storage_begin_); - if (vec_index < 0) { + if (vec_index >= buckets_.size()) { KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve" - << " more elements in constructor. Push front."; - int32 increase_size = - vec_index; - std::vector > tmp(buckets_); - buckets_.resize(tmp.size() + increase_size); - std::copy(tmp.begin(), tmp.end(), buckets_.begin() + increase_size); - // Update start point - bucket_storage_begin_ = bucket_index; - vec_index = 0; - } else if (vec_index > buckets_.size() - 1) { - KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve" - << " more elements in constructor. Push back."; - buckets_.resize(vec_index + 1); + << " more elements in constructor."; + int32 offset = static_cast(vec_index); + // a margin here (e.g. 10); + int32 increase_size = offset >= 0 ? offset + 1 - buckets_.size() + 10 : + - offset + 10; + buckets_.resize(buckets_.size() + increase_size); + + // Push front + if (offset < 0) { + std::vector > tmp(buckets_); + buckets_.clear(); + for (int32 i = 10 - offset ; i < buckets_.size(); i++) { + buckets_[i].swap(tmp[i + offset - 10]); + } + // Update start point + bucket_storage_begin_ = bucket_index - 10; + vec_index = 10; + } } tok->in_queue = true; @@ -143,6 +149,7 @@ void LatticeFasterDecoderCombineTpl::InitDecoding() { cur_toks_[start_state] = start_tok; // initialize current tokens map num_toks_++; best_token_in_next_frame_ = start_tok; + adaptive_beam_ = config_.beam; } // Returns true if any kind of traceback is available (not necessarily from @@ -753,67 +760,6 @@ void LatticeFasterDecoderCombineTpl::FinalizeDecoding() { << " to " << num_toks_; } -/// Gets the weight cutoff. -template -BaseFloat LatticeFasterDecoderCombineTpl::GetCutoff( - const TokenList &token_list, const Token* best_token, - BaseFloat *adaptive_beam, BucketQueue *queue) { - BaseFloat best_weight = best_token->tot_cost; - // positive == high cost == bad. - // best_weight is the minimum value. - if (config_.max_active == std::numeric_limits::max() && - config_.min_active == 0) { - for (Token* tok = token_list.toks; tok != NULL; tok = tok->next) { - queue->Push(tok); - } - if (adaptive_beam != NULL) *adaptive_beam = config_.beam; - return best_weight + config_.beam; - } else { - tmp_array_.clear(); - for (Token* tok = token_list.toks; tok != NULL; tok = tok->next) { - BaseFloat w = static_cast(tok->tot_cost); - tmp_array_.push_back(w); - queue->Push(tok); - } - - BaseFloat beam_cutoff = best_weight + config_.beam, - min_active_cutoff = std::numeric_limits::infinity(), - max_active_cutoff = std::numeric_limits::infinity(); - - KALDI_VLOG(6) << "Number of emitting tokens on frame " - << NumFramesDecoded() - 1 << " is " << tmp_array_.size(); - - if (tmp_array_.size() > static_cast(config_.max_active)) { - std::nth_element(tmp_array_.begin(), - tmp_array_.begin() + config_.max_active, - tmp_array_.end()); - max_active_cutoff = tmp_array_[config_.max_active]; - } - if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. - if (adaptive_beam) - *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta; - return max_active_cutoff; - } - if (tmp_array_.size() > static_cast(config_.min_active)) { - if (config_.min_active == 0) min_active_cutoff = best_weight; - else { - std::nth_element(tmp_array_.begin(), - tmp_array_.begin() + config_.min_active, - tmp_array_.size() > static_cast(config_.max_active) ? - tmp_array_.begin() + config_.max_active : tmp_array_.end()); - min_active_cutoff = tmp_array_[config_.min_active]; - } - } - if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. - if (adaptive_beam) - *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta; - return min_active_cutoff; - } else { - *adaptive_beam = config_.beam; - return beam_cutoff; - } - } -} template void LatticeFasterDecoderCombineTpl::ProcessForFrame( @@ -834,51 +780,27 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( } KALDI_ASSERT(best_token_in_next_frame_); - BucketQueue cur_queue(best_token_in_next_frame_->tot_cost); - BaseFloat adaptive_beam; - // "cur_cutoff" is used to constrain the epsilon emittion in current frame. - // It will not be updated. - BaseFloat cur_cutoff = GetCutoff(active_toks_[frame], - best_token_in_next_frame_, - &adaptive_beam, &cur_queue); - KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " - << adaptive_beam; - - // pruning "online" before having seen all tokens + BucketQueue cur_queue(best_token_in_next_frame_->tot_cost, config_.cost_scale); + // Add tokens to queue + for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + cur_queue.Push(tok); + } + // Declare a local variable so the compiler can put it in a register, since + // C++ assumes other threads could be modifying class members. + BaseFloat adaptive_beam = adaptive_beam_; + // "cur_cutoff" will be kept to the best-seen-so-far token on this frame + // + adaptive_beam + BaseFloat cur_cutoff = std::numeric_limits::infinity(); // "next_cutoff" is used to limit a new token in next frame should be handle // or not. It will be updated along with the further processing. + // this will be kept updated to the best-seen-so-far token "on next frame" + // + adaptive_beam BaseFloat next_cutoff = std::numeric_limits::infinity(); // "cost_offset" contains the acoustic log-likelihoods on current frame in // order to keep everything in a nice dynamic range. Reduce roundoff errors. - BaseFloat cost_offset = 0.0; - - // First process the best token to get a hopefully - // reasonably tight bound on the next cutoff. The only - // products of the next block are "next_cutoff" and "cost_offset". - // Notice: As the difference between the combine version and the traditional - // version, this "best_tok" is choosen from emittion tokens. Normally, the - // best token of one frame comes from an epsilon non-emittion. So the best - // token is a looser boundary. We use it to estimate a bound on the next - // cutoff and we will update the "next_cutoff" once we have better tokens. - // The "next_cutoff" will be updated in further processing. - Token *best_tok = best_token_in_next_frame_; - StateId best_tok_state_id = best_tok->state_id; - if (best_tok) { - cost_offset = - best_tok->tot_cost; - for (fst::ArcIterator aiter(*fst_, best_tok_state_id); - !aiter.Done(); - aiter.Next()) { - const Arc &arc = aiter.Value(); - if (arc.ilabel != 0) { // propagate.. - // ac_cost + graph_cost - BaseFloat new_weight = arc.weight.Value() + cost_offset - - decodable->LogLikelihood(frame, arc.ilabel) + best_tok->tot_cost; - if (new_weight + adaptive_beam < next_cutoff) - next_cutoff = new_weight + adaptive_beam; - } - } - } + BaseFloat cost_offset = - best_token_in_next_frame_->tot_cost; + best_token_in_next_frame_ = NULL; // Store the offset on the acoustic likelihoods that we're applying. // Could just do cost_offsets_.push_back(cost_offset), but we @@ -888,11 +810,17 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // Iterator the "cur_queue_" to process non-emittion and emittion arcs in fst. Token *tok = NULL; - while ((tok = cur_queue.Pop()) != NULL) { + int32 num_toks_processed = 0; + int32 max_active = config_.max_active; + for (; num_toks_processed < max_active && (tok = cur_queue.Pop()) != NULL; + num_toks_processed++) { BaseFloat cur_cost = tok->tot_cost; StateId state = tok->state_id; - if (cur_cost > cur_cutoff) // Don't bother processing successors. - continue; + if (cur_cost > cur_cutoff) { // Don't bother processing successors. + break; // This is a priority queue. The following tokens will be worse + } else if (cur_cost + adaptive_beam < cur_cutoff) { + cur_cutoff = cur_cost + adaptive_beam; // a tighter boundary + } // If "tok" has any existing forward links, delete them, // because we're about to regenerate them. This is a kind // of non-optimality (remember, this is the simple decoder), @@ -945,8 +873,32 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( } } // for all arcs } // end of while loop - //KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() - 1 - // << " is " << prev_toks_.size(); + + { // This block updates adaptive_beam_ + BaseFloat beam_used_this_frame = adaptive_beam; + Token *tok = cur_queue.Pop(); + if (tok != NULL) { + // The queue would only be nonempty if we hit the max-active constraint. + BaseFloat best_cost_this_frame = cur_cutoff - adaptive_beam; + beam_used_this_frame = tok->tot_cost - best_cost_this_frame; + } + if (num_toks_processed <= config_.min_active) { + // num-toks active is dangerously low, increase the beam even if it + // already exceeds the user-specified beam. + adaptive_beam_ = std::max( + config_.beam, beam_used_this_frame + 2.0 * config_.beam_delta); + } else { + // have adaptive_beam_ approach beam_ in intervals of config_.beam_delta + BaseFloat diff_from_beam = beam_used_this_frame - config_.beam; + if (std::abs(diff_from_beam) < config_.beam_delta) { + adaptive_beam_ = config_.beam; + } else { + // make it close to beam_ + adaptive_beam_ = beam_used_this_frame - + config_.beam_delta * (diff_from_beam > 0 ? 1 : -1); + } + } + } } @@ -969,20 +921,31 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( tmp_toks = &cur_toks_; } - BucketQueue cur_queue(best_token_in_next_frame_->tot_cost); - // "cur_cutoff" is used to constrain the epsilon emittion in current frame. - // It will not be updated. - BaseFloat adaptive_beam; - BaseFloat cur_cutoff = GetCutoff(active_toks_[frame], - best_token_in_next_frame_, - &adaptive_beam, &cur_queue); + BucketQueue cur_queue(best_token_in_next_frame_->tot_cost, config_.cost_scale); + for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + cur_queue.Push(tok); + } + + // Declare a local variable so the compiler can put it in a register, since + // C++ assumes other threads could be modifying class members. + BaseFloat adaptive_beam = adaptive_beam_; + // "cur_cutoff" will be kept to the best-seen-so-far token on this frame + // + adaptive_beam + BaseFloat cur_cutoff = std::numeric_limits::infinity(); Token *tok = NULL; - while ((tok = cur_queue.Pop()) != NULL) { + int32 num_toks_processed = 0; + int32 max_active = config_.max_active; + + for (; num_toks_processed < max_active && (tok = cur_queue.Pop()) != NULL; + num_toks_processed++) { BaseFloat cur_cost = tok->tot_cost; StateId state = tok->state_id; - if (cur_cost > cur_cutoff) // Don't bother processing successors. - continue; + if (cur_cost > cur_cutoff) { // Don't bother processing successors. + break; // This is a priority queue. The following tokens will be worse + } else if (cur_cost + adaptive_beam < cur_cutoff) { + cur_cutoff = cur_cost + adaptive_beam; // a tighter boundary + } // If "tok" has any existing forward links, delete them, // because we're about to regenerate them. This is a kind // of non-optimality (remember, this is the simple decoder), diff --git a/src/decoder/lattice-faster-decoder-combine.h b/src/decoder/lattice-faster-decoder-combine.h index 05f36b8aeab..8abce6260b7 100644 --- a/src/decoder/lattice-faster-decoder-combine.h +++ b/src/decoder/lattice-faster-decoder-combine.h @@ -46,6 +46,7 @@ struct LatticeFasterDecoderCombineConfig { // command-line program. BaseFloat beam_delta; // has nothing to do with beam_ratio BaseFloat hash_ratio; + BaseFloat cost_scale; BaseFloat prune_scale; // Note: we don't make this configurable on the command line, // it's not a very important parameter. It affects the // algorithm that prunes the tokens as we go. @@ -62,6 +63,7 @@ struct LatticeFasterDecoderCombineConfig { determinize_lattice(true), beam_delta(0.5), hash_ratio(2.0), + cost_scale(1.0), prune_scale(0.1) { } void Register(OptionsItf *opts) { det_opts.Register(opts); @@ -81,6 +83,10 @@ struct LatticeFasterDecoderCombineConfig { "max-active constraint is applied. Larger is more accurate."); opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " "control hash behavior"); + opts->Register("cost-scale", &cost_scale, "A scale that we multiply the " + "token costs by before intergerizing; a larger value means " + "more buckets and precise."); + } void Check() const { KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 @@ -570,16 +576,11 @@ class LatticeFasterDecoderCombineTpl { /// on a complete token list on one frame. But, in this version, it is used /// on a token list which only contains the emittion part. So the max_active /// and min_active values might be narrowed. - BaseFloat GetCutoff(const TokenList &token_list, const Token* best_token, - BaseFloat *adaptive_beam, - BucketQueue *queue); - std::vector active_toks_; // Lists of tokens, indexed by // frame (members of TokenList are toks, must_prune_forward_links, // must_prune_tokens). std::queue cur_queue_; // temp variable used in ProcessForFrame // and ProcessNonemitting - std::vector tmp_array_; // used in GetCutoff. // Stores the best token in next frame. The tot_cost of it will be used to // initialize the BucketQueue. Token* best_token_in_next_frame_; @@ -614,6 +615,8 @@ class LatticeFasterDecoderCombineTpl { BaseFloat final_relative_cost_; BaseFloat final_best_cost_; + BaseFloat adaptive_beam_; // will be set to beam_ when we start + // This function takes a singly linked list of tokens for a single frame, and // outputs a list of them in topological order (it will crash if no such order // can be found, which will typically be due to decoding graphs with epsilon From 82206564461419d6b77f3c29d954cc6682740fae Mon Sep 17 00:00:00 2001 From: Hang Lyu Date: Thu, 21 Mar 2019 00:14:26 -0400 Subject: [PATCH 17/29] small fix and class member queue --- ...tice-faster-decoder-combine-bucketqueue.cc | 290 ++++++++---------- ...ttice-faster-decoder-combine-bucketqueue.h | 29 +- src/decoder/lattice-faster-decoder-combine.cc | 117 ++++--- src/decoder/lattice-faster-decoder-combine.h | 16 +- 4 files changed, 231 insertions(+), 221 deletions(-) diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc index 03ebf70cc10..06f1f80b2a8 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc @@ -35,8 +35,8 @@ BucketQueue::BucketQueue(BaseFloat best_cost_estimate, // 5) * cost_scale. int32 bucket_size = 100; buckets_.resize(bucket_size); - bucket_storage_begin_ = std::floor((best_cost_estimate - 15) * cost_scale); - first_occupied_bucket_index_ = bucket_storage_begin_ + bucket_size; + bucket_storage_begin_ = std::floor((best_cost_estimate - 15) * cost_scale_); + first_occupied_vec_index_ = bucket_size; } template @@ -45,60 +45,69 @@ void BucketQueue::Push(Token *tok) { size_t vec_index = static_cast(bucket_index - bucket_storage_begin_); if (vec_index >= buckets_.size()) { - KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve" - << " more elements in constructor."; - int32 offset = static_cast(vec_index); - // a margin here (e.g. 10); - int32 increase_size = offset >= 0 ? offset + 1 - buckets_.size() + 10 : - - offset + 10; - buckets_.resize(buckets_.size() + increase_size); - - // Push front - if (offset < 0) { - std::vector > tmp(buckets_); - buckets_.clear(); - for (int32 i = 10 - offset ; i < buckets_.size(); i++) { - buckets_[i].swap(tmp[i + offset - 10]); - } - // Update start point - bucket_storage_begin_ = bucket_index - 10; - vec_index = 10; + int32 margin = 10; // a margin which is used to reduce re-allocate + // space frequently + // A cast from unsigned to signed type does not generate a machine-code + // instruction + if (static_cast(vec_index) > 0) { + KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve" + << " more elements in constructor. Push back."; + buckets_.resize(static_cast(vec_index) + margin); + } else { // less than 0 + KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve" + << " more elements in constructor. Push front."; + int32 increase_size = - static_cast(vec_index) + margin; + buckets_.resize(buckets_.size() + increase_size); + // translation + for (size_t i = buckets_.size() - 1; i >= increase_size; i--) { + buckets_[i].swap(buckets_[i - increase_size]); + } + bucket_storage_begin_ = bucket_storage_begin_ - increase_size; + vec_index = increase_size; } } - tok->in_queue = true; buckets_[vec_index].push_back(tok); - if (vec_index < (first_occupied_bucket_index_ - bucket_storage_begin_)) - first_occupied_bucket_index_ = vec_index + bucket_storage_begin_; + if (vec_index < first_occupied_vec_index_) + first_occupied_vec_index_ = vec_index; } template Token* BucketQueue::Pop() { - int32 vec_index = first_occupied_bucket_index_ - bucket_storage_begin_; - Token* best_tok = NULL; - while(vec_index < buckets_.size()) { + int32 vec_index = first_occupied_vec_index_; + while (vec_index < buckets_.size()) { + Token* tok = buckets_[vec_index].back(); // Remove the best token - best_tok = buckets_[vec_index].back(); buckets_[vec_index].pop_back(); - if (buckets_[vec_index].empty()) { // This bucket is empty. Update - // first_occupied_bucket_index_ + if (buckets_[vec_index].empty()) { // This bucket is empty. Update vec_index int32 next_vec_index = vec_index + 1; - for(; next_vec_index < buckets_.size(); next_vec_index++) { - if(!buckets_[next_vec_index].empty()) break; + for (; next_vec_index < buckets_.size(); next_vec_index++) { + if (!buckets_[next_vec_index].empty()) break; } - first_occupied_bucket_index_ = bucket_storage_begin_ + next_vec_index; vec_index = next_vec_index; } - if (best_tok->in_queue) { // This is a effective token - best_tok->in_queue = false; - break; - } else { - best_tok = NULL; + if (tok->in_queue) { // This is a effective token + tok->in_queue = false; + first_occupied_vec_index_ = vec_index; + return tok; } } - return best_tok; + return NULL; +} + +template +void BucketQueue::Clear() { + for (size_t i = 0; i < buckets_.size(); i++) { + buckets_[i].clear(); + } + first_occupied_vec_index_ = buckets_.size(); +} + +template +void BucketQueue::SetBegin(BaseFloat best_cost_estimate) { + bucket_storage_begin_ = std::floor((best_cost_estimate - 15) * cost_scale_); } // instantiate this class once for each thing you have to decode. @@ -106,15 +115,19 @@ template LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( const FST &fst, const LatticeFasterDecoderCombineConfig &config): - fst_(&fst), delete_fst_(false), config_(config), num_toks_(0) { + fst_(&fst), delete_fst_(false), config_(config), num_toks_(0), + cur_queue_(0, config_.cost_scale) { config.Check(); + prev_toks_.reserve(1000); + cur_toks_.reserve(1000); } template LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( const LatticeFasterDecoderCombineConfig &config, FST *fst): - fst_(fst), delete_fst_(true), config_(config), num_toks_(0) { + fst_(fst), delete_fst_(true), config_(config), num_toks_(0), + cur_queue_(0, config_.cost_scale) { config.Check(); prev_toks_.reserve(1000); cur_toks_.reserve(1000); @@ -149,6 +162,8 @@ void LatticeFasterDecoderCombineTpl::InitDecoding() { cur_toks_[start_state] = start_tok; // initialize current tokens map num_toks_++; best_token_in_next_frame_ = start_tok; + adaptive_beam_ = config_.beam; + } // Returns true if any kind of traceback is available (not necessarily from @@ -759,67 +774,6 @@ void LatticeFasterDecoderCombineTpl::FinalizeDecoding() { << " to " << num_toks_; } -/// Gets the weight cutoff. -template -BaseFloat LatticeFasterDecoderCombineTpl::GetCutoff( - const TokenList &token_list, const Token* best_token, - BaseFloat *adaptive_beam, BucketQueue *queue) { - BaseFloat best_weight = best_token->tot_cost; - // positive == high cost == bad. - // best_weight is the minimum value. - if (config_.max_active == std::numeric_limits::max() && - config_.min_active == 0) { - for (Token* tok = token_list.toks; tok != NULL; tok = tok->next) { - queue->Push(tok); - } - if (adaptive_beam != NULL) *adaptive_beam = config_.beam; - return best_weight + config_.beam; - } else { - tmp_array_.clear(); - for (Token* tok = token_list.toks; tok != NULL; tok = tok->next) { - BaseFloat w = static_cast(tok->tot_cost); - tmp_array_.push_back(w); - queue->Push(tok); - } - - BaseFloat beam_cutoff = best_weight + config_.beam, - min_active_cutoff = std::numeric_limits::infinity(), - max_active_cutoff = std::numeric_limits::infinity(); - - KALDI_VLOG(6) << "Number of emitting tokens on frame " - << NumFramesDecoded() - 1 << " is " << tmp_array_.size(); - - if (tmp_array_.size() > static_cast(config_.max_active)) { - std::nth_element(tmp_array_.begin(), - tmp_array_.begin() + config_.max_active, - tmp_array_.end()); - max_active_cutoff = tmp_array_[config_.max_active]; - } - if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. - if (adaptive_beam) - *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta; - return max_active_cutoff; - } - if (tmp_array_.size() > static_cast(config_.min_active)) { - if (config_.min_active == 0) min_active_cutoff = best_weight; - else { - std::nth_element(tmp_array_.begin(), - tmp_array_.begin() + config_.min_active, - tmp_array_.size() > static_cast(config_.max_active) ? - tmp_array_.begin() + config_.max_active : tmp_array_.end()); - min_active_cutoff = tmp_array_[config_.min_active]; - } - } - if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. - if (adaptive_beam) - *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta; - return min_active_cutoff; - } else { - *adaptive_beam = config_.beam; - return beam_cutoff; - } - } -} template void LatticeFasterDecoderCombineTpl::ProcessForFrame( @@ -839,52 +793,29 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( } } - KALDI_ASSERT(best_token_in_next_frame_); - BucketQueue cur_queue(best_token_in_next_frame_->tot_cost); - BaseFloat adaptive_beam; - // "cur_cutoff" is used to constrain the epsilon emittion in current frame. - // It will not be updated. - BaseFloat cur_cutoff = GetCutoff(active_toks_[frame], - best_token_in_next_frame_, - &adaptive_beam, &cur_queue); - KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " - << adaptive_beam; - - // pruning "online" before having seen all tokens + KALDI_ASSERT(best_token_in_next_frame_); + cur_queue_.Clear(); + cur_queue_.SetBegin(best_token_in_next_frame_->tot_cost); + // Add tokens to queue + for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + cur_queue_.Push(tok); + } + // Declare a local variable so the compiler can put it in a register, since + // C++ assumes other threads could be modifying class members. + BaseFloat adaptive_beam = adaptive_beam_; + // "cur_cutoff" will be kept to the best-seen-so-far token on this frame + // + adaptive_beam + BaseFloat cur_cutoff = std::numeric_limits::infinity(); // "next_cutoff" is used to limit a new token in next frame should be handle // or not. It will be updated along with the further processing. + // this will be kept updated to the best-seen-so-far token "on next frame" + // + adaptive_beam BaseFloat next_cutoff = std::numeric_limits::infinity(); // "cost_offset" contains the acoustic log-likelihoods on current frame in // order to keep everything in a nice dynamic range. Reduce roundoff errors. - BaseFloat cost_offset = 0.0; - - // First process the best token to get a hopefully - // reasonably tight bound on the next cutoff. The only - // products of the next block are "next_cutoff" and "cost_offset". - // Notice: As the difference between the combine version and the traditional - // version, this "best_tok" is choosen from emittion tokens. Normally, the - // best token of one frame comes from an epsilon non-emittion. So the best - // token is a looser boundary. We use it to estimate a bound on the next - // cutoff and we will update the "next_cutoff" once we have better tokens. - // The "next_cutoff" will be updated in further processing. - Token *best_tok = best_token_in_next_frame_; - StateId best_tok_state_id = best_tok->state_id; - if (best_tok) { - cost_offset = - best_tok->tot_cost; - for (fst::ArcIterator aiter(*fst_, best_tok_state_id); - !aiter.Done(); - aiter.Next()) { - const Arc &arc = aiter.Value(); - if (arc.ilabel != 0) { // propagate.. - // ac_cost + graph_cost - BaseFloat new_weight = arc.weight.Value() + cost_offset - - decodable->LogLikelihood(frame, arc.ilabel) + best_tok->tot_cost; - if (new_weight + adaptive_beam < next_cutoff) - next_cutoff = new_weight + adaptive_beam; - } - } - } + BaseFloat cost_offset = - best_token_in_next_frame_->tot_cost; + best_token_in_next_frame_ = NULL; // Store the offset on the acoustic likelihoods that we're applying. // Could just do cost_offsets_.push_back(cost_offset), but we @@ -894,11 +825,17 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // Iterator the "cur_queue_" to process non-emittion and emittion arcs in fst. Token *tok = NULL; - while ((tok = cur_queue.Pop()) != NULL) { + int32 num_toks_processed = 0; + int32 max_active = config_.max_active; + for (; num_toks_processed < max_active && (tok = cur_queue_.Pop()) != NULL; + num_toks_processed++) { BaseFloat cur_cost = tok->tot_cost; StateId state = tok->state_id; - if (cur_cost > cur_cutoff) // Don't bother processing successors. - continue; + if (cur_cost > cur_cutoff) { // Don't bother processing successors. + break; // This is a priority queue. The following tokens will be worse + } else if (cur_cost + adaptive_beam < cur_cutoff) { + cur_cutoff = cur_cost + adaptive_beam; // a tighter boundary + } // If "tok" has any existing forward links, delete them, // because we're about to regenerate them. This is a kind // of non-optimality (remember, this is the simple decoder), @@ -924,7 +861,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // "changed" tells us whether the new token has a different // cost from before, or is new. if (changed && !new_tok->in_queue) { - cur_queue.Push(new_tok); + cur_queue_.Push(new_tok); } } } else { // propagate emitting @@ -951,8 +888,32 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( } } // for all arcs } // end of while loop - //KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() - 1 - // << " is " << prev_toks_.size(); + + { // This block updates adaptive_beam_ + BaseFloat beam_used_this_frame = adaptive_beam; + Token *tok = cur_queue_.Pop(); + if (tok != NULL) { + // The queue would only be nonempty if we hit the max-active constraint. + BaseFloat best_cost_this_frame = cur_cutoff - adaptive_beam; + beam_used_this_frame = tok->tot_cost - best_cost_this_frame; + } + if (num_toks_processed <= config_.min_active) { + // num-toks active is dangerously low, increase the beam even if it + // already exceeds the user-specified beam. + adaptive_beam_ = std::max( + config_.beam, beam_used_this_frame + 2.0 * config_.beam_delta); + } else { + // have adaptive_beam_ approach beam_ in intervals of config_.beam_delta + BaseFloat diff_from_beam = beam_used_this_frame - config_.beam; + if (std::abs(diff_from_beam) < config_.beam_delta) { + adaptive_beam_ = config_.beam; + } else { + // make it close to beam_ + adaptive_beam_ = beam_used_this_frame - + config_.beam_delta * (diff_from_beam > 0 ? 1 : -1); + } + } + } } @@ -975,20 +936,33 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( tmp_toks = &cur_toks_; } - BucketQueue cur_queue(best_token_in_next_frame_->tot_cost); - // "cur_cutoff" is used to constrain the epsilon emittion in current frame. - // It will not be updated. - BaseFloat adaptive_beam; - BaseFloat cur_cutoff = GetCutoff(active_toks_[frame], - best_token_in_next_frame_, - &adaptive_beam, &cur_queue); + KALDI_ASSERT(best_token_in_next_frame_); + cur_queue_.Clear(); + cur_queue_.SetBegin(best_token_in_next_frame_->tot_cost); + for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + cur_queue_.Push(tok); + } + + // Declare a local variable so the compiler can put it in a register, since + // C++ assumes other threads could be modifying class members. + BaseFloat adaptive_beam = adaptive_beam_; + // "cur_cutoff" will be kept to the best-seen-so-far token on this frame + // + adaptive_beam + BaseFloat cur_cutoff = std::numeric_limits::infinity(); Token *tok = NULL; - while ((tok = cur_queue.Pop()) != NULL) { + int32 num_toks_processed = 0; + int32 max_active = config_.max_active; + + for (; num_toks_processed < max_active && (tok = cur_queue_.Pop()) != NULL; + num_toks_processed++) { BaseFloat cur_cost = tok->tot_cost; StateId state = tok->state_id; - if (cur_cost > cur_cutoff) // Don't bother processing successors. - continue; + if (cur_cost > cur_cutoff) { // Don't bother processing successors. + break; // This is a priority queue. The following tokens will be worse + } else if (cur_cost + adaptive_beam < cur_cutoff) { + cur_cutoff = cur_cost + adaptive_beam; // a tighter boundary + } // If "tok" has any existing forward links, delete them, // because we're about to regenerate them. This is a kind // of non-optimality (remember, this is the simple decoder), @@ -1014,7 +988,7 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( // "changed" tells us whether the new token has a different // cost from before, or is new. if (changed && !new_tok->in_queue) { - cur_queue.Push(new_tok); + cur_queue_.Push(new_tok); } } } diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.h b/src/decoder/lattice-faster-decoder-combine-bucketqueue.h index 05f36b8aeab..016c3ddd4b0 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.h +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.h @@ -46,6 +46,7 @@ struct LatticeFasterDecoderCombineConfig { // command-line program. BaseFloat beam_delta; // has nothing to do with beam_ratio BaseFloat hash_ratio; + BaseFloat cost_scale; BaseFloat prune_scale; // Note: we don't make this configurable on the command line, // it's not a very important parameter. It affects the // algorithm that prunes the tokens as we go. @@ -62,6 +63,7 @@ struct LatticeFasterDecoderCombineConfig { determinize_lattice(true), beam_delta(0.5), hash_ratio(2.0), + cost_scale(1.0), prune_scale(0.1) { } void Register(OptionsItf *opts) { det_opts.Register(opts); @@ -81,6 +83,10 @@ struct LatticeFasterDecoderCombineConfig { "max-active constraint is applied. Larger is more accurate."); opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " "control hash behavior"); + opts->Register("cost-scale", &cost_scale, "A scale that we multiply the " + "token costs by before intergerizing; a larger value means " + "more buckets and precise."); + } void Check() const { KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 @@ -259,6 +265,13 @@ class BucketQueue { // were no Tokens left. Sets tok->in_queue to false for the returned Token. Token* Pop(); + // Clear all the individual buckets. Set 'first_occupied_vec_index_' to the + // value past the end of buckets_. + void Clear(); + + // Set 'bucket_storage_begin_'. + void SetBegin(BaseFloat best_cost_estimate); + private: // Configuration value that is multiplied by tokens' costs before integerizing // them to determine the bucket index @@ -273,11 +286,11 @@ class BucketQueue { // then access buckets_[vec_index]. std::vector > buckets_; - // The lowest-numbered bucket_index that is occupied (i.e. the first one which + // The lowest-numbered vec_index that is occupied (i.e. the first one which // has any elements). Will be updated as we add or remove tokens. // If this corresponds to a value past the end of buckets_, we interpret it // as 'there are no buckets with entries'. - int32 first_occupied_bucket_index_; + int32 first_occupied_vec_index_; // An offset that determines how we index into the buckets_ vector; // may be interpreted as a 'bucket_index' that is better than any one that @@ -570,16 +583,10 @@ class LatticeFasterDecoderCombineTpl { /// on a complete token list on one frame. But, in this version, it is used /// on a token list which only contains the emittion part. So the max_active /// and min_active values might be narrowed. - BaseFloat GetCutoff(const TokenList &token_list, const Token* best_token, - BaseFloat *adaptive_beam, - BucketQueue *queue); - std::vector active_toks_; // Lists of tokens, indexed by // frame (members of TokenList are toks, must_prune_forward_links, // must_prune_tokens). - std::queue cur_queue_; // temp variable used in ProcessForFrame - // and ProcessNonemitting - std::vector tmp_array_; // used in GetCutoff. + // Stores the best token in next frame. The tot_cost of it will be used to // initialize the BucketQueue. Token* best_token_in_next_frame_; @@ -614,6 +621,10 @@ class LatticeFasterDecoderCombineTpl { BaseFloat final_relative_cost_; BaseFloat final_best_cost_; + BaseFloat adaptive_beam_; // will be set to beam_ when we start + BucketQueue cur_queue_; // temp variable used in + // ProcessForFrame/ProcessNonemitting + // This function takes a singly linked list of tokens for a single frame, and // outputs a list of them in topological order (it will crash if no such order // can be found, which will typically be due to decoding graphs with epsilon diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index 274eb7d4d45..06f1f80b2a8 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -35,8 +35,8 @@ BucketQueue::BucketQueue(BaseFloat best_cost_estimate, // 5) * cost_scale. int32 bucket_size = 100; buckets_.resize(bucket_size); - bucket_storage_begin_ = std::floor((best_cost_estimate - 15) * cost_scale); - first_occupied_bucket_index_ = bucket_storage_begin_ + bucket_size; + bucket_storage_begin_ = std::floor((best_cost_estimate - 15) * cost_scale_); + first_occupied_vec_index_ = bucket_size; } template @@ -45,60 +45,69 @@ void BucketQueue::Push(Token *tok) { size_t vec_index = static_cast(bucket_index - bucket_storage_begin_); if (vec_index >= buckets_.size()) { - KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve" - << " more elements in constructor."; - int32 offset = static_cast(vec_index); - // a margin here (e.g. 10); - int32 increase_size = offset >= 0 ? offset + 1 - buckets_.size() + 10 : - - offset + 10; - buckets_.resize(buckets_.size() + increase_size); - - // Push front - if (offset < 0) { - std::vector > tmp(buckets_); - buckets_.clear(); - for (int32 i = 10 - offset ; i < buckets_.size(); i++) { - buckets_[i].swap(tmp[i + offset - 10]); - } - // Update start point - bucket_storage_begin_ = bucket_index - 10; - vec_index = 10; + int32 margin = 10; // a margin which is used to reduce re-allocate + // space frequently + // A cast from unsigned to signed type does not generate a machine-code + // instruction + if (static_cast(vec_index) > 0) { + KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve" + << " more elements in constructor. Push back."; + buckets_.resize(static_cast(vec_index) + margin); + } else { // less than 0 + KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve" + << " more elements in constructor. Push front."; + int32 increase_size = - static_cast(vec_index) + margin; + buckets_.resize(buckets_.size() + increase_size); + // translation + for (size_t i = buckets_.size() - 1; i >= increase_size; i--) { + buckets_[i].swap(buckets_[i - increase_size]); + } + bucket_storage_begin_ = bucket_storage_begin_ - increase_size; + vec_index = increase_size; } } - tok->in_queue = true; buckets_[vec_index].push_back(tok); - if (vec_index < (first_occupied_bucket_index_ - bucket_storage_begin_)) - first_occupied_bucket_index_ = vec_index + bucket_storage_begin_; + if (vec_index < first_occupied_vec_index_) + first_occupied_vec_index_ = vec_index; } template Token* BucketQueue::Pop() { - int32 vec_index = first_occupied_bucket_index_ - bucket_storage_begin_; - Token* best_tok = NULL; - while(vec_index < buckets_.size()) { + int32 vec_index = first_occupied_vec_index_; + while (vec_index < buckets_.size()) { + Token* tok = buckets_[vec_index].back(); // Remove the best token - best_tok = buckets_[vec_index].back(); buckets_[vec_index].pop_back(); - if (buckets_[vec_index].empty()) { // This bucket is empty. Update - // first_occupied_bucket_index_ + if (buckets_[vec_index].empty()) { // This bucket is empty. Update vec_index int32 next_vec_index = vec_index + 1; - for(; next_vec_index < buckets_.size(); next_vec_index++) { - if(!buckets_[next_vec_index].empty()) break; + for (; next_vec_index < buckets_.size(); next_vec_index++) { + if (!buckets_[next_vec_index].empty()) break; } - first_occupied_bucket_index_ = bucket_storage_begin_ + next_vec_index; vec_index = next_vec_index; } - if (best_tok->in_queue) { // This is a effective token - best_tok->in_queue = false; - break; - } else { - best_tok = NULL; + if (tok->in_queue) { // This is a effective token + tok->in_queue = false; + first_occupied_vec_index_ = vec_index; + return tok; } } - return best_tok; + return NULL; +} + +template +void BucketQueue::Clear() { + for (size_t i = 0; i < buckets_.size(); i++) { + buckets_[i].clear(); + } + first_occupied_vec_index_ = buckets_.size(); +} + +template +void BucketQueue::SetBegin(BaseFloat best_cost_estimate) { + bucket_storage_begin_ = std::floor((best_cost_estimate - 15) * cost_scale_); } // instantiate this class once for each thing you have to decode. @@ -106,15 +115,19 @@ template LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( const FST &fst, const LatticeFasterDecoderCombineConfig &config): - fst_(&fst), delete_fst_(false), config_(config), num_toks_(0) { + fst_(&fst), delete_fst_(false), config_(config), num_toks_(0), + cur_queue_(0, config_.cost_scale) { config.Check(); + prev_toks_.reserve(1000); + cur_toks_.reserve(1000); } template LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( const LatticeFasterDecoderCombineConfig &config, FST *fst): - fst_(fst), delete_fst_(true), config_(config), num_toks_(0) { + fst_(fst), delete_fst_(true), config_(config), num_toks_(0), + cur_queue_(0, config_.cost_scale) { config.Check(); prev_toks_.reserve(1000); cur_toks_.reserve(1000); @@ -150,6 +163,7 @@ void LatticeFasterDecoderCombineTpl::InitDecoding() { num_toks_++; best_token_in_next_frame_ = start_tok; adaptive_beam_ = config_.beam; + } // Returns true if any kind of traceback is available (not necessarily from @@ -779,11 +793,12 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( } } - KALDI_ASSERT(best_token_in_next_frame_); - BucketQueue cur_queue(best_token_in_next_frame_->tot_cost, config_.cost_scale); + KALDI_ASSERT(best_token_in_next_frame_); + cur_queue_.Clear(); + cur_queue_.SetBegin(best_token_in_next_frame_->tot_cost); // Add tokens to queue for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { - cur_queue.Push(tok); + cur_queue_.Push(tok); } // Declare a local variable so the compiler can put it in a register, since @@ -812,7 +827,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( Token *tok = NULL; int32 num_toks_processed = 0; int32 max_active = config_.max_active; - for (; num_toks_processed < max_active && (tok = cur_queue.Pop()) != NULL; + for (; num_toks_processed < max_active && (tok = cur_queue_.Pop()) != NULL; num_toks_processed++) { BaseFloat cur_cost = tok->tot_cost; StateId state = tok->state_id; @@ -846,7 +861,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // "changed" tells us whether the new token has a different // cost from before, or is new. if (changed && !new_tok->in_queue) { - cur_queue.Push(new_tok); + cur_queue_.Push(new_tok); } } } else { // propagate emitting @@ -876,7 +891,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( { // This block updates adaptive_beam_ BaseFloat beam_used_this_frame = adaptive_beam; - Token *tok = cur_queue.Pop(); + Token *tok = cur_queue_.Pop(); if (tok != NULL) { // The queue would only be nonempty if we hit the max-active constraint. BaseFloat best_cost_this_frame = cur_cutoff - adaptive_beam; @@ -921,9 +936,11 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( tmp_toks = &cur_toks_; } - BucketQueue cur_queue(best_token_in_next_frame_->tot_cost, config_.cost_scale); + KALDI_ASSERT(best_token_in_next_frame_); + cur_queue_.Clear(); + cur_queue_.SetBegin(best_token_in_next_frame_->tot_cost); for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { - cur_queue.Push(tok); + cur_queue_.Push(tok); } // Declare a local variable so the compiler can put it in a register, since @@ -937,7 +954,7 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( int32 num_toks_processed = 0; int32 max_active = config_.max_active; - for (; num_toks_processed < max_active && (tok = cur_queue.Pop()) != NULL; + for (; num_toks_processed < max_active && (tok = cur_queue_.Pop()) != NULL; num_toks_processed++) { BaseFloat cur_cost = tok->tot_cost; StateId state = tok->state_id; @@ -971,7 +988,7 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( // "changed" tells us whether the new token has a different // cost from before, or is new. if (changed && !new_tok->in_queue) { - cur_queue.Push(new_tok); + cur_queue_.Push(new_tok); } } } diff --git a/src/decoder/lattice-faster-decoder-combine.h b/src/decoder/lattice-faster-decoder-combine.h index 8abce6260b7..016c3ddd4b0 100644 --- a/src/decoder/lattice-faster-decoder-combine.h +++ b/src/decoder/lattice-faster-decoder-combine.h @@ -265,6 +265,13 @@ class BucketQueue { // were no Tokens left. Sets tok->in_queue to false for the returned Token. Token* Pop(); + // Clear all the individual buckets. Set 'first_occupied_vec_index_' to the + // value past the end of buckets_. + void Clear(); + + // Set 'bucket_storage_begin_'. + void SetBegin(BaseFloat best_cost_estimate); + private: // Configuration value that is multiplied by tokens' costs before integerizing // them to determine the bucket index @@ -279,11 +286,11 @@ class BucketQueue { // then access buckets_[vec_index]. std::vector > buckets_; - // The lowest-numbered bucket_index that is occupied (i.e. the first one which + // The lowest-numbered vec_index that is occupied (i.e. the first one which // has any elements). Will be updated as we add or remove tokens. // If this corresponds to a value past the end of buckets_, we interpret it // as 'there are no buckets with entries'. - int32 first_occupied_bucket_index_; + int32 first_occupied_vec_index_; // An offset that determines how we index into the buckets_ vector; // may be interpreted as a 'bucket_index' that is better than any one that @@ -579,8 +586,7 @@ class LatticeFasterDecoderCombineTpl { std::vector active_toks_; // Lists of tokens, indexed by // frame (members of TokenList are toks, must_prune_forward_links, // must_prune_tokens). - std::queue cur_queue_; // temp variable used in ProcessForFrame - // and ProcessNonemitting + // Stores the best token in next frame. The tot_cost of it will be used to // initialize the BucketQueue. Token* best_token_in_next_frame_; @@ -616,6 +622,8 @@ class LatticeFasterDecoderCombineTpl { BaseFloat final_best_cost_; BaseFloat adaptive_beam_; // will be set to beam_ when we start + BucketQueue cur_queue_; // temp variable used in + // ProcessForFrame/ProcessNonemitting // This function takes a singly linked list of tokens for a single frame, and // outputs a list of them in topological order (it will crash if no such order From d7a3a9dd16e1d8d4d93b6f029df2c79f881ac1c4 Mon Sep 17 00:00:00 2001 From: Hang Lyu Date: Mon, 25 Mar 2019 20:51:10 -0400 Subject: [PATCH 18/29] some fix and remove SetBegin --- ...tice-faster-decoder-combine-bucketqueue.cc | 45 ++++++++----------- ...ttice-faster-decoder-combine-bucketqueue.h | 7 --- src/decoder/lattice-faster-decoder-combine.cc | 45 ++++++++----------- src/decoder/lattice-faster-decoder-combine.h | 7 --- 4 files changed, 36 insertions(+), 68 deletions(-) diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc index 06f1f80b2a8..4d91f48782f 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc @@ -86,11 +86,11 @@ Token* BucketQueue::Pop() { if (!buckets_[next_vec_index].empty()) break; } vec_index = next_vec_index; + first_occupied_vec_index_ = vec_index; } if (tok->in_queue) { // This is a effective token tok->in_queue = false; - first_occupied_vec_index_ = vec_index; return tok; } } @@ -103,11 +103,7 @@ void BucketQueue::Clear() { buckets_[i].clear(); } first_occupied_vec_index_ = buckets_.size(); -} - -template -void BucketQueue::SetBegin(BaseFloat best_cost_estimate) { - bucket_storage_begin_ = std::floor((best_cost_estimate - 15) * cost_scale_); + bucket_storage_begin_ = -15 * cost_scale_; } // instantiate this class once for each thing you have to decode. @@ -161,8 +157,9 @@ void LatticeFasterDecoderCombineTpl::InitDecoding() { active_toks_[0].toks = start_tok; cur_toks_[start_state] = start_tok; // initialize current tokens map num_toks_++; - best_token_in_next_frame_ = start_tok; adaptive_beam_ = config_.beam; + cost_offsets_.resize(1); + cost_offsets_[0] = 0.0; } @@ -793,9 +790,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( } } - KALDI_ASSERT(best_token_in_next_frame_); cur_queue_.Clear(); - cur_queue_.SetBegin(best_token_in_next_frame_->tot_cost); // Add tokens to queue for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { cur_queue_.Push(tok); @@ -814,14 +809,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( BaseFloat next_cutoff = std::numeric_limits::infinity(); // "cost_offset" contains the acoustic log-likelihoods on current frame in // order to keep everything in a nice dynamic range. Reduce roundoff errors. - BaseFloat cost_offset = - best_token_in_next_frame_->tot_cost; - - best_token_in_next_frame_ = NULL; - // Store the offset on the acoustic likelihoods that we're applying. - // Could just do cost_offsets_.push_back(cost_offset), but we - // do it this way as it's more robust to future code changes. - cost_offsets_.resize(frame + 1, 0.0); - cost_offsets_[frame] = cost_offset; + BaseFloat cost_offset = cost_offsets_[frame]; // Iterator the "cur_queue_" to process non-emittion and emittion arcs in fst. Token *tok = NULL; @@ -831,7 +819,9 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( num_toks_processed++) { BaseFloat cur_cost = tok->tot_cost; StateId state = tok->state_id; - if (cur_cost > cur_cutoff) { // Don't bother processing successors. + if (cur_cost > cur_cutoff && + num_toks_processed < config_.min_active) { // Don't bother processing + // successors. break; // This is a priority queue. The following tokens will be worse } else if (cur_cost + adaptive_beam < cur_cutoff) { cur_cutoff = cur_cost + adaptive_beam; // a tighter boundary @@ -840,7 +830,6 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // because we're about to regenerate them. This is a kind // of non-optimality (remember, this is the simple decoder), DeleteForwardLinks(tok); // necessary when re-visiting - tok->links = NULL; for (fst::ArcIterator aiter(*fst_, state); !aiter.Done(); aiter.Next()) { @@ -881,14 +870,17 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // list tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, graph_cost, ac_cost, tok->links); - if (best_token_in_next_frame_ == NULL || - next_tok->tot_cost < best_token_in_next_frame_->tot_cost) { - best_token_in_next_frame_ = next_tok; - } } } // for all arcs } // end of while loop + // Store the offset on the acoustic likelihoods that we're applying. + // Could just do cost_offsets_.push_back(cost_offset), but we + // do it this way as it's more robust to future code changes. + // Set the cost_offset_ for next frame, it equals "- best_cost_on_next_frame". + cost_offsets_.resize(frame + 2, 0.0); + cost_offsets_[frame + 1] = adaptive_beam - next_cutoff; + { // This block updates adaptive_beam_ BaseFloat beam_used_this_frame = adaptive_beam; Token *tok = cur_queue_.Pop(); @@ -936,9 +928,7 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( tmp_toks = &cur_toks_; } - KALDI_ASSERT(best_token_in_next_frame_); cur_queue_.Clear(); - cur_queue_.SetBegin(best_token_in_next_frame_->tot_cost); for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { cur_queue_.Push(tok); } @@ -958,7 +948,9 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( num_toks_processed++) { BaseFloat cur_cost = tok->tot_cost; StateId state = tok->state_id; - if (cur_cost > cur_cutoff) { // Don't bother processing successors. + if (cur_cost > cur_cutoff && + num_toks_processed < config_.min_active) { // Don't bother processing + // successors. break; // This is a priority queue. The following tokens will be worse } else if (cur_cost + adaptive_beam < cur_cutoff) { cur_cutoff = cur_cost + adaptive_beam; // a tighter boundary @@ -967,7 +959,6 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( // because we're about to regenerate them. This is a kind // of non-optimality (remember, this is the simple decoder), DeleteForwardLinks(tok); // necessary when re-visiting - tok->links = NULL; for (fst::ArcIterator aiter(*fst_, state); !aiter.Done(); aiter.Next()) { diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.h b/src/decoder/lattice-faster-decoder-combine-bucketqueue.h index 016c3ddd4b0..49e34789825 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.h +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.h @@ -269,9 +269,6 @@ class BucketQueue { // value past the end of buckets_. void Clear(); - // Set 'bucket_storage_begin_'. - void SetBegin(BaseFloat best_cost_estimate); - private: // Configuration value that is multiplied by tokens' costs before integerizing // them to determine the bucket index @@ -587,10 +584,6 @@ class LatticeFasterDecoderCombineTpl { // frame (members of TokenList are toks, must_prune_forward_links, // must_prune_tokens). - // Stores the best token in next frame. The tot_cost of it will be used to - // initialize the BucketQueue. - Token* best_token_in_next_frame_; - // fst_ is a pointer to the FST we are decoding from. const FST *fst_; // delete_fst_ is true if the pointer fst_ needs to be deleted when this diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index 06f1f80b2a8..4d91f48782f 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -86,11 +86,11 @@ Token* BucketQueue::Pop() { if (!buckets_[next_vec_index].empty()) break; } vec_index = next_vec_index; + first_occupied_vec_index_ = vec_index; } if (tok->in_queue) { // This is a effective token tok->in_queue = false; - first_occupied_vec_index_ = vec_index; return tok; } } @@ -103,11 +103,7 @@ void BucketQueue::Clear() { buckets_[i].clear(); } first_occupied_vec_index_ = buckets_.size(); -} - -template -void BucketQueue::SetBegin(BaseFloat best_cost_estimate) { - bucket_storage_begin_ = std::floor((best_cost_estimate - 15) * cost_scale_); + bucket_storage_begin_ = -15 * cost_scale_; } // instantiate this class once for each thing you have to decode. @@ -161,8 +157,9 @@ void LatticeFasterDecoderCombineTpl::InitDecoding() { active_toks_[0].toks = start_tok; cur_toks_[start_state] = start_tok; // initialize current tokens map num_toks_++; - best_token_in_next_frame_ = start_tok; adaptive_beam_ = config_.beam; + cost_offsets_.resize(1); + cost_offsets_[0] = 0.0; } @@ -793,9 +790,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( } } - KALDI_ASSERT(best_token_in_next_frame_); cur_queue_.Clear(); - cur_queue_.SetBegin(best_token_in_next_frame_->tot_cost); // Add tokens to queue for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { cur_queue_.Push(tok); @@ -814,14 +809,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( BaseFloat next_cutoff = std::numeric_limits::infinity(); // "cost_offset" contains the acoustic log-likelihoods on current frame in // order to keep everything in a nice dynamic range. Reduce roundoff errors. - BaseFloat cost_offset = - best_token_in_next_frame_->tot_cost; - - best_token_in_next_frame_ = NULL; - // Store the offset on the acoustic likelihoods that we're applying. - // Could just do cost_offsets_.push_back(cost_offset), but we - // do it this way as it's more robust to future code changes. - cost_offsets_.resize(frame + 1, 0.0); - cost_offsets_[frame] = cost_offset; + BaseFloat cost_offset = cost_offsets_[frame]; // Iterator the "cur_queue_" to process non-emittion and emittion arcs in fst. Token *tok = NULL; @@ -831,7 +819,9 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( num_toks_processed++) { BaseFloat cur_cost = tok->tot_cost; StateId state = tok->state_id; - if (cur_cost > cur_cutoff) { // Don't bother processing successors. + if (cur_cost > cur_cutoff && + num_toks_processed < config_.min_active) { // Don't bother processing + // successors. break; // This is a priority queue. The following tokens will be worse } else if (cur_cost + adaptive_beam < cur_cutoff) { cur_cutoff = cur_cost + adaptive_beam; // a tighter boundary @@ -840,7 +830,6 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // because we're about to regenerate them. This is a kind // of non-optimality (remember, this is the simple decoder), DeleteForwardLinks(tok); // necessary when re-visiting - tok->links = NULL; for (fst::ArcIterator aiter(*fst_, state); !aiter.Done(); aiter.Next()) { @@ -881,14 +870,17 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // list tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, graph_cost, ac_cost, tok->links); - if (best_token_in_next_frame_ == NULL || - next_tok->tot_cost < best_token_in_next_frame_->tot_cost) { - best_token_in_next_frame_ = next_tok; - } } } // for all arcs } // end of while loop + // Store the offset on the acoustic likelihoods that we're applying. + // Could just do cost_offsets_.push_back(cost_offset), but we + // do it this way as it's more robust to future code changes. + // Set the cost_offset_ for next frame, it equals "- best_cost_on_next_frame". + cost_offsets_.resize(frame + 2, 0.0); + cost_offsets_[frame + 1] = adaptive_beam - next_cutoff; + { // This block updates adaptive_beam_ BaseFloat beam_used_this_frame = adaptive_beam; Token *tok = cur_queue_.Pop(); @@ -936,9 +928,7 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( tmp_toks = &cur_toks_; } - KALDI_ASSERT(best_token_in_next_frame_); cur_queue_.Clear(); - cur_queue_.SetBegin(best_token_in_next_frame_->tot_cost); for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { cur_queue_.Push(tok); } @@ -958,7 +948,9 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( num_toks_processed++) { BaseFloat cur_cost = tok->tot_cost; StateId state = tok->state_id; - if (cur_cost > cur_cutoff) { // Don't bother processing successors. + if (cur_cost > cur_cutoff && + num_toks_processed < config_.min_active) { // Don't bother processing + // successors. break; // This is a priority queue. The following tokens will be worse } else if (cur_cost + adaptive_beam < cur_cutoff) { cur_cutoff = cur_cost + adaptive_beam; // a tighter boundary @@ -967,7 +959,6 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( // because we're about to regenerate them. This is a kind // of non-optimality (remember, this is the simple decoder), DeleteForwardLinks(tok); // necessary when re-visiting - tok->links = NULL; for (fst::ArcIterator aiter(*fst_, state); !aiter.Done(); aiter.Next()) { diff --git a/src/decoder/lattice-faster-decoder-combine.h b/src/decoder/lattice-faster-decoder-combine.h index 016c3ddd4b0..49e34789825 100644 --- a/src/decoder/lattice-faster-decoder-combine.h +++ b/src/decoder/lattice-faster-decoder-combine.h @@ -269,9 +269,6 @@ class BucketQueue { // value past the end of buckets_. void Clear(); - // Set 'bucket_storage_begin_'. - void SetBegin(BaseFloat best_cost_estimate); - private: // Configuration value that is multiplied by tokens' costs before integerizing // them to determine the bucket index @@ -587,10 +584,6 @@ class LatticeFasterDecoderCombineTpl { // frame (members of TokenList are toks, must_prune_forward_links, // must_prune_tokens). - // Stores the best token in next frame. The tot_cost of it will be used to - // initialize the BucketQueue. - Token* best_token_in_next_frame_; - // fst_ is a pointer to the FST we are decoding from. const FST *fst_; // delete_fst_ is true if the pointer fst_ needs to be deleted when this From a70f34b26e1a53ddc0a3f93c43484d4820b94962 Mon Sep 17 00:00:00 2001 From: Hang Lyu Date: Mon, 25 Mar 2019 22:29:29 -0400 Subject: [PATCH 19/29] minor fix --- src/decoder/lattice-faster-decoder-combine-bucketqueue.cc | 4 ++-- src/decoder/lattice-faster-decoder-combine.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc index 4d91f48782f..0f1ccf88530 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc @@ -43,7 +43,6 @@ template void BucketQueue::Push(Token *tok) { int32 bucket_index = std::floor(tok->tot_cost * cost_scale_); size_t vec_index = static_cast(bucket_index - bucket_storage_begin_); - if (vec_index >= buckets_.size()) { int32 margin = 10; // a margin which is used to reduce re-allocate // space frequently @@ -63,7 +62,8 @@ void BucketQueue::Push(Token *tok) { buckets_[i].swap(buckets_[i - increase_size]); } bucket_storage_begin_ = bucket_storage_begin_ - increase_size; - vec_index = increase_size; + vec_index = static_cast(vec_index) + increase_size; + first_occupied_vec_index_ = vec_index; } } tok->in_queue = true; diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index 4d91f48782f..0f1ccf88530 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -43,7 +43,6 @@ template void BucketQueue::Push(Token *tok) { int32 bucket_index = std::floor(tok->tot_cost * cost_scale_); size_t vec_index = static_cast(bucket_index - bucket_storage_begin_); - if (vec_index >= buckets_.size()) { int32 margin = 10; // a margin which is used to reduce re-allocate // space frequently @@ -63,7 +62,8 @@ void BucketQueue::Push(Token *tok) { buckets_[i].swap(buckets_[i - increase_size]); } bucket_storage_begin_ = bucket_storage_begin_ - increase_size; - vec_index = increase_size; + vec_index = static_cast(vec_index) + increase_size; + first_occupied_vec_index_ = vec_index; } } tok->in_queue = true; From b63611938c18e7439251fb7ea234b22f84005e94 Mon Sep 17 00:00:00 2001 From: Hang Lyu Date: Mon, 25 Mar 2019 23:10:23 -0400 Subject: [PATCH 20/29] small fix --- src/decoder/lattice-faster-decoder-combine-bucketqueue.cc | 1 - src/decoder/lattice-faster-decoder-combine.cc | 1 - 2 files changed, 2 deletions(-) diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc index 0f1ccf88530..68584ac991a 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc @@ -103,7 +103,6 @@ void BucketQueue::Clear() { buckets_[i].clear(); } first_occupied_vec_index_ = buckets_.size(); - bucket_storage_begin_ = -15 * cost_scale_; } // instantiate this class once for each thing you have to decode. diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index 0f1ccf88530..68584ac991a 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -103,7 +103,6 @@ void BucketQueue::Clear() { buckets_[i].clear(); } first_occupied_vec_index_ = buckets_.size(); - bucket_storage_begin_ = -15 * cost_scale_; } // instantiate this class once for each thing you have to decode. From b6abf4339b4f481395b2b4c930811413aee07ef0 Mon Sep 17 00:00:00 2001 From: Hang Lyu Date: Wed, 27 Mar 2019 23:10:04 -0400 Subject: [PATCH 21/29] first_nonempty_bucket_index_ and first_nonempty_bucket_ --- ...tice-faster-decoder-combine-bucketqueue.cc | 111 +++++++++--------- ...ttice-faster-decoder-combine-bucketqueue.h | 75 ++++++------ src/decoder/lattice-faster-decoder-combine.cc | 111 +++++++++--------- src/decoder/lattice-faster-decoder-combine.h | 75 ++++++------ 4 files changed, 186 insertions(+), 186 deletions(-) diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc index 68584ac991a..9da51c67f8c 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc @@ -27,82 +27,87 @@ namespace kaldi { template -BucketQueue::BucketQueue(BaseFloat best_cost_estimate, - BaseFloat cost_scale) : +BucketQueue::BucketQueue(BaseFloat cost_scale) : cost_scale_(cost_scale) { // NOTE: we reserve plenty of elements to avoid expensive reallocations // later on. Normally, the size is a little bigger than (adaptive_beam + - // 5) * cost_scale. + // 15) * cost_scale. int32 bucket_size = 100; buckets_.resize(bucket_size); - bucket_storage_begin_ = std::floor((best_cost_estimate - 15) * cost_scale_); - first_occupied_vec_index_ = bucket_size; + bucket_offset_ = 15 * cost_scale_; + first_nonempty_bucket_index_ = bucket_size - 1; + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; } template void BucketQueue::Push(Token *tok) { - int32 bucket_index = std::floor(tok->tot_cost * cost_scale_); - size_t vec_index = static_cast(bucket_index - bucket_storage_begin_); - if (vec_index >= buckets_.size()) { + size_t bucket_index = std::floor(tok->tot_cost * cost_scale_) + + bucket_offset_; + if (bucket_index >= buckets_.size()) { int32 margin = 10; // a margin which is used to reduce re-allocate // space frequently - // A cast from unsigned to signed type does not generate a machine-code - // instruction - if (static_cast(vec_index) > 0) { - KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve" - << " more elements in constructor. Push back."; - buckets_.resize(static_cast(vec_index) + margin); + if (static_cast(bucket_index) > 0) { + buckets_.resize(bucket_index + margin); } else { // less than 0 - KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve" - << " more elements in constructor. Push front."; - int32 increase_size = - static_cast(vec_index) + margin; - buckets_.resize(buckets_.size() + increase_size); - // translation - for (size_t i = buckets_.size() - 1; i >= increase_size; i--) { - buckets_[i].swap(buckets_[i - increase_size]); - } - bucket_storage_begin_ = bucket_storage_begin_ - increase_size; - vec_index = static_cast(vec_index) + increase_size; - first_occupied_vec_index_ = vec_index; + int32 increase_size = - static_cast(bucket_index) + margin; + buckets_.resize(buckets_.size() + increase_size); + // translation + for (size_t i = buckets_.size() - 1; i >= increase_size; i--) { + buckets_[i].swap(buckets_[i - increase_size]); + } + bucket_offset_ = bucket_offset_ + increase_size * cost_scale_; + bucket_index += increase_size; + first_nonempty_bucket_index_ = bucket_index; + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; } } tok->in_queue = true; - buckets_[vec_index].push_back(tok); - if (vec_index < first_occupied_vec_index_) - first_occupied_vec_index_ = vec_index; + buckets_[bucket_index].push_back(tok); + if (bucket_index < first_nonempty_bucket_index_) { + first_nonempty_bucket_index_ = bucket_index; + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; + } } template Token* BucketQueue::Pop() { - int32 vec_index = first_occupied_vec_index_; - while (vec_index < buckets_.size()) { - Token* tok = buckets_[vec_index].back(); - // Remove the best token - buckets_[vec_index].pop_back(); - - if (buckets_[vec_index].empty()) { // This bucket is empty. Update vec_index - int32 next_vec_index = vec_index + 1; - for (; next_vec_index < buckets_.size(); next_vec_index++) { - if (!buckets_[next_vec_index].empty()) break; + while (true) { + if (!first_nonempty_bucket_->empty()) { + Token *ans = first_nonempty_bucket_->back(); + first_nonempty_bucket_->pop_back(); + if (ans->in_queue) { + ans->in_queue = false; + return ans; } - vec_index = next_vec_index; - first_occupied_vec_index_ = vec_index; } + if (first_nonempty_bucket_->empty()) { + // In case, pop an empty BucketQueue + if (first_nonempty_bucket_index_ == buckets_.size() - 1) { + return NULL; + } - if (tok->in_queue) { // This is a effective token - tok->in_queue = false; - return tok; + first_nonempty_bucket_index_++; + for (; first_nonempty_bucket_index_ < buckets_.size() - 1; + first_nonempty_bucket_index_++) { + if (!buckets_[first_nonempty_bucket_index_].empty()) + break; + } + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; + if (first_nonempty_bucket_index_ == buckets_.size() - 1 && + first_nonempty_bucket_->empty()) { + return NULL; + } } } - return NULL; } template void BucketQueue::Clear() { - for (size_t i = 0; i < buckets_.size(); i++) { + for (size_t i = first_nonempty_bucket_index_; i < buckets_.size(); i++) { buckets_[i].clear(); } - first_occupied_vec_index_ = buckets_.size(); + first_nonempty_bucket_index_ = buckets_.size() - 1; + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; } // instantiate this class once for each thing you have to decode. @@ -111,7 +116,7 @@ LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( const FST &fst, const LatticeFasterDecoderCombineConfig &config): fst_(&fst), delete_fst_(false), config_(config), num_toks_(0), - cur_queue_(0, config_.cost_scale) { + cur_queue_(config_.cost_scale) { config.Check(); prev_toks_.reserve(1000); cur_toks_.reserve(1000); @@ -122,7 +127,7 @@ template LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( const LatticeFasterDecoderCombineConfig &config, FST *fst): fst_(fst), delete_fst_(true), config_(config), num_toks_(0), - cur_queue_(0, config_.cost_scale) { + cur_queue_(config_.cost_scale) { config.Check(); prev_toks_.reserve(1000); cur_toks_.reserve(1000); @@ -133,8 +138,6 @@ template LatticeFasterDecoderCombineTpl::~LatticeFasterDecoderCombineTpl() { ClearActiveTokens(); if (delete_fst_) delete fst_; - //prev_toks_.clear(); - //cur_toks_.clear(); } template @@ -819,7 +822,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( BaseFloat cur_cost = tok->tot_cost; StateId state = tok->state_id; if (cur_cost > cur_cutoff && - num_toks_processed < config_.min_active) { // Don't bother processing + num_toks_processed > config_.min_active) { // Don't bother processing // successors. break; // This is a priority queue. The following tokens will be worse } else if (cur_cost + adaptive_beam < cur_cutoff) { @@ -848,7 +851,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // "changed" tells us whether the new token has a different // cost from before, or is new. - if (changed && !new_tok->in_queue) { + if (changed) { cur_queue_.Push(new_tok); } } @@ -948,7 +951,7 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( BaseFloat cur_cost = tok->tot_cost; StateId state = tok->state_id; if (cur_cost > cur_cutoff && - num_toks_processed < config_.min_active) { // Don't bother processing + num_toks_processed > config_.min_active) { // Don't bother processing // successors. break; // This is a priority queue. The following tokens will be worse } else if (cur_cost + adaptive_beam < cur_cutoff) { @@ -977,7 +980,7 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( // "changed" tells us whether the new token has a different // cost from before, or is new. - if (changed && !new_tok->in_queue) { + if (changed) { cur_queue_.Push(new_tok); } } diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.h b/src/decoder/lattice-faster-decoder-combine-bucketqueue.h index 49e34789825..562a6005ef3 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.h +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.h @@ -240,33 +240,29 @@ struct BackpointerToken { template class BucketQueue { public: - /** Constructor. 'cost_scale' is a scale that we multiply the token costs by - * before intergerizing; a larger value means more buckets. - * 'best_cost_estimate' is an estimate of the best (lowest) cost that - * we are likely to encounter (e.g. the best cost that we have seen so far). - * It is used to initialize 'bucket_storage_begin_'. - */ - BucketQueue(BaseFloat best_cost_estimate, BaseFloat cost_scale = 1.0); - - // Add a Token to the queue; sets the field tok->in_queue to true (it is not + // Constructor. 'cost_scale' is a scale that we multiply the token costs by + // before intergerizing; a larger value means more buckets. + // 'bucket_offset_' is initialized to "15 * cost_scale_". It is an empirical + // value in case we trigger the re-allocation in normal case, since we do in + // fact normalize costs to be not far from zero on each frame. + BucketQueue(BaseFloat cost_scale = 1.0); + + // Adds Token to the queue; sets the field tok->in_queue to true (it is not // an error if it was already true). // If a Token was already in the queue but its cost improves, you should // just Push it again. It will be added to (possibly) a different bucket, but - // the old entry will remain. The old entry in the queue will be considered as - // nonexistent when we try to pop it and notice that the recorded cost - // does not match the cost in the Token. (Actually, we use in_queue to decide - // an entry is nonexistent or This strategy means that you may not - // delete Tokens as long as pointers to them might exist in this queue (hence, - // it is probably best to only ever have this queue as a local variable inside - // a function). + // the old entry will remain. We use "tok->in_queue" to decide + // an entry is nonexistent or not. When pop a Token off, the field + // 'tok->in_queue' is set to false. So the old entry in the queue will be + // considered as nonexistent when we try to pop it. void Push(Token *tok); // Removes and returns the next Token 'tok' in the queue, or NULL if there // were no Tokens left. Sets tok->in_queue to false for the returned Token. Token* Pop(); - // Clear all the individual buckets. Set 'first_occupied_vec_index_' to the - // value past the end of buckets_. + // Clears all the individual buckets. Sets 'first_nonempty_bucket_index_' to + // the end of buckets_. void Clear(); private: @@ -283,21 +279,20 @@ class BucketQueue { // then access buckets_[vec_index]. std::vector > buckets_; - // The lowest-numbered vec_index that is occupied (i.e. the first one which - // has any elements). Will be updated as we add or remove tokens. - // If this corresponds to a value past the end of buckets_, we interpret it - // as 'there are no buckets with entries'. - int32 first_occupied_vec_index_; - // An offset that determines how we index into the buckets_ vector; - // may be interpreted as a 'bucket_index' that is better than any one that - // we are likely to see. // In the constructor this will be initialized to something like - // bucket_storage_begin_ = std::floor((best_cost_estimate - 15) * cost_scale) - // which will make it unlikely that we have to change this value in future if - // we get a much better Token (this is expensive because it involves - // reallocating 'buckets_'). - int32 bucket_storage_begin_; + // "15 * cost_scale_" which will make it unlikely that we have to change this + // value in future if we get a much better Token (this is expensive because it + // involves reallocating 'buckets_'). + int32 bucket_offset_; + + // first_nonempty_bucket_index_ is an integer in the range [0, + // buckets_.size() - 1] which is not larger than the index of the first + // nonempty element of buckets_. + int32 first_nonempty_bucket_index_; + + // Synchronizes with first_nonempty_bucket_index_. + std::vector *first_nonempty_bucket_; }; /** This is the "normal" lattice-generating decoder. @@ -543,14 +538,16 @@ class LatticeFasterDecoderCombineTpl { void ProcessForFrame(DecodableInterface *decodable); /// Processes nonemitting (epsilon) arcs for one frame. - /// Calls this function once when all frames were processed. - /// Or calls it in GetRawLattice() to generate the complete token list for - /// the last frame. [Deal With the tokens in map "cur_toks_" which would - /// only contains emittion tokens from previous frame.] - /// If the map, "token_orig_cost", isn't NULL, we build the map which will - /// be used to recover "active_toks_[last_frame]" token list for the last - /// frame. - void ProcessNonemitting(std::unordered_map *token_orig_cost); + /// This function is called from FinalizeDecoding(), and also from + /// GetRawLattice() if GetRawLattice() is called before FinalizeDecoding() is + /// called. In the latter case, RecoverLastTokenList() is called later by + /// GetRawLattice() to restore the state prior to ProcessNonemitting() being + /// called, since ProcessForFrame() does not expect nonemitting arcs to + /// already have been propagagted. ["token_orig_cost" isn't NULL in the + /// latter case, we build the map which will be used to recover + /// "active_toks_[last_frame]" token list for the last frame.] + void ProcessNonemitting( + std::unordered_map *token_orig_cost); /// When GetRawLattice() is called during decoding, the /// active_toks_[last_frame] is changed. To keep the consistency of function diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index 68584ac991a..9da51c67f8c 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -27,82 +27,87 @@ namespace kaldi { template -BucketQueue::BucketQueue(BaseFloat best_cost_estimate, - BaseFloat cost_scale) : +BucketQueue::BucketQueue(BaseFloat cost_scale) : cost_scale_(cost_scale) { // NOTE: we reserve plenty of elements to avoid expensive reallocations // later on. Normally, the size is a little bigger than (adaptive_beam + - // 5) * cost_scale. + // 15) * cost_scale. int32 bucket_size = 100; buckets_.resize(bucket_size); - bucket_storage_begin_ = std::floor((best_cost_estimate - 15) * cost_scale_); - first_occupied_vec_index_ = bucket_size; + bucket_offset_ = 15 * cost_scale_; + first_nonempty_bucket_index_ = bucket_size - 1; + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; } template void BucketQueue::Push(Token *tok) { - int32 bucket_index = std::floor(tok->tot_cost * cost_scale_); - size_t vec_index = static_cast(bucket_index - bucket_storage_begin_); - if (vec_index >= buckets_.size()) { + size_t bucket_index = std::floor(tok->tot_cost * cost_scale_) + + bucket_offset_; + if (bucket_index >= buckets_.size()) { int32 margin = 10; // a margin which is used to reduce re-allocate // space frequently - // A cast from unsigned to signed type does not generate a machine-code - // instruction - if (static_cast(vec_index) > 0) { - KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve" - << " more elements in constructor. Push back."; - buckets_.resize(static_cast(vec_index) + margin); + if (static_cast(bucket_index) > 0) { + buckets_.resize(bucket_index + margin); } else { // less than 0 - KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve" - << " more elements in constructor. Push front."; - int32 increase_size = - static_cast(vec_index) + margin; - buckets_.resize(buckets_.size() + increase_size); - // translation - for (size_t i = buckets_.size() - 1; i >= increase_size; i--) { - buckets_[i].swap(buckets_[i - increase_size]); - } - bucket_storage_begin_ = bucket_storage_begin_ - increase_size; - vec_index = static_cast(vec_index) + increase_size; - first_occupied_vec_index_ = vec_index; + int32 increase_size = - static_cast(bucket_index) + margin; + buckets_.resize(buckets_.size() + increase_size); + // translation + for (size_t i = buckets_.size() - 1; i >= increase_size; i--) { + buckets_[i].swap(buckets_[i - increase_size]); + } + bucket_offset_ = bucket_offset_ + increase_size * cost_scale_; + bucket_index += increase_size; + first_nonempty_bucket_index_ = bucket_index; + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; } } tok->in_queue = true; - buckets_[vec_index].push_back(tok); - if (vec_index < first_occupied_vec_index_) - first_occupied_vec_index_ = vec_index; + buckets_[bucket_index].push_back(tok); + if (bucket_index < first_nonempty_bucket_index_) { + first_nonempty_bucket_index_ = bucket_index; + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; + } } template Token* BucketQueue::Pop() { - int32 vec_index = first_occupied_vec_index_; - while (vec_index < buckets_.size()) { - Token* tok = buckets_[vec_index].back(); - // Remove the best token - buckets_[vec_index].pop_back(); - - if (buckets_[vec_index].empty()) { // This bucket is empty. Update vec_index - int32 next_vec_index = vec_index + 1; - for (; next_vec_index < buckets_.size(); next_vec_index++) { - if (!buckets_[next_vec_index].empty()) break; + while (true) { + if (!first_nonempty_bucket_->empty()) { + Token *ans = first_nonempty_bucket_->back(); + first_nonempty_bucket_->pop_back(); + if (ans->in_queue) { + ans->in_queue = false; + return ans; } - vec_index = next_vec_index; - first_occupied_vec_index_ = vec_index; } + if (first_nonempty_bucket_->empty()) { + // In case, pop an empty BucketQueue + if (first_nonempty_bucket_index_ == buckets_.size() - 1) { + return NULL; + } - if (tok->in_queue) { // This is a effective token - tok->in_queue = false; - return tok; + first_nonempty_bucket_index_++; + for (; first_nonempty_bucket_index_ < buckets_.size() - 1; + first_nonempty_bucket_index_++) { + if (!buckets_[first_nonempty_bucket_index_].empty()) + break; + } + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; + if (first_nonempty_bucket_index_ == buckets_.size() - 1 && + first_nonempty_bucket_->empty()) { + return NULL; + } } } - return NULL; } template void BucketQueue::Clear() { - for (size_t i = 0; i < buckets_.size(); i++) { + for (size_t i = first_nonempty_bucket_index_; i < buckets_.size(); i++) { buckets_[i].clear(); } - first_occupied_vec_index_ = buckets_.size(); + first_nonempty_bucket_index_ = buckets_.size() - 1; + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; } // instantiate this class once for each thing you have to decode. @@ -111,7 +116,7 @@ LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( const FST &fst, const LatticeFasterDecoderCombineConfig &config): fst_(&fst), delete_fst_(false), config_(config), num_toks_(0), - cur_queue_(0, config_.cost_scale) { + cur_queue_(config_.cost_scale) { config.Check(); prev_toks_.reserve(1000); cur_toks_.reserve(1000); @@ -122,7 +127,7 @@ template LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( const LatticeFasterDecoderCombineConfig &config, FST *fst): fst_(fst), delete_fst_(true), config_(config), num_toks_(0), - cur_queue_(0, config_.cost_scale) { + cur_queue_(config_.cost_scale) { config.Check(); prev_toks_.reserve(1000); cur_toks_.reserve(1000); @@ -133,8 +138,6 @@ template LatticeFasterDecoderCombineTpl::~LatticeFasterDecoderCombineTpl() { ClearActiveTokens(); if (delete_fst_) delete fst_; - //prev_toks_.clear(); - //cur_toks_.clear(); } template @@ -819,7 +822,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( BaseFloat cur_cost = tok->tot_cost; StateId state = tok->state_id; if (cur_cost > cur_cutoff && - num_toks_processed < config_.min_active) { // Don't bother processing + num_toks_processed > config_.min_active) { // Don't bother processing // successors. break; // This is a priority queue. The following tokens will be worse } else if (cur_cost + adaptive_beam < cur_cutoff) { @@ -848,7 +851,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // "changed" tells us whether the new token has a different // cost from before, or is new. - if (changed && !new_tok->in_queue) { + if (changed) { cur_queue_.Push(new_tok); } } @@ -948,7 +951,7 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( BaseFloat cur_cost = tok->tot_cost; StateId state = tok->state_id; if (cur_cost > cur_cutoff && - num_toks_processed < config_.min_active) { // Don't bother processing + num_toks_processed > config_.min_active) { // Don't bother processing // successors. break; // This is a priority queue. The following tokens will be worse } else if (cur_cost + adaptive_beam < cur_cutoff) { @@ -977,7 +980,7 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( // "changed" tells us whether the new token has a different // cost from before, or is new. - if (changed && !new_tok->in_queue) { + if (changed) { cur_queue_.Push(new_tok); } } diff --git a/src/decoder/lattice-faster-decoder-combine.h b/src/decoder/lattice-faster-decoder-combine.h index 49e34789825..562a6005ef3 100644 --- a/src/decoder/lattice-faster-decoder-combine.h +++ b/src/decoder/lattice-faster-decoder-combine.h @@ -240,33 +240,29 @@ struct BackpointerToken { template class BucketQueue { public: - /** Constructor. 'cost_scale' is a scale that we multiply the token costs by - * before intergerizing; a larger value means more buckets. - * 'best_cost_estimate' is an estimate of the best (lowest) cost that - * we are likely to encounter (e.g. the best cost that we have seen so far). - * It is used to initialize 'bucket_storage_begin_'. - */ - BucketQueue(BaseFloat best_cost_estimate, BaseFloat cost_scale = 1.0); - - // Add a Token to the queue; sets the field tok->in_queue to true (it is not + // Constructor. 'cost_scale' is a scale that we multiply the token costs by + // before intergerizing; a larger value means more buckets. + // 'bucket_offset_' is initialized to "15 * cost_scale_". It is an empirical + // value in case we trigger the re-allocation in normal case, since we do in + // fact normalize costs to be not far from zero on each frame. + BucketQueue(BaseFloat cost_scale = 1.0); + + // Adds Token to the queue; sets the field tok->in_queue to true (it is not // an error if it was already true). // If a Token was already in the queue but its cost improves, you should // just Push it again. It will be added to (possibly) a different bucket, but - // the old entry will remain. The old entry in the queue will be considered as - // nonexistent when we try to pop it and notice that the recorded cost - // does not match the cost in the Token. (Actually, we use in_queue to decide - // an entry is nonexistent or This strategy means that you may not - // delete Tokens as long as pointers to them might exist in this queue (hence, - // it is probably best to only ever have this queue as a local variable inside - // a function). + // the old entry will remain. We use "tok->in_queue" to decide + // an entry is nonexistent or not. When pop a Token off, the field + // 'tok->in_queue' is set to false. So the old entry in the queue will be + // considered as nonexistent when we try to pop it. void Push(Token *tok); // Removes and returns the next Token 'tok' in the queue, or NULL if there // were no Tokens left. Sets tok->in_queue to false for the returned Token. Token* Pop(); - // Clear all the individual buckets. Set 'first_occupied_vec_index_' to the - // value past the end of buckets_. + // Clears all the individual buckets. Sets 'first_nonempty_bucket_index_' to + // the end of buckets_. void Clear(); private: @@ -283,21 +279,20 @@ class BucketQueue { // then access buckets_[vec_index]. std::vector > buckets_; - // The lowest-numbered vec_index that is occupied (i.e. the first one which - // has any elements). Will be updated as we add or remove tokens. - // If this corresponds to a value past the end of buckets_, we interpret it - // as 'there are no buckets with entries'. - int32 first_occupied_vec_index_; - // An offset that determines how we index into the buckets_ vector; - // may be interpreted as a 'bucket_index' that is better than any one that - // we are likely to see. // In the constructor this will be initialized to something like - // bucket_storage_begin_ = std::floor((best_cost_estimate - 15) * cost_scale) - // which will make it unlikely that we have to change this value in future if - // we get a much better Token (this is expensive because it involves - // reallocating 'buckets_'). - int32 bucket_storage_begin_; + // "15 * cost_scale_" which will make it unlikely that we have to change this + // value in future if we get a much better Token (this is expensive because it + // involves reallocating 'buckets_'). + int32 bucket_offset_; + + // first_nonempty_bucket_index_ is an integer in the range [0, + // buckets_.size() - 1] which is not larger than the index of the first + // nonempty element of buckets_. + int32 first_nonempty_bucket_index_; + + // Synchronizes with first_nonempty_bucket_index_. + std::vector *first_nonempty_bucket_; }; /** This is the "normal" lattice-generating decoder. @@ -543,14 +538,16 @@ class LatticeFasterDecoderCombineTpl { void ProcessForFrame(DecodableInterface *decodable); /// Processes nonemitting (epsilon) arcs for one frame. - /// Calls this function once when all frames were processed. - /// Or calls it in GetRawLattice() to generate the complete token list for - /// the last frame. [Deal With the tokens in map "cur_toks_" which would - /// only contains emittion tokens from previous frame.] - /// If the map, "token_orig_cost", isn't NULL, we build the map which will - /// be used to recover "active_toks_[last_frame]" token list for the last - /// frame. - void ProcessNonemitting(std::unordered_map *token_orig_cost); + /// This function is called from FinalizeDecoding(), and also from + /// GetRawLattice() if GetRawLattice() is called before FinalizeDecoding() is + /// called. In the latter case, RecoverLastTokenList() is called later by + /// GetRawLattice() to restore the state prior to ProcessNonemitting() being + /// called, since ProcessForFrame() does not expect nonemitting arcs to + /// already have been propagagted. ["token_orig_cost" isn't NULL in the + /// latter case, we build the map which will be used to recover + /// "active_toks_[last_frame]" token list for the last frame.] + void ProcessNonemitting( + std::unordered_map *token_orig_cost); /// When GetRawLattice() is called during decoding, the /// active_toks_[last_frame] is changed. To keep the consistency of function From 25907d83bab1a3e74755e786f963eea1383c9125 Mon Sep 17 00:00:00 2001 From: Hang Lyu Date: Thu, 28 Mar 2019 00:58:46 -0400 Subject: [PATCH 22/29] remove RecoverLastFrame() --- ...tice-faster-decoder-combine-bucketqueue.cc | 62 +++---------------- ...ttice-faster-decoder-combine-bucketqueue.h | 24 +------ src/decoder/lattice-faster-decoder-combine.cc | 62 +++---------------- src/decoder/lattice-faster-decoder-combine.h | 24 +------ 4 files changed, 24 insertions(+), 148 deletions(-) diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc index 9da51c67f8c..ed53153193a 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc @@ -218,10 +218,9 @@ bool LatticeFasterDecoderCombineTpl::GetRawLattice( KALDI_ERR << "You cannot call FinalizeDecoding() and then call " << "GetRawLattice() with use_final_probs == false"; - std::unordered_map token_orig_cost; if (!decoding_finalized_) { // Process the non-emitting arcs for the unfinished last frame. - ProcessNonemitting(&token_orig_cost); + ProcessNonemitting(); } @@ -293,40 +292,9 @@ bool LatticeFasterDecoderCombineTpl::GetRawLattice( } } - if (!decoding_finalized_) { // recover last token list - RecoverLastTokenList(token_orig_cost); - } return (ofst->NumStates() > 0); } - -// When GetRawLattice() is called during decoding, the -// active_toks_[last_frame] is changed. To keep the consistency of function -// ProcessForFrame(), recover it. -// Notice: as new token will be added to the head of TokenList, tok->next -// will not be affacted. -template -void LatticeFasterDecoderCombineTpl::RecoverLastTokenList( - const std::unordered_map &token_orig_cost) { - if (!token_orig_cost.empty()) { - for (Token* tok = active_toks_[active_toks_.size() - 1].toks; - tok != NULL;) { - if (token_orig_cost.find(tok) != token_orig_cost.end()) { - DeleteForwardLinks(tok); - tok->tot_cost = token_orig_cost.find(tok)->second; - tok->in_queue = false; - tok = tok->next; - } else { - DeleteForwardLinks(tok); - Token *next_tok = tok->next; - delete tok; - num_toks_--; - tok = next_tok; - } - } - } -} - // This function is now deprecated, since now we do determinization from outside // the LatticeFasterDecoder class. Outputs an FST corresponding to the // lattice-determinized lattice (one path per word sequence). @@ -756,7 +724,7 @@ void LatticeFasterDecoderCombineTpl::AdvanceDecoding( // tokens. This function used to be called PruneActiveTokensFinal(). template void LatticeFasterDecoderCombineTpl::FinalizeDecoding() { - ProcessNonemitting(NULL); + ProcessNonemitting(); int32 final_frame_plus_one = NumFramesDecoded(); int32 num_toks_begin = num_toks_; // PruneForwardLinksFinal() prunes final frame (with final-probs), and @@ -912,23 +880,8 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( template -void LatticeFasterDecoderCombineTpl::ProcessNonemitting( - std::unordered_map *token_orig_cost) { +void LatticeFasterDecoderCombineTpl::ProcessNonemitting() { int32 frame = active_toks_.size() - 1; - if (token_orig_cost) { // Build the elements which are used to recover - for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { - (*token_orig_cost)[tok] = tok->tot_cost; - } - } - - StateIdToTokenMap *tmp_toks; - if (token_orig_cost) { // "token_orig_cost" isn't NULL. It means we need to - // recover active_toks_[last_frame] and "cur_toks_" - // will be used in the future. - tmp_toks = new StateIdToTokenMap(cur_toks_); - } else { - tmp_toks = &cur_toks_; - } cur_queue_.Clear(); for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { @@ -971,7 +924,7 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( BaseFloat tot_cost = cur_cost + graph_cost; if (tot_cost < cur_cutoff) { Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, - tok, tmp_toks, &changed); + tok, &cur_toks_, &changed); // Add ForwardLink from tok to new_tok. Put it on the head of // tok->link list @@ -987,7 +940,12 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( } } // end of for loop } // end of while loop - if (token_orig_cost) delete tmp_toks; + if (!decoding_finalized_) { + // Update cost_offsets_, it equals "- best_cost". + cost_offsets_[frame] = adaptive_beam - cur_cutoff; + // Needn't to update adaptive_beam_, since we still process this frame in + // ProcessForFrame. + } } diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.h b/src/decoder/lattice-faster-decoder-combine-bucketqueue.h index 562a6005ef3..96096a0485f 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.h +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.h @@ -540,28 +540,8 @@ class LatticeFasterDecoderCombineTpl { /// Processes nonemitting (epsilon) arcs for one frame. /// This function is called from FinalizeDecoding(), and also from /// GetRawLattice() if GetRawLattice() is called before FinalizeDecoding() is - /// called. In the latter case, RecoverLastTokenList() is called later by - /// GetRawLattice() to restore the state prior to ProcessNonemitting() being - /// called, since ProcessForFrame() does not expect nonemitting arcs to - /// already have been propagagted. ["token_orig_cost" isn't NULL in the - /// latter case, we build the map which will be used to recover - /// "active_toks_[last_frame]" token list for the last frame.] - void ProcessNonemitting( - std::unordered_map *token_orig_cost); - - /// When GetRawLattice() is called during decoding, the - /// active_toks_[last_frame] is changed. To keep the consistency of function - /// ProcessForFrame(), recover it. - /// Notice: as new token will be added to the head of TokenList, tok->next - /// will not be affacted. - /// "token_orig_cost" is a mapping from token pointer to the tot_cost of the - /// token before propagating non-emitting arcs. It is used to recover the - /// change of original tokens in the last frame and remove the new tokens - /// which come from propagating non-emitting arcs, so that we can guarantee - /// the consistency of function ProcessForFrame(). - void RecoverLastTokenList( - const std::unordered_map &token_orig_cost); - + /// called. + void ProcessNonemitting(); /// The "prev_toks_" and "cur_toks_" actually allow us to maintain current /// and next frames. They are indexed by StateId. It is indexed by frame-index diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index 9da51c67f8c..ed53153193a 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -218,10 +218,9 @@ bool LatticeFasterDecoderCombineTpl::GetRawLattice( KALDI_ERR << "You cannot call FinalizeDecoding() and then call " << "GetRawLattice() with use_final_probs == false"; - std::unordered_map token_orig_cost; if (!decoding_finalized_) { // Process the non-emitting arcs for the unfinished last frame. - ProcessNonemitting(&token_orig_cost); + ProcessNonemitting(); } @@ -293,40 +292,9 @@ bool LatticeFasterDecoderCombineTpl::GetRawLattice( } } - if (!decoding_finalized_) { // recover last token list - RecoverLastTokenList(token_orig_cost); - } return (ofst->NumStates() > 0); } - -// When GetRawLattice() is called during decoding, the -// active_toks_[last_frame] is changed. To keep the consistency of function -// ProcessForFrame(), recover it. -// Notice: as new token will be added to the head of TokenList, tok->next -// will not be affacted. -template -void LatticeFasterDecoderCombineTpl::RecoverLastTokenList( - const std::unordered_map &token_orig_cost) { - if (!token_orig_cost.empty()) { - for (Token* tok = active_toks_[active_toks_.size() - 1].toks; - tok != NULL;) { - if (token_orig_cost.find(tok) != token_orig_cost.end()) { - DeleteForwardLinks(tok); - tok->tot_cost = token_orig_cost.find(tok)->second; - tok->in_queue = false; - tok = tok->next; - } else { - DeleteForwardLinks(tok); - Token *next_tok = tok->next; - delete tok; - num_toks_--; - tok = next_tok; - } - } - } -} - // This function is now deprecated, since now we do determinization from outside // the LatticeFasterDecoder class. Outputs an FST corresponding to the // lattice-determinized lattice (one path per word sequence). @@ -756,7 +724,7 @@ void LatticeFasterDecoderCombineTpl::AdvanceDecoding( // tokens. This function used to be called PruneActiveTokensFinal(). template void LatticeFasterDecoderCombineTpl::FinalizeDecoding() { - ProcessNonemitting(NULL); + ProcessNonemitting(); int32 final_frame_plus_one = NumFramesDecoded(); int32 num_toks_begin = num_toks_; // PruneForwardLinksFinal() prunes final frame (with final-probs), and @@ -912,23 +880,8 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( template -void LatticeFasterDecoderCombineTpl::ProcessNonemitting( - std::unordered_map *token_orig_cost) { +void LatticeFasterDecoderCombineTpl::ProcessNonemitting() { int32 frame = active_toks_.size() - 1; - if (token_orig_cost) { // Build the elements which are used to recover - for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { - (*token_orig_cost)[tok] = tok->tot_cost; - } - } - - StateIdToTokenMap *tmp_toks; - if (token_orig_cost) { // "token_orig_cost" isn't NULL. It means we need to - // recover active_toks_[last_frame] and "cur_toks_" - // will be used in the future. - tmp_toks = new StateIdToTokenMap(cur_toks_); - } else { - tmp_toks = &cur_toks_; - } cur_queue_.Clear(); for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { @@ -971,7 +924,7 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( BaseFloat tot_cost = cur_cost + graph_cost; if (tot_cost < cur_cutoff) { Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, - tok, tmp_toks, &changed); + tok, &cur_toks_, &changed); // Add ForwardLink from tok to new_tok. Put it on the head of // tok->link list @@ -987,7 +940,12 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting( } } // end of for loop } // end of while loop - if (token_orig_cost) delete tmp_toks; + if (!decoding_finalized_) { + // Update cost_offsets_, it equals "- best_cost". + cost_offsets_[frame] = adaptive_beam - cur_cutoff; + // Needn't to update adaptive_beam_, since we still process this frame in + // ProcessForFrame. + } } diff --git a/src/decoder/lattice-faster-decoder-combine.h b/src/decoder/lattice-faster-decoder-combine.h index 562a6005ef3..96096a0485f 100644 --- a/src/decoder/lattice-faster-decoder-combine.h +++ b/src/decoder/lattice-faster-decoder-combine.h @@ -540,28 +540,8 @@ class LatticeFasterDecoderCombineTpl { /// Processes nonemitting (epsilon) arcs for one frame. /// This function is called from FinalizeDecoding(), and also from /// GetRawLattice() if GetRawLattice() is called before FinalizeDecoding() is - /// called. In the latter case, RecoverLastTokenList() is called later by - /// GetRawLattice() to restore the state prior to ProcessNonemitting() being - /// called, since ProcessForFrame() does not expect nonemitting arcs to - /// already have been propagagted. ["token_orig_cost" isn't NULL in the - /// latter case, we build the map which will be used to recover - /// "active_toks_[last_frame]" token list for the last frame.] - void ProcessNonemitting( - std::unordered_map *token_orig_cost); - - /// When GetRawLattice() is called during decoding, the - /// active_toks_[last_frame] is changed. To keep the consistency of function - /// ProcessForFrame(), recover it. - /// Notice: as new token will be added to the head of TokenList, tok->next - /// will not be affacted. - /// "token_orig_cost" is a mapping from token pointer to the tot_cost of the - /// token before propagating non-emitting arcs. It is used to recover the - /// change of original tokens in the last frame and remove the new tokens - /// which come from propagating non-emitting arcs, so that we can guarantee - /// the consistency of function ProcessForFrame(). - void RecoverLastTokenList( - const std::unordered_map &token_orig_cost); - + /// called. + void ProcessNonemitting(); /// The "prev_toks_" and "cur_toks_" actually allow us to maintain current /// and next frames. They are indexed by StateId. It is indexed by frame-index From 26b378a6d1ef6ae1c4246ad266a80e74d5fa60a2 Mon Sep 17 00:00:00 2001 From: Hang Lyu Date: Thu, 28 Mar 2019 23:25:39 -0400 Subject: [PATCH 23/29] do ProcessNonemitting if final-probs are requested --- src/decoder/lattice-faster-decoder-combine-bucketqueue.cc | 2 +- src/decoder/lattice-faster-decoder-combine.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc index ed53153193a..0ef7c822efd 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc @@ -218,7 +218,7 @@ bool LatticeFasterDecoderCombineTpl::GetRawLattice( KALDI_ERR << "You cannot call FinalizeDecoding() and then call " << "GetRawLattice() with use_final_probs == false"; - if (!decoding_finalized_) { + if (!decoding_finalized_ && use_final_probs) { // Process the non-emitting arcs for the unfinished last frame. ProcessNonemitting(); } diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index ed53153193a..0ef7c822efd 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -218,7 +218,7 @@ bool LatticeFasterDecoderCombineTpl::GetRawLattice( KALDI_ERR << "You cannot call FinalizeDecoding() and then call " << "GetRawLattice() with use_final_probs == false"; - if (!decoding_finalized_) { + if (!decoding_finalized_ && use_final_probs) { // Process the non-emitting arcs for the unfinished last frame. ProcessNonemitting(); } From c66f1bb50526cac100baee87df47d8170c39c364 Mon Sep 17 00:00:00 2001 From: Hang Lyu Date: Fri, 29 Mar 2019 16:39:59 -0400 Subject: [PATCH 24/29] fix --- src/decoder/lattice-faster-decoder-combine-bucketqueue.cc | 1 + src/decoder/lattice-faster-decoder-combine.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc index 0ef7c822efd..1dad465d79f 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc @@ -51,6 +51,7 @@ void BucketQueue::Push(Token *tok) { } else { // less than 0 int32 increase_size = - static_cast(bucket_index) + margin; buckets_.resize(buckets_.size() + increase_size); + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; // translation for (size_t i = buckets_.size() - 1; i >= increase_size; i--) { buckets_[i].swap(buckets_[i - increase_size]); diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index 0ef7c822efd..1dad465d79f 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -51,6 +51,7 @@ void BucketQueue::Push(Token *tok) { } else { // less than 0 int32 increase_size = - static_cast(bucket_index) + margin; buckets_.resize(buckets_.size() + increase_size); + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; // translation for (size_t i = buckets_.size() - 1; i >= increase_size; i--) { buckets_[i].swap(buckets_[i - increase_size]); From 79b007180429979a41cbcb56272198fc744d7c0e Mon Sep 17 00:00:00 2001 From: Hang Lyu Date: Fri, 29 Mar 2019 17:31:56 -0400 Subject: [PATCH 25/29] small fix --- src/decoder/lattice-faster-decoder-combine-bucketqueue.cc | 2 +- src/decoder/lattice-faster-decoder-combine.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc index 1dad465d79f..e3757e83019 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc @@ -48,10 +48,10 @@ void BucketQueue::Push(Token *tok) { // space frequently if (static_cast(bucket_index) > 0) { buckets_.resize(bucket_index + margin); + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; } else { // less than 0 int32 increase_size = - static_cast(bucket_index) + margin; buckets_.resize(buckets_.size() + increase_size); - first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; // translation for (size_t i = buckets_.size() - 1; i >= increase_size; i--) { buckets_[i].swap(buckets_[i - increase_size]); diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index 1dad465d79f..e3757e83019 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -48,10 +48,10 @@ void BucketQueue::Push(Token *tok) { // space frequently if (static_cast(bucket_index) > 0) { buckets_.resize(bucket_index + margin); + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; } else { // less than 0 int32 increase_size = - static_cast(bucket_index) + margin; buckets_.resize(buckets_.size() + increase_size); - first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; // translation for (size_t i = buckets_.size() - 1; i >= increase_size; i--) { buckets_[i].swap(buckets_[i - increase_size]); From c359fe2cf39c023070e13c0251c5b46e7e9fc7c6 Mon Sep 17 00:00:00 2001 From: Hang Lyu Date: Wed, 3 Apr 2019 20:25:09 -0400 Subject: [PATCH 26/29] fix according to the comments --- ...tice-faster-decoder-combine-bucketqueue.cc | 90 +++++++++---------- ...ttice-faster-decoder-combine-bucketqueue.h | 11 ++- src/decoder/lattice-faster-decoder-combine.cc | 90 +++++++++---------- src/decoder/lattice-faster-decoder-combine.h | 11 ++- 4 files changed, 98 insertions(+), 104 deletions(-) diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc index e3757e83019..3f2b0e8e5cb 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc @@ -48,7 +48,6 @@ void BucketQueue::Push(Token *tok) { // space frequently if (static_cast(bucket_index) > 0) { buckets_.resize(bucket_index + margin); - first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; } else { // less than 0 int32 increase_size = - static_cast(bucket_index) + margin; buckets_.resize(buckets_.size() + increase_size); @@ -56,11 +55,11 @@ void BucketQueue::Push(Token *tok) { for (size_t i = buckets_.size() - 1; i >= increase_size; i--) { buckets_[i].swap(buckets_[i - increase_size]); } - bucket_offset_ = bucket_offset_ + increase_size * cost_scale_; + bucket_offset_ = bucket_offset_ + increase_size; bucket_index += increase_size; first_nonempty_bucket_index_ = bucket_index; - first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; } + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; } tok->in_queue = true; buckets_[bucket_index].push_back(tok); @@ -76,28 +75,22 @@ Token* BucketQueue::Pop() { if (!first_nonempty_bucket_->empty()) { Token *ans = first_nonempty_bucket_->back(); first_nonempty_bucket_->pop_back(); - if (ans->in_queue) { + if (ans->in_queue) { // If ans->in_queue is false, this means it is a + // duplicate instance of this Token that was left + // over when a Token's best_cost changed, and the + // Token has already been processed(so conceptually, + // it is not in the queue). ans->in_queue = false; return ans; } } if (first_nonempty_bucket_->empty()) { - // In case, pop an empty BucketQueue - if (first_nonempty_bucket_index_ == buckets_.size() - 1) { - return NULL; - } - - first_nonempty_bucket_index_++; - for (; first_nonempty_bucket_index_ < buckets_.size() - 1; + for (; first_nonempty_bucket_index_ + 1 < buckets_.size(); first_nonempty_bucket_index_++) { - if (!buckets_[first_nonempty_bucket_index_].empty()) - break; + if (!buckets_[first_nonempty_bucket_index_].empty()) break; } first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; - if (first_nonempty_bucket_index_ == buckets_.size() - 1 && - first_nonempty_bucket_->empty()) { - return NULL; - } + if (first_nonempty_bucket_->empty()) return NULL; } } } @@ -119,8 +112,8 @@ LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( fst_(&fst), delete_fst_(false), config_(config), num_toks_(0), cur_queue_(config_.cost_scale) { config.Check(); - prev_toks_.reserve(1000); cur_toks_.reserve(1000); + next_toks_.reserve(1000); } @@ -130,8 +123,8 @@ LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( fst_(fst), delete_fst_(true), config_(config), num_toks_(0), cur_queue_(config_.cost_scale) { config.Check(); - prev_toks_.reserve(1000); cur_toks_.reserve(1000); + next_toks_.reserve(1000); } @@ -144,8 +137,8 @@ LatticeFasterDecoderCombineTpl::~LatticeFasterDecoderCombineTpl() { template void LatticeFasterDecoderCombineTpl::InitDecoding() { // clean up from last time: - prev_toks_.clear(); cur_toks_.clear(); + next_toks_.clear(); cost_offsets_.clear(); ClearActiveTokens(); @@ -158,7 +151,7 @@ void LatticeFasterDecoderCombineTpl::InitDecoding() { active_toks_.resize(1); Token *start_tok = new Token(0.0, 0.0, start_state, NULL, NULL, NULL); active_toks_[0].toks = start_tok; - cur_toks_[start_state] = start_tok; // initialize current tokens map + next_toks_[start_state] = start_tok; // initialize current tokens map num_toks_++; adaptive_beam_ = config_.beam; cost_offsets_.resize(1); @@ -747,25 +740,26 @@ template void LatticeFasterDecoderCombineTpl::ProcessForFrame( DecodableInterface *decodable) { KALDI_ASSERT(active_toks_.size() > 0); - int32 frame = active_toks_.size() - 1; // frame is the frame-index - // (zero-based) used to get likelihoods - // from the decodable object. + int32 cur_frame = active_toks_.size() - 1, // frame is the frame-index (zero- + // based) used to get likelihoods + // from the decodable object. + next_frame = cur_frame + 1; + active_toks_.resize(active_toks_.size() + 1); - prev_toks_.swap(cur_toks_); - cur_toks_.clear(); - if (prev_toks_.empty()) { + cur_toks_.swap(next_toks_); + next_toks_.clear(); + if (cur_toks_.empty()) { if (!warned_) { - KALDI_WARN << "Error, no surviving tokens on frame " << frame; + KALDI_WARN << "Error, no surviving tokens on frame " << cur_frame; warned_ = true; } } cur_queue_.Clear(); // Add tokens to queue - for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + for (Token* tok = active_toks_[cur_frame].toks; tok != NULL; tok = tok->next) cur_queue_.Push(tok); - } // Declare a local variable so the compiler can put it in a register, since // C++ assumes other threads could be modifying class members. @@ -780,7 +774,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( BaseFloat next_cutoff = std::numeric_limits::infinity(); // "cost_offset" contains the acoustic log-likelihoods on current frame in // order to keep everything in a nice dynamic range. Reduce roundoff errors. - BaseFloat cost_offset = cost_offsets_[frame]; + BaseFloat cost_offset = cost_offsets_[cur_frame]; // Iterator the "cur_queue_" to process non-emittion and emittion arcs in fst. Token *tok = NULL; @@ -810,8 +804,8 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( BaseFloat graph_cost = arc.weight.Value(); BaseFloat tot_cost = cur_cost + graph_cost; if (tot_cost < cur_cutoff) { - Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, - tok, &prev_toks_, &changed); + Token *new_tok = FindOrAddToken(arc.nextstate, cur_frame, tot_cost, + tok, &cur_toks_, &changed); // Add ForwardLink from tok to new_tok. Put it on the head of // tok->link list @@ -826,17 +820,19 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( } } else { // propagate emitting BaseFloat graph_cost = arc.weight.Value(), - ac_cost = cost_offset - decodable->LogLikelihood(frame, arc.ilabel), + ac_cost = cost_offset - decodable->LogLikelihood(cur_frame, + arc.ilabel), cur_cost = tok->tot_cost, tot_cost = cur_cost + ac_cost + graph_cost; if (tot_cost > next_cutoff) continue; else if (tot_cost + adaptive_beam < next_cutoff) { - next_cutoff = tot_cost + adaptive_beam; // a tighter boundary for emitting + next_cutoff = tot_cost + adaptive_beam; // a tighter boundary for + // emitting } // no change flag is needed - Token *next_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, - tok, &cur_toks_, NULL); + Token *next_tok = FindOrAddToken(arc.nextstate, next_frame, tot_cost, + tok, &next_toks_, NULL); // Add ForwardLink from tok to next_tok. Put it on the head of tok->link // list tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, @@ -849,14 +845,16 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // Could just do cost_offsets_.push_back(cost_offset), but we // do it this way as it's more robust to future code changes. // Set the cost_offset_ for next frame, it equals "- best_cost_on_next_frame". - cost_offsets_.resize(frame + 2, 0.0); - cost_offsets_[frame + 1] = adaptive_beam - next_cutoff; + cost_offsets_.resize(cur_frame + 2, 0.0); + cost_offsets_[next_frame] = adaptive_beam - next_cutoff; { // This block updates adaptive_beam_ BaseFloat beam_used_this_frame = adaptive_beam; Token *tok = cur_queue_.Pop(); if (tok != NULL) { - // The queue would only be nonempty if we hit the max-active constraint. + // We hit the max-active contraint, meaning we effectively pruned to a + // beam tighter than 'beam'. Work out what this was, it will be used to + // update 'adaptive_beam'. BaseFloat best_cost_this_frame = cur_cutoff - adaptive_beam; beam_used_this_frame = tok->tot_cost - best_cost_this_frame; } @@ -882,12 +880,12 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( template void LatticeFasterDecoderCombineTpl::ProcessNonemitting() { - int32 frame = active_toks_.size() - 1; + int32 cur_frame = active_toks_.size() - 1; + StateIdToTokenMap &cur_toks = next_toks_; cur_queue_.Clear(); - for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + for (Token* tok = active_toks_[cur_frame].toks; tok != NULL; tok = tok->next) cur_queue_.Push(tok); - } // Declare a local variable so the compiler can put it in a register, since // C++ assumes other threads could be modifying class members. @@ -924,8 +922,8 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting() { BaseFloat graph_cost = arc.weight.Value(); BaseFloat tot_cost = cur_cost + graph_cost; if (tot_cost < cur_cutoff) { - Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, - tok, &cur_toks_, &changed); + Token *new_tok = FindOrAddToken(arc.nextstate, cur_frame, tot_cost, + tok, &cur_toks, &changed); // Add ForwardLink from tok to new_tok. Put it on the head of // tok->link list @@ -943,7 +941,7 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting() { } // end of while loop if (!decoding_finalized_) { // Update cost_offsets_, it equals "- best_cost". - cost_offsets_[frame] = adaptive_beam - cur_cutoff; + cost_offsets_[cur_frame] = adaptive_beam - cur_cutoff; // Needn't to update adaptive_beam_, since we still process this frame in // ProcessForFrame. } diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.h b/src/decoder/lattice-faster-decoder-combine-bucketqueue.h index 96096a0485f..3dab3818408 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.h +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.h @@ -377,9 +377,8 @@ class LatticeFasterDecoderCombineTpl { /// it will treat all final-probs as one. /// The raw lattice will be topologically sorted. /// The function can be called during decoding, it will process non-emitting - /// arcs from "cur_toks_" map to get tokens from both non-emitting and - /// emitting arcs for getting raw lattice. Then recover it to ensure the - /// consistency of ProcessForFrame(). + /// arcs from "next_toks_" map to get tokens from both non-emitting and + /// emitting arcs for getting raw lattice. /// /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, /// which also supports a pruning beam, in case for some reason @@ -529,7 +528,7 @@ class LatticeFasterDecoderCombineTpl { void PruneActiveTokens(BaseFloat delta); /// Processes non-emitting (epsilon) arcs and emitting arcs for one frame - /// together. It takes the emittion tokens in "prev_toks_" from last frame. + /// together. It takes the emittion tokens in "cur_toks_" from last frame. /// Generates non-emitting tokens for previous frame and emitting tokens for /// next frame. /// Notice: The emitting tokens for the current frame means the token take @@ -543,14 +542,14 @@ class LatticeFasterDecoderCombineTpl { /// called. void ProcessNonemitting(); - /// The "prev_toks_" and "cur_toks_" actually allow us to maintain current + /// The "cur_toks_" and "next_toks_" actually allow us to maintain current /// and next frames. They are indexed by StateId. It is indexed by frame-index /// plus one, where the frame-index is zero-based, as used in decodable object. /// That is, the emitting probs of frame t are accounted for in tokens at /// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of /// the graph. - StateIdToTokenMap prev_toks_; StateIdToTokenMap cur_toks_; + StateIdToTokenMap next_toks_; /// Gets the weight cutoff. /// Notice: In traiditional version, the histogram prunning method is applied diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index e3757e83019..3f2b0e8e5cb 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -48,7 +48,6 @@ void BucketQueue::Push(Token *tok) { // space frequently if (static_cast(bucket_index) > 0) { buckets_.resize(bucket_index + margin); - first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; } else { // less than 0 int32 increase_size = - static_cast(bucket_index) + margin; buckets_.resize(buckets_.size() + increase_size); @@ -56,11 +55,11 @@ void BucketQueue::Push(Token *tok) { for (size_t i = buckets_.size() - 1; i >= increase_size; i--) { buckets_[i].swap(buckets_[i - increase_size]); } - bucket_offset_ = bucket_offset_ + increase_size * cost_scale_; + bucket_offset_ = bucket_offset_ + increase_size; bucket_index += increase_size; first_nonempty_bucket_index_ = bucket_index; - first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; } + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; } tok->in_queue = true; buckets_[bucket_index].push_back(tok); @@ -76,28 +75,22 @@ Token* BucketQueue::Pop() { if (!first_nonempty_bucket_->empty()) { Token *ans = first_nonempty_bucket_->back(); first_nonempty_bucket_->pop_back(); - if (ans->in_queue) { + if (ans->in_queue) { // If ans->in_queue is false, this means it is a + // duplicate instance of this Token that was left + // over when a Token's best_cost changed, and the + // Token has already been processed(so conceptually, + // it is not in the queue). ans->in_queue = false; return ans; } } if (first_nonempty_bucket_->empty()) { - // In case, pop an empty BucketQueue - if (first_nonempty_bucket_index_ == buckets_.size() - 1) { - return NULL; - } - - first_nonempty_bucket_index_++; - for (; first_nonempty_bucket_index_ < buckets_.size() - 1; + for (; first_nonempty_bucket_index_ + 1 < buckets_.size(); first_nonempty_bucket_index_++) { - if (!buckets_[first_nonempty_bucket_index_].empty()) - break; + if (!buckets_[first_nonempty_bucket_index_].empty()) break; } first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; - if (first_nonempty_bucket_index_ == buckets_.size() - 1 && - first_nonempty_bucket_->empty()) { - return NULL; - } + if (first_nonempty_bucket_->empty()) return NULL; } } } @@ -119,8 +112,8 @@ LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( fst_(&fst), delete_fst_(false), config_(config), num_toks_(0), cur_queue_(config_.cost_scale) { config.Check(); - prev_toks_.reserve(1000); cur_toks_.reserve(1000); + next_toks_.reserve(1000); } @@ -130,8 +123,8 @@ LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( fst_(fst), delete_fst_(true), config_(config), num_toks_(0), cur_queue_(config_.cost_scale) { config.Check(); - prev_toks_.reserve(1000); cur_toks_.reserve(1000); + next_toks_.reserve(1000); } @@ -144,8 +137,8 @@ LatticeFasterDecoderCombineTpl::~LatticeFasterDecoderCombineTpl() { template void LatticeFasterDecoderCombineTpl::InitDecoding() { // clean up from last time: - prev_toks_.clear(); cur_toks_.clear(); + next_toks_.clear(); cost_offsets_.clear(); ClearActiveTokens(); @@ -158,7 +151,7 @@ void LatticeFasterDecoderCombineTpl::InitDecoding() { active_toks_.resize(1); Token *start_tok = new Token(0.0, 0.0, start_state, NULL, NULL, NULL); active_toks_[0].toks = start_tok; - cur_toks_[start_state] = start_tok; // initialize current tokens map + next_toks_[start_state] = start_tok; // initialize current tokens map num_toks_++; adaptive_beam_ = config_.beam; cost_offsets_.resize(1); @@ -747,25 +740,26 @@ template void LatticeFasterDecoderCombineTpl::ProcessForFrame( DecodableInterface *decodable) { KALDI_ASSERT(active_toks_.size() > 0); - int32 frame = active_toks_.size() - 1; // frame is the frame-index - // (zero-based) used to get likelihoods - // from the decodable object. + int32 cur_frame = active_toks_.size() - 1, // frame is the frame-index (zero- + // based) used to get likelihoods + // from the decodable object. + next_frame = cur_frame + 1; + active_toks_.resize(active_toks_.size() + 1); - prev_toks_.swap(cur_toks_); - cur_toks_.clear(); - if (prev_toks_.empty()) { + cur_toks_.swap(next_toks_); + next_toks_.clear(); + if (cur_toks_.empty()) { if (!warned_) { - KALDI_WARN << "Error, no surviving tokens on frame " << frame; + KALDI_WARN << "Error, no surviving tokens on frame " << cur_frame; warned_ = true; } } cur_queue_.Clear(); // Add tokens to queue - for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + for (Token* tok = active_toks_[cur_frame].toks; tok != NULL; tok = tok->next) cur_queue_.Push(tok); - } // Declare a local variable so the compiler can put it in a register, since // C++ assumes other threads could be modifying class members. @@ -780,7 +774,7 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( BaseFloat next_cutoff = std::numeric_limits::infinity(); // "cost_offset" contains the acoustic log-likelihoods on current frame in // order to keep everything in a nice dynamic range. Reduce roundoff errors. - BaseFloat cost_offset = cost_offsets_[frame]; + BaseFloat cost_offset = cost_offsets_[cur_frame]; // Iterator the "cur_queue_" to process non-emittion and emittion arcs in fst. Token *tok = NULL; @@ -810,8 +804,8 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( BaseFloat graph_cost = arc.weight.Value(); BaseFloat tot_cost = cur_cost + graph_cost; if (tot_cost < cur_cutoff) { - Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, - tok, &prev_toks_, &changed); + Token *new_tok = FindOrAddToken(arc.nextstate, cur_frame, tot_cost, + tok, &cur_toks_, &changed); // Add ForwardLink from tok to new_tok. Put it on the head of // tok->link list @@ -826,17 +820,19 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( } } else { // propagate emitting BaseFloat graph_cost = arc.weight.Value(), - ac_cost = cost_offset - decodable->LogLikelihood(frame, arc.ilabel), + ac_cost = cost_offset - decodable->LogLikelihood(cur_frame, + arc.ilabel), cur_cost = tok->tot_cost, tot_cost = cur_cost + ac_cost + graph_cost; if (tot_cost > next_cutoff) continue; else if (tot_cost + adaptive_beam < next_cutoff) { - next_cutoff = tot_cost + adaptive_beam; // a tighter boundary for emitting + next_cutoff = tot_cost + adaptive_beam; // a tighter boundary for + // emitting } // no change flag is needed - Token *next_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, - tok, &cur_toks_, NULL); + Token *next_tok = FindOrAddToken(arc.nextstate, next_frame, tot_cost, + tok, &next_toks_, NULL); // Add ForwardLink from tok to next_tok. Put it on the head of tok->link // list tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, @@ -849,14 +845,16 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( // Could just do cost_offsets_.push_back(cost_offset), but we // do it this way as it's more robust to future code changes. // Set the cost_offset_ for next frame, it equals "- best_cost_on_next_frame". - cost_offsets_.resize(frame + 2, 0.0); - cost_offsets_[frame + 1] = adaptive_beam - next_cutoff; + cost_offsets_.resize(cur_frame + 2, 0.0); + cost_offsets_[next_frame] = adaptive_beam - next_cutoff; { // This block updates adaptive_beam_ BaseFloat beam_used_this_frame = adaptive_beam; Token *tok = cur_queue_.Pop(); if (tok != NULL) { - // The queue would only be nonempty if we hit the max-active constraint. + // We hit the max-active contraint, meaning we effectively pruned to a + // beam tighter than 'beam'. Work out what this was, it will be used to + // update 'adaptive_beam'. BaseFloat best_cost_this_frame = cur_cutoff - adaptive_beam; beam_used_this_frame = tok->tot_cost - best_cost_this_frame; } @@ -882,12 +880,12 @@ void LatticeFasterDecoderCombineTpl::ProcessForFrame( template void LatticeFasterDecoderCombineTpl::ProcessNonemitting() { - int32 frame = active_toks_.size() - 1; + int32 cur_frame = active_toks_.size() - 1; + StateIdToTokenMap &cur_toks = next_toks_; cur_queue_.Clear(); - for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + for (Token* tok = active_toks_[cur_frame].toks; tok != NULL; tok = tok->next) cur_queue_.Push(tok); - } // Declare a local variable so the compiler can put it in a register, since // C++ assumes other threads could be modifying class members. @@ -924,8 +922,8 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting() { BaseFloat graph_cost = arc.weight.Value(); BaseFloat tot_cost = cur_cost + graph_cost; if (tot_cost < cur_cutoff) { - Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, - tok, &cur_toks_, &changed); + Token *new_tok = FindOrAddToken(arc.nextstate, cur_frame, tot_cost, + tok, &cur_toks, &changed); // Add ForwardLink from tok to new_tok. Put it on the head of // tok->link list @@ -943,7 +941,7 @@ void LatticeFasterDecoderCombineTpl::ProcessNonemitting() { } // end of while loop if (!decoding_finalized_) { // Update cost_offsets_, it equals "- best_cost". - cost_offsets_[frame] = adaptive_beam - cur_cutoff; + cost_offsets_[cur_frame] = adaptive_beam - cur_cutoff; // Needn't to update adaptive_beam_, since we still process this frame in // ProcessForFrame. } diff --git a/src/decoder/lattice-faster-decoder-combine.h b/src/decoder/lattice-faster-decoder-combine.h index 96096a0485f..3dab3818408 100644 --- a/src/decoder/lattice-faster-decoder-combine.h +++ b/src/decoder/lattice-faster-decoder-combine.h @@ -377,9 +377,8 @@ class LatticeFasterDecoderCombineTpl { /// it will treat all final-probs as one. /// The raw lattice will be topologically sorted. /// The function can be called during decoding, it will process non-emitting - /// arcs from "cur_toks_" map to get tokens from both non-emitting and - /// emitting arcs for getting raw lattice. Then recover it to ensure the - /// consistency of ProcessForFrame(). + /// arcs from "next_toks_" map to get tokens from both non-emitting and + /// emitting arcs for getting raw lattice. /// /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, /// which also supports a pruning beam, in case for some reason @@ -529,7 +528,7 @@ class LatticeFasterDecoderCombineTpl { void PruneActiveTokens(BaseFloat delta); /// Processes non-emitting (epsilon) arcs and emitting arcs for one frame - /// together. It takes the emittion tokens in "prev_toks_" from last frame. + /// together. It takes the emittion tokens in "cur_toks_" from last frame. /// Generates non-emitting tokens for previous frame and emitting tokens for /// next frame. /// Notice: The emitting tokens for the current frame means the token take @@ -543,14 +542,14 @@ class LatticeFasterDecoderCombineTpl { /// called. void ProcessNonemitting(); - /// The "prev_toks_" and "cur_toks_" actually allow us to maintain current + /// The "cur_toks_" and "next_toks_" actually allow us to maintain current /// and next frames. They are indexed by StateId. It is indexed by frame-index /// plus one, where the frame-index is zero-based, as used in decodable object. /// That is, the emitting probs of frame t are accounted for in tokens at /// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of /// the graph. - StateIdToTokenMap prev_toks_; StateIdToTokenMap cur_toks_; + StateIdToTokenMap next_toks_; /// Gets the weight cutoff. /// Notice: In traiditional version, the histogram prunning method is applied From 896c5c8994fb3353e0ef196576a724c7df49c228 Mon Sep 17 00:00:00 2001 From: Hang Lyu Date: Wed, 3 Apr 2019 21:13:04 -0400 Subject: [PATCH 27/29] resize the BucketQueue when a weird long one was caused --- src/decoder/lattice-faster-decoder-combine-bucketqueue.cc | 5 ++++- src/decoder/lattice-faster-decoder-combine-bucketqueue.h | 6 ++++++ src/decoder/lattice-faster-decoder-combine.cc | 5 ++++- src/decoder/lattice-faster-decoder-combine.h | 6 ++++++ 4 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc index 3f2b0e8e5cb..2c5b093571c 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc @@ -32,11 +32,12 @@ BucketQueue::BucketQueue(BaseFloat cost_scale) : // NOTE: we reserve plenty of elements to avoid expensive reallocations // later on. Normally, the size is a little bigger than (adaptive_beam + // 15) * cost_scale. - int32 bucket_size = 100; + int32 bucket_size = (15 + 20) * cost_scale_; buckets_.resize(bucket_size); bucket_offset_ = 15 * cost_scale_; first_nonempty_bucket_index_ = bucket_size - 1; first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; + bucket_size_tolerance_ = bucket_size; } template @@ -100,6 +101,8 @@ void BucketQueue::Clear() { for (size_t i = first_nonempty_bucket_index_; i < buckets_.size(); i++) { buckets_[i].clear(); } + if (buckets_.size() > bucket_size_tolerance_) + buckets_.resize(bucket_size_tolerance_); first_nonempty_bucket_index_ = buckets_.size() - 1; first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; } diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.h b/src/decoder/lattice-faster-decoder-combine-bucketqueue.h index 3dab3818408..094e9765d73 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.h +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.h @@ -293,6 +293,12 @@ class BucketQueue { // Synchronizes with first_nonempty_bucket_index_. std::vector *first_nonempty_bucket_; + + // If the size of the BucketQueue is larger than "bucket_size_tolerance_", we + // will resize it to "bucket_size_tolerance_" in Clear. A weird long + // BucketQueue might be caused when the min-active was activated and an + // unusually large loglikelihood range was encountered. + size_t bucket_size_tolerance_; }; /** This is the "normal" lattice-generating decoder. diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index 3f2b0e8e5cb..2c5b093571c 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -32,11 +32,12 @@ BucketQueue::BucketQueue(BaseFloat cost_scale) : // NOTE: we reserve plenty of elements to avoid expensive reallocations // later on. Normally, the size is a little bigger than (adaptive_beam + // 15) * cost_scale. - int32 bucket_size = 100; + int32 bucket_size = (15 + 20) * cost_scale_; buckets_.resize(bucket_size); bucket_offset_ = 15 * cost_scale_; first_nonempty_bucket_index_ = bucket_size - 1; first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; + bucket_size_tolerance_ = bucket_size; } template @@ -100,6 +101,8 @@ void BucketQueue::Clear() { for (size_t i = first_nonempty_bucket_index_; i < buckets_.size(); i++) { buckets_[i].clear(); } + if (buckets_.size() > bucket_size_tolerance_) + buckets_.resize(bucket_size_tolerance_); first_nonempty_bucket_index_ = buckets_.size() - 1; first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; } diff --git a/src/decoder/lattice-faster-decoder-combine.h b/src/decoder/lattice-faster-decoder-combine.h index 3dab3818408..094e9765d73 100644 --- a/src/decoder/lattice-faster-decoder-combine.h +++ b/src/decoder/lattice-faster-decoder-combine.h @@ -293,6 +293,12 @@ class BucketQueue { // Synchronizes with first_nonempty_bucket_index_. std::vector *first_nonempty_bucket_; + + // If the size of the BucketQueue is larger than "bucket_size_tolerance_", we + // will resize it to "bucket_size_tolerance_" in Clear. A weird long + // BucketQueue might be caused when the min-active was activated and an + // unusually large loglikelihood range was encountered. + size_t bucket_size_tolerance_; }; /** This is the "normal" lattice-generating decoder. From 6e2d27ae807b33c7723c583ea3f291d48892a1bd Mon Sep 17 00:00:00 2001 From: Hang Lyu Date: Wed, 3 Apr 2019 23:38:09 -0400 Subject: [PATCH 28/29] small fix --- src/decoder/lattice-faster-decoder-combine-bucketqueue.cc | 4 +++- src/decoder/lattice-faster-decoder-combine.cc | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc index 2c5b093571c..968ff2a2ce5 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc @@ -101,8 +101,10 @@ void BucketQueue::Clear() { for (size_t i = first_nonempty_bucket_index_; i < buckets_.size(); i++) { buckets_[i].clear(); } - if (buckets_.size() > bucket_size_tolerance_) + if (buckets_.size() > bucket_size_tolerance_) { buckets_.resize(bucket_size_tolerance_); + bucket_offset_ = 15 * cost_scale_; + } first_nonempty_bucket_index_ = buckets_.size() - 1; first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; } diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index 2c5b093571c..968ff2a2ce5 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -101,8 +101,10 @@ void BucketQueue::Clear() { for (size_t i = first_nonempty_bucket_index_; i < buckets_.size(); i++) { buckets_[i].clear(); } - if (buckets_.size() > bucket_size_tolerance_) + if (buckets_.size() > bucket_size_tolerance_) { buckets_.resize(bucket_size_tolerance_); + bucket_offset_ = 15 * cost_scale_; + } first_nonempty_bucket_index_ = buckets_.size() - 1; first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; } From 7dd2ca2f6f31cbff17be9f188bd4601ecb179b0b Mon Sep 17 00:00:00 2001 From: Hang Lyu Date: Sat, 6 Apr 2019 17:28:00 -0400 Subject: [PATCH 29/29] 1.2 tolerance --- src/decoder/lattice-faster-decoder-combine-bucketqueue.cc | 2 +- src/decoder/lattice-faster-decoder-combine.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc index 968ff2a2ce5..f30fc36b872 100644 --- a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc @@ -37,7 +37,7 @@ BucketQueue::BucketQueue(BaseFloat cost_scale) : bucket_offset_ = 15 * cost_scale_; first_nonempty_bucket_index_ = bucket_size - 1; first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; - bucket_size_tolerance_ = bucket_size; + bucket_size_tolerance_ = 1.2 * bucket_size; } template diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc index 968ff2a2ce5..f30fc36b872 100644 --- a/src/decoder/lattice-faster-decoder-combine.cc +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -37,7 +37,7 @@ BucketQueue::BucketQueue(BaseFloat cost_scale) : bucket_offset_ = 15 * cost_scale_; first_nonempty_bucket_index_ = bucket_size - 1; first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; - bucket_size_tolerance_ = bucket_size; + bucket_size_tolerance_ = 1.2 * bucket_size; } template