Skip to content

Commit

Permalink
Use struct instead of messy map in C++ layer
Browse files Browse the repository at this point in the history
  • Loading branch information
alexowens90 committed Apr 2, 2024
1 parent 917cd81 commit 566783c
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 27 deletions.
40 changes: 15 additions & 25 deletions cpp/arcticdb/processing/clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,43 +326,33 @@ Composite<EntityIds> ProjectClause::process(Composite<EntityIds>&& entity_ids) c
}

AggregationClause::AggregationClause(const std::string& grouping_column,
const std::unordered_map<std::string, std::variant<std::string, std::pair<std::string, std::string>>>& aggregations):
const std::vector<NamedAggregator>& named_aggregators):
grouping_column_(grouping_column) {
clause_info_.can_combine_with_column_selection_ = false;
clause_info_.new_index_ = grouping_column_;
clause_info_.input_columns_ = std::make_optional<std::unordered_set<std::string>>({grouping_column_});
clause_info_.modifies_output_descriptor_ = true;
str_ = "AGGREGATE {";
for (const auto& [output_column_name, var_agg_named_agg]: aggregations) {
std::string input_column_name;
std::string aggregation_operator;
util::variant_match(
var_agg_named_agg,
[&input_column_name, &aggregation_operator, &output_column_name] (const std::string& agg_operator) {
input_column_name = output_column_name;
aggregation_operator = agg_operator;
},
[&input_column_name, &aggregation_operator] (const std::pair<std::string, std::string>& input_col_and_agg) {
input_column_name = input_col_and_agg.first;
aggregation_operator = input_col_and_agg.second;
}
);
str_.append(fmt::format("{}: ({}, {}), ", output_column_name, input_column_name, aggregation_operator));
clause_info_.input_columns_->insert(input_column_name);
auto typed_input_column_name = ColumnName(input_column_name);
auto typed_output_column_name = ColumnName(output_column_name);
if (aggregation_operator == "sum") {
for (const auto& named_aggregator: named_aggregators) {
str_.append(fmt::format("{}: ({}, {}), ",
named_aggregator.output_column_name_,
named_aggregator.input_column_name_,
named_aggregator.aggregation_operator_));
clause_info_.input_columns_->insert(named_aggregator.input_column_name_);
auto typed_input_column_name = ColumnName(named_aggregator.input_column_name_);
auto typed_output_column_name = ColumnName(named_aggregator.output_column_name_);
if (named_aggregator.aggregation_operator_ == "sum") {
aggregators_.emplace_back(SumAggregator(typed_input_column_name, typed_output_column_name));
} else if (aggregation_operator == "mean") {
} else if (named_aggregator.aggregation_operator_ == "mean") {
aggregators_.emplace_back(MeanAggregator(typed_input_column_name, typed_output_column_name));
} else if (aggregation_operator == "max") {
} else if (named_aggregator.aggregation_operator_ == "max") {
aggregators_.emplace_back(MaxAggregator(typed_input_column_name, typed_output_column_name));
} else if (aggregation_operator == "min") {
} else if (named_aggregator.aggregation_operator_ == "min") {
aggregators_.emplace_back(MinAggregator(typed_input_column_name, typed_output_column_name));
} else if (aggregation_operator == "count") {
} else if (named_aggregator.aggregation_operator_ == "count") {
aggregators_.emplace_back(CountAggregator(typed_input_column_name, typed_output_column_name));
} else {
user_input::raise<ErrorCode::E_INVALID_USER_ARGUMENT>("Unknown aggregation operator provided: {}", aggregation_operator);
user_input::raise<ErrorCode::E_INVALID_USER_ARGUMENT>("Unknown aggregation operator provided: {}", named_aggregator.aggregation_operator_);
}
}
str_.append("}");
Expand Down
8 changes: 7 additions & 1 deletion cpp/arcticdb/processing/clause.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,12 @@ inline StreamDescriptor empty_descriptor(arcticdb::proto::descriptors::IndexDesc
return StreamDescriptor{StreamId{id}, IndexDescriptor{field_count, type}, std::make_shared<FieldCollection>()};
}

struct NamedAggregator {
std::string aggregation_operator_;
std::string input_column_name_;
std::string output_column_name_;
};

struct AggregationClause {
ClauseInfo clause_info_;
std::shared_ptr<ComponentManager> component_manager_;
Expand All @@ -363,7 +369,7 @@ struct AggregationClause {
ARCTICDB_MOVE_COPY_DEFAULT(AggregationClause)

AggregationClause(const std::string& grouping_column,
const std::unordered_map<std::string, std::variant<std::string, std::pair<std::string, std::string>>>& aggregations);
const std::vector<NamedAggregator>& aggregations);

[[noreturn]] std::vector<std::vector<size_t>> structure_for_processing(
ARCTICDB_UNUSED const std::vector<RangesAndKey>&,
Expand Down
18 changes: 17 additions & 1 deletion cpp/arcticdb/version/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,23 @@ void register_bindings(py::module &version, py::exception<arcticdb::ArcticExcept
.def("__str__", &GroupByClause::to_string);

py::class_<AggregationClause, std::shared_ptr<AggregationClause>>(version, "AggregationClause")
.def(py::init<std::string, std::unordered_map<std::string, std::variant<std::string, std::pair<std::string, std::string>>>>())
.def(py::init([](
const std::string& grouping_colum,
const std::unordered_map<std::string, std::variant<std::string, std::pair<std::string, std::string>>> aggregations) {
std::vector<NamedAggregator> named_aggregators;
for (const auto& [output_column_name, var_agg_named_agg]: aggregations) {
util::variant_match(
var_agg_named_agg,
[&named_aggregators, &output_column_name] (const std::string& agg_operator) {
named_aggregators.emplace_back(agg_operator, output_column_name, output_column_name);
},
[&named_aggregators, &output_column_name] (const std::pair<std::string, std::string>& input_col_and_agg) {
named_aggregators.emplace_back(input_col_and_agg.second, input_col_and_agg.first, output_column_name);
}
);
}
return AggregationClause(grouping_colum, named_aggregators);
}))
.def("__str__", &AggregationClause::to_string);

py::enum_<RowRangeClause::RowRangeType>(version, "RowRangeType")
Expand Down

0 comments on commit 566783c

Please sign in to comment.