diff --git a/src/index/sparse/sparse_inverted_index.h b/src/index/sparse/sparse_inverted_index.h index 06665894a..cb891c89e 100644 --- a/src/index/sparse/sparse_inverted_index.h +++ b/src/index/sparse/sparse_inverted_index.h @@ -654,16 +654,16 @@ class InvertedIndex : public BaseInvertedIndex { std::sort(cursors.begin(), cursors.end(), [](auto& x, auto& y) { return x->cur_vec_id_ < y->cur_vec_id_; }); }; sort_cursors(); + float upper_bound = 0; + size_t pivot_cursor_id = 0; while (true) { float threshold = heap.full() ? heap.top().val : 0; - float upper_bound = 0; - size_t pivot; bool found_pivot = false; - for (pivot = 0; pivot < valid_q_dim; ++pivot) { - if (cursors[pivot]->loc_ >= cursors[pivot]->lut_size_) { + for (; pivot_cursor_id < valid_q_dim; ++pivot_cursor_id) { + if (cursors[pivot_cursor_id]->loc_ >= cursors[pivot_cursor_id]->lut_size_) { break; } - upper_bound += cursors[pivot]->max_score_; + upper_bound += cursors[pivot_cursor_id]->max_score_; if (upper_bound > threshold) { found_pivot = true; break; @@ -672,24 +672,35 @@ class InvertedIndex : public BaseInvertedIndex { if (!found_pivot) { break; } - table_t pivot_id = cursors[pivot]->cur_vec_id_; - if (pivot_id == cursors[0]->cur_vec_id_) { - float score = 0; + table_t pivot_vec_id = cursors[pivot_cursor_id]->cur_vec_id_; + if (pivot_vec_id == cursors[0]->cur_vec_id_) { + float score_sum = 0; for (auto& cursor : cursors) { - if (cursor->cur_vec_id_ != pivot_id) { + if (cursor->cur_vec_id_ != pivot_vec_id) { break; } T cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursor->cur_vec_id_) : 0; - score += cursor->q_value_ * computer(cursor->cur_vec_val(), cur_vec_sum); + score_sum += cursor->q_value_ * computer(cursor->cur_vec_val(), cur_vec_sum); cursor->next(); } - heap.push(pivot_id, score); + heap.push(pivot_vec_id, score_sum); sort_cursors(); + pivot_cursor_id = 0; + upper_bound = 0; } else { - size_t next_list = pivot; - for (; cursors[next_list]->cur_vec_id_ == pivot_id; --next_list) { + size_t next_list = pivot_cursor_id; + for (; cursors[next_list]->cur_vec_id_ == pivot_vec_id; --next_list) { + } + cursors[next_list]->seek(pivot_vec_id); + if (cursors[next_list]->cur_vec_id_ > pivot_vec_id) { + upper_bound -= cursors[next_list]->max_score_; + upper_bound -= cursors[pivot_cursor_id]->max_score_; + --pivot_cursor_id; + } else { + // the max_score_ of the pivot will be added again in the next loop, + // so it needs to be subtracted here. + upper_bound -= cursors[pivot_cursor_id]->max_score_; } - cursors[next_list]->seek(pivot_id); for (size_t i = next_list + 1; i < valid_q_dim; ++i) { if (cursors[i]->cur_vec_id_ >= cursors[i - 1]->cur_vec_id_) { break;