From deb394570e1f0b0d047a582743fa6de582fa424b Mon Sep 17 00:00:00 2001 From: Abram Date: Tue, 12 Dec 2023 17:09:14 +0100 Subject: [PATCH] Update - refactor SDK to support asynchronous operations while maintaining backward compatibility for synchronous use --- agenta-cli/agenta/sdk/agenta_decorator.py | 74 +++++++++++++++++------ 1 file changed, 55 insertions(+), 19 deletions(-) diff --git a/agenta-cli/agenta/sdk/agenta_decorator.py b/agenta-cli/agenta/sdk/agenta_decorator.py index e132b084a1..d47a45f719 100644 --- a/agenta-cli/agenta/sdk/agenta_decorator.py +++ b/agenta-cli/agenta/sdk/agenta_decorator.py @@ -1,19 +1,19 @@ """The code for the Agenta SDK""" -import argparse -import functools -import inspect import os import sys +import inspect +import argparse import traceback +import functools from pathlib import Path from tempfile import NamedTemporaryFile -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple, List -import agenta from fastapi import Body, FastAPI, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse +import agenta from .context import save_context from .router import router as router from .types import ( @@ -62,20 +62,21 @@ def entrypoint(func: Callable[..., Any]) -> Callable[..., Any]: Returns: Wrapped function for HTTP POST and terminal. """ + endpoint_name = "generate" func_signature = inspect.signature(func) config_params = agenta.config.all() ingestible_files = extract_ingestible_files(func_signature) @functools.wraps(func) - def wrapper(*args, **kwargs) -> Any: + async def wrapper(*args, **kwargs) -> Any: func_params, api_config_params = split_kwargs(kwargs, config_params) ingest_files(func_params, ingestible_files) agenta.config.set(**api_config_params) - return execute_function(func, *args, **func_params) + return await execute_function(func, *args, **func_params) @functools.wraps(func) - def wrapper_deployed(*args, **kwargs) -> Any: + async def wrapper_deployed(*args, **kwargs) -> Any: func_params = { k: v for k, v in kwargs.items() if k not in ["config", "environment"] } @@ -83,9 +84,9 @@ def wrapper_deployed(*args, **kwargs) -> Any: agenta.config.pull(environment_name=kwargs["environment"]) elif "config" in kwargs and kwargs["config"] is not None: agenta.config.pull(config_name=kwargs["config"]) - else: # if no config is specified in the api call, we pull the default config + else: agenta.config.pull(config_name="default") - return execute_function(func, *args, **func_params) + return await execute_function(func, *args, **func_params) update_function_signature(wrapper, func_signature, config_params, ingestible_files) route = f"/{endpoint_name}" @@ -107,9 +108,11 @@ def wrapper_deployed(*args, **kwargs) -> Any: if is_main_script(func): handle_terminal_run( - func, func_signature.parameters, config_params, ingestible_files + func, + func_signature.parameters, + config_params, + ingestible_files, ) - return None @@ -117,6 +120,7 @@ def extract_ingestible_files( func_signature: inspect.Signature, ) -> Dict[str, inspect.Parameter]: """Extract parameters annotated as InFile from function signature.""" + return { name: param for name, param in func_signature.parameters.items() @@ -128,6 +132,7 @@ def split_kwargs( kwargs: Dict[str, Any], config_params: Dict[str, Any] ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Split keyword arguments into function parameters and API configuration parameters.""" + func_params = {k: v for k, v in kwargs.items() if k not in config_params} api_config_params = {k: v for k, v in kwargs.items() if k in config_params} return func_params, api_config_params @@ -137,15 +142,27 @@ def ingest_files( func_params: Dict[str, Any], ingestible_files: Dict[str, inspect.Parameter] ) -> None: """Ingest files specified in function parameters.""" + for name in ingestible_files: if name in func_params and func_params[name] is not None: func_params[name] = ingest_file(func_params[name]) -def execute_function(func: Callable[..., Any], *args, **func_params) -> Any: +async def execute_function(func: Callable[..., Any], *args, **func_params) -> Any: """Execute the function and handle any exceptions.""" + try: - result = func(*args, **func_params) + """Note: The following block is for backward compatibility. + It allows functions to work seamlessly whether they are synchronous or asynchronous. + For synchronous functions, it calls them directly, while for asynchronous functions, + it awaits their execution. + """ + is_coroutine_function = inspect.iscoroutinefunction(func) + if is_coroutine_function: + result = await func(*args, **func_params) + else: + result = func(*args, **func_params) + if isinstance(result, Context): save_context(result) return result @@ -155,6 +172,7 @@ def execute_function(func: Callable[..., Any], *args, **func_params) -> Any: def handle_exception(e: Exception) -> JSONResponse: """Handle exceptions and return a JSONResponse.""" + traceback_str = traceback.format_exception(e, value=e, tb=e.__traceback__) return JSONResponse( status_code=500, @@ -162,6 +180,21 @@ def handle_exception(e: Exception) -> JSONResponse: ) +def update_wrapper_signature(wrapper: Callable[..., Any], updated_params: List): + """ + Updates the signature of a wrapper function with a new list of parameters. + + Args: + wrapper (callable): A callable object, such as a function or a method, that requires a signature update. + updated_params (List[inspect.Parameter]): A list of `inspect.Parameter` objects representing the updated parameters + for the wrapper function. + """ + + wrapper_signature = inspect.signature(wrapper) + wrapper_signature = wrapper_signature.replace(parameters=updated_params) + wrapper.__signature__ = wrapper_signature + + def update_function_signature( wrapper: Callable[..., Any], func_signature: inspect.Signature, @@ -169,10 +202,11 @@ def update_function_signature( ingestible_files: Dict[str, inspect.Parameter], ) -> None: """Update the function signature to include new parameters.""" + updated_params = [] add_config_params_to_parser(updated_params, config_params) add_func_params_to_parser(updated_params, func_signature, ingestible_files) - wrapper.__signature__ = func_signature.replace(parameters=updated_params) + update_wrapper_signature(wrapper, updated_params) def update_deployed_function_signature( @@ -195,7 +229,7 @@ def update_deployed_function_signature( annotation=str, ) ) - wrapper.__signature__ = func_signature.replace(parameters=updated_params) + update_wrapper_signature(wrapper, updated_params) def add_config_params_to_parser( @@ -271,13 +305,15 @@ def handle_terminal_run( Example: handle_terminal_run(func_params=inspect.signature(my_function).parameters, config_params=config.all()) """ - parser = argparse.ArgumentParser() + # For required parameters, we add them as arguments + parser = argparse.ArgumentParser() for name, param in func_params.items(): if name in ingestible_files: parser.add_argument(name, type=str) else: parser.add_argument(name, type=param.annotation) + for name, param in config_params.items(): if type(param) is MultipleChoiceParam: parser.add_argument( @@ -295,7 +331,8 @@ def handle_terminal_run( args = parser.parse_args() - # split the arg list into the arg in the app_param and the arge from the sig.parameter + # split the arg list into the arg in the app_param and + # the args from the sig.parameter args_config_params = {k: v for k, v in vars(args).items() if k in config_params} args_func_params = {k: v for k, v in vars(args).items() if k not in config_params} for name in ingestible_files: @@ -304,7 +341,6 @@ def handle_terminal_run( file_path=args_func_params[name], ) agenta.config.set(**args_config_params) - # print(func(**args_func_params)) def override_schema(openapi_schema: dict, func_name: str, endpoint: str, params: dict):