diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index a87e2849..5eb95e02 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -2673,6 +2673,11 @@ def __contains__(self, key): def __iter__(self): return iter(self._meta) + def __dataframe__(self, *args, **kwargs): + from dask_expr._interchange import DaskDataFrameInterchange + + return DaskDataFrameInterchange(self) + @derived_from(pd.DataFrame) def iterrows(self): frame = self.optimize() diff --git a/dask_expr/_interchange.py b/dask_expr/_interchange.py new file mode 100644 index 00000000..69cf923d --- /dev/null +++ b/dask_expr/_interchange.py @@ -0,0 +1,62 @@ +from dask.dataframe._compat import is_string_dtype +from dask.dataframe.dispatch import is_categorical_dtype +from pandas.core.interchange.dataframe_protocol import DtypeKind + +from dask_expr._collection import DataFrame + +_NP_KINDS = { + "i": DtypeKind.INT, + "u": DtypeKind.UINT, + "f": DtypeKind.FLOAT, + "b": DtypeKind.BOOL, + "U": DtypeKind.STRING, + "M": DtypeKind.DATETIME, + "m": DtypeKind.DATETIME, +} + + +class DaskDataFrameInterchange: + def __init__( + self, df: DataFrame, nan_as_null: bool = False, allow_copy: bool = True + ) -> None: + self._df = df + self._nan_as_null = nan_as_null + self._allow_copy = allow_copy + + def get_columns(self): + return [DaskColumn(self._df[name]) for name in self._df.columns] + + def column_names(self): + return self._df.columns + + def num_columns(self) -> int: + return len(self._df.columns) + + def num_rows(self) -> int: + return len(self._df) + + +class DaskColumn: + def __init__(self, column, allow_copy: bool = True) -> None: + self._col = column + self._allow_copy = allow_copy + + def dtype(self) -> tuple[DtypeKind, None, None, None]: + dtype = self._col.dtype + + if is_categorical_dtype(dtype): + return ( + DtypeKind.CATEGORICAL, + None, + None, + None, + ) + elif is_string_dtype(dtype): + return ( + DtypeKind.STRING, + None, + None, + None, + ) + else: + return _NP_KINDS.get(dtype.kind, None), None, None, None diff --git a/dask_expr/tests/test_interchange.py b/dask_expr/tests/test_interchange.py new file mode 100644 index 00000000..3db7c9d7 --- /dev/null +++ b/dask_expr/tests/test_interchange.py @@ -0,0 +1,18 @@ +from pandas.core.interchange.dataframe_protocol import DtypeKind + +from dask_expr import from_pandas +from dask_expr.tests._util import _backend_library + +# Set DataFrame backend for this module +pd = _backend_library() + + +def test_interchange_protocol(): + pdf = pd.DataFrame({"a": [1, 2, 3], "b": 1}) + df = from_pandas(pdf, npartitions=2) + df_int = df.__dataframe__() + pd.testing.assert_index_equal(pdf.columns, df_int.column_names()) + assert df_int.num_columns() == 2 + assert df_int.num_rows() == 3 + column = df_int.get_columns()[0] + assert column.dtype()[0] == DtypeKind.INT