Skip to content

Commit

Permalink
Refactor (sdk/agenta_decorator): update function_signature and overri…
Browse files Browse the repository at this point in the history
…de_schema to be Pydantic v2 compliance
  • Loading branch information
aybruhm committed May 25, 2024
1 parent 0b30431 commit 9c48ee2
Showing 1 changed file with 77 additions and 34 deletions.
111 changes: 77 additions & 34 deletions agenta-cli/agenta/sdk/agenta_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def ingest_file(upfile: UploadFile):
return InFile(file_name=upfile.filename, file_path=temp_file.name)


def entrypoint(func: Callable[..., Any]) -> Callable[..., Any]:
def entrypoint(func: Callable[..., Any]):
"""
Decorator to wrap a function for HTTP POST and terminal exposure.
Expand Down Expand Up @@ -97,8 +97,8 @@ async def wrapper(*args, **kwargs) -> Any:

# End trace recording
tracing.end_recording(
outputs=llm_result.dict(),
span=tracing.active_trace,
outputs=llm_result.model_dump(),
span=tracing.active_trace, # type: ignore
)
return llm_result

Expand Down Expand Up @@ -130,8 +130,8 @@ async def wrapper_deployed(*args, **kwargs) -> Any:

# End trace recording
tracing.end_recording(
outputs=llm_result.dict(),
span=tracing.active_trace,
outputs=llm_result.model_dump(),
span=tracing.active_trace, # type: ignore
)
return llm_result

Expand All @@ -144,6 +144,7 @@ async def wrapper_deployed(*args, **kwargs) -> Any:
func_signature,
ingestible_files,
)

route_deployed = f"/{endpoint_name}_deployed"
app.post(route_deployed, response_model=FuncResponse)(wrapper_deployed)
override_schema(
Expand All @@ -156,11 +157,10 @@ async def wrapper_deployed(*args, **kwargs) -> Any:
if is_main_script(func):
handle_terminal_run(
func,
func_signature.parameters,
func_signature.parameters, # type: ignore
config_params,
ingestible_files,
)
return None


def extract_ingestible_files(
Expand Down Expand Up @@ -248,7 +248,7 @@ def update_wrapper_signature(wrapper: Callable[..., Any], updated_params: List):

wrapper_signature = inspect.signature(wrapper)
wrapper_signature = wrapper_signature.replace(parameters=updated_params)
wrapper.__signature__ = wrapper_signature
wrapper.__signature__ = wrapper_signature # type: ignore


def update_function_signature(
Expand All @@ -259,7 +259,7 @@ def update_function_signature(
) -> None:
"""Update the function signature to include new parameters."""

updated_params = []
updated_params: List[inspect.Parameter] = []
add_config_params_to_parser(updated_params, config_params)
add_func_params_to_parser(updated_params, func_signature, ingestible_files)
update_wrapper_signature(wrapper, updated_params)
Expand All @@ -271,7 +271,7 @@ def update_deployed_function_signature(
ingestible_files: Dict[str, inspect.Parameter],
) -> None:
"""Update the function signature to include new parameters."""
updated_params = []
updated_params: List[inspect.Parameter] = []
add_func_params_to_parser(updated_params, func_signature, ingestible_files)
for param in [
"config",
Expand All @@ -298,7 +298,11 @@ def add_config_params_to_parser(
name,
inspect.Parameter.KEYWORD_ONLY,
default=Body(param),
annotation=Optional[type(param)],
annotation=param.__class__.__bases__[
0
], # determines and get the base (parent/inheritance) type of the sdk-type at run-time. \
# E.g __class__ is ag.MessagesInput() and accessing it parent type will return (<class 'list'>,), \
# thus, why we are accessing the first item.
)
)

Expand All @@ -319,8 +323,12 @@ def add_func_params_to_parser(
inspect.Parameter(
name,
inspect.Parameter.KEYWORD_ONLY,
default=Body(..., embed=True),
annotation=param.annotation,
default=Body(param.default),
annotation=param.default.__class__.__bases__[
0
], # determines and get the base (parent/inheritance) type of the sdk-type at run-time. \
# E.g __class__ is ag.MessagesInput() and accessing it parent type will return (<class 'list'>,), \
# thus, why we are accessing the first item.
)
)

Expand Down Expand Up @@ -375,8 +383,8 @@ def handle_terminal_run(
parser.add_argument(
f"--{name}",
type=str,
default=param.default,
choices=param.choices,
default=param.default, # type: ignore
choices=param.choices, # type: ignore
)
else:
parser.add_argument(
Expand Down Expand Up @@ -426,7 +434,7 @@ def override_schema(openapi_schema: dict, func_name: str, endpoint: str, params:
params (dict(param_name, param_val)): The dictionary of the parameters for the function
"""

