Skip to content

Commit

Permalink
Implement second phase rank drop limit for hit collector.
Browse files Browse the repository at this point in the history
  • Loading branch information
toregge committed May 31, 2024
1 parent a479163 commit bc8cf5d
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 20 deletions.
73 changes: 73 additions & 0 deletions searchlib/src/tests/hitcollector/hitcollector_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ using namespace search::fef;
using namespace search::queryeval;

using ScoreMap = std::map<uint32_t, feature_t>;
using DocidVector = std::vector<uint32_t>;
using RankedHitVector = std::vector<RankedHit>;

using Ranges = std::pair<Scores, Scores>;

Expand Down Expand Up @@ -574,4 +576,75 @@ TEST(HitCollectorTest, require_that_hits_can_be_added_out_of_order_only_after_pa
checkResult(*rs, nullptr);
}

struct RankDropFixture {
uint32_t _docid_limit;
HitCollector _hc;
std::vector<uint32_t> _dropped;
RankDropFixture(uint32_t docid_limit, uint32_t max_hits_size)
: _docid_limit(docid_limit),
_hc(docid_limit, max_hits_size)
{
}
void add(std::vector<RankedHit> hits) {
for (const auto& hit : hits) {
_hc.addHit(hit.getDocId(), hit.getRank());
}
}
void rerank(ScoreMap score_map, size_t count) {
PredefinedScorer scorer(score_map);
EXPECT_EQ(count, do_reRank(scorer, _hc, count));
}
std::unique_ptr<BitVector> make_bv(DocidVector docids) {
auto bv = BitVector::create(_docid_limit);
for (auto& docid : docids) {
bv->setBit(docid);
}
return bv;
}

void setup() {
add({{5, 1100},{10, 1200},{11, 1300},{12, 1400},{14, 500},{15, 900},{16,1000}});
rerank({{11,14},{12,13}}, 2);
}
void check_result(std::optional<double> rank_drop_limit, RankedHitVector exp_array,
std::unique_ptr<BitVector> exp_bv, DocidVector exp_dropped) {
auto rs = _hc.get_result_set(rank_drop_limit, &_dropped);
checkResult(*rs, exp_array);
checkResult(*rs, exp_bv.get());
EXPECT_EQ(exp_dropped, _dropped);
}
};

TEST(HitCollectorTest, require_that_second_phase_rank_drop_limit_is_enforced)
{
RankDropFixture f(10000, 10);
f.setup();
f.check_result(9.0, {{5,11},{10,12},{11,14},{12,13},{16,10}},
{}, {14, 15});
}

TEST(HitCollectorTest, require_that_docid_vector_is_used)
{
RankDropFixture f(10000, 4);
f.setup();
f.check_result(13.0, {{11,14}},
{}, {5,10,12,14,15,16});
}

TEST(HitCollectorTest, require_that_bitvector_is_not_dropped_without_rank_drop_limit)
{
RankDropFixture f(20, 4);
f.setup();
f.check_result(std::nullopt, {{5,11},{10,12},{11,14},{12,13}},
f.make_bv({5,10,11,12,14,15,16}), {});
}

TEST(HitCollectorTest, require_that_bitvector_is_dropped_with_rank_drop_limit)
{
RankDropFixture f(20, 4);
f.setup();
f.check_result(9.0, {{5,11},{10,12},{11,14},{12,13}},
{}, {14,15,16});
}

GTEST_MAIN_RUN_ALL_TESTS()
157 changes: 137 additions & 20 deletions searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,57 +219,150 @@ class RerankRescorer {
}
};

template <typename Rescorer>
class SimpleHitAdder {
protected:
ResultSet& _rs;
public:
SimpleHitAdder(ResultSet& rs)
: _rs(rs)
{
}
void add(uint32_t docid, double rank_value) {
_rs.push_back({docid, rank_value});
}
};

class ConditionalHitAdder : public SimpleHitAdder {
protected:
double _second_phase_rank_drop_limit;
public:
ConditionalHitAdder(ResultSet& rs, double second_phase_rank_drop_limit)
: SimpleHitAdder(rs),
_second_phase_rank_drop_limit(second_phase_rank_drop_limit)
{
}
void add(uint32_t docid, double rank_value) {
if (rank_value > _second_phase_rank_drop_limit) {
_rs.push_back({docid, rank_value});
}
}
};

