From 38890f1ca43488146274ea0aabec0a982fc1da93 Mon Sep 17 00:00:00 2001 From: Gabriel Eiseman Date: Mon, 12 Feb 2024 23:04:50 -0500 Subject: [PATCH] KD tree bugfixes and improvements - moved k_closest_naive to debug header - added check_layer to debug header - made k_closest_state and cmp_pt_dist public so that user code can compare points by distance to a fixed point (extremely dirty) - improved k_closest error reporting and docs - fix split_i64cu (wasn't fully setting sub-bounds) - add debug checks - extend pointcloud test: also tests k_closest and has many compile time params --- include/crater/kd_check.h | 37 ++++++---- include/crater/kd_tree.h | 31 +++++--- src/lib/crater/kd_tree.c | 48 ++++++------ src/test/pointcloud.c | 151 ++++++++++++++++++++++++++++---------- 4 files changed, 184 insertions(+), 83 deletions(-) diff --git a/include/crater/kd_check.h b/include/crater/kd_check.h index 32bf096..3c5bd72 100644 --- a/include/crater/kd_check.h +++ b/include/crater/kd_check.h @@ -2,23 +2,31 @@ #include +inline static bool cr8r_kd_check_layer(const cr8r_vec *self, const cr8r_kd_ft *ft, uint64_t a, uint64_t b){ + uint64_t mid_idx = (a + b)/2; + const void *mid = self->buf + mid_idx*ft->super.base.size; + for(uint64_t i = a; i < mid_idx; ++i){ + const void *ent = self->buf + i*ft->super.base.size; + if(ft->super.cmp(&ft->super.base, ent, mid) > 0){ + return false; + } + } + for(uint64_t i = mid_idx + 1; i < b; ++i){ + const void *ent = self->buf + i*ft->super.base.size; + if(ft->super.cmp(&ft->super.base, ent, mid) < 0){ + return false; + } + } + return true; +} + inline static bool cr8r_kd_check_tree(const cr8r_vec *self, const cr8r_kd_ft *_ft, uint64_t a, uint64_t b){ cr8r_kd_ft ft = *_ft; while(b > a){ - uint64_t mid_idx = (a + b)/2; - const void *mid = self->buf + mid_idx*ft.super.base.size; - for(uint64_t i = a; i < mid_idx; ++i){ - const void *ent = self->buf + i*ft.super.base.size; - if(ft.super.cmp(&ft.super.base, ent, mid) > 0){ - return false; - } - } - for(uint64_t i = mid_idx + 1; i < b; ++i){ - const void *ent = self->buf + i*ft.super.base.size; - if(ft.super.cmp(&ft.super.base, ent, mid) < 0){ - return false; - } + if(!cr8r_kd_check_layer(self, &ft, a, b)){ + return false; } + uint64_t mid_idx = (a + b)/2; // increment depth ++*(uint64_t*)&ft.super.base.data; if(!cr8r_kd_check_tree(self, &ft, mid_idx + 1, b)){ @@ -29,3 +37,6 @@ inline static bool cr8r_kd_check_tree(const cr8r_vec *self, const cr8r_kd_ft *_f return true; } +bool cr8r_kd_k_closest_naive(cr8r_vec*, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out); + + diff --git a/include/crater/kd_tree.h b/include/crater/kd_tree.h index 06e1979..f4817f9 100644 --- a/include/crater/kd_tree.h +++ b/include/crater/kd_tree.h @@ -11,7 +11,6 @@ /// License, v. 2.0. If a copy of the MPL was not distributed with this /// file, You can obtain one at http://mozilla.org/MPL/2.0/. -#include "crater/container.h" #include #include #include @@ -95,6 +94,15 @@ typedef struct{ int64_t tr[3]; } cr8r_kdwin_s2i64; +/// Additional state for { @link } +typedef struct{ + cr8r_kd_ft ft; + const void *pt; + cr8r_vec *ents; + uint64_t k; + double max_sqdist; +} cr8r_kd_k_closest_state; + /// Type of a visitor callback for a kd tree traversal /// /// f(cr8r_kd_ft *ft, const void *win, void *ent, void *data) @@ -134,6 +142,8 @@ void cr8r_kd_walk(cr8r_vec*, const cr8r_kd_ft *ft, const void *bounds, cr8r_kdvi /// Find the k closest points to a given point /// +/// The output vec MUST be initialized ({ @link cr8r_vec_init }), but +/// will be { @link cr8r_vec_clear}'d. /// @param [in] ft: (uint64_t)_ft->super.base.data is interpreted as the depth of the current subarray, and /// so should be zero when calling this function generally on the whole vector /// @param [in] bounds: the bounds of the tree @@ -141,17 +151,18 @@ void cr8r_kd_walk(cr8r_vec*, const cr8r_kd_ft *ft, const void *bounds, cr8r_kdvi /// If there is a tie for the kth closest point, the list of the k closest points may contain any of the tied points, /// but will always contain exactly k unless there are fewer than k entries in the tree. /// @param [in] k: the number of points to find -/// @param [out] out: vector where the k closest points will be stored, in no specified order (the current -/// implementation unsurprisingly places them in a minmax heap order) -void cr8r_kd_k_closest(cr8r_vec*, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out); +/// @param [out] out: vector where the k closest points will be stored. Must be allocated but will be cleared before use. +/// The output is in no specified order (the current implementation unsurprisingly places them in a minmax heap order +/// @return true on success, false on (allocation) failure. Will not allocate if out->cap >= k + 1 already, but may +/// return fewer than k points if fewer than k are available +bool cr8r_kd_k_closest(cr8r_vec*, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out); -/// Find the k closest points to a given point +/// Comparison function that compares two kd tree points based on their distance from a fixed point /// -/// This function is designed to enable testing { @link cr8r_kd_k_closest }. That function prunes its search area if the max -/// distance of any of the k closest points so far is less than the min distance of the bounding box of the subtree rooted ata -/// a node, then that node and its subtree can be skipped. This function does not do that and just compares all points. -void cr8r_kd_k_closest_naive(cr8r_vec*, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out); - +/// This function MUST be called with _ft pointing at a cr8r_base_ft WITHIN { @link cr8r_kd_k_closest_state }. +/// See the implementation of { @link cr8r_kd_k_closest } for more details. +/// This function will only read from the `pt` field of the `cr8r_kd_k_closest_state` struct. +int cr8r_default_cmp_kd_kcs_pt_dist(const cr8r_base_ft *_ft, const void *a, const void *b); /// kdft implementation for spherical kd trees in 3 dimensions with i64 coordinates extern cr8r_kd_ft cr8r_kdft_s2i64; /// kdft implementation for cuboid kd trees in 3 dimensions with i64 coordintates diff --git a/src/lib/crater/kd_tree.c b/src/lib/crater/kd_tree.c index f882ab4..080963b 100644 --- a/src/lib/crater/kd_tree.c +++ b/src/lib/crater/kd_tree.c @@ -1,5 +1,3 @@ -#include "crater/container.h" -#include "crater/vec.h" #include #include #include @@ -7,6 +5,7 @@ #include #include +#include #include static int cmp_depth_i64cu(const cr8r_base_ft *_ft, const void *_a, const void *_b){ @@ -191,8 +190,8 @@ static void split_i64cu(const cr8r_kd_ft *ft, const void *_self, const void *_ro cr8r_kdwin_s2i64 *o1 = _o1; cr8r_kdwin_s2i64 *o2 = _o2; uint64_t idx = (uint64_t)ft->super.base.data%ft->dim; - memcpy(o1, self, ft->dim*sizeof(int64_t)); - memcpy(o2, self, ft->dim*sizeof(int64_t)); + memcpy(o1, self, sizeof(cr8r_kdwin_s2i64)); + memcpy(o2, self, sizeof(cr8r_kdwin_s2i64)); o1->tr[idx] = root[idx]; o2->bl[idx] = root[idx]; } @@ -270,6 +269,11 @@ bool cr8r_kd_ify(cr8r_vec *self, cr8r_kd_ft *_ft, uint64_t a, uint64_t b){ return 0; } piv = cr8r_vec_partition_with_median(self, &ft.super, a, b, piv); +#ifdef DEBUG + if(!piv || !cr8r_kd_check_layer(self, &ft, a, b)){ + __builtin_trap(); + } +#endif // increment depth ++*(uint64_t*)&ft.super.base.data; if(!cr8r_kd_ify(self, &ft, mid_idx + 1, b)){ @@ -312,16 +316,8 @@ void cr8r_kd_walk(cr8r_vec *self, const cr8r_kd_ft *ft, const void *_bounds, cr8 cr8r_kd_walk_r(self, ft, bounds, visitor, data, 0, self->len); } -typedef struct{ - cr8r_vec *ents; - cr8r_kd_ft ft; - const void *pt; - uint64_t k; - double max_sqdist; -} k_closest_state; - inline static cr8r_walk_decision k_closest_visitor(cr8r_kd_ft *ft, const void *bounds, void *ent, void *_data){ - k_closest_state *data = _data; + cr8r_kd_k_closest_state *data = _data; char tmp[ft->super.base.size]; if(data->ents->len < data->k){ cr8r_mmheap_push(data->ents, &data->ft.super, ent); @@ -335,9 +331,9 @@ inline static cr8r_walk_decision k_closest_visitor(cr8r_kd_ft *ft, const void *b return CR8R_WALK_SKIP_CHILDREN; } -inline static int cmp_pt_dist(const cr8r_base_ft *_ft, const void *a, const void *b){ +int cr8r_default_cmp_kd_kcs_pt_dist(const cr8r_base_ft *_ft, const void *a, const void *b){ const cr8r_kd_ft *ft = (const cr8r_kd_ft*)_ft; - const k_closest_state *data = CR8R_OUTER(ft, k_closest_state, ft); + const cr8r_kd_k_closest_state *data = CR8R_OUTER(ft, cr8r_kd_k_closest_state, ft); double a_sqdist = ft->sqdist(ft, data->pt, a); double b_sqdist = ft->sqdist(ft, data->pt, b); if(a_sqdist < b_sqdist){ @@ -348,30 +344,40 @@ inline static int cmp_pt_dist(const cr8r_base_ft *_ft, const void *a, const void return 0; } -void cr8r_kd_k_closest(cr8r_vec *self, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out){ - k_closest_state data = { +bool cr8r_kd_k_closest(cr8r_vec *self, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out){ + cr8r_vec_clear(out, &ft->super); + if(!cr8r_vec_ensure_cap(out, &ft->super, k + 1)){ + return false; + } + cr8r_kd_k_closest_state data = { .ents = out, .ft = *ft, .pt = pt, .k = k, .max_sqdist = INFINITY }; - data.ft.super.cmp = cmp_pt_dist; + data.ft.super.cmp = cr8r_default_cmp_kd_kcs_pt_dist; cr8r_kd_walk(self, ft, bounds, k_closest_visitor, &data); + return true; } -void cr8r_kd_k_closest_naive(cr8r_vec *self, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out){ - k_closest_state data = { +bool cr8r_kd_k_closest_naive(cr8r_vec *self, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out){ + cr8r_vec_clear(out, &ft->super); + if(!cr8r_vec_ensure_cap(out, &ft->super, k + 1)){ + return false; + } + cr8r_kd_k_closest_state data = { .ents = out, .ft = *ft, .pt = pt, .k = k, .max_sqdist = INFINITY }; - data.ft.super.cmp = cmp_pt_dist; + data.ft.super.cmp = cr8r_default_cmp_kd_kcs_pt_dist; for(uint64_t i = 0; i < self->len; ++i){ k_closest_visitor(ft, bounds, self->buf + i*ft->super.base.size, &data); } + return true; } cr8r_kd_ft cr8r_kdft_s2i64 = { diff --git a/src/test/pointcloud.c b/src/test/pointcloud.c index bd9220d..ce8bcff 100644 --- a/src/test/pointcloud.c +++ b/src/test/pointcloud.c @@ -7,54 +7,125 @@ #include #include -int main(){ - uint64_t tested = 0, passed = 0; - cr8r_prng *prng = cr8r_prng_init_lcg(0x555b6745db2f2b85); - cr8r_vec points = {}; - cr8r_vec_init(&points, &cr8r_kdft_c3i64.super, 1000); - fprintf(stderr, "\e[1;34mGenerating 1000 random lattice points within [-1000, 1000]^3\e[0m\n"); - for(uint64_t i = 0; i < points.cap; ++i){ - int64_t point[3]; - for(uint64_t j = 0; j < 3; ++j){ - int64_t x = cr8r_prng_uniform_u64(prng, 0, 2000); - point[j] = x - 1000; - } - cr8r_vec_pushr(&points, &cr8r_kdft_c3i64.super, point); +static cr8r_kd_k_closest_state point_kcs; + +#define NUM_POINTS 1000 +#define BOX_SIZE 2000 +#define KCS_SIZE 2200 +#define KCS_COUNT 50 +#define KCS_TRIALS 50 +#define KD_TRIALS 5 + +static void print_kcs(const cr8r_vec *points){ + for(uint64_t i = 0; i < points->len; ++i){ + int64_t *point = points->buf + i*point_kcs.ft.super.base.size; + double dist = point_kcs.ft.sqdist(&point_kcs.ft, point_kcs.pt, point); + fprintf(stderr, "%.0f: (%"PRIi64", %"PRIi64", %"PRIi64"), ", dist, point[0], point[1], point[2]); } - fprintf(stderr, "\e[1;34mChecking bounds of points\e[0m\n"); - cr8r_kdwin_s2i64 bounds = {}, bounds1 = {}; - cr8r_kdwin_bounding_i64x3(&bounds, &points, &cr8r_kdft_c3i64); - memcpy(&bounds1.bl, points.buf, cr8r_kdft_c3i64.super.base.size); - memcpy(&bounds1.tr, points.buf, cr8r_kdft_c3i64.super.base.size); - for(uint64_t i = 1; i < points.len; ++i){ - const int64_t *point = cr8r_vec_get(&points, &cr8r_kdft_c3i64.super, i); + fprintf(stderr, "\n"); +} + +static void get_bounds(cr8r_vec *points, cr8r_kdwin_s2i64 *bounds){ + memcpy(bounds->bl, points->buf, cr8r_kdft_c3i64.super.base.size); + memcpy(bounds->tr, points->buf, cr8r_kdft_c3i64.super.base.size); + for(uint64_t i = 1; i < points->len; ++i){ + const int64_t *point = cr8r_vec_get(points, &cr8r_kdft_c3i64.super, i); for(uint64_t j = 0; j < 3; ++j){ - if(point[j] < bounds1.bl[j]){ - bounds1.bl[j] = point[j]; - }else if(point[j] > bounds1.tr[j]){ - bounds1.tr[j] = point[j]; + if(point[j] < bounds->bl[j]){ + bounds->bl[j] = point[j]; + }else if(point[j] > bounds->tr[j]){ + bounds->tr[j] = point[j]; } } } - ++tested; - if(memcmp(&bounds, &bounds1, sizeof(cr8r_kdwin_s2i64))){ - fprintf(stderr, "\e[1;31mcr8r_kdwin_bounding_i64x3 didn't compute the correct bounds!\e[0m\n"); - }else{ - fprintf(stderr, "\e[1;32mFound bounds (%"PRIi64",%"PRIi64",%"PRIi64"):(%"PRIi64",%"PRIi64",%"PRIi64")\e[0m\n", bounds.bl[0], bounds.bl[1], bounds.bl[2], bounds.tr[0], bounds.tr[1], bounds.tr[2]); - ++passed; - } +} - int status = cr8r_kd_ify(&points, &cr8r_kdft_c3i64, 0, points.len); - ++tested; - if(status){ - if(cr8r_kd_check_tree(&points, &cr8r_kdft_c3i64, 0, points.len)){ - fprintf(stderr, "\e[1;32mKD Tree built successfully!\e[0m\n"); +int main(){ + point_kcs.ft = cr8r_kdft_c3i64; + point_kcs.ft.super.cmp = cr8r_default_cmp_kd_kcs_pt_dist; + uint64_t tested = 0, passed = 0; + cr8r_prng *prng = cr8r_prng_init_lcg(0x555b6745db2f2b85); + cr8r_vec points = {}, res_points1 = {}, res_points2; + cr8r_vec_init(&points, &cr8r_kdft_c3i64.super, NUM_POINTS); + cr8r_vec_init(&res_points1, &point_kcs.ft.super, KCS_COUNT + 1); + cr8r_vec_init(&res_points2, &point_kcs.ft.super, KCS_COUNT + 1); + for(uint64_t trial = 0; trial < KD_TRIALS; ++trial){ + cr8r_vec_clear(&points, &cr8r_kdft_c3i64.super); + fprintf(stderr, "\e[1;34mGenerating %1$d random lattice points within [-%2$d, %2$d]^3\e[0m\n", NUM_POINTS, BOX_SIZE/2); + for(uint64_t i = 0; i < points.cap; ++i){ + int64_t point[3]; + for(uint64_t j = 0; j < 3; ++j){ + int64_t x = cr8r_prng_uniform_u64(prng, 0, BOX_SIZE + 1); + point[j] = x - BOX_SIZE/2; + } + cr8r_vec_pushr(&points, &cr8r_kdft_c3i64.super, point); + } + fprintf(stderr, "\e[1;34mChecking bounds of points\e[0m\n"); + cr8r_kdwin_s2i64 bounds = {}, bounds1 = {}; + cr8r_kdwin_bounding_i64x3(&bounds, &points, &cr8r_kdft_c3i64); + get_bounds(&points, &bounds1); + ++tested; + if(memcmp(&bounds, &bounds1, sizeof(cr8r_kdwin_s2i64))){ + fprintf(stderr, "\e[1;31mcr8r_kdwin_bounding_i64x3 didn't compute the correct bounds!\e[0m\n"); + }else{ + fprintf(stderr, "\e[1;32mFound bounds (%"PRIi64",%"PRIi64",%"PRIi64"):(%"PRIi64",%"PRIi64",%"PRIi64")\e[0m\n", bounds.bl[0], bounds.bl[1], bounds.bl[2], bounds.tr[0], bounds.tr[1], bounds.tr[2]); ++passed; + } + + int status = cr8r_kd_ify(&points, &cr8r_kdft_c3i64, 0, points.len); + ++tested; + if(status){ + if(cr8r_kd_check_tree(&points, &cr8r_kdft_c3i64, 0, points.len)){ + fprintf(stderr, "\e[1;32mKD Tree built successfully!\e[0m\n"); + ++passed; + }else{ + fprintf(stderr, "\e[1;31mKD Tree built wrong!\e[0m\n"); + } }else{ - fprintf(stderr, "\e[1;31mKD Tree built wrong!\e[0m\n"); + fprintf(stderr, "\e[1;31mKD Tree build failed!\e[0m\n"); + } + + fprintf(stderr, "\e[1;34mFinding %1$d closest points for %2$d random points in [-%3$d, %3$d]^3 ...\e[0m\n", KCS_COUNT, KCS_TRIALS, KCS_SIZE); + for(uint64_t i = 0; i < KCS_TRIALS; ++i){ + ++tested; + int64_t point[3]; + for(uint64_t j = 0; j < 3; ++j){ + int64_t x = cr8r_prng_uniform_u64(prng, 0, KCS_SIZE + 1); + point[j] = x - KCS_SIZE/2; + } + fprintf(stderr, "\e[1;34m - (%"PRIi64", %"PRIi64", %"PRIi64")\e[0m\n", point[0], point[1], point[2]); + cr8r_kd_k_closest(&points, &cr8r_kdft_c3i64, &bounds, point, KCS_COUNT, &res_points1); + if(res_points1.len != KCS_COUNT){ + fprintf(stderr, "\e[1;31mkd_k_closest found %"PRIu64" points (%d requested)\e[0m\n", res_points1.len, KCS_COUNT); + continue; + } + cr8r_kd_k_closest_naive(&points, &cr8r_kdft_c3i64, &bounds, point, KCS_COUNT, &res_points2); + if(res_points2.len != KCS_COUNT){ + fprintf(stderr, "\e[1;31mkd_k_closest_naive found %"PRIu64" points (%d requested)\e[0m\n", res_points2.len, KCS_COUNT); + cr8r_vec_clear(&res_points2, &point_kcs.ft.super); + continue; + } + point_kcs.pt = point; + cr8r_vec_sort(&res_points1, &point_kcs.ft.super); + cr8r_vec_sort(&res_points2, &point_kcs.ft.super); + bool is_same = true; + for(uint64_t j = 0; j < KCS_COUNT; ++j){ + double d_a = point_kcs.ft.sqdist(&point_kcs.ft, point_kcs.pt, cr8r_vec_get(&res_points1, &point_kcs.ft.super, j)); + double d_b = point_kcs.ft.sqdist(&point_kcs.ft, point_kcs.pt, cr8r_vec_get(&res_points2, &point_kcs.ft.super, j)); + if(d_a != d_b){ + if(is_same){ + print_kcs(&res_points1); + print_kcs(&res_points2); + } + is_same = false; + } + } + if(is_same){ + ++passed; + }else{ + fprintf(stderr, "\e[1;31mkd_k_closest did not produce the same points as naive search!\e[0m\n"); + } } - }else{ - fprintf(stderr, "\e[1;32mKD Tree build failed!\e[0m\n"); } if(passed == tested){ @@ -63,6 +134,8 @@ int main(){ fprintf(stderr, "\e[1;31mFailed: passed %"PRIu64"/%"PRIu64" tests\e[0m\n", passed, tested); } cr8r_vec_delete(&points, &cr8r_kdft_c3i64.super); + cr8r_vec_delete(&res_points1, &point_kcs.ft.super); + cr8r_vec_delete(&res_points2, &point_kcs.ft.super); free(prng); }