Skip to content

Commit

Permalink
[SPARK-44748][SQL] Query execution for the PARTITION BY clause in UDT…
Browse files Browse the repository at this point in the history
…F TABLE arguments

### What changes were proposed in this pull request?

This PR implements query execution support for for the PARTITION BY and ORDER BY clauses for UDTF TABLE arguments.

* The query planning support was added in [1] and [2] and [3]. After those changes, the planner added a projection to compute the PARTITION BY expressions, plus a repartition operator, plus a sort operator.
* In this PR, the Python executor receives the indexes of these expressions within the input table's rows, and compares the values of the projected partitioning expressions between consecutive rows.
* When the values change, this marks the boundary between partitions, and so we call the UDTF instance's `terminate` method, then destroy it and create a new one for the next partition.

[1] #42100
[2] #42174
[3] #42351

Example:

```
# Make a test UDTF to yield an output row with the same value
# consumed from the last input row in the input table or partition.
class TestUDTF:
    def eval(self, row: Row):
	self._last = row['input']
	self._partition_col = row['partition_col']

    def terminate(self):
	yield self._partition_col, self._last

func = udtf(TestUDTF, returnType='partition_col: int, last: int')
self.spark.udtf.register('test_udtf', func)
self.spark.sql('''
    WITH t AS (
        SELECT id AS partition_col, 1 AS input FROM range(0, 2) UNION ALL
        SELECT id AS partition_col, 2 AS input FROM range(0, 2)
    )
    SELECT *
    FROM test_udtf(TABLE(t) PARTITION BY partition_col ORDER BY input)
    ''').collect()

> [Row(partition_col=0, last=2), (partition_col=1, last=2)]
```

### Why are the changes needed?

This brings full end-to-end execution for the PARTITION BY and/or ORDER BY clauses for UDTF TABLE arguments.

### Does this PR introduce _any_ user-facing change?

Yes, see above.

### How was this patch tested?

This PR adds end-to-end testing in `test_udtf.py`.

Closes #42420 from dtenedor/inspect-partition-by.

Authored-by: Daniel Tenedorio <[email protected]>
Signed-off-by: Takuya UESHIN <[email protected]>
  • Loading branch information
dtenedor authored and ueshin committed Aug 21, 2023
1 parent 109e9cd commit 7f3a439
Show file tree
Hide file tree
Showing 7 changed files with 313 additions and 37 deletions.
167 changes: 167 additions & 0 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1964,6 +1964,173 @@ def eval(self, a, b=100):
with self.subTest(query_no=i):
assertDataFrameEqual(df, [Row(a=10, b="z")])

def test_udtf_with_table_argument_and_partition_by(self):
class TestUDTF:
def __init__(self):
self._sum = 0
self._partition_col = None

def eval(self, row: Row):
self._sum += row["input"]
if self._partition_col is not None and self._partition_col != row["partition_col"]:
# Make sure that all values of the partitioning column are the same
# for each row consumed by this method for this instance of the class.
raise Exception(
f"self._partition_col was {self._partition_col} but the row "
+ f"value was {row['partition_col']}"
)
self._partition_col = row["partition_col"]

def terminate(self):
yield self._partition_col, self._sum

# This is a basic example.
func = udtf(TestUDTF, returnType="partition_col: int, total: int")
self.spark.udtf.register("test_udtf", func)
self.assertEqual(
self.spark.sql(
"""
WITH t AS (
SELECT id AS partition_col, 1 AS input FROM range(1, 21)
UNION ALL
SELECT id AS partition_col, 2 AS input FROM range(1, 21)
)
SELECT partition_col, total
FROM test_udtf(TABLE(t) PARTITION BY partition_col - 1)
ORDER BY 1, 2
"""
).collect(),
[Row(partition_col=x, total=3) for x in range(1, 21)],
)

