From 8280f0c6b81fae5beefb017e361645deca434282 Mon Sep 17 00:00:00 2001 From: Philip Potter Date: Tue, 19 Nov 2024 10:58:29 +0000 Subject: [PATCH 01/16] cog build: validate python version is new enough to support concurrency We require python >=3.11 to support asyncio.TaskGroup --- pkg/cli/build.go | 2 +- pkg/config/config.go | 28 ++++++++++----- pkg/config/config_test.go | 76 +++++++++++++++++++++------------------ 3 files changed, 62 insertions(+), 44 deletions(-) diff --git a/pkg/cli/build.go b/pkg/cli/build.go index ebdc2e1d92..92dda89c6b 100644 --- a/pkg/cli/build.go +++ b/pkg/cli/build.go @@ -64,7 +64,7 @@ func buildCommand(cmd *cobra.Command, args []string) error { imageName = config.DockerImageName(projectDir) } - err = config.ValidateModelPythonVersion(cfg.Build.PythonVersion) + err = config.ValidateModelPythonVersion(cfg) if err != nil { return err } diff --git a/pkg/config/config.go b/pkg/config/config.go index 4a282deebd..38a1e64dad 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -30,9 +30,10 @@ var ( // TODO(andreas): suggest valid torchvision versions (e.g. if the user wants to use 0.8.0, suggest 0.8.1) const ( - MinimumMajorPythonVersion int = 3 - MinimumMinorPythonVersion int = 8 - MinimumMajorCudaVersion int = 11 + MinimumMajorPythonVersion int = 3 + MinimumMinorPythonVersion int = 8 + MinimumMinorPythonVersionForConcurrency int = 11 + MinimumMajorCudaVersion int = 11 ) type RunItem struct { @@ -58,16 +59,21 @@ type Build struct { pythonRequirementsContent []string } +type Concurrency struct { + Max int `json:"max,omitempty" yaml:"max"` +} + type Example struct { Input map[string]string `json:"input" yaml:"input"` Output string `json:"output" yaml:"output"` } type Config struct { - Build *Build `json:"build" yaml:"build"` - Image string `json:"image,omitempty" yaml:"image"` - Predict string `json:"predict,omitempty" yaml:"predict"` - Train string `json:"train,omitempty" yaml:"train"` + Build *Build `json:"build" yaml:"build"` + Image string `json:"image,omitempty" yaml:"image"` + Predict string `json:"predict,omitempty" yaml:"predict"` + Train string `json:"train,omitempty" yaml:"train"` + Concurrency *Concurrency `json:"concurrency,omitempty" yaml:"concurrency"` } func DefaultConfig() *Config { @@ -244,7 +250,9 @@ func splitPythonVersion(version string) (major int, minor int, err error) { return major, minor, nil } -func ValidateModelPythonVersion(version string) error { +func ValidateModelPythonVersion(cfg *Config) error { + version := cfg.Build.PythonVersion + // we check for minimum supported here major, minor, err := splitPythonVersion(version) if err != nil { @@ -255,6 +263,10 @@ func ValidateModelPythonVersion(version string) error { return fmt.Errorf("minimum supported Python version is %d.%d. requested %s", MinimumMajorPythonVersion, MinimumMinorPythonVersion, version) } + if cfg.Concurrency.Max > 1 && minor < MinimumMinorPythonVersionForConcurrency { + return fmt.Errorf("when concurrency.max is set, minimum supported Python version is %d.%d. requested %s", + MinimumMajorPythonVersion, MinimumMinorPythonVersionForConcurrency, version) + } return nil } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index d0ec825ab8..798a865423 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -13,47 +13,64 @@ import ( func TestValidateModelPythonVersion(t *testing.T) { testCases := []struct { - name string - input string - expectedErr bool + name string + pythonVersion string + concurrencyMax int + expectedErr string }{ { - name: "ValidVersion", - input: "3.12", - expectedErr: false, + name: "ValidVersion", + pythonVersion: "3.12", }, { - name: "MinimumVersion", - input: "3.8", - expectedErr: false, + name: "MinimumVersion", + pythonVersion: "3.8", }, { - name: "FullyQualifiedVersion", - input: "3.12.1", - expectedErr: false, + name: "MinimumVersionForConcurrency", + pythonVersion: "3.11", + concurrencyMax: 5, }, { - name: "InvalidFormat", - input: "3-12", - expectedErr: true, + name: "TooOldForConcurrency", + pythonVersion: "3.8", + concurrencyMax: 5, + expectedErr: "when concurrency.max is set, minimum supported Python version is 3.11. requested 3.8", }, { - name: "InvalidMissingMinor", - input: "3", - expectedErr: true, + name: "FullyQualifiedVersion", + pythonVersion: "3.12.1", }, { - name: "LessThanMinimum", - input: "3.7", - expectedErr: true, + name: "InvalidFormat", + pythonVersion: "3-12", + expectedErr: "invalid Python version format: missing minor version in 3-12", + }, + { + name: "InvalidMissingMinor", + pythonVersion: "3", + expectedErr: "invalid Python version format: missing minor version in 3", + }, + { + name: "LessThanMinimum", + pythonVersion: "3.7", + expectedErr: "minimum supported Python version is 3.8. requested 3.7", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - err := ValidateModelPythonVersion(tc.input) - if tc.expectedErr { - require.Error(t, err) + cfg := &Config{ + Build: &Build{ + PythonVersion: tc.pythonVersion, + }, + Concurrency: &Concurrency{ + Max: tc.concurrencyMax, + }, + } + err := ValidateModelPythonVersion(cfg) + if tc.expectedErr != "" { + require.ErrorContains(t, err, tc.expectedErr) } else { require.NoError(t, err) } @@ -649,17 +666,6 @@ func TestBlankBuild(t *testing.T) { require.Equal(t, false, config.Build.GPU) } -func TestModelPythonVersionValidation(t *testing.T) { - err := ValidateModelPythonVersion("3.8") - require.NoError(t, err) - err = ValidateModelPythonVersion("3.8.1") - require.NoError(t, err) - err = ValidateModelPythonVersion("3.7") - require.Equal(t, "minimum supported Python version is 3.8. requested 3.7", err.Error()) - err = ValidateModelPythonVersion("3.7.1") - require.Equal(t, "minimum supported Python version is 3.8. requested 3.7.1", err.Error()) -} - func TestSplitPinnedPythonRequirement(t *testing.T) { testCases := []struct { input string From 9813d8d027986ae7f46284ba2a88399aafcb696d Mon Sep 17 00:00:00 2001 From: Philip Potter Date: Tue, 19 Nov 2024 12:44:48 +0000 Subject: [PATCH 02/16] End-to-end support for concurrent async models This builds on the work in #2057 and wires it up end-to-end. We can now support async models with a max concurrency configured, and submit multiple predictions concurrently to them. We only support python 3.11 for async models; this is so that we can use asyncio.TaskGroup to keep track of multiple predictions in flight and ensure they all complete when shutting down. The cog http server was already async, but at one point it called wait() on a concurrent.futures.Future() which blocked the event loop and therefore prevented concurrent prediction requests (when not using prefer-async, which is how the tests run). I have updated this code to wait on asyncio.wrap_future(fut) instead which does not block the event loop. As part of this I have updated the training endpoints to also be asynchronous. We now have three places in the code which keep track of how many predictions are in flight: PredictionRunner, Worker and _ChildWorker all do their own bookkeeping. I'm not sure this is the best design but it works. The code is now an uneasy mix of threaded and asyncio code. This is evident in the usage of threading.Lock, which wouldn't be needed if we were 100% async (and I'm not sure if it's actually needed currently; I just added it to be safe). --- pkg/config/config.go | 2 +- pkg/config/config_test.go | 8 +- pyproject.toml | 4 + python/cog/config.py | 7 ++ python/cog/server/http.py | 22 ++-- python/cog/server/runner.py | 70 ++++++++---- python/cog/server/worker.py | 54 +++++++--- python/cog/types.py | 5 + python/tests/server/conftest.py | 30 ++++-- python/tests/server/test_runner.py | 91 ++++++++++++---- python/tests/server/test_worker.py | 101 +++++++++++------- .../fixtures/async-sleep-project/cog.yaml | 5 + .../fixtures/async-sleep-project/predict.py | 9 ++ .../fixtures/async-string-project/cog.yaml | 3 + .../fixtures/async-string-project/predict.py | 6 ++ .../test_integration/test_predict.py | 47 ++++++++ 16 files changed, 349 insertions(+), 115 deletions(-) create mode 100644 test-integration/test_integration/fixtures/async-sleep-project/cog.yaml create mode 100644 test-integration/test_integration/fixtures/async-sleep-project/predict.py create mode 100644 test-integration/test_integration/fixtures/async-string-project/cog.yaml create mode 100644 test-integration/test_integration/fixtures/async-string-project/predict.py diff --git a/pkg/config/config.go b/pkg/config/config.go index 38a1e64dad..f3e28eb7f6 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -263,7 +263,7 @@ func ValidateModelPythonVersion(cfg *Config) error { return fmt.Errorf("minimum supported Python version is %d.%d. requested %s", MinimumMajorPythonVersion, MinimumMinorPythonVersion, version) } - if cfg.Concurrency.Max > 1 && minor < MinimumMinorPythonVersionForConcurrency { + if cfg.Concurrency != nil && cfg.Concurrency.Max > 1 && minor < MinimumMinorPythonVersionForConcurrency { return fmt.Errorf("when concurrency.max is set, minimum supported Python version is %d.%d. requested %s", MinimumMajorPythonVersion, MinimumMinorPythonVersionForConcurrency, version) } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 798a865423..934fba5a70 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -64,9 +64,13 @@ func TestValidateModelPythonVersion(t *testing.T) { Build: &Build{ PythonVersion: tc.pythonVersion, }, - Concurrency: &Concurrency{ + } + if tc.concurrencyMax != 0 { + // the Concurrency key is optional, only populate it if + // concurrencyMax is a non-default value + cfg.Concurrency = &Concurrency{ Max: tc.concurrencyMax, - }, + } } err := ValidateModelPythonVersion(cfg) if tc.expectedErr != "" { diff --git a/pyproject.toml b/pyproject.toml index cfe833b8ac..e467288ba9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ tests = [ "numpy", "pillow", "pytest", + "pytest-asyncio", "pytest-httpserver", "pytest-timeout", "pytest-xdist", @@ -70,6 +71,9 @@ reportUnusedExpression = "warning" [tool.pyright.defineConstant] PYDANTIC_V2 = true +[tool.pytest.ini_options] +asyncio_default_fixture_loop_scope = "function" + [tool.setuptools] include-package-data = false diff --git a/python/cog/config.py b/python/cog/config.py index ec367e4fb7..bc560a1ddc 100644 --- a/python/cog/config.py +++ b/python/cog/config.py @@ -33,6 +33,7 @@ COG_PREDICT_CODE_STRIP_ENV_VAR = "COG_PREDICT_CODE_STRIP" COG_TRAIN_CODE_STRIP_ENV_VAR = "COG_TRAIN_CODE_STRIP" COG_GPU_ENV_VAR = "COG_GPU" +COG_MAX_CONCURRENCY_ENV_VAR = "COG_MAX_CONCURRENCY" PREDICT_METHOD_NAME = "predict" TRAIN_METHOD_NAME = "train" @@ -101,6 +102,12 @@ def requires_gpu(self) -> bool: """Whether this cog requires the use of a GPU.""" return bool(self._cog_config.get("build", {}).get("gpu", False)) + @property + @env_property(COG_MAX_CONCURRENCY_ENV_VAR) + def max_concurrency(self) -> int: + """The maximum concurrency of predictions supported by this model. Defaults to 1.""" + return int(self._cog_config.get("concurrency", {}).get("max", 1)) + def _predictor_code( self, module_path: str, diff --git a/python/cog/server/http.py b/python/cog/server/http.py index ae350fcfb3..f7da747dd4 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -165,9 +165,11 @@ async def start_shutdown() -> Any: return app worker = make_worker( - predictor_ref=cog_config.get_predictor_ref(mode=mode), is_async=is_async + predictor_ref=cog_config.get_predictor_ref(mode=mode), + is_async=is_async, + max_concurrency=cog_config.max_concurrency, ) - runner = PredictionRunner(worker=worker) + runner = PredictionRunner(worker=worker, max_concurrency=cog_config.max_concurrency) class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType)): pass @@ -219,7 +221,7 @@ class TrainingRequest( response_model=TrainingResponse, response_model_exclude_unset=True, ) - def train( + async def train( request: TrainingRequest = Body(default=None), prefer: Optional[str] = Header(default=None), traceparent: Optional[str] = Header( @@ -232,7 +234,7 @@ def train( respond_async = prefer == "respond-async" with trace_context(make_trace_context(traceparent, tracestate)): - return _predict( + return await _predict( request=request, response_type=TrainingResponse, respond_async=respond_async, @@ -243,7 +245,7 @@ def train( response_model=TrainingResponse, response_model_exclude_unset=True, ) - def train_idempotent( + async def train_idempotent( training_id: str = Path(..., title="Training ID"), request: TrainingRequest = Body(..., title="Training Request"), prefer: Optional[str] = Header(default=None), @@ -280,7 +282,7 @@ def train_idempotent( respond_async = prefer == "respond-async" with trace_context(make_trace_context(traceparent, tracestate)): - return _predict( + return await _predict( request=request, response_type=TrainingResponse, respond_async=respond_async, @@ -359,7 +361,7 @@ async def predict( respond_async = prefer == "respond-async" with trace_context(make_trace_context(traceparent, tracestate)): - return _predict( + return await _predict( request=request, response_type=PredictionResponse, respond_async=respond_async, @@ -407,13 +409,13 @@ async def predict_idempotent( respond_async = prefer == "respond-async" with trace_context(make_trace_context(traceparent, tracestate)): - return _predict( + return await _predict( request=request, response_type=PredictionResponse, respond_async=respond_async, ) - def _predict( + async def _predict( *, request: Optional[PredictionRequest], response_type: Type[schema.PredictionResponse], @@ -455,7 +457,7 @@ def _predict( ) # Otherwise, wait for the prediction to complete... - predict_task.wait() + await predict_task.wait_async() # ...and return the result. if PYDANTIC_V2: diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index 7468be2572..f9e8573439 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -1,5 +1,8 @@ +import asyncio import io +import threading import traceback +import uuid from abc import ABC, abstractmethod from concurrent.futures import Future from datetime import datetime, timezone @@ -60,13 +63,15 @@ class PredictionRunner: def __init__( self, *, + max_concurrency: int = 1, worker: Worker, ) -> None: self._worker = worker + self._max_concurrency = max_concurrency self._setup_task: Optional[SetupTask] = None - self._predict_task: Optional[PredictTask] = None - self._prediction_id = None + self._predict_tasks: Dict[str, PredictTask] = {} + self._predict_tasks_lock = threading.Lock() def setup(self) -> "SetupTask": assert self._setup_task is None, "do not call setup twice" @@ -88,8 +93,21 @@ def predict( task_kwargs = task_kwargs or {} - self._predict_task = PredictTask(prediction, **task_kwargs) - self._prediction_id = prediction.id + tag = prediction.id + if tag is None: + tag = uuid.uuid4().hex + + task = PredictTask(prediction, **task_kwargs) + + with self._predict_tasks_lock: + # first remove finished tasks so we don't grow the dictionary without bound + done_ids = [ + id for id in self._predict_tasks if self._predict_tasks[id].done() + ] + for id in done_ids: + del self._predict_tasks[id] + + self._predict_tasks[tag] = task if isinstance(prediction.input, BaseInput): if PYDANTIC_V2: @@ -101,18 +119,15 @@ def predict( else: payload = prediction.input.copy() - sid = self._worker.subscribe(self._predict_task.handle_event) - self._predict_task.track(self._worker.predict(payload)) - self._predict_task.add_done_callback(lambda _: self._worker.unsubscribe(sid)) + sid = self._worker.subscribe(task.handle_event, tag=tag) + task.track(self._worker.predict(payload, tag=tag)) + task.add_done_callback(lambda _: self._worker.unsubscribe(sid)) - return self._predict_task + return task def get_predict_task(self, id: str) -> Optional["PredictTask"]: - if not self._predict_task: - return None - if self._predict_task.result.id != id: - return None - return self._predict_task + with self._predict_tasks_lock: + return self._predict_tasks.get(id, None) def is_busy(self) -> bool: try: @@ -124,9 +139,13 @@ def is_busy(self) -> bool: def cancel(self, prediction_id: str) -> None: if not prediction_id: raise ValueError("prediction_id is required") - if self._prediction_id != prediction_id: - raise UnknownPredictionError("id mismatch") - self._worker.cancel() + with self._predict_tasks_lock: + if ( + prediction_id not in self._predict_tasks + or self._predict_tasks[prediction_id].done() + ): + raise UnknownPredictionError("unknown prediction id") + self._worker.cancel(tag=prediction_id) def _raise_if_busy(self) -> None: if self._setup_task is None: @@ -135,9 +154,17 @@ def _raise_if_busy(self) -> None: if not self._setup_task.done(): # Setup is still running. raise RunnerBusyError("setup is not complete") - if self._predict_task is not None and not self._predict_task.done(): - # Prediction is still running. - raise RunnerBusyError("prediction running") + + with self._predict_tasks_lock: + processing_tasks = [ + id for id in self._predict_tasks if not self._predict_tasks[id].done() + ] + + if len(processing_tasks) >= self._max_concurrency: + # We're at max concurrency + if self._max_concurrency == 1: + raise RunnerBusyError("prediction running") + raise RunnerBusyError("max predictions running") T = TypeVar("T") @@ -317,6 +344,11 @@ def done(self) -> bool: assert self._fut, "call track before checking done" return self._fut.done() + async def wait_async(self) -> None: + assert self._fut, "call track before waiting" + await asyncio.wrap_future(self._fut) + return None + def wait(self, timeout: Optional[float] = None) -> None: assert self._fut, "call track before waiting" self._fut.result(timeout=timeout) diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index 713e0a38f8..bba098e786 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -351,6 +351,7 @@ def __init__( *, is_async: bool, events: Connection, + max_concurrency: int = 1, tee_output: bool = True, ) -> None: self._predictor_ref = predictor_ref @@ -360,6 +361,7 @@ def __init__( ) self._tee_output = tee_output self._cancelable = False + self._max_concurrency = max_concurrency # for synchronous predictors only! async predictors use _tag_var instead self._sync_tag: Optional[str] = None @@ -459,6 +461,25 @@ def _setup( # Could be a function or a class if hasattr(self._predictor, "setup"): run_setup(self._predictor) + + predict = get_predict(self._predictor) + + is_async_predictor = inspect.iscoroutinefunction( + predict + ) or inspect.isasyncgenfunction(predict) + + # Async models require python >= 3.11 so we can use asyncio.TaskGroup + # We should check for this before getting to this point + if is_async_predictor and sys.version_info < (3, 11): + raise FatalWorkerException( + "Cog requires python >=3.11 for `async def predict(..)` support" + ) + + if self._max_concurrency > 1 and not is_async_predictor: + raise FatalWorkerException( + "max_concurrency>1 requires `async def predict()`" + ) + except Exception as e: # pylint: disable=broad-exception-caught traceback.print_exc() done.error = True @@ -512,20 +533,19 @@ async def _aloop( task = None - while True: - e = cast(Envelope, await self._events.recv()) - if isinstance(e.event, Cancel) and task and self._cancelable: - task.cancel() - elif isinstance(e.event, Shutdown): - break - elif isinstance(e.event, PredictionInput): - task = asyncio.create_task( - self._apredict(e.tag, e.event.payload, predict, redirector) - ) - else: - print(f"Got unexpected event: {e.event}", file=sys.stderr) - if task: - await task + async with asyncio.TaskGroup() as tg: + while True: + e = cast(Envelope, await self._events.recv()) + if isinstance(e.event, Cancel) and task and self._cancelable: + task.cancel() + elif isinstance(e.event, Shutdown): + break + elif isinstance(e.event, PredictionInput): + task = tg.create_task( + self._apredict(e.tag, e.event.payload, predict, redirector) + ) + else: + print(f"Got unexpected event: {e.event}", file=sys.stderr) def _predict( self, @@ -717,7 +737,11 @@ def make_worker( ) -> Worker: parent_conn, child_conn = _spawn.Pipe() child = _ChildWorker( - predictor_ref, events=child_conn, tee_output=tee_output, is_async=is_async + predictor_ref, + is_async=is_async, + events=child_conn, + tee_output=tee_output, + max_concurrency=max_concurrency, ) parent = Worker(child=child, events=parent_conn, max_concurrency=max_concurrency) return parent diff --git a/python/cog/types.py b/python/cog/types.py index 29d868c9e7..f6329dc6e4 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -43,6 +43,7 @@ class ExperimentalFeatureWarning(Warning): class CogConfig(TypedDict): # pylint: disable=too-many-ancestors build: "CogBuildConfig" + concurrency: "CogConcurrencyConfig" image: NotRequired[str] predict: NotRequired[str] train: NotRequired[str] @@ -58,6 +59,10 @@ class CogBuildConfig(TypedDict, total=False): # pylint: disable=too-many-ancest run: Optional[Union[List[str], List[Dict[str, Any]]]] +class CogConcurrencyConfig(TypedDict, total=False): # pylint: disable=too-many-ancestors + max: Optional[int] + + def Input( # pylint: disable=invalid-name, too-many-arguments default: Any = ..., description: str = None, diff --git a/python/tests/server/conftest.py b/python/tests/server/conftest.py index 3bf6d71c01..171ea3643f 100644 --- a/python/tests/server/conftest.py +++ b/python/tests/server/conftest.py @@ -1,4 +1,5 @@ import os +import sys import threading import time from contextlib import ExitStack @@ -27,6 +28,7 @@ class WorkerConfig: is_async: bool = False setup: bool = True max_concurrency: int = 1 + min_python: Optional[Tuple[int, int]] = None def pytest_make_parametrize_id(config, val): @@ -72,7 +74,9 @@ def uses_predictor_with_client_options(name, **options): ) -def uses_worker(name_or_names, setup=True, max_concurrency=1, is_async=False): +def uses_worker( + name_or_names, setup=True, max_concurrency=1, min_python=None, is_async=False +): """ Decorator for tests that require a Worker instance. `name_or_names` can be a single fixture name, or a sequence (list, tuple) of fixture names. If @@ -81,31 +85,34 @@ def uses_worker(name_or_names, setup=True, max_concurrency=1, is_async=False): If `setup` is True (the default) setup will be run before the test runs. """ if isinstance(name_or_names, (tuple, list)): - values = [ + values = ( WorkerConfig( fixture_name=n, setup=setup, max_concurrency=max_concurrency, + min_python=min_python, is_async=is_async, ) for n in name_or_names - ] + ) else: - values = [ + values = ( WorkerConfig( fixture_name=name_or_names, setup=setup, max_concurrency=max_concurrency, + min_python=min_python, is_async=is_async, ), - ] - return uses_worker_configs(values) + ) + return uses_worker_configs(list(values)) def uses_worker_configs(values: Sequence[WorkerConfig]): """ - Decorator for tests that require a Worker instance. `configs` can be - a sequence of `WorkerConfig` instances. + Decorator for tests that require a Worker instance. The test will be + run once for each worker. `configs` is a sequence (list, tuple, generator) + of WorkerConfig. """ return pytest.mark.parametrize("worker", values, indirect=True) @@ -168,6 +175,13 @@ def static_schema(client) -> dict: @pytest.fixture def worker(request): ref = _fixture_path(request.param.fixture_name) + if ( + request.param.min_python is not None + and sys.version_info < request.param.min_python + ): + pytest.skip( + f"Test requires python {request.param.min_python[0]}.{request.param.min_python[1]}" + ) w = make_worker( predictor_ref=ref, is_async=request.param.is_async, diff --git a/python/tests/server/test_runner.py b/python/tests/server/test_runner.py index b012e5f159..5e27357213 100644 --- a/python/tests/server/test_runner.py +++ b/python/tests/server/test_runner.py @@ -38,18 +38,23 @@ def __call__(self): class FakeWorker: def __init__(self): self.subscribers = {} - self.last_prediction_payload = None + self.subscribers_by_tag = {} self._setup_future = None - self._predict_future = None + self._predict_futures = {} + self.last_prediction_payload = None def subscribe(self, subscriber, tag=None): sid = uuid.uuid4() - self.subscribers[sid] = subscriber + self.subscribers[sid] = tag + if tag not in self.subscribers_by_tag: + self.subscribers_by_tag[tag] = {} + self.subscribers_by_tag[tag][sid] = subscriber return sid def unsubscribe(self, sid): - del self.subscribers[sid] + tag = self.subscribers.pop(sid) + del self.subscribers_by_tag[tag][sid] def setup(self): assert self._setup_future is None @@ -61,32 +66,38 @@ def run_setup(self, events): if isinstance(event, Exception): self._setup_future.set_exception(event) return - for subscriber in self.subscribers.values(): + for subscriber in self.subscribers_by_tag.get(None, {}).values(): subscriber(event) if isinstance(event, Done): self._setup_future.set_result(event) def predict(self, payload, tag=None): - assert self._predict_future is None or self._predict_future.done() + assert tag not in self._predict_futures or self._predict_futures[tag].done() self.last_prediction_payload = payload - self._predict_future = Future() - return self._predict_future - - def run_predict(self, events): + self._predict_futures[tag] = Future() + print(f"setting {tag}, now {self._predict_futures}") + return self._predict_futures[tag] + + def run_predict(self, events, id=None): + if id is None: + if len(self._predict_futures) != 1: + raise ValueError("Could not guess prediction id, please specify") + id = next(iter(self._predict_futures)) for event in events: if isinstance(event, Exception): - self._predict_future.set_exception(event) + self._predict_futures[id].set_exception(event) return - for subscriber in self.subscribers.values(): + for subscriber in self.subscribers_by_tag.get(id, {}).values(): subscriber(event) if isinstance(event, Done): - self._predict_future.set_result(event) + print(f"reading {id} from {self._predict_futures}") + self._predict_futures[id].set_result(event) def cancel(self, tag=None): done = Done(canceled=True) - for subscriber in self.subscribers.values(): + for subscriber in self.subscribers_by_tag.get(tag, {}).values(): subscriber(done) - self._predict_future.set_result(done) + self._predict_futures[tag].set_result(done) def test_prediction_runner_setup_success(): @@ -229,11 +240,11 @@ def test_prediction_runner_predict_after_predict_completes(): r.setup() w.run_setup([Done()]) - r.predict(PredictionRequest(input={"text": "giraffes"})) - w.run_predict([Done()]) + r.predict(PredictionRequest(id="p-1", input={"text": "giraffes"})) + w.run_predict([Done()], id="p-1") - r.predict(PredictionRequest(input={"text": "elephants"})) - w.run_predict([Done()]) + r.predict(PredictionRequest(id="p-2", input={"text": "elephants"})) + w.run_predict([Done()], id="p-2") assert w.last_prediction_payload == {"text": "elephants"} @@ -257,6 +268,31 @@ def test_prediction_runner_is_busy(): assert not r.is_busy() +def test_prediction_runner_is_busy_concurrency(): + w = FakeWorker() + r = PredictionRunner(worker=w, max_concurrency=3) + + assert r.is_busy() + + r.setup() + assert r.is_busy() + + w.run_setup([Done()]) + assert not r.is_busy() + + r.predict(PredictionRequest(id="1", input={"text": "elephants"})) + assert not r.is_busy() + + r.predict(PredictionRequest(id="2", input={"text": "elephants"})) + assert not r.is_busy() + + r.predict(PredictionRequest(id="3", input={"text": "elephants"})) + assert r.is_busy() + + w.run_predict([Done()], id="1") + assert not r.is_busy() + + def test_prediction_runner_predict_cancelation(): w = FakeWorker() r = PredictionRunner(worker=w) @@ -299,6 +335,23 @@ def test_prediction_runner_predict_cancelation_multiple_predictions(): assert task2.result.status == Status.CANCELED +def test_prediction_runner_predict_cancelation_concurrent_predictions(): + w = FakeWorker() + r = PredictionRunner(worker=w, max_concurrency=5) + + r.setup() + w.run_setup([Done()]) + + task1 = r.predict(PredictionRequest(id="abcd1234", input={"text": "giraffes"})) + + task2 = r.predict(PredictionRequest(id="defg6789", input={"text": "elephants"})) + + r.cancel("abcd1234") + w.run_predict([Done()], id="defg6789") + assert task1.result.status == Status.CANCELED + assert task2.result.status == Status.SUCCEEDED + + def test_prediction_runner_setup_e2e(): w = make_worker(predictor_ref=_fixture_path("sleep"), is_async=False) r = PredictionRunner(worker=w) diff --git a/python/tests/server/test_worker.py b/python/tests/server/test_worker.py index cb6a469430..abaa4ca19f 100644 --- a/python/tests/server/test_worker.py +++ b/python/tests/server/test_worker.py @@ -1,5 +1,6 @@ import multiprocessing import os +import sys import threading import time import uuid @@ -76,21 +77,7 @@ }, ), ( - WorkerConfig("record_metric_async", is_async=True), - {"name": ST_NAMES}, - { - "foo": 123, - }, - ), - ( - WorkerConfig("emit_metric"), - {"name": ST_NAMES}, - { - "foo": 123, - }, - ), - ( - WorkerConfig("emit_metric_async", is_async=True), + WorkerConfig("record_metric_async", min_python=(3, 11), is_async=True), {"name": ST_NAMES}, { "foo": 123, @@ -105,7 +92,7 @@ lambda x: f"hello, {x['name']}", ), ( - WorkerConfig("hello_world_async", is_async=True), + WorkerConfig("hello_world_async", min_python=(3, 11), is_async=True), {"name": ST_NAMES}, lambda x: f"hello, {x['name']}", ), @@ -132,7 +119,7 @@ "writing to stderr at import time\n", ), ( - WorkerConfig("logging_async", is_async=True, setup=False), + WorkerConfig("logging_async", setup=False, min_python=(3, 11), is_async=True), ("writing to stdout at import time\n" "setting up predictor\n"), "writing to stderr at import time\n", ), @@ -145,12 +132,22 @@ ("WARNING:root:writing log message\n" "writing to stderr\n"), ), ( - WorkerConfig("logging_async", is_async=True), + WorkerConfig("logging_async", min_python=(3, 11), is_async=True), ("writing with print\n"), ("WARNING:root:writing log message\n" "writing to stderr\n"), ), ] +SLEEP_FIXTURES = [ + WorkerConfig("sleep"), + WorkerConfig("sleep_async", min_python=(3, 11), is_async=True), +] + +SLEEP_NO_SETUP_FIXTURES = [ + WorkerConfig("sleep", setup=False), + WorkerConfig("sleep_async", min_python=(3, 11), setup=False, is_async=True), +] + @define class Result: @@ -255,9 +252,11 @@ def test_no_exceptions_from_recoverable_failures(worker): _process(worker, lambda: worker.predict({})) -# TODO test this works with errors and cancelations and the like @uses_worker_configs( - [WorkerConfig("simple"), WorkerConfig("simple_async", is_async=True)] + [ + WorkerConfig("simple"), + WorkerConfig("simple_async", min_python=(3, 11), is_async=True), + ] ) def test_can_subscribe_for_a_specific_tag(worker): tag = "123" @@ -280,12 +279,12 @@ def test_can_subscribe_for_a_specific_tag(worker): worker.unsubscribe(subid) -@uses_worker("sleep_async", is_async=True, max_concurrency=5) +@uses_worker("sleep_async", max_concurrency=5, min_python=(3, 11), is_async=True) def test_can_run_predictions_concurrently_on_async_predictor(worker): subids = [] try: - start = time.time() + start = time.perf_counter() futures = [] results = [] for i in range(5): @@ -299,7 +298,7 @@ def test_can_run_predictions_concurrently_on_async_predictor(worker): for fut in futures: fut.result() - end = time.time() + end = time.perf_counter() duration = end - start # we should take at least 0.5 seconds (the time for 1 prediction) but @@ -319,6 +318,40 @@ def test_can_run_predictions_concurrently_on_async_predictor(worker): worker.unsubscribe(subid) +@pytest.mark.skipif( + sys.version_info >= (3, 11), reason="Testing error message on python versions <3.11" +) +@uses_worker("simple_async", setup=False) +def test_async_predictor_on_python_3_10_or_older_raises_error(worker): + fut = worker.setup() + result = Result() + worker.subscribe(result.handle_event) + + with pytest.raises(FatalWorkerException): + fut.result() + assert result.done + assert result.done.error + assert ( + result.done.error_detail + == "Cog requires python >=3.11 for `async def predict(..)` support" + ) + + +@uses_worker("simple", max_concurrency=5, setup=False) +def test_concurrency_with_sync_predictor_raises_error(worker): + fut = worker.setup() + result = Result() + worker.subscribe(result.handle_event) + + with pytest.raises(FatalWorkerException): + fut.result() + assert result.done + assert result.done.error + assert ( + result.done.error_detail == "max_concurrency>1 requires `async def predict()`" + ) + + @uses_worker("stream_redirector_race_condition") def test_stream_redirector_race_condition(worker): """ @@ -403,12 +436,7 @@ def test_predict_logging(worker, expected_stdout, expected_stderr): assert result.stderr == expected_stderr -@uses_worker_configs( - [ - WorkerConfig("sleep", setup=False), - WorkerConfig("sleep_async", is_async=True, setup=False), - ] -) +@uses_worker_configs(SLEEP_NO_SETUP_FIXTURES) def test_cancel_is_safe(worker): """ Calls to cancel at any time should not result in unexpected things @@ -442,12 +470,7 @@ def test_cancel_is_safe(worker): assert result2.output == "done in 0.1 seconds" -@uses_worker_configs( - [ - WorkerConfig("sleep", setup=False), - WorkerConfig("sleep_async", is_async=True, setup=False), - ] -) +@uses_worker_configs(SLEEP_NO_SETUP_FIXTURES) def test_cancel_idempotency(worker): """ Multiple calls to cancel within the same prediction, while not necessary or @@ -479,9 +502,7 @@ def cancel_a_bunch(_): assert result2.output == "done in 0.1 seconds" -@uses_worker_configs( - [WorkerConfig("sleep"), WorkerConfig("sleep_async", is_async=True)] -) +@uses_worker_configs(SLEEP_FIXTURES) def test_cancel_multiple_predictions(worker): """ Multiple predictions cancelled in a row shouldn't be a problem. This test @@ -499,9 +520,7 @@ def test_cancel_multiple_predictions(worker): assert not worker.predict({"sleep": 0}).result().canceled -@uses_worker_configs( - [WorkerConfig("sleep"), WorkerConfig("sleep_async", is_async=True)] -) +@uses_worker_configs(SLEEP_FIXTURES) def test_graceful_shutdown(worker): """ On shutdown, the worker should finish running the current prediction, and diff --git a/test-integration/test_integration/fixtures/async-sleep-project/cog.yaml b/test-integration/test_integration/fixtures/async-sleep-project/cog.yaml new file mode 100644 index 0000000000..04d04bf7c8 --- /dev/null +++ b/test-integration/test_integration/fixtures/async-sleep-project/cog.yaml @@ -0,0 +1,5 @@ +build: + python_version: "3.11" +predict: "predict.py:Predictor" +concurrency: + max: 5 diff --git a/test-integration/test_integration/fixtures/async-sleep-project/predict.py b/test-integration/test_integration/fixtures/async-sleep-project/predict.py new file mode 100644 index 0000000000..e6c65797a0 --- /dev/null +++ b/test-integration/test_integration/fixtures/async-sleep-project/predict.py @@ -0,0 +1,9 @@ +import asyncio + +from cog import BasePredictor + + +class Predictor(BasePredictor): + async def predict(self, s: str, sleep: float) -> str: + await asyncio.sleep(sleep) + return f"wake up {s}" diff --git a/test-integration/test_integration/fixtures/async-string-project/cog.yaml b/test-integration/test_integration/fixtures/async-string-project/cog.yaml new file mode 100644 index 0000000000..7b6d5d4dce --- /dev/null +++ b/test-integration/test_integration/fixtures/async-string-project/cog.yaml @@ -0,0 +1,3 @@ +build: + python_version: "3.11" +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/async-string-project/predict.py b/test-integration/test_integration/fixtures/async-string-project/predict.py new file mode 100644 index 0000000000..fb2805c794 --- /dev/null +++ b/test-integration/test_integration/fixtures/async-string-project/predict.py @@ -0,0 +1,6 @@ +from cog import BasePredictor + + +class Predictor(BasePredictor): + async def predict(self, s: str) -> str: + return "hello " + s diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index 4065b69927..459f09f03e 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -1,6 +1,8 @@ +import asyncio import pathlib import shutil import subprocess +import time from pathlib import Path import httpx @@ -27,6 +29,20 @@ def test_predict_takes_string_inputs_and_returns_strings_to_stdout(): assert "falling back to slow loader" in result.stderr +def test_predict_supports_async_predictors(): + project_dir = Path(__file__).parent / "fixtures/async-string-project" + result = subprocess.run( + ["cog", "predict", "--debug", "-i", "s=world"], + cwd=project_dir, + check=True, + capture_output=True, + text=True, + timeout=DEFAULT_TIMEOUT, + ) + # stdout should be clean without any log messages so it can be piped to other commands + assert result.stdout == "hello world\n" + + def test_predict_takes_int_inputs_and_returns_ints_to_stdout(): project_dir = Path(__file__).parent / "fixtures/int-project" result = subprocess.run( @@ -322,3 +338,34 @@ def test_predict_with_subprocess_in_setup(fixture_name): assert response.status_code == 200, str(response) assert busy_count < 10 + + +@pytest.mark.asyncio +async def test_concurrent_predictions(): + async def make_request(i: int) -> httpx.Response: + return await client.post( + f"{addr}/predictions", + json={ + "id": f"id-{i}", + "input": {"s": f"sleepyhead{i}", "sleep": 1.0}, + }, + ) + + with cog_server_http_run( + Path(__file__).parent / "fixtures" / "async-sleep-project" + ) as addr: + async with httpx.AsyncClient() as client: + tasks = [] + start = time.perf_counter() + async with asyncio.TaskGroup() as tg: + for i in range(5): + tasks.append(tg.create_task(make_request(i))) + # give time for all of the predictions to be accepted, but not completed + await asyncio.sleep(0.2) + # we shut the server down, but expect all running predictions to complete + await client.post(f"{addr}/shutdown") + end = time.perf_counter() + assert (end - start) < 3.0 # ensure the predictions ran concurrently + for i, task in enumerate(tasks): + assert task.result().status_code == 200 + assert task.result().json()["output"] == f"wake up sleepyhead{i}" From fce5fb25719818928e839e4c7fb00a0fa74ced55 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Tue, 26 Nov 2024 15:15:46 +0000 Subject: [PATCH 03/16] Fix typing of CogConcurrencyConfig The use of `Optional` allowed `None` as a valid value. This has been changed to use `NotRequired` which allows the field to be omitted but must always be an integer when present. --- python/cog/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cog/types.py b/python/cog/types.py index f6329dc6e4..4bacf8ee16 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -60,7 +60,7 @@ class CogBuildConfig(TypedDict, total=False): # pylint: disable=too-many-ancest class CogConcurrencyConfig(TypedDict, total=False): # pylint: disable=too-many-ancestors - max: Optional[int] + max: NotRequired[int] def Input( # pylint: disable=invalid-name, too-many-arguments From 2be4397ae878956a7aa5aa51844b259474f7b79d Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Tue, 26 Nov 2024 15:22:14 +0000 Subject: [PATCH 04/16] Improve consistency of error messages for async predict method --- python/cog/server/worker.py | 4 ++-- python/tests/server/test_worker.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index bba098e786..c22f2fbd62 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -472,12 +472,12 @@ def _setup( # We should check for this before getting to this point if is_async_predictor and sys.version_info < (3, 11): raise FatalWorkerException( - "Cog requires python >=3.11 for `async def predict(..)` support" + "Cog requires Python >=3.11 for `async def predict()` support" ) if self._max_concurrency > 1 and not is_async_predictor: raise FatalWorkerException( - "max_concurrency>1 requires `async def predict()`" + "max_concurrency > 1 requires an async predict function, e.g. `async def predict()`" ) except Exception as e: # pylint: disable=broad-exception-caught diff --git a/python/tests/server/test_worker.py b/python/tests/server/test_worker.py index abaa4ca19f..d470890c76 100644 --- a/python/tests/server/test_worker.py +++ b/python/tests/server/test_worker.py @@ -333,7 +333,7 @@ def test_async_predictor_on_python_3_10_or_older_raises_error(worker): assert result.done.error assert ( result.done.error_detail - == "Cog requires python >=3.11 for `async def predict(..)` support" + == "Cog requires Python >=3.11 for `async def predict()` support" ) @@ -348,7 +348,8 @@ def test_concurrency_with_sync_predictor_raises_error(worker): assert result.done assert result.done.error assert ( - result.done.error_detail == "max_concurrency>1 requires `async def predict()`" + result.done.error_detail + == "max_concurrency > 1 requires an async predict function, e.g. `async def predict()`" ) From 8623698560a6ef74b54e1da1e6b56f2c2f3b0164 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Tue, 26 Nov 2024 15:33:06 +0000 Subject: [PATCH 05/16] Re-word internal use of id to tag Inside the worker we track predictions by tag not exterenal predicition IDs, this commit updates the variable names to reflect this. --- python/cog/server/runner.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index f9e8573439..ebf97d1e4b 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -101,11 +101,12 @@ def predict( with self._predict_tasks_lock: # first remove finished tasks so we don't grow the dictionary without bound - done_ids = [ - id for id in self._predict_tasks if self._predict_tasks[id].done() + # TODO: clean this up by adding a done callback to the task. + done_tags = [ + tag for tag in self._predict_tasks if self._predict_tasks[tag].done() ] - for id in done_ids: - del self._predict_tasks[id] + for tag in done_tags: + del self._predict_tasks[tag] self._predict_tasks[tag] = task From 9f96f5a816cfaf98c10d91d5251b2758ad7502f6 Mon Sep 17 00:00:00 2001 From: Philip Potter Date: Tue, 26 Nov 2024 16:35:27 +0000 Subject: [PATCH 06/16] fix test failure the `for tag in done_tags:` was resetting the existing `tag` variable and breaking things. --- python/cog/server/runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index ebf97d1e4b..081fb593b3 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -105,8 +105,8 @@ def predict( done_tags = [ tag for tag in self._predict_tasks if self._predict_tasks[tag].done() ] - for tag in done_tags: - del self._predict_tasks[tag] + for done_tag in done_tags: + del self._predict_tasks[done_tag] self._predict_tasks[tag] = task From 947b7acfe035007e1ce84187ced699f930c64be6 Mon Sep 17 00:00:00 2001 From: Philip Potter Date: Tue, 26 Nov 2024 16:40:22 +0000 Subject: [PATCH 07/16] Remove tasks from _predict_tasks in callback --- python/cog/server/runner.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index 081fb593b3..1b752ae27a 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -100,14 +100,6 @@ def predict( task = PredictTask(prediction, **task_kwargs) with self._predict_tasks_lock: - # first remove finished tasks so we don't grow the dictionary without bound - # TODO: clean this up by adding a done callback to the task. - done_tags = [ - tag for tag in self._predict_tasks if self._predict_tasks[tag].done() - ] - for done_tag in done_tags: - del self._predict_tasks[done_tag] - self._predict_tasks[tag] = task if isinstance(prediction.input, BaseInput): @@ -122,10 +114,18 @@ def predict( sid = self._worker.subscribe(task.handle_event, tag=tag) task.track(self._worker.predict(payload, tag=tag)) - task.add_done_callback(lambda _: self._worker.unsubscribe(sid)) + task.add_done_callback(self._task_done_callback(tag, sid)) return task + def _task_done_callback(self, tag: str, sid: int) -> Callable[[Any], None]: + def _callback(_) -> None: + self._worker.unsubscribe(sid) + with self._predict_tasks_lock: + del self._predict_tasks[tag] + + return _callback + def get_predict_task(self, id: str) -> Optional["PredictTask"]: with self._predict_tasks_lock: return self._predict_tasks.get(id, None) From 54dc287a6df71a92d25678eed9131f1b89ca4b89 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 27 Nov 2024 13:42:56 +0000 Subject: [PATCH 08/16] Ensure SimpleStreamWrapper flushes on newlines This commit manually calls `flush()` on the `SimpleStreamWrapper` each time the string provided to `write()` contains a newline character. The previous implementation assumed that the underlying TextIOWrapper class would call our custom `flush()` method but this is not the case as `TextIOWrapper` is implemented in C and calls into the compiled code. --- python/cog/server/helpers.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/cog/server/helpers.py b/python/cog/server/helpers.py index 3aa967ed2a..990fa49cb0 100644 --- a/python/cog/server/helpers.py +++ b/python/cog/server/helpers.py @@ -33,7 +33,7 @@ def __init__( callback: Callable[[str, str], None], tee: bool = False, ) -> None: - super().__init__(buffer, line_buffering=True) + super().__init__(buffer) self._callback = callback self._tee = tee @@ -44,11 +44,10 @@ def write(self, s: str) -> int: self._buffer.append(s) if self._tee: super().write(s) - else: - # If we're not teeing, we have to handle automatic flush on - # newline. When `tee` is true, this is handled by the write method. - if "\n" in s or "\r" in s: - self.flush() + + if "\n" in s or "\r" in s: + self.flush() + return length def flush(self) -> None: From 75c5408b5e25b0e9c3ecaea3c5c97218b15aadc1 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 27 Nov 2024 11:13:38 +0000 Subject: [PATCH 09/16] Add deprecated `emit_metric()` helper to cog interface This helps with the transition between the main cog branch and the experimental `async` branch. Models built for the `async` branch can be run on either without code changes. --- python/tests/server/test_worker.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/tests/server/test_worker.py b/python/tests/server/test_worker.py index d470890c76..236c3e88d7 100644 --- a/python/tests/server/test_worker.py +++ b/python/tests/server/test_worker.py @@ -83,6 +83,20 @@ "foo": 123, }, ), + ( + WorkerConfig("emit_metric"), + {"name": ST_NAMES}, + { + "foo": 123, + }, + ), + ( + WorkerConfig("emit_metric_async"), + {"name": ST_NAMES}, + { + "foo": 123, + }, + ), ] OUTPUT_FIXTURES = [ From 76d58bb7df42de6ecdbc1c98dd5f1800b981a701 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 27 Nov 2024 15:12:05 +0000 Subject: [PATCH 10/16] Pass min_python into WorkerConfig tests --- python/tests/server/test_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/server/test_worker.py b/python/tests/server/test_worker.py index 236c3e88d7..5acb532123 100644 --- a/python/tests/server/test_worker.py +++ b/python/tests/server/test_worker.py @@ -91,7 +91,7 @@ }, ), ( - WorkerConfig("emit_metric_async"), + WorkerConfig("emit_metric_async", min_python=(3, 11)), {"name": ST_NAMES}, { "foo": 123, From 0895db461ef9702edc28fa7e6d27ad26a07fa32a Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Thu, 28 Nov 2024 11:27:02 +0000 Subject: [PATCH 11/16] Add AsyncConcatenateIterator to exported types --- python/cog/__init__.py | 2 + python/cog/types.py | 91 ++++++++++++++++++++++++++++++++---------- 2 files changed, 71 insertions(+), 22 deletions(-) diff --git a/python/cog/__init__.py b/python/cog/__init__.py index d6ad24d9af..72f1399cd0 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -6,6 +6,7 @@ from .mimetypes_ext import install_mime_extensions from .server.scope import current_scope, emit_metric from .types import ( + AsyncConcatenateIterator, ConcatenateIterator, ExperimentalFeatureWarning, File, @@ -26,6 +27,7 @@ "__version__", "current_scope", "emit_metric", + "AsyncConcatenateIterator", "BaseModel", "BasePredictor", "ConcatenateIterator", diff --git a/python/cog/types.py b/python/cog/types.py index 4bacf8ee16..c27247afa9 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -9,6 +9,7 @@ import urllib.response from typing import ( Any, + AsyncIterator, Dict, Iterator, List, @@ -65,13 +66,13 @@ class CogConcurrencyConfig(TypedDict, total=False): # pylint: disable=too-many- def Input( # pylint: disable=invalid-name, too-many-arguments default: Any = ..., - description: str = None, - ge: float = None, - le: float = None, - min_length: int = None, - max_length: int = None, - regex: str = None, - choices: List[Union[str, int]] = None, + description: Optional[str] = None, + ge: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + regex: Optional[str] = None, + choices: Optional[List[Union[str, int]]] = None, ) -> Any: """Input is similar to pydantic.Field, but doesn't require a default value to be the first argument.""" field_kwargs = { @@ -415,6 +416,12 @@ def get_filename(url: str) -> str: Item = TypeVar("Item") +_concatenate_iterator_schema = { + "type": "array", + "items": {"type": "string"}, + "x-cog-array-type": "iterator", + "x-cog-array-display": "concatenate", +} class ConcatenateIterator(Iterator[Item]): # pylint: disable=abstract-method @@ -450,14 +457,7 @@ def __get_pydantic_json_schema__( ) -> "JsonSchemaValue": # type: ignore # noqa: F821 json_schema = handler(core_schema) json_schema.pop("allOf", None) - json_schema.update( - { - "type": "array", - "items": {"type": "string"}, - "x-cog-array-type": "iterator", - "x-cog-array-display": "concatenate", - } - ) + json_schema.update(_concatenate_iterator_schema) return json_schema else: @@ -470,15 +470,62 @@ def __get_validators__(cls) -> Iterator[Any]: def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: """Defines what this type should be in openapi.json""" field_schema.pop("allOf", None) - field_schema.update( - { - "type": "array", - "items": {"type": "string"}, - "x-cog-array-type": "iterator", - "x-cog-array-display": "concatenate", - } + field_schema.update(_concatenate_iterator_schema) + + +class AsyncConcatenateIterator(AsyncIterator[Item]): + @classmethod + def validate(cls, value: AsyncIterator[Any]) -> AsyncIterator[Any]: + return value + + if PYDANTIC_V2: + from pydantic import GetCoreSchemaHandler + from pydantic.json_schema import JsonSchemaValue + from pydantic_core import CoreSchema + + @classmethod + def __get_pydantic_core_schema__( + cls, + source: Type[Any], # pylint: disable=unused-argument + handler: "pydantic.GetCoreSchemaHandler", # pylint: disable=unused-argument + ) -> "CoreSchema": + from pydantic_core import ( # pylint: disable=import-outside-toplevel + core_schema, ) + return core_schema.union_schema( + [ + core_schema.is_instance_schema(AsyncIterator), + core_schema.no_info_plain_validator_function(cls.validate), + ] + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, core_schema: "CoreSchema", handler: "pydantic.GetJsonSchemaHandler" + ) -> "JsonSchemaValue": # type: ignore # noqa: F821 + json_schema = handler(core_schema) + json_schema.pop("allOf", None) + json_schema.update(_concatenate_iterator_schema) + return json_schema + else: + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + """Defines what this type should be in openapi.json""" + field_schema.pop("allOf", None) + field_schema.update(_concatenate_iterator_schema) + + @classmethod + def __get_validators__(cls) -> Iterator[Any]: + yield cls.validate + + +def get_filename_from_urlopen(resp: urllib.response.addinfourl) -> str: + mime_type = resp.headers.get_content_type() + extension = mimetypes.guess_extension(mime_type) + return ("file" + extension) if extension else "file" + def _len_bytes(s: str, encoding: str = "utf-8") -> int: return len(s.encode(encoding)) From 22b70b4ff2532390a566c3eb59745c0394962dbe Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Thu, 28 Nov 2024 12:08:08 +0000 Subject: [PATCH 12/16] Dynamically add self.log onto the Predictor instance This copies the functionality over from the `async` branch and allows us to test models on production. This will emit a `DeprecationWarning`. Users will need to add `warnings.filterwarnings("once", DeprecationWarning)` to their code to see the error. --- python/cog/server/worker.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index c22f2fbd62..de95c432a8 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -11,6 +11,7 @@ import traceback import types import uuid +import warnings from concurrent.futures import Future, ThreadPoolExecutor from enum import Enum, auto, unique from multiprocessing.connection import Connection @@ -396,6 +397,7 @@ def run(self) -> None: # it has sent a error Done event and we're done here. if not self._predictor: return + self._predictor.log = self._log predict = get_predict(self._predictor) if self._is_async: @@ -727,6 +729,18 @@ def _stream_write_hook(self, stream_name: str, data: str) -> None: Envelope(event=Log(data, source="stderr"), tag=self._current_tag) ) + def _log(self, *messages: str, source: str = "stderr") -> None: + """ + DEPRECATED: This function will be removed in a future version of cog. + """ + warnings.warn( + "log() is deprecated and will be removed in a future version. Use `print` or `logging` module instead", + category=DeprecationWarning, + stacklevel=1, + ) + file = sys.stdout if source == "stdout" else sys.stderr + print(*messages, file=file) + def make_worker( predictor_ref: str, From 43ee1881b817e50ed1d321debba959e5ced9855c Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Thu, 28 Nov 2024 12:44:31 +0000 Subject: [PATCH 13/16] Do not include trailing newline in log implementation --- python/cog/server/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index de95c432a8..e1470356a0 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -739,7 +739,7 @@ def _log(self, *messages: str, source: str = "stderr") -> None: stacklevel=1, ) file = sys.stdout if source == "stdout" else sys.stderr - print(*messages, file=file) + print(*messages, file=file, end="") def make_worker( From 5cdaf9b378fdf0eaaeb2367d91ba1ec4e36a09f2 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 29 Nov 2024 15:51:16 +0000 Subject: [PATCH 14/16] linting --- python/tests/server/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/server/conftest.py b/python/tests/server/conftest.py index 171ea3643f..4d5dfd8bb7 100644 --- a/python/tests/server/conftest.py +++ b/python/tests/server/conftest.py @@ -3,7 +3,7 @@ import threading import time from contextlib import ExitStack -from typing import Any, Dict, Optional, Sequence +from typing import Any, Dict, Optional, Sequence, Tuple from unittest import mock import pytest From 9623a668c42c3edb9d0cbbf7f5f3a691eca5af66 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 29 Nov 2024 16:03:17 +0000 Subject: [PATCH 15/16] Ignore type issue --- python/cog/server/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index e1470356a0..2cba7c3f7e 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -397,7 +397,7 @@ def run(self) -> None: # it has sent a error Done event and we're done here. if not self._predictor: return - self._predictor.log = self._log + self._predictor.log = self._log # type: ignore predict = get_predict(self._predictor) if self._is_async: From 6b7f18dec8c151f7d2c8e305b8d0ae025edd5258 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 29 Nov 2024 16:03:30 +0000 Subject: [PATCH 16/16] Add missing is_async flag to emit_metrics commit --- python/tests/server/test_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/server/test_worker.py b/python/tests/server/test_worker.py index 5acb532123..d7aa69cad1 100644 --- a/python/tests/server/test_worker.py +++ b/python/tests/server/test_worker.py @@ -91,7 +91,7 @@ }, ), ( - WorkerConfig("emit_metric_async", min_python=(3, 11)), + WorkerConfig("emit_metric_async", min_python=(3, 11), is_async=True), {"name": ST_NAMES}, { "foo": 123,