Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Worker to support concurrent predictions #2057

Merged
merged 7 commits into from
Nov 25, 2024

Conversation

philandstuff
Copy link
Contributor

PLAT-502.

This adds a max_concurrency parameter to Worker that allows it to accept
predictions while others are already in-flight.

I have removed the PROCESSING WorkerState; in concurrent mode, there is no
distinction between READY and PROCESSING because we might be able to accept a
prediction in either case.

Worker now keeps track of multiple in-flight predictions in a dictionary, keyed
on tag. Tags are required if max_concurrency > 1. Otherwise tags are
optional (and, if omitted, we store the prediction with tag=None).

There is one awkward place which is _prepare_payload(). As I understand it,
this synchronously downloads URLFiles, which will block us from processing any
other updates from the child while we download the URL.

Opening as draft as I have local test failures I don't fully understand - TestWorkerState is flaky, and the test_build integration tests are failing?!

@philandstuff philandstuff force-pushed the support-concurrent-predictions branch 3 times, most recently from 87e8be8 to 86b421a Compare November 15, 2024 15:27
@philandstuff philandstuff changed the base branch from main to fix-test-flake November 15, 2024 15:27
@philandstuff
Copy link
Contributor Author

Ok this is ready for review now I think I've got a handle on the tests

@philandstuff philandstuff marked this pull request as ready for review November 15, 2024 15:28
Base automatically changed from fix-test-flake to main November 15, 2024 15:43
@philandstuff philandstuff force-pushed the support-concurrent-predictions branch 3 times, most recently from 2afe540 to 9c51063 Compare November 18, 2024 10:28
Copy link
Member

@nickstenning nickstenning left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I think this is looking great. Left a bunch of comments inline, most of which are for my own understanding.

self._assert_state(WorkerState.READY)
result = Future()
self._predictions_in_flight[tag] = PredictionState(tag, payload, result)
self._request_send_conn.send(PredictionRequest(tag))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this happen while holding the lock?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, no

