Skip to content

Commit

Permalink
[SPARK-50194][SS][PYTHON] Integration of New Timer API and Initial St…
Browse files Browse the repository at this point in the history
…ate API with Timer

### What changes were proposed in this pull request?

As Scala side, we modify the timer API with a separate `handleExpiredTimer` function inside `StatefulProcessor`, this PR make a change to the timer API to couple with API on Scala side. Also adds a timer parameter to pass into `handleInitialState` function to support use cases for registering timers in the first batch for initial state rows.

### Why are the changes needed?

This change is to couple with Scala side of APIs: #48553

### Does this PR introduce _any_ user-facing change?

Yes.
We add a new user defined function to explicitly handle expired timeres:
```
def handleExpiredTimer(
        self, key: Any, timer_values: TimerValues, expired_timer_info: ExpiredTimerInfo
```
We also add a new timer parameter to enable users to register timers for keys exist in the initial state:
```
def handleInitialState(
        self,
        key: Any,
        initialState: "PandasDataFrameLike",
        timer_values: TimerValues) -> None
```

### How was this patch tested?

Add a new test in `test_pandas_transform_with_state`

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #48838 from jingz-db/python-new-timer.

Lead-authored-by: jingz-db <[email protected]>
Co-authored-by: Jing Zhan <[email protected]>
Co-authored-by: Jungtaek Lim <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
3 people committed Nov 28, 2024
1 parent bb994d1 commit e6252d6
Show file tree
Hide file tree
Showing 8 changed files with 363 additions and 224 deletions.
107 changes: 64 additions & 43 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TimerValues,
)
from pyspark.sql.streaming.stateful_processor import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode
from pyspark.sql.types import StructType, _parse_datatype_string

if TYPE_CHECKING:
Expand Down Expand Up @@ -503,58 +504,59 @@ def transformWithStateInPandas(
if isinstance(outputStructType, str):
outputStructType = cast(StructType, _parse_datatype_string(outputStructType))

def handle_data_with_timers(
def handle_data_rows(
statefulProcessorApiClient: StatefulProcessorApiClient,
key: Any,
inputRows: Iterator["PandasDataFrameLike"],
inputRows: Optional[Iterator["PandasDataFrameLike"]] = None,
) -> Iterator["PandasDataFrameLike"]:
statefulProcessorApiClient.set_implicit_key(key)
if timeMode != "none":
batch_timestamp = statefulProcessorApiClient.get_batch_timestamp()
watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp()

batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps(
timeMode
)

# process with data rows
if inputRows is not None:
data_iter = statefulProcessor.handleInputRows(
key, inputRows, TimerValues(batch_timestamp, watermark_timestamp)
)
return data_iter
else:
batch_timestamp = -1
watermark_timestamp = -1
# process with invalid expiry timer info and emit data rows
data_iter = statefulProcessor.handleInputRows(
key,
inputRows,
TimerValues(batch_timestamp, watermark_timestamp),
ExpiredTimerInfo(False),
return iter([])

def handle_expired_timers(
statefulProcessorApiClient: StatefulProcessorApiClient,
) -> Iterator["PandasDataFrameLike"]:
batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps(
timeMode
)
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.DATA_PROCESSED)

if timeMode == "processingtime":
if timeMode.lower() == "processingtime":
expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator(
batch_timestamp
)
elif timeMode == "eventtime":
elif timeMode.lower() == "eventtime":
expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator(
watermark_timestamp
)
else:
expiry_list_iter = iter([[]])

result_iter_list = [data_iter]
# process with valid expiry time info and with empty input rows,
# only timer related rows will be emitted
# process with expiry timers, only timer related rows will be emitted
for expiry_list in expiry_list_iter:
for key_obj, expiry_timestamp in expiry_list:
result_iter_list.append(
statefulProcessor.handleInputRows(
key_obj,
iter([]),
TimerValues(batch_timestamp, watermark_timestamp),
ExpiredTimerInfo(True, expiry_timestamp),
)
)
# TODO(SPARK-49603) set the handle state in the lazily initialized iterator

result = itertools.chain(*result_iter_list)
return result
statefulProcessorApiClient.set_implicit_key(key_obj)
for pd in statefulProcessor.handleExpiredTimer(
key=key_obj,
timer_values=TimerValues(batch_timestamp, watermark_timestamp),
expired_timer_info=ExpiredTimerInfo(expiry_timestamp),
):
yield pd
statefulProcessorApiClient.delete_timer(expiry_timestamp)

def transformWithStateUDF(
statefulProcessorApiClient: StatefulProcessorApiClient,
mode: TransformWithStateInPandasFuncMode,
key: Any,
inputRows: Iterator["PandasDataFrameLike"],
) -> Iterator["PandasDataFrameLike"]:
Expand All @@ -566,19 +568,28 @@ def transformWithStateUDF(
StatefulProcessorHandleState.INITIALIZED
)

# Key is None when we have processed all the input data from the worker and ready to
# proceed with the cleanup steps.
if key is None:
if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER:
statefulProcessorApiClient.set_handle_state(
StatefulProcessorHandleState.DATA_PROCESSED
)
result = handle_expired_timers(statefulProcessorApiClient)
return result
elif mode == TransformWithStateInPandasFuncMode.COMPLETE:
statefulProcessorApiClient.set_handle_state(
StatefulProcessorHandleState.TIMER_PROCESSED
)
statefulProcessorApiClient.remove_implicit_key()
statefulProcessor.close()
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED)
return iter([])

result = handle_data_with_timers(statefulProcessorApiClient, key, inputRows)
return result
else:
# mode == TransformWithStateInPandasFuncMode.PROCESS_DATA
result = handle_data_rows(statefulProcessorApiClient, key, inputRows)
return result

def transformWithStateWithInitStateUDF(
statefulProcessorApiClient: StatefulProcessorApiClient,
mode: TransformWithStateInPandasFuncMode,
key: Any,
inputRows: Iterator["PandasDataFrameLike"],
initialStates: Optional[Iterator["PandasDataFrameLike"]] = None,
Expand All @@ -603,20 +614,30 @@ def transformWithStateWithInitStateUDF(
StatefulProcessorHandleState.INITIALIZED
)

# Key is None when we have processed all the input data from the worker and ready to
# proceed with the cleanup steps.
if key is None:
if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER:
statefulProcessorApiClient.set_handle_state(
StatefulProcessorHandleState.DATA_PROCESSED
)
result = handle_expired_timers(statefulProcessorApiClient)
return result
elif mode == TransformWithStateInPandasFuncMode.COMPLETE:
statefulProcessorApiClient.remove_implicit_key()
statefulProcessor.close()
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED)
return iter([])
else:
# mode == TransformWithStateInPandasFuncMode.PROCESS_DATA
batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps(
timeMode
)

# only process initial state if first batch and initial state is not None
if initialStates is not None:
for cur_initial_state in initialStates:
statefulProcessorApiClient.set_implicit_key(key)
# TODO(SPARK-50194) integration with new timer API with initial state
statefulProcessor.handleInitialState(key, cur_initial_state)
statefulProcessor.handleInitialState(
key, cur_initial_state, TimerValues(batch_timestamp, watermark_timestamp)
)

# if we don't have input rows for the given key but only have initial state
# for the grouping key, the inputRows iterator could be empty
Expand All @@ -629,7 +650,7 @@ def transformWithStateWithInitStateUDF(
inputRows = itertools.chain([first], inputRows)

if not input_rows_empty:
result = handle_data_with_timers(statefulProcessorApiClient, key, inputRows)
result = handle_data_rows(statefulProcessorApiClient, key, inputRows)
else:
result = iter([])

Expand Down
13 changes: 11 additions & 2 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
_create_converter_from_pandas,
_create_converter_to_pandas,
)
from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode
from pyspark.sql.types import (
DataType,
StringType,
Expand Down Expand Up @@ -1197,7 +1198,11 @@ def generate_data_batches(batches):
data_batches = generate_data_batches(_batches)

for k, g in groupby(data_batches, key=lambda x: x[0]):
yield (k, g)
yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g)

yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)

yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None)

def dump_stream(self, iterator, stream):
"""
Expand Down Expand Up @@ -1281,4 +1286,8 @@ def flatten_columns(cur_batch, col_name):
data_batches = generate_data_batches(_batches)

for k, g in groupby(data_batches, key=lambda x: x[0]):
yield (k, g)
yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g)

yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)

yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None)
49 changes: 35 additions & 14 deletions python/pyspark/sql/streaming/stateful_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,13 @@ def get_current_watermark_in_ms(self) -> int:

class ExpiredTimerInfo:
"""
Class used for arbitrary stateful operations with transformWithState to access expired timer
info. When is_valid is false, the expiry timestamp is invalid.
Class used to provide access to expired timer's expiry time.
.. versionadded:: 4.0.0
"""

def __init__(self, is_valid: bool, expiry_time_in_ms: int = -1) -> None:
self._is_valid = is_valid
def __init__(self, expiry_time_in_ms: int = -1) -> None:
self._expiry_time_in_ms = expiry_time_in_ms

def is_valid(self) -> bool:
"""
Whether the expiry info is valid.
"""
return self._is_valid

def get_expiry_time_in_ms(self) -> int:
"""
Get the timestamp for expired timer, return timestamp in millisecond.
Expand Down Expand Up @@ -398,7 +390,6 @@ def handleInputRows(
key: Any,
rows: Iterator["PandasDataFrameLike"],
timer_values: TimerValues,
expired_timer_info: ExpiredTimerInfo,
) -> Iterator["PandasDataFrameLike"]:
"""
Function that will allow users to interact with input data rows along with the grouping key.
Expand All @@ -420,11 +411,29 @@ def handleInputRows(
timer_values: TimerValues
Timer value for the current batch that process the input rows.
Users can get the processing or event time timestamp from TimerValues.
expired_timer_info: ExpiredTimerInfo
Timestamp of expired timers on the grouping key.
"""
...

def handleExpiredTimer(
self, key: Any, timer_values: TimerValues, expired_timer_info: ExpiredTimerInfo
) -> Iterator["PandasDataFrameLike"]:
"""
Optional to implement. Will act return an empty iterator if not defined.
Function that will be invoked when a timer is fired for a given key. Users can choose to
evict state, register new timers and optionally provide output rows.
Parameters
----------
key : Any
grouping key.
timer_values: TimerValues
Timer value for the current batch that process the input rows.
Users can get the processing or event time timestamp from TimerValues.
expired_timer_info: ExpiredTimerInfo
Instance of ExpiredTimerInfo that provides access to expired timer.
"""
return iter([])

@abstractmethod
def close(self) -> None:
"""
Expand All @@ -433,9 +442,21 @@ def close(self) -> None:
"""
...

def handleInitialState(self, key: Any, initialState: "PandasDataFrameLike") -> None:
def handleInitialState(
self, key: Any, initialState: "PandasDataFrameLike", timer_values: TimerValues
) -> None:
"""
Optional to implement. Will act as no-op if not defined or no initial state input.
Function that will be invoked only in the first batch for users to process initial states.
Parameters
----------
key : Any
grouping key.
initialState: :class:`pandas.DataFrame`
One dataframe in the initial state associated with the key.
timer_values: TimerValues
Timer value for the current batch that process the input rows.
Users can get the processing or event time timestamp from TimerValues.
"""
pass
Loading

0 comments on commit e6252d6

Please sign in to comment.