Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancement - Improve SDK output format #1025

Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions agenta-cli/agenta/sdk/agenta_decorator.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
"""The code for the Agenta SDK"""
import os
import sys
import time
import inspect
import argparse
import traceback
import functools
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Dict, Optional, Tuple, List
from typing import Any, Callable, Dict, Optional, Tuple, List, Union

from fastapi import Body, FastAPI, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware

import agenta
from .context import save_context
Expand All @@ -26,6 +27,7 @@
TextParam,
MessagesInput,
FileInputURL,
FuncResponse,
)

app = FastAPI()
Expand Down Expand Up @@ -90,15 +92,15 @@ async def wrapper_deployed(*args, **kwargs) -> Any:

update_function_signature(wrapper, func_signature, config_params, ingestible_files)
route = f"/{endpoint_name}"
app.post(route)(wrapper)
app.post(route, response_model=FuncResponse)(wrapper)

update_deployed_function_signature(
wrapper_deployed,
func_signature,
ingestible_files,
)
route_deployed = f"/{endpoint_name}_deployed"
app.post(route_deployed)(wrapper_deployed)
app.post(route_deployed, response_model=FuncResponse)(wrapper_deployed)
override_schema(
openapi_schema=app.openapi(),
func_name=func.__name__,
Expand Down Expand Up @@ -148,7 +150,9 @@ def ingest_files(
func_params[name] = ingest_file(func_params[name])


async def execute_function(func: Callable[..., Any], *args, **func_params) -> Any:
async def execute_function(
func: Callable[..., Any], *args, **func_params
) -> Union[Dict[str, Any], JSONResponse]:
"""Execute the function and handle any exceptions."""

try:
Expand All @@ -159,13 +163,22 @@ async def execute_function(func: Callable[..., Any], *args, **func_params) -> An
"""
is_coroutine_function = inspect.iscoroutinefunction(func)
if is_coroutine_function:
start_time = time.perf_counter()
result = await func(*args, **func_params)
end_time = time.perf_counter()
latency = end_time - start_time
else:
start_time = time.perf_counter()
result = func(*args, **func_params)
end_time = time.perf_counter()
latency = end_time - start_time

if isinstance(result, Context):
save_context(result)
return result
if isinstance(result, Dict):
return FuncResponse(**result, latency=round(latency, 4)).dict()
if isinstance(result, str):
return FuncResponse(message=result, latency=round(latency, 4)).dict()
except Exception as e:
return handle_exception(e)

Expand Down
15 changes: 14 additions & 1 deletion agenta-cli/agenta/sdk/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Extra, HttpUrl

Expand All @@ -10,6 +10,19 @@ def __init__(self, file_name: str, file_path: str):
self.file_path = file_path


class LLMTokenUsage(BaseModel):
completion_tokens: int
prompt_tokens: int
total_tokens: int


class FuncResponse(BaseModel):
message: str
usage: Optional[LLMTokenUsage]
cost: Optional[str]
latency: float


class DictInput(dict):
def __new__(cls, default_keys=None):
instance = super().__new__(cls, default_keys)
Expand Down
40 changes: 40 additions & 0 deletions examples/async_chat_sdk_output_format/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import agenta as ag
from agenta import FloatParam, MessagesInput, MultipleChoiceParam
from openai import AsyncOpenAI


client = AsyncOpenAI()

SYSTEM_PROMPT = "You have expertise in offering technical ideas to startups."
CHAT_LLM_GPT = [
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4",
]

ag.init()
ag.config.default(
temperature=FloatParam(0.2),
model=MultipleChoiceParam("gpt-3.5-turbo", CHAT_LLM_GPT),
max_tokens=ag.IntParam(-1, -1, 4000),
prompt_system=ag.TextParam(SYSTEM_PROMPT),
)


@ag.entrypoint
async def chat(inputs: MessagesInput = MessagesInput()) -> str:
aybruhm marked this conversation as resolved.
Show resolved Hide resolved
messages = [{"role": "system", "content": ag.config.prompt_system}] + inputs
max_tokens = ag.config.max_tokens if ag.config.max_tokens != -1 else None
chat_completion = await client.chat.completions.create(
model=ag.config.model,
messages=messages,
temperature=ag.config.temperature,
max_tokens=max_tokens,
)
return {
"message": chat_completion.choices[0].message.content,
**{"usage": chat_completion.usage.dict()}
# "cost": ...
aybruhm marked this conversation as resolved.
Show resolved Hide resolved
}
2 changes: 2 additions & 0 deletions examples/async_chat_sdk_output_format/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
agenta
openai