Skip to content

Commit

Permalink
Refactor - modified obs sdk and backend logic
Browse files Browse the repository at this point in the history
  • Loading branch information
aybruhm committed Mar 25, 2024
1 parent f33a84f commit a409093
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class Trace(Span):

class TraceDetail(Trace):
content: Dict[str, Any]
variant_config: Dict[str, Any]
config: Dict[str, Any]


class ObservabilityData(BaseModel):
Expand Down
1 change: 0 additions & 1 deletion agenta-backend/agenta_backend/models/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,6 @@ class SpanDB(Document):
user: Optional[str]
environment: Optional[str] # request source -> playground, development, etc
start_time: datetime
config: Optional[Dict[str, Any]]
end_time: datetime = Field(default=datetime.now())
tokens: Optional[LLMTokens]
cost: Optional[float]
Expand Down
2 changes: 1 addition & 1 deletion agenta-backend/agenta_backend/services/event_db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ async def fetch_generation_span_detail(span_id: str) -> SpanDetail:
],
"output": span_db.output,
},
"config": span_db.config,
"config": span_db.meta.get("model_config"),
},
)

Expand Down
2 changes: 1 addition & 1 deletion agenta-cli/agenta/docker/docker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

DEBUG = False
DEBUG = True


def create_dockerfile(out_folder: Path):
Expand Down
50 changes: 32 additions & 18 deletions agenta-cli/agenta/sdk/agenta_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import functools
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Dict, Optional, Tuple, List, Union
from typing import Any, Callable, Dict, Optional, Tuple, List

from fastapi import Body, FastAPI, UploadFile
from fastapi.responses import JSONResponse
from fastapi import Body, FastAPI, UploadFile
from fastapi.middleware.cors import CORSMiddleware

