Skip to content

Commit

Permalink
vectorized hash join
Browse files Browse the repository at this point in the history
  • Loading branch information
Dousir9 committed Oct 7, 2023
1 parent 28da525 commit ab7139f
Show file tree
Hide file tree
Showing 17 changed files with 356 additions and 364 deletions.
77 changes: 23 additions & 54 deletions src/common/hashtable/src/hashjoin_hashtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub struct RawEntry<K> {
pub struct HashJoinHashTable<K: Keyable, A: Allocator + Clone = MmapAllocator> {
pub(crate) pointers: Box<[u64], A>,
pub(crate) atomic_pointers: *mut AtomicU64,
pub(crate) hash_mask: usize,
pub(crate) hash_mask: u64,
pub(crate) phantom: PhantomData<K>,
}

Expand All @@ -68,7 +68,7 @@ impl<K: Keyable, A: Allocator + Clone + Default> HashJoinHashTable<K, A> {
Box::new_zeroed_slice_in(capacity, Default::default()).assume_init()
},
atomic_pointers: std::ptr::null_mut(),
hash_mask: capacity - 1,
hash_mask: capacity as u64 - 1,
phantom: PhantomData,
};
hashtable.atomic_pointers = unsafe {
Expand All @@ -78,7 +78,7 @@ impl<K: Keyable, A: Allocator + Clone + Default> HashJoinHashTable<K, A> {
}

pub fn insert(&mut self, key: K, raw_entry_ptr: *mut RawEntry<K>) {
let index = key.hash() as usize & self.hash_mask;
let index = (key.hash() & self.hash_mask) as usize;
// # Safety
// `index` is less than the capacity of hash table.
let mut head = unsafe { (*self.atomic_pointers.add(index)).load(Ordering::Relaxed) };
Expand Down Expand Up @@ -107,73 +107,42 @@ where
{
type Key = K;

fn contains(&self, key_ref: &Self::Key) -> bool {
let index = key_ref.hash() as usize & self.hash_mask;
let mut raw_entry_ptr = self.pointers[index];
loop {
if raw_entry_ptr == 0 {
break;
}
let raw_entry = unsafe { &*(raw_entry_ptr as *mut RawEntry<K>) };
if key_ref == &raw_entry.key {
return true;
}
raw_entry_ptr = raw_entry.next;
}
false
// Using hashes to probe hash table and converting them in-place to pointers for memory reuse.
fn probe(&self, hashes: &mut [u64]) {
hashes
.iter_mut()
.for_each(|hash| *hash = self.pointers[(*hash & self.hash_mask) as usize]);
}

fn probe_hash_table(
&self,
key_ref: &Self::Key,
vec_ptr: *mut RowPtr,
mut occupied: usize,
capacity: usize,
) -> (usize, u64) {
let index = key_ref.hash() as usize & self.hash_mask;
let origin = occupied;
let mut raw_entry_ptr = self.pointers[index];
fn next_contains(&self, key: &Self::Key, mut ptr: u64) -> bool {
loop {
if raw_entry_ptr == 0 || occupied >= capacity {
if ptr == 0 {
break;
}
let raw_entry = unsafe { &*(raw_entry_ptr as *mut RawEntry<K>) };
if key_ref == &raw_entry.key {
// # Safety
// occupied is less than the capacity of vec_ptr.
unsafe {
std::ptr::copy_nonoverlapping(
&raw_entry.row_ptr as *const RowPtr,
vec_ptr.add(occupied),
1,
)
};
occupied += 1;
let raw_entry = unsafe { &*(ptr as *mut RawEntry<K>) };
if key == &raw_entry.key {
return true;
}
raw_entry_ptr = raw_entry.next;
}
if occupied > origin {
(occupied - origin, raw_entry_ptr)
} else {
(0, 0)
ptr = raw_entry.next;
}
false
}

fn next_incomplete_ptr(
fn next_probe(
&self,
key_ref: &Self::Key,
mut incomplete_ptr: u64,
key: &Self::Key,
mut ptr: u64,
vec_ptr: *mut RowPtr,
mut occupied: usize,
capacity: usize,
) -> (usize, u64) {
let origin = occupied;
loop {
if incomplete_ptr == 0 || occupied >= capacity {
if ptr == 0 || occupied >= capacity {
break;
}
let raw_entry = unsafe { &*(incomplete_ptr as *mut RawEntry<K>) };
if key_ref == &raw_entry.key {
let raw_entry = unsafe { &*(ptr as *mut RawEntry<K>) };
if key == &raw_entry.key {
// # Safety
// occupied is less than the capacity of vec_ptr.
unsafe {
Expand All @@ -185,10 +154,10 @@ where
};
occupied += 1;
}
incomplete_ptr = raw_entry.next;
ptr = raw_entry.next;
}
if occupied > origin {
(occupied - origin, incomplete_ptr)
(occupied - origin, ptr)
} else {
(0, 0)
}
Expand Down
101 changes: 29 additions & 72 deletions src/common/hashtable/src/hashjoin_string_hashtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub struct StringRawEntry {
pub struct HashJoinStringHashTable<A: Allocator + Clone = MmapAllocator> {
pub(crate) pointers: Box<[u64], A>,
pub(crate) atomic_pointers: *mut AtomicU64,
pub(crate) hash_mask: usize,
pub(crate) hash_mask: u64,
}

unsafe impl<A: Allocator + Clone + Send> Send for HashJoinStringHashTable<A> {}
Expand All @@ -49,7 +49,7 @@ impl<A: Allocator + Clone + Default> HashJoinStringHashTable<A> {
Box::new_zeroed_slice_in(capacity, Default::default()).assume_init()
},
atomic_pointers: std::ptr::null_mut(),
hash_mask: capacity - 1,
hash_mask: capacity as u64 - 1,
};
hashtable.atomic_pointers = unsafe {
std::mem::transmute::<*mut u64, *mut AtomicU64>(hashtable.pointers.as_mut_ptr())
Expand All @@ -58,7 +58,7 @@ impl<A: Allocator + Clone + Default> HashJoinStringHashTable<A> {
}

pub fn insert(&mut self, key: &[u8], raw_entry_ptr: *mut StringRawEntry) {
let index = key.fast_hash() as usize & self.hash_mask;
let index = (key.fast_hash() & self.hash_mask) as usize;
// # Safety
// `index` is less than the capacity of hash table.
let mut head = unsafe { (*self.atomic_pointers.add(index)).load(Ordering::Relaxed) };
Expand All @@ -85,23 +85,28 @@ where A: Allocator + Clone + 'static
{
type Key = [u8];

fn contains(&self, key_ref: &Self::Key) -> bool {
let index = key_ref.fast_hash() as usize & self.hash_mask;
let mut raw_entry_ptr = self.pointers[index];
// Using hashes to probe hash table and converting them in-place to pointers for memory reuse.
fn probe(&self, hashes: &mut [u64]) {
hashes
.iter_mut()
.for_each(|hash| *hash = self.pointers[(*hash & self.hash_mask) as usize]);
}

fn next_contains(&self, key: &Self::Key, mut ptr: u64) -> bool {
loop {
if raw_entry_ptr == 0 {
if ptr == 0 {
break;
}
let raw_entry = unsafe { &*(raw_entry_ptr as *mut StringRawEntry) };
let raw_entry = unsafe { &*(ptr as *mut StringRawEntry) };
// Compare `early` and the length of the string, the size of `early` is 4.
let min_len = std::cmp::min(
STRING_EARLY_SIZE,
std::cmp::min(key_ref.len(), raw_entry.length as usize),
std::cmp::min(key.len(), raw_entry.length as usize),
);
if raw_entry.length as usize == key_ref.len()
&& key_ref[0..min_len] == raw_entry.early[0..min_len]
if raw_entry.length as usize == key.len()
&& key[0..min_len] == raw_entry.early[0..min_len]
{
let key = unsafe {
let key_ref = unsafe {
std::slice::from_raw_parts(
raw_entry.key as *const u8,
raw_entry.length as usize,
Expand All @@ -111,79 +116,31 @@ where A: Allocator + Clone + 'static
return true;
}
}
raw_entry_ptr = raw_entry.next;
ptr = raw_entry.next;
}
false
}

fn probe_hash_table(
&self,
key_ref: &Self::Key,
vec_ptr: *mut RowPtr,
mut occupied: usize,
capacity: usize,
) -> (usize, u64) {
let index = key_ref.fast_hash() as usize & self.hash_mask;
let origin = occupied;
let mut raw_entry_ptr = self.pointers[index];
loop {
if raw_entry_ptr == 0 || occupied >= capacity {
break;
}
let raw_entry = unsafe { &*(raw_entry_ptr as *mut StringRawEntry) };
// Compare `early` and the length of the string, the size of `early` is 4.
let min_len = std::cmp::min(STRING_EARLY_SIZE, key_ref.len());
if raw_entry.length as usize == key_ref.len()
&& key_ref[0..min_len] == raw_entry.early[0..min_len]
{
let key = unsafe {
std::slice::from_raw_parts(
raw_entry.key as *const u8,
raw_entry.length as usize,
)
};
if key == key_ref {
// # Safety
// occupied is less than the capacity of vec_ptr.
unsafe {
std::ptr::copy_nonoverlapping(
&raw_entry.row_ptr as *const RowPtr,
vec_ptr.add(occupied),
1,
)
};
occupied += 1;
}
}
raw_entry_ptr = raw_entry.next;
}
if occupied > origin {
(occupied - origin, raw_entry_ptr)
} else {
(0, 0)
}
}

fn next_incomplete_ptr(
fn next_probe(
&self,
key_ref: &Self::Key,
mut incomplete_ptr: u64,
key: &Self::Key,
mut ptr: u64,
vec_ptr: *mut RowPtr,
mut occupied: usize,
capacity: usize,
) -> (usize, u64) {
let origin = occupied;
loop {
if incomplete_ptr == 0 || occupied >= capacity {
if ptr == 0 || occupied >= capacity {
break;
}
let raw_entry = unsafe { &*(incomplete_ptr as *mut StringRawEntry) };
let raw_entry = unsafe { &*(ptr as *mut StringRawEntry) };
// Compare `early` and the length of the string, the size of `early` is 4.
let min_len = std::cmp::min(STRING_EARLY_SIZE, key_ref.len());
if raw_entry.length as usize == key_ref.len()
&& key_ref[0..min_len] == raw_entry.early[0..min_len]
let min_len = std::cmp::min(STRING_EARLY_SIZE, key.len());
if raw_entry.length as usize == key.len()
&& key[0..min_len] == raw_entry.early[0..min_len]
{
let key = unsafe {
let key_ref = unsafe {
std::slice::from_raw_parts(
raw_entry.key as *const u8,
raw_entry.length as usize,
Expand All @@ -202,10 +159,10 @@ where A: Allocator + Clone + 'static
occupied += 1;
}
}
incomplete_ptr = raw_entry.next;
ptr = raw_entry.next;
}
if occupied > origin {
(occupied - origin, incomplete_ptr)
(occupied - origin, ptr)
} else {
(0, 0)
}
Expand Down
17 changes: 6 additions & 11 deletions src/common/hashtable/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,20 +436,15 @@ pub trait HashtableLike {
pub trait HashJoinHashtableLike {
type Key: ?Sized;

fn contains(&self, key_ref: &Self::Key) -> bool;
// Using hashes to probe hash table and converting them in-place to pointers for memory reuse.
fn probe(&self, hashes: &mut [u64]);

fn probe_hash_table(
&self,
key_ref: &Self::Key,
vec_ptr: *mut RowPtr,
occupied: usize,
capacity: usize,
) -> (usize, u64);
fn next_contains(&self, key: &Self::Key, ptr: u64) -> bool;

fn next_incomplete_ptr(
fn next_probe(
&self,
key_ref: &Self::Key,
incomplete_ptr: u64,
key: &Self::Key,
ptr: u64,
vec_ptr: *mut RowPtr,
occupied: usize,
capacity: usize,
Expand Down
22 changes: 12 additions & 10 deletions src/query/expression/src/kernels/group_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,20 @@ impl DataBlock {
hash_key_types: &[DataType],
efficiently_memory: bool,
) -> Result<HashMethodKind> {
if hash_key_types.len() == 1 {
let typ = hash_key_types[0].clone();
if matches!(typ, DataType::String | DataType::Variant) {
return Ok(HashMethodKind::SingleString(
HashMethodSingleString::default(),
));
}
if hash_key_types.len() == 1
&& matches!(
hash_key_types[0],
DataType::String | DataType::Variant | DataType::Bitmap
)
{
return Ok(HashMethodKind::SingleString(
HashMethodSingleString::default(),
));
}

let mut group_key_len = 0;
for typ in hash_key_types {
let not_null_type = typ.remove_nullable();
for hash_key_type in hash_key_types {
let not_null_type = hash_key_type.remove_nullable();

if not_null_type.is_numeric()
|| not_null_type.is_date_or_date_time()
Expand All @@ -69,7 +71,7 @@ impl DataBlock {
group_key_len += not_null_type.numeric_byte_size().unwrap();

// extra one byte for null flag
if typ.is_nullable() {
if hash_key_type.is_nullable() {
group_key_len += 1;
}
} else {
Expand Down
Loading

0 comments on commit ab7139f

Please sign in to comment.