From e6252d6c6143f134201a17bb25978a136186cbfa Mon Sep 17 00:00:00 2001 From: jingz-db Date: Thu, 28 Nov 2024 15:34:37 +0900 Subject: [PATCH] [SPARK-50194][SS][PYTHON] Integration of New Timer API and Initial State API with Timer ### What changes were proposed in this pull request? As Scala side, we modify the timer API with a separate `handleExpiredTimer` function inside `StatefulProcessor`, this PR make a change to the timer API to couple with API on Scala side. Also adds a timer parameter to pass into `handleInitialState` function to support use cases for registering timers in the first batch for initial state rows. ### Why are the changes needed? This change is to couple with Scala side of APIs: https://github.com/apache/spark/pull/48553 ### Does this PR introduce _any_ user-facing change? Yes. We add a new user defined function to explicitly handle expired timeres: ``` def handleExpiredTimer( self, key: Any, timer_values: TimerValues, expired_timer_info: ExpiredTimerInfo ``` We also add a new timer parameter to enable users to register timers for keys exist in the initial state: ``` def handleInitialState( self, key: Any, initialState: "PandasDataFrameLike", timer_values: TimerValues) -> None ``` ### How was this patch tested? Add a new test in `test_pandas_transform_with_state` ### Was this patch authored or co-authored using generative AI tooling? No Closes #48838 from jingz-db/python-new-timer. Lead-authored-by: jingz-db Co-authored-by: Jing Zhan <135738831+jingz-db@users.noreply.github.com> Co-authored-by: Jungtaek Lim Signed-off-by: Jungtaek Lim --- python/pyspark/sql/pandas/group_ops.py | 107 +++++---- python/pyspark/sql/pandas/serializers.py | 13 +- .../sql/streaming/stateful_processor.py | 49 ++-- .../stateful_processor_api_client.py | 94 +++++--- .../sql/streaming/stateful_processor_util.py | 27 +++ .../test_pandas_transform_with_state.py | 227 +++++++++++------- python/pyspark/worker.py | 68 +++--- ...ransformWithStateInPandasStateServer.scala | 2 + 8 files changed, 363 insertions(+), 224 deletions(-) create mode 100644 python/pyspark/sql/streaming/stateful_processor_util.py diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index d8f22e434374c..688ad4b05732e 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -35,6 +35,7 @@ TimerValues, ) from pyspark.sql.streaming.stateful_processor import StatefulProcessor, StatefulProcessorHandle +from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode from pyspark.sql.types import StructType, _parse_datatype_string if TYPE_CHECKING: @@ -503,58 +504,59 @@ def transformWithStateInPandas( if isinstance(outputStructType, str): outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) - def handle_data_with_timers( + def handle_data_rows( statefulProcessorApiClient: StatefulProcessorApiClient, key: Any, - inputRows: Iterator["PandasDataFrameLike"], + inputRows: Optional[Iterator["PandasDataFrameLike"]] = None, ) -> Iterator["PandasDataFrameLike"]: statefulProcessorApiClient.set_implicit_key(key) - if timeMode != "none": - batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() - watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() + + batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps( + timeMode + ) + + # process with data rows + if inputRows is not None: + data_iter = statefulProcessor.handleInputRows( + key, inputRows, TimerValues(batch_timestamp, watermark_timestamp) + ) + return data_iter else: - batch_timestamp = -1 - watermark_timestamp = -1 - # process with invalid expiry timer info and emit data rows - data_iter = statefulProcessor.handleInputRows( - key, - inputRows, - TimerValues(batch_timestamp, watermark_timestamp), - ExpiredTimerInfo(False), + return iter([]) + + def handle_expired_timers( + statefulProcessorApiClient: StatefulProcessorApiClient, + ) -> Iterator["PandasDataFrameLike"]: + batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps( + timeMode ) - statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.DATA_PROCESSED) - if timeMode == "processingtime": + if timeMode.lower() == "processingtime": expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( batch_timestamp ) - elif timeMode == "eventtime": + elif timeMode.lower() == "eventtime": expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( watermark_timestamp ) else: expiry_list_iter = iter([[]]) - result_iter_list = [data_iter] - # process with valid expiry time info and with empty input rows, - # only timer related rows will be emitted + # process with expiry timers, only timer related rows will be emitted for expiry_list in expiry_list_iter: for key_obj, expiry_timestamp in expiry_list: - result_iter_list.append( - statefulProcessor.handleInputRows( - key_obj, - iter([]), - TimerValues(batch_timestamp, watermark_timestamp), - ExpiredTimerInfo(True, expiry_timestamp), - ) - ) - # TODO(SPARK-49603) set the handle state in the lazily initialized iterator - - result = itertools.chain(*result_iter_list) - return result + statefulProcessorApiClient.set_implicit_key(key_obj) + for pd in statefulProcessor.handleExpiredTimer( + key=key_obj, + timer_values=TimerValues(batch_timestamp, watermark_timestamp), + expired_timer_info=ExpiredTimerInfo(expiry_timestamp), + ): + yield pd + statefulProcessorApiClient.delete_timer(expiry_timestamp) def transformWithStateUDF( statefulProcessorApiClient: StatefulProcessorApiClient, + mode: TransformWithStateInPandasFuncMode, key: Any, inputRows: Iterator["PandasDataFrameLike"], ) -> Iterator["PandasDataFrameLike"]: @@ -566,19 +568,28 @@ def transformWithStateUDF( StatefulProcessorHandleState.INITIALIZED ) - # Key is None when we have processed all the input data from the worker and ready to - # proceed with the cleanup steps. - if key is None: + if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER: + statefulProcessorApiClient.set_handle_state( + StatefulProcessorHandleState.DATA_PROCESSED + ) + result = handle_expired_timers(statefulProcessorApiClient) + return result + elif mode == TransformWithStateInPandasFuncMode.COMPLETE: + statefulProcessorApiClient.set_handle_state( + StatefulProcessorHandleState.TIMER_PROCESSED + ) statefulProcessorApiClient.remove_implicit_key() statefulProcessor.close() statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) return iter([]) - - result = handle_data_with_timers(statefulProcessorApiClient, key, inputRows) - return result + else: + # mode == TransformWithStateInPandasFuncMode.PROCESS_DATA + result = handle_data_rows(statefulProcessorApiClient, key, inputRows) + return result def transformWithStateWithInitStateUDF( statefulProcessorApiClient: StatefulProcessorApiClient, + mode: TransformWithStateInPandasFuncMode, key: Any, inputRows: Iterator["PandasDataFrameLike"], initialStates: Optional[Iterator["PandasDataFrameLike"]] = None, @@ -603,20 +614,30 @@ def transformWithStateWithInitStateUDF( StatefulProcessorHandleState.INITIALIZED ) - # Key is None when we have processed all the input data from the worker and ready to - # proceed with the cleanup steps. - if key is None: + if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER: + statefulProcessorApiClient.set_handle_state( + StatefulProcessorHandleState.DATA_PROCESSED + ) + result = handle_expired_timers(statefulProcessorApiClient) + return result + elif mode == TransformWithStateInPandasFuncMode.COMPLETE: statefulProcessorApiClient.remove_implicit_key() statefulProcessor.close() statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) return iter([]) + else: + # mode == TransformWithStateInPandasFuncMode.PROCESS_DATA + batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps( + timeMode + ) # only process initial state if first batch and initial state is not None if initialStates is not None: for cur_initial_state in initialStates: statefulProcessorApiClient.set_implicit_key(key) - # TODO(SPARK-50194) integration with new timer API with initial state - statefulProcessor.handleInitialState(key, cur_initial_state) + statefulProcessor.handleInitialState( + key, cur_initial_state, TimerValues(batch_timestamp, watermark_timestamp) + ) # if we don't have input rows for the given key but only have initial state # for the grouping key, the inputRows iterator could be empty @@ -629,7 +650,7 @@ def transformWithStateWithInitStateUDF( inputRows = itertools.chain([first], inputRows) if not input_rows_empty: - result = handle_data_with_timers(statefulProcessorApiClient, key, inputRows) + result = handle_data_rows(statefulProcessorApiClient, key, inputRows) else: result = iter([]) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 5bf07b87400fe..536bf7307065c 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -36,6 +36,7 @@ _create_converter_from_pandas, _create_converter_to_pandas, ) +from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode from pyspark.sql.types import ( DataType, StringType, @@ -1197,7 +1198,11 @@ def generate_data_batches(batches): data_batches = generate_data_batches(_batches) for k, g in groupby(data_batches, key=lambda x: x[0]): - yield (k, g) + yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g) + + yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) + + yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) def dump_stream(self, iterator, stream): """ @@ -1281,4 +1286,8 @@ def flatten_columns(cur_batch, col_name): data_batches = generate_data_batches(_batches) for k, g in groupby(data_batches, key=lambda x: x[0]): - yield (k, g) + yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g) + + yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) + + yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) diff --git a/python/pyspark/sql/streaming/stateful_processor.py b/python/pyspark/sql/streaming/stateful_processor.py index 20078c215bace..9caa9304d6a87 100644 --- a/python/pyspark/sql/streaming/stateful_processor.py +++ b/python/pyspark/sql/streaming/stateful_processor.py @@ -105,21 +105,13 @@ def get_current_watermark_in_ms(self) -> int: class ExpiredTimerInfo: """ - Class used for arbitrary stateful operations with transformWithState to access expired timer - info. When is_valid is false, the expiry timestamp is invalid. + Class used to provide access to expired timer's expiry time. .. versionadded:: 4.0.0 """ - def __init__(self, is_valid: bool, expiry_time_in_ms: int = -1) -> None: - self._is_valid = is_valid + def __init__(self, expiry_time_in_ms: int = -1) -> None: self._expiry_time_in_ms = expiry_time_in_ms - def is_valid(self) -> bool: - """ - Whether the expiry info is valid. - """ - return self._is_valid - def get_expiry_time_in_ms(self) -> int: """ Get the timestamp for expired timer, return timestamp in millisecond. @@ -398,7 +390,6 @@ def handleInputRows( key: Any, rows: Iterator["PandasDataFrameLike"], timer_values: TimerValues, - expired_timer_info: ExpiredTimerInfo, ) -> Iterator["PandasDataFrameLike"]: """ Function that will allow users to interact with input data rows along with the grouping key. @@ -420,11 +411,29 @@ def handleInputRows( timer_values: TimerValues Timer value for the current batch that process the input rows. Users can get the processing or event time timestamp from TimerValues. - expired_timer_info: ExpiredTimerInfo - Timestamp of expired timers on the grouping key. """ ... + def handleExpiredTimer( + self, key: Any, timer_values: TimerValues, expired_timer_info: ExpiredTimerInfo + ) -> Iterator["PandasDataFrameLike"]: + """ + Optional to implement. Will act return an empty iterator if not defined. + Function that will be invoked when a timer is fired for a given key. Users can choose to + evict state, register new timers and optionally provide output rows. + + Parameters + ---------- + key : Any + grouping key. + timer_values: TimerValues + Timer value for the current batch that process the input rows. + Users can get the processing or event time timestamp from TimerValues. + expired_timer_info: ExpiredTimerInfo + Instance of ExpiredTimerInfo that provides access to expired timer. + """ + return iter([]) + @abstractmethod def close(self) -> None: """ @@ -433,9 +442,21 @@ def close(self) -> None: """ ... - def handleInitialState(self, key: Any, initialState: "PandasDataFrameLike") -> None: + def handleInitialState( + self, key: Any, initialState: "PandasDataFrameLike", timer_values: TimerValues + ) -> None: """ Optional to implement. Will act as no-op if not defined or no initial state input. Function that will be invoked only in the first batch for users to process initial states. + + Parameters + ---------- + key : Any + grouping key. + initialState: :class:`pandas.DataFrame` + One dataframe in the initial state associated with the key. + timer_values: TimerValues + Timer value for the current batch that process the input rows. + Users can get the processing or event time timestamp from TimerValues. """ pass diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index 353f75e267962..53704188081c3 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -62,6 +62,10 @@ def __init__(self, state_server_port: int, key_schema: StructType) -> None: # Dictionaries to store the mapping between iterator id and a tuple of pandas DataFrame # and the index of the last row that was read. self.list_timer_iterator_cursors: Dict[str, Tuple["PandasDataFrameLike", int]] = {} + # statefulProcessorApiClient is initialized per batch per partition, + # so we will have new timestamps for a new batch + self._batch_timestamp = -1 + self._watermark_timestamp = -1 def set_handle_state(self, state: StatefulProcessorHandleState) -> None: import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage @@ -266,47 +270,15 @@ def get_expiry_timers_iterator( # TODO(SPARK-49233): Classify user facing errors. raise PySparkRuntimeError(f"Error getting expiry timers: " f"{response_message[1]}") - def get_batch_timestamp(self) -> int: - import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage - - get_processing_time_call = stateMessage.GetProcessingTime() - timer_value_call = stateMessage.TimerValueRequest( - getProcessingTimer=get_processing_time_call - ) - timer_request = stateMessage.TimerRequest(timerValueRequest=timer_value_call) - message = stateMessage.StateRequest(timerRequest=timer_request) - - self._send_proto_message(message.SerializeToString()) - response_message = self._receive_proto_message_with_long_value() - status = response_message[0] - if status != 0: - # TODO(SPARK-49233): Classify user facing errors. - raise PySparkRuntimeError( - f"Error getting processing timestamp: " f"{response_message[1]}" - ) + def get_timestamps(self, time_mode: str) -> Tuple[int, int]: + if time_mode.lower() == "none": + return -1, -1 else: - timestamp = response_message[2] - return timestamp - - def get_watermark_timestamp(self) -> int: - import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage - - get_watermark_call = stateMessage.GetWatermark() - timer_value_call = stateMessage.TimerValueRequest(getWatermark=get_watermark_call) - timer_request = stateMessage.TimerRequest(timerValueRequest=timer_value_call) - message = stateMessage.StateRequest(timerRequest=timer_request) - - self._send_proto_message(message.SerializeToString()) - response_message = self._receive_proto_message_with_long_value() - status = response_message[0] - if status != 0: - # TODO(SPARK-49233): Classify user facing errors. - raise PySparkRuntimeError( - f"Error getting eventtime timestamp: " f"{response_message[1]}" - ) - else: - timestamp = response_message[2] - return timestamp + if self._batch_timestamp == -1: + self._batch_timestamp = self._get_batch_timestamp() + if self._watermark_timestamp == -1: + self._watermark_timestamp = self._get_watermark_timestamp() + return self._batch_timestamp, self._watermark_timestamp def get_map_state( self, @@ -353,6 +325,48 @@ def delete_if_exists(self, state_name: str) -> None: # TODO(SPARK-49233): Classify user facing errors. raise PySparkRuntimeError(f"Error deleting state: " f"{response_message[1]}") + def _get_batch_timestamp(self) -> int: + import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage + + get_processing_time_call = stateMessage.GetProcessingTime() + timer_value_call = stateMessage.TimerValueRequest( + getProcessingTimer=get_processing_time_call + ) + timer_request = stateMessage.TimerRequest(timerValueRequest=timer_value_call) + message = stateMessage.StateRequest(timerRequest=timer_request) + + self._send_proto_message(message.SerializeToString()) + response_message = self._receive_proto_message_with_long_value() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError( + f"Error getting processing timestamp: " f"{response_message[1]}" + ) + else: + timestamp = response_message[2] + return timestamp + + def _get_watermark_timestamp(self) -> int: + import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage + + get_watermark_call = stateMessage.GetWatermark() + timer_value_call = stateMessage.TimerValueRequest(getWatermark=get_watermark_call) + timer_request = stateMessage.TimerRequest(timerValueRequest=timer_value_call) + message = stateMessage.StateRequest(timerRequest=timer_request) + + self._send_proto_message(message.SerializeToString()) + response_message = self._receive_proto_message_with_long_value() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError( + f"Error getting eventtime timestamp: " f"{response_message[1]}" + ) + else: + timestamp = response_message[2] + return timestamp + def _send_proto_message(self, message: bytes) -> None: # Writing zero here to indicate message version. This allows us to evolve the message # format or even changing the message protocol in the future. diff --git a/python/pyspark/sql/streaming/stateful_processor_util.py b/python/pyspark/sql/streaming/stateful_processor_util.py new file mode 100644 index 0000000000000..6130a9581bc24 --- /dev/null +++ b/python/pyspark/sql/streaming/stateful_processor_util.py @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from enum import Enum + +# This file places the utilities for transformWithStateInPandas; we have a separate file to avoid +# putting internal classes to the stateful_processor.py file which contains public APIs. + + +class TransformWithStateInPandasFuncMode(Enum): + PROCESS_DATA = 1 + PROCESS_TIMER = 2 + COMPLETE = 3 diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index f385d7cd1abc0..60f2c9348db3f 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -55,6 +55,7 @@ def conf(cls): "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider", ) cfg.set("spark.sql.execution.arrow.transformWithStateInPandas.maxRecordsPerBatch", "2") + cfg.set("spark.sql.session.timeZone", "UTC") return cfg def _prepare_input_data(self, input_path, col1, col2): @@ -558,14 +559,25 @@ def prepare_batch3(input_path): def test_transform_with_state_in_pandas_event_time(self): def check_results(batch_df, batch_id): if batch_id == 0: - assert set(batch_df.sort("id").collect()) == {Row(id="a", timestamp="20")} - elif batch_id == 1: + # watermark for late event = 0 + # watermark for eviction = 0 + # timer is registered with expiration time = 0, hence expired at the same batch assert set(batch_df.sort("id").collect()) == { Row(id="a", timestamp="20"), Row(id="a-expired", timestamp="0"), } + elif batch_id == 1: + # watermark for late event = 0 + # watermark for eviction = 10 (20 - 10) + # timer is registered with expiration time = 10, hence expired at the same batch + assert set(batch_df.sort("id").collect()) == { + Row(id="a", timestamp="4"), + Row(id="a-expired", timestamp="10000"), + } elif batch_id == 2: - # verify that rows and expired timer produce the expected result + # watermark for late event = 10 + # watermark for eviction = 10 (unchanged as 4 < 10) + # timer is registered with expiration time = 10, hence expired at the same batch assert set(batch_df.sort("id").collect()) == { Row(id="a", timestamp="15"), Row(id="a-expired", timestamp="10000"), @@ -578,7 +590,9 @@ def check_results(batch_df, batch_id): EventTimeStatefulProcessor(), check_results ) - def _test_transform_with_state_init_state_in_pandas(self, stateful_processor, check_results): + def _test_transform_with_state_init_state_in_pandas( + self, stateful_processor, check_results, time_mode="None" + ): input_path = tempfile.mkdtemp() self._prepare_test_resource1(input_path) time.sleep(2) @@ -606,7 +620,7 @@ def _test_transform_with_state_init_state_in_pandas(self, stateful_processor, ch statefulProcessor=stateful_processor, outputStructType=output_schema, outputMode="Update", - timeMode="None", + timeMode=time_mode, initialState=initial_state, ) .writeStream.queryName("this_query") @@ -806,6 +820,45 @@ def check_results(batch_df, batch_id): StatefulProcessorChainingOps(), check_results, "eventTime", ["outputTimestamp", "id"] ) + def test_transform_with_state_init_state_with_timers(self): + def check_results(batch_df, batch_id): + if batch_id == 0: + # timers are registered and handled in the first batch for + # rows in initial state; For key=0 and key=3 which contains + # expired timers, both should be handled by handleExpiredTimers + # regardless of whether key exists in the data rows or not + expired_df = batch_df.filter(batch_df["id"].contains("expired")) + data_df = batch_df.filter(~batch_df["id"].contains("expired")) + assert set(expired_df.sort("id").select("id").collect()) == { + Row(id="0-expired"), + Row(id="3-expired"), + } + assert set(data_df.sort("id").collect()) == { + Row(id="0", value=str(789 + 123 + 46)), + Row(id="1", value=str(146 + 346)), + } + elif batch_id == 1: + # handleInitialState is only processed in the first batch, + # no more timer is registered so no more expired timers + assert set(batch_df.sort("id").collect()) == { + Row(id="0", value=str(789 + 123 + 46 + 67)), + Row(id="3", value=str(987 + 12)), + } + else: + for q in self.spark.streams.active: + q.stop() + + self._test_transform_with_state_init_state_in_pandas( + StatefulProcessorWithInitialStateTimers(), check_results, "processingTime" + ) + + # run the same test suites again but with single shuffle partition + def test_transform_with_state_with_timers_single_partition(self): + with self.sql_conf({"spark.sql.shuffle.partitions": "1"}): + self.test_transform_with_state_init_state_with_timers() + self.test_transform_with_state_in_pandas_event_time() + self.test_transform_with_state_in_pandas_proc_timer() + class SimpleStatefulProcessorWithInitialState(StatefulProcessor): # this dict is the same as input initial state dataframe @@ -814,10 +867,9 @@ class SimpleStatefulProcessorWithInitialState(StatefulProcessor): def init(self, handle: StatefulProcessorHandle) -> None: state_schema = StructType([StructField("value", IntegerType(), True)]) self.value_state = handle.getValueState("value_state", state_schema) + self.handle = handle - def handleInputRows( - self, key, rows, timer_values, expired_timer_info - ) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: exists = self.value_state.exists() if exists: value_row = self.value_state.get() @@ -840,7 +892,7 @@ def handleInputRows( else: yield pd.DataFrame({"id": key, "value": str(accumulated_value)}) - def handleInitialState(self, key, initialState) -> None: + def handleInitialState(self, key, initialState, timer_values) -> None: init_val = initialState.at[0, "initVal"] self.value_state.update((init_val,)) if len(key) == 1: @@ -850,6 +902,19 @@ def close(self) -> None: pass +class StatefulProcessorWithInitialStateTimers(SimpleStatefulProcessorWithInitialState): + def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[pd.DataFrame]: + self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) + str_key = f"{str(key[0])}-expired" + yield pd.DataFrame( + {"id": (str_key,), "value": str(expired_timer_info.get_expiry_time_in_ms())} + ) + + def handleInitialState(self, key, initialState, timer_values) -> None: + super().handleInitialState(key, initialState, timer_values) + self.handle.registerTimer(timer_values.get_current_processing_time_in_ms() - 1) + + # A stateful processor that output the max event time it has seen. Register timer for # current watermark. Clear max state if timer expires. class EventTimeStatefulProcessor(StatefulProcessor): @@ -858,33 +923,30 @@ def init(self, handle: StatefulProcessorHandle) -> None: self.handle = handle self.max_state = handle.getValueState("max_state", state_schema) - def handleInputRows( - self, key, rows, timer_values, expired_timer_info - ) -> Iterator[pd.DataFrame]: - if expired_timer_info.is_valid(): - self.max_state.clear() - self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) - str_key = f"{str(key[0])}-expired" - yield pd.DataFrame( - {"id": (str_key,), "timestamp": str(expired_timer_info.get_expiry_time_in_ms())} - ) + def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[pd.DataFrame]: + self.max_state.clear() + self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) + str_key = f"{str(key[0])}-expired" + yield pd.DataFrame( + {"id": (str_key,), "timestamp": str(expired_timer_info.get_expiry_time_in_ms())} + ) - else: - timestamp_list = [] - for pdf in rows: - # int64 will represent timestamp in nanosecond, restore to second - timestamp_list.extend((pdf["eventTime"].astype("int64") // 10**9).tolist()) + def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + timestamp_list = [] + for pdf in rows: + # int64 will represent timestamp in nanosecond, restore to second + timestamp_list.extend((pdf["eventTime"].astype("int64") // 10**9).tolist()) - if self.max_state.exists(): - cur_max = int(self.max_state.get()[0]) - else: - cur_max = 0 - max_event_time = str(max(cur_max, max(timestamp_list))) + if self.max_state.exists(): + cur_max = int(self.max_state.get()[0]) + else: + cur_max = 0 + max_event_time = str(max(cur_max, max(timestamp_list))) - self.max_state.update((max_event_time,)) - self.handle.registerTimer(timer_values.get_current_watermark_in_ms()) + self.max_state.update((max_event_time,)) + self.handle.registerTimer(timer_values.get_current_watermark_in_ms()) - yield pd.DataFrame({"id": key, "timestamp": max_event_time}) + yield pd.DataFrame({"id": key, "timestamp": max_event_time}) def close(self) -> None: pass @@ -898,54 +960,49 @@ def init(self, handle: StatefulProcessorHandle) -> None: self.handle = handle self.count_state = handle.getValueState("count_state", state_schema) - def handleInputRows( - self, key, rows, timer_values, expired_timer_info - ) -> Iterator[pd.DataFrame]: - if expired_timer_info.is_valid(): - # reset count state each time the timer is expired - timer_list_1 = [e for e in self.handle.listTimers()] - timer_list_2 = [] - idx = 0 - for e in self.handle.listTimers(): - timer_list_2.append(e) - # check multiple iterator on the same grouping key works - assert timer_list_2[idx] == timer_list_1[idx] - idx += 1 - - if len(timer_list_1) > 0: - # before deleting the expiring timers, there are 2 timers - - # one timer we just registered, and one that is going to be deleted - assert len(timer_list_1) == 2 - self.count_state.clear() - self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) - yield pd.DataFrame( - { - "id": key, - "countAsString": str("-1"), - "timeValues": str(expired_timer_info.get_expiry_time_in_ms()), - } - ) + def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[pd.DataFrame]: + # reset count state each time the timer is expired + timer_list_1 = [e for e in self.handle.listTimers()] + timer_list_2 = [] + idx = 0 + for e in self.handle.listTimers(): + timer_list_2.append(e) + # check multiple iterator on the same grouping key works + assert timer_list_2[idx] == timer_list_1[idx] + idx += 1 + + if len(timer_list_1) > 0: + assert len(timer_list_1) == 2 + self.count_state.clear() + self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) + yield pd.DataFrame( + { + "id": key, + "countAsString": str("-1"), + "timeValues": str(expired_timer_info.get_expiry_time_in_ms()), + } + ) + def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + if not self.count_state.exists(): + count = 0 else: - if not self.count_state.exists(): - count = 0 - else: - count = int(self.count_state.get()[0]) + count = int(self.count_state.get()[0]) - if key == ("0",): - self.handle.registerTimer(timer_values.get_current_processing_time_in_ms()) + if key == ("0",): + self.handle.registerTimer(timer_values.get_current_processing_time_in_ms() + 1) - rows_count = 0 - for pdf in rows: - pdf_count = len(pdf) - rows_count += pdf_count + rows_count = 0 + for pdf in rows: + pdf_count = len(pdf) + rows_count += pdf_count - count = count + rows_count + count = count + rows_count - self.count_state.update((str(count),)) - timestamp = str(timer_values.get_current_processing_time_in_ms()) + self.count_state.update((str(count),)) + timestamp = str(timer_values.get_current_processing_time_in_ms()) - yield pd.DataFrame({"id": key, "countAsString": str(count), "timeValues": timestamp}) + yield pd.DataFrame({"id": key, "countAsString": str(count), "timeValues": timestamp}) def close(self) -> None: pass @@ -961,9 +1018,7 @@ def init(self, handle: StatefulProcessorHandle) -> None: self.temp_state = handle.getValueState("tempState", state_schema) handle.deleteIfExists("tempState") - def handleInputRows( - self, key, rows, timer_values, expired_timer_info - ) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: with self.assertRaisesRegex(PySparkRuntimeError, "Error checking value state exists"): self.temp_state.exists() new_violations = 0 @@ -995,9 +1050,7 @@ class StatefulProcessorChainingOps(StatefulProcessor): def init(self, handle: StatefulProcessorHandle) -> None: pass - def handleInputRows( - self, key, rows, timer_values, expired_timer_info - ) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: for pdf in rows: timestamp_list = pdf["eventTime"].tolist() yield pd.DataFrame({"id": key, "outputTimestamp": timestamp_list[0]}) @@ -1027,9 +1080,7 @@ def init(self, handle: StatefulProcessorHandle) -> None: "ttl-map-state", user_key_schema, state_schema, 10000 ) - def handleInputRows( - self, key, rows, timer_values, expired_timer_info - ) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: count = 0 ttl_count = 0 ttl_list_state_count = 0 @@ -1079,9 +1130,7 @@ def init(self, handle: StatefulProcessorHandle) -> None: state_schema = StructType([StructField("value", IntegerType(), True)]) self.num_violations_state = handle.getValueState("numViolations", state_schema) - def handleInputRows( - self, key, rows, timer_values, expired_timer_info - ) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: count = 0 exists = self.num_violations_state.exists() assert not exists @@ -1105,9 +1154,7 @@ def init(self, handle: StatefulProcessorHandle) -> None: self.list_state1 = handle.getListState("listState1", state_schema) self.list_state2 = handle.getListState("listState2", state_schema) - def handleInputRows( - self, key, rows, timer_values, expired_timer_info - ) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: count = 0 for pdf in rows: list_state_rows = [(120,), (20,)] @@ -1162,9 +1209,7 @@ def init(self, handle: StatefulProcessorHandle): value_schema = StructType([StructField("count", IntegerType(), True)]) self.map_state = handle.getMapState("mapState", key_schema, value_schema) - def handleInputRows( - self, key, rows, timer_values, expired_timer_info - ) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: count = 0 key1 = ("key1",) key2 = ("key2",) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 04f95e9f52648..1ebc04520ecad 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -34,6 +34,7 @@ _deserialize_accumulator, ) from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient +from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode from pyspark.taskcontext import BarrierTaskContext, TaskContext from pyspark.resource import ResourceInformation from pyspark.util import PythonEvalType, local_connect_and_auth @@ -493,36 +494,36 @@ def wrapped(key_series, value_series): def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf): - def wrapped(stateful_processor_api_client, key, value_series_gen): + def wrapped(stateful_processor_api_client, mode, key, value_series_gen): import pandas as pd values = (pd.concat(x, axis=1) for x in value_series_gen) - result_iter = f(stateful_processor_api_client, key, values) + result_iter = f(stateful_processor_api_client, mode, key, values) # TODO(SPARK-49100): add verification that elements in result_iter are # indeed of type pd.DataFrame and confirm to assigned cols return result_iter - return lambda p, k, v: [(wrapped(p, k, v), to_arrow_type(return_type))] + return lambda p, m, k, v: [(wrapped(p, m, k, v), to_arrow_type(return_type))] def wrap_grouped_transform_with_state_pandas_init_state_udf(f, return_type, runner_conf): - def wrapped(stateful_processor_api_client, key, value_series_gen): + def wrapped(stateful_processor_api_client, mode, key, value_series_gen): import pandas as pd state_values_gen, init_states_gen = itertools.tee(value_series_gen, 2) state_values = (df for x, _ in state_values_gen if not (df := pd.concat(x, axis=1)).empty) init_states = (df for _, x in init_states_gen if not (df := pd.concat(x, axis=1)).empty) - result_iter = f(stateful_processor_api_client, key, state_values, init_states) + result_iter = f(stateful_processor_api_client, mode, key, state_values, init_states) # TODO(SPARK-49100): add verification that elements in result_iter are # indeed of type pd.DataFrame and confirm to assigned cols return result_iter - return lambda p, k, v: [(wrapped(p, k, v), to_arrow_type(return_type))] + return lambda p, m, k, v: [(wrapped(p, m, k, v), to_arrow_type(return_type))] def wrap_grouped_map_pandas_udf_with_state(f, return_type): @@ -1697,18 +1698,22 @@ def mapper(a): ser.key_offsets = parsed_offsets[0][0] stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema) - # Create function like this: - # mapper a: f([a[0]], [a[0], a[1]]) def mapper(a): - key = a[0] + mode = a[0] - def values_gen(): - for x in a[1]: - retVal = [x[1][o] for o in parsed_offsets[0][1]] - yield retVal + if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: + key = a[1] - # This must be generator comprehension - do not materialize. - return f(stateful_processor_api_client, key, values_gen()) + def values_gen(): + for x in a[2]: + retVal = [x[1][o] for o in parsed_offsets[0][1]] + yield retVal + + # This must be generator comprehension - do not materialize. + return f(stateful_processor_api_client, mode, key, values_gen()) + else: + # mode == PROCESS_TIMER or mode == COMPLETE + return f(stateful_processor_api_client, mode, None, iter([])) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF: # We assume there is only one UDF here because grouped map doesn't @@ -1731,16 +1736,22 @@ def values_gen(): stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema) def mapper(a): - key = a[0] + mode = a[0] - def values_gen(): - for x in a[1]: - retVal = [x[1][o] for o in parsed_offsets[0][1]] - initVal = [x[2][o] for o in parsed_offsets[1][1]] - yield retVal, initVal + if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: + key = a[1] - # This must be generator comprehension - do not materialize. - return f(stateful_processor_api_client, key, values_gen()) + def values_gen(): + for x in a[2]: + retVal = [x[1][o] for o in parsed_offsets[0][1]] + initVal = [x[2][o] for o in parsed_offsets[1][1]] + yield retVal, initVal + + # This must be generator comprehension - do not materialize. + return f(stateful_processor_api_client, mode, key, values_gen()) + else: + # mode == PROCESS_TIMER or mode == COMPLETE + return f(stateful_processor_api_client, mode, None, iter([])) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF: import pyarrow as pa @@ -1958,17 +1969,6 @@ def process(): try: serializer.dump_stream(out_iter, outfile) finally: - # Sending a signal to TransformWithState UDF to perform proper cleanup steps. - if ( - eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF - or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF - ): - # Sending key as None to indicate that process() has finished. - end_iter = func(split_index, iter([(None, None)])) - # Need to materialize the iterator to trigger the cleanup steps, nothing needs - # to be done here. - for _ in end_iter: - pass if hasattr(out_iter, "close"): out_iter.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala index 0373c8607ff2c..2957f4b387580 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala @@ -120,6 +120,8 @@ class TransformWithStateInPandasStateServer( } /** Timer related class variables */ + // An iterator to store all expired timer info. This is meant to be consumed only once per + // partition. This should be called after finishing handling all input rows. private var expiryTimestampIter: Option[Iterator[(Any, Long)]] = if (expiryTimerIterForTest != null) { Option(expiryTimerIterForTest)