# These cases partition by constant values.
for str_first, str_second, result_first, result_second in (
("123", "456", 123, 456),
("123", "NULL", None, 123),
):
self.assertEqual(
self.spark.sql(
f"""
WITH t AS (
SELECT {str_first} AS partition_col, id AS input FROM range(0, 2)
UNION ALL
SELECT {str_second} AS partition_col, id AS input FROM range(0, 2)
)
SELECT partition_col, total
FROM test_udtf(TABLE(t) PARTITION BY partition_col)
ORDER BY 1, 2
"""
).collect(),
[
Row(partition_col=result_first, total=1),
Row(partition_col=result_second, total=1),
],
)

# Combine a lateral join with a TABLE argument with PARTITION BY .
func = udtf(TestUDTF, returnType="partition_col: int, total: int")
self.spark.udtf.register("test_udtf", func)
self.assertEqual(
self.spark.sql(
"""
WITH t AS (
SELECT id AS partition_col, 1 AS input FROM range(1, 3)
UNION ALL
SELECT id AS partition_col, 2 AS input FROM range(1, 3)
)
SELECT v.a, v.b, f.partition_col, f.total
FROM VALUES (0, 1) AS v(a, b),
LATERAL test_udtf(TABLE(t) PARTITION BY partition_col - 1) f
ORDER BY 1, 2, 3, 4
"""
).collect(),
[Row(a=0, b=1, partition_col=1, total=3), Row(a=0, b=1, partition_col=2, total=3)],
)

def test_udtf_with_table_argument_and_partition_by_and_order_by(self):
class TestUDTF:
def __init__(self):
self._last = None
self._partition_col = None

def eval(self, row: Row, partition_col: str):
# Make sure that all values of the partitioning column are the same
# for each row consumed by this method for this instance of the class.
if self._partition_col is not None and self._partition_col != row[partition_col]:
raise Exception(
f"self._partition_col was {self._partition_col} but the row "
+ f"value was {row[partition_col]}"
)
self._last = row["input"]
self._partition_col = row[partition_col]

def terminate(self):
yield self._partition_col, self._last

func = udtf(TestUDTF, returnType="partition_col: int, last: int")
self.spark.udtf.register("test_udtf", func)
for order_by_str, result_val in (
("input ASC", 2),
("input + 1 ASC", 2),
("input DESC", 1),
("input - 1 DESC", 1),
):
self.assertEqual(
self.spark.sql(
f"""
WITH t AS (
SELECT id AS partition_col, 1 AS input FROM range(1, 21)
UNION ALL
SELECT id AS partition_col, 2 AS input FROM range(1, 21)
)
SELECT partition_col, last
FROM test_udtf(
row => TABLE(t) PARTITION BY partition_col - 1 ORDER BY {order_by_str},
partition_col => 'partition_col')
ORDER BY 1, 2
"""
).collect(),
[Row(partition_col=x, last=result_val) for x in range(1, 21)],
)

def test_udtf_with_table_argument_with_single_partition(self):
class TestUDTF:
def __init__(self):
self._count = 0
self._sum = 0
self._last = None

def eval(self, row: Row):
# Make sure that the rows arrive in the expected order.
if self._last is not None and self._last > row["input"]:
raise Exception(
f"self._last was {self._last} but the row value was {row['input']}"
)
self._count += 1
self._last = row["input"]
self._sum += row["input"]

def terminate(self):
yield self._count, self._sum, self._last

func = udtf(TestUDTF, returnType="count: int, total: int, last: int")
self.spark.udtf.register("test_udtf", func)
self.assertEqual(
self.spark.sql(
"""
WITH t AS (
SELECT id AS partition_col, 1 AS input FROM range(1, 21)
UNION ALL
SELECT id AS partition_col, 2 AS input FROM range(1, 21)
)
SELECT count, total, last
FROM test_udtf(TABLE(t) WITH SINGLE PARTITION ORDER BY (input, partition_col))
ORDER BY 1, 2
"""
).collect(),
[Row(count=40, total=60, last=2)],
)


