-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-50194][SS][PYTHON] Integration of New Timer API and Initial State API with Timer #48838
Changes from 7 commits
49fb1cb
585e268
a0a53cf
d998e48
bad8f7a
b7e6f59
0c5ab3f
3d87b0e
96e7226
4c272a5
a69cb6f
53fb7cc
3fa4ef2
d4cd2e1
5e018ff
521722e
f10348c
0bff9eb
cd16fda
91d4c10
571dc44
aaf34a3
2cd4f8e
00b6340
b4f9d91
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
# | ||
import itertools | ||
import sys | ||
from typing import Any, Iterator, List, Optional, Union, TYPE_CHECKING, cast | ||
from typing import Any, Iterator, List, Optional, Union, Tuple, TYPE_CHECKING, cast | ||
import warnings | ||
|
||
from pyspark.errors import PySparkTypeError | ||
|
@@ -502,53 +502,59 @@ def transformWithStateInPandas( | |
if isinstance(outputStructType, str): | ||
outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) | ||
|
||
def handle_data_with_timers( | ||
def get_timestamps( | ||
statefulProcessorApiClient: StatefulProcessorApiClient, | ||
key: Any, | ||
inputRows: Iterator["PandasDataFrameLike"], | ||
) -> Iterator["PandasDataFrameLike"]: | ||
statefulProcessorApiClient.set_implicit_key(key) | ||
) -> Tuple[int, int]: | ||
if timeMode != "none": | ||
batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() | ||
watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() | ||
else: | ||
batch_timestamp = -1 | ||
watermark_timestamp = -1 | ||
# process with invalid expiry timer info and emit data rows | ||
data_iter = statefulProcessor.handleInputRows( | ||
key, | ||
inputRows, | ||
TimerValues(batch_timestamp, watermark_timestamp), | ||
ExpiredTimerInfo(False), | ||
) | ||
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.DATA_PROCESSED) | ||
return batch_timestamp, watermark_timestamp | ||
|
||
def handle_data_with_timers( | ||
statefulProcessorApiClient: StatefulProcessorApiClient, | ||
key: Any, | ||
batch_timestamp: int, | ||
watermark_timestamp: int, | ||
inputRows: Optional[Iterator["PandasDataFrameLike"]] = None, | ||
) -> Iterator["PandasDataFrameLike"]: | ||
statefulProcessorApiClient.set_implicit_key(key) | ||
# process with data rows | ||
if inputRows is not None: | ||
data_iter = statefulProcessor.handleInputRows( | ||
key, inputRows, TimerValues(batch_timestamp, watermark_timestamp) | ||
) | ||
result_iter_list = [data_iter] | ||
statefulProcessorApiClient.set_handle_state( | ||
StatefulProcessorHandleState.DATA_PROCESSED | ||
) | ||
else: | ||
result_iter_list = [] | ||
|
||
if timeMode == "processingtime": | ||
if timeMode.lower() == "processingtime": | ||
expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( | ||
batch_timestamp | ||
) | ||
elif timeMode == "eventtime": | ||
elif timeMode.lower() == "eventtime": | ||
expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( | ||
watermark_timestamp | ||
) | ||
else: | ||
expiry_list_iter = iter([[]]) | ||
|
||
result_iter_list = [data_iter] | ||
# process with valid expiry time info and with empty input rows, | ||
# only timer related rows will be emitted | ||
# process with expiry timers, only timer related rows will be emitted | ||
for expiry_list in expiry_list_iter: | ||
for key_obj, expiry_timestamp in expiry_list: | ||
result_iter_list.append( | ||
statefulProcessor.handleInputRows( | ||
statefulProcessor.handleExpiredTimer( | ||
key_obj, | ||
iter([]), | ||
TimerValues(batch_timestamp, watermark_timestamp), | ||
ExpiredTimerInfo(True, expiry_timestamp), | ||
ExpiredTimerInfo(expiry_timestamp), | ||
) | ||
) | ||
# TODO(SPARK-49603) set the handle state in the lazily initialized iterator | ||
|
||
result = itertools.chain(*result_iter_list) | ||
return result | ||
|
||
|
@@ -573,7 +579,11 @@ def transformWithStateUDF( | |
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) | ||
return iter([]) | ||
|
||
result = handle_data_with_timers(statefulProcessorApiClient, key, inputRows) | ||
batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally this shouldn't be called at every key. If we split out the handling of timer expiration from the handling of input rows, we would only need to call this at once. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved this |
||
|
||
result = handle_data_with_timers( | ||
statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, inputRows | ||
) | ||
return result | ||
|
||
def transformWithStateWithInitStateUDF( | ||
|
@@ -610,12 +620,15 @@ def transformWithStateWithInitStateUDF( | |
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) | ||
return iter([]) | ||
|
||
batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient) | ||
|
||
# only process initial state if first batch and initial state is not None | ||
if initialStates is not None: | ||
for cur_initial_state in initialStates: | ||
statefulProcessorApiClient.set_implicit_key(key) | ||
# TODO(SPARK-50194) integration with new timer API with initial state | ||
statefulProcessor.handleInitialState(key, cur_initial_state) | ||
statefulProcessor.handleInitialState( | ||
key, cur_initial_state, TimerValues(batch_timestamp, watermark_timestamp) | ||
) | ||
|
||
# if we don't have input rows for the given key but only have initial state | ||
# for the grouping key, the inputRows iterator could be empty | ||
|
@@ -628,10 +641,15 @@ def transformWithStateWithInitStateUDF( | |
inputRows = itertools.chain([first], inputRows) | ||
|
||
if not input_rows_empty: | ||
result = handle_data_with_timers(statefulProcessorApiClient, key, inputRows) | ||
result = handle_data_with_timers( | ||
statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, inputRows | ||
) | ||
else: | ||
result = iter([]) | ||
|
||
# if the input rows is empty, we still need to handle the expired timers registered | ||
# in the initial state | ||
result = handle_data_with_timers( | ||
statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, None | ||
) | ||
return result | ||
|
||
if isinstance(outputStructType, str): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -105,21 +105,13 @@ def get_current_watermark_in_ms(self) -> int: | |
|
||
class ExpiredTimerInfo: | ||
""" | ||
Class used for arbitrary stateful operations with transformWithState to access expired timer | ||
info. When is_valid is false, the expiry timestamp is invalid. | ||
Class used to provide access to expired timer's expiry time. | ||
.. versionadded:: 4.0.0 | ||
""" | ||
|
||
def __init__(self, is_valid: bool, expiry_time_in_ms: int = -1) -> None: | ||
self._is_valid = is_valid | ||
def __init__(self, expiry_time_in_ms: int = -1) -> None: | ||
self._expiry_time_in_ms = expiry_time_in_ms | ||
|
||
def is_valid(self) -> bool: | ||
""" | ||
Whether the expiry info is valid. | ||
""" | ||
return self._is_valid | ||
|
||
def get_expiry_time_in_ms(self) -> int: | ||
""" | ||
Get the timestamp for expired timer, return timestamp in millisecond. | ||
|
@@ -398,7 +390,6 @@ def handleInputRows( | |
key: Any, | ||
rows: Iterator["PandasDataFrameLike"], | ||
timer_values: TimerValues, | ||
expired_timer_info: ExpiredTimerInfo, | ||
) -> Iterator["PandasDataFrameLike"]: | ||
""" | ||
Function that will allow users to interact with input data rows along with the grouping key. | ||
|
@@ -420,11 +411,29 @@ def handleInputRows( | |
timer_values: TimerValues | ||
Timer value for the current batch that process the input rows. | ||
Users can get the processing or event time timestamp from TimerValues. | ||
expired_timer_info: ExpiredTimerInfo | ||
Timestamp of expired timers on the grouping key. | ||
""" | ||
... | ||
|
||
def handleExpiredTimer( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just double check that this method is not required for users to implement, correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct. Add a comment line in the docstring to explicitly saying this is optional to implement. |
||
self, key: Any, timer_values: TimerValues, expired_timer_info: ExpiredTimerInfo | ||
) -> Iterator["PandasDataFrameLike"]: | ||
""" | ||
Optional to implement. Will act return an empty iterator if not defined. | ||
Function that will be invoked when a timer is fired for a given key. Users can choose to | ||
evict state, register new timers and optionally provide output rows. | ||
|
||
Parameters | ||
---------- | ||
key : Any | ||
grouping key. | ||
timer_values: TimerValues | ||
Timer value for the current batch that process the input rows. | ||
Users can get the processing or event time timestamp from TimerValues. | ||
expired_timer_info: ExpiredTimerInfo | ||
Instance of ExpiredTimerInfo that provides access to expired timer. | ||
""" | ||
return iter([]) | ||
|
||
@abstractmethod | ||
def close(self) -> None: | ||
""" | ||
|
@@ -433,9 +442,21 @@ def close(self) -> None: | |
""" | ||
... | ||
|
||
def handleInitialState(self, key: Any, initialState: "PandasDataFrameLike") -> None: | ||
def handleInitialState( | ||
self, key: Any, initialState: "PandasDataFrameLike", timer_values: TimerValues | ||
) -> None: | ||
""" | ||
Optional to implement. Will act as no-op if not defined or no initial state input. | ||
Function that will be invoked only in the first batch for users to process initial states. | ||
|
||
Parameters | ||
---------- | ||
key : Any | ||
grouping key. | ||
initialState: :class:`pandas.DataFrame` | ||
One dataframe in the initial state associated with the key. | ||
timer_values: TimerValues | ||
Timer value for the current batch that process the input rows. | ||
Users can get the processing or event time timestamp from TimerValues. | ||
""" | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have confused about this every time. Is this relying on the behavior that expired timer will be removed so we won't list up the same timer as expired multiple times? This is very easy to be forgotten.
If there is any way we can just move this out and do this after we process all input? Can this be done in transformWithStateUDF/transformWithStateWithInitStateUDF with key = null?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for catching this! I made a terrible correctness bug in my prior timer implementation. I now moved all timer handling codes into
serializer.py
where the expired timers are processed per partition.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left an explanation of what is causing the correctness issue in my prior implementation here just in case you are curious: #48838 (comment)