Skip to content

Commit

Permalink
fix: nw.Series could not be pickled (#1488)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Dec 2, 2024
1 parent 15b9b7f commit 37638d2
Show file tree
Hide file tree
Showing 38 changed files with 686 additions and 540 deletions.
20 changes: 20 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,26 @@ repos:
language: python
files: ^narwhals/
exclude: ^narwhals/dependencies\.py
- id: dtypes-import
name: don't import from narwhals.dtypes (use `import_dtypes_module` instead)
entry: |
(?x)
import\ narwhals.dtypes|
from\ narwhals\ import\ dtypes|
from\ narwhals.dtypes\ import\ [^D_]+|
import\ narwhals.stable.v1.dtypes|
from\ narwhals.stable.v1\ import\ dtypes|
from\ narwhals.stable.v1.dtypes\ import
language: pygrep
files: ^narwhals/
exclude: |
(?x)
^(
narwhals/utils\.py|
narwhals/stable/v1/_dtypes.py|
narwhals/.*__init__.py|
narwhals/.*typing\.py
)
- repo: https://github.com/kynan/nbstripout
rev: 0.8.0
hooks:
Expand Down
16 changes: 8 additions & 8 deletions docs/how_it_works.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ import pandas as pd
import narwhals as nw
from narwhals._pandas_like.namespace import PandasLikeNamespace
from narwhals._pandas_like.utils import Implementation
from narwhals.utils import parse_version
from narwhals.utils import parse_version, Version

pn = PandasLikeNamespace(
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
dtypes=nw.dtypes,
version=Version.MAIN,
)
print(nw.col("a")._call(pn))
```
Expand All @@ -96,21 +96,21 @@ import narwhals as nw
from narwhals._pandas_like.namespace import PandasLikeNamespace
from narwhals._pandas_like.utils import Implementation
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
from narwhals.utils import parse_version
from narwhals.utils import parse_version, Version
import pandas as pd

pn = PandasLikeNamespace(
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
dtypes=nw.dtypes,
version=Version.MAIN,
)

df_pd = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
df = PandasLikeDataFrame(
df_pd,
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
dtypes=nw.dtypes,
version=Version.MAIN,
)
expression = pn.col("a") + 1
result = expression._call(df)
Expand Down Expand Up @@ -193,13 +193,13 @@ import narwhals as nw
from narwhals._pandas_like.namespace import PandasLikeNamespace
from narwhals._pandas_like.utils import Implementation
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
from narwhals.utils import parse_version
from narwhals.utils import parse_version, Version
import pandas as pd

pn = PandasLikeNamespace(
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
dtypes=nw.dtypes,
version=Version.MAIN,
)

df_pd = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
Expand All @@ -214,7 +214,7 @@ backend, and it does so by passing a Narwhals-compliant namespace to `nw.Expr._c
pn = PandasLikeNamespace(
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
dtypes=nw.dtypes,
version=Version.MAIN,
)
expr = (nw.col("a") + 1)._call(pn)
print(expr)
Expand Down
34 changes: 18 additions & 16 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import IntoArrowExpr
from narwhals.dtypes import DType
from narwhals.typing import DTypes
from narwhals.utils import Version


class ArrowDataFrame:
Expand All @@ -45,17 +45,19 @@ def __init__(
native_dataframe: pa.Table,
*,
backend_version: tuple[int, ...],
dtypes: DTypes,
version: Version,
) -> None:
self._native_frame = native_dataframe
self._implementation = Implementation.PYARROW
self._backend_version = backend_version
self._dtypes = dtypes
self._version = version

def __narwhals_namespace__(self: Self) -> ArrowNamespace:
from narwhals._arrow.namespace import ArrowNamespace

return ArrowNamespace(backend_version=self._backend_version, dtypes=self._dtypes)
return ArrowNamespace(
backend_version=self._backend_version, version=self._version
)

def __native_namespace__(self: Self) -> ModuleType:
if self._implementation is Implementation.PYARROW:
Expand All @@ -70,14 +72,14 @@ def __narwhals_dataframe__(self: Self) -> Self:
def __narwhals_lazyframe__(self: Self) -> Self:
return self

def _change_dtypes(self: Self, dtypes: DTypes) -> Self:
def _change_dtypes(self: Self, version: Version) -> Self:
return self.__class__(
self._native_frame, backend_version=self._backend_version, dtypes=dtypes
self._native_frame, backend_version=self._backend_version, version=version
)

def _from_native_frame(self: Self, df: pa.Table) -> Self:
return self.__class__(
df, backend_version=self._backend_version, dtypes=self._dtypes
df, backend_version=self._backend_version, version=self._version
)

@property
Expand Down Expand Up @@ -143,7 +145,7 @@ def get_column(self: Self, name: str) -> ArrowSeries:
self._native_frame[name],
name=name,
backend_version=self._backend_version,
dtypes=self._dtypes,
version=self._version,
)

