Skip to content

Commit

Permalink
fix defered annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
technillogue committed Aug 22, 2024
1 parent f594145 commit c83a8e4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
40 changes: 25 additions & 15 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,7 @@ def get_input_create_model_kwargs(
order = 0

for name, parameter in signature.parameters.items():
if name not in input_types:
raise TypeError(f"No input type provided for parameter `{name}`.")

InputType = input_types[name] # pylint: disable=invalid-name
InputType = input_types.get(name, parameter.annotation)

validate_input_type(InputType, name)

Expand Down Expand Up @@ -360,7 +357,10 @@ class Input(BaseModel):
predict = get_predict(predictor)
signature = inspect.signature(predict)

input_types = get_type_hints(predict)
try:
input_types = get_type_hints(predict)
except TypeError:
input_types = {}
if "return" in input_types:
del input_types["return"]

Expand All @@ -374,16 +374,25 @@ class Input(BaseModel):
) # type: ignore


def get_return_annotation(fn: Callable[..., Any]) -> Optional[Type[Any]]:
try:
return get_type_hints(fn).get("return", None)
except TypeError:
return_annotation = inspect.signature(fn).return_annotation
if return_annotation is inspect.Signature.empty:
return None
return return_annotation


def get_output_type(predictor: BasePredictor) -> Type[BaseModel]:
"""
Creates a Pydantic Output model from the return type annotation of a Predictor's predict() method.
"""

predict = get_predict(predictor)
maybe_output_type = get_return_annotation(predict)

input_types = get_type_hints(predict)

if "return" not in input_types:
if maybe_output_type is None:
raise TypeError(
"""You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type.
Expand All @@ -398,8 +407,8 @@ def predict(
...
"""
)
OutputType = input_types.pop("return") # pylint: disable=invalid-name

# we need the indirection to narrow the type to the one that's declared later
OutputType = maybe_output_type # pylint: disable=invalid-name
# The type that goes in the response is a list of the yielded type
if get_origin(OutputType) is Iterator:
# Annotated allows us to attach Field annotations to the list, which we use to mark that this is an iterator
Expand Down Expand Up @@ -462,7 +471,10 @@ class TrainingInput(BaseModel):
train = get_train(predictor)
signature = inspect.signature(train)

input_types = get_type_hints(train)
try:
input_types = get_type_hints(train)
except TypeError:
input_types = {}
if "return" in input_types:
del input_types["return"]

Expand All @@ -483,8 +495,8 @@ def get_training_output_type(predictor: BasePredictor) -> Type[BaseModel]:

train = get_train(predictor)

input_types = get_type_hints(train)
if "return" not in input_types:
TrainingOutputType = get_return_annotation(train) # pylint: disable=invalid-name
if not TrainingOutputType:
raise TypeError(
"""You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type.
Expand All @@ -500,8 +512,6 @@ def train(
"""
)

TrainingOutputType = input_types.pop("return") # pylint: disable=invalid-name

name = (
TrainingOutputType.__name__ if hasattr(TrainingOutputType, "__name__") else ""
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable
import functools
import inspect
from typing import Any, Callable

from cog import BasePredictor, Input

Expand All @@ -10,7 +11,7 @@ def general(
) -> int:
return 1

def _remove(f: Callable, defaults: dict[str, Any]) -> Callable:
def _remove(f: Callable, defaults: "dict[str, Any]") -> Callable:
# pylint: disable=no-self-argument
def wrapper(self, *args, **kwargs):
kwargs.update(defaults)
Expand All @@ -31,7 +32,7 @@ def wrapper(self, *args, **kwargs):
predict = _remove(general, {"system_prompt": ""})


def _train(self, prompt: str = Input(description="hi"), system_prompt: str = None):
def _train(prompt: str = Input(description="hi"), system_prompt: str = None):
return 1


Expand Down

0 comments on commit c83a8e4

Please sign in to comment.