Skip to content

Commit

Permalink
use window on high cardinality keys
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Nov 25, 2024
1 parent 2a80d3c commit be5837f
Show file tree
Hide file tree
Showing 13 changed files with 939 additions and 184 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
)
}

// If the partition keys are high cardinality, the aggregation method is slower.
def enableConvertWindowGroupLimitToAggregate(): Boolean = {
SparkEnv.get.conf.getBoolean(
CHConf.runtimeConfig("enable_window_group_limit_to_aggregate"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3162,62 +3162,66 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
}

test("GLUTEN-7905 get topk of window by aggregate") {
def checkWindowGroupLimit(df: DataFrame): Unit = {
val expands = collectWithSubqueries(df.queryExecution.executedPlan) {
case e: ExpandExecTransformer
if (e.child.isInstanceOf[CHAggregateGroupLimitExecTransformer]) =>
e
withSQLConf((
"spark.gluten.sql.columnar.backend.ch.runtime_config.enable_window_group_limit_to_aggregate",
"true")) {
def checkWindowGroupLimit(df: DataFrame): Unit = {
val expands = collectWithSubqueries(df.queryExecution.executedPlan) {
case e: ExpandExecTransformer
if (e.child.isInstanceOf[CHAggregateGroupLimitExecTransformer]) =>
e
}
assert(expands.size == 1)
}
assert(expands.size == 1)
spark.sql("create table test_win_top (a string, b int, c int) using parquet")
spark.sql("""
|insert into test_win_top values
|('a', 3, 3), ('a', 1, 5), ('a', 2, 2), ('a', null, null), ('a', null, 1),
|('b', 1, 1), ('b', 2, 1),
|('c', 2, 3)
|""".stripMargin)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b desc nulls first) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b desc, c nulls last) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b asc nulls first, c) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b asc nulls last) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b , c) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
spark.sql("drop table if exists test_win_top")
}
spark.sql("create table test_win_top (a string, b int, c int) using parquet")
spark.sql("""
|insert into test_win_top values
|('a', 3, 3), ('a', 1, 5), ('a', 2, 2), ('a', null, null), ('a', null, 1),
|('b', 1, 1), ('b', 2, 1),
|('c', 2, 3)
|""".stripMargin)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b desc nulls first) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b desc, c nulls last) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b asc nulls first) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b asc nulls last) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b , c) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
spark.sql("drop table if exists test_win_top")

}

Expand Down
29 changes: 6 additions & 23 deletions cpp-ch/local-engine/AggregateFunctions/GroupLimitFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

#include <Poco/Logger.h>
#include <Common/logger_useful.h>
#include "base/defines.h"

namespace DB::ErrorCodes
{
Expand Down Expand Up @@ -72,7 +73,6 @@ struct RowNumGroupArraySortedData
const auto & pos = sort_order.pos;
const auto & asc = sort_order.direction;
const auto & nulls_first = sort_order.nulls_direction;
LOG_ERROR(getLogger("GroupLimitFunction"), "xxx pos: {} tuple size: {} {}", pos, rhs.size(), lhs.size());
bool l_is_null = lhs[pos].isNull();
bool r_is_null = rhs[pos].isNull();
if (l_is_null && r_is_null)
Expand Down Expand Up @@ -120,25 +120,17 @@ struct RowNumGroupArraySortedData
values[current_index] = current;
}

