Skip to content

Commit

Permalink
Add minimal subset of interchange protocol (dask#1087)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored Jun 21, 2024
1 parent 7b786a7 commit 560303c
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 0 deletions.
5 changes: 5 additions & 0 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2673,6 +2673,11 @@ def __contains__(self, key):
def __iter__(self):
return iter(self._meta)

def __dataframe__(self, *args, **kwargs):
from dask_expr._interchange import DaskDataFrameInterchange

return DaskDataFrameInterchange(self)

@derived_from(pd.DataFrame)
def iterrows(self):
frame = self.optimize()
Expand Down
62 changes: 62 additions & 0 deletions dask_expr/_interchange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from dask.dataframe._compat import is_string_dtype
from dask.dataframe.dispatch import is_categorical_dtype
from pandas.core.interchange.dataframe_protocol import DtypeKind

from dask_expr._collection import DataFrame

_NP_KINDS = {
"i": DtypeKind.INT,
"u": DtypeKind.UINT,
"f": DtypeKind.FLOAT,
"b": DtypeKind.BOOL,
"U": DtypeKind.STRING,
"M": DtypeKind.DATETIME,
"m": DtypeKind.DATETIME,
}


class DaskDataFrameInterchange:
def __init__(
self, df: DataFrame, nan_as_null: bool = False, allow_copy: bool = True
) -> None:
self._df = df
self._nan_as_null = nan_as_null
self._allow_copy = allow_copy

def get_columns(self):
return [DaskColumn(self._df[name]) for name in self._df.columns]

def column_names(self):
return self._df.columns

def num_columns(self) -> int:
return len(self._df.columns)

def num_rows(self) -> int:
return len(self._df)


class DaskColumn:
def __init__(self, column, allow_copy: bool = True) -> None:
self._col = column
self._allow_copy = allow_copy

def dtype(self) -> tuple[DtypeKind, None, None, None]:
dtype = self._col.dtype

if is_categorical_dtype(dtype):
return (
DtypeKind.CATEGORICAL,
None,
None,
None,
)
elif is_string_dtype(dtype):
return (
DtypeKind.STRING,
None,
None,
None,
)
else:
return _NP_KINDS.get(dtype.kind, None), None, None, None
18 changes: 18 additions & 0 deletions dask_expr/tests/test_interchange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pandas.core.interchange.dataframe_protocol import DtypeKind

from dask_expr import from_pandas
from dask_expr.tests._util import _backend_library

# Set DataFrame backend for this module
pd = _backend_library()


def test_interchange_protocol():
pdf = pd.DataFrame({"a": [1, 2, 3], "b": 1})
df = from_pandas(pdf, npartitions=2)
df_int = df.__dataframe__()
pd.testing.assert_index_equal(pdf.columns, df_int.column_names())
assert df_int.num_columns() == 2
assert df_int.num_rows() == 3
column = df_int.get_columns()[0]
assert column.dtype()[0] == DtypeKind.INT

0 comments on commit 560303c

Please sign in to comment.