def __array__(self: Self, dtype: Any, copy: bool | None) -> np.ndarray:
Expand Down Expand Up @@ -186,7 +188,7 @@ def __getitem__(
self._native_frame[item],
name=item,
backend_version=self._backend_version,
dtypes=self._dtypes,
version=self._version,
)
elif (
isinstance(item, tuple)
Expand Down Expand Up @@ -231,14 +233,14 @@ def __getitem__(
self._native_frame[col_name],
name=col_name,
backend_version=self._backend_version,
dtypes=self._dtypes,
version=self._version,
)
selected_rows = select_rows(self._native_frame, item[0])
return ArrowSeries(
selected_rows[col_name],
name=col_name,
backend_version=self._backend_version,
dtypes=self._dtypes,
version=self._version,
)

elif isinstance(item, slice):
Expand Down Expand Up @@ -276,7 +278,7 @@ def __getitem__(
def schema(self: Self) -> dict[str, DType]:
schema = self._native_frame.schema
return {
name: native_to_narwhals_dtype(dtype, self._dtypes)
name: native_to_narwhals_dtype(dtype, self._version)
for name, dtype in zip(schema.names, schema.types)
}

Expand Down Expand Up @@ -463,7 +465,7 @@ def to_dict(
col,
name=name,
backend_version=self._backend_version,
dtypes=self._dtypes,
version=self._version,
)
for name, col in names_and_values
}
Expand Down Expand Up @@ -530,7 +532,7 @@ def collect(self: Self) -> ArrowDataFrame:
return ArrowDataFrame(
self._native_frame,
backend_version=self._backend_version,
dtypes=self._dtypes,
version=self._version,
)

def clone(self: Self) -> Self:
Expand Down Expand Up @@ -611,7 +613,7 @@ def is_duplicated(self: Self) -> ArrowSeries:
is_duplicated,
name="",
backend_version=self._backend_version,
dtypes=self._dtypes,
version=self._version,
)

def is_unique(self: Self) -> ArrowSeries:
Expand All @@ -625,7 +627,7 @@ def is_unique(self: Self) -> ArrowSeries:
pc.invert(is_duplicated),
name="",
backend_version=self._backend_version,
dtypes=self._dtypes,
version=self._version,
)

def unique(
Expand Down
40 changes: 21 additions & 19 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import IntoArrowExpr
from narwhals.dtypes import DType
from narwhals.typing import DTypes
from narwhals.utils import Version


class ArrowExpr:
Expand All @@ -36,7 +36,7 @@ def __init__(
root_names: list[str] | None,
output_names: list[str] | None,
backend_version: tuple[int, ...],
dtypes: DTypes,
version: Version,
) -> None:
self._call = call
self._depth = depth
Expand All @@ -46,7 +46,7 @@ def __init__(
self._output_names = output_names
self._implementation = Implementation.PYARROW
self._backend_version = backend_version
self._dtypes = dtypes
self._version = version

def __repr__(self: Self) -> str: # pragma: no cover
return (
Expand All @@ -62,7 +62,7 @@ def from_column_names(
cls: type[Self],
*column_names: str,
backend_version: tuple[int, ...],
dtypes: DTypes,
version: Version,
) -> Self:
from narwhals._arrow.series import ArrowSeries

Expand All @@ -73,7 +73,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
df._native_frame[column_name],
name=column_name,
backend_version=df._backend_version,
dtypes=df._dtypes,
version=df._version,
)
for column_name in column_names
]
Expand All @@ -91,15 +91,15 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
root_names=list(column_names),
output_names=list(column_names),
backend_version=backend_version,
dtypes=dtypes,
version=version,
)

