Skip to content

Commit

Permalink
Fix nan correctness of sum aggregator and remove as much branching as…
Browse files Browse the repository at this point in the history
… possible
  • Loading branch information
alexowens90 committed Apr 18, 2024
1 parent 60622a8 commit 7167248
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
21 changes: 15 additions & 6 deletions cpp/arcticdb/processing/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,26 +226,35 @@ template<typename T>
class SumBucketAggregator {
public:
void push(T value) {
has_value_ = true;
if constexpr (std::is_floating_point_v<T>) {
if (ARCTICDB_LIKELY(!std::isnan(value))) {
sum_ = sum_.value_or(0) + value;
sum_ += value;
}
} else {
sum_ = sum_.value_or(0) + value;
sum_ += value;
}
}

T finalize() {
auto res = sum_.value_or(0);
sum_.reset();
T res;
if (ARCTICDB_LIKELY(has_value_)) {
res = sum_;
} else {
res = default_value_;
}
sum_ = 0;
has_value_ = false;
return res;
}

[[nodiscard]] bool has_value() const {
return sum_.has_value();
return has_value_;
}
private:
std::optional<T> sum_{std::nullopt};
bool has_value_{false};
static constexpr T default_value_{0};
T sum_{0};
};

template<typename T>
Expand Down
32 changes: 30 additions & 2 deletions python/tests/unit/arcticdb/version_store/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,36 @@ def test_resampling(lmdb_version_store_v1, freq, date_range, closed, label):
)


def test_resampling_with_nans(lmdb_version_store_v1):
lib = lmdb_version_store_v1
def test_tmp(lmdb_version_store_tiny_segment):
lib = lmdb_version_store_tiny_segment
sym = "test_tmp"
idx = [0, 1]
idx = np.array(idx, dtype="datetime64[ns]")
col = np.arange(2, dtype=np.float64)
col[0] = np.nan
col[1] = np.nan

df = pd.DataFrame({"col": col}, index=idx)
lib.write(sym, df)

expected = df.resample("us").agg(
sum=pd.NamedAgg("col", "sum"),
)
expected = expected.reindex(columns=sorted(expected.columns))

q = QueryBuilder()
q = q.resample("us").agg(
{
"sum": ("col", "sum"),
}
)
received = lib.read(sym, query_builder=q).data
received = received.reindex(columns=sorted(received.columns))
assert_frame_equal(expected, received)


def test_resampling_with_nans(lmdb_version_store_tiny_segment):
lib = lmdb_version_store_tiny_segment
sym = "test_resampling_with_nans"
# Create 5 buckets worth of data, each containing 3 values:
# - No nans
Expand Down

0 comments on commit 7167248

Please sign in to comment.