Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancement - Improve SDK output format #1025

Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions agenta-cli/agenta/sdk/agenta_decorator.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
"""The code for the Agenta SDK"""
import os
import sys
import time
import inspect
import argparse
import traceback
import functools
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Dict, Optional, Tuple, List
from typing import Any, Callable, Dict, Optional, Tuple, List, Union

from fastapi import Body, FastAPI, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware

import agenta
from .context import save_context
Expand All @@ -26,6 +27,7 @@
TextParam,
MessagesInput,
FileInputURL,
FuncResponse,
)

app = FastAPI()
Expand Down Expand Up @@ -90,15 +92,15 @@ async def wrapper_deployed(*args, **kwargs) -> Any:

update_function_signature(wrapper, func_signature, config_params, ingestible_files)
route = f"/{endpoint_name}"
app.post(route)(wrapper)
app.post(route, response_model=FuncResponse)(wrapper)

update_deployed_function_signature(
wrapper_deployed,
func_signature,
ingestible_files,
)
route_deployed = f"/{endpoint_name}_deployed"
app.post(route_deployed)(wrapper_deployed)
app.post(route_deployed, response_model=FuncResponse)(wrapper_deployed)
override_schema(
openapi_schema=app.openapi(),
func_name=func.__name__,
Expand Down Expand Up @@ -148,7 +150,9 @@ def ingest_files(
func_params[name] = ingest_file(func_params[name])


async def execute_function(func: Callable[..., Any], *args, **func_params) -> Any:
async def execute_function(
func: Callable[..., Any], *args, **func_params
) -> Union[Dict[str, Any], JSONResponse]:
"""Execute the function and handle any exceptions."""

try:
Expand All @@ -159,13 +163,22 @@ async def execute_function(func: Callable[..., Any], *args, **func_params) -> An
"""
is_coroutine_function = inspect.iscoroutinefunction(func)
if is_coroutine_function:
start_time = time.perf_counter()
result = await func(*args, **func_params)
end_time = time.perf_counter()
latency = end_time - start_time
else:
start_time = time.perf_counter()
result = func(*args, **func_params)
end_time = time.perf_counter()
latency = end_time - start_time

if isinstance(result, Context):
save_context(result)
return result
if isinstance(result, FuncResponse):
return FuncResponse(**result, latency=str(latency)).dict()
if isinstance(result, str):
return FuncResponse(message=result, latency=str(latency)).dict()
except Exception as e:
return handle_exception(e)

Expand Down
15 changes: 14 additions & 1 deletion agenta-cli/agenta/sdk/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Extra, HttpUrl

Expand All @@ -10,6 +10,19 @@ def __init__(self, file_name: str, file_path: str):
self.file_path = file_path


class FuncTokenUsage(BaseModel):
completion_tokens: str
prompt_tokens: str
total_tokens: str


class FuncResponse(BaseModel):
message: str
usage: Optional[FuncTokenUsage]
cost: Optional[str]
latency: str


class DictInput(dict):
def __new__(cls, default_keys=None):
instance = super().__new__(cls, default_keys)
Expand Down
Loading