Skip to content

Commit

Permalink
SNOW-760037: Fix set operations when it's after another set operator …
Browse files Browse the repository at this point in the history
…with column changes (snowflakedb#727)
  • Loading branch information
sfc-gh-yixie authored Mar 24, 2023
1 parent 491599a commit d277f1a
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
- Added support for `delimiters` parameter in `functions.initcap()`.
- Added support for `functions.hash()` to accept a variable number of input expressions.

### Bug Fixes
- Fixed a bug where a DataFrame set operation(`DataFrame.substract`, `DataFrame.union`, etc.) being called after another DataFrame set operation and `DataFrame.select` or `DataFrame.with_column` throws an exception.

## 1.2.0 (2023-03-02)

### New Features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,11 @@ def set_operator(
],
operator: str,
) -> "SelectStatement":
if isinstance(self.from_, SetStatement) and not self.has_clause:
if (
isinstance(self.from_, SetStatement)
and not self.has_clause
and not self.projection
):
last_operator = self.from_.set_operands[-1].operator
if operator == last_operator:
existing_set_operands = self.from_.set_operands
Expand Down
6 changes: 2 additions & 4 deletions src/snowflake/snowpark/dataframe_stat_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,10 @@ def approx_quantile(
Examples::
>>> df = session.create_dataframe([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], schema=["a"])
>>> df.stat.approx_quantile("a", [0, 0.1, 0.4, 0.6, 1])
[-0.5, 0.5, 3.5, 5.5, 9.5]
>>> df.stat.approx_quantile("a", [0, 0.1, 0.4, 0.6, 1]) # doctest: +SKIP
>>> df2 = session.create_dataframe([[0.1, 0.5], [0.2, 0.6], [0.3, 0.7]], schema=["a", "b"])
>>> df2.stat.approx_quantile(["a", "b"], [0, 0.1, 0.6])
[[0.05, 0.15000000000000002, 0.25], [0.45, 0.55, 0.6499999999999999]]
>>> df2.stat.approx_quantile(["a", "b"], [0, 0.1, 0.6]) # doctest: +SKIP
Args:
col: The name of the numeric column.
Expand Down
17 changes: 13 additions & 4 deletions tests/integ/scala/test_dataframe_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,10 @@ def test_df_stat_approx_quantile(session):
) == [4.5]
assert TestData.approx_numbers(session).stat.approx_quantile(
"a", [0, 0.1, 0.4, 0.6, 1]
) == [-0.5, 0.5, 3.5, 5.5, 9.5]
) in (
[-0.5, 0.5, 3.5, 5.5, 9.5], # old behavior of Snowflake
[0.0, 0.9, 3.6, 5.3999999999999995, 9.0],
) # new behavior of Snowflake

with pytest.raises(SnowparkSQLException) as exec_info:
TestData.approx_numbers(session).stat.approx_quantile("a", [-1])
Expand All @@ -511,9 +514,15 @@ def test_df_stat_approx_quantile(session):
assert session.table(table_name).stat.approx_quantile("num", [0.5])[0] is None

res = TestData.double2(session).stat.approx_quantile(["a", "b"], [0, 0.1, 0.6])
Utils.assert_rows(
res, [[0.05, 0.15000000000000002, 0.25], [0.45, 0.55, 0.6499999999999999]]
)
try:
Utils.assert_rows(
res,
[[0.05, 0.15000000000000002, 0.25], [0.45, 0.55, 0.6499999999999999]],
) # old behavior of Snowflake
except AssertionError:
Utils.assert_rows(
res, [[0.1, 0.12000000000000001, 0.22], [0.5, 0.52, 0.62]]
) # new behavior of Snowflake

# ApproxNumbers2 contains a column called T, which conflicts with tmpColumnName.
# This test demos that the query still works.
Expand Down
20 changes: 20 additions & 0 deletions tests/integ/test_simplifier_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,26 @@ def get_max_nesting_depth(query):
)


def test_set_after_set(session):
df = session.createDataFrame([(1, "one"), (2, "two"), (3, "one"), (4, "two")])
df2 = session.createDataFrame([(3, "one"), (4, "two")])

df_new = df.subtract(df2).with_column("NEW_COLUMN", lit(True))
df_new_2 = df.subtract(df2).with_column("NEW_COLUMN", lit(True))
df_union = df_new.union_all(df_new_2)
Utils.check_answer(
df_union,
[
Row(1, "one", True),
Row(1, "one", True),
Row(2, "two", True),
Row(2, "two", True),
],
sort=True,
)
assert df_union.columns == ["_1", "_2", "NEW_COLUMN"]


def test_select_new_columns(session, simplifier_table):
"""The query adds columns that reference columns unchanged in the subquery."""
df = session.table(simplifier_table)
Expand Down

0 comments on commit d277f1a

Please sign in to comment.