From bda2577b90e143f4b1785852db2cd3c74ffaa744 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 14 Feb 2024 23:41:09 +0100 Subject: [PATCH] Fix reads from local dir that changes directory --- dask_expr/_collection.py | 19 +++++++++++++++++++ dask_expr/io/csv.py | 2 ++ dask_expr/io/parquet.py | 2 ++ dask_expr/io/tests/test_io.py | 34 +++++++++++++++++++++++++++++++--- 4 files changed, 54 insertions(+), 3 deletions(-) diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index ce39a1c7e..646b2c3ee 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -3,6 +3,7 @@ import datetime import functools import inspect +import os import warnings from collections.abc import Callable, Hashable, Mapping from numbers import Integral, Number @@ -4526,6 +4527,7 @@ def read_csv( storage_options=storage_options, kwargs=kwargs, header=header, + _cwd=_get_cwd(path, kwargs), ) ) @@ -4551,6 +4553,7 @@ def read_table( storage_options=storage_options, kwargs=kwargs, header=header, + _cwd=_get_cwd(path, kwargs), ) ) @@ -4576,10 +4579,25 @@ def read_fwf( storage_options=storage_options, kwargs=kwargs, header=header, + _cwd=_get_cwd(path, kwargs), ) ) +def _get_protocol(urlpath): + if "://" in urlpath: + protocol, _ = urlpath.split("://", 1) + if len(protocol) > 1: + # excludes Windows paths + return protocol + return None + + +def _get_cwd(path, kwargs): + protocol = kwargs.pop("protocol", None) or _get_protocol(path) or "file" + return os.getcwd() if protocol == "file" else None + + def read_parquet( path=None, columns=None, @@ -4630,6 +4648,7 @@ def read_parquet( filesystem=filesystem, engine=_set_parquet_engine(engine), kwargs=kwargs, + _cwd=_get_cwd(path, kwargs), _series=isinstance(columns, str), ) ) diff --git a/dask_expr/io/csv.py b/dask_expr/io/csv.py index 6be02153d..7bc0618e4 100644 --- a/dask_expr/io/csv.py +++ b/dask_expr/io/csv.py @@ -14,6 +14,7 @@ class ReadCSV(PartitionsFiltered, BlockwiseIO): "_partitions", "storage_options", "kwargs", + "_cwd", # needed for tokenization "_series", ] _defaults = { @@ -24,6 +25,7 @@ class ReadCSV(PartitionsFiltered, BlockwiseIO): "_partitions": None, "storage_options": None, "_series": False, + "_cwd": None, } _absorb_projections = True diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 8cd9e8c3d..f722a4d8f 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -426,6 +426,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "filesystem", "engine", "kwargs", + "_cwd", # needed for tokenization "_partitions", "_series", "_dataset_info_cache", @@ -449,6 +450,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "_partitions": None, "_series": False, "_dataset_info_cache": None, + "_cwd": None, } _pq_length_stats = None _absorb_projections = True diff --git a/dask_expr/io/tests/test_io.py b/dask_expr/io/tests/test_io.py index cca5d63c0..62619c314 100644 --- a/dask_expr/io/tests/test_io.py +++ b/dask_expr/io/tests/test_io.py @@ -1,5 +1,6 @@ import glob import os +from pathlib import Path import dask.array as da import dask.dataframe as dd @@ -30,14 +31,14 @@ pd = _backend_library() -def _make_file(dir, format="parquet", df=None): +def _make_file(dir, format="parquet", df=None, **kwargs): fn = os.path.join(str(dir), f"myfile.{format}") if df is None: df = pd.DataFrame({c: range(10) for c in "abcde"}) if format == "csv": - df.to_csv(fn) + df.to_csv(fn, **kwargs) elif format == "parquet": - df.to_parquet(fn) + df.to_parquet(fn, **kwargs) else: ValueError(f"{format} not a supported format") return fn @@ -413,6 +414,33 @@ def test_combine_similar_no_projection_on_one_branch(tmpdir): assert_eq(df, pdf) +@pytest.mark.parametrize( + "fmt, func, kwargs", + [ + ("parquet", read_parquet, {}), + ("csv", read_csv, {"index": False}), + ], +) +def test_chdir_different_files(tmpdir, fmt, func, kwargs): + cwd = os.getcwd() + + try: + pdf = pd.DataFrame({"x": [0, 1, 2, 3] * 4, "y": range(16)}) + os.chdir(tmpdir) + _make_file(tmpdir, format=fmt, df=pdf, **kwargs) + df = func(f"myfile.{fmt}") + + new_dir = Path(tmpdir).joinpath("new_dir") + new_dir.mkdir() + os.chdir(new_dir) + pdf2 = pd.DataFrame({"x": [0, 100, 200, 300] * 4, "y": range(16)}) + _make_file(new_dir, format=fmt, df=pdf2, **kwargs) + df2 = func(f"myfile.{fmt}") + assert_eq(df.sum() + df2.sum(), pd.Series([2424, 240], index=["x", "y"])) + finally: + os.chdir(cwd) + + @pytest.mark.parametrize("meta", [True, False]) @pytest.mark.parametrize("label", [None, "foo"]) @pytest.mark.parametrize("allow_projection", [True, False])