From e9356b3aa857d1508807fff78228c87765b3a8d0 Mon Sep 17 00:00:00 2001 From: Abram Date: Sun, 10 Dec 2023 14:27:21 +0100 Subject: [PATCH] Update - improve execute_function and entrypoint functions --- agenta-cli/agenta/sdk/agenta_decorator.py | 45 +++++++++++++++++------ 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/agenta-cli/agenta/sdk/agenta_decorator.py b/agenta-cli/agenta/sdk/agenta_decorator.py index e132b084a1..f9893dac67 100644 --- a/agenta-cli/agenta/sdk/agenta_decorator.py +++ b/agenta-cli/agenta/sdk/agenta_decorator.py @@ -4,16 +4,17 @@ import inspect import os import sys +import time import traceback from pathlib import Path from tempfile import NamedTemporaryFile -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple, Union, TypeVar -import agenta +from fastapi.responses import JSONResponse from fastapi import Body, FastAPI, UploadFile from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse +import agenta from .context import save_context from .router import router as router from .types import ( @@ -26,9 +27,11 @@ TextParam, MessagesInput, FileInputURL, + FuncResponse, ) app = FastAPI() +T = TypeVar("T") origins = [ "*", @@ -52,7 +55,7 @@ def ingest_file(upfile: UploadFile): return InFile(file_name=upfile.filename, file_path=temp_file.name) -def entrypoint(func: Callable[..., Any]) -> Callable[..., Any]: +def entrypoint(func: Callable[..., T]) -> Callable[..., Dict[str, T]]: """ Decorator to wrap a function for HTTP POST and terminal exposure. @@ -68,14 +71,14 @@ def entrypoint(func: Callable[..., Any]) -> Callable[..., Any]: ingestible_files = extract_ingestible_files(func_signature) @functools.wraps(func) - def wrapper(*args, **kwargs) -> Any: + def wrapper(*args, **kwargs) -> Dict[str, T]: func_params, api_config_params = split_kwargs(kwargs, config_params) ingest_files(func_params, ingestible_files) agenta.config.set(**api_config_params) return execute_function(func, *args, **func_params) @functools.wraps(func) - def wrapper_deployed(*args, **kwargs) -> Any: + def wrapper_deployed(*args, **kwargs) -> FuncResponse: func_params = { k: v for k, v in kwargs.items() if k not in ["config", "environment"] } @@ -89,7 +92,7 @@ def wrapper_deployed(*args, **kwargs) -> Any: update_function_signature(wrapper, func_signature, config_params, ingestible_files) route = f"/{endpoint_name}" - app.post(route)(wrapper) + app.post(route, response_model=FuncResponse)(wrapper) update_deployed_function_signature( wrapper_deployed, @@ -97,7 +100,7 @@ def wrapper_deployed(*args, **kwargs) -> Any: ingestible_files, ) route_deployed = f"/{endpoint_name}_deployed" - app.post(route_deployed)(wrapper_deployed) + app.post(route_deployed, response_model=FuncResponse)(wrapper_deployed) override_schema( openapi_schema=app.openapi(), func_name=func.__name__, @@ -142,13 +145,33 @@ def ingest_files( func_params[name] = ingest_file(func_params[name]) -def execute_function(func: Callable[..., Any], *args, **func_params) -> Any: - """Execute the function and handle any exceptions.""" +def execute_function( + func: Callable[..., Any], *args, **func_params +) -> Union[Dict[str, Any], JSONResponse]: + """ + Execute the given function and handle any exceptions. + + Parameters: + - func: The function to be executed. + - args: Positional arguments for the function. + - func_params: Keyword arguments for the function. + + Returns: + Either a dictionary or a JSONResponse object. + """ + try: + start_time = time.time() result = func(*args, **func_params) + end_time = time.time() + latency = end_time - start_time + if isinstance(result, Context): save_context(result) - return result + if isinstance(result, FuncResponse): + return FuncResponse(**result, latency=str(latency)).dict() + if isinstance(result, str): + return FuncResponse(message=result, latency=str(latency)).dict() except Exception as e: return handle_exception(e)