python/cog/server/worker.py Show resolved Hide resolved
@@ -197,93 +236,106 @@ def _consume_events_inner(self) -> None:
# If we didn't get a done event, the child process died.
if not done:
exitcode = self._child.exitcode
assert self._result
self._result.set_exception(
assert self._setup_result
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this assertion (and those that follow) is probably now redundant given that self._setup_result is always set?

if predict_state and not predict_state.cancel_sent:
self._child.send_cancel()
self._events.send(Envelope(event=Cancel(), tag=ev.tag))
predict_state.cancel_sent = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we explicitly handle the case that we get an object that isn't a PredictionRequest or a CancelRequest?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can stick a log.warn in but I'm not sure I want to crash or anything like that

Comment on lines 269 to 270
for sock in read_socks:
if sock == self._request_recv_conn:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Purely stylistic, but with select.select this is often written as

Suggested change
for sock in read_socks:
if sock == self._request_recv_conn:
if self._request_recv_conn in read_socks:

Comment on lines 331 to 332
if len(self._predictions_in_flight) == 0:
self._state = WorkerState.READY
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this right? I don't see that worker ever moves out of state READY (in normal operations) after it's got there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good spot, this is vestigial

self._state = WorkerState.DEFUNCT
with self._predictions_lock:
for tag in list(self._predictions_in_flight.keys()):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The importance of the list(...) is IMO a bit obscure to the reader.

I'd be tempted to do something like:

# take a copy so we can change the dictionary
tags = [t for t in self._predictions_in_flight.keys()]
for t in tags:
    ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, come to think of it, given you have a lock on the whole object, maybe just:

for state in self._predictions_in_flight.values():
    state.result.set_exception(...)
self._predictions_in_flight.clear()

Copy link
Contributor Author

@philandstuff philandstuff Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

come to think of it, there's no need to clear out the dict here, we're tearing the whole worker down anyway. I'll just remove the del and the list(...)

EDIT: actually I like your second suggestion best

Comment on lines 724 to 748
if state.canceled:
# if this has previously been canceled, we expect no Cancel event
# sent to the child
assert not self.child_events.poll(timeout=0.1)
else:
assert self.child_events.poll(timeout=0.5)
e = self.child_events.recv()
assert isinstance(e, Envelope)
assert isinstance(e.event, Cancel)
assert e.tag == state.tag
assert self.child.cancel_sent
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how I feel about asserting on what we sent to the child -- it feels like it might make things a bit more brittle than before. Was there something specific that prompted adding this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't have to assert, but listening to the child events is how we ensure we synchronize the test with the worker and wait for the worker event consumer thread to do its thing.

This doesn't matter too much here, but it does matter in predict() where we need to wait until the PredictionInput event has been sent to the child - if we don't do that, we can send a Done event for a prediction before the PredictionInput event is sent, which breaks Worker assumptions.

And once we're monitoring the child worker for the PredictionInput event we need to monitor the Cancel event here too because we have to monitor the whole connection.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a comment to make this clearer in predict()

assert result.output == "done in 0.5 seconds"

finally:
worker.unsubscribe(subid)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also unsubscribe for all the values in subids?

if state.canceled:
# if this has previously been canceled, we expect no Cancel event
# sent to the child
assert not self.child_events.poll(timeout=0.1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to check there are no events left after the else block below, like we do on line 639 above? If so, we could move this down below that block and assert it unconditionally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think re #2057 (comment) I'm actually more in favour of doing less asserting here rather than more. The code here is mainly to ensure we stay synced up with the child connection, not to make actual assertions - that should happen on the outside of the black box, as it were.

@philandstuff
Copy link
Contributor Author

I pushed 27e35b4 to update the hypothesis tests to support concurrent prediction subscribers, and 3f25357 with PR feedback

Copy link
Member

@evilstreak evilstreak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@nickstenning nickstenning left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. The only thing I had to double check was the implementation of the healthcheck -- which needs to return Health.BUSY when we're at max concurrency -- which I was assuming would use the worker state to figure that out, and we'd have broken that by removing WorkerState.PROCESSING. It turns out it's implemented completely differently right now, so we can worry about that later on.

philandstuff and others added 7 commits November 25, 2024 14:19
This adds support to TestWorkerState for prediction tags as added in #2020.  We
want to test that if we tag predictions and subscribe to those tags, the worker
still behaves as expected and we receive the events we expect to receive.

This also updates cancel to assert that send_cancel() was called on the child
worker.
As we start performing concurrent predictions, we will hit problems iterating
over the subscribers dictionary in one thread while another thread tries to add
or remove subscribers.
This adds a max_concurrency parameter to Worker that allows it to accept
predictions while others are already in-flight.

I have removed the PROCESSING WorkerState; in concurrent mode, there is no
distinction between READY and PROCESSING because we might be able to accept a
prediction in either case.

Worker now keeps track of multiple in-flight predictions in a dictionary, keyed
on tag.  Tags are required if max_concurrency > 1.  Otherwise tags are
optional (and, if omitted, we store the prediction with tag=None).

There is one awkward place which is _prepare_payload().  As I understand it,
this synchronously downloads URLFiles, which will block us from processing any
other updates from the child while we download the URL.
This updates TestWorkerState to allow multiple prediction subscribers to be
active concurrently, to ensure that we are correctly publishing to specific
tags.
Co-authored-by: Nick Stenning <[email protected]>
Signed-off-by: Philip Potter <[email protected]>
@philandstuff philandstuff enabled auto-merge (rebase) November 25, 2024 14:19
@philandstuff philandstuff merged commit cf0f8b2 into main Nov 25, 2024
19 checks passed
@philandstuff philandstuff deleted the support-concurrent-predictions branch November 25, 2024 14:29
philandstuff added a commit that referenced this pull request Nov 26, 2024
This builds on the work in #2057 and wires it up end-to-end.

We can now support async models with a max concurrency configured, and submit
multiple predictions concurrently to them.

We only support python 3.11 for async models; this is so that we can use
asyncio.TaskGroup to keep track of multiple predictions in flight and ensure
they all complete when shutting down.

The cog http server was already async, but at one point it called wait() on a
concurrent.futures.Future() which blocked the event loop and therefore prevented
concurrent prediction requests (when not using prefer-async, which is how the
tests run).  I have updated this code to wait on asyncio.wrap_future(fut)
instead which does not block the event loop.  As part of this I have updated the
training endpoints to also be asynchronous.

We now have three places in the code which keep track of how many predictions
are in flight: PredictionRunner, Worker and _ChildWorker all do their own
bookkeeping. I'm not sure this is the best design but it works.

The code is now an uneasy mix of threaded and asyncio code.  This is evident in
the usage of threading.Lock, which wouldn't be needed if we were 100% async (and
I'm not sure if it's actually needed currently; I just added it to be safe).
aron pushed a commit that referenced this pull request Nov 28, 2024
This builds on the work in #2057 and wires it up end-to-end.

We can now support async models with a max concurrency configured, and submit
multiple predictions concurrently to them.

We only support python 3.11 for async models; this is so that we can use
asyncio.TaskGroup to keep track of multiple predictions in flight and ensure
they all complete when shutting down.

The cog http server was already async, but at one point it called wait() on a
concurrent.futures.Future() which blocked the event loop and therefore prevented
concurrent prediction requests (when not using prefer-async, which is how the
tests run).  I have updated this code to wait on asyncio.wrap_future(fut)
instead which does not block the event loop.  As part of this I have updated the
training endpoints to also be asynchronous.

We now have three places in the code which keep track of how many predictions
are in flight: PredictionRunner, Worker and _ChildWorker all do their own
bookkeeping. I'm not sure this is the best design but it works.

The code is now an uneasy mix of threaded and asyncio code.  This is evident in
the usage of threading.Lock, which wouldn't be needed if we were 100% async (and
I'm not sure if it's actually needed currently; I just added it to be safe).
aron pushed a commit that referenced this pull request Nov 28, 2024
This builds on the work in #2057 and wires it up end-to-end.

