Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/named aggs #1468

Merged
merged 16 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 24 additions & 22 deletions cpp/arcticdb/processing/clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,34 +326,36 @@ Composite<EntityIds> ProjectClause::process(Composite<EntityIds>&& entity_ids) c
}

AggregationClause::AggregationClause(const std::string& grouping_column,
const std::unordered_map<std::string,
std::string>& aggregations):
grouping_column_(grouping_column),
aggregation_map_(aggregations) {
const std::vector<NamedAggregator>& named_aggregators):
grouping_column_(grouping_column) {
clause_info_.can_combine_with_column_selection_ = false;
clause_info_.new_index_ = grouping_column_;
clause_info_.input_columns_ = std::make_optional<std::unordered_set<std::string>>({grouping_column_});
clause_info_.modifies_output_descriptor_ = true;
for (const auto& [column_name, aggregation_operator]: aggregations) {
auto [_, inserted] = clause_info_.input_columns_->insert(column_name);
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(inserted,
"Cannot perform two aggregations over the same column: {}",
column_name);
auto typed_column_name = ColumnName(column_name);
if (aggregation_operator == "sum") {
aggregators_.emplace_back(SumAggregator(typed_column_name, typed_column_name));
} else if (aggregation_operator == "mean") {
aggregators_.emplace_back(MeanAggregator(typed_column_name, typed_column_name));
} else if (aggregation_operator == "max") {
aggregators_.emplace_back(MaxAggregator(typed_column_name, typed_column_name));
} else if (aggregation_operator == "min") {
aggregators_.emplace_back(MinAggregator(typed_column_name, typed_column_name));
} else if (aggregation_operator == "count") {
aggregators_.emplace_back(CountAggregator(typed_column_name, typed_column_name));
str_ = "AGGREGATE {";
for (const auto& named_aggregator: named_aggregators) {
str_.append(fmt::format("{}: ({}, {}), ",
named_aggregator.output_column_name_,
named_aggregator.input_column_name_,
named_aggregator.aggregation_operator_));
clause_info_.input_columns_->insert(named_aggregator.input_column_name_);
auto typed_input_column_name = ColumnName(named_aggregator.input_column_name_);
auto typed_output_column_name = ColumnName(named_aggregator.output_column_name_);
if (named_aggregator.aggregation_operator_ == "sum") {
aggregators_.emplace_back(SumAggregator(typed_input_column_name, typed_output_column_name));
} else if (named_aggregator.aggregation_operator_ == "mean") {
aggregators_.emplace_back(MeanAggregator(typed_input_column_name, typed_output_column_name));
} else if (named_aggregator.aggregation_operator_ == "max") {
aggregators_.emplace_back(MaxAggregator(typed_input_column_name, typed_output_column_name));
} else if (named_aggregator.aggregation_operator_ == "min") {
aggregators_.emplace_back(MinAggregator(typed_input_column_name, typed_output_column_name));
} else if (named_aggregator.aggregation_operator_ == "count") {
aggregators_.emplace_back(CountAggregator(typed_input_column_name, typed_output_column_name));
} else {
user_input::raise<ErrorCode::E_INVALID_USER_ARGUMENT>("Unknown aggregation operator provided: {}", aggregation_operator);
user_input::raise<ErrorCode::E_INVALID_USER_ARGUMENT>("Unknown aggregation operator provided: {}", named_aggregator.aggregation_operator_);
}
}
str_.append("}");
}

Composite<EntityIds> AggregationClause::process(Composite<EntityIds>&& entity_ids) const {
Expand Down Expand Up @@ -528,7 +530,7 @@ Composite<EntityIds> AggregationClause::process(Composite<EntityIds>&& entity_id
}

[[nodiscard]] std::string AggregationClause::to_string() const {
return fmt::format("AGGREGATE {}", aggregation_map_);
return str_;
}

[[nodiscard]] Composite<EntityIds> RemoveColumnPartitioningClause::process(Composite<EntityIds>&& entity_ids) const {
Expand Down
11 changes: 8 additions & 3 deletions cpp/arcticdb/processing/clause.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,21 +350,26 @@ inline StreamDescriptor empty_descriptor(arcticdb::proto::descriptors::IndexDesc
return StreamDescriptor{StreamId{id}, IndexDescriptor{field_count, type}, std::make_shared<FieldCollection>()};
}

struct NamedAggregator {
std::string aggregation_operator_;
std::string input_column_name_;
std::string output_column_name_;
};

struct AggregationClause {
ClauseInfo clause_info_;
std::shared_ptr<ComponentManager> component_manager_;
ProcessingConfig processing_config_;
std::string grouping_column_;
std::unordered_map<std::string, std::string> aggregation_map_;
std::vector<GroupingAggregator> aggregators_;
std::string str_;

AggregationClause() = delete;

ARCTICDB_MOVE_COPY_DEFAULT(AggregationClause)

AggregationClause(const std::string& grouping_column,
const std::unordered_map<std::string,
std::string>& aggregations);
const std::vector<NamedAggregator>& aggregations);

[[noreturn]] std::vector<std::vector<size_t>> structure_for_processing(
ARCTICDB_UNUSED const std::vector<RangesAndKey>&,
Expand Down
28 changes: 24 additions & 4 deletions cpp/arcticdb/processing/test/test_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,12 @@ TEST(Clause, AggregationEmptyColumn) {
using namespace arcticdb;
auto component_manager = std::make_shared<ComponentManager>();

AggregationClause aggregation("int_repeated_values", {{"empty_sum", "sum"}, {"empty_min", "min"}, {"empty_max", "max"}, {"empty_mean", "mean"}, {"empty_count", "count"}});
AggregationClause aggregation("int_repeated_values",
{{"sum", "empty_sum", "empty_sum"},
{"min", "empty_min", "empty_min"},
{"max", "empty_max", "empty_max"},
{"mean", "empty_mean", "empty_mean"},
{"count", "empty_count", "empty_count"}});
aggregation.set_component_manager(component_manager);

size_t num_rows{100};
Expand Down Expand Up @@ -132,7 +137,12 @@ TEST(Clause, AggregationColumn)
using namespace arcticdb;
auto component_manager = std::make_shared<ComponentManager>();

AggregationClause aggregation("int_repeated_values", {{"sum_int", "sum"}, {"min_int", "min"}, {"max_int", "max"}, {"mean_int", "mean"}, {"count_int", "count"}});
AggregationClause aggregation("int_repeated_values",
{{"sum", "sum_int", "sum_int"},
{"min", "min_int", "min_int"},
{"max", "max_int", "max_int"},
{"mean", "mean_int", "mean_int"},
{"count", "count_int", "count_int"}});
aggregation.set_component_manager(component_manager);

size_t num_rows{100};
Expand All @@ -159,7 +169,12 @@ TEST(Clause, AggregationSparseColumn)
using namespace arcticdb;
auto component_manager = std::make_shared<ComponentManager>();

AggregationClause aggregation("int_repeated_values", {{"sum_int", "sum"}, {"min_int", "min"}, {"max_int", "max"}, {"mean_int", "mean"}, {"count_int", "count"}});
AggregationClause aggregation("int_repeated_values",
{{"sum", "sum_int", "sum_int"},
{"min", "min_int", "min_int"},
{"max", "max_int", "max_int"},
{"mean", "mean_int", "mean_int"},
{"count", "count_int", "count_int"}});
aggregation.set_component_manager(component_manager);

size_t num_rows{100};
Expand Down Expand Up @@ -217,7 +232,12 @@ TEST(Clause, AggregationSparseGroupby) {
using namespace arcticdb;
auto component_manager = std::make_shared<ComponentManager>();

AggregationClause aggregation("int_sparse_repeated_values", {{"sum_int", "sum"}, {"min_int", "min"}, {"max_int", "max"}, {"mean_int", "mean"}, {"count_int", "count"}});
AggregationClause aggregation("int_sparse_repeated_values",
{{"sum", "sum_int", "sum_int"},
{"min", "min_int", "min_int"},
{"max", "max_int", "max_int"},
{"mean", "mean_int", "mean_int"},
{"count", "count_int", "count_int"}});
aggregation.set_component_manager(component_manager);

size_t num_rows{100};
Expand Down
18 changes: 17 additions & 1 deletion cpp/arcticdb/version/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,23 @@ void register_bindings(py::module &version, py::exception<arcticdb::ArcticExcept
.def("__str__", &GroupByClause::to_string);

py::class_<AggregationClause, std::shared_ptr<AggregationClause>>(version, "AggregationClause")
.def(py::init<std::string, std::unordered_map<std::string, std::string>>())
.def(py::init([](
const std::string& grouping_colum,
const std::unordered_map<std::string, std::variant<std::string, std::pair<std::string, std::string>>> aggregations) {
std::vector<NamedAggregator> named_aggregators;
for (const auto& [output_column_name, var_agg_named_agg]: aggregations) {
util::variant_match(
var_agg_named_agg,
[&named_aggregators, &output_column_name] (const std::string& agg_operator) {
named_aggregators.emplace_back(agg_operator, output_column_name, output_column_name);
},
[&named_aggregators, &output_column_name] (const std::pair<std::string, std::string>& input_col_and_agg) {
named_aggregators.emplace_back(input_col_and_agg.second, input_col_and_agg.first, output_column_name);
}
);
}
return AggregationClause(grouping_colum, named_aggregators);
}))
.def("__str__", &AggregationClause::to_string);

py::enum_<RowRangeClause::RowRangeType>(version, "RowRangeType")
Expand Down
41 changes: 34 additions & 7 deletions python/arcticdb/version_store/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np
import pandas as pd

from typing import Dict, NamedTuple
from typing import Dict, NamedTuple, Tuple, Union

from arcticdb.exceptions import ArcticNativeException, UserInputException
from arcticdb.version_store._normalization import normalize_dt_range_to_ts
Expand Down Expand Up @@ -451,9 +451,9 @@ def groupby(self, name: str):
>>> q = q.groupby("grouping_column").agg({"to_mean": "mean"})
>>> lib.write("symbol", df)
>>> lib.read("symbol", query_builder=q).data
to_mean
to_mean
group_1 1.666667
group_2 NaN
group_2 2.2

Max over one group:

Expand Down Expand Up @@ -485,9 +485,27 @@ def groupby(self, name: str):
>>> q = q.groupby("grouping_column").agg({"to_max": "max", "to_mean": "mean"})
>>> lib.write("symbol", df)
>>> lib.read("symbol", query_builder=q).data
to_max to_mean
to_max to_mean
group_1 2.5 1.666667

Min and max over one column, mean over another:
>>> df = pd.DataFrame(
{
"grouping_column": ["group_1", "group_1", "group_1", "group_2", "group_2"],
"agg_1": [1, 2, 3, 4, 5],
"agg_2": [1.1, 1.4, 2.5, np.nan, 2.2],
},
index=np.arange(5),
)
>>> q = adb.QueryBuilder()
>>> q = q.groupby("grouping_column")
>>> q = q.agg({"agg_1_min": ("agg_1", "min"), "agg_1_max": ("agg_1", "max"), "agg_2": "mean"})
>>> lib.write("symbol", df)
>>> lib.read("symbol", query_builder=q).data
agg_1_min agg_1_max agg_2
group_1 1 3 1.666667
group_2 4 5 2.2

Returns
-------
QueryBuilder
Expand All @@ -497,14 +515,23 @@ def groupby(self, name: str):
self._python_clauses.append(PythonGroupByClause(name))
return self

def agg(self, aggregations: Dict[str, str]):
def agg(self, aggregations: Dict[str, Union[str, Tuple[str, str]]]):
# Only makes sense if previous stage is a group-by
check(
len(self.clauses) and isinstance(self.clauses[-1], _GroupByClause),
f"Aggregation only makes sense after groupby",
)
for v in aggregations.values():
v = v.lower()
for k, v in aggregations.items():
check(isinstance(v, (str, tuple)), f"Values in agg dict expected to be strings or tuples, received {v} of type {type(v)}")
if isinstance(v, str):
aggregations[k] = v.lower()
elif isinstance(v, tuple):
check(
len(v) == 2 and (isinstance(v[0], str) and isinstance(v[1], str)),
f"Tuple values in agg dict expected to have 2 string elements, received {v}"
)
aggregations[k] = (v[0], v[1].lower())

self.clauses.append(_AggregationClause(self.clauses[-1].grouping_column, aggregations))
self._python_clauses.append(PythonAggregationClause(aggregations))
return self
Expand Down
31 changes: 31 additions & 0 deletions python/tests/unit/arcticdb/version_store/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,37 @@ def test_mean_aggregation_float(local_object_version_store):
assert_frame_equal(res.data, df)


def test_named_agg(lmdb_version_store_tiny_segment):
lib = lmdb_version_store_tiny_segment
sym = "test_named_agg"
gen = np.random.default_rng()
df = DataFrame(
{
"grouping_column": [1, 1, 1, 2, 3, 4],
"agg_column": gen.random(6)
}
)
lib.write(sym, df)
expected = df.groupby("grouping_column").agg(
agg_column_sum=pd.NamedAgg("agg_column", "sum"),
agg_column_mean=pd.NamedAgg("agg_column", "mean"),
agg_column=pd.NamedAgg("agg_column", "min"),
)
expected = expected.reindex(columns=sorted(expected.columns))
q = QueryBuilder()
q = q.groupby("grouping_column").agg(
{
"agg_column_sum": ("agg_column", "sum"),
"agg_column_mean": ("agg_column", "MEAN"),
"agg_column": "MIN",
}
)
received = lib.read(sym, query_builder=q).data
received.sort_index(inplace=True)
received = received.reindex(columns=sorted(received.columns))
assert_frame_equal(expected, received)


def test_max_minus_one(lmdb_version_store):
symbol = "minus_one"
lib = lmdb_version_store
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def test_querybuilder_pickling():
q = q.groupby("col1")

# PythonAggregationClause
q = q.agg({"col2": "sum"})
q = q.agg({"col2": "sum", "new_col": ("col2", "mean")})

import pickle

Expand Down
Loading