Skip to content

Commit

Permalink
SNOW-704112 explode related methods (snowflakedb#762)
Browse files Browse the repository at this point in the history
* initial commit

* add examples

* update changelog

* add tests

* add explode in functions.rst

* use user_visible_name for table function name

* fix merge

* use output dict in SnowflakePlan

* make output_dict more general purpose
  • Loading branch information
sfc-gh-aalam authored Apr 3, 2023
1 parent f394fea commit a376a04
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### New Features

- Added support for `explode` function in `snowflake.snowpark.functions`.
- Added parameter `skip_upload_on_content_match` when creating UDF, UDTF and Stored Procedure using `register_from_file` to skip file uploads to stage in case the files are already present on stage.

## 1.3.0 (2023-03-28)
Expand Down
1 change: 1 addition & 0 deletions docs/source/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ Functions
endswith
equal_nan
exp
explode
expr
factorial
first_value
Expand Down
9 changes: 9 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def __init__(
# We need to copy this list since we don't want to change it for the
# previous SnowflakePlan objects
self.api_calls = api_calls.copy() if api_calls else []
self._output_dict = None

def with_subqueries(self, subquery_plans: List["SnowflakePlan"]) -> "SnowflakePlan":
pre_queries = self.queries[:-1]
Expand Down Expand Up @@ -230,6 +231,14 @@ def attributes(self) -> List[Attribute]:
def output(self) -> List[Attribute]:
return [Attribute(a.name, a.datatype, a.nullable) for a in self.attributes]

@property
def output_dict(self) -> Dict[str, Any]:
if self._output_dict is None:
self._output_dict = {
attr.name: (attr.datatype, attr.nullable) for attr in self.output
}
return self._output_dict

def __copy__(self) -> "SnowflakePlan":
return SnowflakePlan(
self.queries.copy() if self.queries else [],
Expand Down
14 changes: 10 additions & 4 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@
from snowflake.snowpark.table_function import (
TableFunctionCall,
_create_table_function_expression,
_ExplodeFunctionCall,
_get_cols_after_explode_join,
_get_cols_after_join_table,
)
from snowflake.snowpark.types import StringType, StructType, _NumericType
Expand Down Expand Up @@ -980,16 +982,20 @@ def select(
if table_func:
raise ValueError(
f"At most one table function can be called inside a select(). "
f"Called '{table_func.name}' and '{e.name}'."
f"Called '{table_func.user_visible_name}' and '{e.user_visible_name}'."
)
table_func = e
func_expr = _create_table_function_expression(func=table_func)
join_plan = self._session._analyzer.resolve(
TableFunctionJoin(self._plan, func_expr)
)
_, new_cols = _get_cols_after_join_table(
func_expr, self._plan, join_plan
)

if isinstance(e, _ExplodeFunctionCall):
new_cols = _get_cols_after_explode_join(e, self._plan)
else:
_, new_cols = _get_cols_after_join_table(
func_expr, self._plan, join_plan
)
names.extend(new_cols)
else:
raise TypeError(
Expand Down
56 changes: 56 additions & 0 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from snowflake.snowpark.table_function import TableFunctionCall, _ExplodeFunctionCall

"""
Provides utility and SQL functions that generate :class:`~snowflake.snowpark.Column` expressions that you can pass to :class:`~snowflake.snowpark.DataFrame` transformation methods.
Expand Down Expand Up @@ -890,6 +892,60 @@ def approx_percentile_combine(state: ColumnOrName) -> Column:
return builtin("approx_percentile_combine")(c)


def explode(col: ColumnOrName) -> TableFunctionCall:
"""Flattens a given array or map type column into individual rows. The default
column name for the output column in case of array input column is ``VALUE``,
and is ``KEY`` and ``VALUE`` in case of map input column.
Examples::
>>> df = session.create_dataframe([[1, [1, 2, 3], {"Ashi Garami": "Single Leg X"}, "Kimura"],
... [2, [11, 22], {"Sankaku": "Triangle"}, "Coffee"]],
... schema=["idx", "lists", "maps", "strs"])
>>> df.select(df.idx, explode(df.lists)).show()
-------------------
|"IDX" |"VALUE" |
-------------------
|1 |1 |
|1 |2 |
|1 |3 |
|2 |11 |
|2 |22 |
-------------------
<BLANKLINE>
>>> df.select(df.strs, explode(df.maps)).show()
-----------------------------------------
|"STRS" |"KEY" |"VALUE" |
-----------------------------------------
|Kimura |Ashi Garami |"Single Leg X" |
|Coffee |Sankaku |"Triangle" |
-----------------------------------------
<BLANKLINE>
>>> df.select(explode(col("lists")).alias("uno")).show()
---------
|"UNO" |
---------
|1 |
|2 |
|3 |
|11 |
|22 |
---------
<BLANKLINE>
>>> df.select(explode('maps').as_("primo", "secundo")).show()
--------------------------------
|"PRIMO" |"SECUNDO" |
--------------------------------
|Ashi Garami |"Single Leg X" |
|Sankaku |"Triangle" |
--------------------------------
<BLANKLINE>
"""
return _ExplodeFunctionCall(col)


def grouping(*cols: ColumnOrName) -> Column:
"""
Describes which of a list of expressions are grouped in a row produced by a GROUP BY query.
Expand Down
50 changes: 50 additions & 0 deletions src/snowflake/snowpark/table_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from snowflake.snowpark._internal.type_utils import ColumnOrName
from snowflake.snowpark._internal.utils import validate_object_name
from snowflake.snowpark.column import Column, _to_col_if_str
from snowflake.snowpark.types import ArrayType, MapType

from ._internal.analyzer.snowflake_plan import SnowflakePlan

Expand All @@ -39,6 +40,7 @@ def __init__(
if func_arguments and func_named_arguments:
raise ValueError("A table function shouldn't have both args and named args")
self.name: str = func_name #: The table function name
self.user_visible_name: str = func_name
self.arguments: Iterable[
ColumnOrName
] = func_arguments #: The positional arguments used to call this table function.
Expand Down Expand Up @@ -131,6 +133,18 @@ def alias(self, *aliases: str) -> "TableFunctionCall":
as_ = alias


class _ExplodeFunctionCall(TableFunctionCall):
"""Internal class to identify explode function call as a special instance of TableFunctionCall"""

def __init__(self, col: ColumnOrName) -> None:
super().__init__("flatten", col)
if isinstance(col, Column):
self.col = col._expression.name
else:
self.col = quote_name(col)
self.user_visible_name: str = "explode"


def _create_order_by_expression(e: Union[str, Column]) -> SortOrder:
if isinstance(e, str):
return SortOrder(Column(e)._expression, Ascending())
Expand Down Expand Up @@ -233,3 +247,39 @@ def get_column_names_from_plan(plan: SnowflakePlan) -> List[str]:
new_cols = [Column(col)._named() for col in new_cols]

return old_cols, new_cols


def _get_cols_after_explode_join(
func: _ExplodeFunctionCall, plan: SnowflakePlan
) -> List:
explode_col_type = plan.output_dict.get(func.col, [None])[0]

cols = []
if isinstance(explode_col_type, ArrayType):
if func._aliases:
if len(func._aliases) != 1:
raise ValueError(
f"Invalid number of aliases given for explode. Expecting 1, got {len(func._aliases)}"
)
cols.append(Column("VALUE").alias(func._aliases[0])._named())
else:
cols.append(Column("VALUE")._named())
elif isinstance(explode_col_type, MapType):
if func._aliases:
if len(func._aliases) != 2:
raise ValueError(
f"Invalid number of aliases given for explode. Expecting 2, got {len(func._aliases)}"
)
cols.extend(
[
Column("KEY").as_(func._aliases[0])._named(),
Column("VALUE").as_(func._aliases[1])._named(),
]
)
else:
cols.extend([Column("KEY")._named(), Column("VALUE")._named()])
else:
raise ValueError(
f"Invalid column type for explode({func.col}). Expected ArrayType() or MapType(), got {explode_col_type}"
)
return cols
89 changes: 89 additions & 0 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
col,
concat,
count,
explode,
lit,
seq1,
seq2,
Expand Down Expand Up @@ -571,6 +572,94 @@ def process(self, n: int) -> Iterable[Tuple[int, int]]:
)


def test_explode(session):
df = session.create_dataframe(
[[1, [1, 2, 3], {"a": "b"}, "Kimura"]], schema=["idx", "lists", "maps", "strs"]
)

# col is str
expected_result = [
Row(value="1"),
Row(value="2"),
Row(value="3"),
]
Utils.check_answer(df.select(explode("lists")), expected_result)

expected_result = [Row(key="a", value='"b"')]
Utils.check_answer(df.select(explode("maps")), expected_result)

# col is Column
expected_result = [
Row(value="1"),
Row(value="2"),
Row(value="3"),
]
Utils.check_answer(df.select(explode(col("lists"))), expected_result)

expected_result = [Row(key="a", value='"b"')]
Utils.check_answer(df.select(explode(df.maps)), expected_result)

# with other non table cols
expected_result = [
Row(idx=1, value="1"),
Row(idx=1, value="2"),
Row(idx=1, value="3"),
]
Utils.check_answer(df.select(df.idx, explode(col("lists"))), expected_result)

expected_result = [Row(strs="Kimura", key="a", value='"b"')]
Utils.check_answer(df.select(df.strs, explode(df.maps)), expected_result)

# with alias
expected_result = [
Row(idx=1, uno="1"),
Row(idx=1, uno="2"),
Row(idx=1, uno="3"),
]
Utils.check_answer(
df.select(df.idx, explode(col("lists")).alias("uno")), expected_result
)

expected_result = [Row(strs="Kimura", primo="a", secundo='"b"')]
Utils.check_answer(
df.select(df.strs, explode(df.maps).as_("primo", "secundo")), expected_result
)


def test_explode_negative(session):
df = session.create_dataframe(
[[1, [1, 2, 3], {"a": "b"}, "Kimura"]], schema=["idx", "lists", "maps", "strs"]
)
split_to_table = table_function("split_to_table")

# mix explode and table function
with pytest.raises(
ValueError, match="At most one table function can be called inside"
):
df.select(split_to_table(df.strs, lit("")), explode(df.lists))

# mismatch in number of alias given array
with pytest.raises(
ValueError,
match="Invalid number of aliases given for explode. Expecting 1, got 2",
):
df.select(explode(df.lists).alias("key", "val"))

# mismatch in number of alias given map
with pytest.raises(
ValueError,
match="Invalid number of aliases given for explode. Expecting 2, got 1",
):
df.select(explode(df.maps).alias("val"))

# invalid column type
with pytest.raises(ValueError, match="Invalid column type for explode"):
df.select(explode(df.idx))

with pytest.raises(ValueError, match="Invalid column type for explode"):
df.select(explode(col("DOES_NOT_EXIST")))


@pytest.mark.udf
def test_with_column(session):
df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
Expand Down

0 comments on commit a376a04

Please sign in to comment.