diff --git a/src/decoder/lattice-simple-decoder.cc b/src/decoder/lattice-simple-decoder.cc index cc8712e854d..87378f93bbd 100644 --- a/src/decoder/lattice-simple-decoder.cc +++ b/src/decoder/lattice-simple-decoder.cc @@ -45,8 +45,8 @@ void LatticeSimpleDecoder::InitDecoding() { bool LatticeSimpleDecoder::Decode(DecodableInterface *decodable) { InitDecoding(); - - while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { + + while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { if (NumFramesDecoded() % config_.prune_interval == 0) PruneActiveTokens(config_.lattice_beam * config_.prune_scale); ProcessEmitting(decodable); @@ -57,7 +57,7 @@ bool LatticeSimpleDecoder::Decode(DecodableInterface *decodable) { ProcessNonemitting(); } FinalizeDecoding(); - + // Returns true if we have any kind of traceback available (not necessarily // to the end state; query ReachedFinal() for that). return !final_costs_.empty(); @@ -88,9 +88,9 @@ bool LatticeSimpleDecoder::GetRawLattice(Lattice *ofst, if (decoding_finalized_ && !use_final_probs) KALDI_ERR << "You cannot call FinalizeDecoding() and then call " << "GetRawLattice() with use_final_probs == false"; - + unordered_map final_costs_local; - + const unordered_map &final_costs = (decoding_finalized_ ? final_costs_ : final_costs_local); @@ -100,7 +100,7 @@ bool LatticeSimpleDecoder::GetRawLattice(Lattice *ofst, ofst->DeleteStates(); int32 num_frames = NumFramesDecoded(); KALDI_ASSERT(num_frames > 0); - const int32 bucket_count = num_toks_/2 + 3; + const int32 bucket_count = num_toks_/2 + 3; unordered_map tok_map(bucket_count); // First create all states. for (int32 f = 0; f <= num_frames; f++) { @@ -169,10 +169,10 @@ bool LatticeSimpleDecoder::GetLattice( fst::ILabelCompare ilabel_comp; ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes // lattice-determinization more efficient. - + fst::DeterminizeLatticePrunedOptions lat_opts; lat_opts.max_mem = config_.det_opts.max_mem; - + DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts); raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed. Connect(ofst); // Remove unreachable states... there might be @@ -196,7 +196,7 @@ inline LatticeSimpleDecoder::Token *LatticeSimpleDecoder::FindOrAddToken( bool emitting, bool *changed) { KALDI_ASSERT(frame < active_toks_.size()); Token *&toks = active_toks_[frame].toks; - + unordered_map::iterator find_iter = cur_toks_.find(state); if (find_iter == cur_toks_.end()) { // no such token presently. // Create one. @@ -221,7 +221,7 @@ inline LatticeSimpleDecoder::Token *LatticeSimpleDecoder::FindOrAddToken( return tok; } } - + // delta is the amount by which the extra_costs must // change before it sets "extra_costs_changed" to true. If delta is larger, // we'll tend to go back less far toward the beginning of the file. @@ -242,7 +242,7 @@ void LatticeSimpleDecoder::PruneForwardLinks( warned_ = true; } } - + bool changed = true; while (changed) { changed = false; @@ -300,7 +300,7 @@ void LatticeSimpleDecoder::ComputeFinalCosts( BaseFloat infinity = std::numeric_limits::infinity(); BaseFloat best_cost = infinity, best_cost_with_final = infinity; - + for (unordered_map::const_iterator iter = cur_toks_.begin(); iter != cur_toks_.end(); ++iter) { StateId state = iter->first; @@ -336,19 +336,19 @@ void LatticeSimpleDecoder::ComputeFinalCosts( // on the final frame. If there are final tokens active, it uses the final-probs // for pruning, otherwise it treats all tokens as final. void LatticeSimpleDecoder::PruneForwardLinksFinal() { - KALDI_ASSERT(!active_toks_.empty()); + KALDI_ASSERT(!active_toks_.empty()); int32 frame_plus_one = active_toks_.size() - 1; if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. KALDI_WARN << "No tokens alive at end of file\n"; - typedef unordered_map::const_iterator IterType; + typedef unordered_map::const_iterator IterType; ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); decoding_finalized_ = true; // We're about to delete some of the tokens active on the final frame, so we // clear cur_toks_ because otherwise it would then contain dangling pointers. cur_toks_.clear(); - + // Now go through tokens on this frame, pruning forward links... may have to // iterate a few times until there is no more change, because the list is not // in topological order. This is a modified version of the code in @@ -429,7 +429,7 @@ BaseFloat LatticeSimpleDecoder::FinalRelativeCost() const { return final_relative_cost_; } } - + // Prune away any tokens on this frame that have no forward links. [we don't do // this in PruneForwardLinks because it would give us a problem with dangling // pointers]. @@ -453,14 +453,14 @@ void LatticeSimpleDecoder::PruneTokensForFrame(int32 frame) { } } } - + // Go backwards through still-alive tokens, pruning them, starting not from // the current frame (where we want to keep all tokens) but from the frame before // that. We go backwards through the frames and stop when we reach a point // where the delta-costs are not changing (and the delta controls when we consider // a cost to have "not changed"). void LatticeSimpleDecoder::PruneActiveTokens(BaseFloat delta) { - int32 cur_frame_plus_one = NumFramesDecoded(); + int32 cur_frame_plus_one = NumFramesDecoded(); int32 num_toks_begin = num_toks_; // The index "f" below represents a "frame plus one", i.e. you'd have to subtract // one to get the corresponding index for the decodable object. @@ -468,7 +468,7 @@ void LatticeSimpleDecoder::PruneActiveTokens(BaseFloat delta) { // Reason why we need to prune forward links in this situation: // (1) we have never pruned them // (2) we never pruned the forward links on the next frame, which - // + // if (active_toks_[f].must_prune_forward_links) { bool extra_costs_changed = false, links_pruned = false; PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); @@ -478,7 +478,7 @@ void LatticeSimpleDecoder::PruneActiveTokens(BaseFloat delta) { active_toks_[f].must_prune_tokens = true; active_toks_[f].must_prune_forward_links = false; } - if (f+1 < cur_frame_plus_one && + if (f+1 < cur_frame_plus_one && active_toks_[f+1].must_prune_tokens) { PruneTokensForFrame(f+1); active_toks_[f+1].must_prune_tokens = false; @@ -493,20 +493,20 @@ void LatticeSimpleDecoder::PruneActiveTokens(BaseFloat delta) { // (optionally) on the final frame. Takes into account the final-prob of // tokens. This function used to be called PruneActiveTokensFinal(). void LatticeSimpleDecoder::FinalizeDecoding() { - int32 final_frame_plus_one = NumFramesDecoded(); + int32 final_frame_plus_one = NumFramesDecoded(); int32 num_toks_begin = num_toks_; PruneForwardLinksFinal(); - for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { + for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { bool b1, b2; // values not used. BaseFloat dontcare = 0.0; PruneForwardLinks(f, &b1, &b2, dontcare); PruneTokensForFrame(f + 1); } - PruneTokensForFrame(0); + PruneTokensForFrame(0); KALDI_VLOG(3) << "pruned tokens from " << num_toks_begin << " to " << num_toks_; } - + void LatticeSimpleDecoder::ProcessEmitting(DecodableInterface *decodable) { int32 frame = active_toks_.size() - 1; // frame is the frame-index // (zero-based) used to get likelihoods @@ -538,9 +538,9 @@ void LatticeSimpleDecoder::ProcessEmitting(DecodableInterface *decodable) { // AddToken adds the next_tok to cur_toks_ (if not already present). Token *next_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, true, NULL); - + // Add ForwardLink from tok to next_tok (put on head of list tok->links) - tok->links = new ForwardLink(next_tok, arc.ilabel, arc.olabel, + tok->links = new ForwardLink(next_tok, arc.ilabel, arc.olabel, graph_cost, ac_cost, tok->links); } } @@ -553,7 +553,7 @@ void LatticeSimpleDecoder::ProcessNonemitting() { // Note: "frame" is the time-index we just processed, or -1 if // we are processing the nonemitting transitions before the // first frame (called from InitDecoding()). - + // Processes nonemitting arcs for one frame. Propagates within // cur_toks_. Note-- this queue structure is is not very optimal as // it may cause us to process states unnecessarily (e.g. more than once), @@ -569,15 +569,9 @@ void LatticeSimpleDecoder::ProcessNonemitting() { queue.push_back(state); best_cost = std::min(best_cost, iter->second->tot_cost); } - if (queue.empty()) { - if (!warned_) { - KALDI_ERR << "Error in ProcessEmitting: no surviving tokens: frame is " - << frame; - warned_ = true; - } - } + BaseFloat cutoff = best_cost + config_.beam; - + while (!queue.empty()) { StateId state = queue.back(); queue.pop_back(); @@ -600,10 +594,10 @@ void LatticeSimpleDecoder::ProcessNonemitting() { bool changed; Token *new_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, false, &changed); - + tok->links = new ForwardLink(new_tok, 0, arc.olabel, graph_cost, 0, tok->links); - + // "changed" tells us whether the new token has a different // cost from before, or is new [if so, add into queue]. if (changed && fst_.NumInputEpsilons(arc.nextstate) != 0) @@ -662,5 +656,3 @@ void LatticeSimpleDecoder::PruneCurrentTokens(BaseFloat beam, unordered_map