We can now support async models with a max concurrency configured, and submit
multiple predictions concurrently to them.

We only support python 3.11 for async models; this is so that we can use
asyncio.TaskGroup to keep track of multiple predictions in flight and ensure
they all complete when shutting down.

The cog http server was already async, but at one point it called wait() on a
concurrent.futures.Future() which blocked the event loop and therefore prevented
concurrent prediction requests (when not using prefer-async, which is how the
tests run).  I have updated this code to wait on asyncio.wrap_future(fut)
instead which does not block the event loop.  As part of this I have updated the
training endpoints to also be asynchronous.

We now have three places in the code which keep track of how many predictions
are in flight: PredictionRunner, Worker and _ChildWorker all do their own
bookkeeping. I'm not sure this is the best design but it works.

The code is now an uneasy mix of threaded and asyncio code.  This is evident in
the usage of threading.Lock, which wouldn't be needed if we were 100% async (and
I'm not sure if it's actually needed currently; I just added it to be safe).
aron pushed a commit that referenced this pull request Nov 29, 2024
This builds on the work in #2057 and wires it up end-to-end.

We can now support async models with a max concurrency configured, and submit
multiple predictions concurrently to them.

We only support python 3.11 for async models; this is so that we can use
asyncio.TaskGroup to keep track of multiple predictions in flight and ensure
they all complete when shutting down.

The cog http server was already async, but at one point it called wait() on a
concurrent.futures.Future() which blocked the event loop and therefore prevented
concurrent prediction requests (when not using prefer-async, which is how the
tests run).  I have updated this code to wait on asyncio.wrap_future(fut)
instead which does not block the event loop.  As part of this I have updated the
training endpoints to also be asynchronous.

We now have three places in the code which keep track of how many predictions
are in flight: PredictionRunner, Worker and _ChildWorker all do their own
bookkeeping. I'm not sure this is the best design but it works.

The code is now an uneasy mix of threaded and asyncio code.  This is evident in
the usage of threading.Lock, which wouldn't be needed if we were 100% async (and
I'm not sure if it's actually needed currently; I just added it to be safe).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants