Skip to content

Commit

Permalink
sparse: use staged cursor and upper_bound for WAND
Browse files Browse the repository at this point in the history
Signed-off-by: Shawn Wang <[email protected]>
  • Loading branch information
sparknack committed Dec 3, 2024
1 parent 95bce9a commit 9ddd7a7
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions src/index/sparse/sparse_inverted_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -654,16 +654,16 @@ class InvertedIndex : public BaseInvertedIndex<T> {
std::sort(cursors.begin(), cursors.end(), [](auto& x, auto& y) { return x->cur_vec_id_ < y->cur_vec_id_; });
};
sort_cursors();
float upper_bound = 0;
size_t pivot_cursor_id = 0;
while (true) {
float threshold = heap.full() ? heap.top().val : 0;
float upper_bound = 0;
size_t pivot;
bool found_pivot = false;
for (pivot = 0; pivot < valid_q_dim; ++pivot) {
if (cursors[pivot]->loc_ >= cursors[pivot]->lut_size_) {
for (; pivot_cursor_id < valid_q_dim; ++pivot_cursor_id) {
if (cursors[pivot_cursor_id]->loc_ >= cursors[pivot_cursor_id]->lut_size_) {
break;
}
upper_bound += cursors[pivot]->max_score_;
upper_bound += cursors[pivot_cursor_id]->max_score_;
if (upper_bound > threshold) {
found_pivot = true;
break;
Expand All @@ -672,24 +672,33 @@ class InvertedIndex : public BaseInvertedIndex<T> {
if (!found_pivot) {
break;
}
table_t pivot_id = cursors[pivot]->cur_vec_id_;
if (pivot_id == cursors[0]->cur_vec_id_) {
float score = 0;
table_t pivot_vec_id = cursors[pivot_cursor_id]->cur_vec_id_;
if (pivot_vec_id == cursors[0]->cur_vec_id_) {
float score_sum = 0;
for (auto& cursor : cursors) {
if (cursor->cur_vec_id_ != pivot_id) {
if (cursor->cur_vec_id_ != pivot_vec_id) {
break;
}
T cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursor->cur_vec_id_) : 0;
score += cursor->q_value_ * computer(cursor->cur_vec_val(), cur_vec_sum);
score_sum += cursor->q_value_ * computer(cursor->cur_vec_val(), cur_vec_sum);
cursor->next();
}
heap.push(pivot_id, score);
heap.push(pivot_vec_id, score_sum);
sort_cursors();
pivot_cursor_id = 0;
upper_bound = 0;
} else {
size_t next_list = pivot;
for (; cursors[next_list]->cur_vec_id_ == pivot_id; --next_list) {
size_t next_list = pivot_cursor_id;
for (; cursors[next_list]->cur_vec_id_ == pivot_vec_id; --next_list) {
}
cursors[next_list]->seek(pivot_vec_id);
if (cursors[next_list]->cur_vec_id_ > pivot_vec_id) {
upper_bound -= cursors[next_list]->max_score_;
} else {
// the max_score_ of the pivot will be added again in the next loop,
// so it needs to be subtracted here.
upper_bound -= cursors[pivot_cursor_id]->max_score_;
}
cursors[next_list]->seek(pivot_id);
for (size_t i = next_list + 1; i < valid_q_dim; ++i) {
if (cursors[i]->cur_vec_id_ >= cursors[i - 1]->cur_vec_id_) {
break;
Expand Down

0 comments on commit 9ddd7a7

Please sign in to comment.