Skip to content

Commit

Permalink
Refactor - make use of llm_tracing func directly
Browse files Browse the repository at this point in the history
  • Loading branch information
aybruhm committed Apr 4, 2024
1 parent dd81a35 commit 18e0508
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 11 deletions.
3 changes: 2 additions & 1 deletion agenta-cli/agenta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
)
from .sdk.tracing.decorators import span
from .sdk.utils.preinit import PreInitObject
from .sdk.agenta_init import Config, init, tracing
from .sdk.agenta_init import Config, init, llm_tracing
from .sdk.utils.helper.openai_cost import calculate_token_usage


config = PreInitObject("agenta.config", Config)
3 changes: 2 additions & 1 deletion agenta-cli/agenta/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
BinaryParam,
)
from .tracing.decorators import span
from .agenta_init import Config, init, tracing
from .agenta_init import Config, init, llm_tracing
from .utils.helper.openai_cost import calculate_token_usage


config = PreInitObject("agenta.config", Config)
5 changes: 4 additions & 1 deletion agenta-cli/agenta/sdk/agenta_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,15 @@ def entrypoint(func: Callable[..., Any]) -> Callable[..., Any]:
config_params = agenta.config.all()
ingestible_files = extract_ingestible_files(func_signature)

# Initialize tracing
tracing = agenta.llm_tracing()

@functools.wraps(func)
async def wrapper(*args, **kwargs) -> Any:
func_params, api_config_params = split_kwargs(kwargs, config_params)

# Start tracing
agenta.tracing.trace(
tracing.trace(
trace_name=func.__name__,
inputs=func_params,
config=api_config_params,
Expand Down
3 changes: 0 additions & 3 deletions agenta-cli/agenta/sdk/agenta_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,3 @@ def llm_tracing(max_workers: Optional[int] = None) -> Tracing:
api_key=singleton.api_key,
max_workers=max_workers,
)


tracing = llm_tracing()
8 changes: 5 additions & 3 deletions agenta-cli/agenta/sdk/tracing/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@
def span(type: str):
"""Decorator to automatically start and end spans."""

tracing = ag.llm_tracing()

def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
result = None
span = ag.tracing.start_span(
span = tracing.start_span(
func.__name__,
input=kwargs,
type=type,
trace_id=ag.tracing.active_trace,
trace_id=tracing.active_trace,
)
try:
is_coroutine_function = inspect.iscoroutinefunction(func)
Expand All @@ -33,7 +35,7 @@ async def wrapper(*args, **kwargs):
finally:
if not isinstance(result, dict):
result = {"message": result}
ag.tracing.end_span(output=result, span=span)
tracing.end_span(output=result, span=span)
return result

return wrapper
Expand Down
3 changes: 2 additions & 1 deletion examples/app_with_observability/app_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)

ag.init()
tracing = ag.llm_tracing()
ag.config.default(
temperature=ag.FloatParam(0.2), prompt_template=ag.TextParam(default_prompt)
)
Expand All @@ -18,7 +19,7 @@ async def llm_call(prompt):
chat_completion = await client.chat.completions.create(
model="gpt-3.5-turbo", messages=[{"role": "user", "content": prompt}]
)
ag.tracing.set_span_attribute(
tracing.set_span_attribute(
"model_config", {"model": "gpt-3.5-turbo", "temperature": ag.config.temperature}
) # translate to {"model_config": {"model": "gpt-3.5-turbo", "temperature": 0.2}}
tokens_usage = chat_completion.usage.dict()
Expand Down
3 changes: 2 additions & 1 deletion examples/app_with_observability/app_nested_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
]

ag.init()
tracing = ag.llm_tracing()
ag.config.default(
temperature_1=ag.FloatParam(default=1, minval=0.0, maxval=2.0),
model_1=ag.MultipleChoiceParam("gpt-3.5-turbo", CHAT_LLM_GPT),
Expand Down Expand Up @@ -56,7 +57,7 @@ async def llm_call(
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
ag.tracing.set_span_attribute(
tracing.set_span_attribute(
"model_config", {"model": model, "temperature": temperature}
)
tokens_usage = response.usage.dict() # type: ignore
Expand Down

0 comments on commit 18e0508

Please sign in to comment.