Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50194][SS][PYTHON] Integration of New Timer API and Initial State API with Timer #48838

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 64 additions & 43 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TimerValues,
)
from pyspark.sql.streaming.stateful_processor import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode
from pyspark.sql.types import StructType, _parse_datatype_string

if TYPE_CHECKING:
Expand Down Expand Up @@ -503,58 +504,59 @@ def transformWithStateInPandas(
if isinstance(outputStructType, str):
outputStructType = cast(StructType, _parse_datatype_string(outputStructType))

def handle_data_with_timers(
def handle_data_rows(
statefulProcessorApiClient: StatefulProcessorApiClient,
key: Any,
inputRows: Iterator["PandasDataFrameLike"],
inputRows: Optional[Iterator["PandasDataFrameLike"]] = None,
) -> Iterator["PandasDataFrameLike"]:
statefulProcessorApiClient.set_implicit_key(key)
if timeMode != "none":
batch_timestamp = statefulProcessorApiClient.get_batch_timestamp()
watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp()

batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps(
timeMode
)

# process with data rows
if inputRows is not None:
data_iter = statefulProcessor.handleInputRows(
key, inputRows, TimerValues(batch_timestamp, watermark_timestamp)
)
return data_iter
else:
batch_timestamp = -1
watermark_timestamp = -1
# process with invalid expiry timer info and emit data rows
data_iter = statefulProcessor.handleInputRows(
key,
inputRows,
TimerValues(batch_timestamp, watermark_timestamp),
ExpiredTimerInfo(False),
return iter([])

def handle_expired_timers(
statefulProcessorApiClient: StatefulProcessorApiClient,
) -> Iterator["PandasDataFrameLike"]:
batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps(
timeMode
)
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.DATA_PROCESSED)

if timeMode == "processingtime":
if timeMode.lower() == "processingtime":
expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator(
batch_timestamp
)
elif timeMode == "eventtime":
elif timeMode.lower() == "eventtime":
expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator(
watermark_timestamp
)
else:
expiry_list_iter = iter([[]])

result_iter_list = [data_iter]
# process with valid expiry time info and with empty input rows,
# only timer related rows will be emitted
# process with expiry timers, only timer related rows will be emitted
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have confused about this every time. Is this relying on the behavior that expired timer will be removed so we won't list up the same timer as expired multiple times? This is very easy to be forgotten.

If there is any way we can just move this out and do this after we process all input? Can this be done in transformWithStateUDF/transformWithStateWithInitStateUDF with key = null?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for catching this! I made a terrible correctness bug in my prior timer implementation. I now moved all timer handling codes into serializer.py where the expired timers are processed per partition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left an explanation of what is causing the correctness issue in my prior implementation here just in case you are curious: #48838 (comment)

for expiry_list in expiry_list_iter:
for key_obj, expiry_timestamp in expiry_list:
result_iter_list.append(
statefulProcessor.handleInputRows(
key_obj,
iter([]),
TimerValues(batch_timestamp, watermark_timestamp),
ExpiredTimerInfo(True, expiry_timestamp),
)
)
# TODO(SPARK-49603) set the handle state in the lazily initialized iterator

result = itertools.chain(*result_iter_list)
return result
statefulProcessorApiClient.set_implicit_key(key_obj)
for pd in statefulProcessor.handleExpiredTimer(
key=key_obj,
timer_values=TimerValues(batch_timestamp, watermark_timestamp),
expired_timer_info=ExpiredTimerInfo(expiry_timestamp),
):
yield pd
statefulProcessorApiClient.delete_timer(expiry_timestamp)

def transformWithStateUDF(
statefulProcessorApiClient: StatefulProcessorApiClient,
mode: TransformWithStateInPandasFuncMode,
key: Any,
inputRows: Iterator["PandasDataFrameLike"],
) -> Iterator["PandasDataFrameLike"]:
Expand All @@ -566,19 +568,28 @@ def transformWithStateUDF(
StatefulProcessorHandleState.INITIALIZED
)

# Key is None when we have processed all the input data from the worker and ready to
# proceed with the cleanup steps.
if key is None:
if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER:
statefulProcessorApiClient.set_handle_state(
StatefulProcessorHandleState.DATA_PROCESSED
)
result = handle_expired_timers(statefulProcessorApiClient)
return result
elif mode == TransformWithStateInPandasFuncMode.COMPLETE:
statefulProcessorApiClient.set_handle_state(
StatefulProcessorHandleState.TIMER_PROCESSED
)
statefulProcessorApiClient.remove_implicit_key()
statefulProcessor.close()
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED)
return iter([])

result = handle_data_with_timers(statefulProcessorApiClient, key, inputRows)
return result
else:
# mode == TransformWithStateInPandasFuncMode.PROCESS_DATA
result = handle_data_rows(statefulProcessorApiClient, key, inputRows)
return result

def transformWithStateWithInitStateUDF(
statefulProcessorApiClient: StatefulProcessorApiClient,
mode: TransformWithStateInPandasFuncMode,
key: Any,
inputRows: Iterator["PandasDataFrameLike"],
initialStates: Optional[Iterator["PandasDataFrameLike"]] = None,
Expand All @@ -603,20 +614,30 @@ def transformWithStateWithInitStateUDF(
StatefulProcessorHandleState.INITIALIZED
)

# Key is None when we have processed all the input data from the worker and ready to
# proceed with the cleanup steps.
if key is None:
if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER:
statefulProcessorApiClient.set_handle_state(
StatefulProcessorHandleState.DATA_PROCESSED
)
result = handle_expired_timers(statefulProcessorApiClient)
return result
elif mode == TransformWithStateInPandasFuncMode.COMPLETE:
statefulProcessorApiClient.remove_implicit_key()
statefulProcessor.close()
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED)
return iter([])
else:
# mode == TransformWithStateInPandasFuncMode.PROCESS_DATA
batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps(
timeMode
)

# only process initial state if first batch and initial state is not None
if initialStates is not None:
for cur_initial_state in initialStates:
statefulProcessorApiClient.set_implicit_key(key)
# TODO(SPARK-50194) integration with new timer API with initial state
statefulProcessor.handleInitialState(key, cur_initial_state)
statefulProcessor.handleInitialState(
key, cur_initial_state, TimerValues(batch_timestamp, watermark_timestamp)
)

# if we don't have input rows for the given key but only have initial state
# for the grouping key, the inputRows iterator could be empty
Expand All @@ -629,7 +650,7 @@ def transformWithStateWithInitStateUDF(
inputRows = itertools.chain([first], inputRows)

if not input_rows_empty:
result = handle_data_with_timers(statefulProcessorApiClient, key, inputRows)
result = handle_data_rows(statefulProcessorApiClient, key, inputRows)
else:
result = iter([])

Expand Down
13 changes: 11 additions & 2 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
_create_converter_from_pandas,
_create_converter_to_pandas,
)
from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode
from pyspark.sql.types import (
DataType,
StringType,
Expand Down Expand Up @@ -1197,7 +1198,11 @@ def generate_data_batches(batches):
data_batches = generate_data_batches(_batches)

for k, g in groupby(data_batches, key=lambda x: x[0]):
yield (k, g)
yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: looks like not consistent? Here we use tuple with explicit () and below class we don't use (). Not a huge deal if linter does not complain, but while we are here (linter is failing)...


yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)

yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None)