class TrackingConditionalHitAdder : public ConditionalHitAdder {
std::vector<uint32_t>& _dropped;
public:
TrackingConditionalHitAdder(ResultSet& rs, double second_phase_rank_drop_limit, std::vector<uint32_t>& dropped)
: ConditionalHitAdder(rs, second_phase_rank_drop_limit),
_dropped(dropped)
{
}
void add(uint32_t docid, double rank_value) {
if (rank_value > _second_phase_rank_drop_limit) {
_rs.push_back({docid, rank_value});
} else {
_dropped.emplace_back(docid);
}
}
};

template <typename HitAdder, typename Rescorer>
void
add_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, Rescorer rescorer)
add_rescored_hits(HitAdder hit_adder, const std::vector<HitCollector::Hit>& hits, Rescorer rescorer)
{
for (auto& hit : hits) {
rs.push_back({hit.first, rescorer.rescore(hit.first, hit.second)});
hit_adder.add(hit.first, rescorer.rescore(hit.first, hit.second));
}
}

template <typename Rescorer>
template <typename HitAdder, typename Rescorer>
void
add_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, const std::vector<HitCollector::Hit>& reranked_hits, Rescorer rescorer)
add_rescored_hits(HitAdder hit_adder, const std::vector<HitCollector::Hit>& hits, const std::vector<HitCollector::Hit>& reranked_hits, Rescorer rescorer)
{
if (reranked_hits.empty()) {
add_rescored_hits(rs, hits, rescorer);
add_rescored_hits(hit_adder, hits, rescorer);
} else {
add_rescored_hits(rs, hits, RerankRescorer(rescorer, reranked_hits));
add_rescored_hits(hit_adder, hits, RerankRescorer(rescorer, reranked_hits));
}
}

template <typename Rescorer>
void
mixin_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, const std::vector<uint32_t>& docids, double default_value, Rescorer rescorer)
add_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, const std::vector<HitCollector::Hit>& reranked_hits, std::optional<double> second_phase_rank_drop_limit, std::vector<uint32_t>* dropped, Rescorer rescorer)
{
if (second_phase_rank_drop_limit.has_value()) {
if (dropped != nullptr) {
add_rescored_hits(TrackingConditionalHitAdder(rs, second_phase_rank_drop_limit.value(), *dropped), hits, reranked_hits, rescorer);
} else {
add_rescored_hits(ConditionalHitAdder(rs, second_phase_rank_drop_limit.value()), hits, reranked_hits, rescorer);
}
} else {
add_rescored_hits(SimpleHitAdder(rs), hits, reranked_hits, rescorer);
}
}

template <typename HitAdder, typename Rescorer>
void
mixin_rescored_hits(HitAdder hit_adder, const std::vector<HitCollector::Hit>& hits, const std::vector<uint32_t>& docids, double default_value, Rescorer rescorer)
{
auto hits_cur = hits.begin();
auto hits_end = hits.end();
for (auto docid : docids) {
if (hits_cur != hits_end && docid == hits_cur->first) {
rs.push_back({docid, rescorer.rescore(docid, hits_cur->second)});
hit_adder.add(docid, rescorer.rescore(docid, hits_cur->second));
++hits_cur;
} else {
rs.push_back({docid, default_value});
hit_adder.add(docid, default_value);
}
}
}

template <typename Rescorer>
template <typename HitAdder, typename Rescorer>
void
mixin_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, const std::vector<uint32_t>& docids, double default_value, const std::vector<HitCollector::Hit>& reranked_hits, Rescorer rescorer)
mixin_rescored_hits(HitAdder hit_adder, const std::vector<HitCollector::Hit>& hits, const std::vector<uint32_t>& docids, double default_value, const std::vector<HitCollector::Hit>& reranked_hits, Rescorer rescorer)
{
if (reranked_hits.empty()) {
mixin_rescored_hits(rs, hits, docids, default_value, rescorer);
mixin_rescored_hits(hit_adder, hits, docids, default_value, rescorer);
} else {
mixin_rescored_hits(rs, hits, docids, default_value, RerankRescorer(rescorer, reranked_hits));
mixin_rescored_hits(hit_adder, hits, docids, default_value, RerankRescorer(rescorer, reranked_hits));
}
}

