From f55c87ae056abfe3c72c101074405ac36ee3e1ea Mon Sep 17 00:00:00 2001 From: John Lees Date: Sat, 16 Mar 2024 12:24:15 +0000 Subject: [PATCH] Replace sort with priority queue --- src/api.cpp | 47 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/src/api.cpp b/src/api.cpp index 9fc2d9cf..1118731f 100644 --- a/src/api.cpp +++ b/src/api.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -314,6 +315,27 @@ void check_sparse_inputs(const std::vector &ref_sketches, } } +class SparseDist { + public: + SparseDist(float dist, long j) : _dist(dist), _j(j) {} + + float dist() {return _dist; } + long j() {return _j; } + + friend bool operator<(SparseDist const &a, SparseDist const &b) + { + return a._dist < b._dist; + } + friend bool operator==(SparseDist const &a, SparseDist const &b) + { + return a._dist == b._dist; + } + + private: + float _dist; + long _j; +} + sparse_coo query_db_sparse(std::vector &ref_sketches, const std::vector &kmer_lengths, RandomMC &random_chance, const bool jaccard, @@ -344,13 +366,14 @@ sparse_coo query_db_sparse(std::vector &ref_sketches, Eigen::MatrixXf kmer_mat = kmer2mat(kmer_lengths); #pragma omp parallel for schedule(static) num_threads(num_threads) shared(progress) for (size_t i = 0; i < ref_sketches.size(); i++) { - std::vector row_dists(ref_sketches.size()); + std::priority_queue min_dists; if (!interrupt) { for (size_t j = 0; j < ref_sketches.size(); j++) { + float row_dist = std::numeric_limits::infinity(); if (i != j) { if (jaccard) { // Need 1-J here to sort correctly - row_dists[j] = 1.0f - ref_sketches[i].jaccard_dist( + row_dist = 1.0f - ref_sketches[i].jaccard_dist( ref_sketches[j], kmer_lengths[dist_col], random_chance); } else { float core, acc; @@ -358,13 +381,17 @@ sparse_coo query_db_sparse(std::vector &ref_sketches, ref_sketches[i].core_acc_dist( ref_sketches[j], kmer_mat, random_chance); if (dist_col == 0) { - row_dists[j] = core; + row_dist = core; } else { - row_dists[j] = acc; + row_dist = acc; + } + } + if (min_dists.size() < kNN || row_dist < min_dists.top()) { + min_dists.push(SparseDist(row_dist, j)); + if (min_dists.size > kNN) { + min_dists.pop(); } } - } else { - row_dists[j] = std::numeric_limits::infinity(); } if ((i * ref_sketches.size() + j) % update_every == 0) { #pragma omp critical @@ -377,13 +404,11 @@ sparse_coo query_db_sparse(std::vector &ref_sketches, } } long offset = i * kNN; - std::vector ordered_dists = sort_indexes(row_dists); std::fill_n(i_vec.begin() + offset, kNN, i); - // std::copy_n(ordered_dists.begin(), kNN, j_vec.begin() + offset); - for (int k = 0; k < kNN; ++k) { - j_vec[offset + k] = ordered_dists[k]; - dists[offset + k] = row_dists[ordered_dists[k]]; + SparseDist entry = min_dists.pop(); + j_vec[offset + k] = entry.j(); + dists[offset + k] = entry.dist(); } }