From e4d26eaa361a1bda7b29b7cbf4ef7162e45dd4f5 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Thu, 21 Jan 2021 10:30:38 -0700 Subject: [PATCH] Remove entry count state from perfect hash table --- .../JoinHashTable/PerfectJoinHashTable.cpp | 31 ++++++++++++++----- .../JoinHashTable/PerfectJoinHashTable.h | 4 +-- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/QueryEngine/JoinHashTable/PerfectJoinHashTable.cpp b/QueryEngine/JoinHashTable/PerfectJoinHashTable.cpp index 2f24b2b154..f72530bf37 100644 --- a/QueryEngine/JoinHashTable/PerfectJoinHashTable.cpp +++ b/QueryEngine/JoinHashTable/PerfectJoinHashTable.cpp @@ -501,10 +501,6 @@ int PerfectJoinHashTable::initHashTableForDevice( #ifndef HAVE_CUDA CHECK_EQ(Data_Namespace::CPU_LEVEL, effective_memory_level); #endif - if (!device_id) { - hash_entry_count_ = hash_entry_info.getNormalizedHashEntryCount(); - } - int err{0}; const int32_t hash_join_invalid_val{-1}; if (effective_memory_level == Data_Namespace::CPU_LEVEL) { @@ -801,18 +797,29 @@ size_t PerfectJoinHashTable::payloadBufferOff() const noexcept { } size_t PerfectJoinHashTable::getComponentBufferSize() const noexcept { - if (hash_type_ == HashType::OneToMany) { - return hash_entry_count_ * sizeof(int32_t); + if (hash_tables_for_device_.empty()) { + return 0; + } + auto hash_table = hash_tables_for_device_.front(); + CHECK(hash_table); + if (hash_table->getLayout() == HashType::OneToMany) { + return hash_table->getEntryCount() * sizeof(int32_t); } else { return 0; } } +HashTable* PerfectJoinHashTable::getHashTableForDevice(const size_t device_id) const { + CHECK_LT(device_id, hash_tables_for_device_.size()); + return hash_tables_for_device_[device_id].get(); +} + std::string PerfectJoinHashTable::toString(const ExecutorDeviceType device_type, const int device_id, bool raw) const { auto buffer = getJoinHashBuffer(device_type, device_id); auto buffer_size = getJoinHashBufferSize(device_type, device_id); + auto hash_table = getHashTableForDevice(device_id); #ifdef HAVE_CUDA std::unique_ptr buffer_copy; if (device_type == ExecutorDeviceType::GPU) { @@ -835,7 +842,7 @@ std::string PerfectJoinHashTable::toString(const ExecutorDeviceType device_type, getHashTypeString(hash_type_), 0, 0, - hash_entry_count_, + hash_table ? hash_table->getEntryCount() : 0, ptr1, ptr2, ptr3, @@ -849,6 +856,7 @@ std::set PerfectJoinHashTable::toSet( const int device_id) const { auto buffer = getJoinHashBuffer(device_type, device_id); auto buffer_size = getJoinHashBufferSize(device_type, device_id); + auto hash_table = getHashTableForDevice(device_id); #ifdef HAVE_CUDA std::unique_ptr buffer_copy; if (device_type == ExecutorDeviceType::GPU) { @@ -867,7 +875,14 @@ std::set PerfectJoinHashTable::toSet( auto ptr2 = ptr1 + offsetBufferOff(); auto ptr3 = ptr1 + countBufferOff(); auto ptr4 = ptr1 + payloadBufferOff(); - return HashTable::toSet(0, 0, hash_entry_count_, ptr1, ptr2, ptr3, ptr4, buffer_size); + return HashTable::toSet(0, + 0, + hash_table ? hash_table->getEntryCount() : 0, + ptr1, + ptr2, + ptr3, + ptr4, + buffer_size); } llvm::Value* PerfectJoinHashTable::codegenSlot(const CompilationOptions& co, diff --git a/QueryEngine/JoinHashTable/PerfectJoinHashTable.h b/QueryEngine/JoinHashTable/PerfectJoinHashTable.h index 75a344fe68..8f24f380dd 100644 --- a/QueryEngine/JoinHashTable/PerfectJoinHashTable.h +++ b/QueryEngine/JoinHashTable/PerfectJoinHashTable.h @@ -146,7 +146,6 @@ class PerfectJoinHashTable : public HashJoin { , query_infos_(query_infos) , memory_level_(memory_level) , hash_type_(preferred_hash_type) - , hash_entry_count_(0) , col_range_(col_range) , executor_(executor) , column_cache_(column_cache) @@ -185,12 +184,13 @@ class PerfectJoinHashTable : public HashJoin { size_t getComponentBufferSize() const noexcept override; + HashTable* getHashTableForDevice(const size_t device_id) const; + std::shared_ptr qual_bin_oper_; std::shared_ptr col_var_; const std::vector& query_infos_; const Data_Namespace::MemoryLevel memory_level_; HashType hash_type_; - size_t hash_entry_count_; std::mutex cpu_hash_table_buff_mutex_; ExpressionRange col_range_;