ALWAYS_INLINE void addElement(const Data & data, const SortOrderFields & sort_orders, size_t max_elements)
ALWAYS_INLINE void addElement(const Data && data, const SortOrderFields & sort_orders, size_t max_elements)
{
if (values.size() >= max_elements)
{
LOG_ERROR(
getLogger("GroupLimitFunction"),
"xxxx values size: {}, limit: {}, tuple size: {} {}",
values.size(),
max_elements,
data.size(),
values[0].size());
if (!compare(data, values[0], sort_orders))
return;
values[0] = data;
heapReplaceTop(sort_orders);
return;
}
values.push_back(data);
LOG_ERROR(getLogger("GroupLimitFunction"), "add new element: {} {}", values.size(), values.back().size());
values.emplace_back(std::move(data));
auto cmp = [&sort_orders](const Data & a, const Data & b) { return compare(a, b, sort_orders); };
std::push_heap(values.begin(), values.end(), cmp);
}
Expand Down Expand Up @@ -190,7 +182,7 @@ class RowNumGroupArraySorted final : public DB::IAggregateFunctionDataHelper<Row
public:
explicit RowNumGroupArraySorted(DB::DataTypePtr data_type, const DB::Array & parameters_)
: DB::IAggregateFunctionDataHelper<RowNumGroupArraySortedData, RowNumGroupArraySorted>(
{data_type}, parameters_, getRowNumReultDataType(data_type))
{data_type}, parameters_, getRowNumReultDataType(data_type))
{
if (parameters_.size() != 2)
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "{} needs two parameters: limit and order clause", getName());
Expand All @@ -212,23 +204,14 @@ class RowNumGroupArraySorted final : public DB::IAggregateFunctionDataHelper<Row
{
auto & data = this->data(place);
DB::Tuple data_tuple = (*columns[0])[row_num].safeGet<DB::Tuple>();
// const DB::Tuple & data_tuple = *(static_cast<const DB::Tuple *>(&((*columns[0])[row_num])));
LOG_ERROR(
getLogger("GroupLimitFunction"),
"xxx col len: {}, row num: {}, tuple size: {}, type: {}",
columns[0]->size(),
row_num,
data_tuple.size(),
(*columns[0])[row_num].getType());
;
this->data(place).addElement(data_tuple, sort_order_fields, limit);
this->data(place).addElement(std::move(data_tuple), sort_order_fields, limit);
}

void merge(DB::AggregateDataPtr __restrict place, DB::ConstAggregateDataPtr rhs, DB::Arena * /*arena*/) const override
{
auto & rhs_values = this->data(rhs).values;
for (auto & rhs_element : rhs_values)
this->data(place).addElement(rhs_element, sort_order_fields, limit);
this->data(place).addElement(std::move(rhs_element), sort_order_fields, limit);
}

void serialize(DB::ConstAggregateDataPtr __restrict place, DB::WriteBuffer & buf, std::optional<size_t> /* version */) const override
Expand Down
1 change: 0 additions & 1 deletion cpp-ch/local-engine/Common/AggregateUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ extern const SettingsBool enable_memory_bound_merging_of_aggregation_results;
extern const SettingsUInt64 aggregation_in_order_max_block_bytes;
extern const SettingsUInt64 group_by_two_level_threshold;
extern const SettingsFloat min_hit_rate_to_use_consecutive_keys_optimization;
extern const SettingsMaxThreads max_threads;
extern const SettingsUInt64 max_block_size;
}

Expand Down
15 changes: 15 additions & 0 deletions cpp-ch/local-engine/Common/ArrayJoinHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,21 @@ addArrayJoinStep(DB::ContextPtr context, DB::QueryPlan & plan, const DB::Actions
steps.emplace_back(array_join_step.get());
plan.addStep(std::move(array_join_step));
// LOG_DEBUG(logger, "plan2:{}", PlanUtil::explainPlan(*query_plan));

/// Post-projection after array join(Optional)
if (!ignore_actions_dag(splitted_actions_dags.after_array_join))
{
auto step_after_array_join
= std::make_unique<DB::ExpressionStep>(plan.getCurrentHeader(), std::move(splitted_actions_dags.after_array_join));
step_after_array_join->setStepDescription("Post-projection In Generate");
steps.emplace_back(step_after_array_join.get());
plan.addStep(std::move(step_after_array_join));
// LOG_DEBUG(logger, "plan3:{}", PlanUtil::explainPlan(*query_plan));
}
}
else
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Expect array join node in actions_dag");
}

return steps;
Expand Down
Loading

0 comments on commit be5837f

Please sign in to comment.