class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
89 changes: 85 additions & 4 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import time
from inspect import getfullargspec
import json
from typing import Any, Iterable, Iterator
from typing import Any, Callable, Iterable, Iterator

import traceback
import faulthandler
Expand Down Expand Up @@ -53,7 +53,7 @@
ApplyInPandasWithStateSerializer,
)
from pyspark.sql.pandas.types import to_arrow_type
from pyspark.sql.types import BinaryType, StringType, StructType, _parse_datatype_json_string
from pyspark.sql.types import BinaryType, Row, StringType, StructType, _parse_datatype_json_string
from pyspark.util import fail_on_stopiteration, try_simplify_traceback
from pyspark import shuffle
from pyspark.errors import PySparkRuntimeError, PySparkTypeError
Expand Down Expand Up @@ -609,7 +609,8 @@ def read_udtf(pickleSer, infile, eval_type):
kwargs_offsets[name] = offset
else:
args_offsets.append(offset)

num_partition_child_indexes = read_int(infile)
partition_child_indexes = [read_int(infile) for i in range(num_partition_child_indexes)]
handler = read_command(pickleSer, infile)
if not isinstance(handler, type):
raise PySparkRuntimeError(
Expand All @@ -623,9 +624,89 @@ def read_udtf(pickleSer, infile, eval_type):
f"The return type of a UDTF must be a struct type, but got {type(return_type)}."
)

class UDTFWithPartitions:
"""
This implements the logic of a UDTF that accepts an input TABLE argument with one or more
PARTITION BY expressions.
For example, let's assume we have a table like:
CREATE TABLE t (c1 INT, c2 INT) USING delta;
Then for the following queries:
SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2);
The partition_child_indexes will be: 0, 1.
SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4);
The partition_child_indexes will be: 0, 2 (where we add a projection for "c2 + 4").
"""

def __init__(self, create_udtf: Callable, partition_child_indexes: list):
"""
Creates a new instance of this class to wrap the provided UDTF with another one that
checks the values of projected partitioning expressions on consecutive rows to figure
out when the partition boundaries change.
Parameters
----------
create_udtf: function
Function to create a new instance of the UDTF to be invoked.
partition_child_indexes: list
List of integers identifying zero-based indexes of the columns of the input table
that contain projected partitioning expressions. This class will inspect these
values for each pair of consecutive input rows. When they change, this indicates
the boundary between two partitions, and we will invoke the 'terminate' method on
the UDTF class instance and then destroy it and create a new one to implement the
desired partitioning semantics.
"""
self._create_udtf: Callable = create_udtf
self._udtf = create_udtf()
self._prev_arguments: list = list()
self._partition_child_indexes: list = partition_child_indexes

def eval(self, *args, **kwargs) -> Iterator:
changed_partitions = self._check_partition_boundaries(
list(args) + list(kwargs.values())
)
if changed_partitions:
if self._udtf.terminate is not None:
result = self._udtf.terminate()
if result is not None:
for row in result:
yield row
self._udtf = self._create_udtf()
if self._udtf.eval is not None:
result = self._udtf.eval(*args, **kwargs)
if result is not None:
for row in result:
yield row

def terminate(self) -> Iterator:
if self._udtf.terminate is not None:
return self._udtf.terminate()
return iter(())

def _check_partition_boundaries(self, arguments: list) -> bool:
result = False
if len(self._prev_arguments) > 0:
cur_table_arg = self._get_table_arg(arguments)
prev_table_arg = self._get_table_arg(self._prev_arguments)
cur_partitions_args = []
prev_partitions_args = []
for i in partition_child_indexes:
cur_partitions_args.append(cur_table_arg[i])
prev_partitions_args.append(prev_table_arg[i])
self._prev_arguments = arguments
result = any(k != v for k, v in zip(cur_partitions_args, prev_partitions_args))
self._prev_arguments = arguments
return result

def _get_table_arg(self, inputs: list) -> Row:
return [x for x in inputs if type(x) is Row][0]

