Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
… from `read_csv` & enabling more tests

### What changes were proposed in this pull request?

This PR proposes to remove `squeeze` parameter from `read_csv` to follow the behavior of latest pandas. See pandas-dev/pandas#40413 and pandas-dev/pandas#43427 for detail.

This PR also enables more tests for pandas 2.0.0 and above.

### Why are the changes needed?

To follow the behavior of latest pandas, and increase the test coverage.

### Does this PR introduce _any_ user-facing change?

`squeeze` will be no longer available from `read_csv`. Otherwise, it's test-only.

### How was this patch tested?

Enabling & updating the existing tests.

Closes #42551 from itholic/pandas_remaining_tests.

Authored-by: itholic <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
itholic authored and zhengruifeng committed Aug 22, 2023
1 parent 290b632 commit 8d4ca0a
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 135 deletions.
1 change: 1 addition & 0 deletions python/docs/source/migration_guide/pyspark_upgrade.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Upgrading from PySpark 3.5 to 4.0
* In Spark 4.0, ``sort_columns`` parameter from ``DataFrame.plot`` and `Series.plot`` has been removed from pandas API on Spark.
* In Spark 4.0, the default value of ``regex`` parameter for ``Series.str.replace`` has been changed from ``True`` to ``False`` from pandas API on Spark. Additionally, a single character ``pat`` with ``regex=True`` is now treated as a regular expression instead of a string literal.
* In Spark 4.0, the resulting name from ``value_counts`` for all objects sets to ``'count'`` (or ``'propotion'`` if ``nomalize=True`` was passed) from pandas API on Spark, and the index will be named after the original object.
* In Spark 4.0, ``squeeze`` parameter from ``ps.read_csv`` and ``ps.read_excel`` has been removed from pandas API on Spark.


Upgrading from PySpark 3.3 to 3.4
Expand Down
32 changes: 5 additions & 27 deletions python/pyspark/pandas/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ def read_csv(
names: Optional[Union[str, List[str]]] = None,
index_col: Optional[Union[str, List[str]]] = None,
usecols: Optional[Union[List[int], List[str], Callable[[str], bool]]] = None,
squeeze: bool = False,
mangle_dupe_cols: bool = True,
dtype: Optional[Union[str, Dtype, Dict[str, Union[str, Dtype]]]] = None,
nrows: Optional[int] = None,
Expand Down Expand Up @@ -262,11 +261,6 @@ def read_csv(
from the document header row(s).
If callable, the callable function will be evaluated against the column names,
returning names where the callable function evaluates to `True`.
squeeze : bool, default False
If the parsed data only contains one column then return a Series.
.. deprecated:: 3.4.0
mangle_dupe_cols : bool, default True
Duplicate columns will be specified as 'X0', 'X1', ... 'XN', rather
than 'X' ... 'X'. Passing in False will cause data to be overwritten if
Expand Down Expand Up @@ -466,10 +460,7 @@ def read_csv(
for col in psdf.columns:
psdf[col] = psdf[col].astype(dtype)

if squeeze and len(psdf.columns) == 1:
return first_series(psdf)
else:
return psdf
return psdf


def read_json(
Expand Down Expand Up @@ -912,7 +903,6 @@ def read_excel(
names: Optional[List] = None,
index_col: Optional[List[int]] = None,
usecols: Optional[Union[int, str, List[Union[int, str]], Callable[[str], bool]]] = None,
squeeze: bool = False,
dtype: Optional[Dict[str, Union[str, Dtype]]] = None,
engine: Optional[str] = None,
converters: Optional[Dict] = None,
Expand Down Expand Up @@ -985,11 +975,6 @@ def read_excel(
* If list of string, then indicates list of column names to be parsed.
* If callable, then evaluate each column name against it and parse the
column if the callable returns ``True``.
squeeze : bool, default False
If the parsed data only contains one column then return a Series.
.. deprecated:: 3.4.0
dtype : Type name or dict of column -> type, default None
Data type for data or columns. E.g. {'a': np.float64, 'b': np.int32}
Use `object` to preserve data as stored in Excel and not interpret dtype.
Expand Down Expand Up @@ -1142,7 +1127,7 @@ def read_excel(
"""

def pd_read_excel(
io_or_bin: Any, sn: Union[str, int, List[Union[str, int]], None], sq: bool
io_or_bin: Any, sn: Union[str, int, List[Union[str, int]], None]
) -> pd.DataFrame:
return pd.read_excel(
io=BytesIO(io_or_bin) if isinstance(io_or_bin, (bytes, bytearray)) else io_or_bin,
Expand All @@ -1151,7 +1136,6 @@ def pd_read_excel(
names=names,
index_col=index_col,
usecols=usecols,
squeeze=sq,
dtype=dtype,
engine=engine,
converters=converters,
Expand Down Expand Up @@ -1181,7 +1165,7 @@ def pd_read_excel(
io_or_bin = io
single_file = True

pdf_or_psers = pd_read_excel(io_or_bin, sn=sheet_name, sq=squeeze)
pdf_or_psers = pd_read_excel(io_or_bin, sn=sheet_name)

if single_file:
if isinstance(pdf_or_psers, dict):
Expand All @@ -1208,9 +1192,7 @@ def read_excel_on_spark(
)

def output_func(pdf: pd.DataFrame) -> pd.DataFrame:
pdf = pd.concat(
[pd_read_excel(bin, sn=sn, sq=False) for bin in pdf[pdf.columns[0]]]
)
pdf = pd.concat([pd_read_excel(bin, sn=sn) for bin in pdf[pdf.columns[0]]])

reset_index = pdf.reset_index()
for name, col in reset_index.items():
Expand All @@ -1231,11 +1213,7 @@ def output_func(pdf: pd.DataFrame) -> pd.DataFrame:
.mapInPandas(lambda iterator: map(output_func, iterator), schema=return_schema)
)

psdf = DataFrame(psdf._internal.with_new_sdf(sdf))
if squeeze and len(psdf.columns) == 1:
return first_series(psdf)
else:
return psdf
return DataFrame(psdf._internal.with_new_sdf(sdf))

if isinstance(pdf_or_psers, dict):
return {
Expand Down
18 changes: 0 additions & 18 deletions python/pyspark/pandas/tests/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,24 +255,6 @@ def test_read_csv_with_sep(self):
actual = ps.read_csv(fn, sep="\t")
self.assert_eq(expected, actual, almost=True)

@unittest.skipIf(
LooseVersion(pd.__version__) >= LooseVersion("2.0.0"),
"TODO(SPARK-43563): Enable CsvTests.test_read_csv_with_squeeze for pandas 2.0.0.",
)
def test_read_csv_with_squeeze(self):
with self.csv_file(self.csv_text) as fn:
expected = pd.read_csv(fn, squeeze=True, usecols=["name"])
actual = ps.read_csv(fn, squeeze=True, usecols=["name"])
self.assert_eq(expected, actual, almost=True)

expected = pd.read_csv(fn, squeeze=True, usecols=["name", "amount"])
actual = ps.read_csv(fn, squeeze=True, usecols=["name", "amount"])
self.assert_eq(expected, actual, almost=True)

expected = pd.read_csv(fn, squeeze=True, usecols=["name", "amount"], index_col=["name"])
actual = ps.read_csv(fn, squeeze=True, usecols=["name", "amount"], index_col=["name"])
self.assert_eq(expected, actual, almost=True)

def test_read_csv_with_mangle_dupe_cols(self):
self.assertRaisesRegex(
ValueError, "mangle_dupe_cols", lambda: ps.read_csv("path", mangle_dupe_cols=False)
Expand Down
16 changes: 3 additions & 13 deletions python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,6 @@ def tearDownClass(cls):
reset_option("compute.ops_on_diff_frames")
super().tearDownClass()

@unittest.skipIf(
LooseVersion(pd.__version__) >= LooseVersion("2.0.0"),
"TODO(SPARK-43460): Enable OpsOnDiffFramesGroupByTests.test_groupby_different_lengths "
"for pandas 2.0.0.",
)
def test_groupby_different_lengths(self):
pdfs1 = [
pd.DataFrame({"c": [4, 2, 7, 3, None, 1, 1, 1, 2], "d": list("abcdefght")}),
Expand Down Expand Up @@ -71,7 +66,7 @@ def sort(df):

self.assert_eq(
sort(psdf1.groupby(psdf2.a, as_index=as_index).sum()),
sort(pdf1.groupby(pdf2.a, as_index=as_index).sum()),
sort(pdf1.groupby(pdf2.a, as_index=as_index).sum(numeric_only=True)),
almost=as_index,
)

Expand All @@ -86,11 +81,6 @@ def sort(df):
almost=as_index,
)

@unittest.skipIf(
LooseVersion(pd.__version__) >= LooseVersion("2.0.0"),
"TODO(SPARK-43459): Enable OpsOnDiffFramesGroupByTests.test_groupby_multiindex_columns "
"for pandas 2.0.0.",
)
def test_groupby_multiindex_columns(self):
pdf1 = pd.DataFrame(
{("y", "c"): [4, 2, 7, 3, None, 1, 1, 1, 2], ("z", "d"): list("abcdefght")}
Expand All @@ -103,7 +93,7 @@ def test_groupby_multiindex_columns(self):

self.assert_eq(
psdf1.groupby(psdf2[("x", "a")]).sum().sort_index(),
pdf1.groupby(pdf2[("x", "a")]).sum().sort_index(),
pdf1.groupby(pdf2[("x", "a")]).sum(numeric_only=True).sort_index(),
)

self.assert_eq(
Expand All @@ -112,7 +102,7 @@ def test_groupby_multiindex_columns(self):
.sort_values(("y", "c"))
.reset_index(drop=True),
pdf1.groupby(pdf2[("x", "a")], as_index=False)
.sum()
.sum(numeric_only=True)
.sort_values(("y", "c"))
.reset_index(drop=True),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,35 @@ def _test_groupby_rolling_func(self, f):
getattr(pdf.groupby(pkey)[["b"]].rolling(2), f)().sort_index(),
)

@unittest.skipIf(
LooseVersion(pd.__version__) >= LooseVersion("2.0.0"),
"TODO(SPARK-43452): Enable RollingTests.test_groupby_rolling_count for pandas 2.0.0.",
)
def test_groupby_rolling_count(self):
self._test_groupby_rolling_func("count")
pser = pd.Series([1, 2, 3], name="a")
pkey = pd.Series([1, 2, 3], name="a")
psser = ps.from_pandas(pser)
kkey = ps.from_pandas(pkey)

# TODO(SPARK-43432): Fix `min_periods` for Rolling.count() to work same as pandas
self.assert_eq(
psser.groupby(kkey).rolling(2).count().sort_index(),
pser.groupby(pkey).rolling(2, min_periods=1).count().sort_index(),
)

pdf = pd.DataFrame({"a": [1, 2, 3, 2], "b": [4.0, 2.0, 3.0, 1.0]})
pkey = pd.Series([1, 2, 3, 2], name="a")
psdf = ps.from_pandas(pdf)
kkey = ps.from_pandas(pkey)

self.assert_eq(
psdf.groupby(kkey).rolling(2).count().sort_index(),
pdf.groupby(pkey).rolling(2, min_periods=1).count().sort_index(),
)
self.assert_eq(
psdf.groupby(kkey)["b"].rolling(2).count().sort_index(),
pdf.groupby(pkey)["b"].rolling(2, min_periods=1).count().sort_index(),
)
self.assert_eq(
psdf.groupby(kkey)[["b"]].rolling(2).count().sort_index(),
pdf.groupby(pkey)[["b"]].rolling(2, min_periods=1).count().sort_index(),
)

def test_groupby_rolling_min(self):
self._test_groupby_rolling_func("min")
Expand Down
Loading

0 comments on commit 8d4ca0a

Please sign in to comment.