From 9c48ee2e90a917b569bffc4c8e37f1b14073a5da Mon Sep 17 00:00:00 2001 From: Abram Date: Sat, 25 May 2024 18:55:20 +0100 Subject: [PATCH] Refactor (sdk/agenta_decorator): update function_signature and override_schema to be Pydantic v2 compliance --- agenta-cli/agenta/sdk/agenta_decorator.py | 111 +++++++++++++++------- 1 file changed, 77 insertions(+), 34 deletions(-) diff --git a/agenta-cli/agenta/sdk/agenta_decorator.py b/agenta-cli/agenta/sdk/agenta_decorator.py index 4f5a4e75c7..4e20ff267a 100644 --- a/agenta-cli/agenta/sdk/agenta_decorator.py +++ b/agenta-cli/agenta/sdk/agenta_decorator.py @@ -57,7 +57,7 @@ def ingest_file(upfile: UploadFile): return InFile(file_name=upfile.filename, file_path=temp_file.name) -def entrypoint(func: Callable[..., Any]) -> Callable[..., Any]: +def entrypoint(func: Callable[..., Any]): """ Decorator to wrap a function for HTTP POST and terminal exposure. @@ -97,8 +97,8 @@ async def wrapper(*args, **kwargs) -> Any: # End trace recording tracing.end_recording( - outputs=llm_result.dict(), - span=tracing.active_trace, + outputs=llm_result.model_dump(), + span=tracing.active_trace, # type: ignore ) return llm_result @@ -130,8 +130,8 @@ async def wrapper_deployed(*args, **kwargs) -> Any: # End trace recording tracing.end_recording( - outputs=llm_result.dict(), - span=tracing.active_trace, + outputs=llm_result.model_dump(), + span=tracing.active_trace, # type: ignore ) return llm_result @@ -144,6 +144,7 @@ async def wrapper_deployed(*args, **kwargs) -> Any: func_signature, ingestible_files, ) + route_deployed = f"/{endpoint_name}_deployed" app.post(route_deployed, response_model=FuncResponse)(wrapper_deployed) override_schema( @@ -156,11 +157,10 @@ async def wrapper_deployed(*args, **kwargs) -> Any: if is_main_script(func): handle_terminal_run( func, - func_signature.parameters, + func_signature.parameters, # type: ignore config_params, ingestible_files, ) - return None def extract_ingestible_files( @@ -248,7 +248,7 @@ def update_wrapper_signature(wrapper: Callable[..., Any], updated_params: List): wrapper_signature = inspect.signature(wrapper) wrapper_signature = wrapper_signature.replace(parameters=updated_params) - wrapper.__signature__ = wrapper_signature + wrapper.__signature__ = wrapper_signature # type: ignore def update_function_signature( @@ -259,7 +259,7 @@ def update_function_signature( ) -> None: """Update the function signature to include new parameters.""" - updated_params = [] + updated_params: List[inspect.Parameter] = [] add_config_params_to_parser(updated_params, config_params) add_func_params_to_parser(updated_params, func_signature, ingestible_files) update_wrapper_signature(wrapper, updated_params) @@ -271,7 +271,7 @@ def update_deployed_function_signature( ingestible_files: Dict[str, inspect.Parameter], ) -> None: """Update the function signature to include new parameters.""" - updated_params = [] + updated_params: List[inspect.Parameter] = [] add_func_params_to_parser(updated_params, func_signature, ingestible_files) for param in [ "config", @@ -298,7 +298,11 @@ def add_config_params_to_parser( name, inspect.Parameter.KEYWORD_ONLY, default=Body(param), - annotation=Optional[type(param)], + annotation=param.__class__.__bases__[ + 0 + ], # determines and get the base (parent/inheritance) type of the sdk-type at run-time. \ + # E.g __class__ is ag.MessagesInput() and accessing it parent type will return (,), \ + # thus, why we are accessing the first item. ) ) @@ -319,8 +323,12 @@ def add_func_params_to_parser( inspect.Parameter( name, inspect.Parameter.KEYWORD_ONLY, - default=Body(..., embed=True), - annotation=param.annotation, + default=Body(param.default), + annotation=param.default.__class__.__bases__[ + 0 + ], # determines and get the base (parent/inheritance) type of the sdk-type at run-time. \ + # E.g __class__ is ag.MessagesInput() and accessing it parent type will return (,), \ + # thus, why we are accessing the first item. ) ) @@ -375,8 +383,8 @@ def handle_terminal_run( parser.add_argument( f"--{name}", type=str, - default=param.default, - choices=param.choices, + default=param.default, # type: ignore + choices=param.choices, # type: ignore ) else: parser.add_argument( @@ -426,7 +434,7 @@ def override_schema(openapi_schema: dict, func_name: str, endpoint: str, params: params (dict(param_name, param_val)): The dictionary of the parameters for the function """ - def find_in_schema(schema: dict, param_name: str, xparam: str): + def find_in_schema(schema_type_properties: dict, schema: dict, param_name: str, xparam: str): """Finds a parameter in the schema based on its name and x-parameter value""" for _, value in schema.items(): value_title_lower = str(value.get("title")).lower() @@ -438,9 +446,17 @@ def find_in_schema(schema: dict, param_name: str, xparam: str): if ( isinstance(value, dict) - and value.get("x-parameter") == xparam + and schema_type_properties.get("x-parameter") == xparam and value_title == param_name ): + # this will update the default type schema with the properties gotten + # from the schema type (param_val) __schema_properties__ classmethod + for type_key, type_value in schema_type_properties.items(): + # BEFORE: + # value = {'temperature': {'title': 'Temperature'}} + value[type_key] = type_value + # AFTER: + # value = {'temperature': { "type": "number", "title": "Temperature", "x-parameter": "float" }} return value schema_to_override = openapi_schema["components"]["schemas"][ @@ -448,16 +464,20 @@ def find_in_schema(schema: dict, param_name: str, xparam: str): ]["properties"] for param_name, param_val in params.items(): if isinstance(param_val, GroupedMultipleChoiceParam): - subschema = find_in_schema(schema_to_override, param_name, "grouped_choice") + subschema = find_in_schema( + param_val.__schema_type_properties__(), schema_to_override, param_name, "grouped_choice" + ) assert ( subschema ), f"GroupedMultipleChoiceParam '{param_name}' is in the parameters but could not be found in the openapi.json" - subschema["choices"] = param_val.choices - subschema["default"] = param_val.default + subschema["choices"] = param_val.choices # type: ignore + subschema["default"] = param_val.default # type: ignore if isinstance(param_val, MultipleChoiceParam): - subschema = find_in_schema(schema_to_override, param_name, "choice") + subschema = find_in_schema( + param_val.__schema_type_properties__(), schema_to_override, param_name, "choice" + ) default = str(param_val) - param_choices = param_val.choices + param_choices = param_val.choices # type: ignore choices = ( [default] + param_choices if param_val not in param_choices @@ -466,36 +486,59 @@ def find_in_schema(schema: dict, param_name: str, xparam: str): subschema["enum"] = choices subschema["default"] = default if default in param_choices else choices[0] if isinstance(param_val, FloatParam): - subschema = find_in_schema(schema_to_override, param_name, "float") - subschema["minimum"] = param_val.minval - subschema["maximum"] = param_val.maxval + subschema = find_in_schema( + param_val.__schema_type_properties__(), schema_to_override, param_name, "float" + ) + subschema["minimum"] = param_val.minval # type: ignore + subschema["maximum"] = param_val.maxval # type: ignore subschema["default"] = param_val if isinstance(param_val, IntParam): - subschema = find_in_schema(schema_to_override, param_name, "int") - subschema["minimum"] = param_val.minval - subschema["maximum"] = param_val.maxval + subschema = find_in_schema( + param_val.__schema_type_properties__(), schema_to_override, param_name, "int" + ) + subschema["minimum"] = param_val.minval # type: ignore + subschema["maximum"] = param_val.maxval # type: ignore subschema["default"] = param_val if ( isinstance(param_val, inspect.Parameter) and param_val.annotation is DictInput ): - subschema = find_in_schema(schema_to_override, param_name, "dict") + subschema = find_in_schema( + param_val.annotation.__schema_type_properties__(), + schema_to_override, + param_name, + "dict", + ) subschema["default"] = param_val.default["default_keys"] if isinstance(param_val, TextParam): - subschema = find_in_schema(schema_to_override, param_name, "text") + subschema = find_in_schema( + param_val.__schema_type_properties__(), schema_to_override, param_name, "text" + ) subschema["default"] = param_val if ( isinstance(param_val, inspect.Parameter) and param_val.annotation is MessagesInput ): - subschema = find_in_schema(schema_to_override, param_name, "messages") + subschema = find_in_schema( + param_val.annotation.__schema_type_properties__(), + schema_to_override, + param_name, + "messages", + ) subschema["default"] = param_val.default if ( isinstance(param_val, inspect.Parameter) and param_val.annotation is FileInputURL ): - subschema = find_in_schema(schema_to_override, param_name, "file_url") + subschema = find_in_schema( + param_val.annotation.__schema_type_properties__(), + schema_to_override, + param_name, + "file_url", + ) subschema["default"] = "https://example.com" if isinstance(param_val, BinaryParam): - subschema = find_in_schema(schema_to_override, param_name, "bool") - subschema["default"] = param_val.default + subschema = find_in_schema( + param_val.__schema_type_properties__(), schema_to_override, param_name, "bool" + ) + subschema["default"] = param_val.default # type: ignore