@classmethod
def from_column_indices(
cls: type[Self],
*column_indices: int,
backend_version: tuple[int, ...],
dtypes: DTypes,
version: Version,
) -> Self:
from narwhals._arrow.series import ArrowSeries

Expand All @@ -109,7 +109,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
df._native_frame[column_index],
name=df._native_frame.column_names[column_index],
backend_version=df._backend_version,
dtypes=df._dtypes,
version=df._version,
)
for column_index in column_indices
]
Expand All @@ -121,13 +121,15 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
root_names=None,
output_names=None,
backend_version=backend_version,
dtypes=dtypes,
version=version,
)

def __narwhals_namespace__(self: Self) -> ArrowNamespace:
from narwhals._arrow.namespace import ArrowNamespace

return ArrowNamespace(backend_version=self._backend_version, dtypes=self._dtypes)
return ArrowNamespace(
backend_version=self._backend_version, version=self._version
)

def __narwhals_expr__(self: Self) -> None: ...

Expand Down Expand Up @@ -287,7 +289,7 @@ def alias(self: Self, name: str) -> Self:
root_names=self._root_names,
output_names=[name],
backend_version=self._backend_version,
dtypes=self._dtypes,
version=self._version,
)

def null_count(self: Self) -> Self:
Expand Down Expand Up @@ -406,7 +408,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
root_names=self._root_names,
output_names=self._output_names,
backend_version=self._backend_version,
dtypes=self._dtypes,
version=self._version,
)

def mode(self: Self) -> Self:
Expand Down Expand Up @@ -447,7 +449,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
root_names=self._root_names,
output_names=self._output_names,
backend_version=self._backend_version,
dtypes=self._dtypes,
version=self._version,
)

def is_finite(self: Self) -> Self:
Expand Down Expand Up @@ -726,7 +728,7 @@ def keep(self: Self) -> ArrowExpr:
root_names=root_names,
output_names=root_names,
backend_version=self._expr._backend_version,
dtypes=self._expr._dtypes,
version=self._expr._version,
)

def map(self: Self, function: Callable[[str], str]) -> ArrowExpr:
Expand All @@ -752,7 +754,7 @@ def map(self: Self, function: Callable[[str], str]) -> ArrowExpr:
root_names=root_names,
output_names=output_names,
backend_version=self._expr._backend_version,
dtypes=self._expr._dtypes,
version=self._expr._version,
)

def prefix(self: Self, prefix: str) -> ArrowExpr:
Expand All @@ -776,7 +778,7 @@ def prefix(self: Self, prefix: str) -> ArrowExpr:
root_names=root_names,
output_names=output_names,
backend_version=self._expr._backend_version,
dtypes=self._expr._dtypes,
version=self._expr._version,
)

def suffix(self: Self, suffix: str) -> ArrowExpr:
Expand All @@ -801,7 +803,7 @@ def suffix(self: Self, suffix: str) -> ArrowExpr:
root_names=root_names,
output_names=output_names,
backend_version=self._expr._backend_version,
dtypes=self._expr._dtypes,
version=self._expr._version,
)

def to_lowercase(self: Self) -> ArrowExpr:
Expand All @@ -826,7 +828,7 @@ def to_lowercase(self: Self) -> ArrowExpr:
root_names=root_names,
output_names=output_names,
backend_version=self._expr._backend_version,
dtypes=self._expr._dtypes,
version=self._expr._version,
)

def to_uppercase(self: Self) -> ArrowExpr:
Expand All @@ -851,5 +853,5 @@ def to_uppercase(self: Self) -> ArrowExpr:
root_names=root_names,
output_names=output_names,
backend_version=self._expr._backend_version,
dtypes=self._expr._dtypes,
version=self._expr._version,
)
Loading

0 comments on commit 37638d2

Please sign in to comment.