Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

End-to-end support for concurrent async models #2066

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/cli/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func buildCommand(cmd *cobra.Command, args []string) error {
imageName = config.DockerImageName(projectDir)
}

err = config.ValidateModelPythonVersion(cfg.Build.PythonVersion)
err = config.ValidateModelPythonVersion(cfg)
if err != nil {
return err
}
Expand Down
28 changes: 20 additions & 8 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ var (
// TODO(andreas): suggest valid torchvision versions (e.g. if the user wants to use 0.8.0, suggest 0.8.1)

const (
MinimumMajorPythonVersion int = 3
MinimumMinorPythonVersion int = 8
MinimumMajorCudaVersion int = 11
MinimumMajorPythonVersion int = 3
MinimumMinorPythonVersion int = 8
MinimumMinorPythonVersionForConcurrency int = 11
MinimumMajorCudaVersion int = 11
)

type RunItem struct {
Expand All @@ -58,16 +59,21 @@ type Build struct {
pythonRequirementsContent []string
}

type Concurrency struct {
Max int `json:"max,omitempty" yaml:"max"`
}

type Example struct {
Input map[string]string `json:"input" yaml:"input"`
Output string `json:"output" yaml:"output"`
}

type Config struct {
Build *Build `json:"build" yaml:"build"`
Image string `json:"image,omitempty" yaml:"image"`
Predict string `json:"predict,omitempty" yaml:"predict"`
Train string `json:"train,omitempty" yaml:"train"`
Build *Build `json:"build" yaml:"build"`
Image string `json:"image,omitempty" yaml:"image"`
Predict string `json:"predict,omitempty" yaml:"predict"`
Train string `json:"train,omitempty" yaml:"train"`
Concurrency *Concurrency `json:"concurrency,omitempty" yaml:"concurrency"`
}

func DefaultConfig() *Config {
Expand Down Expand Up @@ -244,7 +250,9 @@ func splitPythonVersion(version string) (major int, minor int, err error) {
return major, minor, nil
}

func ValidateModelPythonVersion(version string) error {
func ValidateModelPythonVersion(cfg *Config) error {
version := cfg.Build.PythonVersion

// we check for minimum supported here
major, minor, err := splitPythonVersion(version)
if err != nil {
Expand All @@ -255,6 +263,10 @@ func ValidateModelPythonVersion(version string) error {
return fmt.Errorf("minimum supported Python version is %d.%d. requested %s",
MinimumMajorPythonVersion, MinimumMinorPythonVersion, version)
}
if cfg.Concurrency != nil && cfg.Concurrency.Max > 1 && minor < MinimumMinorPythonVersionForConcurrency {
return fmt.Errorf("when concurrency.max is set, minimum supported Python version is %d.%d. requested %s",
MinimumMajorPythonVersion, MinimumMinorPythonVersionForConcurrency, version)
}
return nil
}

Expand Down
80 changes: 45 additions & 35 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,68 @@ import (

func TestValidateModelPythonVersion(t *testing.T) {
testCases := []struct {
name string
input string
expectedErr bool
name string
pythonVersion string
concurrencyMax int
expectedErr string
}{
{
name: "ValidVersion",
input: "3.12",
expectedErr: false,
name: "ValidVersion",
pythonVersion: "3.12",
},
{
name: "MinimumVersion",
input: "3.8",
expectedErr: false,
name: "MinimumVersion",
pythonVersion: "3.8",
},
{
name: "FullyQualifiedVersion",
input: "3.12.1",
expectedErr: false,
name: "MinimumVersionForConcurrency",
pythonVersion: "3.11",
concurrencyMax: 5,
},
{
name: "InvalidFormat",
input: "3-12",
expectedErr: true,
name: "TooOldForConcurrency",
pythonVersion: "3.8",
concurrencyMax: 5,
expectedErr: "when concurrency.max is set, minimum supported Python version is 3.11. requested 3.8",
},
{
name: "InvalidMissingMinor",
input: "3",
expectedErr: true,
name: "FullyQualifiedVersion",
pythonVersion: "3.12.1",
},
{
name: "LessThanMinimum",
input: "3.7",
expectedErr: true,
name: "InvalidFormat",
pythonVersion: "3-12",
expectedErr: "invalid Python version format: missing minor version in 3-12",
},
{
name: "InvalidMissingMinor",
pythonVersion: "3",
expectedErr: "invalid Python version format: missing minor version in 3",
},
{
name: "LessThanMinimum",
pythonVersion: "3.7",
expectedErr: "minimum supported Python version is 3.8. requested 3.7",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := ValidateModelPythonVersion(tc.input)
if tc.expectedErr {
require.Error(t, err)
cfg := &Config{
Build: &Build{
PythonVersion: tc.pythonVersion,
},
}
if tc.concurrencyMax != 0 {
// the Concurrency key is optional, only populate it if
// concurrencyMax is a non-default value
cfg.Concurrency = &Concurrency{
Max: tc.concurrencyMax,
}
}
err := ValidateModelPythonVersion(cfg)
if tc.expectedErr != "" {
require.ErrorContains(t, err, tc.expectedErr)
} else {
require.NoError(t, err)
}
Expand Down Expand Up @@ -649,17 +670,6 @@ func TestBlankBuild(t *testing.T) {
require.Equal(t, false, config.Build.GPU)
}

func TestModelPythonVersionValidation(t *testing.T) {
err := ValidateModelPythonVersion("3.8")
require.NoError(t, err)
err = ValidateModelPythonVersion("3.8.1")
require.NoError(t, err)
err = ValidateModelPythonVersion("3.7")
require.Equal(t, "minimum supported Python version is 3.8. requested 3.7", err.Error())
err = ValidateModelPythonVersion("3.7.1")
require.Equal(t, "minimum supported Python version is 3.8. requested 3.7.1", err.Error())
}

func TestSplitPinnedPythonRequirement(t *testing.T) {
testCases := []struct {
input string
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ tests = [
"numpy",
"pillow",
"pytest",
"pytest-asyncio",
"pytest-httpserver",
"pytest-timeout",
"pytest-xdist",
Expand Down Expand Up @@ -70,6 +71,9 @@ reportUnusedExpression = "warning"
[tool.pyright.defineConstant]
PYDANTIC_V2 = true

[tool.pytest.ini_options]
asyncio_default_fixture_loop_scope = "function"

[tool.setuptools]
include-package-data = false

Expand Down
7 changes: 7 additions & 0 deletions python/cog/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
COG_PREDICT_CODE_STRIP_ENV_VAR = "COG_PREDICT_CODE_STRIP"
COG_TRAIN_CODE_STRIP_ENV_VAR = "COG_TRAIN_CODE_STRIP"
COG_GPU_ENV_VAR = "COG_GPU"
COG_MAX_CONCURRENCY_ENV_VAR = "COG_MAX_CONCURRENCY"
PREDICT_METHOD_NAME = "predict"
TRAIN_METHOD_NAME = "train"

Expand Down Expand Up @@ -98,6 +99,12 @@ def requires_gpu(self) -> bool:
"""Whether this cog requires the use of a GPU."""
return bool(self._cog_config.get("build", {}).get("gpu", False))

@property
@env_property(COG_MAX_CONCURRENCY_ENV_VAR)
def max_concurrency(self) -> int:
"""The maximum concurrency of predictions supported by this model. Defaults to 1."""
return int(self._cog_config.get("concurrency", {}).get("max", 1))

def _predictor_code(
self,
module_path: str,
Expand Down
23 changes: 13 additions & 10 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,11 @@ async def start_shutdown() -> Any:
add_setup_failed_routes(app, started_at, msg)
return app

worker = make_worker(predictor_ref=cog_config.get_predictor_ref(mode=mode))
runner = PredictionRunner(worker=worker)
worker = make_worker(
predictor_ref=cog_config.get_predictor_ref(mode=mode),
max_concurrency=cog_config.max_concurrency,
)
runner = PredictionRunner(worker=worker, max_concurrency=cog_config.max_concurrency)

class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType)):
pass
Expand Down Expand Up @@ -215,7 +218,7 @@ class TrainingRequest(
response_model=TrainingResponse,
response_model_exclude_unset=True,
)
def train(
async def train(
request: TrainingRequest = Body(default=None),
prefer: Optional[str] = Header(default=None),
traceparent: Optional[str] = Header(
Expand All @@ -228,7 +231,7 @@ def train(
respond_async = prefer == "respond-async"

with trace_context(make_trace_context(traceparent, tracestate)):
return _predict(
return await _predict(
request=request,
response_type=TrainingResponse,
respond_async=respond_async,
Expand All @@ -239,7 +242,7 @@ def train(
response_model=TrainingResponse,
response_model_exclude_unset=True,
)
def train_idempotent(
async def train_idempotent(
training_id: str = Path(..., title="Training ID"),
request: TrainingRequest = Body(..., title="Training Request"),
prefer: Optional[str] = Header(default=None),
Expand Down Expand Up @@ -276,7 +279,7 @@ def train_idempotent(
respond_async = prefer == "respond-async"

with trace_context(make_trace_context(traceparent, tracestate)):
return _predict(
return await _predict(
request=request,
response_type=TrainingResponse,
respond_async=respond_async,
Expand Down Expand Up @@ -355,7 +358,7 @@ async def predict(
respond_async = prefer == "respond-async"

with trace_context(make_trace_context(traceparent, tracestate)):
return _predict(
return await _predict(
request=request,
response_type=PredictionResponse,
respond_async=respond_async,
Expand Down Expand Up @@ -403,13 +406,13 @@ async def predict_idempotent(
respond_async = prefer == "respond-async"

with trace_context(make_trace_context(traceparent, tracestate)):
return _predict(
return await _predict(
request=request,
response_type=PredictionResponse,
respond_async=respond_async,
)

def _predict(
async def _predict(
*,
request: Optional[PredictionRequest],
response_type: Type[schema.PredictionResponse],
Expand Down Expand Up @@ -451,7 +454,7 @@ def _predict(
)

# Otherwise, wait for the prediction to complete...
predict_task.wait()
await predict_task.wait_async()

# ...and return the result.
if PYDANTIC_V2:
Expand Down
Loading
Loading