From 49fb1cb2cd3fbe05531005a2fc5f4a12b7eb072d Mon Sep 17 00:00:00 2001 From: jingz-db Date: Tue, 12 Nov 2024 12:11:34 -0800 Subject: [PATCH 01/22] new timer API --- python/pyspark/sql/pandas/group_ops.py | 8 +- .../sql/streaming/stateful_processor.py | 31 ++-- .../test_pandas_transform_with_state.py | 146 ++++++++---------- 3 files changed, 86 insertions(+), 99 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 56efe0676c08f..b2d287b1d9cb9 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -516,10 +516,7 @@ def handle_data_with_timers( watermark_timestamp = -1 # process with invalid expiry timer info and emit data rows data_iter = statefulProcessor.handleInputRows( - key, - inputRows, - TimerValues(batch_timestamp, watermark_timestamp), - ExpiredTimerInfo(False), + key, inputRows, TimerValues(batch_timestamp, watermark_timestamp) ) statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.DATA_PROCESSED) @@ -540,9 +537,8 @@ def handle_data_with_timers( for expiry_list in expiry_list_iter: for key_obj, expiry_timestamp in expiry_list: result_iter_list.append( - statefulProcessor.handleInputRows( + statefulProcessor.handleExpiredTimer( key_obj, - iter([]), TimerValues(batch_timestamp, watermark_timestamp), ExpiredTimerInfo(True, expiry_timestamp), ) diff --git a/python/pyspark/sql/streaming/stateful_processor.py b/python/pyspark/sql/streaming/stateful_processor.py index 20078c215bace..562aaed0269f3 100644 --- a/python/pyspark/sql/streaming/stateful_processor.py +++ b/python/pyspark/sql/streaming/stateful_processor.py @@ -105,8 +105,7 @@ def get_current_watermark_in_ms(self) -> int: class ExpiredTimerInfo: """ - Class used for arbitrary stateful operations with transformWithState to access expired timer - info. When is_valid is false, the expiry timestamp is invalid. + Class used to provide access to expired timer's expiry time. .. versionadded:: 4.0.0 """ @@ -114,12 +113,6 @@ def __init__(self, is_valid: bool, expiry_time_in_ms: int = -1) -> None: self._is_valid = is_valid self._expiry_time_in_ms = expiry_time_in_ms - def is_valid(self) -> bool: - """ - Whether the expiry info is valid. - """ - return self._is_valid - def get_expiry_time_in_ms(self) -> int: """ Get the timestamp for expired timer, return timestamp in millisecond. @@ -398,7 +391,6 @@ def handleInputRows( key: Any, rows: Iterator["PandasDataFrameLike"], timer_values: TimerValues, - expired_timer_info: ExpiredTimerInfo, ) -> Iterator["PandasDataFrameLike"]: """ Function that will allow users to interact with input data rows along with the grouping key. @@ -420,10 +412,27 @@ def handleInputRows( timer_values: TimerValues Timer value for the current batch that process the input rows. Users can get the processing or event time timestamp from TimerValues. + """ + return iter([]) + + def handleExpiredTimer( + self, key: Any, timer_values: TimerValues, expired_timer_info: ExpiredTimerInfo + ) -> Iterator["PandasDataFrameLike"]: + """ + Function that will be invoked when a timer is fired for a given key. Users can choose to + evict state, register new timers and optionally provide output rows. + + Parameters + ---------- + key : Any + grouping key. + timer_values: TimerValues + Timer value for the current batch that process the input rows. + Users can get the processing or event time timestamp from TimerValues. expired_timer_info: ExpiredTimerInfo - Timestamp of expired timers on the grouping key. + Instance of ExpiredTimerInfo that provides access to expired timer. """ - ... + return iter([]) @abstractmethod def close(self) -> None: diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 8901f09e9272d..be8116fff990d 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -710,9 +710,7 @@ def init(self, handle: StatefulProcessorHandle) -> None: state_schema = StructType([StructField("value", IntegerType(), True)]) self.value_state = handle.getValueState("value_state", state_schema) - def handleInputRows( - self, key, rows, timer_values, expired_timer_info - ) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: exists = self.value_state.exists() if exists: value_row = self.value_state.get() @@ -753,33 +751,30 @@ def init(self, handle: StatefulProcessorHandle) -> None: self.handle = handle self.max_state = handle.getValueState("max_state", state_schema) - def handleInputRows( - self, key, rows, timer_values, expired_timer_info - ) -> Iterator[pd.DataFrame]: - if expired_timer_info.is_valid(): - self.max_state.clear() - self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) - str_key = f"{str(key[0])}-expired" - yield pd.DataFrame( - {"id": (str_key,), "timestamp": str(expired_timer_info.get_expiry_time_in_ms())} - ) + def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[pd.DataFrame]: + self.max_state.clear() + self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) + str_key = f"{str(key[0])}-expired" + yield pd.DataFrame( + {"id": (str_key,), "timestamp": str(expired_timer_info.get_expiry_time_in_ms())} + ) - else: - timestamp_list = [] - for pdf in rows: - # int64 will represent timestamp in nanosecond, restore to second - timestamp_list.extend((pdf["eventTime"].astype("int64") // 10**9).tolist()) + def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + timestamp_list = [] + for pdf in rows: + # int64 will represent timestamp in nanosecond, restore to second + timestamp_list.extend((pdf["eventTime"].astype("int64") // 10**9).tolist()) - if self.max_state.exists(): - cur_max = int(self.max_state.get()[0]) - else: - cur_max = 0 - max_event_time = str(max(cur_max, max(timestamp_list))) + if self.max_state.exists(): + cur_max = int(self.max_state.get()[0]) + else: + cur_max = 0 + max_event_time = str(max(cur_max, max(timestamp_list))) - self.max_state.update((max_event_time,)) - self.handle.registerTimer(timer_values.get_current_watermark_in_ms()) + self.max_state.update((max_event_time,)) + self.handle.registerTimer(timer_values.get_current_watermark_in_ms()) - yield pd.DataFrame({"id": key, "timestamp": max_event_time}) + yield pd.DataFrame({"id": key, "timestamp": max_event_time}) def close(self) -> None: pass @@ -793,54 +788,51 @@ def init(self, handle: StatefulProcessorHandle) -> None: self.handle = handle self.count_state = handle.getValueState("count_state", state_schema) - def handleInputRows( - self, key, rows, timer_values, expired_timer_info - ) -> Iterator[pd.DataFrame]: - if expired_timer_info.is_valid(): - # reset count state each time the timer is expired - timer_list_1 = [e for e in self.handle.listTimers()] - timer_list_2 = [] - idx = 0 - for e in self.handle.listTimers(): - timer_list_2.append(e) - # check multiple iterator on the same grouping key works - assert timer_list_2[idx] == timer_list_1[idx] - idx += 1 - - if len(timer_list_1) > 0: - # before deleting the expiring timers, there are 2 timers - - # one timer we just registered, and one that is going to be deleted - assert len(timer_list_1) == 2 - self.count_state.clear() - self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) - yield pd.DataFrame( - { - "id": key, - "countAsString": str("-1"), - "timeValues": str(expired_timer_info.get_expiry_time_in_ms()), - } - ) + def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[pd.DataFrame]: + # reset count state each time the timer is expired + timer_list_1 = [e for e in self.handle.listTimers()] + timer_list_2 = [] + idx = 0 + for e in self.handle.listTimers(): + timer_list_2.append(e) + # check multiple iterator on the same grouping key works + assert timer_list_2[idx] == timer_list_1[idx] + idx += 1 + + if len(timer_list_1) > 0: + # before deleting the expiring timers, there are 2 timers - + # one timer we just registered, and one that is going to be deleted + assert len(timer_list_1) == 2 + self.count_state.clear() + self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) + yield pd.DataFrame( + { + "id": key, + "countAsString": str("-1"), + "timeValues": str(expired_timer_info.get_expiry_time_in_ms()), + } + ) + def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + if not self.count_state.exists(): + count = 0 else: - if not self.count_state.exists(): - count = 0 - else: - count = int(self.count_state.get()[0]) + count = int(self.count_state.get()[0]) - if key == ("0",): - self.handle.registerTimer(timer_values.get_current_processing_time_in_ms()) + if key == ("0",): + self.handle.registerTimer(timer_values.get_current_processing_time_in_ms()) - rows_count = 0 - for pdf in rows: - pdf_count = len(pdf) - rows_count += pdf_count + rows_count = 0 + for pdf in rows: + pdf_count = len(pdf) + rows_count += pdf_count - count = count + rows_count + count = count + rows_count - self.count_state.update((str(count),)) - timestamp = str(timer_values.get_current_processing_time_in_ms()) + self.count_state.update((str(count),)) + timestamp = str(timer_values.get_current_processing_time_in_ms()) - yield pd.DataFrame({"id": key, "countAsString": str(count), "timeValues": timestamp}) + yield pd.DataFrame({"id": key, "countAsString": str(count), "timeValues": timestamp}) def close(self) -> None: pass @@ -856,9 +848,7 @@ def init(self, handle: StatefulProcessorHandle) -> None: self.temp_state = handle.getValueState("tempState", state_schema) handle.deleteIfExists("tempState") - def handleInputRows( - self, key, rows, timer_values, expired_timer_info - ) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: with self.assertRaisesRegex(PySparkRuntimeError, "Error checking value state exists"): self.temp_state.exists() new_violations = 0 @@ -907,9 +897,7 @@ def init(self, handle: StatefulProcessorHandle) -> None: "ttl-map-state", user_key_schema, state_schema, 10000 ) - def handleInputRows( - self, key, rows, timer_values, expired_timer_info - ) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: count = 0 ttl_count = 0 ttl_list_state_count = 0 @@ -959,9 +947,7 @@ def init(self, handle: StatefulProcessorHandle) -> None: state_schema = StructType([StructField("value", IntegerType(), True)]) self.num_violations_state = handle.getValueState("numViolations", state_schema) - def handleInputRows( - self, key, rows, timer_values, expired_timer_info - ) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: count = 0 exists = self.num_violations_state.exists() assert not exists @@ -985,9 +971,7 @@ def init(self, handle: StatefulProcessorHandle) -> None: self.list_state1 = handle.getListState("listState1", state_schema) self.list_state2 = handle.getListState("listState2", state_schema) - def handleInputRows( - self, key, rows, timer_values, expired_timer_info - ) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: count = 0 for pdf in rows: list_state_rows = [(120,), (20,)] @@ -1042,9 +1026,7 @@ def init(self, handle: StatefulProcessorHandle): value_schema = StructType([StructField("count", IntegerType(), True)]) self.map_state = handle.getMapState("mapState", key_schema, value_schema) - def handleInputRows( - self, key, rows, timer_values, expired_timer_info - ) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: count = 0 key1 = ("key1",) key2 = ("key2",) From 585e268c58872ec1c8a3d15d2720575678b1c362 Mon Sep 17 00:00:00 2001 From: jingz-db Date: Tue, 12 Nov 2024 13:56:28 -0800 Subject: [PATCH 02/22] add timer integration for initial state --- python/pyspark/sql/pandas/group_ops.py | 32 +++++++++++++------ .../sql/streaming/proto/StateMessage_pb2.py | 4 --- .../sql/streaming/stateful_processor.py | 16 +++++++++- .../test_pandas_transform_with_state.py | 2 +- 4 files changed, 38 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index b2d287b1d9cb9..7c0ff2b8ae82b 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -506,14 +506,10 @@ def handle_data_with_timers( statefulProcessorApiClient: StatefulProcessorApiClient, key: Any, inputRows: Iterator["PandasDataFrameLike"], + batch_timestamp: int, + watermark_timestamp: int, ) -> Iterator["PandasDataFrameLike"]: statefulProcessorApiClient.set_implicit_key(key) - if timeMode != "none": - batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() - watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() - else: - batch_timestamp = -1 - watermark_timestamp = -1 # process with invalid expiry timer info and emit data rows data_iter = statefulProcessor.handleInputRows( key, inputRows, TimerValues(batch_timestamp, watermark_timestamp) @@ -569,7 +565,15 @@ def transformWithStateUDF( statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) return iter([]) - result = handle_data_with_timers(statefulProcessorApiClient, key, inputRows) + if timeMode != "none": + batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() + watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() + else: + batch_timestamp = -1 + watermark_timestamp = -1 + + result = handle_data_with_timers( + statefulProcessorApiClient, key, inputRows, batch_timestamp, watermark_timestamp) return result def transformWithStateWithInitStateUDF( @@ -606,12 +610,19 @@ def transformWithStateWithInitStateUDF( statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) return iter([]) + if timeMode != "none": + batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() + watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() + else: + batch_timestamp = -1 + watermark_timestamp = -1 + # only process initial state if first batch and initial state is not None if initialStates is not None: for cur_initial_state in initialStates: statefulProcessorApiClient.set_implicit_key(key) - # TODO(SPARK-50194) integration with new timer API with initial state - statefulProcessor.handleInitialState(key, cur_initial_state) + statefulProcessor.handleInitialState( + key, cur_initial_state, TimerValues(batch_timestamp, watermark_timestamp)) # if we don't have input rows for the given key but only have initial state # for the grouping key, the inputRows iterator could be empty @@ -624,7 +635,8 @@ def transformWithStateWithInitStateUDF( inputRows = itertools.chain([first], inputRows) if not input_rows_empty: - result = handle_data_with_timers(statefulProcessorApiClient, key, inputRows) + result = handle_data_with_timers( + statefulProcessorApiClient, key, inputRows, batch_timestamp, watermark_timestamp) else: result = iter([]) diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py index 46bed10c45588..589cdb8d371fb 100644 --- a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py +++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py @@ -22,13 +22,9 @@ """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder -_runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, 5, 27, 3, "", "StateMessage.proto" -) # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() diff --git a/python/pyspark/sql/streaming/stateful_processor.py b/python/pyspark/sql/streaming/stateful_processor.py index 562aaed0269f3..6c3335e79cbbf 100644 --- a/python/pyspark/sql/streaming/stateful_processor.py +++ b/python/pyspark/sql/streaming/stateful_processor.py @@ -442,9 +442,23 @@ def close(self) -> None: """ ... - def handleInitialState(self, key: Any, initialState: "PandasDataFrameLike") -> None: + def handleInitialState( + self, + key: Any, + initialState: "PandasDataFrameLike", + timer_values: TimerValues) -> None: """ Optional to implement. Will act as no-op if not defined or no initial state input. Function that will be invoked only in the first batch for users to process initial states. + + Parameters + ---------- + key : Any + grouping key. + initialState: :class:`pandas.DataFrame` + One dataframe in the initial state associated with the key. + timer_values: TimerValues + Timer value for the current batch that process the input rows. + Users can get the processing or event time timestamp from TimerValues. """ pass diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index be8116fff990d..a6cbb6f59ecd0 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -733,7 +733,7 @@ def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: else: yield pd.DataFrame({"id": key, "value": str(accumulated_value)}) - def handleInitialState(self, key, initialState) -> None: + def handleInitialState(self, key, initialState, timer_values) -> None: init_val = initialState.at[0, "initVal"] self.value_state.update((init_val,)) if len(key) == 1: From a0a53cf318f4cb6f8b37bb3c70ca895eb9c3796b Mon Sep 17 00:00:00 2001 From: jingz-db Date: Wed, 13 Nov 2024 14:05:06 -0800 Subject: [PATCH 03/22] weird bug --- python/pyspark/sql/pandas/group_ops.py | 43 +++++++++++++------ .../sql/streaming/stateful_processor.py | 3 +- .../test_pandas_transform_with_state.py | 42 +++++++++++++++++- python/pyspark/worker.py | 11 ----- ...ransformWithStateInPandasStateServer.scala | 4 +- 5 files changed, 73 insertions(+), 30 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 7c0ff2b8ae82b..3890ad72cd896 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -505,29 +505,32 @@ def transformWithStateInPandas( def handle_data_with_timers( statefulProcessorApiClient: StatefulProcessorApiClient, key: Any, - inputRows: Iterator["PandasDataFrameLike"], batch_timestamp: int, watermark_timestamp: int, + inputRows: Optional[Iterator["PandasDataFrameLike"]] = None, ) -> Iterator["PandasDataFrameLike"]: statefulProcessorApiClient.set_implicit_key(key) # process with invalid expiry timer info and emit data rows - data_iter = statefulProcessor.handleInputRows( - key, inputRows, TimerValues(batch_timestamp, watermark_timestamp) - ) - statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.DATA_PROCESSED) + if inputRows is not None: + data_iter = statefulProcessor.handleInputRows( + key, inputRows, TimerValues(batch_timestamp, watermark_timestamp) + ) + result_iter_list = [data_iter] + statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.DATA_PROCESSED) + else: + result_iter_list = [] - if timeMode == "processingtime": + if timeMode.lower() == "processingtime": expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( batch_timestamp ) - elif timeMode == "eventtime": + elif timeMode.lower() == "eventtime": expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( watermark_timestamp ) else: expiry_list_iter = iter([[]]) - result_iter_list = [data_iter] # process with valid expiry time info and with empty input rows, # only timer related rows will be emitted for expiry_list in expiry_list_iter: @@ -536,19 +539,24 @@ def handle_data_with_timers( statefulProcessor.handleExpiredTimer( key_obj, TimerValues(batch_timestamp, watermark_timestamp), - ExpiredTimerInfo(True, expiry_timestamp), + ExpiredTimerInfo(expiry_timestamp), ) ) # TODO(SPARK-49603) set the handle state in the lazily initialized iterator - result = itertools.chain(*result_iter_list) - return result + print(f"just before return result, key: {key}") + if len(result_iter_list) == 0: + return iter([]) + else: + result = itertools.chain(*result_iter_list) + return result def transformWithStateUDF( statefulProcessorApiClient: StatefulProcessorApiClient, key: Any, inputRows: Iterator["PandasDataFrameLike"], ) -> Iterator["PandasDataFrameLike"]: + print(f"I am inside tws with udf, key: {key}\n") handle = StatefulProcessorHandle(statefulProcessorApiClient) if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED: @@ -573,7 +581,7 @@ def transformWithStateUDF( watermark_timestamp = -1 result = handle_data_with_timers( - statefulProcessorApiClient, key, inputRows, batch_timestamp, watermark_timestamp) + statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, inputRows) return result def transformWithStateWithInitStateUDF( @@ -594,6 +602,7 @@ def transformWithStateWithInitStateUDF( - `initialStates` is None, while `inputRows` is not empty. This is not first batch. `initialStates` is initialized to the positional value as None. """ + print(f"I am inside tws with udf with init, key: {key}\n") handle = StatefulProcessorHandle(statefulProcessorApiClient) if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED: @@ -604,11 +613,14 @@ def transformWithStateWithInitStateUDF( # Key is None when we have processed all the input data from the worker and ready to # proceed with the cleanup steps. + """ if key is None: + print(f"I am inside tws, key is none") statefulProcessorApiClient.remove_implicit_key() statefulProcessor.close() statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) return iter([]) + """ if timeMode != "none": batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() @@ -636,9 +648,12 @@ def transformWithStateWithInitStateUDF( if not input_rows_empty: result = handle_data_with_timers( - statefulProcessorApiClient, key, inputRows, batch_timestamp, watermark_timestamp) + statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, inputRows) else: - result = iter([]) + # if the input rows is empty, we still need to handle the expired timers registered + # in the initial state + result = handle_data_with_timers( + statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, None) return result diff --git a/python/pyspark/sql/streaming/stateful_processor.py b/python/pyspark/sql/streaming/stateful_processor.py index 6c3335e79cbbf..7082f0506e21f 100644 --- a/python/pyspark/sql/streaming/stateful_processor.py +++ b/python/pyspark/sql/streaming/stateful_processor.py @@ -109,8 +109,7 @@ class ExpiredTimerInfo: .. versionadded:: 4.0.0 """ - def __init__(self, is_valid: bool, expiry_time_in_ms: int = -1) -> None: - self._is_valid = is_valid + def __init__(self, expiry_time_in_ms: int = -1) -> None: self._expiry_time_in_ms = expiry_time_in_ms def get_expiry_time_in_ms(self) -> int: diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index a6cbb6f59ecd0..fbd696e0a71fd 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -103,6 +103,7 @@ def build_test_df_with_3_cols(self, input_path): ) return df_final + """ def _test_transform_with_state_in_pandas_basic( self, stateful_processor, check_results, single_batch=False, timeMode="None" ): @@ -560,8 +561,10 @@ def check_results(batch_df, batch_id): self._test_transform_with_state_in_pandas_event_time( EventTimeStatefulProcessor(), check_results ) + """ - def _test_transform_with_state_init_state_in_pandas(self, stateful_processor, check_results): + def _test_transform_with_state_init_state_in_pandas( + self, stateful_processor, check_results, time_mode="None"): input_path = tempfile.mkdtemp() self._prepare_test_resource1(input_path) time.sleep(2) @@ -589,7 +592,7 @@ def _test_transform_with_state_init_state_in_pandas(self, stateful_processor, ch statefulProcessor=stateful_processor, outputStructType=output_schema, outputMode="Update", - timeMode="None", + timeMode=time_mode, initialState=initial_state, ) .writeStream.queryName("this_query") @@ -604,6 +607,7 @@ def _test_transform_with_state_init_state_in_pandas(self, stateful_processor, ch q.awaitTermination(10) self.assertTrue(q.exception() is None) + """ def test_transform_with_state_init_state_in_pandas(self): def check_results(batch_df, batch_id): if batch_id == 0: @@ -700,6 +704,25 @@ def check_results(batch_df, batch_id): self._test_transform_with_state_non_contiguous_grouping_cols( SimpleStatefulProcessorWithInitialState(), check_results, initial_state ) + """ + + def test_transform_with_state_init_state_with_timers(self): + def check_results(batch_df, batch_id): + if batch_id == 0: + assert set(batch_df.sort("id").collect()) == { + Row(id="0", value=str(789 + 123 + 46)), + Row(id="1", value=str(146 + 346)), + } + else: + raise Exception(f"i am in batch id {batch_id}, batchdf: ${batch_df.collect()}") + assert set(batch_df.sort("id").collect()) == { + Row(id="0", value=str(789 + 123 + 46 + 67)), + Row(id="3", value=str(987 + 12)), + } + + self._test_transform_with_state_init_state_in_pandas( + StatefulProcessorWithInitialStateTimers(), check_results, "processingTime" + ) class SimpleStatefulProcessorWithInitialState(StatefulProcessor): @@ -709,6 +732,7 @@ class SimpleStatefulProcessorWithInitialState(StatefulProcessor): def init(self, handle: StatefulProcessorHandle) -> None: state_schema = StructType([StructField("value", IntegerType(), True)]) self.value_state = handle.getValueState("value_state", state_schema) + self.handle = handle def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: exists = self.value_state.exists() @@ -743,6 +767,20 @@ def close(self) -> None: pass +class StatefulProcessorWithInitialStateTimers(SimpleStatefulProcessorWithInitialState): + + def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[pd.DataFrame]: + self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) + str_key = f"{str(key[0])}-expired" + yield pd.DataFrame( + {"id": (str_key,), "timestamp": str(expired_timer_info.get_expiry_time_in_ms())} + ) + + def handleInitialState(self, key, initialState, timer_values) -> None: + super().handleInitialState(key, initialState, timer_values) + self.handle.registerTimer(timer_values.get_current_processing_time_in_ms()) + + # A stateful processor that output the max event time it has seen. Register timer for # current watermark. Clear max state if timer expires. class EventTimeStatefulProcessor(StatefulProcessor): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 04f95e9f52648..10418f0487c94 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1958,17 +1958,6 @@ def process(): try: serializer.dump_stream(out_iter, outfile) finally: - # Sending a signal to TransformWithState UDF to perform proper cleanup steps. - if ( - eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF - or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF - ): - # Sending key as None to indicate that process() has finished. - end_iter = func(split_index, iter([(None, None)])) - # Need to materialize the iterator to trigger the cleanup steps, nothing needs - # to be done here. - for _ in end_iter: - pass if hasattr(out_iter, "close"): out_iter.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala index 0373c8607ff2c..343f8abc9f7f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala @@ -214,6 +214,7 @@ class TransformWithStateInPandasStateServer( // this implementation is safe val expiryRequest = message.getExpiryTimerRequest() val expiryTimestamp = expiryRequest.getExpiryTimestampMs + println(s"I am getting expired timer on server, timestamp: ${expiryTimestamp}") if (!expiryTimestampIter.isDefined) { expiryTimestampIter = Option(statefulProcessorHandle.getExpiredTimers(expiryTimestamp)) @@ -251,7 +252,7 @@ class TransformWithStateInPandasStateServer( sendResponse(0) case ImplicitGroupingKeyRequest.MethodCase.REMOVEIMPLICITKEY => ImplicitGroupingKeyTracker.removeImplicitKey() - // Reset the list/map state iterators for a new grouping key. + // Reset the list/map state iterators for a new grouping sey. iterators = new mutable.HashMap[String, Iterator[Row]]() listTimerIters = new mutable.HashMap[String, Iterator[Long]]() sendResponse(0) @@ -313,6 +314,7 @@ class TransformWithStateInPandasStateServer( case TimerStateCallCommand.MethodCase.REGISTER => val expiryTimestamp = message.getTimerStateCall.getRegister.getExpiryTimestampMs + println(s"I am inside timer register, timestamp: $expiryTimestamp") statefulProcessorHandle.registerTimer(expiryTimestamp) sendResponse(0) case TimerStateCallCommand.MethodCase.DELETE => From bad8f7ad01c092277d7c45835f830b241334ebba Mon Sep 17 00:00:00 2001 From: jingz-db Date: Thu, 14 Nov 2024 11:35:42 -0800 Subject: [PATCH 04/22] test case fixed --- python/pyspark/sql/pandas/group_ops.py | 36 ++++++++----------- .../sql/streaming/proto/StateMessage_pb2.py | 1 + .../sql/streaming/stateful_processor.py | 6 ++-- .../test_pandas_transform_with_state.py | 27 ++++++++------ ...ransformWithStateInPandasStateServer.scala | 4 +-- 5 files changed, 36 insertions(+), 38 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 3890ad72cd896..0bc6e7d3c1059 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -510,13 +510,15 @@ def handle_data_with_timers( inputRows: Optional[Iterator["PandasDataFrameLike"]] = None, ) -> Iterator["PandasDataFrameLike"]: statefulProcessorApiClient.set_implicit_key(key) - # process with invalid expiry timer info and emit data rows + # process with data rows if inputRows is not None: data_iter = statefulProcessor.handleInputRows( key, inputRows, TimerValues(batch_timestamp, watermark_timestamp) ) result_iter_list = [data_iter] - statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.DATA_PROCESSED) + statefulProcessorApiClient.set_handle_state( + StatefulProcessorHandleState.DATA_PROCESSED + ) else: result_iter_list = [] @@ -531,8 +533,7 @@ def handle_data_with_timers( else: expiry_list_iter = iter([[]]) - # process with valid expiry time info and with empty input rows, - # only timer related rows will be emitted + # process with expiry timers, only timer related rows will be emitted for expiry_list in expiry_list_iter: for key_obj, expiry_timestamp in expiry_list: result_iter_list.append( @@ -543,20 +544,14 @@ def handle_data_with_timers( ) ) # TODO(SPARK-49603) set the handle state in the lazily initialized iterator - - print(f"just before return result, key: {key}") - if len(result_iter_list) == 0: - return iter([]) - else: - result = itertools.chain(*result_iter_list) - return result + result = itertools.chain(*result_iter_list) + return result def transformWithStateUDF( statefulProcessorApiClient: StatefulProcessorApiClient, key: Any, inputRows: Iterator["PandasDataFrameLike"], ) -> Iterator["PandasDataFrameLike"]: - print(f"I am inside tws with udf, key: {key}\n") handle = StatefulProcessorHandle(statefulProcessorApiClient) if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED: @@ -581,7 +576,8 @@ def transformWithStateUDF( watermark_timestamp = -1 result = handle_data_with_timers( - statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, inputRows) + statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, inputRows + ) return result def transformWithStateWithInitStateUDF( @@ -602,7 +598,6 @@ def transformWithStateWithInitStateUDF( - `initialStates` is None, while `inputRows` is not empty. This is not first batch. `initialStates` is initialized to the positional value as None. """ - print(f"I am inside tws with udf with init, key: {key}\n") handle = StatefulProcessorHandle(statefulProcessorApiClient) if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED: @@ -613,14 +608,11 @@ def transformWithStateWithInitStateUDF( # Key is None when we have processed all the input data from the worker and ready to # proceed with the cleanup steps. - """ if key is None: - print(f"I am inside tws, key is none") statefulProcessorApiClient.remove_implicit_key() statefulProcessor.close() statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) return iter([]) - """ if timeMode != "none": batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() @@ -634,7 +626,8 @@ def transformWithStateWithInitStateUDF( for cur_initial_state in initialStates: statefulProcessorApiClient.set_implicit_key(key) statefulProcessor.handleInitialState( - key, cur_initial_state, TimerValues(batch_timestamp, watermark_timestamp)) + key, cur_initial_state, TimerValues(batch_timestamp, watermark_timestamp) + ) # if we don't have input rows for the given key but only have initial state # for the grouping key, the inputRows iterator could be empty @@ -648,13 +641,14 @@ def transformWithStateWithInitStateUDF( if not input_rows_empty: result = handle_data_with_timers( - statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, inputRows) + statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, inputRows + ) else: # if the input rows is empty, we still need to handle the expired timers registered # in the initial state result = handle_data_with_timers( - statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, None) - + statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, None + ) return result if isinstance(outputStructType, str): diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py index 13d1efdc593dc..0a54690513a39 100644 --- a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py +++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py @@ -22,6 +22,7 @@ """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder diff --git a/python/pyspark/sql/streaming/stateful_processor.py b/python/pyspark/sql/streaming/stateful_processor.py index 7082f0506e21f..6210ea142d5a4 100644 --- a/python/pyspark/sql/streaming/stateful_processor.py +++ b/python/pyspark/sql/streaming/stateful_processor.py @@ -442,10 +442,8 @@ def close(self) -> None: ... def handleInitialState( - self, - key: Any, - initialState: "PandasDataFrameLike", - timer_values: TimerValues) -> None: + self, key: Any, initialState: "PandasDataFrameLike", timer_values: TimerValues + ) -> None: """ Optional to implement. Will act as no-op if not defined or no initial state input. Function that will be invoked only in the first batch for users to process initial states. diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index fbd696e0a71fd..4a6d639a9c9c0 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -103,7 +103,6 @@ def build_test_df_with_3_cols(self, input_path): ) return df_final - """ def _test_transform_with_state_in_pandas_basic( self, stateful_processor, check_results, single_batch=False, timeMode="None" ): @@ -561,10 +560,10 @@ def check_results(batch_df, batch_id): self._test_transform_with_state_in_pandas_event_time( EventTimeStatefulProcessor(), check_results ) - """ def _test_transform_with_state_init_state_in_pandas( - self, stateful_processor, check_results, time_mode="None"): + self, stateful_processor, check_results, time_mode="None" + ): input_path = tempfile.mkdtemp() self._prepare_test_resource1(input_path) time.sleep(2) @@ -607,7 +606,6 @@ def _test_transform_with_state_init_state_in_pandas( q.awaitTermination(10) self.assertTrue(q.exception() is None) - """ def test_transform_with_state_init_state_in_pandas(self): def check_results(batch_df, batch_id): if batch_id == 0: @@ -704,17 +702,27 @@ def check_results(batch_df, batch_id): self._test_transform_with_state_non_contiguous_grouping_cols( SimpleStatefulProcessorWithInitialState(), check_results, initial_state ) - """ def test_transform_with_state_init_state_with_timers(self): def check_results(batch_df, batch_id): if batch_id == 0: - assert set(batch_df.sort("id").collect()) == { + # timers are registered and handled in the first batch for + # rows in initial state; For key=0 and key=3 which contains + # expired timers, both should be handled by handleExpiredTimers + # regardless of whether key exists in the data rows or not + expired_df = batch_df.filter(batch_df["id"].contains("expired")) + data_df = batch_df.filter(~batch_df["id"].contains("expired")) + assert set(expired_df.sort("id").select("id").collect()) == { + Row(id="0-expired"), + Row(id="3-expired"), + } + assert set(data_df.sort("id").collect()) == { Row(id="0", value=str(789 + 123 + 46)), Row(id="1", value=str(146 + 346)), } else: - raise Exception(f"i am in batch id {batch_id}, batchdf: ${batch_df.collect()}") + # handleInitialState is only processed in the first batch, + # no more timer is registered so no more expired timers assert set(batch_df.sort("id").collect()) == { Row(id="0", value=str(789 + 123 + 46 + 67)), Row(id="3", value=str(987 + 12)), @@ -768,17 +776,16 @@ def close(self) -> None: class StatefulProcessorWithInitialStateTimers(SimpleStatefulProcessorWithInitialState): - def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[pd.DataFrame]: self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) str_key = f"{str(key[0])}-expired" yield pd.DataFrame( - {"id": (str_key,), "timestamp": str(expired_timer_info.get_expiry_time_in_ms())} + {"id": (str_key,), "value": str(expired_timer_info.get_expiry_time_in_ms())} ) def handleInitialState(self, key, initialState, timer_values) -> None: super().handleInitialState(key, initialState, timer_values) - self.handle.registerTimer(timer_values.get_current_processing_time_in_ms()) + self.handle.registerTimer(timer_values.get_current_processing_time_in_ms() - 1) # A stateful processor that output the max event time it has seen. Register timer for diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala index 343f8abc9f7f6..0373c8607ff2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala @@ -214,7 +214,6 @@ class TransformWithStateInPandasStateServer( // this implementation is safe val expiryRequest = message.getExpiryTimerRequest() val expiryTimestamp = expiryRequest.getExpiryTimestampMs - println(s"I am getting expired timer on server, timestamp: ${expiryTimestamp}") if (!expiryTimestampIter.isDefined) { expiryTimestampIter = Option(statefulProcessorHandle.getExpiredTimers(expiryTimestamp)) @@ -252,7 +251,7 @@ class TransformWithStateInPandasStateServer( sendResponse(0) case ImplicitGroupingKeyRequest.MethodCase.REMOVEIMPLICITKEY => ImplicitGroupingKeyTracker.removeImplicitKey() - // Reset the list/map state iterators for a new grouping sey. + // Reset the list/map state iterators for a new grouping key. iterators = new mutable.HashMap[String, Iterator[Row]]() listTimerIters = new mutable.HashMap[String, Iterator[Long]]() sendResponse(0) @@ -314,7 +313,6 @@ class TransformWithStateInPandasStateServer( case TimerStateCallCommand.MethodCase.REGISTER => val expiryTimestamp = message.getTimerStateCall.getRegister.getExpiryTimestampMs - println(s"I am inside timer register, timestamp: $expiryTimestamp") statefulProcessorHandle.registerTimer(expiryTimestamp) sendResponse(0) case TimerStateCallCommand.MethodCase.DELETE => From b7e6f59195f7e55d7bdecee6ac0e8adb9da3a4d7 Mon Sep 17 00:00:00 2001 From: jingz-db Date: Thu, 14 Nov 2024 11:38:30 -0800 Subject: [PATCH 05/22] restore worker file --- python/pyspark/worker.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 10418f0487c94..04f95e9f52648 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1958,6 +1958,17 @@ def process(): try: serializer.dump_stream(out_iter, outfile) finally: + # Sending a signal to TransformWithState UDF to perform proper cleanup steps. + if ( + eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF + or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF + ): + # Sending key as None to indicate that process() has finished. + end_iter = func(split_index, iter([(None, None)])) + # Need to materialize the iterator to trigger the cleanup steps, nothing needs + # to be done here. + for _ in end_iter: + pass if hasattr(out_iter, "close"): out_iter.close() From 0c5ab3f49bf7dfaec694798115d04829a7e96b41 Mon Sep 17 00:00:00 2001 From: jingz-db Date: Mon, 18 Nov 2024 11:30:47 -0800 Subject: [PATCH 06/22] resolve comments --- python/pyspark/sql/pandas/group_ops.py | 27 ++++++++++--------- .../sql/streaming/stateful_processor.py | 3 ++- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 0bc6e7d3c1059..ade0de270ae5a 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -16,7 +16,7 @@ # import itertools import sys -from typing import Any, Iterator, List, Optional, Union, TYPE_CHECKING, cast +from typing import Any, Iterator, List, Optional, Union, Tuple, TYPE_CHECKING, cast import warnings from pyspark.errors import PySparkTypeError @@ -502,6 +502,17 @@ def transformWithStateInPandas( if isinstance(outputStructType, str): outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) + def get_timestamps( + statefulProcessorApiClient: StatefulProcessorApiClient, + ) -> Tuple[int, int]: + if timeMode != "none": + batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() + watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() + else: + batch_timestamp = -1 + watermark_timestamp = -1 + return batch_timestamp, watermark_timestamp + def handle_data_with_timers( statefulProcessorApiClient: StatefulProcessorApiClient, key: Any, @@ -568,12 +579,7 @@ def transformWithStateUDF( statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) return iter([]) - if timeMode != "none": - batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() - watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() - else: - batch_timestamp = -1 - watermark_timestamp = -1 + batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient) result = handle_data_with_timers( statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, inputRows @@ -614,12 +620,7 @@ def transformWithStateWithInitStateUDF( statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) return iter([]) - if timeMode != "none": - batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() - watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() - else: - batch_timestamp = -1 - watermark_timestamp = -1 + batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient) # only process initial state if first batch and initial state is not None if initialStates is not None: diff --git a/python/pyspark/sql/streaming/stateful_processor.py b/python/pyspark/sql/streaming/stateful_processor.py index 6210ea142d5a4..9caa9304d6a87 100644 --- a/python/pyspark/sql/streaming/stateful_processor.py +++ b/python/pyspark/sql/streaming/stateful_processor.py @@ -412,12 +412,13 @@ def handleInputRows( Timer value for the current batch that process the input rows. Users can get the processing or event time timestamp from TimerValues. """ - return iter([]) + ... def handleExpiredTimer( self, key: Any, timer_values: TimerValues, expired_timer_info: ExpiredTimerInfo ) -> Iterator["PandasDataFrameLike"]: """ + Optional to implement. Will act return an empty iterator if not defined. Function that will be invoked when a timer is fired for a given key. Users can choose to evict state, register new timers and optionally provide output rows. From 3d87b0ee0cc60f50901a003818534e1c87d0f0e6 Mon Sep 17 00:00:00 2001 From: jingz-db Date: Mon, 25 Nov 2024 13:14:22 -0800 Subject: [PATCH 07/22] per partition, need to fix set implicit key --- python/pyspark/sql/pandas/group_ops.py | 68 ++++++++++--------- python/pyspark/sql/pandas/serializers.py | 60 +++++++++++++++- .../sql/streaming/proto/StateMessage_pb2.py | 9 --- .../test_pandas_transform_with_state.py | 13 +++- python/pyspark/worker.py | 1 + ...ransformWithStateInPandasStateServer.scala | 6 ++ .../StatefulProcessorHandleImplBase.scala | 2 + .../execution/streaming/TimerStateImpl.scala | 4 ++ 8 files changed, 120 insertions(+), 43 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index ade0de270ae5a..5dbaa79f14b46 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -513,26 +513,11 @@ def get_timestamps( watermark_timestamp = -1 return batch_timestamp, watermark_timestamp - def handle_data_with_timers( - statefulProcessorApiClient: StatefulProcessorApiClient, - key: Any, - batch_timestamp: int, - watermark_timestamp: int, - inputRows: Optional[Iterator["PandasDataFrameLike"]] = None, - ) -> Iterator["PandasDataFrameLike"]: - statefulProcessorApiClient.set_implicit_key(key) - # process with data rows - if inputRows is not None: - data_iter = statefulProcessor.handleInputRows( - key, inputRows, TimerValues(batch_timestamp, watermark_timestamp) - ) - result_iter_list = [data_iter] - statefulProcessorApiClient.set_handle_state( - StatefulProcessorHandleState.DATA_PROCESSED - ) - else: - result_iter_list = [] - + def process_timers( + statefulProcessorApiClient: StatefulProcessorApiClient, + batch_timestamp: int, + watermark_timestamp: int,) -> Iterator["PandasDataFrameLike"]: + result_iter_list = [] if timeMode.lower() == "processingtime": expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( batch_timestamp @@ -547,6 +532,8 @@ def handle_data_with_timers( # process with expiry timers, only timer related rows will be emitted for expiry_list in expiry_list_iter: for key_obj, expiry_timestamp in expiry_list: + print(f"I am setting implict key as {key_obj}\n") + statefulProcessorApiClient.set_implicit_key(key_obj) result_iter_list.append( statefulProcessor.handleExpiredTimer( key_obj, @@ -556,7 +543,29 @@ def handle_data_with_timers( ) # TODO(SPARK-49603) set the handle state in the lazily initialized iterator result = itertools.chain(*result_iter_list) - return result + print(f"returning from process timers, tuple\n") + return (result, statefulProcessorApiClient) + + def handle_data_rows( + statefulProcessorApiClient: StatefulProcessorApiClient, + key: Any, + batch_timestamp: int, + watermark_timestamp: int, + inputRows: Optional[Iterator["PandasDataFrameLike"]] = None, + ) -> Tuple[Iterator["PandasDataFrameLike"], StatefulProcessorApiClient]: + statefulProcessorApiClient.set_implicit_key(key) + print(f"I am inside handle data rows, key is {key}\n") + # process with data rows + if inputRows is not None: + data_iter = statefulProcessor.handleInputRows( + key, inputRows, TimerValues(batch_timestamp, watermark_timestamp) + ) + statefulProcessorApiClient.set_handle_state( + StatefulProcessorHandleState.DATA_PROCESSED + ) + return (data_iter, statefulProcessorApiClient, statefulProcessor) + else: + return (iter([]), statefulProcessorApiClient, statefulProcessor) def transformWithStateUDF( statefulProcessorApiClient: StatefulProcessorApiClient, @@ -574,6 +583,7 @@ def transformWithStateUDF( # Key is None when we have processed all the input data from the worker and ready to # proceed with the cleanup steps. if key is None: + print(f"I am inside key is None\n") statefulProcessorApiClient.remove_implicit_key() statefulProcessor.close() statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) @@ -581,7 +591,7 @@ def transformWithStateUDF( batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient) - result = handle_data_with_timers( + result = handle_data_rows( statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, inputRows ) return result @@ -614,14 +624,15 @@ def transformWithStateWithInitStateUDF( # Key is None when we have processed all the input data from the worker and ready to # proceed with the cleanup steps. + batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient) if key is None: + print(f"I am inside key is None, after processing timers, handle state: " + f"{statefulProcessorApiClient.handle_state}\n") statefulProcessorApiClient.remove_implicit_key() statefulProcessor.close() statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) return iter([]) - batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient) - # only process initial state if first batch and initial state is not None if initialStates is not None: for cur_initial_state in initialStates: @@ -641,16 +652,11 @@ def transformWithStateWithInitStateUDF( inputRows = itertools.chain([first], inputRows) if not input_rows_empty: - result = handle_data_with_timers( + return handle_data_rows( statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, inputRows ) else: - # if the input rows is empty, we still need to handle the expired timers registered - # in the initial state - result = handle_data_with_timers( - statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, None - ) - return result + return (iter([]), statefulProcessorApiClient, statefulProcessor) if isinstance(outputStructType, str): outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 5bf07b87400fe..39859ee0bcf23 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1204,8 +1204,64 @@ def dump_stream(self, iterator, stream): Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow RecordBatches, and write batches to stream. """ - result = [(b, t) for x in iterator for y, t in x for b in y] - super().dump_stream(result, stream) + + from itertools import tee, chain + from pyspark.sql.streaming.stateful_processor_api_client import ( + StatefulProcessorHandleState, + ) + from pyspark.sql.streaming.stateful_processor import ( + ExpiredTimerInfo, + TimerValues, + ) + # Clone the original iterator + cloned_iterator, result_iterator = tee(iterator) + result = ([(pd, t) for x in cloned_iterator for y, t in x for pd in y[0]]) + args = [(y[1], y[2], t) for x in result_iterator for y, t in x] + + print(f"args type: {type(args[0])}") + timeMode = "processingTime" + statefulProcessorApiClient = args[0][0] + statefulProcessor = args[0][1] + outputType = args[0][2] + + if timeMode != "none": + batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() + watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() + else: + batch_timestamp = -1 + watermark_timestamp = -1 + + result_iter_list = [] + if timeMode.lower() == "processingtime": + expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( + batch_timestamp + ) + elif timeMode.lower() == "eventtime": + expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( + watermark_timestamp + ) + else: + expiry_list_iter = iter([[]]) + + # process with expiry timers, only timer related rows will be emitted + for expiry_list in expiry_list_iter: + for key_obj, expiry_timestamp in expiry_list: + print(f"I am setting implicit key for timer as {key_obj}\n") + statefulProcessorApiClient.set_implicit_key(key_obj) + result_iter_list.append( + statefulProcessor.handleExpiredTimer( + key_obj, + TimerValues(batch_timestamp, watermark_timestamp), + ExpiredTimerInfo(expiry_timestamp), + ) + ) + # TODO(SPARK-49603) set the handle state in the lazily initialized iterator + timer_result_list = ((df, outputType) for df in chain(*result_iter_list)) + print(f"planning to chain together, result ele type: {type(result[0])}, " + f"timer type: {type(next(timer_result_list))}") + final_result = chain(result, timer_result_list) + + super().dump_stream(final_result, stream) class TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSerializer): diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py index 0a54690513a39..b57440380b3ae 100644 --- a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py +++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py @@ -22,18 +22,9 @@ """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder -_runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 5, - 28, - 3, - "", - "org/apache/spark/sql/execution/streaming/StateMessage.proto", -) # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 4a6d639a9c9c0..3530511bf1d79 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -55,7 +55,7 @@ class TransformWithStateInPandasTestsMixin: @classmethod def conf(cls): cfg = SparkConf() - cfg.set("spark.sql.shuffle.partitions", "5") + cfg.set("spark.sql.shuffle.partitions", "1") cfg.set( "spark.sql.streaming.stateStore.providerClass", "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider", @@ -103,6 +103,7 @@ def build_test_df_with_3_cols(self, input_path): ) return df_final + """ def _test_transform_with_state_in_pandas_basic( self, stateful_processor, check_results, single_batch=False, timeMode="None" ): @@ -560,6 +561,7 @@ def check_results(batch_df, batch_id): self._test_transform_with_state_in_pandas_event_time( EventTimeStatefulProcessor(), check_results ) + """ def _test_transform_with_state_init_state_in_pandas( self, stateful_processor, check_results, time_mode="None" @@ -606,6 +608,7 @@ def _test_transform_with_state_init_state_in_pandas( q.awaitTermination(10) self.assertTrue(q.exception() is None) + """ def test_transform_with_state_init_state_in_pandas(self): def check_results(batch_df, batch_id): if batch_id == 0: @@ -702,6 +705,7 @@ def check_results(batch_df, batch_id): self._test_transform_with_state_non_contiguous_grouping_cols( SimpleStatefulProcessorWithInitialState(), check_results, initial_state ) + """ def test_transform_with_state_init_state_with_timers(self): def check_results(batch_df, batch_id): @@ -712,6 +716,7 @@ def check_results(batch_df, batch_id): # regardless of whether key exists in the data rows or not expired_df = batch_df.filter(batch_df["id"].contains("expired")) data_df = batch_df.filter(~batch_df["id"].contains("expired")) + print(f"batch id: {batch_id}, batch df: {batch_df.collect()}\n") assert set(expired_df.sort("id").select("id").collect()) == { Row(id="0-expired"), Row(id="3-expired"), @@ -777,8 +782,12 @@ def close(self) -> None: class StatefulProcessorWithInitialStateTimers(SimpleStatefulProcessorWithInitialState): def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[pd.DataFrame]: + print(f"before delete Timers for key: {key}, " + f"timestamp: {expired_timer_info.get_expiry_time_in_ms()}\n") self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) str_key = f"{str(key[0])}-expired" + print(f"after delete Timers for key: {key}, " + f"return key: {str_key}\n") yield pd.DataFrame( {"id": (str_key,), "value": str(expired_timer_info.get_expiry_time_in_ms())} ) @@ -786,6 +795,8 @@ def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[ def handleInitialState(self, key, initialState, timer_values) -> None: super().handleInitialState(key, initialState, timer_values) self.handle.registerTimer(timer_values.get_current_processing_time_in_ms() - 1) + print(f"after register Timers for key: {key}, " + f"timestamp: {timer_values.get_current_processing_time_in_ms() - 1}\n") # A stateful processor that output the max event time it has seen. Register timer for diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 04f95e9f52648..886f46f6a40ff 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1956,6 +1956,7 @@ def process(): iterator = deserializer.load_stream(infile) out_iter = func(split_index, iterator) try: + print(f"Inside process, before dump stream for data rows\n") serializer.dump_stream(out_iter, outfile) finally: # Sending a signal to TransformWithState UDF to perform proper cleanup steps. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala index 0373c8607ff2c..aed2afdaef6da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala @@ -214,15 +214,20 @@ class TransformWithStateInPandasStateServer( // this implementation is safe val expiryRequest = message.getExpiryTimerRequest() val expiryTimestamp = expiryRequest.getExpiryTimestampMs + println(s"JVM server, get expiry timers, expityTimestamp: $expiryTimestamp") if (!expiryTimestampIter.isDefined) { expiryTimestampIter = Option(statefulProcessorHandle.getExpiredTimers(expiryTimestamp)) } // expiryTimestampIter could be None in the TWSPandasServerSuite if (!expiryTimestampIter.isDefined || !expiryTimestampIter.get.hasNext) { + println(s"JVM server, get expiry timers, expityTimestamp: $expiryTimestamp, " + + s"send back response 1") // iterator is exhausted, signal the end of iterator on python client sendResponse(1) } else { + println(s"JVM server, get expiry timers, expityTimestamp: $expiryTimestamp, " + + s"send back response 0") sendResponse(0) val outputSchema = new StructType() .add("key", BinaryType) @@ -318,6 +323,7 @@ class TransformWithStateInPandasStateServer( case TimerStateCallCommand.MethodCase.DELETE => val expiryTimestamp = message.getTimerStateCall.getDelete.getExpiryTimestampMs + println(s"JVM server, delete timer, expityTimestamp: $expiryTimestamp") statefulProcessorHandle.deleteTimer(expiryTimestamp) sendResponse(0) case TimerStateCallCommand.MethodCase.LIST => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala index 64d87073ccf9f..aec0612c7f560 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala @@ -35,11 +35,13 @@ abstract class StatefulProcessorHandleImplBase( def verifyTimerOperations(operationType: String): Unit = { if (timeMode == NoTime) { + println(s"JVM stateful processor, delete timer, timeMode: $timeMode") throw StateStoreErrors.cannotPerformOperationWithInvalidTimeMode(operationType, timeMode.toString) } if (currState < INITIALIZED || currState >= TIMER_PROCESSED) { + println(s"JVM stateful processor, delete timer, currstate: $currState") throw StateStoreErrors.cannotPerformOperationWithInvalidHandleState(operationType, currState.toString) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala index d0fbaf6600609..3999c83387d91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala @@ -112,6 +112,8 @@ class TimerStateImpl( */ def registerTimer(expiryTimestampMs: Long): Unit = { val groupingKey = getGroupingKey(keyToTsCFName) + println(s"JVM timer state, register timer, expiryTimestamp: $expiryTimestampMs, " + + s"groupingKey: $groupingKey") if (exists(groupingKey, expiryTimestampMs)) { logWarning(log"Failed to register timer for key=${MDC(KEY, groupingKey)} and " + log"timestamp=${MDC(EXPIRY_TIMESTAMP, expiryTimestampMs)} ms since it already exists") @@ -129,6 +131,8 @@ class TimerStateImpl( */ def deleteTimer(expiryTimestampMs: Long): Unit = { val groupingKey = getGroupingKey(keyToTsCFName) + println(s"JVM timer state, delete timer, expiryTimestamp: $expiryTimestampMs, " + + s"groupingKey: $groupingKey") if (!exists(groupingKey, expiryTimestampMs)) { logWarning(log"Failed to delete timer for key=${MDC(KEY, groupingKey)} and " + From 96e7226ff14c4c58244833ba96767f1f85c2dc3e Mon Sep 17 00:00:00 2001 From: jingz-db Date: Mon, 25 Nov 2024 14:17:21 -0800 Subject: [PATCH 08/22] a fully working version --- python/pyspark/sql/pandas/group_ops.py | 46 ++----------------- python/pyspark/sql/pandas/serializers.py | 46 +++++++++++++------ .../test_pandas_transform_with_state.py | 7 +-- python/pyspark/worker.py | 1 - ...ransformWithStateInPandasStateServer.scala | 6 --- .../StatefulProcessorHandleImplBase.scala | 2 - .../execution/streaming/TimerStateImpl.scala | 4 -- 7 files changed, 36 insertions(+), 76 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 5dbaa79f14b46..cad8f800db549 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -513,39 +513,6 @@ def get_timestamps( watermark_timestamp = -1 return batch_timestamp, watermark_timestamp - def process_timers( - statefulProcessorApiClient: StatefulProcessorApiClient, - batch_timestamp: int, - watermark_timestamp: int,) -> Iterator["PandasDataFrameLike"]: - result_iter_list = [] - if timeMode.lower() == "processingtime": - expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( - batch_timestamp - ) - elif timeMode.lower() == "eventtime": - expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( - watermark_timestamp - ) - else: - expiry_list_iter = iter([[]]) - - # process with expiry timers, only timer related rows will be emitted - for expiry_list in expiry_list_iter: - for key_obj, expiry_timestamp in expiry_list: - print(f"I am setting implict key as {key_obj}\n") - statefulProcessorApiClient.set_implicit_key(key_obj) - result_iter_list.append( - statefulProcessor.handleExpiredTimer( - key_obj, - TimerValues(batch_timestamp, watermark_timestamp), - ExpiredTimerInfo(expiry_timestamp), - ) - ) - # TODO(SPARK-49603) set the handle state in the lazily initialized iterator - result = itertools.chain(*result_iter_list) - print(f"returning from process timers, tuple\n") - return (result, statefulProcessorApiClient) - def handle_data_rows( statefulProcessorApiClient: StatefulProcessorApiClient, key: Any, @@ -554,7 +521,6 @@ def handle_data_rows( inputRows: Optional[Iterator["PandasDataFrameLike"]] = None, ) -> Tuple[Iterator["PandasDataFrameLike"], StatefulProcessorApiClient]: statefulProcessorApiClient.set_implicit_key(key) - print(f"I am inside handle data rows, key is {key}\n") # process with data rows if inputRows is not None: data_iter = statefulProcessor.handleInputRows( @@ -563,9 +529,9 @@ def handle_data_rows( statefulProcessorApiClient.set_handle_state( StatefulProcessorHandleState.DATA_PROCESSED ) - return (data_iter, statefulProcessorApiClient, statefulProcessor) + return (data_iter, statefulProcessorApiClient, statefulProcessor, timeMode) else: - return (iter([]), statefulProcessorApiClient, statefulProcessor) + return (iter([]), statefulProcessorApiClient, statefulProcessor, timeMode) def transformWithStateUDF( statefulProcessorApiClient: StatefulProcessorApiClient, @@ -583,7 +549,6 @@ def transformWithStateUDF( # Key is None when we have processed all the input data from the worker and ready to # proceed with the cleanup steps. if key is None: - print(f"I am inside key is None\n") statefulProcessorApiClient.remove_implicit_key() statefulProcessor.close() statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) @@ -591,10 +556,9 @@ def transformWithStateUDF( batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient) - result = handle_data_rows( + return handle_data_rows( statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, inputRows ) - return result def transformWithStateWithInitStateUDF( statefulProcessorApiClient: StatefulProcessorApiClient, @@ -626,8 +590,6 @@ def transformWithStateWithInitStateUDF( # proceed with the cleanup steps. batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient) if key is None: - print(f"I am inside key is None, after processing timers, handle state: " - f"{statefulProcessorApiClient.handle_state}\n") statefulProcessorApiClient.remove_implicit_key() statefulProcessor.close() statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) @@ -656,7 +618,7 @@ def transformWithStateWithInitStateUDF( statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, inputRows ) else: - return (iter([]), statefulProcessorApiClient, statefulProcessor) + return (iter([]), statefulProcessorApiClient, statefulProcessor, timeMode) if isinstance(outputStructType, str): outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 39859ee0bcf23..804b95b108cd5 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1213,16 +1213,21 @@ def dump_stream(self, iterator, stream): ExpiredTimerInfo, TimerValues, ) - # Clone the original iterator + # Clone the original iterator to get additional args cloned_iterator, result_iterator = tee(iterator) result = ([(pd, t) for x in cloned_iterator for y, t in x for pd in y[0]]) - args = [(y[1], y[2], t) for x in result_iterator for y, t in x] + args = [(y[1], y[2], t, y[3]) for x in result_iterator for y, t in x] - print(f"args type: {type(args[0])}") - timeMode = "processingTime" + # if num of keys is smaller than num of partitions, some partitions will have empty + # input rows; we do nothing for such partitions + if len(args) == 0: + return + + # all keys on the same partition share the same args statefulProcessorApiClient = args[0][0] statefulProcessor = args[0][1] outputType = args[0][2] + timeMode = args[0][3] if timeMode != "none": batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() @@ -1243,25 +1248,36 @@ def dump_stream(self, iterator, stream): else: expiry_list_iter = iter([[]]) + def timer_iter_wrapper(func, *args, **kwargs): + def wrapper(): + timer_cur_key = kwargs.get("key", args[0] if len(args) > 0 else None) + # set implicit key for the timer row before calling UDF + statefulProcessorApiClient.set_implicit_key(timer_cur_key) + # Call handleExpiredTimer UDF + iter = func(*args, **kwargs) + try: + for e in iter: + yield e + finally: + statefulProcessorApiClient.remove_implicit_key() + + return wrapper() + # process with expiry timers, only timer related rows will be emitted for expiry_list in expiry_list_iter: for key_obj, expiry_timestamp in expiry_list: - print(f"I am setting implicit key for timer as {key_obj}\n") - statefulProcessorApiClient.set_implicit_key(key_obj) - result_iter_list.append( - statefulProcessor.handleExpiredTimer( - key_obj, - TimerValues(batch_timestamp, watermark_timestamp), - ExpiredTimerInfo(expiry_timestamp), - ) + result_iter_list.append(timer_iter_wrapper( + statefulProcessor.handleExpiredTimer, + key=key_obj, + timer_values=TimerValues(batch_timestamp, watermark_timestamp), + expired_timer_info=ExpiredTimerInfo(expiry_timestamp)) ) - # TODO(SPARK-49603) set the handle state in the lazily initialized iterator + timer_result_list = ((df, outputType) for df in chain(*result_iter_list)) - print(f"planning to chain together, result ele type: {type(result[0])}, " - f"timer type: {type(next(timer_result_list))}") final_result = chain(result, timer_result_list) super().dump_stream(final_result, stream) + statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.TIMER_PROCESSED) class TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSerializer): diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 3530511bf1d79..2ce9b706580c0 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -55,7 +55,7 @@ class TransformWithStateInPandasTestsMixin: @classmethod def conf(cls): cfg = SparkConf() - cfg.set("spark.sql.shuffle.partitions", "1") + cfg.set("spark.sql.shuffle.partitions", "5") cfg.set( "spark.sql.streaming.stateStore.providerClass", "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider", @@ -103,7 +103,6 @@ def build_test_df_with_3_cols(self, input_path): ) return df_final - """ def _test_transform_with_state_in_pandas_basic( self, stateful_processor, check_results, single_batch=False, timeMode="None" ): @@ -561,7 +560,6 @@ def check_results(batch_df, batch_id): self._test_transform_with_state_in_pandas_event_time( EventTimeStatefulProcessor(), check_results ) - """ def _test_transform_with_state_init_state_in_pandas( self, stateful_processor, check_results, time_mode="None" @@ -608,7 +606,6 @@ def _test_transform_with_state_init_state_in_pandas( q.awaitTermination(10) self.assertTrue(q.exception() is None) - """ def test_transform_with_state_init_state_in_pandas(self): def check_results(batch_df, batch_id): if batch_id == 0: @@ -705,7 +702,6 @@ def check_results(batch_df, batch_id): self._test_transform_with_state_non_contiguous_grouping_cols( SimpleStatefulProcessorWithInitialState(), check_results, initial_state ) - """ def test_transform_with_state_init_state_with_timers(self): def check_results(batch_df, batch_id): @@ -716,7 +712,6 @@ def check_results(batch_df, batch_id): # regardless of whether key exists in the data rows or not expired_df = batch_df.filter(batch_df["id"].contains("expired")) data_df = batch_df.filter(~batch_df["id"].contains("expired")) - print(f"batch id: {batch_id}, batch df: {batch_df.collect()}\n") assert set(expired_df.sort("id").select("id").collect()) == { Row(id="0-expired"), Row(id="3-expired"), diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 886f46f6a40ff..04f95e9f52648 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1956,7 +1956,6 @@ def process(): iterator = deserializer.load_stream(infile) out_iter = func(split_index, iterator) try: - print(f"Inside process, before dump stream for data rows\n") serializer.dump_stream(out_iter, outfile) finally: # Sending a signal to TransformWithState UDF to perform proper cleanup steps. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala index aed2afdaef6da..0373c8607ff2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala @@ -214,20 +214,15 @@ class TransformWithStateInPandasStateServer( // this implementation is safe val expiryRequest = message.getExpiryTimerRequest() val expiryTimestamp = expiryRequest.getExpiryTimestampMs - println(s"JVM server, get expiry timers, expityTimestamp: $expiryTimestamp") if (!expiryTimestampIter.isDefined) { expiryTimestampIter = Option(statefulProcessorHandle.getExpiredTimers(expiryTimestamp)) } // expiryTimestampIter could be None in the TWSPandasServerSuite if (!expiryTimestampIter.isDefined || !expiryTimestampIter.get.hasNext) { - println(s"JVM server, get expiry timers, expityTimestamp: $expiryTimestamp, " + - s"send back response 1") // iterator is exhausted, signal the end of iterator on python client sendResponse(1) } else { - println(s"JVM server, get expiry timers, expityTimestamp: $expiryTimestamp, " + - s"send back response 0") sendResponse(0) val outputSchema = new StructType() .add("key", BinaryType) @@ -323,7 +318,6 @@ class TransformWithStateInPandasStateServer( case TimerStateCallCommand.MethodCase.DELETE => val expiryTimestamp = message.getTimerStateCall.getDelete.getExpiryTimestampMs - println(s"JVM server, delete timer, expityTimestamp: $expiryTimestamp") statefulProcessorHandle.deleteTimer(expiryTimestamp) sendResponse(0) case TimerStateCallCommand.MethodCase.LIST => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala index aec0612c7f560..64d87073ccf9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala @@ -35,13 +35,11 @@ abstract class StatefulProcessorHandleImplBase( def verifyTimerOperations(operationType: String): Unit = { if (timeMode == NoTime) { - println(s"JVM stateful processor, delete timer, timeMode: $timeMode") throw StateStoreErrors.cannotPerformOperationWithInvalidTimeMode(operationType, timeMode.toString) } if (currState < INITIALIZED || currState >= TIMER_PROCESSED) { - println(s"JVM stateful processor, delete timer, currstate: $currState") throw StateStoreErrors.cannotPerformOperationWithInvalidHandleState(operationType, currState.toString) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala index 3999c83387d91..d0fbaf6600609 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala @@ -112,8 +112,6 @@ class TimerStateImpl( */ def registerTimer(expiryTimestampMs: Long): Unit = { val groupingKey = getGroupingKey(keyToTsCFName) - println(s"JVM timer state, register timer, expiryTimestamp: $expiryTimestampMs, " + - s"groupingKey: $groupingKey") if (exists(groupingKey, expiryTimestampMs)) { logWarning(log"Failed to register timer for key=${MDC(KEY, groupingKey)} and " + log"timestamp=${MDC(EXPIRY_TIMESTAMP, expiryTimestampMs)} ms since it already exists") @@ -131,8 +129,6 @@ class TimerStateImpl( */ def deleteTimer(expiryTimestampMs: Long): Unit = { val groupingKey = getGroupingKey(keyToTsCFName) - println(s"JVM timer state, delete timer, expiryTimestamp: $expiryTimestampMs, " + - s"groupingKey: $groupingKey") if (!exists(groupingKey, expiryTimestampMs)) { logWarning(log"Failed to delete timer for key=${MDC(KEY, groupingKey)} and " + From 4c272a5152dd889f89f0679847de3e787fd174a7 Mon Sep 17 00:00:00 2001 From: jingz-db Date: Mon, 25 Nov 2024 14:28:48 -0800 Subject: [PATCH 09/22] add some comments --- python/pyspark/sql/pandas/group_ops.py | 17 ++++------------- python/pyspark/sql/pandas/serializers.py | 19 +++++++++++-------- .../stateful_processor_api_client.py | 9 +++++++++ 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index cad8f800db549..aab56a8a37bc7 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -502,17 +502,6 @@ def transformWithStateInPandas( if isinstance(outputStructType, str): outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) - def get_timestamps( - statefulProcessorApiClient: StatefulProcessorApiClient, - ) -> Tuple[int, int]: - if timeMode != "none": - batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() - watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() - else: - batch_timestamp = -1 - watermark_timestamp = -1 - return batch_timestamp, watermark_timestamp - def handle_data_rows( statefulProcessorApiClient: StatefulProcessorApiClient, key: Any, @@ -554,7 +543,8 @@ def transformWithStateUDF( statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) return iter([]) - batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient) + batch_timestamp, watermark_timestamp =\ + statefulProcessorApiClient.get_timestamps(timeMode) return handle_data_rows( statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, inputRows @@ -588,13 +578,14 @@ def transformWithStateWithInitStateUDF( # Key is None when we have processed all the input data from the worker and ready to # proceed with the cleanup steps. - batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient) if key is None: statefulProcessorApiClient.remove_implicit_key() statefulProcessor.close() statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) return iter([]) + batch_timestamp, watermark_timestamp = \ + statefulProcessorApiClient.get_timestamps(timeMode) # only process initial state if first batch and initial state is not None if initialStates is not None: for cur_initial_state in initialStates: diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 804b95b108cd5..3e6cbba305872 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1201,8 +1201,10 @@ def generate_data_batches(batches): def dump_stream(self, iterator, stream): """ - Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow - RecordBatches, and write batches to stream. + Read through chained return results from a single partition of handleInputRows. + For a single partition, after finish handling all input rows, we need to iterate + through all expired timers and handle them. We chain the results of handleInputRows + with handleExpiredTimer into a single iterator and dump the stream as arrow batches. """ from itertools import tee, chain @@ -1229,12 +1231,8 @@ def dump_stream(self, iterator, stream): outputType = args[0][2] timeMode = args[0][3] - if timeMode != "none": - batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() - watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() - else: - batch_timestamp = -1 - watermark_timestamp = -1 + batch_timestamp, watermark_timestamp = \ + statefulProcessorApiClient.get_timestamps(timeMode) result_iter_list = [] if timeMode.lower() == "processingtime": @@ -1249,6 +1247,11 @@ def dump_stream(self, iterator, stream): expiry_list_iter = iter([[]]) def timer_iter_wrapper(func, *args, **kwargs): + """ + Wrap the timer iterator returned from handleExpiredTimer with implicit key handling. + For a given key, need to properly set implicit key before calling handleExpiredTimer, + and remove the implicit key after consuming the iterator. + """ def wrapper(): timer_cur_key = kwargs.get("key", args[0] if len(args) > 0 else None) # set implicit key for the timer row before calling UDF diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index 353f75e267962..3076fad3a50c0 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -353,6 +353,15 @@ def delete_if_exists(self, state_name: str) -> None: # TODO(SPARK-49233): Classify user facing errors. raise PySparkRuntimeError(f"Error deleting state: " f"{response_message[1]}") + def get_timestamps(self, timeMode: str) -> Tuple[int, int]: + if timeMode != "none": + batch_timestamp = self.get_batch_timestamp() + watermark_timestamp = self.get_watermark_timestamp() + else: + batch_timestamp = -1 + watermark_timestamp = -1 + return batch_timestamp, watermark_timestamp + def _send_proto_message(self, message: bytes) -> None: # Writing zero here to indicate message version. This allows us to evolve the message # format or even changing the message protocol in the future. From a69cb6fd5b18edb0f6cb150b168272f93ab756c0 Mon Sep 17 00:00:00 2001 From: jingz-db Date: Mon, 25 Nov 2024 15:06:45 -0800 Subject: [PATCH 10/22] get timestamps per batch --- python/pyspark/sql/pandas/group_ops.py | 17 +--- python/pyspark/sql/pandas/serializers.py | 3 +- .../stateful_processor_api_client.py | 94 +++++++++---------- ...ransformWithStateInPandasStateServer.scala | 1 + 4 files changed, 52 insertions(+), 63 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index aab56a8a37bc7..5b05fa9f28422 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -505,11 +505,10 @@ def transformWithStateInPandas( def handle_data_rows( statefulProcessorApiClient: StatefulProcessorApiClient, key: Any, - batch_timestamp: int, - watermark_timestamp: int, inputRows: Optional[Iterator["PandasDataFrameLike"]] = None, ) -> Tuple[Iterator["PandasDataFrameLike"], StatefulProcessorApiClient]: statefulProcessorApiClient.set_implicit_key(key) + batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps() # process with data rows if inputRows is not None: data_iter = statefulProcessor.handleInputRows( @@ -543,12 +542,7 @@ def transformWithStateUDF( statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) return iter([]) - batch_timestamp, watermark_timestamp =\ - statefulProcessorApiClient.get_timestamps(timeMode) - - return handle_data_rows( - statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, inputRows - ) + return handle_data_rows(statefulProcessorApiClient, key, inputRows) def transformWithStateWithInitStateUDF( statefulProcessorApiClient: StatefulProcessorApiClient, @@ -584,8 +578,7 @@ def transformWithStateWithInitStateUDF( statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) return iter([]) - batch_timestamp, watermark_timestamp = \ - statefulProcessorApiClient.get_timestamps(timeMode) + batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps() # only process initial state if first batch and initial state is not None if initialStates is not None: for cur_initial_state in initialStates: @@ -605,9 +598,7 @@ def transformWithStateWithInitStateUDF( inputRows = itertools.chain([first], inputRows) if not input_rows_empty: - return handle_data_rows( - statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, inputRows - ) + return handle_data_rows(statefulProcessorApiClient, key, inputRows) else: return (iter([]), statefulProcessorApiClient, statefulProcessor, timeMode) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 3e6cbba305872..d9218ea787ff9 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1231,8 +1231,7 @@ def dump_stream(self, iterator, stream): outputType = args[0][2] timeMode = args[0][3] - batch_timestamp, watermark_timestamp = \ - statefulProcessorApiClient.get_timestamps(timeMode) + batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps() result_iter_list = [] if timeMode.lower() == "processingtime": diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index 3076fad3a50c0..dac02c6d90801 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -62,6 +62,10 @@ def __init__(self, state_server_port: int, key_schema: StructType) -> None: # Dictionaries to store the mapping between iterator id and a tuple of pandas DataFrame # and the index of the last row that was read. self.list_timer_iterator_cursors: Dict[str, Tuple["PandasDataFrameLike", int]] = {} + # statefulProcessorApiClient is initialized per batch per partition, + # so we will have new timestamps for a new batch + self._batch_timestamp = self._get_batch_timestamp() + self._watermark_timestamp = self._get_watermark_timestamp() def set_handle_state(self, state: StatefulProcessorHandleState) -> None: import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage @@ -266,47 +270,8 @@ def get_expiry_timers_iterator( # TODO(SPARK-49233): Classify user facing errors. raise PySparkRuntimeError(f"Error getting expiry timers: " f"{response_message[1]}") - def get_batch_timestamp(self) -> int: - import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage - - get_processing_time_call = stateMessage.GetProcessingTime() - timer_value_call = stateMessage.TimerValueRequest( - getProcessingTimer=get_processing_time_call - ) - timer_request = stateMessage.TimerRequest(timerValueRequest=timer_value_call) - message = stateMessage.StateRequest(timerRequest=timer_request) - - self._send_proto_message(message.SerializeToString()) - response_message = self._receive_proto_message_with_long_value() - status = response_message[0] - if status != 0: - # TODO(SPARK-49233): Classify user facing errors. - raise PySparkRuntimeError( - f"Error getting processing timestamp: " f"{response_message[1]}" - ) - else: - timestamp = response_message[2] - return timestamp - - def get_watermark_timestamp(self) -> int: - import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage - - get_watermark_call = stateMessage.GetWatermark() - timer_value_call = stateMessage.TimerValueRequest(getWatermark=get_watermark_call) - timer_request = stateMessage.TimerRequest(timerValueRequest=timer_value_call) - message = stateMessage.StateRequest(timerRequest=timer_request) - - self._send_proto_message(message.SerializeToString()) - response_message = self._receive_proto_message_with_long_value() - status = response_message[0] - if status != 0: - # TODO(SPARK-49233): Classify user facing errors. - raise PySparkRuntimeError( - f"Error getting eventtime timestamp: " f"{response_message[1]}" - ) - else: - timestamp = response_message[2] - return timestamp + def get_timestamps(self) -> Tuple[int, int]: + return self._batch_timestamp, self._watermark_timestamp def get_map_state( self, @@ -353,14 +318,47 @@ def delete_if_exists(self, state_name: str) -> None: # TODO(SPARK-49233): Classify user facing errors. raise PySparkRuntimeError(f"Error deleting state: " f"{response_message[1]}") - def get_timestamps(self, timeMode: str) -> Tuple[int, int]: - if timeMode != "none": - batch_timestamp = self.get_batch_timestamp() - watermark_timestamp = self.get_watermark_timestamp() + def _get_batch_timestamp(self) -> int: + import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage + + get_processing_time_call = stateMessage.GetProcessingTime() + timer_value_call = stateMessage.TimerValueRequest( + getProcessingTimer=get_processing_time_call + ) + timer_request = stateMessage.TimerRequest(timerValueRequest=timer_value_call) + message = stateMessage.StateRequest(timerRequest=timer_request) + + self._send_proto_message(message.SerializeToString()) + response_message = self._receive_proto_message_with_long_value() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError( + f"Error getting processing timestamp: " f"{response_message[1]}" + ) + else: + timestamp = response_message[2] + return timestamp + + def _get_watermark_timestamp(self) -> int: + import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage + + get_watermark_call = stateMessage.GetWatermark() + timer_value_call = stateMessage.TimerValueRequest(getWatermark=get_watermark_call) + timer_request = stateMessage.TimerRequest(timerValueRequest=timer_value_call) + message = stateMessage.StateRequest(timerRequest=timer_request) + + self._send_proto_message(message.SerializeToString()) + response_message = self._receive_proto_message_with_long_value() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError( + f"Error getting eventtime timestamp: " f"{response_message[1]}" + ) else: - batch_timestamp = -1 - watermark_timestamp = -1 - return batch_timestamp, watermark_timestamp + timestamp = response_message[2] + return timestamp def _send_proto_message(self, message: bytes) -> None: # Writing zero here to indicate message version. This allows us to evolve the message diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala index 0373c8607ff2c..675d020a2fa5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala @@ -120,6 +120,7 @@ class TransformWithStateInPandasStateServer( } /** Timer related class variables */ + private var expiryTimestampIter: Option[Iterator[(Any, Long)]] = if (expiryTimerIterForTest != null) { Option(expiryTimerIterForTest) From 53fb7ccb40b2482178223ef3030463558813517f Mon Sep 17 00:00:00 2001 From: jingz-db Date: Mon, 25 Nov 2024 15:10:55 -0800 Subject: [PATCH 11/22] lint --- python/pyspark/sql/pandas/serializers.py | 16 ++++++++++------ .../pandas/test_pandas_transform_with_state.py | 15 +++++++++------ .../TransformWithStateInPandasStateServer.scala | 3 ++- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index d9218ea787ff9..ad3542de92c7f 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1215,9 +1215,10 @@ def dump_stream(self, iterator, stream): ExpiredTimerInfo, TimerValues, ) + # Clone the original iterator to get additional args cloned_iterator, result_iterator = tee(iterator) - result = ([(pd, t) for x in cloned_iterator for y, t in x for pd in y[0]]) + result = [(pd, t) for x in cloned_iterator for y, t in x for pd in y[0]] args = [(y[1], y[2], t, y[3]) for x in result_iterator for y, t in x] # if num of keys is smaller than num of partitions, some partitions will have empty @@ -1251,6 +1252,7 @@ def timer_iter_wrapper(func, *args, **kwargs): For a given key, need to properly set implicit key before calling handleExpiredTimer, and remove the implicit key after consuming the iterator. """ + def wrapper(): timer_cur_key = kwargs.get("key", args[0] if len(args) > 0 else None) # set implicit key for the timer row before calling UDF @@ -1268,11 +1270,13 @@ def wrapper(): # process with expiry timers, only timer related rows will be emitted for expiry_list in expiry_list_iter: for key_obj, expiry_timestamp in expiry_list: - result_iter_list.append(timer_iter_wrapper( - statefulProcessor.handleExpiredTimer, - key=key_obj, - timer_values=TimerValues(batch_timestamp, watermark_timestamp), - expired_timer_info=ExpiredTimerInfo(expiry_timestamp)) + result_iter_list.append( + timer_iter_wrapper( + statefulProcessor.handleExpiredTimer, + key=key_obj, + timer_values=TimerValues(batch_timestamp, watermark_timestamp), + expired_timer_info=ExpiredTimerInfo(expiry_timestamp), + ) ) timer_result_list = ((df, outputType) for df in chain(*result_iter_list)) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 2ce9b706580c0..2c512ca9c8da9 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -777,12 +777,13 @@ def close(self) -> None: class StatefulProcessorWithInitialStateTimers(SimpleStatefulProcessorWithInitialState): def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[pd.DataFrame]: - print(f"before delete Timers for key: {key}, " - f"timestamp: {expired_timer_info.get_expiry_time_in_ms()}\n") + print( + f"before delete Timers for key: {key}, " + f"timestamp: {expired_timer_info.get_expiry_time_in_ms()}\n" + ) self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) str_key = f"{str(key[0])}-expired" - print(f"after delete Timers for key: {key}, " - f"return key: {str_key}\n") + print(f"after delete Timers for key: {key}, " f"return key: {str_key}\n") yield pd.DataFrame( {"id": (str_key,), "value": str(expired_timer_info.get_expiry_time_in_ms())} ) @@ -790,8 +791,10 @@ def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[ def handleInitialState(self, key, initialState, timer_values) -> None: super().handleInitialState(key, initialState, timer_values) self.handle.registerTimer(timer_values.get_current_processing_time_in_ms() - 1) - print(f"after register Timers for key: {key}, " - f"timestamp: {timer_values.get_current_processing_time_in_ms() - 1}\n") + print( + f"after register Timers for key: {key}, " + f"timestamp: {timer_values.get_current_processing_time_in_ms() - 1}\n" + ) # A stateful processor that output the max event time it has seen. Register timer for diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala index 675d020a2fa5a..7b80dd431ad3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala @@ -120,7 +120,8 @@ class TransformWithStateInPandasStateServer( } /** Timer related class variables */ - + // An iterator store all expired timer info. This is meant to be consumed only once per + // partition. This should be called after finishing handling all input rows. private var expiryTimestampIter: Option[Iterator[(Any, Long)]] = if (expiryTimerIterForTest != null) { Option(expiryTimerIterForTest) From 3fa4ef2ae816bf9bfef970e487cbe9290ea8c35e Mon Sep 17 00:00:00 2001 From: jingz-db Date: Mon, 25 Nov 2024 15:11:21 -0800 Subject: [PATCH 12/22] typo --- .../python/TransformWithStateInPandasStateServer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala index 7b80dd431ad3d..2957f4b387580 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala @@ -120,7 +120,7 @@ class TransformWithStateInPandasStateServer( } /** Timer related class variables */ - // An iterator store all expired timer info. This is meant to be consumed only once per + // An iterator to store all expired timer info. This is meant to be consumed only once per // partition. This should be called after finishing handling all input rows. private var expiryTimestampIter: Option[Iterator[(Any, Long)]] = if (expiryTimerIterForTest != null) { From d4cd2e130d93bc9a7ad00bb34111dd5e1a068dc4 Mon Sep 17 00:00:00 2001 From: jingz-db Date: Mon, 25 Nov 2024 15:39:38 -0800 Subject: [PATCH 13/22] fix type hint --- python/pyspark/sql/pandas/group_ops.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 5b05fa9f28422..6820f8ec781f6 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -506,7 +506,9 @@ def handle_data_rows( statefulProcessorApiClient: StatefulProcessorApiClient, key: Any, inputRows: Optional[Iterator["PandasDataFrameLike"]] = None, - ) -> Tuple[Iterator["PandasDataFrameLike"], StatefulProcessorApiClient]: + ) -> Tuple[ + Iterator["PandasDataFrameLike"], StatefulProcessorApiClient, StatefulProcessor, str + ]: statefulProcessorApiClient.set_implicit_key(key) batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps() # process with data rows @@ -525,7 +527,9 @@ def transformWithStateUDF( statefulProcessorApiClient: StatefulProcessorApiClient, key: Any, inputRows: Iterator["PandasDataFrameLike"], - ) -> Iterator["PandasDataFrameLike"]: + ) -> Tuple[ + Iterator["PandasDataFrameLike"], StatefulProcessorApiClient, StatefulProcessor, str + ]: handle = StatefulProcessorHandle(statefulProcessorApiClient) if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED: @@ -549,7 +553,9 @@ def transformWithStateWithInitStateUDF( key: Any, inputRows: Iterator["PandasDataFrameLike"], initialStates: Optional[Iterator["PandasDataFrameLike"]] = None, - ) -> Iterator["PandasDataFrameLike"]: + ) -> Tuple[ + Iterator["PandasDataFrameLike"], StatefulProcessorApiClient, StatefulProcessor, str + ]: """ UDF for TWS operator with non-empty initial states. Possible input combinations of inputRows and initialStates iterator: From 5e018ffa1255730de1020697cc64f9c7f25737c5 Mon Sep 17 00:00:00 2001 From: jingz-db Date: Tue, 26 Nov 2024 17:28:06 -0800 Subject: [PATCH 14/22] resolve partial comments --- python/pyspark/sql/pandas/serializers.py | 40 ++++++++++++------- .../sql/streaming/proto/StateMessage_pb2.py | 9 +++++ .../test_pandas_transform_with_state.py | 16 ++++---- 3 files changed, 42 insertions(+), 23 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index ad3542de92c7f..6ece11a5420af 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1218,6 +1218,18 @@ def dump_stream(self, iterator, stream): # Clone the original iterator to get additional args cloned_iterator, result_iterator = tee(iterator) + # since we return Tuple[iterator["PandasDataframeLike"], StatefulProcessorApiClient, + # StatefulProcessor, str of timeMode] from `transformWithStateUDF` and + # `transformWithStateWithInitStateUDF`, the iterator of result grouped for all keys on a + # partition is of type: + # Iterator[List[Tuple[ + # Tuple[ + # iterator["PandasDataframeLike"], StatefulProcessorApiClient, + # StatefulProcessor, str of timeMode + # ], outputStructType]] + # ] + # We want to convert the result iterator to a list of pandas dataframe, + # and get the remaining args to further perform timer handling operations result = [(pd, t) for x in cloned_iterator for y, t in x for pd in y[0]] args = [(y[1], y[2], t, y[3]) for x in result_iterator for y, t in x] @@ -1232,20 +1244,6 @@ def dump_stream(self, iterator, stream): outputType = args[0][2] timeMode = args[0][3] - batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps() - - result_iter_list = [] - if timeMode.lower() == "processingtime": - expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( - batch_timestamp - ) - elif timeMode.lower() == "eventtime": - expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( - watermark_timestamp - ) - else: - expiry_list_iter = iter([[]]) - def timer_iter_wrapper(func, *args, **kwargs): """ Wrap the timer iterator returned from handleExpiredTimer with implicit key handling. @@ -1267,6 +1265,20 @@ def wrapper(): return wrapper() + batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps() + + result_iter_list = [] + if timeMode.lower() == "processingtime": + expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( + batch_timestamp + ) + elif timeMode.lower() == "eventtime": + expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( + watermark_timestamp + ) + else: + expiry_list_iter = iter([[]]) + # process with expiry timers, only timer related rows will be emitted for expiry_list in expiry_list_iter: for key_obj, expiry_timestamp in expiry_list: diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py index b57440380b3ae..c0de0d510cdef 100644 --- a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py +++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py @@ -22,9 +22,18 @@ """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder +runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 28, + 3, + "", + "org/apache/spark/sql/execution/streaming/StateMessage.proto", +) # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 2c512ca9c8da9..7cef6cdeb0147 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -732,6 +732,13 @@ def check_results(batch_df, batch_id): StatefulProcessorWithInitialStateTimers(), check_results, "processingTime" ) + # run the same test suites again but with single shuffle partition + def test_transform_with_state_with_timers_single_partition(self): + with self.sql_conf({"spark.sql.shuffle.partitions": "1"}): + self.test_transform_with_state_init_state_with_timers() + self.test_transform_with_state_in_pandas_event_time() + self.test_transform_with_state_in_pandas_proc_timer() + class SimpleStatefulProcessorWithInitialState(StatefulProcessor): # this dict is the same as input initial state dataframe @@ -777,13 +784,8 @@ def close(self) -> None: class StatefulProcessorWithInitialStateTimers(SimpleStatefulProcessorWithInitialState): def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[pd.DataFrame]: - print( - f"before delete Timers for key: {key}, " - f"timestamp: {expired_timer_info.get_expiry_time_in_ms()}\n" - ) self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) str_key = f"{str(key[0])}-expired" - print(f"after delete Timers for key: {key}, " f"return key: {str_key}\n") yield pd.DataFrame( {"id": (str_key,), "value": str(expired_timer_info.get_expiry_time_in_ms())} ) @@ -791,10 +793,6 @@ def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[ def handleInitialState(self, key, initialState, timer_values) -> None: super().handleInitialState(key, initialState, timer_values) self.handle.registerTimer(timer_values.get_current_processing_time_in_ms() - 1) - print( - f"after register Timers for key: {key}, " - f"timestamp: {timer_values.get_current_processing_time_in_ms() - 1}\n" - ) # A stateful processor that output the max event time it has seen. Register timer for From 521722e76f2ce1362c8d61aa9f81882324e6f43b Mon Sep 17 00:00:00 2001 From: Jing Zhan <135738831+jingz-db@users.noreply.github.com> Date: Tue, 26 Nov 2024 17:46:56 -0800 Subject: [PATCH 15/22] restore StateMessage_pb2.py --- python/pyspark/sql/streaming/proto/StateMessage_pb2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py index c0de0d510cdef..0a54690513a39 100644 --- a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py +++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py @@ -26,7 +26,7 @@ from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder -runtime_version.ValidateProtobufRuntimeVersion( +_runtime_version.ValidateProtobufRuntimeVersion( _runtime_version.Domain.PUBLIC, 5, 28, From 0bff9ebaf319b9b0894e9560bfa533d79a382cba Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 27 Nov 2024 12:13:13 +0900 Subject: [PATCH 16/22] Made the logic of handling timer and cleanup simpler --- python/pyspark/sql/pandas/group_ops.py | 86 +++++++++++++++---- python/pyspark/sql/pandas/serializers.py | 14 ++- .../sql/streaming/stateful_processor_util.py | 26 ++++++ .../test_pandas_transform_with_state.py | 1 + python/pyspark/worker.py | 68 +++++++-------- 5 files changed, 143 insertions(+), 52 deletions(-) create mode 100644 python/pyspark/sql/streaming/stateful_processor_util.py diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 6820f8ec781f6..396f073bfc35e 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -35,6 +35,7 @@ TimerValues, ) from pyspark.sql.streaming.stateful_processor import StatefulProcessor, StatefulProcessorHandle +from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode from pyspark.sql.types import StructType, _parse_datatype_string if TYPE_CHECKING: @@ -502,6 +503,17 @@ def transformWithStateInPandas( if isinstance(outputStructType, str): outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) + def get_timestamps( + statefulProcessorApiClient: StatefulProcessorApiClient, + ) -> Tuple[int, int]: + if timeMode != "none": + batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() + watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() + else: + batch_timestamp = -1 + watermark_timestamp = -1 + return batch_timestamp, watermark_timestamp + def handle_data_rows( statefulProcessorApiClient: StatefulProcessorApiClient, key: Any, @@ -510,21 +522,53 @@ def handle_data_rows( Iterator["PandasDataFrameLike"], StatefulProcessorApiClient, StatefulProcessor, str ]: statefulProcessorApiClient.set_implicit_key(key) + batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps() + # process with data rows if inputRows is not None: data_iter = statefulProcessor.handleInputRows( key, inputRows, TimerValues(batch_timestamp, watermark_timestamp) ) - statefulProcessorApiClient.set_handle_state( - StatefulProcessorHandleState.DATA_PROCESSED + return data_iter + else: + return iter([]) + + def handle_expired_timers( + statefulProcessorApiClient: StatefulProcessorApiClient, + ) -> Iterator["PandasDataFrameLike"]: + result_iter_list = [] + + batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient) + + if timeMode.lower() == "processingtime": + expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( + batch_timestamp + ) + elif timeMode.lower() == "eventtime": + expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( + watermark_timestamp ) - return (data_iter, statefulProcessorApiClient, statefulProcessor, timeMode) else: - return (iter([]), statefulProcessorApiClient, statefulProcessor, timeMode) + expiry_list_iter = iter([[]]) + + # process with expiry timers, only timer related rows will be emitted + for expiry_list in expiry_list_iter: + for key_obj, expiry_timestamp in expiry_list: + statefulProcessorApiClient.set_implicit_key(key_obj) + result_iter_list.append( + statefulProcessor.handleExpiredTimer( + key_obj, + TimerValues(batch_timestamp, watermark_timestamp), + ExpiredTimerInfo(expiry_timestamp), + ) + ) + + return itertools.chain(*result_iter_list) def transformWithStateUDF( statefulProcessorApiClient: StatefulProcessorApiClient, + mode: TransformWithStateInPandasFuncMode, key: Any, inputRows: Iterator["PandasDataFrameLike"], ) -> Tuple[ @@ -538,18 +582,24 @@ def transformWithStateUDF( StatefulProcessorHandleState.INITIALIZED ) - # Key is None when we have processed all the input data from the worker and ready to - # proceed with the cleanup steps. - if key is None: + if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER: + statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.DATA_PROCESSED) + result = handle_expired_timers(statefulProcessorApiClient) + return result + elif mode == TransformWithStateInPandasFuncMode.COMPLETE: + statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.TIMER_PROCESSED) statefulProcessorApiClient.remove_implicit_key() statefulProcessor.close() statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) return iter([]) - - return handle_data_rows(statefulProcessorApiClient, key, inputRows) + else: + # mode == TransformWithStateInPandasFuncMode.PROCESS_DATA + result = handle_data_rows(statefulProcessorApiClient, key, inputRows) + return result def transformWithStateWithInitStateUDF( statefulProcessorApiClient: StatefulProcessorApiClient, + mode: TransformWithStateInPandasFuncMode, key: Any, inputRows: Iterator["PandasDataFrameLike"], initialStates: Optional[Iterator["PandasDataFrameLike"]] = None, @@ -576,15 +626,19 @@ def transformWithStateWithInitStateUDF( StatefulProcessorHandleState.INITIALIZED ) - # Key is None when we have processed all the input data from the worker and ready to - # proceed with the cleanup steps. - if key is None: + if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER: + statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.DATA_PROCESSED) + result = handle_expired_timers(statefulProcessorApiClient) + return result + elif mode == TransformWithStateInPandasFuncMode.COMPLETE: statefulProcessorApiClient.remove_implicit_key() statefulProcessor.close() statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) return iter([]) + else: + # mode == TransformWithStateInPandasFuncMode.PROCESS_DATA + batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient) - batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps() # only process initial state if first batch and initial state is not None if initialStates is not None: for cur_initial_state in initialStates: @@ -604,9 +658,11 @@ def transformWithStateWithInitStateUDF( inputRows = itertools.chain([first], inputRows) if not input_rows_empty: - return handle_data_rows(statefulProcessorApiClient, key, inputRows) + result = handle_data_rows(statefulProcessorApiClient, key, inputRows) else: - return (iter([]), statefulProcessorApiClient, statefulProcessor, timeMode) + result = iter([]) + + return result if isinstance(outputStructType, str): outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 6ece11a5420af..80b4f14a61fc9 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -36,6 +36,7 @@ _create_converter_from_pandas, _create_converter_to_pandas, ) +from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode from pyspark.sql.types import ( DataType, StringType, @@ -1140,7 +1141,6 @@ def init_stream_yield_batches(batches): return ArrowStreamSerializer.dump_stream(self, batches_to_write, stream) - class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer): """ Serializer used by Python worker to evaluate UDF for @@ -1197,7 +1197,11 @@ def generate_data_batches(batches): data_batches = generate_data_batches(_batches) for k, g in groupby(data_batches, key=lambda x: x[0]): - yield (k, g) + yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g) + + yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) + + yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) def dump_stream(self, iterator, stream): """ @@ -1371,4 +1375,8 @@ def flatten_columns(cur_batch, col_name): data_batches = generate_data_batches(_batches) for k, g in groupby(data_batches, key=lambda x: x[0]): - yield (k, g) + yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g) + + yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) + + yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) diff --git a/python/pyspark/sql/streaming/stateful_processor_util.py b/python/pyspark/sql/streaming/stateful_processor_util.py new file mode 100644 index 0000000000000..623eff8132255 --- /dev/null +++ b/python/pyspark/sql/streaming/stateful_processor_util.py @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from enum import Enum + +# This file places the utilities for transformWithStateInPandas; we have a separate file to avoid +# putting internal classes to the stateful_processor.py file which contains public APIs. + +class TransformWithStateInPandasFuncMode(Enum): + PROCESS_DATA = 1 + PROCESS_TIMER = 2 + COMPLETE = 3 diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 7cef6cdeb0147..59929f21e4a66 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -61,6 +61,7 @@ def conf(cls): "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider", ) cfg.set("spark.sql.execution.arrow.transformWithStateInPandas.maxRecordsPerBatch", "2") + cfg.set("spark.sql.session.timeZone", "UTC") return cfg def _prepare_input_data(self, input_path, col1, col2): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 04f95e9f52648..1ebc04520ecad 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -34,6 +34,7 @@ _deserialize_accumulator, ) from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient +from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode from pyspark.taskcontext import BarrierTaskContext, TaskContext from pyspark.resource import ResourceInformation from pyspark.util import PythonEvalType, local_connect_and_auth @@ -493,36 +494,36 @@ def wrapped(key_series, value_series): def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf): - def wrapped(stateful_processor_api_client, key, value_series_gen): + def wrapped(stateful_processor_api_client, mode, key, value_series_gen): import pandas as pd values = (pd.concat(x, axis=1) for x in value_series_gen) - result_iter = f(stateful_processor_api_client, key, values) + result_iter = f(stateful_processor_api_client, mode, key, values) # TODO(SPARK-49100): add verification that elements in result_iter are # indeed of type pd.DataFrame and confirm to assigned cols return result_iter - return lambda p, k, v: [(wrapped(p, k, v), to_arrow_type(return_type))] + return lambda p, m, k, v: [(wrapped(p, m, k, v), to_arrow_type(return_type))] def wrap_grouped_transform_with_state_pandas_init_state_udf(f, return_type, runner_conf): - def wrapped(stateful_processor_api_client, key, value_series_gen): + def wrapped(stateful_processor_api_client, mode, key, value_series_gen): import pandas as pd state_values_gen, init_states_gen = itertools.tee(value_series_gen, 2) state_values = (df for x, _ in state_values_gen if not (df := pd.concat(x, axis=1)).empty) init_states = (df for _, x in init_states_gen if not (df := pd.concat(x, axis=1)).empty) - result_iter = f(stateful_processor_api_client, key, state_values, init_states) + result_iter = f(stateful_processor_api_client, mode, key, state_values, init_states) # TODO(SPARK-49100): add verification that elements in result_iter are # indeed of type pd.DataFrame and confirm to assigned cols return result_iter - return lambda p, k, v: [(wrapped(p, k, v), to_arrow_type(return_type))] + return lambda p, m, k, v: [(wrapped(p, m, k, v), to_arrow_type(return_type))] def wrap_grouped_map_pandas_udf_with_state(f, return_type): @@ -1697,18 +1698,22 @@ def mapper(a): ser.key_offsets = parsed_offsets[0][0] stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema) - # Create function like this: - # mapper a: f([a[0]], [a[0], a[1]]) def mapper(a): - key = a[0] + mode = a[0] - def values_gen(): - for x in a[1]: - retVal = [x[1][o] for o in parsed_offsets[0][1]] - yield retVal + if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: + key = a[1] - # This must be generator comprehension - do not materialize. - return f(stateful_processor_api_client, key, values_gen()) + def values_gen(): + for x in a[2]: + retVal = [x[1][o] for o in parsed_offsets[0][1]] + yield retVal + + # This must be generator comprehension - do not materialize. + return f(stateful_processor_api_client, mode, key, values_gen()) + else: + # mode == PROCESS_TIMER or mode == COMPLETE + return f(stateful_processor_api_client, mode, None, iter([])) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF: # We assume there is only one UDF here because grouped map doesn't @@ -1731,16 +1736,22 @@ def values_gen(): stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema) def mapper(a): - key = a[0] + mode = a[0] - def values_gen(): - for x in a[1]: - retVal = [x[1][o] for o in parsed_offsets[0][1]] - initVal = [x[2][o] for o in parsed_offsets[1][1]] - yield retVal, initVal + if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: + key = a[1] - # This must be generator comprehension - do not materialize. - return f(stateful_processor_api_client, key, values_gen()) + def values_gen(): + for x in a[2]: + retVal = [x[1][o] for o in parsed_offsets[0][1]] + initVal = [x[2][o] for o in parsed_offsets[1][1]] + yield retVal, initVal + + # This must be generator comprehension - do not materialize. + return f(stateful_processor_api_client, mode, key, values_gen()) + else: + # mode == PROCESS_TIMER or mode == COMPLETE + return f(stateful_processor_api_client, mode, None, iter([])) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF: import pyarrow as pa @@ -1958,17 +1969,6 @@ def process(): try: serializer.dump_stream(out_iter, outfile) finally: - # Sending a signal to TransformWithState UDF to perform proper cleanup steps. - if ( - eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF - or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF - ): - # Sending key as None to indicate that process() has finished. - end_iter = func(split_index, iter([(None, None)])) - # Need to materialize the iterator to trigger the cleanup steps, nothing needs - # to be done here. - for _ in end_iter: - pass if hasattr(out_iter, "close"): out_iter.close() From cd16fdae7628c140fd6af95851afd229ad3f1951 Mon Sep 17 00:00:00 2001 From: jingz-db Date: Wed, 27 Nov 2024 11:27:07 -0800 Subject: [PATCH 17/22] resolve conflicts --- python/pyspark/sql/pandas/group_ops.py | 35 +++--- python/pyspark/sql/pandas/serializers.py | 105 ++---------------- .../stateful_processor_api_client.py | 13 ++- .../sql/streaming/stateful_processor_util.py | 1 + 4 files changed, 37 insertions(+), 117 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 396f073bfc35e..a55bfceb8e9b5 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -503,17 +503,6 @@ def transformWithStateInPandas( if isinstance(outputStructType, str): outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) - def get_timestamps( - statefulProcessorApiClient: StatefulProcessorApiClient, - ) -> Tuple[int, int]: - if timeMode != "none": - batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() - watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() - else: - batch_timestamp = -1 - watermark_timestamp = -1 - return batch_timestamp, watermark_timestamp - def handle_data_rows( statefulProcessorApiClient: StatefulProcessorApiClient, key: Any, @@ -523,7 +512,9 @@ def handle_data_rows( ]: statefulProcessorApiClient.set_implicit_key(key) - batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps() + batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps( + timeMode + ) # process with data rows if inputRows is not None: @@ -539,7 +530,9 @@ def handle_expired_timers( ) -> Iterator["PandasDataFrameLike"]: result_iter_list = [] - batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient) + batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps( + timeMode + ) if timeMode.lower() == "processingtime": expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( @@ -583,11 +576,15 @@ def transformWithStateUDF( ) if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER: - statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.DATA_PROCESSED) + statefulProcessorApiClient.set_handle_state( + StatefulProcessorHandleState.DATA_PROCESSED + ) result = handle_expired_timers(statefulProcessorApiClient) return result elif mode == TransformWithStateInPandasFuncMode.COMPLETE: - statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.TIMER_PROCESSED) + statefulProcessorApiClient.set_handle_state( + StatefulProcessorHandleState.TIMER_PROCESSED + ) statefulProcessorApiClient.remove_implicit_key() statefulProcessor.close() statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) @@ -627,7 +624,9 @@ def transformWithStateWithInitStateUDF( ) if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER: - statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.DATA_PROCESSED) + statefulProcessorApiClient.set_handle_state( + StatefulProcessorHandleState.DATA_PROCESSED + ) result = handle_expired_timers(statefulProcessorApiClient) return result elif mode == TransformWithStateInPandasFuncMode.COMPLETE: @@ -637,7 +636,9 @@ def transformWithStateWithInitStateUDF( return iter([]) else: # mode == TransformWithStateInPandasFuncMode.PROCESS_DATA - batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient) + batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps( + timeMode + ) # only process initial state if first batch and initial state is not None if initialStates is not None: diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 80b4f14a61fc9..f39185c4e3de4 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1141,6 +1141,7 @@ def init_stream_yield_batches(batches): return ArrowStreamSerializer.dump_stream(self, batches_to_write, stream) + class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer): """ Serializer used by Python worker to evaluate UDF for @@ -1205,101 +1206,11 @@ def generate_data_batches(batches): def dump_stream(self, iterator, stream): """ - Read through chained return results from a single partition of handleInputRows. - For a single partition, after finish handling all input rows, we need to iterate - through all expired timers and handle them. We chain the results of handleInputRows - with handleExpiredTimer into a single iterator and dump the stream as arrow batches. + Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow + RecordBatches, and write batches to stream. """ - - from itertools import tee, chain - from pyspark.sql.streaming.stateful_processor_api_client import ( - StatefulProcessorHandleState, - ) - from pyspark.sql.streaming.stateful_processor import ( - ExpiredTimerInfo, - TimerValues, - ) - - # Clone the original iterator to get additional args - cloned_iterator, result_iterator = tee(iterator) - # since we return Tuple[iterator["PandasDataframeLike"], StatefulProcessorApiClient, - # StatefulProcessor, str of timeMode] from `transformWithStateUDF` and - # `transformWithStateWithInitStateUDF`, the iterator of result grouped for all keys on a - # partition is of type: - # Iterator[List[Tuple[ - # Tuple[ - # iterator["PandasDataframeLike"], StatefulProcessorApiClient, - # StatefulProcessor, str of timeMode - # ], outputStructType]] - # ] - # We want to convert the result iterator to a list of pandas dataframe, - # and get the remaining args to further perform timer handling operations - result = [(pd, t) for x in cloned_iterator for y, t in x for pd in y[0]] - args = [(y[1], y[2], t, y[3]) for x in result_iterator for y, t in x] - - # if num of keys is smaller than num of partitions, some partitions will have empty - # input rows; we do nothing for such partitions - if len(args) == 0: - return - - # all keys on the same partition share the same args - statefulProcessorApiClient = args[0][0] - statefulProcessor = args[0][1] - outputType = args[0][2] - timeMode = args[0][3] - - def timer_iter_wrapper(func, *args, **kwargs): - """ - Wrap the timer iterator returned from handleExpiredTimer with implicit key handling. - For a given key, need to properly set implicit key before calling handleExpiredTimer, - and remove the implicit key after consuming the iterator. - """ - - def wrapper(): - timer_cur_key = kwargs.get("key", args[0] if len(args) > 0 else None) - # set implicit key for the timer row before calling UDF - statefulProcessorApiClient.set_implicit_key(timer_cur_key) - # Call handleExpiredTimer UDF - iter = func(*args, **kwargs) - try: - for e in iter: - yield e - finally: - statefulProcessorApiClient.remove_implicit_key() - - return wrapper() - - batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps() - - result_iter_list = [] - if timeMode.lower() == "processingtime": - expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( - batch_timestamp - ) - elif timeMode.lower() == "eventtime": - expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator( - watermark_timestamp - ) - else: - expiry_list_iter = iter([[]]) - - # process with expiry timers, only timer related rows will be emitted - for expiry_list in expiry_list_iter: - for key_obj, expiry_timestamp in expiry_list: - result_iter_list.append( - timer_iter_wrapper( - statefulProcessor.handleExpiredTimer, - key=key_obj, - timer_values=TimerValues(batch_timestamp, watermark_timestamp), - expired_timer_info=ExpiredTimerInfo(expiry_timestamp), - ) - ) - - timer_result_list = ((df, outputType) for df in chain(*result_iter_list)) - final_result = chain(result, timer_result_list) - - super().dump_stream(final_result, stream) - statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.TIMER_PROCESSED) + result = [(b, t) for x in iterator for y, t in x for b in y] + super().dump_stream(result, stream) class TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSerializer): @@ -1375,8 +1286,8 @@ def flatten_columns(cur_batch, col_name): data_batches = generate_data_batches(_batches) for k, g in groupby(data_batches, key=lambda x: x[0]): - yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g) + yield TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g - yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) + yield TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None - yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) + yield TransformWithStateInPandasFuncMode.COMPLETE, None, None diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index dac02c6d90801..53704188081c3 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -64,8 +64,8 @@ def __init__(self, state_server_port: int, key_schema: StructType) -> None: self.list_timer_iterator_cursors: Dict[str, Tuple["PandasDataFrameLike", int]] = {} # statefulProcessorApiClient is initialized per batch per partition, # so we will have new timestamps for a new batch - self._batch_timestamp = self._get_batch_timestamp() - self._watermark_timestamp = self._get_watermark_timestamp() + self._batch_timestamp = -1 + self._watermark_timestamp = -1 def set_handle_state(self, state: StatefulProcessorHandleState) -> None: import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage @@ -270,7 +270,14 @@ def get_expiry_timers_iterator( # TODO(SPARK-49233): Classify user facing errors. raise PySparkRuntimeError(f"Error getting expiry timers: " f"{response_message[1]}") - def get_timestamps(self) -> Tuple[int, int]: + def get_timestamps(self, time_mode: str) -> Tuple[int, int]: + if time_mode.lower() == "none": + return -1, -1 + else: + if self._batch_timestamp == -1: + self._batch_timestamp = self._get_batch_timestamp() + if self._watermark_timestamp == -1: + self._watermark_timestamp = self._get_watermark_timestamp() return self._batch_timestamp, self._watermark_timestamp def get_map_state( diff --git a/python/pyspark/sql/streaming/stateful_processor_util.py b/python/pyspark/sql/streaming/stateful_processor_util.py index 623eff8132255..6130a9581bc24 100644 --- a/python/pyspark/sql/streaming/stateful_processor_util.py +++ b/python/pyspark/sql/streaming/stateful_processor_util.py @@ -20,6 +20,7 @@ # This file places the utilities for transformWithStateInPandas; we have a separate file to avoid # putting internal classes to the stateful_processor.py file which contains public APIs. + class TransformWithStateInPandasFuncMode(Enum): PROCESS_DATA = 1 PROCESS_TIMER = 2 From 571dc441f8ae5e200056b1394f85e61c904435fb Mon Sep 17 00:00:00 2001 From: jingz-db Date: Wed, 27 Nov 2024 12:40:27 -0800 Subject: [PATCH 18/22] tests passed for single partition --- python/pyspark/sql/pandas/group_ops.py | 39 ++++++++++++------- python/pyspark/sql/pandas/serializers.py | 4 +- .../test_pandas_transform_with_state.py | 5 ++- 3 files changed, 30 insertions(+), 18 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index a1dfe02d91e21..39f05552a3ca9 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -508,9 +508,7 @@ def handle_data_rows( statefulProcessorApiClient: StatefulProcessorApiClient, key: Any, inputRows: Optional[Iterator["PandasDataFrameLike"]] = None, - ) -> Tuple[ - Iterator["PandasDataFrameLike"], StatefulProcessorApiClient, StatefulProcessor, str - ]: + ) -> Iterator["PandasDataFrameLike"]: statefulProcessorApiClient.set_implicit_key(key) batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps( @@ -529,6 +527,22 @@ def handle_data_rows( def handle_expired_timers( statefulProcessorApiClient: StatefulProcessorApiClient, ) -> Iterator["PandasDataFrameLike"]: + def timer_iter_wrapper(func, *args, **kwargs): + def wrapper(): + timer_cur_key = kwargs.get("key", args[0] if len(args) > 0 else None) + expired_timer_info = kwargs.get("expired_timer_info", args[2] if len(args) > 2 else None) + # set implicit key for the timer row before calling UDF + statefulProcessorApiClient.set_implicit_key(timer_cur_key) + # Call handleExpiredTimer UDF + iter = func(*args, **kwargs) + try: + for e in iter: + yield e + finally: + statefulProcessorApiClient.delete_timer(expired_timer_info.get_expiry_time_in_ms()) + statefulProcessorApiClient.remove_implicit_key() + return wrapper() + result_iter_list = [] batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps( @@ -549,13 +563,12 @@ def handle_expired_timers( # process with expiry timers, only timer related rows will be emitted for expiry_list in expiry_list_iter: for key_obj, expiry_timestamp in expiry_list: - statefulProcessorApiClient.set_implicit_key(key_obj) result_iter_list.append( - statefulProcessor.handleExpiredTimer( - key_obj, - TimerValues(batch_timestamp, watermark_timestamp), - ExpiredTimerInfo(expiry_timestamp), - ) + timer_iter_wrapper( + statefulProcessor.handleExpiredTimer, + key=key_obj, + timer_values=TimerValues(batch_timestamp, watermark_timestamp), + expired_timer_info=ExpiredTimerInfo(expiry_timestamp)) ) return itertools.chain(*result_iter_list) @@ -565,9 +578,7 @@ def transformWithStateUDF( mode: TransformWithStateInPandasFuncMode, key: Any, inputRows: Iterator["PandasDataFrameLike"], - ) -> Tuple[ - Iterator["PandasDataFrameLike"], StatefulProcessorApiClient, StatefulProcessor, str - ]: + ) -> Iterator["PandasDataFrameLike"]: handle = StatefulProcessorHandle(statefulProcessorApiClient) if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED: @@ -601,9 +612,7 @@ def transformWithStateWithInitStateUDF( key: Any, inputRows: Iterator["PandasDataFrameLike"], initialStates: Optional[Iterator["PandasDataFrameLike"]] = None, - ) -> Tuple[ - Iterator["PandasDataFrameLike"], StatefulProcessorApiClient, StatefulProcessor, str - ]: + ) -> Iterator["PandasDataFrameLike"]: """ UDF for TWS operator with non-empty initial states. Possible input combinations of inputRows and initialStates iterator: diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index f39185c4e3de4..deb9d5f20cd0e 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1288,6 +1288,6 @@ def flatten_columns(cur_batch, col_name): for k, g in groupby(data_batches, key=lambda x: x[0]): yield TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g - yield TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None + yield TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None - yield TransformWithStateInPandasFuncMode.COMPLETE, None, None + yield TransformWithStateInPandasFuncMode.COMPLETE, None, None diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 3011902618fd0..ec1f421cb382d 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -826,13 +826,16 @@ def check_results(batch_df, batch_id): Row(id="0", value=str(789 + 123 + 46)), Row(id="1", value=str(146 + 346)), } - else: + elif batch_id == 1: # handleInitialState is only processed in the first batch, # no more timer is registered so no more expired timers assert set(batch_df.sort("id").collect()) == { Row(id="0", value=str(789 + 123 + 46 + 67)), Row(id="3", value=str(987 + 12)), } + else: + for q in self.spark.streams.active: + q.stop() self._test_transform_with_state_init_state_in_pandas( StatefulProcessorWithInitialStateTimers(), check_results, "processingTime" From aaf34a32733ab6a3298ab228ad373a994bee5975 Mon Sep 17 00:00:00 2001 From: jingz-db Date: Wed, 27 Nov 2024 14:09:47 -0800 Subject: [PATCH 19/22] fix tests, set implicit key inside yield --- python/pyspark/sql/pandas/group_ops.py | 35 +++++-------------- .../test_pandas_transform_with_state.py | 17 ++++----- 2 files changed, 15 insertions(+), 37 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 39f05552a3ca9..e54e62db07c86 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -527,24 +527,6 @@ def handle_data_rows( def handle_expired_timers( statefulProcessorApiClient: StatefulProcessorApiClient, ) -> Iterator["PandasDataFrameLike"]: - def timer_iter_wrapper(func, *args, **kwargs): - def wrapper(): - timer_cur_key = kwargs.get("key", args[0] if len(args) > 0 else None) - expired_timer_info = kwargs.get("expired_timer_info", args[2] if len(args) > 2 else None) - # set implicit key for the timer row before calling UDF - statefulProcessorApiClient.set_implicit_key(timer_cur_key) - # Call handleExpiredTimer UDF - iter = func(*args, **kwargs) - try: - for e in iter: - yield e - finally: - statefulProcessorApiClient.delete_timer(expired_timer_info.get_expiry_time_in_ms()) - statefulProcessorApiClient.remove_implicit_key() - return wrapper() - - result_iter_list = [] - batch_timestamp, watermark_timestamp = statefulProcessorApiClient.get_timestamps( timeMode ) @@ -563,15 +545,14 @@ def wrapper(): # process with expiry timers, only timer related rows will be emitted for expiry_list in expiry_list_iter: for key_obj, expiry_timestamp in expiry_list: - result_iter_list.append( - timer_iter_wrapper( - statefulProcessor.handleExpiredTimer, - key=key_obj, - timer_values=TimerValues(batch_timestamp, watermark_timestamp), - expired_timer_info=ExpiredTimerInfo(expiry_timestamp)) - ) - - return itertools.chain(*result_iter_list) + statefulProcessorApiClient.set_implicit_key(key_obj) + for pd in statefulProcessor.handleExpiredTimer( + key=key_obj, + timer_values=TimerValues(batch_timestamp, watermark_timestamp), + expired_timer_info=ExpiredTimerInfo(expiry_timestamp), + ): + yield pd + statefulProcessorApiClient.delete_timer(expiry_timestamp) def transformWithStateUDF( statefulProcessorApiClient: StatefulProcessorApiClient, diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index ec1f421cb382d..c1154819bd788 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -559,17 +559,18 @@ def prepare_batch3(input_path): def test_transform_with_state_in_pandas_event_time(self): def check_results(batch_df, batch_id): if batch_id == 0: - assert set(batch_df.sort("id").collect()) == {Row(id="a", timestamp="20")} + assert set(batch_df.sort("id").collect()) == { + Row(id="a", timestamp="20"), + } elif batch_id == 1: assert set(batch_df.sort("id").collect()) == { Row(id="a", timestamp="20"), - Row(id="a-expired", timestamp="0"), + Row(id="a-expired", timestamp="1"), } elif batch_id == 2: # verify that rows and expired timer produce the expected result assert set(batch_df.sort("id").collect()) == { Row(id="a", timestamp="15"), - Row(id="a-expired", timestamp="10000"), } else: for q in self.spark.streams.active: @@ -933,7 +934,7 @@ def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: max_event_time = str(max(cur_max, max(timestamp_list))) self.max_state.update((max_event_time,)) - self.handle.registerTimer(timer_values.get_current_watermark_in_ms()) + self.handle.registerTimer(timer_values.get_current_watermark_in_ms() + 1) yield pd.DataFrame({"id": key, "timestamp": max_event_time}) @@ -961,8 +962,6 @@ def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[ idx += 1 if len(timer_list_1) > 0: - # before deleting the expiring timers, there are 2 timers - - # one timer we just registered, and one that is going to be deleted assert len(timer_list_1) == 2 self.count_state.clear() self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) @@ -981,7 +980,7 @@ def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: count = int(self.count_state.get()[0]) if key == ("0",): - self.handle.registerTimer(timer_values.get_current_processing_time_in_ms()) + self.handle.registerTimer(timer_values.get_current_processing_time_in_ms() + 1) rows_count = 0 for pdf in rows: @@ -1041,9 +1040,7 @@ class StatefulProcessorChainingOps(StatefulProcessor): def init(self, handle: StatefulProcessorHandle) -> None: pass - def handleInputRows( - self, key, rows, timer_values, expired_timer_info - ) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: for pdf in rows: timestamp_list = pdf["eventTime"].tolist() yield pd.DataFrame({"id": key, "outputTimestamp": timestamp_list[0]}) From 2cd4f8eab0c0c270a9fd3a84c447d0a9f5750c4c Mon Sep 17 00:00:00 2001 From: jingz-db Date: Wed, 27 Nov 2024 15:21:08 -0800 Subject: [PATCH 20/22] check for expiry key in same batch --- .../tests/pandas/test_pandas_transform_with_state.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index c1154819bd788..ac6eda4bc7def 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -559,18 +559,22 @@ def prepare_batch3(input_path): def test_transform_with_state_in_pandas_event_time(self): def check_results(batch_df, batch_id): if batch_id == 0: + # check timer registered in the same batch is expired assert set(batch_df.sort("id").collect()) == { Row(id="a", timestamp="20"), + Row(id="a-expired", timestamp="0"), } elif batch_id == 1: + # value state is cleared in batch 0, so timestamp=4 is returned here assert set(batch_df.sort("id").collect()) == { - Row(id="a", timestamp="20"), - Row(id="a-expired", timestamp="1"), + Row(id="a", timestamp="4"), + Row(id='a-expired', timestamp='10000'), } elif batch_id == 2: - # verify that rows and expired timer produce the expected result + # event time is still 20-10=10, so we will see expired row with timestamp=10s assert set(batch_df.sort("id").collect()) == { Row(id="a", timestamp="15"), + Row(id='a-expired', timestamp='10000'), } else: for q in self.spark.streams.active: @@ -934,7 +938,7 @@ def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: max_event_time = str(max(cur_max, max(timestamp_list))) self.max_state.update((max_event_time,)) - self.handle.registerTimer(timer_values.get_current_watermark_in_ms() + 1) + self.handle.registerTimer(timer_values.get_current_watermark_in_ms()) yield pd.DataFrame({"id": key, "timestamp": max_event_time}) From 00b63405cb862bf3f7e3e00e8d35c5f2c85a67a4 Mon Sep 17 00:00:00 2001 From: jingz-db Date: Wed, 27 Nov 2024 15:22:35 -0800 Subject: [PATCH 21/22] lint --- .../sql/tests/pandas/test_pandas_transform_with_state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index ac6eda4bc7def..4d2606b4be786 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -568,13 +568,13 @@ def check_results(batch_df, batch_id): # value state is cleared in batch 0, so timestamp=4 is returned here assert set(batch_df.sort("id").collect()) == { Row(id="a", timestamp="4"), - Row(id='a-expired', timestamp='10000'), + Row(id="a-expired", timestamp="10000"), } elif batch_id == 2: # event time is still 20-10=10, so we will see expired row with timestamp=10s assert set(batch_df.sort("id").collect()) == { Row(id="a", timestamp="15"), - Row(id='a-expired', timestamp='10000'), + Row(id="a-expired", timestamp="10000"), } else: for q in self.spark.streams.active: From b4f9d91baa35b227c6c3785a96678843a403f731 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 28 Nov 2024 13:35:21 +0900 Subject: [PATCH 22/22] fix linter, reflect review comments --- python/pyspark/sql/pandas/group_ops.py | 2 +- python/pyspark/sql/pandas/serializers.py | 6 +++--- .../tests/pandas/test_pandas_transform_with_state.py | 12 +++++++++--- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index e54e62db07c86..688ad4b05732e 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -16,7 +16,7 @@ # import itertools import sys -from typing import Any, Iterator, List, Optional, Union, Tuple, TYPE_CHECKING, cast +from typing import Any, Iterator, List, Optional, Union, TYPE_CHECKING, cast import warnings from pyspark.errors import PySparkTypeError diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index deb9d5f20cd0e..536bf7307065c 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1286,8 +1286,8 @@ def flatten_columns(cur_batch, col_name): data_batches = generate_data_batches(_batches) for k, g in groupby(data_batches, key=lambda x: x[0]): - yield TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g + yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g) - yield TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None + yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) - yield TransformWithStateInPandasFuncMode.COMPLETE, None, None + yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 4d2606b4be786..60f2c9348db3f 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -559,19 +559,25 @@ def prepare_batch3(input_path): def test_transform_with_state_in_pandas_event_time(self): def check_results(batch_df, batch_id): if batch_id == 0: - # check timer registered in the same batch is expired + # watermark for late event = 0 + # watermark for eviction = 0 + # timer is registered with expiration time = 0, hence expired at the same batch assert set(batch_df.sort("id").collect()) == { Row(id="a", timestamp="20"), Row(id="a-expired", timestamp="0"), } elif batch_id == 1: - # value state is cleared in batch 0, so timestamp=4 is returned here + # watermark for late event = 0 + # watermark for eviction = 10 (20 - 10) + # timer is registered with expiration time = 10, hence expired at the same batch assert set(batch_df.sort("id").collect()) == { Row(id="a", timestamp="4"), Row(id="a-expired", timestamp="10000"), } elif batch_id == 2: - # event time is still 20-10=10, so we will see expired row with timestamp=10s + # watermark for late event = 10 + # watermark for eviction = 10 (unchanged as 4 < 10) + # timer is registered with expiration time = 10, hence expired at the same batch assert set(batch_df.sort("id").collect()) == { Row(id="a", timestamp="15"), Row(id="a-expired", timestamp="10000"),