-
Notifications
You must be signed in to change notification settings - Fork 565
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update Worker to support concurrent predictions #2057
Conversation
87e8be8
to
86b421a
Compare
Ok this is ready for review now I think I've got a handle on the tests |
2afe540
to
9c51063
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall I think this is looking great. Left a bunch of comments inline, most of which are for my own understanding.
python/cog/server/worker.py
Outdated
self._assert_state(WorkerState.READY) | ||
result = Future() | ||
self._predictions_in_flight[tag] = PredictionState(tag, payload, result) | ||
self._request_send_conn.send(PredictionRequest(tag)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this happen while holding the lock?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, no
python/cog/server/worker.py
Outdated
@@ -197,93 +236,106 @@ def _consume_events_inner(self) -> None: | |||
# If we didn't get a done event, the child process died. | |||
if not done: | |||
exitcode = self._child.exitcode | |||
assert self._result | |||
self._result.set_exception( | |||
assert self._setup_result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this assertion (and those that follow) is probably now redundant given that self._setup_result
is always set?
python/cog/server/worker.py
Outdated
if predict_state and not predict_state.cancel_sent: | ||
self._child.send_cancel() | ||
self._events.send(Envelope(event=Cancel(), tag=ev.tag)) | ||
predict_state.cancel_sent = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we explicitly handle the case that we get an object that isn't a PredictionRequest
or a CancelRequest
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can stick a log.warn
in but I'm not sure I want to crash or anything like that
python/cog/server/worker.py
Outdated
for sock in read_socks: | ||
if sock == self._request_recv_conn: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Purely stylistic, but with select.select
this is often written as
for sock in read_socks: | |
if sock == self._request_recv_conn: | |
if self._request_recv_conn in read_socks: |
python/cog/server/worker.py
Outdated
if len(self._predictions_in_flight) == 0: | ||
self._state = WorkerState.READY |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this right? I don't see that worker ever moves out of state READY
(in normal operations) after it's got there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good spot, this is vestigial
python/cog/server/worker.py
Outdated
self._state = WorkerState.DEFUNCT | ||
with self._predictions_lock: | ||
for tag in list(self._predictions_in_flight.keys()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The importance of the list(...)
is IMO a bit obscure to the reader.
I'd be tempted to do something like:
# take a copy so we can change the dictionary
tags = [t for t in self._predictions_in_flight.keys()]
for t in tags:
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or, come to think of it, given you have a lock on the whole object, maybe just:
for state in self._predictions_in_flight.values():
state.result.set_exception(...)
self._predictions_in_flight.clear()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
come to think of it, there's no need to clear out the dict here, we're tearing the whole worker down anyway. I'll just remove the del
and the list(...)
EDIT: actually I like your second suggestion best
python/tests/server/test_worker.py
Outdated
if state.canceled: | ||
# if this has previously been canceled, we expect no Cancel event | ||
# sent to the child | ||
assert not self.child_events.poll(timeout=0.1) | ||
else: | ||
assert self.child_events.poll(timeout=0.5) | ||
e = self.child_events.recv() | ||
assert isinstance(e, Envelope) | ||
assert isinstance(e.event, Cancel) | ||
assert e.tag == state.tag | ||
assert self.child.cancel_sent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure how I feel about asserting on what we sent to the child -- it feels like it might make things a bit more brittle than before. Was there something specific that prompted adding this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't have to assert, but listening to the child events is how we ensure we synchronize the test with the worker and wait for the worker event consumer thread to do its thing.
This doesn't matter too much here, but it does matter in predict() where we need to wait until the PredictionInput event has been sent to the child - if we don't do that, we can send a Done event for a prediction before the PredictionInput event is sent, which breaks Worker assumptions.
And once we're monitoring the child worker for the PredictionInput event we need to monitor the Cancel event here too because we have to monitor the whole connection.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add a comment to make this clearer in predict()
python/tests/server/test_worker.py
Outdated
assert result.output == "done in 0.5 seconds" | ||
|
||
finally: | ||
worker.unsubscribe(subid) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this also unsubscribe for all the values in subids
?
python/tests/server/test_worker.py
Outdated
if state.canceled: | ||
# if this has previously been canceled, we expect no Cancel event | ||
# sent to the child | ||
assert not self.child_events.poll(timeout=0.1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to check there are no events left after the else block below, like we do on line 639 above? If so, we could move this down below that block and assert it unconditionally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think re #2057 (comment) I'm actually more in favour of doing less asserting here rather than more. The code here is mainly to ensure we stay synced up with the child connection, not to make actual assertions - that should happen on the outside of the black box, as it were.
9c51063
to
3aadb5a
Compare
3f25357
to
deb2ff9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
✨
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. The only thing I had to double check was the implementation of the healthcheck -- which needs to return Health.BUSY
when we're at max concurrency -- which I was assuming would use the worker state to figure that out, and we'd have broken that by removing WorkerState.PROCESSING
. It turns out it's implemented completely differently right now, so we can worry about that later on.
This adds support to TestWorkerState for prediction tags as added in #2020. We want to test that if we tag predictions and subscribe to those tags, the worker still behaves as expected and we receive the events we expect to receive. This also updates cancel to assert that send_cancel() was called on the child worker.
As we start performing concurrent predictions, we will hit problems iterating over the subscribers dictionary in one thread while another thread tries to add or remove subscribers.
This adds a max_concurrency parameter to Worker that allows it to accept predictions while others are already in-flight. I have removed the PROCESSING WorkerState; in concurrent mode, there is no distinction between READY and PROCESSING because we might be able to accept a prediction in either case. Worker now keeps track of multiple in-flight predictions in a dictionary, keyed on tag. Tags are required if max_concurrency > 1. Otherwise tags are optional (and, if omitted, we store the prediction with tag=None). There is one awkward place which is _prepare_payload(). As I understand it, this synchronously downloads URLFiles, which will block us from processing any other updates from the child while we download the URL.
This updates TestWorkerState to allow multiple prediction subscribers to be active concurrently, to ensure that we are correctly publishing to specific tags.
Thanks @evilstreak and @nickstenning
Co-authored-by: Nick Stenning <[email protected]> Signed-off-by: Philip Potter <[email protected]>
17e3594
to
625f5de
Compare
This builds on the work in #2057 and wires it up end-to-end. We can now support async models with a max concurrency configured, and submit multiple predictions concurrently to them. We only support python 3.11 for async models; this is so that we can use asyncio.TaskGroup to keep track of multiple predictions in flight and ensure they all complete when shutting down. The cog http server was already async, but at one point it called wait() on a concurrent.futures.Future() which blocked the event loop and therefore prevented concurrent prediction requests (when not using prefer-async, which is how the tests run). I have updated this code to wait on asyncio.wrap_future(fut) instead which does not block the event loop. As part of this I have updated the training endpoints to also be asynchronous. We now have three places in the code which keep track of how many predictions are in flight: PredictionRunner, Worker and _ChildWorker all do their own bookkeeping. I'm not sure this is the best design but it works. The code is now an uneasy mix of threaded and asyncio code. This is evident in the usage of threading.Lock, which wouldn't be needed if we were 100% async (and I'm not sure if it's actually needed currently; I just added it to be safe).
This builds on the work in #2057 and wires it up end-to-end. We can now support async models with a max concurrency configured, and submit multiple predictions concurrently to them. We only support python 3.11 for async models; this is so that we can use asyncio.TaskGroup to keep track of multiple predictions in flight and ensure they all complete when shutting down. The cog http server was already async, but at one point it called wait() on a concurrent.futures.Future() which blocked the event loop and therefore prevented concurrent prediction requests (when not using prefer-async, which is how the tests run). I have updated this code to wait on asyncio.wrap_future(fut) instead which does not block the event loop. As part of this I have updated the training endpoints to also be asynchronous. We now have three places in the code which keep track of how many predictions are in flight: PredictionRunner, Worker and _ChildWorker all do their own bookkeeping. I'm not sure this is the best design but it works. The code is now an uneasy mix of threaded and asyncio code. This is evident in the usage of threading.Lock, which wouldn't be needed if we were 100% async (and I'm not sure if it's actually needed currently; I just added it to be safe).
This builds on the work in #2057 and wires it up end-to-end. We can now support async models with a max concurrency configured, and submit multiple predictions concurrently to them. We only support python 3.11 for async models; this is so that we can use asyncio.TaskGroup to keep track of multiple predictions in flight and ensure they all complete when shutting down. The cog http server was already async, but at one point it called wait() on a concurrent.futures.Future() which blocked the event loop and therefore prevented concurrent prediction requests (when not using prefer-async, which is how the tests run). I have updated this code to wait on asyncio.wrap_future(fut) instead which does not block the event loop. As part of this I have updated the training endpoints to also be asynchronous. We now have three places in the code which keep track of how many predictions are in flight: PredictionRunner, Worker and _ChildWorker all do their own bookkeeping. I'm not sure this is the best design but it works. The code is now an uneasy mix of threaded and asyncio code. This is evident in the usage of threading.Lock, which wouldn't be needed if we were 100% async (and I'm not sure if it's actually needed currently; I just added it to be safe).
This builds on the work in #2057 and wires it up end-to-end. We can now support async models with a max concurrency configured, and submit multiple predictions concurrently to them. We only support python 3.11 for async models; this is so that we can use asyncio.TaskGroup to keep track of multiple predictions in flight and ensure they all complete when shutting down. The cog http server was already async, but at one point it called wait() on a concurrent.futures.Future() which blocked the event loop and therefore prevented concurrent prediction requests (when not using prefer-async, which is how the tests run). I have updated this code to wait on asyncio.wrap_future(fut) instead which does not block the event loop. As part of this I have updated the training endpoints to also be asynchronous. We now have three places in the code which keep track of how many predictions are in flight: PredictionRunner, Worker and _ChildWorker all do their own bookkeeping. I'm not sure this is the best design but it works. The code is now an uneasy mix of threaded and asyncio code. This is evident in the usage of threading.Lock, which wouldn't be needed if we were 100% async (and I'm not sure if it's actually needed currently; I just added it to be safe).
PLAT-502.
This adds a max_concurrency parameter to Worker that allows it to accept
predictions while others are already in-flight.
I have removed the PROCESSING WorkerState; in concurrent mode, there is no
distinction between READY and PROCESSING because we might be able to accept a
prediction in either case.
Worker now keeps track of multiple in-flight predictions in a dictionary, keyed
on tag. Tags are required if max_concurrency > 1. Otherwise tags are
optional (and, if omitted, we store the prediction with tag=None).
There is one awkward place which is _prepare_payload(). As I understand it,
this synchronously downloads URLFiles, which will block us from processing any
other updates from the child while we download the URL.
Opening as draft as I have local test failures I don't fully understand - TestWorkerState is flaky, and the test_build integration tests are failing?!