Skip to content

Commit

Permalink
KD tree bugfixes and improvements
Browse files Browse the repository at this point in the history
- moved k_closest_naive to debug header
- added check_layer to debug header
- made k_closest_state and cmp_pt_dist public
so that user code can compare points by
distance to a fixed point (extremely dirty)
- improved k_closest error reporting and docs
- fix split_i64cu (wasn't fully setting sub-bounds)
- add debug checks
- extend pointcloud test: also tests k_closest and has many compile time params
  • Loading branch information
hacatu committed Feb 13, 2024
1 parent 763b696 commit 38890f1
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 83 deletions.
37 changes: 24 additions & 13 deletions include/crater/kd_check.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,31 @@

#include <crater/kd_tree.h>

inline static bool cr8r_kd_check_layer(const cr8r_vec *self, const cr8r_kd_ft *ft, uint64_t a, uint64_t b){
uint64_t mid_idx = (a + b)/2;
const void *mid = self->buf + mid_idx*ft->super.base.size;
for(uint64_t i = a; i < mid_idx; ++i){
const void *ent = self->buf + i*ft->super.base.size;
if(ft->super.cmp(&ft->super.base, ent, mid) > 0){
return false;
}
}
for(uint64_t i = mid_idx + 1; i < b; ++i){
const void *ent = self->buf + i*ft->super.base.size;
if(ft->super.cmp(&ft->super.base, ent, mid) < 0){
return false;
}
}
return true;
}

inline static bool cr8r_kd_check_tree(const cr8r_vec *self, const cr8r_kd_ft *_ft, uint64_t a, uint64_t b){
cr8r_kd_ft ft = *_ft;
while(b > a){
uint64_t mid_idx = (a + b)/2;
const void *mid = self->buf + mid_idx*ft.super.base.size;
for(uint64_t i = a; i < mid_idx; ++i){
const void *ent = self->buf + i*ft.super.base.size;
if(ft.super.cmp(&ft.super.base, ent, mid) > 0){
return false;
}
}
for(uint64_t i = mid_idx + 1; i < b; ++i){
const void *ent = self->buf + i*ft.super.base.size;
if(ft.super.cmp(&ft.super.base, ent, mid) < 0){
return false;
}
if(!cr8r_kd_check_layer(self, &ft, a, b)){
return false;
}
uint64_t mid_idx = (a + b)/2;
// increment depth
++*(uint64_t*)&ft.super.base.data;
if(!cr8r_kd_check_tree(self, &ft, mid_idx + 1, b)){
Expand All @@ -29,3 +37,6 @@ inline static bool cr8r_kd_check_tree(const cr8r_vec *self, const cr8r_kd_ft *_f
return true;
}

bool cr8r_kd_k_closest_naive(cr8r_vec*, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out);


31 changes: 21 additions & 10 deletions include/crater/kd_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
/// License, v. 2.0. If a copy of the MPL was not distributed with this
/// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include "crater/container.h"
#include <stddef.h>
#include <inttypes.h>
#include <stdbool.h>
Expand Down Expand Up @@ -95,6 +94,15 @@ typedef struct{
int64_t tr[3];
} cr8r_kdwin_s2i64;

/// Additional state for { @link }
typedef struct{
cr8r_kd_ft ft;
const void *pt;
cr8r_vec *ents;
uint64_t k;
double max_sqdist;
} cr8r_kd_k_closest_state;

/// Type of a visitor callback for a kd tree traversal
///
/// f(cr8r_kd_ft *ft, const void *win, void *ent, void *data)
Expand Down Expand Up @@ -134,24 +142,27 @@ void cr8r_kd_walk(cr8r_vec*, const cr8r_kd_ft *ft, const void *bounds, cr8r_kdvi

/// Find the k closest points to a given point
///
/// The output vec MUST be initialized ({ @link cr8r_vec_init }), but
/// will be { @link cr8r_vec_clear}'d.
/// @param [in] ft: (uint64_t)_ft->super.base.data is interpreted as the depth of the current subarray, and
/// so should be zero when calling this function generally on the whole vector
/// @param [in] bounds: the bounds of the tree
/// @param [in] pt: point to find k closest points to. Note that if pt is in the tree, it will be returned.
/// If there is a tie for the kth closest point, the list of the k closest points may contain any of the tied points,
/// but will always contain exactly k unless there are fewer than k entries in the tree.
/// @param [in] k: the number of points to find
/// @param [out] out: vector where the k closest points will be stored, in no specified order (the current
/// implementation unsurprisingly places them in a minmax heap order)
void cr8r_kd_k_closest(cr8r_vec*, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out);
/// @param [out] out: vector where the k closest points will be stored. Must be allocated but will be cleared before use.
/// The output is in no specified order (the current implementation unsurprisingly places them in a minmax heap order
/// @return true on success, false on (allocation) failure. Will not allocate if out->cap >= k + 1 already, but may
/// return fewer than k points if fewer than k are available
bool cr8r_kd_k_closest(cr8r_vec*, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out);

/// Find the k closest points to a given point
/// Comparison function that compares two kd tree points based on their distance from a fixed point
///
/// This function is designed to enable testing { @link cr8r_kd_k_closest }. That function prunes its search area if the max
/// distance of any of the k closest points so far is less than the min distance of the bounding box of the subtree rooted ata
/// a node, then that node and its subtree can be skipped. This function does not do that and just compares all points.
void cr8r_kd_k_closest_naive(cr8r_vec*, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out);

/// This function MUST be called with _ft pointing at a cr8r_base_ft WITHIN { @link cr8r_kd_k_closest_state }.
/// See the implementation of { @link cr8r_kd_k_closest } for more details.
/// This function will only read from the `pt` field of the `cr8r_kd_k_closest_state` struct.
int cr8r_default_cmp_kd_kcs_pt_dist(const cr8r_base_ft *_ft, const void *a, const void *b);
/// kdft implementation for spherical kd trees in 3 dimensions with i64 coordinates
extern cr8r_kd_ft cr8r_kdft_s2i64;
/// kdft implementation for cuboid kd trees in 3 dimensions with i64 coordintates
Expand Down
48 changes: 27 additions & 21 deletions src/lib/crater/kd_tree.c
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
#include "crater/container.h"
#include "crater/vec.h"
#include <stdlib.h>
#include <string.h>
#include <inttypes.h>
#include <limits.h>
#include <math.h>

#include <crater/kd_tree.h>
#include <crater/kd_check.h>
#include <crater/minmax_heap.h>

static int cmp_depth_i64cu(const cr8r_base_ft *_ft, const void *_a, const void *_b){
Expand Down Expand Up @@ -191,8 +190,8 @@ static void split_i64cu(const cr8r_kd_ft *ft, const void *_self, const void *_ro
cr8r_kdwin_s2i64 *o1 = _o1;
cr8r_kdwin_s2i64 *o2 = _o2;
uint64_t idx = (uint64_t)ft->super.base.data%ft->dim;
memcpy(o1, self, ft->dim*sizeof(int64_t));
memcpy(o2, self, ft->dim*sizeof(int64_t));
memcpy(o1, self, sizeof(cr8r_kdwin_s2i64));
memcpy(o2, self, sizeof(cr8r_kdwin_s2i64));
o1->tr[idx] = root[idx];
o2->bl[idx] = root[idx];
}
Expand Down Expand Up @@ -270,6 +269,11 @@ bool cr8r_kd_ify(cr8r_vec *self, cr8r_kd_ft *_ft, uint64_t a, uint64_t b){
return 0;
}
piv = cr8r_vec_partition_with_median(self, &ft.super, a, b, piv);
#ifdef DEBUG
if(!piv || !cr8r_kd_check_layer(self, &ft, a, b)){
__builtin_trap();
}
#endif
// increment depth
++*(uint64_t*)&ft.super.base.data;
if(!cr8r_kd_ify(self, &ft, mid_idx + 1, b)){
Expand Down Expand Up @@ -312,16 +316,8 @@ void cr8r_kd_walk(cr8r_vec *self, const cr8r_kd_ft *ft, const void *_bounds, cr8
cr8r_kd_walk_r(self, ft, bounds, visitor, data, 0, self->len);
}

typedef struct{
cr8r_vec *ents;
cr8r_kd_ft ft;
const void *pt;
uint64_t k;
double max_sqdist;
} k_closest_state;

inline static cr8r_walk_decision k_closest_visitor(cr8r_kd_ft *ft, const void *bounds, void *ent, void *_data){
k_closest_state *data = _data;
cr8r_kd_k_closest_state *data = _data;
char tmp[ft->super.base.size];
if(data->ents->len < data->k){
cr8r_mmheap_push(data->ents, &data->ft.super, ent);
Expand All @@ -335,9 +331,9 @@ inline static cr8r_walk_decision k_closest_visitor(cr8r_kd_ft *ft, const void *b
return CR8R_WALK_SKIP_CHILDREN;
}

inline static int cmp_pt_dist(const cr8r_base_ft *_ft, const void *a, const void *b){
int cr8r_default_cmp_kd_kcs_pt_dist(const cr8r_base_ft *_ft, const void *a, const void *b){
const cr8r_kd_ft *ft = (const cr8r_kd_ft*)_ft;
const k_closest_state *data = CR8R_OUTER(ft, k_closest_state, ft);
const cr8r_kd_k_closest_state *data = CR8R_OUTER(ft, cr8r_kd_k_closest_state, ft);
double a_sqdist = ft->sqdist(ft, data->pt, a);
double b_sqdist = ft->sqdist(ft, data->pt, b);
if(a_sqdist < b_sqdist){
Expand All @@ -348,30 +344,40 @@ inline static int cmp_pt_dist(const cr8r_base_ft *_ft, const void *a, const void
return 0;
}

void cr8r_kd_k_closest(cr8r_vec *self, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out){
k_closest_state data = {
bool cr8r_kd_k_closest(cr8r_vec *self, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out){
cr8r_vec_clear(out, &ft->super);
if(!cr8r_vec_ensure_cap(out, &ft->super, k + 1)){
return false;
}
cr8r_kd_k_closest_state data = {
.ents = out,
.ft = *ft,
.pt = pt,
.k = k,
.max_sqdist = INFINITY
};
data.ft.super.cmp = cmp_pt_dist;
data.ft.super.cmp = cr8r_default_cmp_kd_kcs_pt_dist;
cr8r_kd_walk(self, ft, bounds, k_closest_visitor, &data);
return true;
}

void cr8r_kd_k_closest_naive(cr8r_vec *self, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out){
k_closest_state data = {
bool cr8r_kd_k_closest_naive(cr8r_vec *self, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out){
cr8r_vec_clear(out, &ft->super);
if(!cr8r_vec_ensure_cap(out, &ft->super, k + 1)){
return false;
}
cr8r_kd_k_closest_state data = {
.ents = out,
.ft = *ft,
.pt = pt,
.k = k,
.max_sqdist = INFINITY
};
data.ft.super.cmp = cmp_pt_dist;
data.ft.super.cmp = cr8r_default_cmp_kd_kcs_pt_dist;
for(uint64_t i = 0; i < self->len; ++i){
k_closest_visitor(ft, bounds, self->buf + i*ft->super.base.size, &data);
}
return true;
}

cr8r_kd_ft cr8r_kdft_s2i64 = {
Expand Down
Loading

0 comments on commit 38890f1

Please sign in to comment.