def dump_stream(self, iterator, stream):
"""
Expand Down Expand Up @@ -1281,4 +1286,8 @@ def flatten_columns(cur_batch, col_name):
data_batches = generate_data_batches(_batches)

for k, g in groupby(data_batches, key=lambda x: x[0]):
yield (k, g)
yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g)

yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)

yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None)
49 changes: 35 additions & 14 deletions python/pyspark/sql/streaming/stateful_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,13 @@ def get_current_watermark_in_ms(self) -> int:

class ExpiredTimerInfo:
"""
Class used for arbitrary stateful operations with transformWithState to access expired timer
info. When is_valid is false, the expiry timestamp is invalid.
Class used to provide access to expired timer's expiry time.
.. versionadded:: 4.0.0
"""

def __init__(self, is_valid: bool, expiry_time_in_ms: int = -1) -> None:
self._is_valid = is_valid
def __init__(self, expiry_time_in_ms: int = -1) -> None:
self._expiry_time_in_ms = expiry_time_in_ms

def is_valid(self) -> bool:
"""
Whether the expiry info is valid.
"""
return self._is_valid

def get_expiry_time_in_ms(self) -> int:
"""
Get the timestamp for expired timer, return timestamp in millisecond.
Expand Down Expand Up @@ -398,7 +390,6 @@ def handleInputRows(
key: Any,
rows: Iterator["PandasDataFrameLike"],
timer_values: TimerValues,
expired_timer_info: ExpiredTimerInfo,
) -> Iterator["PandasDataFrameLike"]:
"""
Function that will allow users to interact with input data rows along with the grouping key.
Expand All @@ -420,11 +411,29 @@ def handleInputRows(
timer_values: TimerValues
Timer value for the current batch that process the input rows.
Users can get the processing or event time timestamp from TimerValues.
expired_timer_info: ExpiredTimerInfo
Timestamp of expired timers on the grouping key.
"""
...

def handleExpiredTimer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just double check that this method is not required for users to implement, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. Add a comment line in the docstring to explicitly saying this is optional to implement.

self, key: Any, timer_values: TimerValues, expired_timer_info: ExpiredTimerInfo
) -> Iterator["PandasDataFrameLike"]:
"""
Optional to implement. Will act return an empty iterator if not defined.
Function that will be invoked when a timer is fired for a given key. Users can choose to
evict state, register new timers and optionally provide output rows.

Parameters
----------
key : Any
grouping key.
timer_values: TimerValues
Timer value for the current batch that process the input rows.
Users can get the processing or event time timestamp from TimerValues.
expired_timer_info: ExpiredTimerInfo
Instance of ExpiredTimerInfo that provides access to expired timer.
"""
return iter([])

@abstractmethod
def close(self) -> None:
"""
Expand All @@ -433,9 +442,21 @@ def close(self) -> None:
"""
...

def handleInitialState(self, key: Any, initialState: "PandasDataFrameLike") -> None:
def handleInitialState(
self, key: Any, initialState: "PandasDataFrameLike", timer_values: TimerValues
) -> None:
"""
Optional to implement. Will act as no-op if not defined or no initial state input.
Function that will be invoked only in the first batch for users to process initial states.

Parameters
----------
key : Any
grouping key.
initialState: :class:`pandas.DataFrame`
One dataframe in the initial state associated with the key.
timer_values: TimerValues
Timer value for the current batch that process the input rows.
Users can get the processing or event time timestamp from TimerValues.
"""
pass
Loading