Skip to content

Commit

Permalink
Fix Projection meta (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora authored May 10, 2023
1 parent c9c10d4 commit 3a45f2e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
19 changes: 18 additions & 1 deletion dask_expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
_get_meta_map_partitions,
apply_and_enforce,
is_dataframe_like,
is_index_like,
is_series_like,
)
from dask.utils import M, apply, funcname, import_required
from dask.utils import M, apply, funcname, import_required, is_arraylike

replacement_rules = []

Expand Down Expand Up @@ -83,6 +85,14 @@ def _tree_repr_lines(self, indent=0, recursive=True):

if isinstance(op, pd.core.base.PandasObject):
op = "<pandas>"
elif is_dataframe_like(op):
op = "<dataframe>"
elif is_index_like(op):
op = "<index>"
elif is_series_like(op):
op = "<series>"
elif is_arraylike(op):
op = "<array>"

elif repr(op) != repr(default):
if param:
Expand Down Expand Up @@ -777,6 +787,13 @@ def columns(self):
else:
return self.operand("columns")

@property
def _meta(self):
if is_dataframe_like(self.frame._meta):
return super()._meta
# Avoid column selection for Series/Index
return self.frame._meta

def _node_label_args(self):
return [self.frame, self.operand("columns")]

Expand Down
4 changes: 0 additions & 4 deletions dask_expr/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,6 @@ def chunk_kwargs(self):
min_count=self.min_count,
)

@property
def _meta(self):
return self.frame._meta.sum(**self.chunk_kwargs)

def _simplify_up(self, parent):
if isinstance(parent, Projection):
return self.frame[parent.operand("columns")].sum(*self.operands[1:])
Expand Down
3 changes: 2 additions & 1 deletion dask_expr/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pickle

import dask
import numpy as np
import pandas as pd
import pytest
from dask.dataframe.utils import assert_eq
Expand Down Expand Up @@ -47,7 +48,7 @@ def test_meta_divisions_name():
assert list(df.columns) == list(a.columns)
assert df.npartitions == 2

assert df.x.sum()._meta == 0
assert np.isscalar(df.x.sum()._meta)
assert df.x.sum().npartitions == 1

assert "mul" in df._name
Expand Down

0 comments on commit 3a45f2e

Please sign in to comment.