def find_in_schema(schema: dict, param_name: str, xparam: str):
def find_in_schema(schema_type_properties: dict, schema: dict, param_name: str, xparam: str):
"""Finds a parameter in the schema based on its name and x-parameter value"""
for _, value in schema.items():
value_title_lower = str(value.get("title")).lower()
Expand All @@ -438,26 +446,38 @@ def find_in_schema(schema: dict, param_name: str, xparam: str):

if (
isinstance(value, dict)
and value.get("x-parameter") == xparam
and schema_type_properties.get("x-parameter") == xparam
and value_title == param_name
):
# this will update the default type schema with the properties gotten
# from the schema type (param_val) __schema_properties__ classmethod
for type_key, type_value in schema_type_properties.items():
# BEFORE:
# value = {'temperature': {'title': 'Temperature'}}
value[type_key] = type_value
# AFTER:
# value = {'temperature': { "type": "number", "title": "Temperature", "x-parameter": "float" }}
return value

schema_to_override = openapi_schema["components"]["schemas"][
f"Body_{func_name}_{endpoint}_post"
]["properties"]
for param_name, param_val in params.items():
if isinstance(param_val, GroupedMultipleChoiceParam):
subschema = find_in_schema(schema_to_override, param_name, "grouped_choice")
subschema = find_in_schema(
param_val.__schema_type_properties__(), schema_to_override, param_name, "grouped_choice"
)
assert (
subschema
), f"GroupedMultipleChoiceParam '{param_name}' is in the parameters but could not be found in the openapi.json"
subschema["choices"] = param_val.choices
subschema["default"] = param_val.default
subschema["choices"] = param_val.choices # type: ignore
subschema["default"] = param_val.default # type: ignore
if isinstance(param_val, MultipleChoiceParam):
subschema = find_in_schema(schema_to_override, param_name, "choice")
subschema = find_in_schema(
param_val.__schema_type_properties__(), schema_to_override, param_name, "choice"
)
default = str(param_val)
param_choices = param_val.choices
param_choices = param_val.choices # type: ignore
choices = (
[default] + param_choices
if param_val not in param_choices
Expand All @@ -466,36 +486,59 @@ def find_in_schema(schema: dict, param_name: str, xparam: str):
subschema["enum"] = choices
subschema["default"] = default if default in param_choices else choices[0]
if isinstance(param_val, FloatParam):
subschema = find_in_schema(schema_to_override, param_name, "float")
subschema["minimum"] = param_val.minval
subschema["maximum"] = param_val.maxval
subschema = find_in_schema(
param_val.__schema_type_properties__(), schema_to_override, param_name, "float"
)
subschema["minimum"] = param_val.minval # type: ignore
subschema["maximum"] = param_val.maxval # type: ignore
subschema["default"] = param_val
if isinstance(param_val, IntParam):
subschema = find_in_schema(schema_to_override, param_name, "int")
subschema["minimum"] = param_val.minval
subschema["maximum"] = param_val.maxval
subschema = find_in_schema(
param_val.__schema_type_properties__(), schema_to_override, param_name, "int"
)
subschema["minimum"] = param_val.minval # type: ignore
subschema["maximum"] = param_val.maxval # type: ignore
subschema["default"] = param_val
if (
isinstance(param_val, inspect.Parameter)
and param_val.annotation is DictInput
):
subschema = find_in_schema(schema_to_override, param_name, "dict")
subschema = find_in_schema(
param_val.annotation.__schema_type_properties__(),
schema_to_override,
param_name,
"dict",
)
subschema["default"] = param_val.default["default_keys"]
if isinstance(param_val, TextParam):
subschema = find_in_schema(schema_to_override, param_name, "text")
subschema = find_in_schema(
param_val.__schema_type_properties__(), schema_to_override, param_name, "text"
)
subschema["default"] = param_val
if (
isinstance(param_val, inspect.Parameter)
and param_val.annotation is MessagesInput
):
subschema = find_in_schema(schema_to_override, param_name, "messages")
subschema = find_in_schema(
param_val.annotation.__schema_type_properties__(),
schema_to_override,
param_name,
"messages",
)
subschema["default"] = param_val.default
if (
isinstance(param_val, inspect.Parameter)
and param_val.annotation is FileInputURL
):
subschema = find_in_schema(schema_to_override, param_name, "file_url")
subschema = find_in_schema(
param_val.annotation.__schema_type_properties__(),
schema_to_override,
param_name,
"file_url",
)
subschema["default"] = "https://example.com"
if isinstance(param_val, BinaryParam):
subschema = find_in_schema(schema_to_override, param_name, "bool")
subschema["default"] = param_val.default
subschema = find_in_schema(
param_val.__schema_type_properties__(), schema_to_override, param_name, "bool"
)
subschema["default"] = param_val.default # type: ignore

0 comments on commit 9c48ee2

Please sign in to comment.