diff --git a/pkg/config/config.go b/pkg/config/config.go index 38a1e64dad..f3e28eb7f6 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -263,7 +263,7 @@ func ValidateModelPythonVersion(cfg *Config) error { return fmt.Errorf("minimum supported Python version is %d.%d. requested %s", MinimumMajorPythonVersion, MinimumMinorPythonVersion, version) } - if cfg.Concurrency.Max > 1 && minor < MinimumMinorPythonVersionForConcurrency { + if cfg.Concurrency != nil && cfg.Concurrency.Max > 1 && minor < MinimumMinorPythonVersionForConcurrency { return fmt.Errorf("when concurrency.max is set, minimum supported Python version is %d.%d. requested %s", MinimumMajorPythonVersion, MinimumMinorPythonVersionForConcurrency, version) } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 798a865423..934fba5a70 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -64,9 +64,13 @@ func TestValidateModelPythonVersion(t *testing.T) { Build: &Build{ PythonVersion: tc.pythonVersion, }, - Concurrency: &Concurrency{ + } + if tc.concurrencyMax != 0 { + // the Concurrency key is optional, only populate it if + // concurrencyMax is a non-default value + cfg.Concurrency = &Concurrency{ Max: tc.concurrencyMax, - }, + } } err := ValidateModelPythonVersion(cfg) if tc.expectedErr != "" { diff --git a/pyproject.toml b/pyproject.toml index cfe833b8ac..e467288ba9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ tests = [ "numpy", "pillow", "pytest", + "pytest-asyncio", "pytest-httpserver", "pytest-timeout", "pytest-xdist", @@ -70,6 +71,9 @@ reportUnusedExpression = "warning" [tool.pyright.defineConstant] PYDANTIC_V2 = true +[tool.pytest.ini_options] +asyncio_default_fixture_loop_scope = "function" + [tool.setuptools] include-package-data = false diff --git a/python/cog/config.py b/python/cog/config.py index 44675c79dc..d0a37fb777 100644 --- a/python/cog/config.py +++ b/python/cog/config.py @@ -30,6 +30,7 @@ COG_PREDICT_CODE_STRIP_ENV_VAR = "COG_PREDICT_CODE_STRIP" COG_TRAIN_CODE_STRIP_ENV_VAR = "COG_TRAIN_CODE_STRIP" COG_GPU_ENV_VAR = "COG_GPU" +COG_MAX_CONCURRENCY_ENV_VAR = "COG_MAX_CONCURRENCY" PREDICT_METHOD_NAME = "predict" TRAIN_METHOD_NAME = "train" @@ -98,6 +99,12 @@ def requires_gpu(self) -> bool: """Whether this cog requires the use of a GPU.""" return bool(self._cog_config.get("build", {}).get("gpu", False)) + @property + @env_property(COG_MAX_CONCURRENCY_ENV_VAR) + def max_concurrency(self) -> int: + """The maximum concurrency of predictions supported by this model. Defaults to 1.""" + return int(self._cog_config.get("concurrency", {}).get("max", 1)) + def _predictor_code( self, module_path: str, diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 75c04d1c9f..c63fc39431 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -162,8 +162,11 @@ async def start_shutdown() -> Any: add_setup_failed_routes(app, started_at, msg) return app - worker = make_worker(predictor_ref=cog_config.get_predictor_ref(mode=mode)) - runner = PredictionRunner(worker=worker) + worker = make_worker( + predictor_ref=cog_config.get_predictor_ref(mode=mode), + max_concurrency=cog_config.max_concurrency, + ) + runner = PredictionRunner(worker=worker, max_concurrency=cog_config.max_concurrency) class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType)): pass @@ -215,7 +218,7 @@ class TrainingRequest( response_model=TrainingResponse, response_model_exclude_unset=True, ) - def train( + async def train( request: TrainingRequest = Body(default=None), prefer: Optional[str] = Header(default=None), traceparent: Optional[str] = Header( @@ -228,7 +231,7 @@ def train( respond_async = prefer == "respond-async" with trace_context(make_trace_context(traceparent, tracestate)): - return _predict( + return await _predict( request=request, response_type=TrainingResponse, respond_async=respond_async, @@ -239,7 +242,7 @@ def train( response_model=TrainingResponse, response_model_exclude_unset=True, ) - def train_idempotent( + async def train_idempotent( training_id: str = Path(..., title="Training ID"), request: TrainingRequest = Body(..., title="Training Request"), prefer: Optional[str] = Header(default=None), @@ -276,7 +279,7 @@ def train_idempotent( respond_async = prefer == "respond-async" with trace_context(make_trace_context(traceparent, tracestate)): - return _predict( + return await _predict( request=request, response_type=TrainingResponse, respond_async=respond_async, @@ -355,7 +358,7 @@ async def predict( respond_async = prefer == "respond-async" with trace_context(make_trace_context(traceparent, tracestate)): - return _predict( + return await _predict( request=request, response_type=PredictionResponse, respond_async=respond_async, @@ -403,13 +406,13 @@ async def predict_idempotent( respond_async = prefer == "respond-async" with trace_context(make_trace_context(traceparent, tracestate)): - return _predict( + return await _predict( request=request, response_type=PredictionResponse, respond_async=respond_async, ) - def _predict( + async def _predict( *, request: Optional[PredictionRequest], response_type: Type[schema.PredictionResponse], @@ -451,7 +454,7 @@ def _predict( ) # Otherwise, wait for the prediction to complete... - predict_task.wait() + await predict_task.wait_async() # ...and return the result. if PYDANTIC_V2: diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index 7468be2572..f9e8573439 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -1,5 +1,8 @@ +import asyncio import io +import threading import traceback +import uuid from abc import ABC, abstractmethod from concurrent.futures import Future from datetime import datetime, timezone @@ -60,13 +63,15 @@ class PredictionRunner: def __init__( self, *, + max_concurrency: int = 1, worker: Worker, ) -> None: self._worker = worker + self._max_concurrency = max_concurrency self._setup_task: Optional[SetupTask] = None - self._predict_task: Optional[PredictTask] = None - self._prediction_id = None + self._predict_tasks: Dict[str, PredictTask] = {} + self._predict_tasks_lock = threading.Lock() def setup(self) -> "SetupTask": assert self._setup_task is None, "do not call setup twice" @@ -88,8 +93,21 @@ def predict( task_kwargs = task_kwargs or {} - self._predict_task = PredictTask(prediction, **task_kwargs) - self._prediction_id = prediction.id + tag = prediction.id + if tag is None: + tag = uuid.uuid4().hex + + task = PredictTask(prediction, **task_kwargs) + + with self._predict_tasks_lock: + # first remove finished tasks so we don't grow the dictionary without bound + done_ids = [ + id for id in self._predict_tasks if self._predict_tasks[id].done() + ] + for id in done_ids: + del self._predict_tasks[id] + + self._predict_tasks[tag] = task if isinstance(prediction.input, BaseInput): if PYDANTIC_V2: @@ -101,18 +119,15 @@ def predict( else: payload = prediction.input.copy() - sid = self._worker.subscribe(self._predict_task.handle_event) - self._predict_task.track(self._worker.predict(payload)) - self._predict_task.add_done_callback(lambda _: self._worker.unsubscribe(sid)) + sid = self._worker.subscribe(task.handle_event, tag=tag) + task.track(self._worker.predict(payload, tag=tag)) + task.add_done_callback(lambda _: self._worker.unsubscribe(sid)) - return self._predict_task + return task def get_predict_task(self, id: str) -> Optional["PredictTask"]: - if not self._predict_task: - return None - if self._predict_task.result.id != id: - return None - return self._predict_task + with self._predict_tasks_lock: + return self._predict_tasks.get(id, None) def is_busy(self) -> bool: try: @@ -124,9 +139,13 @@ def is_busy(self) -> bool: def cancel(self, prediction_id: str) -> None: if not prediction_id: raise ValueError("prediction_id is required") - if self._prediction_id != prediction_id: - raise UnknownPredictionError("id mismatch") - self._worker.cancel() + with self._predict_tasks_lock: + if ( + prediction_id not in self._predict_tasks + or self._predict_tasks[prediction_id].done() + ): + raise UnknownPredictionError("unknown prediction id") + self._worker.cancel(tag=prediction_id) def _raise_if_busy(self) -> None: if self._setup_task is None: @@ -135,9 +154,17 @@ def _raise_if_busy(self) -> None: if not self._setup_task.done(): # Setup is still running. raise RunnerBusyError("setup is not complete") - if self._predict_task is not None and not self._predict_task.done(): - # Prediction is still running. - raise RunnerBusyError("prediction running") + + with self._predict_tasks_lock: + processing_tasks = [ + id for id in self._predict_tasks if not self._predict_tasks[id].done() + ] + + if len(processing_tasks) >= self._max_concurrency: + # We're at max concurrency + if self._max_concurrency == 1: + raise RunnerBusyError("prediction running") + raise RunnerBusyError("max predictions running") T = TypeVar("T") @@ -317,6 +344,11 @@ def done(self) -> bool: assert self._fut, "call track before checking done" return self._fut.done() + async def wait_async(self) -> None: + assert self._fut, "call track before waiting" + await asyncio.wrap_future(self._fut) + return None + def wait(self, timeout: Optional[float] = None) -> None: assert self._fut, "call track before waiting" self._fut.result(timeout=timeout) diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index 94c9e8b02b..3235ac678b 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -349,6 +349,7 @@ def __init__( self, predictor_ref: str, events: Connection, + max_concurrency: int = 1, tee_output: bool = True, ) -> None: self._predictor_ref = predictor_ref @@ -358,6 +359,7 @@ def __init__( ) self._tee_output = tee_output self._cancelable = False + self._max_concurrency = max_concurrency # for synchronous predictors only! async predictors use _tag_var instead self._sync_tag: Optional[str] = None @@ -425,6 +427,25 @@ def _setup(self, redirector: AsyncStreamRedirector) -> None: # Could be a function or a class if hasattr(self._predictor, "setup"): run_setup(self._predictor) + + predict = get_predict(self._predictor) + + is_async_predictor = inspect.iscoroutinefunction( + predict + ) or inspect.isasyncgenfunction(predict) + + # Async models require python >= 3.11 so we can use asyncio.TaskGroup + # We should check for this before getting to this point + if is_async_predictor and sys.version_info < (3, 11): + raise FatalWorkerException( + "Cog requires python >=3.11 for `async def predict(..)` support" + ) + + if self._max_concurrency > 1 and not is_async_predictor: + raise FatalWorkerException( + "max_concurrency>1 requires `async def predict()`" + ) + except Exception as e: # pylint: disable=broad-exception-caught traceback.print_exc() done.error = True @@ -480,20 +501,19 @@ async def _aloop( task = None with scope(self._loop_scope()), redirector: - while True: - e = cast(Envelope, await self._events.recv()) - if isinstance(e.event, Cancel) and task and self._cancelable: - task.cancel() - elif isinstance(e.event, Shutdown): - break - elif isinstance(e.event, PredictionInput): - task = asyncio.create_task( - self._apredict(e.tag, e.event.payload, predict, redirector) - ) - else: - print(f"Got unexpected event: {e.event}", file=sys.stderr) - if task: - await task + async with asyncio.TaskGroup() as tg: + while True: + e = cast(Envelope, await self._events.recv()) + if isinstance(e.event, Cancel) and task and self._cancelable: + task.cancel() + elif isinstance(e.event, Shutdown): + break + elif isinstance(e.event, PredictionInput): + task = tg.create_task( + self._apredict(e.tag, e.event.payload, predict, redirector) + ) + else: + print(f"Got unexpected event: {e.event}", file=sys.stderr) def _loop_scope(self) -> Scope: return Scope(record_metric=self.record_metric) @@ -683,7 +703,12 @@ def make_worker( predictor_ref: str, tee_output: bool = True, max_concurrency: int = 1 ) -> Worker: parent_conn, child_conn = _spawn.Pipe() - child = _ChildWorker(predictor_ref, events=child_conn, tee_output=tee_output) + child = _ChildWorker( + predictor_ref, + events=child_conn, + tee_output=tee_output, + max_concurrency=max_concurrency, + ) parent = Worker(child=child, events=parent_conn, max_concurrency=max_concurrency) return parent diff --git a/python/cog/types.py b/python/cog/types.py index 29d868c9e7..f6329dc6e4 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -43,6 +43,7 @@ class ExperimentalFeatureWarning(Warning): class CogConfig(TypedDict): # pylint: disable=too-many-ancestors build: "CogBuildConfig" + concurrency: "CogConcurrencyConfig" image: NotRequired[str] predict: NotRequired[str] train: NotRequired[str] @@ -58,6 +59,10 @@ class CogBuildConfig(TypedDict, total=False): # pylint: disable=too-many-ancest run: Optional[Union[List[str], List[Dict[str, Any]]]] +class CogConcurrencyConfig(TypedDict, total=False): # pylint: disable=too-many-ancestors + max: Optional[int] + + def Input( # pylint: disable=invalid-name, too-many-arguments default: Any = ..., description: str = None, diff --git a/python/tests/server/conftest.py b/python/tests/server/conftest.py index 80e530f61f..5499e60497 100644 --- a/python/tests/server/conftest.py +++ b/python/tests/server/conftest.py @@ -1,8 +1,9 @@ import os +import sys import threading import time from contextlib import ExitStack -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from unittest import mock import pytest @@ -26,6 +27,7 @@ class WorkerConfig: fixture_name: str setup: bool = True max_concurrency: int = 1 + min_python: Optional[Tuple[int, int]] = None def pytest_make_parametrize_id(config, val): @@ -71,7 +73,7 @@ def uses_predictor_with_client_options(name, **options): ) -def uses_worker(name_or_names, setup=True, max_concurrency=1): +def uses_worker(name_or_names, setup=True, max_concurrency=1, min_python=None): """ Decorator for tests that require a Worker instance. `name_or_names` can be a single fixture name, or a sequence (list, tuple) of fixture names. If @@ -81,16 +83,34 @@ def uses_worker(name_or_names, setup=True, max_concurrency=1): """ if isinstance(name_or_names, (tuple, list)): values = ( - WorkerConfig(fixture_name=n, setup=setup, max_concurrency=max_concurrency) + WorkerConfig( + fixture_name=n, + setup=setup, + max_concurrency=max_concurrency, + min_python=min_python, + ) for n in name_or_names ) else: values = ( WorkerConfig( - fixture_name=name_or_names, setup=setup, max_concurrency=max_concurrency + fixture_name=name_or_names, + setup=setup, + max_concurrency=max_concurrency, + min_python=min_python, ), ) - return pytest.mark.parametrize("worker", values, indirect=True) + return uses_worker_configs(values) + + +def uses_worker_configs(configs): + """ + Decorator for tests that require a Worker instance. The test will be + run once for each worker. `configs` is a sequence (list, tuple, generator) + of WorkerConfig. + + """ + return pytest.mark.parametrize("worker", configs, indirect=True) def make_client( @@ -151,6 +171,13 @@ def static_schema(client) -> dict: @pytest.fixture def worker(request): ref = _fixture_path(request.param.fixture_name) + if ( + request.param.min_python is not None + and sys.version_info < request.param.min_python + ): + pytest.skip( + f"Test requires python {request.param.min_python[0]}.{request.param.min_python[1]}" + ) w = make_worker( predictor_ref=ref, tee_output=False, diff --git a/python/tests/server/test_runner.py b/python/tests/server/test_runner.py index ccd5aa174b..7cfcfd7b6a 100644 --- a/python/tests/server/test_runner.py +++ b/python/tests/server/test_runner.py @@ -38,18 +38,23 @@ def __call__(self): class FakeWorker: def __init__(self): self.subscribers = {} - self.last_prediction_payload = None + self.subscribers_by_tag = {} self._setup_future = None - self._predict_future = None + self._predict_futures = {} + self.last_prediction_payload = None def subscribe(self, subscriber, tag=None): sid = uuid.uuid4() - self.subscribers[sid] = subscriber + self.subscribers[sid] = tag + if tag not in self.subscribers_by_tag: + self.subscribers_by_tag[tag] = {} + self.subscribers_by_tag[tag][sid] = subscriber return sid def unsubscribe(self, sid): - del self.subscribers[sid] + tag = self.subscribers.pop(sid) + del self.subscribers_by_tag[tag][sid] def setup(self): assert self._setup_future is None @@ -61,32 +66,38 @@ def run_setup(self, events): if isinstance(event, Exception): self._setup_future.set_exception(event) return - for subscriber in self.subscribers.values(): + for subscriber in self.subscribers_by_tag.get(None, {}).values(): subscriber(event) if isinstance(event, Done): self._setup_future.set_result(event) def predict(self, payload, tag=None): - assert self._predict_future is None or self._predict_future.done() + assert tag not in self._predict_futures or self._predict_futures[tag].done() self.last_prediction_payload = payload - self._predict_future = Future() - return self._predict_future - - def run_predict(self, events): + self._predict_futures[tag] = Future() + print(f"setting {tag}, now {self._predict_futures}") + return self._predict_futures[tag] + + def run_predict(self, events, id=None): + if id is None: + if len(self._predict_futures) != 1: + raise ValueError("Could not guess prediction id, please specify") + id = next(iter(self._predict_futures)) for event in events: if isinstance(event, Exception): - self._predict_future.set_exception(event) + self._predict_futures[id].set_exception(event) return - for subscriber in self.subscribers.values(): + for subscriber in self.subscribers_by_tag.get(id, {}).values(): subscriber(event) if isinstance(event, Done): - self._predict_future.set_result(event) + print(f"reading {id} from {self._predict_futures}") + self._predict_futures[id].set_result(event) def cancel(self, tag=None): done = Done(canceled=True) - for subscriber in self.subscribers.values(): + for subscriber in self.subscribers_by_tag.get(tag, {}).values(): subscriber(done) - self._predict_future.set_result(done) + self._predict_futures[tag].set_result(done) def test_prediction_runner_setup_success(): @@ -229,11 +240,11 @@ def test_prediction_runner_predict_after_predict_completes(): r.setup() w.run_setup([Done()]) - r.predict(PredictionRequest(input={"text": "giraffes"})) - w.run_predict([Done()]) + r.predict(PredictionRequest(id="p-1", input={"text": "giraffes"})) + w.run_predict([Done()], id="p-1") - r.predict(PredictionRequest(input={"text": "elephants"})) - w.run_predict([Done()]) + r.predict(PredictionRequest(id="p-2", input={"text": "elephants"})) + w.run_predict([Done()], id="p-2") assert w.last_prediction_payload == {"text": "elephants"} @@ -257,6 +268,31 @@ def test_prediction_runner_is_busy(): assert not r.is_busy() +def test_prediction_runner_is_busy_concurrency(): + w = FakeWorker() + r = PredictionRunner(worker=w, max_concurrency=3) + + assert r.is_busy() + + r.setup() + assert r.is_busy() + + w.run_setup([Done()]) + assert not r.is_busy() + + r.predict(PredictionRequest(id="1", input={"text": "elephants"})) + assert not r.is_busy() + + r.predict(PredictionRequest(id="2", input={"text": "elephants"})) + assert not r.is_busy() + + r.predict(PredictionRequest(id="3", input={"text": "elephants"})) + assert r.is_busy() + + w.run_predict([Done()], id="1") + assert not r.is_busy() + + def test_prediction_runner_predict_cancelation(): w = FakeWorker() r = PredictionRunner(worker=w) @@ -299,6 +335,23 @@ def test_prediction_runner_predict_cancelation_multiple_predictions(): assert task2.result.status == Status.CANCELED +def test_prediction_runner_predict_cancelation_concurrent_predictions(): + w = FakeWorker() + r = PredictionRunner(worker=w, max_concurrency=5) + + r.setup() + w.run_setup([Done()]) + + task1 = r.predict(PredictionRequest(id="abcd1234", input={"text": "giraffes"})) + + task2 = r.predict(PredictionRequest(id="defg6789", input={"text": "elephants"})) + + r.cancel("abcd1234") + w.run_predict([Done()], id="defg6789") + assert task1.result.status == Status.CANCELED + assert task2.result.status == Status.SUCCEEDED + + def test_prediction_runner_setup_e2e(): w = make_worker(predictor_ref=_fixture_path("sleep")) r = PredictionRunner(worker=w) diff --git a/python/tests/server/test_worker.py b/python/tests/server/test_worker.py index 46009dea06..7dab8d027e 100644 --- a/python/tests/server/test_worker.py +++ b/python/tests/server/test_worker.py @@ -1,5 +1,6 @@ import multiprocessing import os +import sys import threading import time import uuid @@ -30,7 +31,7 @@ from cog.server.exceptions import FatalWorkerException, InvalidStateException from cog.server.worker import Worker, _PublicEventType -from .conftest import WorkerConfig, uses_worker +from .conftest import WorkerConfig, uses_worker, uses_worker_configs if TYPE_CHECKING: from concurrent.futures import Future @@ -76,7 +77,7 @@ }, ), ( - WorkerConfig("record_metric_async"), + WorkerConfig("record_metric_async", min_python=(3, 11)), {"name": ST_NAMES}, { "foo": 123, @@ -91,7 +92,7 @@ lambda x: f"hello, {x['name']}", ), ( - WorkerConfig("hello_world_async"), + WorkerConfig("hello_world_async", min_python=(3, 11)), {"name": ST_NAMES}, lambda x: f"hello, {x['name']}", ), @@ -114,7 +115,7 @@ "writing to stderr at import time\n", ), ( - WorkerConfig("logging_async", setup=False), + WorkerConfig("logging_async", setup=False, min_python=(3, 11)), ("writing to stdout at import time\n" "setting up predictor\n"), "writing to stderr at import time\n", ), @@ -127,12 +128,22 @@ ("WARNING:root:writing log message\n" "writing to stderr\n"), ), ( - WorkerConfig("logging_async"), + WorkerConfig("logging_async", min_python=(3, 11)), ("writing with print\n"), ("WARNING:root:writing log message\n" "writing to stderr\n"), ), ] +SLEEP_FIXTURES = [ + WorkerConfig("sleep"), + WorkerConfig("sleep_async", min_python=(3, 11)), +] + +SLEEP_NO_SETUP_FIXTURES = [ + WorkerConfig("sleep", setup=False), + WorkerConfig("sleep_async", min_python=(3, 11), setup=False), +] + @define class Result: @@ -237,8 +248,12 @@ def test_no_exceptions_from_recoverable_failures(worker): _process(worker, lambda: worker.predict({})) -# TODO test this works with errors and cancelations and the like -@uses_worker(["simple", "simple_async"]) +@uses_worker_configs( + [ + WorkerConfig("simple"), + WorkerConfig("simple_async", min_python=(3, 11)), + ] +) def test_can_subscribe_for_a_specific_tag(worker): tag = "123" @@ -260,12 +275,12 @@ def test_can_subscribe_for_a_specific_tag(worker): worker.unsubscribe(subid) -@uses_worker("sleep_async", max_concurrency=5) +@uses_worker("sleep_async", max_concurrency=5, min_python=(3, 11)) def test_can_run_predictions_concurrently_on_async_predictor(worker): subids = [] try: - start = time.time() + start = time.perf_counter() futures = [] results = [] for i in range(5): @@ -279,7 +294,7 @@ def test_can_run_predictions_concurrently_on_async_predictor(worker): for fut in futures: fut.result() - end = time.time() + end = time.perf_counter() duration = end - start # we should take at least 0.5 seconds (the time for 1 prediction) but @@ -299,6 +314,40 @@ def test_can_run_predictions_concurrently_on_async_predictor(worker): worker.unsubscribe(subid) +@pytest.mark.skipif( + sys.version_info >= (3, 11), reason="Testing error message on python versions <3.11" +) +@uses_worker("simple_async", setup=False) +def test_async_predictor_on_python_3_10_or_older_raises_error(worker): + fut = worker.setup() + result = Result() + worker.subscribe(result.handle_event) + + with pytest.raises(FatalWorkerException): + fut.result() + assert result.done + assert result.done.error + assert ( + result.done.error_detail + == "Cog requires python >=3.11 for `async def predict(..)` support" + ) + + +@uses_worker("simple", max_concurrency=5, setup=False) +def test_concurrency_with_sync_predictor_raises_error(worker): + fut = worker.setup() + result = Result() + worker.subscribe(result.handle_event) + + with pytest.raises(FatalWorkerException): + fut.result() + assert result.done + assert result.done.error + assert ( + result.done.error_detail == "max_concurrency>1 requires `async def predict()`" + ) + + @uses_worker("stream_redirector_race_condition") def test_stream_redirector_race_condition(worker): """ @@ -383,7 +432,7 @@ def test_predict_logging(worker, expected_stdout, expected_stderr): assert result.stderr == expected_stderr -@uses_worker(["sleep", "sleep_async"], setup=False) +@uses_worker_configs(SLEEP_NO_SETUP_FIXTURES) def test_cancel_is_safe(worker): """ Calls to cancel at any time should not result in unexpected things @@ -417,7 +466,7 @@ def test_cancel_is_safe(worker): assert result2.output == "done in 0.1 seconds" -@uses_worker(["sleep", "sleep_async"], setup=False) +@uses_worker_configs(SLEEP_NO_SETUP_FIXTURES) def test_cancel_idempotency(worker): """ Multiple calls to cancel within the same prediction, while not necessary or @@ -449,7 +498,7 @@ def cancel_a_bunch(_): assert result2.output == "done in 0.1 seconds" -@uses_worker(["sleep", "sleep_async"]) +@uses_worker_configs(SLEEP_FIXTURES) def test_cancel_multiple_predictions(worker): """ Multiple predictions cancelled in a row shouldn't be a problem. This test @@ -467,7 +516,7 @@ def test_cancel_multiple_predictions(worker): assert not worker.predict({"sleep": 0}).result().canceled -@uses_worker(["sleep", "sleep_async"]) +@uses_worker_configs(SLEEP_FIXTURES) def test_graceful_shutdown(worker): """ On shutdown, the worker should finish running the current prediction, and diff --git a/test-integration/test_integration/fixtures/async-sleep-project/cog.yaml b/test-integration/test_integration/fixtures/async-sleep-project/cog.yaml new file mode 100644 index 0000000000..04d04bf7c8 --- /dev/null +++ b/test-integration/test_integration/fixtures/async-sleep-project/cog.yaml @@ -0,0 +1,5 @@ +build: + python_version: "3.11" +predict: "predict.py:Predictor" +concurrency: + max: 5 diff --git a/test-integration/test_integration/fixtures/async-sleep-project/predict.py b/test-integration/test_integration/fixtures/async-sleep-project/predict.py new file mode 100644 index 0000000000..e6c65797a0 --- /dev/null +++ b/test-integration/test_integration/fixtures/async-sleep-project/predict.py @@ -0,0 +1,9 @@ +import asyncio + +from cog import BasePredictor + + +class Predictor(BasePredictor): + async def predict(self, s: str, sleep: float) -> str: + await asyncio.sleep(sleep) + return f"wake up {s}" diff --git a/test-integration/test_integration/fixtures/async-string-project/cog.yaml b/test-integration/test_integration/fixtures/async-string-project/cog.yaml new file mode 100644 index 0000000000..7b6d5d4dce --- /dev/null +++ b/test-integration/test_integration/fixtures/async-string-project/cog.yaml @@ -0,0 +1,3 @@ +build: + python_version: "3.11" +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/async-string-project/predict.py b/test-integration/test_integration/fixtures/async-string-project/predict.py new file mode 100644 index 0000000000..fb2805c794 --- /dev/null +++ b/test-integration/test_integration/fixtures/async-string-project/predict.py @@ -0,0 +1,6 @@ +from cog import BasePredictor + + +class Predictor(BasePredictor): + async def predict(self, s: str) -> str: + return "hello " + s diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index 4065b69927..459f09f03e 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -1,6 +1,8 @@ +import asyncio import pathlib import shutil import subprocess +import time from pathlib import Path import httpx @@ -27,6 +29,20 @@ def test_predict_takes_string_inputs_and_returns_strings_to_stdout(): assert "falling back to slow loader" in result.stderr +def test_predict_supports_async_predictors(): + project_dir = Path(__file__).parent / "fixtures/async-string-project" + result = subprocess.run( + ["cog", "predict", "--debug", "-i", "s=world"], + cwd=project_dir, + check=True, + capture_output=True, + text=True, + timeout=DEFAULT_TIMEOUT, + ) + # stdout should be clean without any log messages so it can be piped to other commands + assert result.stdout == "hello world\n" + + def test_predict_takes_int_inputs_and_returns_ints_to_stdout(): project_dir = Path(__file__).parent / "fixtures/int-project" result = subprocess.run( @@ -322,3 +338,34 @@ def test_predict_with_subprocess_in_setup(fixture_name): assert response.status_code == 200, str(response) assert busy_count < 10 + + +@pytest.mark.asyncio +async def test_concurrent_predictions(): + async def make_request(i: int) -> httpx.Response: + return await client.post( + f"{addr}/predictions", + json={ + "id": f"id-{i}", + "input": {"s": f"sleepyhead{i}", "sleep": 1.0}, + }, + ) + + with cog_server_http_run( + Path(__file__).parent / "fixtures" / "async-sleep-project" + ) as addr: + async with httpx.AsyncClient() as client: + tasks = [] + start = time.perf_counter() + async with asyncio.TaskGroup() as tg: + for i in range(5): + tasks.append(tg.create_task(make_request(i))) + # give time for all of the predictions to be accepted, but not completed + await asyncio.sleep(0.2) + # we shut the server down, but expect all running predictions to complete + await client.post(f"{addr}/shutdown") + end = time.perf_counter() + assert (end - start) < 3.0 # ensure the predictions ran concurrently + for i, task in enumerate(tasks): + assert task.result().status_code == 200 + assert task.result().json()["output"] == f"wake up sleepyhead{i}"