diff --git a/agenta-cli/agenta/__init__.py b/agenta-cli/agenta/__init__.py index 49ab103b87..7108ba4cd9 100644 --- a/agenta-cli/agenta/__init__.py +++ b/agenta-cli/agenta/__init__.py @@ -14,7 +14,8 @@ ) from .sdk.tracing.decorators import span from .sdk.utils.preinit import PreInitObject -from .sdk.agenta_init import Config, init, tracing +from .sdk.agenta_init import Config, init, llm_tracing from .sdk.utils.helper.openai_cost import calculate_token_usage + config = PreInitObject("agenta.config", Config) diff --git a/agenta-cli/agenta/sdk/__init__.py b/agenta-cli/agenta/sdk/__init__.py index 1d8e6a69ba..252feff591 100644 --- a/agenta-cli/agenta/sdk/__init__.py +++ b/agenta-cli/agenta/sdk/__init__.py @@ -15,7 +15,8 @@ BinaryParam, ) from .tracing.decorators import span -from .agenta_init import Config, init, tracing +from .agenta_init import Config, init, llm_tracing from .utils.helper.openai_cost import calculate_token_usage + config = PreInitObject("agenta.config", Config) diff --git a/agenta-cli/agenta/sdk/agenta_decorator.py b/agenta-cli/agenta/sdk/agenta_decorator.py index eb5fa0f4b4..b5c08dfec5 100644 --- a/agenta-cli/agenta/sdk/agenta_decorator.py +++ b/agenta-cli/agenta/sdk/agenta_decorator.py @@ -73,12 +73,15 @@ def entrypoint(func: Callable[..., Any]) -> Callable[..., Any]: config_params = agenta.config.all() ingestible_files = extract_ingestible_files(func_signature) + # Initialize tracing + tracing = agenta.llm_tracing() + @functools.wraps(func) async def wrapper(*args, **kwargs) -> Any: func_params, api_config_params = split_kwargs(kwargs, config_params) # Start tracing - agenta.tracing.trace( + tracing.trace( trace_name=func.__name__, inputs=func_params, config=api_config_params, diff --git a/agenta-cli/agenta/sdk/agenta_init.py b/agenta-cli/agenta/sdk/agenta_init.py index 0d707ea594..5b1302bab9 100644 --- a/agenta-cli/agenta/sdk/agenta_init.py +++ b/agenta-cli/agenta/sdk/agenta_init.py @@ -240,6 +240,3 @@ def llm_tracing(max_workers: Optional[int] = None) -> Tracing: api_key=singleton.api_key, max_workers=max_workers, ) - - -tracing = llm_tracing() diff --git a/agenta-cli/agenta/sdk/tracing/decorators.py b/agenta-cli/agenta/sdk/tracing/decorators.py index 84dfb607bf..af799d5876 100644 --- a/agenta-cli/agenta/sdk/tracing/decorators.py +++ b/agenta-cli/agenta/sdk/tracing/decorators.py @@ -9,15 +9,17 @@ def span(type: str): """Decorator to automatically start and end spans.""" + tracing = ag.llm_tracing() + def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): result = None - span = ag.tracing.start_span( + span = tracing.start_span( func.__name__, input=kwargs, type=type, - trace_id=ag.tracing.active_trace, + trace_id=tracing.active_trace, ) try: is_coroutine_function = inspect.iscoroutinefunction(func) @@ -33,7 +35,7 @@ async def wrapper(*args, **kwargs): finally: if not isinstance(result, dict): result = {"message": result} - ag.tracing.end_span(output=result, span=span) + tracing.end_span(output=result, span=span) return result return wrapper diff --git a/examples/app_with_observability/app_async.py b/examples/app_with_observability/app_async.py index cf423750e4..6f0612c15d 100644 --- a/examples/app_with_observability/app_async.py +++ b/examples/app_with_observability/app_async.py @@ -8,6 +8,7 @@ ) ag.init() +tracing = ag.llm_tracing() ag.config.default( temperature=ag.FloatParam(0.2), prompt_template=ag.TextParam(default_prompt) ) @@ -18,7 +19,7 @@ async def llm_call(prompt): chat_completion = await client.chat.completions.create( model="gpt-3.5-turbo", messages=[{"role": "user", "content": prompt}] ) - ag.tracing.set_span_attribute( + tracing.set_span_attribute( "model_config", {"model": "gpt-3.5-turbo", "temperature": ag.config.temperature} ) # translate to {"model_config": {"model": "gpt-3.5-turbo", "temperature": 0.2}} tokens_usage = chat_completion.usage.dict() diff --git a/examples/app_with_observability/app_nested_async.py b/examples/app_with_observability/app_nested_async.py index 291796df10..fe94789e78 100644 --- a/examples/app_with_observability/app_nested_async.py +++ b/examples/app_with_observability/app_nested_async.py @@ -19,6 +19,7 @@ ] ag.init() +tracing = ag.llm_tracing() ag.config.default( temperature_1=ag.FloatParam(default=1, minval=0.0, maxval=2.0), model_1=ag.MultipleChoiceParam("gpt-3.5-turbo", CHAT_LLM_GPT), @@ -56,7 +57,7 @@ async def llm_call( frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, ) - ag.tracing.set_span_attribute( + tracing.set_span_attribute( "model_config", {"model": model, "temperature": temperature} ) tokens_usage = response.usage.dict() # type: ignore