# Instantiate the UDTF class.
try:
udtf = handler()
if len(partition_child_indexes) > 0:
udtf = UDTFWithPartitions(handler, partition_child_indexes)
else:
udtf = handler()
except Exception as e:
raise PySparkRuntimeError(
error_class="UDTF_EXEC_ERROR",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2098,6 +2098,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}

val tableArgs = mutable.ArrayBuffer.empty[LogicalPlan]
val functionTableSubqueryArgs =
mutable.ArrayBuffer.empty[FunctionTableSubqueryArgumentExpression]
val tvf = resolvedFunc.transformAllExpressionsWithPruning(
_.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION), ruleId) {
case t: FunctionTableSubqueryArgumentExpression =>
Expand All @@ -2110,6 +2112,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
"PARTITION BY clause, but only Python table functions support this clause")
}
tableArgs.append(SubqueryAlias(alias, t.evaluable))
functionTableSubqueryArgs.append(t)
UnresolvedAttribute(Seq(alias, "c"))
}
if (tableArgs.nonEmpty) {
Expand All @@ -2118,11 +2121,34 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
tableArgs.size)
}
val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")
// Propagate the column indexes for TABLE arguments to the PythonUDTF instance.
def assignUDTFPartitionColumnIndexes(
fn: PythonUDTFPartitionColumnIndexes => LogicalPlan): Option[LogicalPlan] = {
val indexes: Seq[Int] = functionTableSubqueryArgs.headOption
.map(_.partitioningExpressionIndexes).getOrElse(Seq.empty)
if (indexes.nonEmpty) {
Some(fn(PythonUDTFPartitionColumnIndexes(indexes)))
} else {
None
}
}
val tvfWithTableColumnIndexes: LogicalPlan = tvf match {
case g@Generate(p: PythonUDTF, _, _, _, _, _) =>
assignUDTFPartitionColumnIndexes(
i => g.copy(generator = p.copy(pythonUDTFPartitionColumnIndexes = Some(i))))
.getOrElse(g)
case g@Generate(p: UnresolvedPolymorphicPythonUDTF, _, _, _, _, _) =>
assignUDTFPartitionColumnIndexes(
i => g.copy(generator = p.copy(pythonUDTFPartitionColumnIndexes = Some(i))))
.getOrElse(g)
case _ =>
tvf
}
Project(
Seq(UnresolvedStar(Some(Seq(alias)))),
LateralJoin(
tableArgs.reduceLeft(Join(_, _, Inner, None, JoinHint.NONE)),
LateralSubquery(SubqueryAlias(alias, tvf)), Inner, None)
LateralSubquery(SubqueryAlias(alias, tvfWithTableColumnIndexes)), Inner, None)
)
} else {
tvf
Expand Down Expand Up @@ -2200,7 +2226,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case u: UnresolvedPolymorphicPythonUDTF => withPosition(u) {
val elementSchema = u.resolveElementSchema(u.func, u.children)
PythonUDTF(u.name, u.func, elementSchema, u.children,
u.evalType, u.udfDeterministic, u.resultId)
u.evalType, u.udfDeterministic, u.resultId, u.pythonUDTFPartitionColumnIndexes)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,21 +163,3 @@ case class FunctionTableSubqueryArgumentExpression(

private lazy val subqueryOutputs: Map[Expression, Int] = plan.output.zipWithIndex.toMap
}

object FunctionTableSubqueryArgumentExpression {
/**
* Returns a sequence of zero-based integer indexes identifying the values of a Python UDTF's
* 'eval' method's *args list that correspond to partitioning columns of the input TABLE argument.
*/
def partitionChildIndexes(udtfArguments: Seq[Expression]): Seq[Int] = {
udtfArguments.zipWithIndex.flatMap { case (expr, index) =>
expr match {
case f: FunctionTableSubqueryArgumentExpression =>
f.partitioningExpressionIndexes.map(_ + index)
case _ =>
Seq()
}
}
}
}

Loading

0 comments on commit 7f3a439

Please sign in to comment.