From 3a86f6ec3793c89573c08b27161918f0749d0fee Mon Sep 17 00:00:00 2001 From: yoonminnam <53632385+yoonminnam@users.noreply.github.com> Date: Sun, 24 Jan 2021 04:46:52 +0900 Subject: [PATCH] Support parameter control for overlaps join via query hint * Support two param control via query hint: 1) overlaps_bucket_threshold AND 2) overlaps_max_size * Improve the code for parsing / passing query hint * assign default query hint iff query_dag_ is nullptr * Fixup codes / Add test cases --- QueryEngine/CardinalityEstimator.cpp | 2 + QueryEngine/Execute.cpp | 7 +- QueryEngine/Execute.h | 3 +- QueryEngine/ExecutionKernel.h | 1 + QueryEngine/IRCodegen.cpp | 3 +- QueryEngine/JoinHashTable/HashJoin.cpp | 20 ++- QueryEngine/JoinHashTable/HashJoin.h | 3 +- .../JoinHashTable/OverlapsJoinHashTable.cpp | 107 ++++++++++----- .../JoinHashTable/OverlapsJoinHashTable.h | 10 +- QueryEngine/QueryHint.h | 51 ++++++- QueryEngine/QueryRewrite.cpp | 3 + QueryEngine/RelAlgDagBuilder.cpp | 36 ++--- QueryEngine/RelAlgDagBuilder.h | 68 +++++++++- QueryEngine/RelAlgExecutionUnit.h | 2 + QueryEngine/RelAlgExecutor.cpp | 124 ++++++++++-------- QueryEngine/RelAlgExecutor.h | 2 +- QueryRunner/QueryRunner.cpp | 2 +- QueryRunner/QueryRunner.h | 2 +- Tests/SQLHintTest.cpp | 89 +++++++++++-- .../parser/hint/OmniSciHintStrategyTable.java | 5 +- 20 files changed, 409 insertions(+), 131 deletions(-) diff --git a/QueryEngine/CardinalityEstimator.cpp b/QueryEngine/CardinalityEstimator.cpp index 186328d0e5..8a4ada2cac 100644 --- a/QueryEngine/CardinalityEstimator.cpp +++ b/QueryEngine/CardinalityEstimator.cpp @@ -96,6 +96,7 @@ RelAlgExecutionUnit create_ndv_execution_unit(const RelAlgExecutionUnit& ra_exe_ : makeExpr(ra_exe_unit.groupby_exprs), SortInfo{{}, SortAlgorithm::Default, 0, 0}, 0, + ra_exe_unit.query_hint, false, ra_exe_unit.union_all, ra_exe_unit.query_state}; @@ -114,6 +115,7 @@ RelAlgExecutionUnit create_count_all_execution_unit( nullptr, SortInfo{{}, SortAlgorithm::Default, 0, 0}, 0, + ra_exe_unit.query_hint, false, ra_exe_unit.union_all, ra_exe_unit.query_state}; diff --git a/QueryEngine/Execute.cpp b/QueryEngine/Execute.cpp index c8195ecfb8..98c3617420 100644 --- a/QueryEngine/Execute.cpp +++ b/QueryEngine/Execute.cpp @@ -1335,6 +1335,7 @@ RelAlgExecutionUnit replace_scan_limit(const RelAlgExecutionUnit& ra_exe_unit_in ra_exe_unit_in.estimator, ra_exe_unit_in.sort_info, new_scan_limit, + ra_exe_unit_in.query_hint, ra_exe_unit_in.use_bump_allocator, ra_exe_unit_in.union_all, ra_exe_unit_in.query_state}; @@ -3120,7 +3121,8 @@ Executor::JoinHashTableOrError Executor::buildHashTableForQualifier( const std::vector& query_infos, const MemoryLevel memory_level, const HashType preferred_hash_type, - ColumnCacheMap& column_cache) { + ColumnCacheMap& column_cache, + const QueryHint& query_hint) { if (!g_enable_overlaps_hashjoin && qual_bin_oper->is_overlaps_oper()) { return {nullptr, "Overlaps hash join disabled, attempting to fall back to loop join"}; } @@ -3137,7 +3139,8 @@ Executor::JoinHashTableOrError Executor::buildHashTableForQualifier( preferred_hash_type, deviceCountForMemoryLevel(memory_level), column_cache, - this); + this, + query_hint); return {tbl, ""}; } catch (const HashJoinFail& e) { return {nullptr, e.what()}; diff --git a/QueryEngine/Execute.h b/QueryEngine/Execute.h index ff48b5de5f..889376a24e 100644 --- a/QueryEngine/Execute.h +++ b/QueryEngine/Execute.h @@ -832,7 +832,8 @@ class Executor { const std::vector& query_infos, const MemoryLevel memory_level, const HashType preferred_hash_type, - ColumnCacheMap& column_cache); + ColumnCacheMap& column_cache, + const QueryHint& query_hint); void nukeOldState(const bool allow_lazy_fetch, const std::vector& query_infos, const PlanState::DeletedColumnsMap& deleted_cols_map, diff --git a/QueryEngine/ExecutionKernel.h b/QueryEngine/ExecutionKernel.h index c2d3c7bbd4..de3383a3c4 100644 --- a/QueryEngine/ExecutionKernel.h +++ b/QueryEngine/ExecutionKernel.h @@ -43,6 +43,7 @@ class SharedKernelContext { std::vector all_frag_row_offsets_; std::mutex all_frag_row_offsets_mutex_; const std::vector& query_infos_; + const QueryHint query_hint_; }; class ExecutionKernel { diff --git a/QueryEngine/IRCodegen.cpp b/QueryEngine/IRCodegen.cpp index dd5768bf82..159bde3f88 100644 --- a/QueryEngine/IRCodegen.cpp +++ b/QueryEngine/IRCodegen.cpp @@ -518,7 +518,8 @@ std::shared_ptr Executor::buildCurrentLevelHashTable( co.device_type == ExecutorDeviceType::GPU ? MemoryLevel::GPU_LEVEL : MemoryLevel::CPU_LEVEL, HashType::OneToOne, - column_cache); + column_cache, + ra_exe_unit.query_hint); current_level_hash_table = hash_table_or_error.hash_table; } if (hash_table_or_error.hash_table) { diff --git a/QueryEngine/JoinHashTable/HashJoin.cpp b/QueryEngine/JoinHashTable/HashJoin.cpp index 6fa1ffba93..645faad3b1 100644 --- a/QueryEngine/JoinHashTable/HashJoin.cpp +++ b/QueryEngine/JoinHashTable/HashJoin.cpp @@ -222,7 +222,8 @@ std::shared_ptr HashJoin::getInstance( const HashType preferred_hash_type, const int device_count, ColumnCacheMap& column_cache, - Executor* executor) { + Executor* executor, + const QueryHint& query_hint) { auto timer = DEBUG_TIMER(__func__); std::shared_ptr join_hash_table; CHECK_GT(device_count, 0); @@ -232,8 +233,13 @@ std::shared_ptr HashJoin::getInstance( } if (qual_bin_oper->is_overlaps_oper()) { VLOG(1) << "Trying to build geo hash table:"; - join_hash_table = OverlapsJoinHashTable::getInstance( - qual_bin_oper, query_infos, memory_level, device_count, column_cache, executor); + join_hash_table = OverlapsJoinHashTable::getInstance(qual_bin_oper, + query_infos, + memory_level, + device_count, + column_cache, + executor, + query_hint); } else if (dynamic_cast( qual_bin_oper->get_left_operand())) { VLOG(1) << "Trying to build keyed hash table:"; @@ -458,6 +464,7 @@ std::shared_ptr HashJoin::getSyntheticInstance( AllColumnVarsVisitor().visit(qual_bin_oper.get()); auto query_infos = getSyntheticInputTableInfo(cvs, executor); setupSyntheticCaching(cvs, executor); + QueryHint query_hint = QueryHint::defaults(); auto hash_table = HashJoin::getInstance(qual_bin_oper, query_infos, @@ -465,7 +472,8 @@ std::shared_ptr HashJoin::getSyntheticInstance( preferred_hash_type, device_count, column_cache, - executor); + executor, + query_hint); return hash_table; } @@ -481,6 +489,7 @@ std::shared_ptr HashJoin::getSyntheticInstance( AllColumnVarsVisitor().visit(qual_bin_oper.get()); auto query_infos = getSyntheticInputTableInfo(cvs, executor); setupSyntheticCaching(cvs, executor); + QueryHint query_hint = QueryHint::defaults(); auto hash_table = HashJoin::getInstance(qual_bin_oper, query_infos, @@ -488,7 +497,8 @@ std::shared_ptr HashJoin::getSyntheticInstance( preferred_hash_type, device_count, column_cache, - executor); + executor, + query_hint); return hash_table; } diff --git a/QueryEngine/JoinHashTable/HashJoin.h b/QueryEngine/JoinHashTable/HashJoin.h index 66d2e87e95..0c331ffa33 100644 --- a/QueryEngine/JoinHashTable/HashJoin.h +++ b/QueryEngine/JoinHashTable/HashJoin.h @@ -166,7 +166,8 @@ class HashJoin { const HashType preferred_hash_type, const int device_count, ColumnCacheMap& column_cache, - Executor* executor); + Executor* executor, + const QueryHint& query_hint); //! Make hash table from named tables and columns (such as for testing). static std::shared_ptr getSyntheticInstance( diff --git a/QueryEngine/JoinHashTable/OverlapsJoinHashTable.cpp b/QueryEngine/JoinHashTable/OverlapsJoinHashTable.cpp index cda83e2aa9..f590decd9e 100644 --- a/QueryEngine/JoinHashTable/OverlapsJoinHashTable.cpp +++ b/QueryEngine/JoinHashTable/OverlapsJoinHashTable.cpp @@ -40,7 +40,8 @@ std::shared_ptr OverlapsJoinHashTable::getInstance( const Data_Namespace::MemoryLevel memory_level, const int device_count, ColumnCacheMap& column_cache, - Executor* executor) { + Executor* executor, + const QueryHint& query_hint) { decltype(std::chrono::steady_clock::now()) ts1, ts2; auto inner_outer_pairs = normalize_column_pairs( condition.get(), *executor->getCatalog(), executor->getTemporaryTables()); @@ -95,6 +96,9 @@ std::shared_ptr OverlapsJoinHashTable::getInstance( executor, inner_outer_pairs, device_count); + if (query_hint.hint_delivered) { + join_hash_table->registerQueryHint(query_hint); + } try { join_hash_table->reify(layout); } catch (const HashJoinFail& e) { @@ -130,6 +134,34 @@ void OverlapsJoinHashTable::reifyWithLayout(const HashType layout) { if (query_info.fragments.empty()) { return; } + + auto overlaps_max_table_size_bytes = g_overlaps_max_table_size_bytes; + bool use_user_given_bucket_threshold = false; + auto query_hint = getRegisteredQueryHint(); + if (query_hint.hint_delivered) { + if (query_hint.overlaps_bucket_threshold != overlaps_hashjoin_bucket_threshold_) { + VLOG(1) << "User changes a threshold \'overlaps_hashjoin_bucket_threshold\' via " + "query hint: " + << overlaps_hashjoin_bucket_threshold_ << " -> " + << query_hint.overlaps_bucket_threshold; + overlaps_hashjoin_bucket_threshold_ = query_hint.overlaps_bucket_threshold; + use_user_given_bucket_threshold = true; + } + if (query_hint.overlaps_max_size != overlaps_max_table_size_bytes) { + std::ostringstream oss; + oss << "User requests to change a threshold \'overlaps_max_table_size_bytes\' via " + "query hint: " + << overlaps_max_table_size_bytes << " -> " << query_hint.overlaps_max_size; + if (!use_user_given_bucket_threshold) { + overlaps_max_table_size_bytes = query_hint.overlaps_max_size; + } else { + oss << ", but is skipped since the query hint also changes the threshold " + "\'overlaps_hashjoin_bucket_threshold\'"; + } + VLOG(1) << oss.str(); + } + } + std::vector columns_per_device; const auto catalog = executor_->getCatalog(); CHECK(catalog); @@ -175,42 +207,45 @@ void OverlapsJoinHashTable::reifyWithLayout(const HashType layout) { // Auto-tuner: Pre-calculate some possible hash table sizes. std::lock_guard guard(auto_tuner_cache_mutex_); - auto atc = auto_tuner_cache_.find(cache_key); - if (atc != auto_tuner_cache_.end()) { - overlaps_hashjoin_bucket_threshold_ = atc->second; - VLOG(1) << "Auto tuner using cached overlaps hash table size of: " - << overlaps_hashjoin_bucket_threshold_; - } else { - VLOG(1) << "Auto tuning for the overlaps hash table size:"; - // TODO(jclay): Currently, joining on large poly sets - // will lead to lengthy construction times (and large hash tables) - // tune this to account for the characteristics of the data being joined. - const double min_threshold{1e-5}; - const double max_threshold{1}; - double good_threshold{max_threshold}; - for (double threshold = max_threshold; threshold >= min_threshold; - threshold /= 10.0) { - overlaps_hashjoin_bucket_threshold_ = threshold; - size_t entry_count; - size_t emitted_keys_count; - std::tie(entry_count, emitted_keys_count) = - calculateCounts(shard_count, query_info, columns_per_device); - size_t hash_table_size = calculateHashTableSize( - bucket_sizes_for_dimension_.size(), emitted_keys_count, entry_count); - bucket_sizes_for_dimension_.clear(); - VLOG(1) << "Calculated bin threshold of " << std::fixed << threshold - << " giving: entry count " << entry_count << " hash table size " - << hash_table_size; - if (hash_table_size <= g_overlaps_max_table_size_bytes) { - good_threshold = overlaps_hashjoin_bucket_threshold_; - } else { - VLOG(1) << "Rejected bin threshold of " << std::fixed << threshold; - break; + if (!use_user_given_bucket_threshold) { + // auto-tuning is valid iff no query hint is delivered to change bucket threshold + auto atc = auto_tuner_cache_.find(cache_key); + if (atc != auto_tuner_cache_.end()) { + overlaps_hashjoin_bucket_threshold_ = atc->second; + VLOG(1) << "Auto tuner using cached overlaps hash table size of: " + << overlaps_hashjoin_bucket_threshold_; + } else { + VLOG(1) << "Auto tuning for the overlaps hash table size:"; + // TODO(jclay): Currently, joining on large poly sets + // will lead to lengthy construction times (and large hash tables) + // tune this to account for the characteristics of the data being joined. + const double min_threshold{1e-5}; + const double max_threshold{1}; + double good_threshold{max_threshold}; + for (double threshold = max_threshold; threshold >= min_threshold; + threshold /= 10.0) { + overlaps_hashjoin_bucket_threshold_ = threshold; + size_t entry_count; + size_t emitted_keys_count; + std::tie(entry_count, emitted_keys_count) = + calculateCounts(shard_count, query_info, columns_per_device); + size_t hash_table_size = calculateHashTableSize( + bucket_sizes_for_dimension_.size(), emitted_keys_count, entry_count); + bucket_sizes_for_dimension_.clear(); + VLOG(1) << "Calculated bin threshold of " << std::fixed << threshold + << " giving: entry count " << entry_count << " hash table size " + << hash_table_size; + if (hash_table_size <= overlaps_max_table_size_bytes) { + good_threshold = overlaps_hashjoin_bucket_threshold_; + } else { + VLOG(1) << "Rejected bin threshold of " << std::fixed << threshold; + break; + } + } + overlaps_hashjoin_bucket_threshold_ = good_threshold; + if (!cache_key_contains_intermediate_table(cache_key)) { + auto_tuner_cache_[cache_key] = overlaps_hashjoin_bucket_threshold_; } - } - overlaps_hashjoin_bucket_threshold_ = good_threshold; - if (!cache_key_contains_intermediate_table(cache_key)) { - auto_tuner_cache_[cache_key] = overlaps_hashjoin_bucket_threshold_; } } diff --git a/QueryEngine/JoinHashTable/OverlapsJoinHashTable.h b/QueryEngine/JoinHashTable/OverlapsJoinHashTable.h index 71dff8b3a7..0907315f9a 100644 --- a/QueryEngine/JoinHashTable/OverlapsJoinHashTable.h +++ b/QueryEngine/JoinHashTable/OverlapsJoinHashTable.h @@ -39,6 +39,7 @@ class OverlapsJoinHashTable : public HashJoin { , device_count_(device_count) { CHECK_GT(device_count_, 0); hash_tables_for_device_.resize(std::max(device_count_, 1)); + query_hint_ = QueryHint::defaults(); } virtual ~OverlapsJoinHashTable() {} @@ -50,7 +51,8 @@ class OverlapsJoinHashTable : public HashJoin { const Data_Namespace::MemoryLevel memory_level, const int device_count, ColumnCacheMap& column_cache, - Executor* executor); + Executor* executor, + const QueryHint& query_hint); static auto getCacheInvalidator() -> std::function { VLOG(1) << "Invalidate " << auto_tuner_cache_.size() << " cached overlaps hashtable."; @@ -135,6 +137,10 @@ class OverlapsJoinHashTable : public HashJoin { return nullptr; } + const QueryHint& getRegisteredQueryHint() { return query_hint_; } + + void registerQueryHint(const QueryHint& query_hint) { query_hint_ = query_hint; } + static std::map auto_tuner_cache_; static std::mutex auto_tuner_cache_mutex_; @@ -245,4 +251,6 @@ class OverlapsJoinHashTable : public HashJoin { using HashTableCacheValue = std::shared_ptr; static std::unique_ptr> hash_table_cache_; + + QueryHint query_hint_; }; diff --git a/QueryEngine/QueryHint.h b/QueryEngine/QueryHint.h index 3bb50dd9ba..c0b4934998 100644 --- a/QueryEngine/QueryHint.h +++ b/QueryEngine/QueryHint.h @@ -17,8 +17,57 @@ #ifndef OMNISCI_QUERYHINT_H #define OMNISCI_QUERYHINT_H +#include "ThriftHandler/CommandLineOptions.h" + struct QueryHint { - bool cpu_mode{false}; + // for each hint "H", we first define its value as the corresponding system-defined + // default value "D" + // After then, if we detect at least one hint is registered (via hint_delivered), + // we can compare the value btw. "H" and "D" during the query compilation step that H + // is involved and then use the "H" iff "H" != "D" + // since that indicates user-given hint is delivered + // (otherwise, "H" should be the equal to "D") + // note that we should check if H is valid W.R.T the proper value range + // i.e., if H is valid in 0.0 ~ 1.0, then we check that at the point + // when we decide to use H, and use D iff given H does not have a valid value + QueryHint() { + hint_delivered = false; + cpu_mode = false; + overlaps_bucket_threshold = 0.1; + overlaps_max_size = g_overlaps_max_table_size_bytes; + } + + QueryHint& operator=(const QueryHint& other) { + hint_delivered = other.hint_delivered; + cpu_mode = other.cpu_mode; + overlaps_bucket_threshold = other.overlaps_bucket_threshold; + overlaps_max_size = other.overlaps_max_size; + return *this; + } + + QueryHint(const QueryHint& other) { + hint_delivered = other.hint_delivered; + cpu_mode = other.cpu_mode; + overlaps_bucket_threshold = other.overlaps_bucket_threshold; + overlaps_max_size = other.overlaps_max_size; + } + + // set true if at least one query hint is delivered + bool hint_delivered; + + // general query execution + bool cpu_mode; + + // overlaps hash join + double overlaps_bucket_threshold; // defined in "OverlapsJoinHashTable.h" + size_t overlaps_max_size; + + std::unordered_map OMNISCI_SUPPORTED_HINT_CLASS = { + {"cpu_mode", 0}, + {"overlaps_bucket_threshold", 1}, + {"overlaps_max_size", 2}}; + + static QueryHint defaults() { return QueryHint(); } }; #endif // OMNISCI_QUERYHINT_H diff --git a/QueryEngine/QueryRewrite.cpp b/QueryEngine/QueryRewrite.cpp index 19fcff0d4d..54929ff0b2 100644 --- a/QueryEngine/QueryRewrite.cpp +++ b/QueryEngine/QueryRewrite.cpp @@ -77,6 +77,7 @@ RelAlgExecutionUnit QueryRewriter::rewriteOverlapsJoin( ra_exe_unit_in.estimator, ra_exe_unit_in.sort_info, ra_exe_unit_in.scan_limit, + ra_exe_unit_in.query_hint, ra_exe_unit_in.use_bump_allocator}; } @@ -369,6 +370,7 @@ RelAlgExecutionUnit QueryRewriter::rewriteColumnarUpdate( ra_exe_unit_in.estimator, ra_exe_unit_in.sort_info, ra_exe_unit_in.scan_limit, + ra_exe_unit_in.query_hint, ra_exe_unit_in.use_bump_allocator, ra_exe_unit_in.union_all, ra_exe_unit_in.query_state}; @@ -468,6 +470,7 @@ RelAlgExecutionUnit QueryRewriter::rewriteColumnarDelete( ra_exe_unit_in.estimator, ra_exe_unit_in.sort_info, ra_exe_unit_in.scan_limit, + ra_exe_unit_in.query_hint, ra_exe_unit_in.use_bump_allocator, ra_exe_unit_in.union_all, ra_exe_unit_in.query_state}; diff --git a/QueryEngine/RelAlgDagBuilder.cpp b/QueryEngine/RelAlgDagBuilder.cpp index 9ffc1e703d..a238536fb0 100644 --- a/QueryEngine/RelAlgDagBuilder.cpp +++ b/QueryEngine/RelAlgDagBuilder.cpp @@ -1199,43 +1199,47 @@ void bind_inputs(const std::vector>& nodes) noexcept void handleQueryHint(const std::vector>& nodes, RelAlgDagBuilder* dag_builder) noexcept { - QueryHint query_hints; + Hints* hint_delivered = nullptr; for (auto node : nodes) { const auto agg_node = std::dynamic_pointer_cast(node); if (agg_node) { - if (agg_node->hasHintEnabled("cpu_mode")) { - query_hints.cpu_mode = true; + if (agg_node->hasDeliveredHint()) { + hint_delivered = agg_node->getDeliveredHints(); + break; } } const auto project_node = std::dynamic_pointer_cast(node); if (project_node) { - if (project_node->hasHintEnabled("cpu_mode")) { - query_hints.cpu_mode = true; + if (project_node->hasDeliveredHint()) { + hint_delivered = project_node->getDeliveredHints(); + break; } } const auto scan_node = std::dynamic_pointer_cast(node); if (scan_node) { - if (scan_node->hasHintEnabled("cpu_mode")) { - query_hints.cpu_mode = true; + if (scan_node->hasDeliveredHint()) { + hint_delivered = scan_node->getDeliveredHints(); + break; } } const auto join_node = std::dynamic_pointer_cast(node); if (join_node) { - if (join_node->hasHintEnabled("cpu_mode")) { - query_hints.cpu_mode = true; + if (join_node->hasDeliveredHint()) { + hint_delivered = join_node->getDeliveredHints(); + break; } } const auto compound_node = std::dynamic_pointer_cast(node); if (compound_node) { - if (compound_node->hasHintEnabled("cpu_mode")) { - query_hints.cpu_mode = true; + if (compound_node->hasDeliveredHint()) { + hint_delivered = compound_node->getDeliveredHints(); + break; } } } - if (query_hints.cpu_mode) { - VLOG(1) << "A user forces to run the query on the CPU execution mode"; + if (hint_delivered && !hint_delivered->empty()) { + dag_builder->registerQueryHints(hint_delivered); } - dag_builder->registerQueryHints(query_hints); } void mark_nops(const std::vector>& nodes) noexcept { @@ -2575,7 +2579,7 @@ class RelAlgDispatcher { RelAlgDagBuilder::RelAlgDagBuilder(const std::string& query_ra, const Catalog_Namespace::Catalog& cat, const RenderInfo* render_info) - : cat_(cat), render_info_(render_info) { + : cat_(cat), render_info_(render_info), query_hint_(QueryHint::defaults()) { rapidjson::Document query_ast; query_ast.Parse(query_ra.c_str()); VLOG(2) << "Parsing query RA JSON: " << query_ra; @@ -2597,7 +2601,7 @@ RelAlgDagBuilder::RelAlgDagBuilder(RelAlgDagBuilder& root_dag_builder, const rapidjson::Value& query_ast, const Catalog_Namespace::Catalog& cat, const RenderInfo* render_info) - : cat_(cat), render_info_(render_info) { + : cat_(cat), render_info_(render_info), query_hint_(QueryHint::defaults()) { build(query_ast, root_dag_builder); } diff --git a/QueryEngine/RelAlgDagBuilder.h b/QueryEngine/RelAlgDagBuilder.h index 3ddfd4df69..57ef7a2e02 100644 --- a/QueryEngine/RelAlgDagBuilder.h +++ b/QueryEngine/RelAlgDagBuilder.h @@ -840,6 +840,10 @@ class RelScan : public RelAlgNode { return hints_->at(hint_name); } + bool hasDeliveredHint() { return !hints_->empty(); } + + Hints* getDeliveredHints() { return hints_.get(); } + private: const TableDescriptor* td_; const std::vector field_names_; @@ -1004,6 +1008,10 @@ class RelProject : public RelAlgNode, public ModifyManipulationTarget { return hints_->at(hint_name); } + bool hasDeliveredHint() { return !hints_->empty(); } + + Hints* getDeliveredHints() { return hints_.get(); } + private: template void visitScalarExprs(EXPR_VISITOR_FUNCTOR visitor_functor) const { @@ -1113,6 +1121,10 @@ class RelAggregate : public RelAlgNode { return hints_->at(hint_name); } + bool hasDeliveredHint() { return !hints_->empty(); } + + Hints* getDeliveredHints() { return hints_.get(); } + private: const size_t groupby_count_; std::vector> agg_exprs_; @@ -1188,6 +1200,10 @@ class RelJoin : public RelAlgNode { return hints_->at(hint_name); } + bool hasDeliveredHint() { return !hints_->empty(); } + + Hints* getDeliveredHints() { return hints_.get(); } + private: mutable std::unique_ptr condition_; const JoinType join_type_; @@ -1363,6 +1379,10 @@ class RelCompound : public RelAlgNode, public ModifyManipulationTarget { return hints_->at(hint_name); } + bool hasDeliveredHint() { return !hints_->empty(); } + + Hints* getDeliveredHints() { return hints_.get(); } + private: std::unique_ptr filter_expr_; const size_t groupby_count_; @@ -1832,7 +1852,53 @@ class RelAlgDagBuilder : public boost::noncopyable { return subqueries_; } - void registerQueryHints(QueryHint& query_hint) { query_hint_ = query_hint; } + void registerQueryHints(Hints* hints_delivered) { + for (auto& kv : query_hint_.OMNISCI_SUPPORTED_HINT_CLASS) { + auto target = hints_delivered->find(kv.first); + if (target != hints_delivered->end()) { + int target_hint_num = kv.second; + switch (target_hint_num) { + case 0: { // cpu_mode + query_hint_.hint_delivered = true; + query_hint_.cpu_mode = true; + VLOG(1) << "A user forces to run the query on the CPU execution mode"; + break; + } + case 1: { // overlaps_bucket_threshold + CHECK(target->second.getListOptions().size() == 1); + double overlaps_bucket_threshold = + std::stod(target->second.getListOptions()[0]); + if (overlaps_bucket_threshold >= 0.0 && overlaps_bucket_threshold <= 1.0) { + query_hint_.hint_delivered = true; + query_hint_.overlaps_bucket_threshold = overlaps_bucket_threshold; + } else { + VLOG(1) << "Skip the given query hint \"overlaps_bucket_threshold\" (" + << overlaps_bucket_threshold + << ") : the hint value should be within 0.0 ~ 1.0"; + } + break; + } + case 2: { // overlaps_max_size + CHECK(target->second.getListOptions().size() == 1); + std::stringstream ss(target->second.getListOptions()[0]); + int overlaps_max_size; + ss >> overlaps_max_size; + if (overlaps_max_size >= 0) { + query_hint_.hint_delivered = true; + query_hint_.overlaps_max_size = (size_t)overlaps_max_size; + } else { + VLOG(1) << "Skip the query hint \"overlaps_max_size\" (" + << overlaps_max_size + << ") : the hint value should be larger than or equal to zero"; + } + break; + } + default: + break; + } + } + } + } const QueryHint getQueryHints() const { return query_hint_; } diff --git a/QueryEngine/RelAlgExecutionUnit.h b/QueryEngine/RelAlgExecutionUnit.h index 66d6deabc4..a71e4728cc 100644 --- a/QueryEngine/RelAlgExecutionUnit.h +++ b/QueryEngine/RelAlgExecutionUnit.h @@ -27,6 +27,7 @@ #define QUERYENGINE_RELALGEXECUTIONUNIT_H #include "Descriptors/InputDescriptors.h" +#include "QueryHint.h" #include "Shared/sqldefs.h" #include "Shared/toString.h" #include "TableFunctions/TableFunctionOutputBufferSizeType.h" @@ -73,6 +74,7 @@ struct RelAlgExecutionUnit { const std::shared_ptr estimator; const SortInfo sort_info; size_t scan_limit; + QueryHint query_hint; bool use_bump_allocator{false}; // empty if not a UNION, true if UNION ALL, false if regular UNION const std::optional union_all; diff --git a/QueryEngine/RelAlgExecutor.cpp b/QueryEngine/RelAlgExecutor.cpp index 3b4e70066f..93d9f2ab2b 100644 --- a/QueryEngine/RelAlgExecutor.cpp +++ b/QueryEngine/RelAlgExecutor.cpp @@ -1892,7 +1892,8 @@ std::unique_ptr RelAlgExecutor::createWindowFunctionConte query_infos, memory_level, HashType::OneToMany, - column_cache_map); + column_cache_map, + ra_exe_unit.query_hint); if (!join_table_or_err.fail_reason.empty()) { throw std::runtime_error(join_table_or_err.fail_reason); } @@ -2611,6 +2612,7 @@ RelAlgExecutor::WorkUnit RelAlgExecutor::createSortInputWorkUnit( nullptr, {sort_info.order_entries, sort_algorithm, limit, offset}, scan_total_limit, + source_exe_unit.query_hint, source_exe_unit.use_bump_allocator, source_exe_unit.union_all, source_exe_unit.query_state}, @@ -2810,6 +2812,11 @@ ExecutionResult RelAlgExecutor::executeWorkUnit( auto ra_exe_unit = decide_approx_count_distinct_implementation( work_unit.exe_unit, table_infos, executor_, co.device_type, target_exprs_owned_); + + // register query hint if query_dag_ is valid + ra_exe_unit.query_hint = + query_dag_ ? query_dag_->getQueryHints() : QueryHint::defaults(); + auto max_groups_buffer_entry_guess = work_unit.max_groups_buffer_entry_guess; if (is_window_execution_unit(ra_exe_unit)) { CHECK_EQ(table_infos.size(), size_t(1)); @@ -3424,19 +3431,21 @@ RelAlgExecutor::WorkUnit RelAlgExecutor::createCompoundWorkUnit( translator, eo.executor_type); CHECK_EQ(compound->size(), target_exprs.size()); - const RelAlgExecutionUnit exe_unit = {input_descs, - input_col_descs, - quals_cf.simple_quals, - rewrite_quals(quals_cf.quals), - left_deep_join_quals, - groupby_exprs, - target_exprs, - nullptr, - sort_info, - 0, - false, - std::nullopt, - query_state_}; + const RelAlgExecutionUnit exe_unit = { + input_descs, + input_col_descs, + quals_cf.simple_quals, + rewrite_quals(quals_cf.quals), + left_deep_join_quals, + groupby_exprs, + target_exprs, + nullptr, + sort_info, + 0, + query_dag_ ? query_dag_->getQueryHints() : QueryHint::defaults(), + false, + std::nullopt, + query_state_}; auto query_rewriter = std::make_unique(query_infos, executor_); const auto rewritten_exe_unit = query_rewriter->rewrite(exe_unit); const auto targets_meta = get_targets_meta(compound, rewritten_exe_unit.target_exprs); @@ -3669,19 +3678,21 @@ RelAlgExecutor::WorkUnit RelAlgExecutor::createAggregateWorkUnit( target_exprs_owned_, scalar_sources, groupby_exprs, aggregate, translator); const auto targets_meta = get_targets_meta(aggregate, target_exprs); aggregate->setOutputMetainfo(targets_meta); - return {RelAlgExecutionUnit{input_descs, - input_col_descs, - {}, - {}, - {}, - groupby_exprs, - target_exprs, - nullptr, - sort_info, - 0, - false, - std::nullopt, - query_state_}, + return {RelAlgExecutionUnit{ + input_descs, + input_col_descs, + {}, + {}, + {}, + groupby_exprs, + target_exprs, + nullptr, + sort_info, + 0, + query_dag_ ? query_dag_->getQueryHints() : QueryHint::defaults(), + false, + std::nullopt, + query_state_}, aggregate, max_groups_buffer_entry_default_guess, nullptr}; @@ -3739,19 +3750,21 @@ RelAlgExecutor::WorkUnit RelAlgExecutor::createProjectWorkUnit( target_exprs_owned_.end(), target_exprs_owned.begin(), target_exprs_owned.end()); const auto target_exprs = get_exprs_not_owned(target_exprs_owned); - const RelAlgExecutionUnit exe_unit = {input_descs, - input_col_descs, - {}, - {}, - left_deep_join_quals, - {nullptr}, - target_exprs, - nullptr, - sort_info, - 0, - false, - std::nullopt, - query_state_}; + const RelAlgExecutionUnit exe_unit = { + input_descs, + input_col_descs, + {}, + {}, + left_deep_join_quals, + {nullptr}, + target_exprs, + nullptr, + sort_info, + 0, + query_dag_ ? query_dag_->getQueryHints() : QueryHint::defaults(), + false, + std::nullopt, + query_state_}; auto query_rewriter = std::make_unique(query_infos, executor_); const auto rewritten_exe_unit = query_rewriter->rewrite(exe_unit); const auto targets_meta = get_targets_meta(project, rewritten_exe_unit.target_exprs); @@ -3825,19 +3838,21 @@ RelAlgExecutor::WorkUnit RelAlgExecutor::createUnionWorkUnit( << " target_exprs.size()=" << target_exprs.size() << " max_num_tuples=" << max_num_tuples; - const RelAlgExecutionUnit exe_unit = {input_descs, - input_col_descs, - {}, // quals_cf.simple_quals, - {}, // rewrite_quals(quals_cf.quals), - {}, - {nullptr}, - target_exprs, - nullptr, - sort_info, - max_num_tuples, - false, - logical_union->isAll(), - query_state_}; + const RelAlgExecutionUnit exe_unit = { + input_descs, + input_col_descs, + {}, // quals_cf.simple_quals, + {}, // rewrite_quals(quals_cf.quals), + {}, + {nullptr}, + target_exprs, + nullptr, + sort_info, + max_num_tuples, + query_dag_ ? query_dag_->getQueryHints() : QueryHint::defaults(), + false, + logical_union->isAll(), + query_state_}; auto query_rewriter = std::make_unique(query_infos, executor_); const auto rewritten_exe_unit = query_rewriter->rewrite(exe_unit); @@ -4052,7 +4067,8 @@ RelAlgExecutor::WorkUnit RelAlgExecutor::createFilterWorkUnit(const RelFilter* f target_exprs, nullptr, sort_info, - 0}, + 0, + query_dag_ ? query_dag_->getQueryHints() : QueryHint::defaults()}, filter, max_groups_buffer_entry_default_guess, nullptr}; diff --git a/QueryEngine/RelAlgExecutor.h b/QueryEngine/RelAlgExecutor.h index 20aa86db63..b3cbc9e90b 100644 --- a/QueryEngine/RelAlgExecutor.h +++ b/QueryEngine/RelAlgExecutor.h @@ -134,7 +134,7 @@ class RelAlgExecutor : private StorageIOFacility { CHECK(query_dag_); return query_dag_->getSubqueries(); }; - const QueryHint getParsedQueryHints() const { + QueryHint getParsedQueryHints() { CHECK(query_dag_); return query_dag_->getQueryHints(); } diff --git a/QueryRunner/QueryRunner.cpp b/QueryRunner/QueryRunner.cpp index 0c9e53eb6c..97076d2d1f 100644 --- a/QueryRunner/QueryRunner.cpp +++ b/QueryRunner/QueryRunner.cpp @@ -242,7 +242,7 @@ std::string apply_copy_to_shim(const std::string& query_str) { return result; } -QueryHint QueryRunner::getParsedQueryHintofQuery(const std::string& query_str) { +QueryHint QueryRunner::getParsedQueryHint(const std::string& query_str) { CHECK(session_info_); CHECK(!Catalog_Namespace::SysCatalog::instance().isAggregator()); auto query_state = create_query_state(session_info_, query_str); diff --git a/QueryRunner/QueryRunner.h b/QueryRunner/QueryRunner.h index 77cea4271c..9db58614dd 100644 --- a/QueryRunner/QueryRunner.h +++ b/QueryRunner/QueryRunner.h @@ -144,7 +144,7 @@ class QueryRunner { virtual std::vector> runMultipleStatements( const std::string&, const ExecutorDeviceType); - virtual QueryHint getParsedQueryHintofQuery(const std::string&); + virtual QueryHint getParsedQueryHint(const std::string&); virtual void runImport(Parser::CopyTableStmt* import_stmt); virtual std::unique_ptr getLoader( diff --git a/Tests/SQLHintTest.cpp b/Tests/SQLHintTest.cpp index f419a851b3..8d05f07357 100644 --- a/Tests/SQLHintTest.cpp +++ b/Tests/SQLHintTest.cpp @@ -20,6 +20,7 @@ #include "Catalog/Catalog.h" #include "Catalog/DBObject.h" +#include "DBHandlerTestHelpers.h" #include "DataMgr/DataMgr.h" #include "QueryEngine/Execute.h" #include "QueryRunner/QueryRunner.h" @@ -49,6 +50,12 @@ bool skip_tests(const ExecutorDeviceType device_type) { continue; \ } +bool approx_eq(const double v, const double target, const double eps = 0.01) { + const auto v_u64 = *reinterpret_cast(may_alias_ptr(&v)); + const auto target_u64 = *reinterpret_cast(may_alias_ptr(&target)); + return v_u64 == target_u64 || (target - eps < v && v < target + eps); +} + inline void run_ddl_statement(const std::string& create_table_stmt) { QR::get()->runDDLStatement(create_table_stmt); } @@ -65,18 +72,84 @@ TEST(CPU_MODE, ForceToCPUMode) { const auto query_without_cpu_mode_hint = "SELECT * FROM SQL_HINT_DUMMY"; QR::get()->runDDLStatement(drop_table_ddl); QR::get()->runDDLStatement(create_table_ddl); - for (auto dt : {ExecutorDeviceType::CPU, ExecutorDeviceType::GPU}) { - SKIP_NO_GPU(); - if (QR::get()->gpusPresent()) { - auto query_hints = QR::get()->getParsedQueryHintofQuery(query_with_cpu_mode_hint); - CHECK(query_hints.cpu_mode); - query_hints = QR::get()->getParsedQueryHintofQuery(query_without_cpu_mode_hint); - CHECK(!query_hints.cpu_mode); - } + if (QR::get()->gpusPresent()) { + auto query_hints = QR::get()->getParsedQueryHint(query_with_cpu_mode_hint); + CHECK(query_hints.hint_delivered && query_hints.cpu_mode); + query_hints = QR::get()->getParsedQueryHint(query_without_cpu_mode_hint); + CHECK(!query_hints.hint_delivered && !query_hints.cpu_mode); } QR::get()->runDDLStatement(drop_table_ddl); } +TEST(OVERLAPS_JOIN_PARAM, Check_Overlaps_Join_Hint) { + const auto overlaps_join_status_backup = g_enable_overlaps_hashjoin; + g_enable_overlaps_hashjoin = true; + ScopeGuard reset_loop_join_state = [&overlaps_join_status_backup] { + g_enable_overlaps_hashjoin = overlaps_join_status_backup; + }; + + const auto drop_table_ddl_1 = "DROP TABLE IF EXISTS geospatial_test"; + const auto drop_table_ddl_2 = "DROP TABLE IF EXISTS geospatial_inner_join_test"; + const auto create_table_ddl_1 = + "CREATE TABLE geospatial_test(id INT, p POINT, l LINESTRING, poly POLYGON);"; + const auto create_table_ddl_2 = + "CREATE TABLE geospatial_inner_join_test(id INT, p POINT, l LINESTRING, poly " + "POLYGON);"; + + QR::get()->runDDLStatement(drop_table_ddl_1); + QR::get()->runDDLStatement(drop_table_ddl_2); + QR::get()->runDDLStatement(create_table_ddl_1); + QR::get()->runDDLStatement(create_table_ddl_2); + + const auto q1 = + "SELECT /*+ overlaps_bucket_threshold(0.718) */ a.id FROM geospatial_test a INNER " + "JOIN geospatial_inner_join_test b ON ST_Contains(b.poly, a.p);"; + const auto q2 = + "SELECT /*+ overlaps_max_size(2021) */ a.id FROM geospatial_test a INNER JOIN " + "geospatial_inner_join_test b ON ST_Contains(b.poly, a.p);"; + const auto q3 = + "SELECT /*+ overlaps_bucket_threshold(0.718), overlaps_max_size(2021) */ a.id FROM " + "geospatial_test a INNER JOIN geospatial_inner_join_test b ON ST_Contains(b.poly, " + "a.p);"; + const auto query_without_hint = + "SELECT a.id FROM geospatial_test a INNER JOIN geospatial_inner_join_test b ON " + "ST_Contains(b.poly, a.p);"; + const auto wrong_q1 = + "SELECT /*+ overlaps_bucket_threshold(-0.718) */ a.id FROM geospatial_test a INNER " + "JOIN geospatial_inner_join_test b ON ST_Contains(b.poly, a.p);"; + const auto wrong_q2 = + "SELECT /*+ overlaps_bucket_threshold(1.718) */ a.id FROM geospatial_test a INNER " + "JOIN geospatial_inner_join_test b ON ST_Contains(b.poly, a.p);"; + const auto wrong_q3 = + "SELECT /*+ overlaps_max_size(-2021) */ a.id FROM geospatial_test a INNER " + "JOIN geospatial_inner_join_test b ON ST_Contains(b.poly, a.p);"; + + auto q1_hints = QR::get()->getParsedQueryHint(q1); + CHECK(q1_hints.hint_delivered && approx_eq(q1_hints.overlaps_bucket_threshold, 0.718)); + + auto q2_hints = QR::get()->getParsedQueryHint(q2); + CHECK(q2_hints.hint_delivered && (q2_hints.overlaps_max_size == 2021)); + + auto q3_hints = QR::get()->getParsedQueryHint(q3); + CHECK(q3_hints.hint_delivered && (q3_hints.overlaps_max_size == 2021) && + approx_eq(q3_hints.overlaps_bucket_threshold, 0.718)); + + auto query_without_hint_res = QR::get()->getParsedQueryHint(query_without_hint); + CHECK(!query_without_hint_res.hint_delivered); + + auto wrong_q1_hints = QR::get()->getParsedQueryHint(wrong_q1); + CHECK(!wrong_q1_hints.hint_delivered); + + auto wrong_q2_hints = QR::get()->getParsedQueryHint(wrong_q2); + CHECK(!wrong_q2_hints.hint_delivered); + + auto wrong_q3_hints = QR::get()->getParsedQueryHint(wrong_q3); + CHECK(!wrong_q3_hints.hint_delivered); + + QR::get()->runDDLStatement(drop_table_ddl_1); + QR::get()->runDDLStatement(drop_table_ddl_2); +} + int main(int argc, char** argv) { TestHelpers::init_logger_stderr_only(argc, argv); testing::InitGoogleTest(&argc, argv); diff --git a/java/calcite/src/main/java/com/mapd/parser/hint/OmniSciHintStrategyTable.java b/java/calcite/src/main/java/com/mapd/parser/hint/OmniSciHintStrategyTable.java index 42a47dcba2..a7b569a594 100644 --- a/java/calcite/src/main/java/com/mapd/parser/hint/OmniSciHintStrategyTable.java +++ b/java/calcite/src/main/java/com/mapd/parser/hint/OmniSciHintStrategyTable.java @@ -11,6 +11,9 @@ private static HintStrategyTable createHintStrategies() { } static HintStrategyTable createHintStrategies(HintStrategyTable.Builder builder) { - return builder.hintStrategy("cpu_mode", HintPredicates.SET_VAR).build(); + return builder.hintStrategy("cpu_mode", HintPredicates.SET_VAR) + .hintStrategy("overlaps_bucket_threshold", HintPredicates.SET_VAR) + .hintStrategy("overlaps_max_size", HintPredicates.SET_VAR) + .build(); } }