-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-50194][SS][PYTHON] Integration of New Timer API and Initial State API with Timer #48838
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, just several minor comments.
if timeMode != "none": | ||
batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() | ||
watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() | ||
else: | ||
batch_timestamp = -1 | ||
watermark_timestamp = -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we abstract this as a separate method and share in both UDFs to reduce redundant code?
@@ -420,10 +411,27 @@ def handleInputRows( | |||
timer_values: TimerValues | |||
Timer value for the current batch that process the input rows. | |||
Users can get the processing or event time timestamp from TimerValues. | |||
""" | |||
return iter([]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we change the ...
placeholder here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I confused this with the handleExpiredTimer
. Changed back to ...
""" | ||
return iter([]) | ||
|
||
def handleExpiredTimer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just double check that this method is not required for users to implement, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. Add a comment line in the docstring to explicitly saying this is optional to implement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
result_iter_list = [data_iter] | ||
# process with valid expiry time info and with empty input rows, | ||
# only timer related rows will be emitted | ||
# process with expiry timers, only timer related rows will be emitted |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have confused about this every time. Is this relying on the behavior that expired timer will be removed so we won't list up the same timer as expired multiple times? This is very easy to be forgotten.
If there is any way we can just move this out and do this after we process all input? Can this be done in transformWithStateUDF/transformWithStateWithInitStateUDF with key = null?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for catching this! I made a terrible correctness bug in my prior timer implementation. I now moved all timer handling codes into serializer.py
where the expired timers are processed per partition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left an explanation of what is causing the correctness issue in my prior implementation here just in case you are curious: #48838 (comment)
@@ -573,7 +579,11 @@ def transformWithStateUDF( | |||
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) | |||
return iter([]) | |||
|
|||
result = handle_data_with_timers(statefulProcessorApiClient, key, inputRows) | |||
batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally this shouldn't be called at every key. If we split out the handling of timer expiration from the handling of input rows, we would only need to call this at once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this get_timestamps()
into stateful_processor_api_client.py
inside __init__()
so we will only make an API call only once for each batch.
I'll revisit the PR once my comments are addressed (or @jingz-db has reasonable point of not doing this), as my proposal would change the code non-trivially. |
@@ -1201,11 +1201,89 @@ def generate_data_batches(batches): | |||
|
|||
def dump_stream(self, iterator, stream): | |||
""" | |||
Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow | |||
RecordBatches, and write batches to stream. | |||
Read through chained return results from a single partition of handleInputRows. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bogao007 Could you revisit this change? This was changed since last time you reviewed because I found a correctness bug in my prior timer change. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my prior implementation, correctness issue happens if there are multiple keys expired on a single partition. E.g. test case test_transform_with_state_init_state_with_timers
will fail if we set the partition to "1".
Previously we call get_expiry_timers_iterator()
and handleExpiredTimer()
in the group_ops.py
inside the UDF which is called per key. So when we register timer for key "0" inside handleInitialState()
and then we will enter get_expiry_timers_iterator()
. Because at that time UDF of key "3" is not called yet, timer for key "3" is not registered. We will only see key "0" expires and will only get Row(id="0-expired")
in the output of first batch. When we enter the UDF for key "3", as in TransformWithStateInPandasStateServer
here we enforce expiryTimestampIter will only be consumed once per partition, JVM will return none for key "3" as this iterator is already consumed for key "0". This way we have a correctness issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I now moved the handleExpiredTimer
inside serializer.py
, so get_expiry_timers_iterator()
will be called after all handleInitialState()
are executed for all keys on the partition, and it is also chained after all handleInputRows()
are called on all keys on the same partition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jingz-db for the detailed explaination! Do you think if we should add a test case where multiple keys are expired in the same partition? Like we either set partition num to 1 or increase the input to have more keys
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to verify this explicitly from test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a test_transform_with_state_with_timers_single_partition
to test with all timer suites with single partition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did first pass after the change, LGTM overall, mostly minor comments
else: | ||
expiry_list_iter = iter([[]]) | ||
|
||
def timer_iter_wrapper(func, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: can we move this method definition to the top of dump_stream
to follow the same pattern in this file? This would also make the code easier to read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved just below statefulProcessorApiClient
is initialized. We will need to access this object from timer_iter_wrapper
.
@@ -1201,11 +1201,89 @@ def generate_data_batches(batches): | |||
|
|||
def dump_stream(self, iterator, stream): | |||
""" | |||
Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow | |||
RecordBatches, and write batches to stream. | |||
Read through chained return results from a single partition of handleInputRows. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe better to include what the structure looks like for input iterator
given we have added a bunch of new objects as the UDF output. Either add it here or down below where args
are being defined.
@@ -1201,11 +1201,89 @@ def generate_data_batches(batches): | |||
|
|||
def dump_stream(self, iterator, stream): | |||
""" | |||
Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow | |||
RecordBatches, and write batches to stream. | |||
Read through chained return results from a single partition of handleInputRows. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jingz-db for the detailed explaination! Do you think if we should add a test case where multiple keys are expired in the same partition? Like we either set partition num to 1 or increase the input to have more keys
@@ -1201,11 +1201,89 @@ def generate_data_batches(batches): | |||
|
|||
def dump_stream(self, iterator, stream): | |||
""" | |||
Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow | |||
RecordBatches, and write batches to stream. | |||
Read through chained return results from a single partition of handleInputRows. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason we can't do this in if key is None:
in transformWithStateUDF
and transformWithStateWithInitStateUDF
?
This was my suggestion and I believe you can just do retrieve expired timers and timestamps, and call handleExpiredTimer() with these information, and done. I don't think this complication is necessary - if we can't do this in key is None
in some reason, I suspect fixing that would be much easier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll wait for the next update about whether my suggestion works or not. I think the complexity would be very different, hence I would like to defer the further review after that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TLDR; if we put the timer handling code inside if key is None
, we will add higher code complexity.
We need to make a tradeoff whether adding the complication in either serializer.py
or TransformWithStateInPandasPythonRunner
if we put the above timer handling codes in if key is None
.
If we put the timer handling logic inside if key is None
, we will need to call dump_stream()
again here in finally
code block: https://github.com/apache/spark/blob/master/python/pyspark/worker.py#L1966. Calling dump_stream() twice means we will need to properly handle how JVM receives batches. Currently we are reusing the read()
function inside PythonArrowOutput
, and the reader will end the reading when Python dump_stream signals the end here: https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala#L118. Since we are now calling dump_stream()
twice, We will need to overwrite this function in TransformWithStateInPandasPythonRunner
and continues reading one more time after receiving end. The extra complexity is that we will also need to properly handle the case where some partitions may not have timer iterator and won't start the additional dump stream writer at all and how we are going to handle exceptions if one of the dump_stream failed. Additionally, we need to set the statefulHandlerState to TIMER_PROCESSED
after all timer rows are processed so we will need to do some code changes inside worker.py
to set this properly. So this means we will need to get the StatefulProcessorHandlerApiClient
object inside worker.py
to set the state correctly. This means we will need to have similar code complexity of what we have now in serializer.py
(return one extra StatefulProcessorHandlerApiClient from transformWithStateWithInitStateUDF
and deserialize it from out_iter
). We cannot set the TIMER_PROCESSED
state in group_ops.py
because the output rows iterator are not fully consumed there. It is fully consumed after dump_stream
is called inside worker.py
.
So either way we will need to deal with extra complexity. I personally think putting timer handling code into serializer.py
is slightly better because this is more similar to how we are dealing with timer on Scala side - we are chaining the timer output rows after the data handling rows into a single iterator.
Let me know if you have suggestions on which way is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the implementation of my suggestion based on 0c5ab3f. I've confirmed that pyspark.sql.tests.pandas.test_pandas_transform_with_state
passed with this change - I haven't added new tests you've added later though.
I think this is lot much simpler - we just add two markers into input iterator which carries over the mode, and the flow does not change at all. No trick on teeing and chaining iterators, minimum changes on the data structure, etc.
How this works? This is just the same with how we use iterator in Spark in Scala codebase; with iterator in Scala, we pull one entry, process it and produce output, and pull another entry. The generator would have each entry for every grouping key, and then the marker for timer, and then the marker for completion. Each entry will call the function which eventually calls the user function, and the user function is expected to return the iterator, but the logic to produce the iterator should be synchronous (no async and no laziness, otherwise I guess it can even fail without my change).
So when the marker for timer has been evaluated, function calls for all grouping keys must have been already done. Same for the marker for completion. This is same with Scala implementation.
As a side effect, updating the phase is corrected in this commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you agree with this, please pick the commit in above. You've already gone through some commits and I can't revert partially by myself.
My fork is public, so you can add my repo and fetch and pull the branch, and cherrypick the commit into this PR branch with merge conflict. I'd recommend you to take whole different way - perform "hard reset" to my commit in this PR branch (git reset --hard f8952b213ba7f2cbfbc78ef145552317812e9f9b
), and add more commits which are used to address other review comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for putting out the commit! I cherry-picked your change and this is now looking much cleaner!
self.max_state.update((max_event_time,)) | ||
self.handle.registerTimer(timer_values.get_current_watermark_in_ms()) | ||
self.max_state.update((max_event_time,)) | ||
self.handle.registerTimer(timer_values.get_current_watermark_in_ms() + 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I modified this to current batch timestamp + 1 for testing with more common use cases as registering with current batch timestamp is not a very common use case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also check if timer can expire in the same batch. So I am keeping event time suite as timer expiring in same batch and register a future timestamp for the processing time suite.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only nit and linter failure. Thanks for the patience.
@@ -1197,7 +1198,11 @@ def generate_data_batches(batches): | |||
data_batches = generate_data_batches(_batches) | |||
|
|||
for k, g in groupby(data_batches, key=lambda x: x[0]): | |||
yield (k, g) | |||
yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: looks like not consistent? Here we use tuple with explicit ()
and below class we don't use ()
. Not a huge deal if linter does not complain, but while we are here (linter is failing)...
@@ -558,14 +559,19 @@ def prepare_batch3(input_path): | |||
def test_transform_with_state_in_pandas_event_time(self): | |||
def check_results(batch_df, batch_id): | |||
if batch_id == 0: | |||
assert set(batch_df.sort("id").collect()) == {Row(id="a", timestamp="20")} | |||
elif batch_id == 1: | |||
# check timer registered in the same batch is expired |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's comment on watermark for late event
and watermark for eviction
per batch, to help verify the output. e.g. in batch_id == 1, watermark for eviction is 10, but the watermark for late event is 0, hence 4 is accepted. The value of timestamp in expired row will follow the value of watermark for eviction
, hence also helpful.
I just pushed a commit addressing my own review comments as well as linter failure. These are nits so I think it wouldn't matter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 pending CI
CI has passed: https://github.com/jingz-db/spark/runs/33636466911 Thanks! Merging to master. |
What changes were proposed in this pull request?
As Scala side, we modify the timer API with a separate
handleExpiredTimer
function insideStatefulProcessor
, this PR make a change to the timer API to couple with API on Scala side. Also adds a timer parameter to pass intohandleInitialState
function to support use cases for registering timers in the first batch for initial state rows.Why are the changes needed?
This change is to couple with Scala side of APIs: #48553
Does this PR introduce any user-facing change?
Yes.
We add a new user defined function to explicitly handle expired timeres:
We also add a new timer parameter to enable users to register timers for keys exist in the initial state:
How was this patch tested?
Add a new test in
test_pandas_transform_with_state
Was this patch authored or co-authored using generative AI tooling?
No