Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

End-to-end support for concurrent async models #2066

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

philandstuff
Copy link
Contributor

This builds on the work in #2057 and wires it up end-to-end.

We can now support async models with a max concurrency configured, and submit
multiple predictions concurrently to them.

We only support python 3.11 for async models; this is so that we can use
asyncio.TaskGroup to keep track of multiple predictions in flight and ensure
they all complete when shutting down.

The cog http server was already async, but at one point it called wait() on a
concurrent.futures.Future() which blocked the event loop and therefore prevented
concurrent prediction requests (when not using prefer-async, which is how the
tests run). I have updated this code to wait on asyncio.wrap_future(fut)
instead which does not block the event loop. As part of this I have updated the
training endpoints to also be asynchronous.

We now have three places in the code which keep track of how many predictions
are in flight: PredictionRunner, Worker and _ChildWorker all do their own
bookkeeping. I'm not sure this is the best design but it works.

The code is now an uneasy mix of threaded and asyncio code. This is evident in
the usage of threading.Lock, which wouldn't be needed if we were 100% async (and
I'm not sure if it's actually needed currently; I just added it to be safe).

We require python >=3.11 to support asyncio.TaskGroup
This builds on the work in #2057 and wires it up end-to-end.

We can now support async models with a max concurrency configured, and submit
multiple predictions concurrently to them.

We only support python 3.11 for async models; this is so that we can use
asyncio.TaskGroup to keep track of multiple predictions in flight and ensure
they all complete when shutting down.

The cog http server was already async, but at one point it called wait() on a
concurrent.futures.Future() which blocked the event loop and therefore prevented
concurrent prediction requests (when not using prefer-async, which is how the
tests run).  I have updated this code to wait on asyncio.wrap_future(fut)
instead which does not block the event loop.  As part of this I have updated the
training endpoints to also be asynchronous.

We now have three places in the code which keep track of how many predictions
are in flight: PredictionRunner, Worker and _ChildWorker all do their own
bookkeeping. I'm not sure this is the best design but it works.

The code is now an uneasy mix of threaded and asyncio code.  This is evident in
the usage of threading.Lock, which wouldn't be needed if we were 100% async (and
I'm not sure if it's actually needed currently; I just added it to be safe).
@philandstuff philandstuff requested a review from a team November 26, 2024 11:33
Copy link
Member

@erbridge erbridge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good to me. A few minor comments and suggestions.

self._prediction_id = prediction.id
tag = prediction.id
if tag is None:
tag = uuid.uuid4().hex
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need two ways of getting the tag? Are we using the prediction ID sometimes for convenience or does it matter under some conditions and if so, does the fallback still work? If we can, I think it would be better to use one mechanism (presumably the UUID) than 2, so we don't have any temptation to rely on variable the behaviour.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The external interface is that callers can cancel predictions by providing the prediction id. If we were to always use uuids for tags, we'd need to maintain a mapping from prediction id to tag. I'd rather not do that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So should it be an error if we don't have a prediction.id?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's currently allowed to not provide a prediction id, it just means you can't then cancel predictions.

We could make it required, that'd be a breaking change but I don't imagine it would actually break many users.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Final thought: should we be setting prediction.id if it's not already set here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Final thought: should we be setting prediction.id if it's not already set here?

I do think it's worth considering whether auto generating a prediction id is a useful thing, it likely is. However as this would change the existing behavior I think it's best left out of this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually i think this is important, otherwise we will default to None when we record in self._predict_tasks and we will overwrite any running task. This will make the busy check report us as not busy even if we're actually at max concurrency.

I'll push a fix to do this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I discussed with @aron and I'm wrong and he's right. (I briefly pushed a commit that always set prediction.id but that wasn't needed and we're backing it out.)

Setting prediction.id is weird because it changes the task we give back to the user.

We don't need to set it (contrary to my comment above) because we key _predict_tasks on tag, not on prediction.id.

python/cog/server/runner.py Outdated Show resolved Hide resolved
@@ -124,9 +139,13 @@ def is_busy(self) -> bool:
def cancel(self, prediction_id: str) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could consistently call the ID the "task ID"? Whether or not it's also the prediction ID, it feels like we're using it to identify the task, not the prediction, aren't we?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The external interface for managing predictions currently uses the prediction ID. Internally we've abstracted this into the concept of a "tag" so that the prediction ID remains optional. This may not be the right thing to do longer term, but for the moment it is consistent.

python/cog/server/worker.py Outdated Show resolved Hide resolved
python/cog/server/runner.py Show resolved Hide resolved
python/cog/server/worker.py Show resolved Hide resolved
elif isinstance(e.event, Shutdown):
break
elif isinstance(e.event, PredictionInput):
task = tg.create_task(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this potentially memory leak all these events forever by adding them to the task group?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completed tasks are removed from the task group once they resolve. So no, this wont leak memory. However if any task raises an exception other than an asyncio.CancelledError the existing tasks in the group will be cancelled. I don't think this is what we want…

Copy link
Contributor

@aron aron left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just tested this locally with a basic async model that yields strings at an interval. Works wonderfully.

python/cog/server/runner.py Show resolved Hide resolved
python/cog/server/worker.py Outdated Show resolved Hide resolved
python/cog/server/worker.py Show resolved Hide resolved
The use of `Optional` allowed `None` as a valid value. This has been
changed to use `NotRequired` which allows the field to be omitted but
must always be an integer when present.
@aron aron force-pushed the support-concurrent-predictions-in-child branch from 3078278 to e307a39 Compare November 26, 2024 15:36
Inside the worker we track predictions by tag not exterenal predicition
IDs, this commit updates the variable names to reflect this.
@aron aron force-pushed the support-concurrent-predictions-in-child branch from e307a39 to 41adaa9 Compare November 26, 2024 15:40
the `for tag in done_tags:` was resetting the existing `tag` variable and
breaking things.
@aron
Copy link
Contributor

aron commented Nov 26, 2024

Cutting a 0.14.0-alpha.1 pre-release build https://github.com/replicate/cog/actions/runs/12035403546

@aron
Copy link
Contributor

aron commented Nov 26, 2024

This is probably for later, but I don't think we have the correct support for typing the output of these async predict functions. For example async def predict(...) -> Iterator[str] will raise type errors as it should be AsyncGenerator

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants