Skip to content

Commit

Permalink
Replace sort with priority queue
Browse files Browse the repository at this point in the history
  • Loading branch information
johnlees committed Mar 16, 2024
1 parent d9e7d83 commit f55c87a
Showing 1 changed file with 36 additions and 11 deletions.
47 changes: 36 additions & 11 deletions src/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <algorithm>
#include <limits>
#include <queue>

#include <H5Cpp.h>
#include <omp.h>
Expand Down Expand Up @@ -314,6 +315,27 @@ void check_sparse_inputs(const std::vector<Reference> &ref_sketches,
}
}

class SparseDist {
public:
SparseDist(float dist, long j) : _dist(dist), _j(j) {}

float dist() {return _dist; }
long j() {return _j; }

friend bool operator<(SparseDist const &a, SparseDist const &b)
{
return a._dist < b._dist;
}
friend bool operator==(SparseDist const &a, SparseDist const &b)
{
return a._dist == b._dist;
}

private:
float _dist;
long _j;
}

sparse_coo query_db_sparse(std::vector<Reference> &ref_sketches,
const std::vector<size_t> &kmer_lengths,
RandomMC &random_chance, const bool jaccard,
Expand Down Expand Up @@ -344,27 +366,32 @@ sparse_coo query_db_sparse(std::vector<Reference> &ref_sketches,
Eigen::MatrixXf kmer_mat = kmer2mat(kmer_lengths);
#pragma omp parallel for schedule(static) num_threads(num_threads) shared(progress)
for (size_t i = 0; i < ref_sketches.size(); i++) {
std::vector<float> row_dists(ref_sketches.size());
std::priority_queue<SparseDist> min_dists;
if (!interrupt) {
for (size_t j = 0; j < ref_sketches.size(); j++) {
float row_dist = std::numeric_limits<float>::infinity();
if (i != j) {
if (jaccard) {
// Need 1-J here to sort correctly
row_dists[j] = 1.0f - ref_sketches[i].jaccard_dist(
row_dist = 1.0f - ref_sketches[i].jaccard_dist(
ref_sketches[j], kmer_lengths[dist_col], random_chance);
} else {
float core, acc;
std::tie(core, acc) =
ref_sketches[i].core_acc_dist<RandomMC>(
ref_sketches[j], kmer_mat, random_chance);
if (dist_col == 0) {
row_dists[j] = core;
row_dist = core;
} else {
row_dists[j] = acc;
row_dist = acc;
}
}
if (min_dists.size() < kNN || row_dist < min_dists.top()) {
min_dists.push(SparseDist(row_dist, j));
if (min_dists.size > kNN) {
min_dists.pop();
}
}
} else {
row_dists[j] = std::numeric_limits<float>::infinity();
}
if ((i * ref_sketches.size() + j) % update_every == 0) {
#pragma omp critical
Expand All @@ -377,13 +404,11 @@ sparse_coo query_db_sparse(std::vector<Reference> &ref_sketches,
}
}
long offset = i * kNN;
std::vector<long> ordered_dists = sort_indexes(row_dists);
std::fill_n(i_vec.begin() + offset, kNN, i);
// std::copy_n(ordered_dists.begin(), kNN, j_vec.begin() + offset);

for (int k = 0; k < kNN; ++k) {
j_vec[offset + k] = ordered_dists[k];
dists[offset + k] = row_dists[ordered_dists[k]];
SparseDist entry = min_dists.pop();
j_vec[offset + k] = entry.j();
dists[offset + k] = entry.dist();
}

}
Expand Down

0 comments on commit f55c87a

Please sign in to comment.