import agenta
Expand Down Expand Up @@ -84,18 +84,25 @@ async def wrapper(*args, **kwargs) -> Any:
trace_name=func.__name__,
inputs=func_params,
config=api_config_params,
**{"environment": "playground"}, # type: ignore
)
ingest_files(func_params, ingestible_files)
agenta.config.set(**api_config_params)
result = await execute_function(
llm_result = await execute_function(
func, *args, **{"params": func_params, "config_params": config_params}
)

# End tracing
if isinstance(result, JSONResponse):
result = {"message": str(result), "usage": None, "cost": None}
tracing.end_trace(outputs=[result["message"]], usage=result["usage"], cost=result["cost"]) # type: ignore
return result
if isinstance(llm_result, JSONResponse):
result = {"message": str(llm_result), "total_tokens": 0, "cost": 0}
else:
result = {
"message": llm_result.message,
"total_tokens": llm_result.usage.total_tokens,
"cost": llm_result.cost,
}
tracing.end_trace(outputs=[result["message"]], total_tokens=result["total_tokens"], cost=result["cost"]) # type: ignore
return llm_result

@functools.wraps(func)
async def wrapper_deployed(*args, **kwargs) -> Any:
Expand All @@ -118,7 +125,7 @@ async def wrapper_deployed(*args, **kwargs) -> Any:
config=config,
**{"environment": kwargs["environment"]},
)
result = await execute_function(
llm_result = await execute_function(
func,
*args,
**{
Expand All @@ -128,10 +135,18 @@ async def wrapper_deployed(*args, **kwargs) -> Any:
)

# End tracing
if isinstance(result, JSONResponse):
result = {"message": str(result), "usage": None}
tracing.end_trace(outputs=[result["message"]], **kwargs) # type: ignore
return result
if isinstance(llm_result, JSONResponse):
result = {"message": str(llm_result), "total_tokens": 0, "cost": 0}
else:
result = {
"message": llm_result.message,
"total_tokens": (
llm_result.usage.total_tokens if llm_result.usage else None
),
"cost": llm_result.cost,
}
tracing.end_trace(outputs=[result["message"]], total_tokens=result["total_tokens"], cost=result["cost"]) # type: ignore
return llm_result

update_function_signature(wrapper, func_signature, config_params, ingestible_files)
route = f"/{endpoint_name}"
Expand Down Expand Up @@ -193,9 +208,7 @@ def ingest_files(
func_params[name] = ingest_file(func_params[name])


async def execute_function(
func: Callable[..., Any], *args, **func_params
) -> Union[Dict[str, Any], JSONResponse]:
async def execute_function(func: Callable[..., Any], *args, **func_params):
"""Execute the function and handle any exceptions."""

try:
Expand All @@ -210,18 +223,19 @@ async def execute_function(
result = await func(*args, **func_params["params"])
else:
result = func(*args, **func_params["params"])

end_time = time.perf_counter()
latency = end_time - start_time

if isinstance(result, Context):
save_context(result)
if isinstance(result, Dict):
return FuncResponse(**result, latency=round(latency, 4)).dict()
return FuncResponse(**result, latency=round(latency, 4))
if isinstance(result, str):
return FuncResponse(message=result, latency=round(latency, 4)).dict() # type: ignore
return FuncResponse(message=result, latency=round(latency, 4)) # type: ignore
except Exception as e:
return handle_exception(e)
return FuncResponse(message="Unexpected error occurred", latency=round(latency, 4)).dict() # type: ignore
return FuncResponse(message="Unexpected error occurred", latency=0) # type: ignore


def handle_exception(e: Exception) -> JSONResponse:
Expand Down
17 changes: 10 additions & 7 deletions agenta-cli/agenta/sdk/tracing/llm_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def client(self) -> AsyncObservabilityClient:
).observability

def set_span_attribute(
self, parent_key: Optional[str] = None, **kwargs: Dict[str, Any]
self, parent_key: Optional[str] = None, attributes: Dict[str, Any] = {}
):
span = self.span_dict[self.active_span] # type: ignore
for key, value in kwargs.items():
for key, value in attributes.items():
span.set_attribute(key, value, parent_key)

def set_trace_tags(self, tags: List[str]):
Expand Down Expand Up @@ -234,17 +234,20 @@ def trace(
except Exception as exc:
self.llm_logger.error(f"Error creating trace: {str(exc)}")

def end_trace(self, outputs: List[str], **kwargs: Dict[str, Any]):
def end_trace(
self,
outputs: List[str],
cost: Optional[float] = None,
total_tokens: Optional[int] = None,
):
try:
self.tasks_manager.add_task(
self.client.update_trace(
trace_id=self.active_trace, # type: ignore
status="COMPLETED",
end_time=datetime.now(),
cost=kwargs.get("cost"), # type: ignore
token_consumption=kwargs["usage"].get(
"total_tokens"
), # type: ignore
cost=cost,
token_consumption=total_tokens,
outputs=outputs,
)
)
Expand Down
9 changes: 3 additions & 6 deletions examples/baby_name_generator/app_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,22 @@ async def llm_call(prompt):
model="gpt-3.5-turbo", messages=[{"role": "user", "content": prompt}]
)
tracing.set_span_attribute(
"model", name="gpt-3.5-turbo"
) # translate to {"model": {"name": "gpt-3.5-turbo"}}
"model_config", {"model": "gpt-3.5-turbo", "temperature": ag.config.temperature}
) # translate to {"model_config": {"model": "gpt-3.5-turbo", "temperature": 0.2}}
return {
"message": chat_completion.choices[0].message.content,
"usage": chat_completion.usage.dict(),
}


@ag.entrypoint
async def generate(country: str, gender: str) -> str:
async def generate(country: str, gender: str):
"""
Generate a baby name based on the given country and gender.
Args:
country (str): The country to generate the name from.
gender (str): The gender of the baby.
Returns:
str: The generated baby name.
"""

prompt = ag.config.prompt_template.format(country=country, gender=gender)
Expand Down

0 comments on commit a409093

Please sign in to comment.