-
Notifications
You must be signed in to change notification settings - Fork 563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
End-to-end support for concurrent async models #2066
base: main
Are you sure you want to change the base?
Conversation
We require python >=3.11 to support asyncio.TaskGroup
This builds on the work in #2057 and wires it up end-to-end. We can now support async models with a max concurrency configured, and submit multiple predictions concurrently to them. We only support python 3.11 for async models; this is so that we can use asyncio.TaskGroup to keep track of multiple predictions in flight and ensure they all complete when shutting down. The cog http server was already async, but at one point it called wait() on a concurrent.futures.Future() which blocked the event loop and therefore prevented concurrent prediction requests (when not using prefer-async, which is how the tests run). I have updated this code to wait on asyncio.wrap_future(fut) instead which does not block the event loop. As part of this I have updated the training endpoints to also be asynchronous. We now have three places in the code which keep track of how many predictions are in flight: PredictionRunner, Worker and _ChildWorker all do their own bookkeeping. I'm not sure this is the best design but it works. The code is now an uneasy mix of threaded and asyncio code. This is evident in the usage of threading.Lock, which wouldn't be needed if we were 100% async (and I'm not sure if it's actually needed currently; I just added it to be safe).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty good to me. A few minor comments and suggestions.
self._prediction_id = prediction.id | ||
tag = prediction.id | ||
if tag is None: | ||
tag = uuid.uuid4().hex |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need two ways of getting the tag? Are we using the prediction ID sometimes for convenience or does it matter under some conditions and if so, does the fallback still work? If we can, I think it would be better to use one mechanism (presumably the UUID) than 2, so we don't have any temptation to rely on variable the behaviour.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The external interface is that callers can cancel predictions by providing the prediction id. If we were to always use uuids for tags, we'd need to maintain a mapping from prediction id to tag. I'd rather not do that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So should it be an error if we don't have a prediction.id
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's currently allowed to not provide a prediction id, it just means you can't then cancel predictions.
We could make it required, that'd be a breaking change but I don't imagine it would actually break many users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Final thought: should we be setting prediction.id
if it's not already set here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Final thought: should we be setting prediction.id if it's not already set here?
I do think it's worth considering whether auto generating a prediction id is a useful thing, it likely is. However as this would change the existing behavior I think it's best left out of this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually i think this is important, otherwise we will default to None
when we record in self._predict_tasks
and we will overwrite any running task. This will make the busy check report us as not busy even if we're actually at max concurrency.
I'll push a fix to do this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I discussed with @aron and I'm wrong and he's right. (I briefly pushed a commit that always set prediction.id but that wasn't needed and we're backing it out.)
Setting prediction.id is weird because it changes the task we give back to the user.
We don't need to set it (contrary to my comment above) because we key _predict_tasks on tag
, not on prediction.id
.
@@ -124,9 +139,13 @@ def is_busy(self) -> bool: | |||
def cancel(self, prediction_id: str) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we could consistently call the ID the "task ID"? Whether or not it's also the prediction ID, it feels like we're using it to identify the task, not the prediction, aren't we?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The external interface for managing predictions currently uses the prediction ID. Internally we've abstracted this into the concept of a "tag" so that the prediction ID remains optional. This may not be the right thing to do longer term, but for the moment it is consistent.
elif isinstance(e.event, Shutdown): | ||
break | ||
elif isinstance(e.event, PredictionInput): | ||
task = tg.create_task( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this potentially memory leak all these events forever by adding them to the task group?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Completed tasks are removed from the task group once they resolve. So no, this wont leak memory. However if any task raises an exception other than an asyncio.CancelledError
the existing tasks in the group will be cancelled. I don't think this is what we want…
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've just tested this locally with a basic async model that yields strings at an interval. Works wonderfully.
The use of `Optional` allowed `None` as a valid value. This has been changed to use `NotRequired` which allows the field to be omitted but must always be an integer when present.
3078278
to
e307a39
Compare
Inside the worker we track predictions by tag not exterenal predicition IDs, this commit updates the variable names to reflect this.
e307a39
to
41adaa9
Compare
the `for tag in done_tags:` was resetting the existing `tag` variable and breaking things.
Cutting a 0.14.0-alpha.1 pre-release build https://github.com/replicate/cog/actions/runs/12035403546 |
This is probably for later, but I don't think we have the correct support for typing the output of these async predict functions. For example |
This builds on the work in #2057 and wires it up end-to-end.
We can now support async models with a max concurrency configured, and submit
multiple predictions concurrently to them.
We only support python 3.11 for async models; this is so that we can use
asyncio.TaskGroup to keep track of multiple predictions in flight and ensure
they all complete when shutting down.
The cog http server was already async, but at one point it called wait() on a
concurrent.futures.Future() which blocked the event loop and therefore prevented
concurrent prediction requests (when not using prefer-async, which is how the
tests run). I have updated this code to wait on asyncio.wrap_future(fut)
instead which does not block the event loop. As part of this I have updated the
training endpoints to also be asynchronous.
We now have three places in the code which keep track of how many predictions
are in flight: PredictionRunner, Worker and _ChildWorker all do their own
bookkeeping. I'm not sure this is the best design but it works.
The code is now an uneasy mix of threaded and asyncio code. This is evident in
the usage of threading.Lock, which wouldn't be needed if we were 100% async (and
I'm not sure if it's actually needed currently; I just added it to be safe).