Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancement - refactor SDK to support asynchronous operations #1043

Merged
merged 2 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 55 additions & 19 deletions agenta-cli/agenta/sdk/agenta_decorator.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
"""The code for the Agenta SDK"""
import argparse
import functools
import inspect
import os
import sys
import inspect
import argparse
import traceback
import functools
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple, List

import agenta
from fastapi import Body, FastAPI, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse

import agenta
from .context import save_context
from .router import router as router
from .types import (
Expand Down Expand Up @@ -62,30 +62,31 @@ def entrypoint(func: Callable[..., Any]) -> Callable[..., Any]:
Returns:
Wrapped function for HTTP POST and terminal.
"""

endpoint_name = "generate"
func_signature = inspect.signature(func)
config_params = agenta.config.all()
ingestible_files = extract_ingestible_files(func_signature)

@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
async def wrapper(*args, **kwargs) -> Any:
func_params, api_config_params = split_kwargs(kwargs, config_params)
ingest_files(func_params, ingestible_files)
agenta.config.set(**api_config_params)
return execute_function(func, *args, **func_params)
return await execute_function(func, *args, **func_params)

@functools.wraps(func)
def wrapper_deployed(*args, **kwargs) -> Any:
async def wrapper_deployed(*args, **kwargs) -> Any:
func_params = {
k: v for k, v in kwargs.items() if k not in ["config", "environment"]
}
if "environment" in kwargs and kwargs["environment"] is not None:
agenta.config.pull(environment_name=kwargs["environment"])
elif "config" in kwargs and kwargs["config"] is not None:
agenta.config.pull(config_name=kwargs["config"])
else: # if no config is specified in the api call, we pull the default config
else:
agenta.config.pull(config_name="default")
return execute_function(func, *args, **func_params)
return await execute_function(func, *args, **func_params)

update_function_signature(wrapper, func_signature, config_params, ingestible_files)
route = f"/{endpoint_name}"
Expand All @@ -107,16 +108,19 @@ def wrapper_deployed(*args, **kwargs) -> Any:

if is_main_script(func):
handle_terminal_run(
func, func_signature.parameters, config_params, ingestible_files
func,
func_signature.parameters,
config_params,
ingestible_files,
)

return None


def extract_ingestible_files(
func_signature: inspect.Signature,
) -> Dict[str, inspect.Parameter]:
"""Extract parameters annotated as InFile from function signature."""

return {
name: param
for name, param in func_signature.parameters.items()
Expand All @@ -128,6 +132,7 @@ def split_kwargs(
kwargs: Dict[str, Any], config_params: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Split keyword arguments into function parameters and API configuration parameters."""

func_params = {k: v for k, v in kwargs.items() if k not in config_params}
api_config_params = {k: v for k, v in kwargs.items() if k in config_params}
return func_params, api_config_params
Expand All @@ -137,15 +142,27 @@ def ingest_files(
func_params: Dict[str, Any], ingestible_files: Dict[str, inspect.Parameter]
) -> None:
"""Ingest files specified in function parameters."""

for name in ingestible_files:
if name in func_params and func_params[name] is not None:
func_params[name] = ingest_file(func_params[name])


def execute_function(func: Callable[..., Any], *args, **func_params) -> Any:
async def execute_function(func: Callable[..., Any], *args, **func_params) -> Any:
"""Execute the function and handle any exceptions."""

try:
result = func(*args, **func_params)
"""Note: The following block is for backward compatibility.
It allows functions to work seamlessly whether they are synchronous or asynchronous.
For synchronous functions, it calls them directly, while for asynchronous functions,
it awaits their execution.
"""
is_coroutine_function = inspect.iscoroutinefunction(func)
if is_coroutine_function:
result = await func(*args, **func_params)
else:
result = func(*args, **func_params)

if isinstance(result, Context):
save_context(result)
return result
Expand All @@ -155,24 +172,41 @@ def execute_function(func: Callable[..., Any], *args, **func_params) -> Any:

def handle_exception(e: Exception) -> JSONResponse:
"""Handle exceptions and return a JSONResponse."""

traceback_str = traceback.format_exception(e, value=e, tb=e.__traceback__)
return JSONResponse(
status_code=500,
content={"error": str(e), "traceback": "".join(traceback_str)},
)


def update_wrapper_signature(wrapper: Callable[..., Any], updated_params: List):
"""
Updates the signature of a wrapper function with a new list of parameters.

Args:
wrapper (callable): A callable object, such as a function or a method, that requires a signature update.
updated_params (List[inspect.Parameter]): A list of `inspect.Parameter` objects representing the updated parameters
for the wrapper function.
"""

wrapper_signature = inspect.signature(wrapper)
wrapper_signature = wrapper_signature.replace(parameters=updated_params)
wrapper.__signature__ = wrapper_signature


def update_function_signature(
wrapper: Callable[..., Any],
func_signature: inspect.Signature,
config_params: Dict[str, Any],
ingestible_files: Dict[str, inspect.Parameter],
) -> None:
"""Update the function signature to include new parameters."""

updated_params = []
add_config_params_to_parser(updated_params, config_params)
add_func_params_to_parser(updated_params, func_signature, ingestible_files)
wrapper.__signature__ = func_signature.replace(parameters=updated_params)
update_wrapper_signature(wrapper, updated_params)


def update_deployed_function_signature(
Expand All @@ -195,7 +229,7 @@ def update_deployed_function_signature(
annotation=str,
)
)
wrapper.__signature__ = func_signature.replace(parameters=updated_params)
update_wrapper_signature(wrapper, updated_params)


def add_config_params_to_parser(
Expand Down Expand Up @@ -271,13 +305,15 @@ def handle_terminal_run(
Example:
handle_terminal_run(func_params=inspect.signature(my_function).parameters, config_params=config.all())
"""
parser = argparse.ArgumentParser()

# For required parameters, we add them as arguments
parser = argparse.ArgumentParser()
for name, param in func_params.items():
if name in ingestible_files:
parser.add_argument(name, type=str)
else:
parser.add_argument(name, type=param.annotation)

for name, param in config_params.items():
if type(param) is MultipleChoiceParam:
parser.add_argument(
Expand All @@ -295,7 +331,8 @@ def handle_terminal_run(

args = parser.parse_args()

# split the arg list into the arg in the app_param and the arge from the sig.parameter
# split the arg list into the arg in the app_param and
# the args from the sig.parameter
args_config_params = {k: v for k, v in vars(args).items() if k in config_params}
args_func_params = {k: v for k, v in vars(args).items() if k not in config_params}
for name in ingestible_files:
Expand All @@ -304,7 +341,6 @@ def handle_terminal_run(
file_path=args_func_params[name],
)
agenta.config.set(**args_config_params)
# print(func(**args_func_params))


def override_schema(openapi_schema: dict, func_name: str, endpoint: str, params: dict):
Expand Down
36 changes: 36 additions & 0 deletions examples/async_startup_technical_ideas/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import agenta as ag
from agenta import FloatParam, MessagesInput, MultipleChoiceParam
from openai import AsyncOpenAI


client = AsyncOpenAI()

SYSTEM_PROMPT = "You have expertise in offering technical ideas to startups."
CHAT_LLM_GPT = [
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4",
]

ag.init()
ag.config.default(
temperature=FloatParam(0.2),
model=MultipleChoiceParam("gpt-3.5-turbo", CHAT_LLM_GPT),
max_tokens=ag.IntParam(-1, -1, 4000),
prompt_system=ag.TextParam(SYSTEM_PROMPT),
)


@ag.entrypoint
async def chat(inputs: MessagesInput = MessagesInput()) -> str:
messages = [{"role": "system", "content": ag.config.prompt_system}] + inputs
max_tokens = ag.config.max_tokens if ag.config.max_tokens != -1 else None
chat_completion = await client.chat.completions.create(
model=ag.config.model,
messages=messages,
temperature=ag.config.temperature,
max_tokens=max_tokens,
)
return chat_completion.choices[0].message.content
2 changes: 2 additions & 0 deletions examples/async_startup_technical_ideas/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
agenta
openai
Loading