Skip to content

Commit

Permalink
[SPARK-48045][PYTHON] Pandas API groupby with multi-agg-relabel ignor…
Browse files Browse the repository at this point in the history
…es as_index=False

### What changes were proposed in this pull request?
In a Scenario where we use GroupBy in PySpark API with relabeling of aggregate columns and using as_index = False,
the columns with which we group by are not returned in the DataFrame. The change proposes to fix this bug.

Example:
ps.DataFrame({"a": [0, 0], "b": [0, 1]}).groupby("a", as_index=False).agg(b_max=("b", "max"))

Result:
_  b_max
0      1

Required Result:
_  a  b_max
0  0      1

### Why are the changes needed?
The relabeling part of the code only uses only the aggregate columns. In a scenario where as_index=True, it is not an issue as the columns with which we group by are included in the index. When as_index=False, we need to append the columns with which we grouped by to the relabeling code.

Please, check the commits/PR for the code changes

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
- Passed GA
- Passed Build tests
- Unit Tested including scenarios in addition to the one provided in the Jira ticket

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#46391 from sinaiamonkar-sai/SPARK-48045-2.

Authored-by: sai <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
sai authored and HyukjinKwon committed May 8, 2024
1 parent a15adeb commit 67ae239
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
7 changes: 6 additions & 1 deletion python/pyspark/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def aggregate(
)

if not self._as_index:
index_cols = psdf._internal.column_labels
should_drop_index = set(
i for i, gkey in enumerate(self._groupkeys) if gkey._psdf is not self._psdf
)
Expand All @@ -322,8 +323,12 @@ def aggregate(
psdf = psdf.reset_index(level=should_drop_index, drop=drop)
if len(should_drop_index) < len(self._groupkeys):
psdf = psdf.reset_index()
index_cols = [c for c in psdf._internal.column_labels if c not in index_cols]
if relabeling:
psdf = psdf[pd.Index(index_cols + list(order))]
psdf.columns = pd.Index([c[0] for c in index_cols] + list(columns))

if relabeling:
if relabeling and self._as_index:
psdf = psdf[order]
psdf.columns = columns # type: ignore[assignment]
return psdf
Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,27 @@ def test_diff(self):
pdf.groupby([("x", "a"), ("x", "b")]).diff().sort_index(),
)

def test_aggregate_relabel_index_false(self):
pdf = pd.DataFrame(
{
"A": [0, 0, 1, 1, 1],
"B": ["a", "a", "b", "a", "b"],
"C": [10, 15, 10, 20, 30],
}
)
psdf = ps.from_pandas(pdf)

self.assert_eq(
pdf.groupby(["B", "A"], as_index=False)
.agg(C_MAX=("C", "max"))
.sort_values(["B", "A"])
.reset_index(drop=True),
psdf.groupby(["B", "A"], as_index=False)
.agg(C_MAX=("C", "max"))
.sort_values(["B", "A"])
.reset_index(drop=True),
)


class GroupByTests(
GroupByTestsMixin,
Expand Down

0 comments on commit 67ae239

Please sign in to comment.