template <typename Rescorer>
void
mixin_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, const std::vector<uint32_t>& docids, double default_value, const std::vector<HitCollector::Hit>& reranked_hits, std::optional<double> second_phase_rank_drop_limit, std::vector<uint32_t>* dropped, Rescorer rescorer)
{
if (second_phase_rank_drop_limit.has_value()) {
if (dropped != nullptr) {
mixin_rescored_hits(TrackingConditionalHitAdder(rs, second_phase_rank_drop_limit.value(), *dropped), hits, docids, default_value, reranked_hits, rescorer);
} else {
mixin_rescored_hits(ConditionalHitAdder(rs, second_phase_rank_drop_limit.value()), hits, docids, default_value, reranked_hits, rescorer);
}
} else {
mixin_rescored_hits(SimpleHitAdder(rs), hits, docids, default_value, reranked_hits, rescorer);
}
}

void
add_bitvector_to_dropped(std::vector<uint32_t>& dropped, vespalib::ConstArrayRef<RankedHit> hits, const BitVector& bv)
{
auto hits_cur = hits.begin();
auto hits_end = hits.end();
auto docid = bv.getFirstTrueBit();
auto docid_limit = bv.size();
while (docid < docid_limit) {
if (hits_cur != hits_end && hits_cur->getDocId() == docid) {
++hits_cur;
} else {
dropped.emplace_back(docid);
}
docid = bv.getNextTrueBit(docid + 1);
}
}

}

std::unique_ptr<ResultSet>
HitCollector::getResultSet()
HitCollector::get_result_set(std::optional<double> second_phase_rank_drop_limit, std::vector<uint32_t>* dropped)
{
/*
* Use default_rank_value (i.e. -HUGE_VAL) when hit collector saves
Expand All @@ -280,34 +373,58 @@ HitCollector::getResultSet()
bool needReScore = FirstPhaseRescorer::need_rescore(_ranges);
FirstPhaseRescorer rescorer(_ranges);

if (dropped != nullptr) {
dropped->clear();
}

// destroys the heap property or score sort order
sortHitsByDocId();

auto rs = std::make_unique<ResultSet>();
if ( ! _collector->isDocIdCollector() ) {
if ( ! _collector->isDocIdCollector() ||
(second_phase_rank_drop_limit.has_value() &&
(_bitVector || dropped == nullptr))) {
rs->allocArray(_hits.size());
auto* dropped_or_null = dropped;
if (second_phase_rank_drop_limit.has_value() && _bitVector) {
dropped_or_null = nullptr;
}
if (needReScore) {
add_rescored_hits(*rs, _hits, _reRankedHits, rescorer);
add_rescored_hits(*rs, _hits, _reRankedHits, second_phase_rank_drop_limit, dropped_or_null, rescorer);
} else {
add_rescored_hits(*rs, _hits, _reRankedHits, NoRescorer());
add_rescored_hits(*rs, _hits, _reRankedHits, second_phase_rank_drop_limit, dropped_or_null, NoRescorer());
}
} else {
if (_unordered) {
std::sort(_docIdVector.begin(), _docIdVector.end());
}
rs->allocArray(_docIdVector.size());
if (needReScore) {
mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, _reRankedHits, rescorer);
mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, _reRankedHits, second_phase_rank_drop_limit, dropped, rescorer);
} else {
mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, _reRankedHits, NoRescorer());
mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, _reRankedHits, second_phase_rank_drop_limit, dropped, NoRescorer());
}
}

if (second_phase_rank_drop_limit.has_value() && _bitVector) {
if (dropped != nullptr) {
assert(dropped->empty());
add_bitvector_to_dropped(*dropped, {rs->getArray(), rs->getArrayUsed()}, *_bitVector);
}
_bitVector.reset();
}

if (_bitVector) {
rs->setBitOverflow(std::move(_bitVector));
}

return rs;
}

std::unique_ptr<ResultSet>
HitCollector::getResultSet()
{
return get_result_set(std::nullopt, nullptr);
}

}
4 changes: 4 additions & 0 deletions searchlib/src/vespa/searchlib/queryeval/hitcollector.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <vespa/searchlib/common/resultset.h>
#include <vespa/vespalib/util/sort.h>
#include <algorithm>
#include <optional>
#include <vector>

namespace search::queryeval {
Expand Down Expand Up @@ -166,6 +167,9 @@ class HitCollector {
const std::pair<Scores, Scores> &getRanges() const { return _ranges; }
void setRanges(const std::pair<Scores, Scores> &ranges);

std::unique_ptr<ResultSet>
get_result_set(std::optional<double> second_phase_rank_drop_limit, std::vector<uint32_t>* dropped);

/**
* Returns a result set based on the content of this collector.
* Invoking this method will destroy the heap property of the
Expand Down

0 comments on commit bc8cf5d

Please sign in to comment.