Skip to content

Commit

Permalink
improve take
Browse files Browse the repository at this point in the history
  • Loading branch information
Dousir9 committed Sep 21, 2023
1 parent f1c83ad commit 83c6586
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 12 deletions.
4 changes: 2 additions & 2 deletions src/common/hashtable/src/hashjoin_hashtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ where

fn contains(&self, key_ref: &Self::Key) -> bool {
let index = key_ref.hash() as usize & self.hash_mask;
let mut raw_entry_ptr = self.pointers[index];
let mut raw_entry_ptr = unsafe { *self.pointers.get_unchecked(index) };
loop {
if raw_entry_ptr == 0 {
break;
Expand All @@ -132,7 +132,7 @@ where
) -> (usize, u64) {
let index = key_ref.hash() as usize & self.hash_mask;
let origin = occupied;
let mut raw_entry_ptr = self.pointers[index];
let mut raw_entry_ptr = unsafe { *self.pointers.get_unchecked(index) };
loop {
if raw_entry_ptr == 0 || occupied >= capacity {
break;
Expand Down
4 changes: 2 additions & 2 deletions src/common/hashtable/src/hashjoin_string_hashtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ where A: Allocator + Clone + 'static

fn contains(&self, key_ref: &Self::Key) -> bool {
let index = key_ref.fast_hash() as usize & self.hash_mask;
let mut raw_entry_ptr = self.pointers[index];
let mut raw_entry_ptr = unsafe { *self.pointers.get_unchecked(index) };
loop {
if raw_entry_ptr == 0 {
break;
Expand Down Expand Up @@ -125,7 +125,7 @@ where A: Allocator + Clone + 'static
) -> (usize, u64) {
let index = key_ref.fast_hash() as usize & self.hash_mask;
let origin = occupied;
let mut raw_entry_ptr = self.pointers[index];
let mut raw_entry_ptr = unsafe { *self.pointers.get_unchecked(index) };
loop {
if raw_entry_ptr == 0 || occupied >= capacity {
break;
Expand Down
13 changes: 8 additions & 5 deletions src/query/expression/src/kernels/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,12 @@ impl Column {
let num_rows = indices.len();
let mut builder: Vec<T> = Vec::with_capacity(num_rows);
let ptr = builder.as_mut_ptr();
let col_ptr = col.as_slice().as_ptr();
// # Safety
// `i` must be less than `num_rows` and the capacity of builder is `num_rows`.
unsafe {
for (i, index) in indices.iter().enumerate() {
std::ptr::write(ptr.add(i), col[index.to_usize()]);
std::ptr::copy_nonoverlapping(col_ptr.add(index.to_usize()), ptr.add(i), 1);
}
builder.set_len(num_rows);
}
Expand Down Expand Up @@ -227,15 +228,17 @@ impl Column {
offsets.push(0);
let items_ptr = items.as_mut_ptr();
let offsets_ptr = unsafe { offsets.as_mut_ptr().add(1) };

let col_offset = col.offsets().as_slice();
let col_data_ptr = col.data().as_slice().as_ptr();
let mut data_size = 0;
for (i, index) in indices.iter().enumerate() {
let item = unsafe { col.index_unchecked(index.to_usize()) };
data_size += item.len() as u64;
let start = unsafe { *col_offset.get_unchecked(index.to_usize()) } as usize;
let len = unsafe { *col_offset.get_unchecked(index.to_usize() + 1) as usize } - start;
data_size += len as u64;
// # Safety
// `i` must be less than the capacity of Vec.
unsafe {
std::ptr::write(items_ptr.add(i), (item.as_ptr() as u64, item.len()));
std::ptr::write(items_ptr.add(i), (col_data_ptr.add(start) as u64, len));
std::ptr::write(offsets_ptr.add(i), data_size);
}
}
Expand Down
11 changes: 8 additions & 3 deletions src/query/expression/src/kernels/take_compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,14 @@ impl Column {
let ptr = builder.as_mut_ptr();
let mut offset = 0;
let mut remain;
for (index, cnt) in indices {
let col_ptr = col.as_slice().as_ptr();
for (index, cnt) in indices.iter() {
if *cnt == 1 {
// # Safety
// offset + 1 <= num_rows
unsafe { std::ptr::write(ptr.add(offset), col[*index as usize]) };
unsafe {
std::ptr::copy_nonoverlapping(col_ptr.add(*index as usize), ptr.add(offset), 1)
};
offset += 1;
continue;
}
Expand All @@ -204,7 +207,9 @@ impl Column {
let base_offset = offset;
// # Safety
// base_offset + 1 <= num_rows
unsafe { std::ptr::write(ptr.add(offset), col[*index as usize]) };
unsafe {
std::ptr::copy_nonoverlapping(col_ptr.add(*index as usize), ptr.add(offset), 1)
};
remain = *cnt as usize;
// Since cnt > 0, then 31 - cnt.leading_zeros() >= 0.
let max_segment = 1 << (31 - cnt.leading_zeros());
Expand Down

0 comments on commit 83c6586

Please sign in to comment.