Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #4870, spurious error in ProcessNonemitting; queue can validly be empty. #4885

Merged
merged 1 commit into from
Nov 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 31 additions & 39 deletions src/decoder/lattice-simple-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ void LatticeSimpleDecoder::InitDecoding() {

bool LatticeSimpleDecoder::Decode(DecodableInterface *decodable) {
InitDecoding();
while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) {

while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) {
if (NumFramesDecoded() % config_.prune_interval == 0)
PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
ProcessEmitting(decodable);
Expand All @@ -57,7 +57,7 @@ bool LatticeSimpleDecoder::Decode(DecodableInterface *decodable) {
ProcessNonemitting();
}
FinalizeDecoding();

// Returns true if we have any kind of traceback available (not necessarily
// to the end state; query ReachedFinal() for that).
return !final_costs_.empty();
Expand Down Expand Up @@ -88,9 +88,9 @@ bool LatticeSimpleDecoder::GetRawLattice(Lattice *ofst,
if (decoding_finalized_ && !use_final_probs)
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
<< "GetRawLattice() with use_final_probs == false";

unordered_map<Token*, BaseFloat> final_costs_local;

const unordered_map<Token*, BaseFloat> &final_costs =
(decoding_finalized_ ? final_costs_ : final_costs_local);

Expand All @@ -100,7 +100,7 @@ bool LatticeSimpleDecoder::GetRawLattice(Lattice *ofst,
ofst->DeleteStates();
int32 num_frames = NumFramesDecoded();
KALDI_ASSERT(num_frames > 0);
const int32 bucket_count = num_toks_/2 + 3;
const int32 bucket_count = num_toks_/2 + 3;
unordered_map<Token*, StateId> tok_map(bucket_count);
// First create all states.
for (int32 f = 0; f <= num_frames; f++) {
Expand Down Expand Up @@ -169,10 +169,10 @@ bool LatticeSimpleDecoder::GetLattice(
fst::ILabelCompare<LatticeArc> ilabel_comp;
ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes
// lattice-determinization more efficient.

fst::DeterminizeLatticePrunedOptions lat_opts;
lat_opts.max_mem = config_.det_opts.max_mem;

DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts);
raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed.
Connect(ofst); // Remove unreachable states... there might be
Expand All @@ -196,7 +196,7 @@ inline LatticeSimpleDecoder::Token *LatticeSimpleDecoder::FindOrAddToken(
bool emitting, bool *changed) {
KALDI_ASSERT(frame < active_toks_.size());
Token *&toks = active_toks_[frame].toks;

unordered_map<StateId, Token*>::iterator find_iter = cur_toks_.find(state);
if (find_iter == cur_toks_.end()) { // no such token presently.
// Create one.
Expand All @@ -221,7 +221,7 @@ inline LatticeSimpleDecoder::Token *LatticeSimpleDecoder::FindOrAddToken(
return tok;
}
}

// delta is the amount by which the extra_costs must
// change before it sets "extra_costs_changed" to true. If delta is larger,
// we'll tend to go back less far toward the beginning of the file.
Expand All @@ -242,7 +242,7 @@ void LatticeSimpleDecoder::PruneForwardLinks(
warned_ = true;
}
}

bool changed = true;
while (changed) {
changed = false;
Expand Down Expand Up @@ -300,7 +300,7 @@ void LatticeSimpleDecoder::ComputeFinalCosts(
BaseFloat infinity = std::numeric_limits<BaseFloat>::infinity();
BaseFloat best_cost = infinity,
best_cost_with_final = infinity;

for (unordered_map<StateId, Token*>::const_iterator iter = cur_toks_.begin();
iter != cur_toks_.end(); ++iter) {
StateId state = iter->first;
Expand Down Expand Up @@ -336,19 +336,19 @@ void LatticeSimpleDecoder::ComputeFinalCosts(
// on the final frame. If there are final tokens active, it uses the final-probs
// for pruning, otherwise it treats all tokens as final.
void LatticeSimpleDecoder::PruneForwardLinksFinal() {
KALDI_ASSERT(!active_toks_.empty());
KALDI_ASSERT(!active_toks_.empty());
int32 frame_plus_one = active_toks_.size() - 1;

if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen.
KALDI_WARN << "No tokens alive at end of file\n";

typedef unordered_map<Token*, BaseFloat>::const_iterator IterType;
typedef unordered_map<Token*, BaseFloat>::const_iterator IterType;
ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_);
decoding_finalized_ = true;
// We're about to delete some of the tokens active on the final frame, so we
// clear cur_toks_ because otherwise it would then contain dangling pointers.
cur_toks_.clear();

// Now go through tokens on this frame, pruning forward links... may have to
// iterate a few times until there is no more change, because the list is not
// in topological order. This is a modified version of the code in
Expand Down Expand Up @@ -429,7 +429,7 @@ BaseFloat LatticeSimpleDecoder::FinalRelativeCost() const {
return final_relative_cost_;
}
}

// Prune away any tokens on this frame that have no forward links. [we don't do
// this in PruneForwardLinks because it would give us a problem with dangling
// pointers].
Expand All @@ -453,22 +453,22 @@ void LatticeSimpleDecoder::PruneTokensForFrame(int32 frame) {
}
}
}

// Go backwards through still-alive tokens, pruning them, starting not from
// the current frame (where we want to keep all tokens) but from the frame before
// that. We go backwards through the frames and stop when we reach a point
// where the delta-costs are not changing (and the delta controls when we consider
// a cost to have "not changed").
void LatticeSimpleDecoder::PruneActiveTokens(BaseFloat delta) {
int32 cur_frame_plus_one = NumFramesDecoded();
int32 cur_frame_plus_one = NumFramesDecoded();
int32 num_toks_begin = num_toks_;
// The index "f" below represents a "frame plus one", i.e. you'd have to subtract
// one to get the corresponding index for the decodable object.
for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) {
// Reason why we need to prune forward links in this situation:
// (1) we have never pruned them
// (2) we never pruned the forward links on the next frame, which
//
//
if (active_toks_[f].must_prune_forward_links) {
bool extra_costs_changed = false, links_pruned = false;
PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta);
Expand All @@ -478,7 +478,7 @@ void LatticeSimpleDecoder::PruneActiveTokens(BaseFloat delta) {
active_toks_[f].must_prune_tokens = true;
active_toks_[f].must_prune_forward_links = false;
}
if (f+1 < cur_frame_plus_one &&
if (f+1 < cur_frame_plus_one &&
active_toks_[f+1].must_prune_tokens) {
PruneTokensForFrame(f+1);
active_toks_[f+1].must_prune_tokens = false;
Expand All @@ -493,20 +493,20 @@ void LatticeSimpleDecoder::PruneActiveTokens(BaseFloat delta) {
// (optionally) on the final frame. Takes into account the final-prob of
// tokens. This function used to be called PruneActiveTokensFinal().
void LatticeSimpleDecoder::FinalizeDecoding() {
int32 final_frame_plus_one = NumFramesDecoded();
int32 final_frame_plus_one = NumFramesDecoded();
int32 num_toks_begin = num_toks_;
PruneForwardLinksFinal();
for (int32 f = final_frame_plus_one - 1; f >= 0; f--) {
for (int32 f = final_frame_plus_one - 1; f >= 0; f--) {
bool b1, b2; // values not used.
BaseFloat dontcare = 0.0;
PruneForwardLinks(f, &b1, &b2, dontcare);
PruneTokensForFrame(f + 1);
}
PruneTokensForFrame(0);
PruneTokensForFrame(0);
KALDI_VLOG(3) << "pruned tokens from " << num_toks_begin
<< " to " << num_toks_;
}

void LatticeSimpleDecoder::ProcessEmitting(DecodableInterface *decodable) {
int32 frame = active_toks_.size() - 1; // frame is the frame-index
// (zero-based) used to get likelihoods
Expand Down Expand Up @@ -538,9 +538,9 @@ void LatticeSimpleDecoder::ProcessEmitting(DecodableInterface *decodable) {
// AddToken adds the next_tok to cur_toks_ (if not already present).
Token *next_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost,
true, NULL);

// Add ForwardLink from tok to next_tok (put on head of list tok->links)
tok->links = new ForwardLink(next_tok, arc.ilabel, arc.olabel,
tok->links = new ForwardLink(next_tok, arc.ilabel, arc.olabel,
graph_cost, ac_cost, tok->links);
}
}
Expand All @@ -553,7 +553,7 @@ void LatticeSimpleDecoder::ProcessNonemitting() {
// Note: "frame" is the time-index we just processed, or -1 if
// we are processing the nonemitting transitions before the
// first frame (called from InitDecoding()).

// Processes nonemitting arcs for one frame. Propagates within
// cur_toks_. Note-- this queue structure is is not very optimal as
// it may cause us to process states unnecessarily (e.g. more than once),
Expand All @@ -569,15 +569,9 @@ void LatticeSimpleDecoder::ProcessNonemitting() {
queue.push_back(state);
best_cost = std::min(best_cost, iter->second->tot_cost);
}
if (queue.empty()) {
if (!warned_) {
KALDI_ERR << "Error in ProcessEmitting: no surviving tokens: frame is "
<< frame;
warned_ = true;
}
}

BaseFloat cutoff = best_cost + config_.beam;

while (!queue.empty()) {
StateId state = queue.back();
queue.pop_back();
Expand All @@ -600,10 +594,10 @@ void LatticeSimpleDecoder::ProcessNonemitting() {
bool changed;
Token *new_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost,
false, &changed);

tok->links = new ForwardLink(new_tok, 0, arc.olabel,
graph_cost, 0, tok->links);

// "changed" tells us whether the new token has a different
// cost from before, or is new [if so, add into queue].
if (changed && fst_.NumInputEpsilons(arc.nextstate) != 0)
Expand Down Expand Up @@ -662,5 +656,3 @@ void LatticeSimpleDecoder::PruneCurrentTokens(BaseFloat beam, unordered_map<Stat


} // end namespace kaldi.