Skip to content

Commit

Permalink
tweaks to separate -> None from -> Signature.empty
Browse files Browse the repository at this point in the history
  • Loading branch information
technillogue committed Sep 13, 2024
1 parent 7720523 commit 7789385
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
20 changes: 9 additions & 11 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import uuid
from abc import ABC, abstractmethod
from collections.abc import Iterator
from inspect import Signature
from pathlib import Path
from typing import (
Any,
Expand Down Expand Up @@ -267,7 +268,7 @@ def validate_input_type(
type: Type[Any], # pylint: disable=redefined-builtin
name: str,
) -> None:
if type is inspect.Signature.empty:
if type is Signature.empty:
raise TypeError(
f"No input type provided for parameter `{name}`. Supported input types are: {readable_types_list(ALLOWED_INPUT_TYPES)}, or a Union or List of those types."
)
Expand All @@ -284,7 +285,7 @@ def validate_input_type(


def get_input_create_model_kwargs(
signature: inspect.Signature, input_types: Dict[str, Any]
signature: Signature, input_types: Dict[str, Any]
) -> Dict[str, Any]:
create_model_kwargs = {}

Expand All @@ -296,7 +297,7 @@ def get_input_create_model_kwargs(
validate_input_type(InputType, name)

# if no default is specified, create an empty, required input
if parameter.default is inspect.Signature.empty:
if parameter.default is Signature.empty:
default = Input()
else:
default = parameter.default
Expand Down Expand Up @@ -374,14 +375,11 @@ class Input(BaseModel):
) # type: ignore


def get_return_annotation(fn: Callable[..., Any]) -> Optional[Type[Any]]:
def get_return_annotation(fn: Callable[..., Any]) -> "Type[Any]|Signature.empty":
try:
return get_type_hints(fn).get("return", None)
return get_type_hints(fn).get("return", Signature.empty)
except TypeError:
return_annotation = inspect.signature(fn).return_annotation
if return_annotation is inspect.Signature.empty:
return None
return return_annotation
return inspect.signature(fn).return_annotation


def get_output_type(predictor: BasePredictor) -> Type[BaseModel]:
Expand All @@ -392,7 +390,7 @@ def get_output_type(predictor: BasePredictor) -> Type[BaseModel]:
predict = get_predict(predictor)
maybe_output_type = get_return_annotation(predict)

if maybe_output_type is None:
if maybe_output_type is Signature.empty:
raise TypeError(
"""You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type.
Expand Down Expand Up @@ -496,7 +494,7 @@ def get_training_output_type(predictor: BasePredictor) -> Type[BaseModel]:
train = get_train(predictor)

TrainingOutputType = get_return_annotation(train) # pylint: disable=invalid-name
if not TrainingOutputType:
if TrainingOutputType is Signature.empty:
raise TypeError(
"""You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type.
Expand Down
1 change: 1 addition & 0 deletions test-integration/test_integration/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def test_predict_works_with_deferred_annotations():
timeout=DEFAULT_TIMEOUT,
)


def test_predict_works_with_partial_wrapper():
project_dir = Path(__file__).parent / "fixtures/partial-predict-project"

Expand Down

0 comments on commit 7789385

Please sign in to comment.