diff --git a/python/cog/predictor.py b/python/cog/predictor.py index 45056e8be8..9f0b2ad778 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -8,6 +8,7 @@ import uuid from abc import ABC, abstractmethod from collections.abc import Iterator +from inspect import Signature from pathlib import Path from typing import ( Any, @@ -267,7 +268,7 @@ def validate_input_type( type: Type[Any], # pylint: disable=redefined-builtin name: str, ) -> None: - if type is inspect.Signature.empty: + if type is Signature.empty: raise TypeError( f"No input type provided for parameter `{name}`. Supported input types are: {readable_types_list(ALLOWED_INPUT_TYPES)}, or a Union or List of those types." ) @@ -284,7 +285,7 @@ def validate_input_type( def get_input_create_model_kwargs( - signature: inspect.Signature, input_types: Dict[str, Any] + signature: Signature, input_types: Dict[str, Any] ) -> Dict[str, Any]: create_model_kwargs = {} @@ -296,7 +297,7 @@ def get_input_create_model_kwargs( validate_input_type(InputType, name) # if no default is specified, create an empty, required input - if parameter.default is inspect.Signature.empty: + if parameter.default is Signature.empty: default = Input() else: default = parameter.default @@ -374,14 +375,11 @@ class Input(BaseModel): ) # type: ignore -def get_return_annotation(fn: Callable[..., Any]) -> Optional[Type[Any]]: +def get_return_annotation(fn: Callable[..., Any]) -> "Type[Any]|Signature.empty": try: - return get_type_hints(fn).get("return", None) + return get_type_hints(fn).get("return", Signature.empty) except TypeError: - return_annotation = inspect.signature(fn).return_annotation - if return_annotation is inspect.Signature.empty: - return None - return return_annotation + return inspect.signature(fn).return_annotation def get_output_type(predictor: BasePredictor) -> Type[BaseModel]: @@ -392,7 +390,7 @@ def get_output_type(predictor: BasePredictor) -> Type[BaseModel]: predict = get_predict(predictor) maybe_output_type = get_return_annotation(predict) - if maybe_output_type is None: + if maybe_output_type is Signature.empty: raise TypeError( """You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type. @@ -496,7 +494,7 @@ def get_training_output_type(predictor: BasePredictor) -> Type[BaseModel]: train = get_train(predictor) TrainingOutputType = get_return_annotation(train) # pylint: disable=invalid-name - if not TrainingOutputType: + if TrainingOutputType is Signature.empty: raise TypeError( """You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type. diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index 31e1a1085b..2bb9295661 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -299,6 +299,7 @@ def test_predict_works_with_deferred_annotations(): timeout=DEFAULT_TIMEOUT, ) + def test_predict_works_with_partial_wrapper(): project_dir = Path(__file__).parent / "fixtures/partial-predict-project"