Skip to content

Commit

Permalink
End-to-end support for concurrent async models
Browse files Browse the repository at this point in the history
This builds on the work in #2057 and wires it up end-to-end.

We can now support async models with a max concurrency configured, and submit
multiple predictions concurrently to them.

We only support python 3.11 for async models; this is so that we can use
asyncio.TaskGroup to keep track of multiple predictions in flight and ensure
they all complete when shutting down.

The cog http server was already async, but at one point it called wait() on a
concurrent.futures.Future() which blocked the event loop and therefore prevented
concurrent prediction requests (when not using prefer-async, which is how the
tests run).  I have updated this code to wait on asyncio.wrap_future(fut)
instead which does not block the event loop.  As part of this I have updated the
training endpoints to also be asynchronous.

We now have three places in the code which keep track of how many predictions
are in flight: PredictionRunner, Worker and _ChildWorker all do their own
bookkeeping. I'm not sure this is the best design but it works.

The code is now an uneasy mix of threaded and asyncio code.  This is evident in
the usage of threading.Lock, which wouldn't be needed if we were 100% async (and
I'm not sure if it's actually needed currently; I just added it to be safe).
  • Loading branch information
philandstuff authored and aron committed Nov 28, 2024
1 parent 1d413f9 commit fd123b8
Show file tree
Hide file tree
Showing 16 changed files with 364 additions and 85 deletions.
2 changes: 1 addition & 1 deletion pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ func ValidateModelPythonVersion(cfg *Config) error {
return fmt.Errorf("minimum supported Python version is %d.%d. requested %s",
MinimumMajorPythonVersion, MinimumMinorPythonVersion, version)
}
if cfg.Concurrency.Max > 1 && minor < MinimumMinorPythonVersionForConcurrency {
if cfg.Concurrency != nil && cfg.Concurrency.Max > 1 && minor < MinimumMinorPythonVersionForConcurrency {
return fmt.Errorf("when concurrency.max is set, minimum supported Python version is %d.%d. requested %s",
MinimumMajorPythonVersion, MinimumMinorPythonVersionForConcurrency, version)
}
Expand Down
8 changes: 6 additions & 2 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,13 @@ func TestValidateModelPythonVersion(t *testing.T) {
Build: &Build{
PythonVersion: tc.pythonVersion,
},
Concurrency: &Concurrency{
}
if tc.concurrencyMax != 0 {
// the Concurrency key is optional, only populate it if
// concurrencyMax is a non-default value
cfg.Concurrency = &Concurrency{
Max: tc.concurrencyMax,
},
}
}
err := ValidateModelPythonVersion(cfg)
if tc.expectedErr != "" {
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ tests = [
"numpy",
"pillow",
"pytest",
"pytest-asyncio",
"pytest-httpserver",
"pytest-timeout",
"pytest-xdist",
Expand Down Expand Up @@ -70,6 +71,9 @@ reportUnusedExpression = "warning"
[tool.pyright.defineConstant]
PYDANTIC_V2 = true

[tool.pytest.ini_options]
asyncio_default_fixture_loop_scope = "function"

[tool.setuptools]
include-package-data = false

Expand Down
7 changes: 7 additions & 0 deletions python/cog/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
COG_PREDICT_CODE_STRIP_ENV_VAR = "COG_PREDICT_CODE_STRIP"
COG_TRAIN_CODE_STRIP_ENV_VAR = "COG_TRAIN_CODE_STRIP"
COG_GPU_ENV_VAR = "COG_GPU"
COG_MAX_CONCURRENCY_ENV_VAR = "COG_MAX_CONCURRENCY"
PREDICT_METHOD_NAME = "predict"
TRAIN_METHOD_NAME = "train"

Expand Down Expand Up @@ -98,6 +99,12 @@ def requires_gpu(self) -> bool:
"""Whether this cog requires the use of a GPU."""
return bool(self._cog_config.get("build", {}).get("gpu", False))

@property
@env_property(COG_MAX_CONCURRENCY_ENV_VAR)
def max_concurrency(self) -> int:
"""The maximum concurrency of predictions supported by this model. Defaults to 1."""
return int(self._cog_config.get("concurrency", {}).get("max", 1))

def _predictor_code(
self,
module_path: str,
Expand Down
23 changes: 13 additions & 10 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,11 @@ async def start_shutdown() -> Any:
add_setup_failed_routes(app, started_at, msg)
return app

worker = make_worker(predictor_ref=cog_config.get_predictor_ref(mode=mode))
runner = PredictionRunner(worker=worker)
worker = make_worker(
predictor_ref=cog_config.get_predictor_ref(mode=mode),
max_concurrency=cog_config.max_concurrency,
)
runner = PredictionRunner(worker=worker, max_concurrency=cog_config.max_concurrency)

class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType)):
pass
Expand Down Expand Up @@ -215,7 +218,7 @@ class TrainingRequest(
response_model=TrainingResponse,
response_model_exclude_unset=True,
)
def train(
async def train(
request: TrainingRequest = Body(default=None),
prefer: Optional[str] = Header(default=None),
traceparent: Optional[str] = Header(
Expand All @@ -228,7 +231,7 @@ def train(
respond_async = prefer == "respond-async"

with trace_context(make_trace_context(traceparent, tracestate)):
return _predict(
return await _predict(
request=request,
response_type=TrainingResponse,
respond_async=respond_async,
Expand All @@ -239,7 +242,7 @@ def train(
response_model=TrainingResponse,
response_model_exclude_unset=True,
)
def train_idempotent(
async def train_idempotent(
training_id: str = Path(..., title="Training ID"),
request: TrainingRequest = Body(..., title="Training Request"),
prefer: Optional[str] = Header(default=None),
Expand Down Expand Up @@ -276,7 +279,7 @@ def train_idempotent(
respond_async = prefer == "respond-async"

with trace_context(make_trace_context(traceparent, tracestate)):
return _predict(
return await _predict(
request=request,
response_type=TrainingResponse,
respond_async=respond_async,
Expand Down Expand Up @@ -355,7 +358,7 @@ async def predict(
respond_async = prefer == "respond-async"

with trace_context(make_trace_context(traceparent, tracestate)):
return _predict(
return await _predict(
request=request,
response_type=PredictionResponse,
respond_async=respond_async,
Expand Down Expand Up @@ -403,13 +406,13 @@ async def predict_idempotent(
respond_async = prefer == "respond-async"

with trace_context(make_trace_context(traceparent, tracestate)):
return _predict(
return await _predict(
request=request,
response_type=PredictionResponse,
respond_async=respond_async,
)

def _predict(
async def _predict(
*,
request: Optional[PredictionRequest],
response_type: Type[schema.PredictionResponse],
Expand Down Expand Up @@ -451,7 +454,7 @@ def _predict(
)

# Otherwise, wait for the prediction to complete...
predict_task.wait()
await predict_task.wait_async()

# ...and return the result.
if PYDANTIC_V2:
Expand Down
70 changes: 51 additions & 19 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import asyncio
import io
import threading
import traceback
import uuid
from abc import ABC, abstractmethod
from concurrent.futures import Future
from datetime import datetime, timezone
Expand Down Expand Up @@ -60,13 +63,15 @@ class PredictionRunner:
def __init__(
self,
*,
max_concurrency: int = 1,
worker: Worker,
) -> None:
self._worker = worker
self._max_concurrency = max_concurrency

self._setup_task: Optional[SetupTask] = None
self._predict_task: Optional[PredictTask] = None
self._prediction_id = None
self._predict_tasks: Dict[str, PredictTask] = {}
self._predict_tasks_lock = threading.Lock()

def setup(self) -> "SetupTask":
assert self._setup_task is None, "do not call setup twice"
Expand All @@ -88,8 +93,21 @@ def predict(

task_kwargs = task_kwargs or {}

self._predict_task = PredictTask(prediction, **task_kwargs)
self._prediction_id = prediction.id
tag = prediction.id
if tag is None:
tag = uuid.uuid4().hex

task = PredictTask(prediction, **task_kwargs)

with self._predict_tasks_lock:
# first remove finished tasks so we don't grow the dictionary without bound
done_ids = [
id for id in self._predict_tasks if self._predict_tasks[id].done()
]
for id in done_ids:
del self._predict_tasks[id]

self._predict_tasks[tag] = task

if isinstance(prediction.input, BaseInput):
if PYDANTIC_V2:
Expand All @@ -101,18 +119,15 @@ def predict(
else:
payload = prediction.input.copy()

sid = self._worker.subscribe(self._predict_task.handle_event)
self._predict_task.track(self._worker.predict(payload))
self._predict_task.add_done_callback(lambda _: self._worker.unsubscribe(sid))
sid = self._worker.subscribe(task.handle_event, tag=tag)
task.track(self._worker.predict(payload, tag=tag))
task.add_done_callback(lambda _: self._worker.unsubscribe(sid))

return self._predict_task
return task

def get_predict_task(self, id: str) -> Optional["PredictTask"]:
if not self._predict_task:
return None
if self._predict_task.result.id != id:
return None
return self._predict_task
with self._predict_tasks_lock:
return self._predict_tasks.get(id, None)

def is_busy(self) -> bool:
try:
Expand All @@ -124,9 +139,13 @@ def is_busy(self) -> bool:
def cancel(self, prediction_id: str) -> None:
if not prediction_id:
raise ValueError("prediction_id is required")
if self._prediction_id != prediction_id:
raise UnknownPredictionError("id mismatch")
self._worker.cancel()
with self._predict_tasks_lock:
if (
prediction_id not in self._predict_tasks
or self._predict_tasks[prediction_id].done()
):
raise UnknownPredictionError("unknown prediction id")
self._worker.cancel(tag=prediction_id)

def _raise_if_busy(self) -> None:
if self._setup_task is None:
Expand All @@ -135,9 +154,17 @@ def _raise_if_busy(self) -> None:
if not self._setup_task.done():
# Setup is still running.
raise RunnerBusyError("setup is not complete")
if self._predict_task is not None and not self._predict_task.done():
# Prediction is still running.
raise RunnerBusyError("prediction running")

with self._predict_tasks_lock:
processing_tasks = [
id for id in self._predict_tasks if not self._predict_tasks[id].done()
]

if len(processing_tasks) >= self._max_concurrency:
# We're at max concurrency
if self._max_concurrency == 1:
raise RunnerBusyError("prediction running")
raise RunnerBusyError("max predictions running")


T = TypeVar("T")
Expand Down Expand Up @@ -317,6 +344,11 @@ def done(self) -> bool:
assert self._fut, "call track before checking done"
return self._fut.done()

async def wait_async(self) -> None:
assert self._fut, "call track before waiting"
await asyncio.wrap_future(self._fut)
return None

def wait(self, timeout: Optional[float] = None) -> None:
assert self._fut, "call track before waiting"
self._fut.result(timeout=timeout)
Expand Down
55 changes: 40 additions & 15 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def __init__(
self,
predictor_ref: str,
events: Connection,
max_concurrency: int = 1,
tee_output: bool = True,
) -> None:
self._predictor_ref = predictor_ref
Expand All @@ -358,6 +359,7 @@ def __init__(
)
self._tee_output = tee_output
self._cancelable = False
self._max_concurrency = max_concurrency

# for synchronous predictors only! async predictors use _tag_var instead
self._sync_tag: Optional[str] = None
Expand Down Expand Up @@ -425,6 +427,25 @@ def _setup(self, redirector: AsyncStreamRedirector) -> None:
# Could be a function or a class
if hasattr(self._predictor, "setup"):
run_setup(self._predictor)

predict = get_predict(self._predictor)

is_async_predictor = inspect.iscoroutinefunction(
predict
) or inspect.isasyncgenfunction(predict)

# Async models require python >= 3.11 so we can use asyncio.TaskGroup
# We should check for this before getting to this point
if is_async_predictor and sys.version_info < (3, 11):
raise FatalWorkerException(
"Cog requires python >=3.11 for `async def predict(..)` support"
)

if self._max_concurrency > 1 and not is_async_predictor:
raise FatalWorkerException(
"max_concurrency>1 requires `async def predict()`"
)

except Exception as e: # pylint: disable=broad-exception-caught
traceback.print_exc()
done.error = True
Expand Down Expand Up @@ -480,20 +501,19 @@ async def _aloop(
task = None

with scope(self._loop_scope()), redirector:
while True:
e = cast(Envelope, await self._events.recv())
if isinstance(e.event, Cancel) and task and self._cancelable:
task.cancel()
elif isinstance(e.event, Shutdown):
break
elif isinstance(e.event, PredictionInput):
task = asyncio.create_task(
self._apredict(e.tag, e.event.payload, predict, redirector)
)
else:
print(f"Got unexpected event: {e.event}", file=sys.stderr)
if task:
await task
async with asyncio.TaskGroup() as tg:
while True:
e = cast(Envelope, await self._events.recv())
if isinstance(e.event, Cancel) and task and self._cancelable:
task.cancel()
elif isinstance(e.event, Shutdown):
break
elif isinstance(e.event, PredictionInput):
task = tg.create_task(
self._apredict(e.tag, e.event.payload, predict, redirector)
)
else:
print(f"Got unexpected event: {e.event}", file=sys.stderr)

def _loop_scope(self) -> Scope:
return Scope(record_metric=self.record_metric)
Expand Down Expand Up @@ -683,7 +703,12 @@ def make_worker(
predictor_ref: str, tee_output: bool = True, max_concurrency: int = 1
) -> Worker:
parent_conn, child_conn = _spawn.Pipe()
child = _ChildWorker(predictor_ref, events=child_conn, tee_output=tee_output)
child = _ChildWorker(
predictor_ref,
events=child_conn,
tee_output=tee_output,
max_concurrency=max_concurrency,
)
parent = Worker(child=child, events=parent_conn, max_concurrency=max_concurrency)
return parent

Expand Down
5 changes: 5 additions & 0 deletions python/cog/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class ExperimentalFeatureWarning(Warning):

class CogConfig(TypedDict): # pylint: disable=too-many-ancestors
build: "CogBuildConfig"
concurrency: "CogConcurrencyConfig"
image: NotRequired[str]
predict: NotRequired[str]
train: NotRequired[str]
Expand All @@ -58,6 +59,10 @@ class CogBuildConfig(TypedDict, total=False): # pylint: disable=too-many-ancest
run: Optional[Union[List[str], List[Dict[str, Any]]]]


class CogConcurrencyConfig(TypedDict, total=False): # pylint: disable=too-many-ancestors
max: Optional[int]


def Input( # pylint: disable=invalid-name, too-many-arguments
default: Any = ...,
description: str = None,
Expand Down
Loading

0 comments on commit fd123b8

Please sign in to comment.