diff --git a/agenta-backend/agenta_backend/services/app_manager.py b/agenta-backend/agenta_backend/services/app_manager.py index a0c1aaf1bb..dfd9f0c021 100644 --- a/agenta-backend/agenta_backend/services/app_manager.py +++ b/agenta-backend/agenta_backend/services/app_manager.py @@ -95,8 +95,6 @@ async def start_variant( env_vars = {} if env_vars is None else env_vars env_vars.update( { - "AGENTA_VARIANT_NAME": db_app_variant.variant_name, - "AGENTA_VARIANT_ID": str(db_app_variant.id), "AGENTA_BASE_ID": str(db_app_variant.base.id), "AGENTA_APP_ID": str(db_app_variant.app.id), "AGENTA_HOST": domain_name, diff --git a/agenta-cli/agenta/__init__.py b/agenta-cli/agenta/__init__.py index 78711111ef..7f5bcde7ab 100644 --- a/agenta-cli/agenta/__init__.py +++ b/agenta-cli/agenta/__init__.py @@ -1,5 +1,4 @@ from .sdk.utils.preinit import PreInitObject -from .sdk.agenta_decorator import app, entrypoint from .sdk.context import get_contexts, save_context from .sdk.types import ( Context, @@ -14,9 +13,13 @@ FileInputURL, BinaryParam, ) -from .sdk.tracing.decorators import span -from .sdk.agenta_init import Config, init, llm_tracing +from .sdk.tracing.llm_tracing import Tracing +from .sdk.decorators.tracing import instrument +from .sdk.decorators.llm_entrypoint import entrypoint, app +from .sdk.agenta_init import Config, AgentaSingleton, init from .sdk.utils.helper.openai_cost import calculate_token_usage from .sdk.client import Agenta config = PreInitObject("agenta.config", Config) +DEFAULT_AGENTA_SINGLETON_INSTANCE = AgentaSingleton() +tracing = DEFAULT_AGENTA_SINGLETON_INSTANCE.tracing # type: ignore diff --git a/agenta-cli/agenta/client/backend/client.py b/agenta-cli/agenta/client/backend/client.py index 6ac62f552e..0e0d643627 100644 --- a/agenta-cli/agenta/client/backend/client.py +++ b/agenta-cli/agenta/client/backend/client.py @@ -56,9 +56,9 @@ def __init__( self._client_wrapper = SyncClientWrapper( base_url=base_url, api_key=api_key, - httpx_client=httpx.Client(timeout=timeout) - if httpx_client is None - else httpx_client, + httpx_client=( + httpx.Client(timeout=timeout) if httpx_client is None else httpx_client + ), ) self.observability = ObservabilityClient(client_wrapper=self._client_wrapper) self.apps = AppsClient(client_wrapper=self._client_wrapper) @@ -1037,9 +1037,11 @@ def __init__( self._client_wrapper = AsyncClientWrapper( base_url=base_url, api_key=api_key, - httpx_client=httpx.AsyncClient(timeout=timeout) - if httpx_client is None - else httpx_client, + httpx_client=( + httpx.AsyncClient(timeout=timeout) + if httpx_client is None + else httpx_client + ), ) self.observability = AsyncObservabilityClient( client_wrapper=self._client_wrapper diff --git a/agenta-cli/agenta/client/backend/types/create_span.py b/agenta-cli/agenta/client/backend/types/create_span.py index 89c123c782..4ce8cfd5a6 100644 --- a/agenta-cli/agenta/client/backend/types/create_span.py +++ b/agenta-cli/agenta/client/backend/types/create_span.py @@ -53,6 +53,6 @@ def dict(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]: return super().dict(**kwargs_with_defaults) class Config: - frozen = True + frozen = False smart_union = True json_encoders = {dt.datetime: serialize_datetime} diff --git a/agenta-cli/agenta/docker/docker-assets/Dockerfile.cloud.template b/agenta-cli/agenta/docker/docker-assets/Dockerfile.cloud.template index cfbcb8b934..d4ecb0a66d 100644 --- a/agenta-cli/agenta/docker/docker-assets/Dockerfile.cloud.template +++ b/agenta-cli/agenta/docker/docker-assets/Dockerfile.cloud.template @@ -1,9 +1,9 @@ FROM public.ecr.aws/s2t9a1r1/agentaai/lambda_templates_public:main COPY requirements.txt ${LAMBDA_TASK_ROOT} -RUN pip install --no-cache-dir --disable-pip-version-check -r requirements.txt -RUN pip install --no-cache-dir --disable-pip-version-check mangum RUN pip install --no-cache-dir --disable-pip-version-check -U agenta +RUN pip install --no-cache-dir --disable-pip-version-check -U -r requirements.txt +RUN pip install --no-cache-dir --disable-pip-version-check mangum COPY . ${LAMBDA_TASK_ROOT} CMD [ "lambda_function.handler" ] diff --git a/agenta-cli/agenta/docker/docker-assets/Dockerfile.template b/agenta-cli/agenta/docker/docker-assets/Dockerfile.template index e6d613a536..9eb6b06a54 100644 --- a/agenta-cli/agenta/docker/docker-assets/Dockerfile.template +++ b/agenta-cli/agenta/docker/docker-assets/Dockerfile.template @@ -4,8 +4,8 @@ WORKDIR /app COPY . . -RUN pip install --no-cache-dir --disable-pip-version-check -r requirements.txt RUN pip install --no-cache-dir --disable-pip-version-check -U agenta +RUN pip install --no-cache-dir --disable-pip-version-check -r requirements.txt EXPOSE 80 diff --git a/agenta-cli/agenta/sdk/__init__.py b/agenta-cli/agenta/sdk/__init__.py index 6d6d02e6d9..d9d9c89e76 100644 --- a/agenta-cli/agenta/sdk/__init__.py +++ b/agenta-cli/agenta/sdk/__init__.py @@ -1,6 +1,4 @@ from .utils.preinit import PreInitObject # always the first import! -from . import agenta_decorator, context, types, utils # noqa: F401 -from .agenta_decorator import app, entrypoint from .context import get_contexts, save_context from .types import ( Context, @@ -15,9 +13,13 @@ FileInputURL, BinaryParam, ) -from .tracing.decorators import span -from .agenta_init import Config, init, llm_tracing +from .tracing.llm_tracing import Tracing +from .decorators.tracing import instrument +from .decorators.llm_entrypoint import entrypoint, app +from .agenta_init import Config, AgentaSingleton, init from .utils.helper.openai_cost import calculate_token_usage config = PreInitObject("agenta.config", Config) +DEFAULT_AGENTA_SINGLETON_INSTANCE = AgentaSingleton() +tracing = DEFAULT_AGENTA_SINGLETON_INSTANCE.tracing # type: ignore diff --git a/agenta-cli/agenta/sdk/agenta_decorator.py b/agenta-cli/agenta/sdk/agenta_decorator.py deleted file mode 100644 index 4f5a4e75c7..0000000000 --- a/agenta-cli/agenta/sdk/agenta_decorator.py +++ /dev/null @@ -1,501 +0,0 @@ -"""The code for the Agenta SDK""" - -import os -import sys -import time -import inspect -import argparse -import asyncio -import traceback -import functools -from pathlib import Path -from tempfile import NamedTemporaryFile -from typing import Any, Callable, Dict, Optional, Tuple, List - -from fastapi.middleware.cors import CORSMiddleware -from fastapi import Body, FastAPI, UploadFile, HTTPException - -import agenta -from .context import save_context -from .router import router as router -from .types import ( - Context, - DictInput, - FloatParam, - InFile, - IntParam, - MultipleChoiceParam, - GroupedMultipleChoiceParam, - TextParam, - MessagesInput, - FileInputURL, - FuncResponse, - BinaryParam, -) - -app = FastAPI() - -origins = [ - "*", -] - -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -app.include_router(router, prefix="") - - -def ingest_file(upfile: UploadFile): - temp_file = NamedTemporaryFile(delete=False) - temp_file.write(upfile.file.read()) - temp_file.close() - return InFile(file_name=upfile.filename, file_path=temp_file.name) - - -def entrypoint(func: Callable[..., Any]) -> Callable[..., Any]: - """ - Decorator to wrap a function for HTTP POST and terminal exposure. - - Args: - func: Function to wrap. - - Returns: - Wrapped function for HTTP POST and terminal. - """ - - endpoint_name = "generate" - func_signature = inspect.signature(func) - config_params = agenta.config.all() - ingestible_files = extract_ingestible_files(func_signature) - - # Initialize tracing - tracing = agenta.llm_tracing() - - @functools.wraps(func) - async def wrapper(*args, **kwargs) -> Any: - func_params, api_config_params = split_kwargs(kwargs, config_params) - - # Start tracing - tracing.start_parent_span( - name=func.__name__, - inputs=func_params, - config=config_params, - environment="playground", # type: ignore #NOTE: wrapper is only called in playground - ) - - # Ingest files, prepare configurations and run llm app - ingest_files(func_params, ingestible_files) - agenta.config.set(**api_config_params) - llm_result = await execute_function( - func, *args, params=func_params, config_params=config_params - ) - - # End trace recording - tracing.end_recording( - outputs=llm_result.dict(), - span=tracing.active_trace, - ) - return llm_result - - @functools.wraps(func) - async def wrapper_deployed(*args, **kwargs) -> Any: - func_params = { - k: v for k, v in kwargs.items() if k not in ["config", "environment"] - } - if "environment" in kwargs and kwargs["environment"] is not None: - agenta.config.pull(environment_name=kwargs["environment"]) - elif "config" in kwargs and kwargs["config"] is not None: - agenta.config.pull(config_name=kwargs["config"]) - else: - agenta.config.pull(config_name="default") - - config = agenta.config.all() - - # Start tracing - tracing.start_parent_span( - name=func.__name__, - inputs=func_params, - config=config, - environment=kwargs["environment"], # type: ignore #NOTE: wrapper is only called in playground - ) - - llm_result = await execute_function( - func, *args, params=func_params, config_params=config_params - ) - - # End trace recording - tracing.end_recording( - outputs=llm_result.dict(), - span=tracing.active_trace, - ) - return llm_result - - update_function_signature(wrapper, func_signature, config_params, ingestible_files) - route = f"/{endpoint_name}" - app.post(route, response_model=FuncResponse)(wrapper) - - update_deployed_function_signature( - wrapper_deployed, - func_signature, - ingestible_files, - ) - route_deployed = f"/{endpoint_name}_deployed" - app.post(route_deployed, response_model=FuncResponse)(wrapper_deployed) - override_schema( - openapi_schema=app.openapi(), - func_name=func.__name__, - endpoint=endpoint_name, - params={**config_params, **func_signature.parameters}, - ) - - if is_main_script(func): - handle_terminal_run( - func, - func_signature.parameters, - config_params, - ingestible_files, - ) - return None - - -def extract_ingestible_files( - func_signature: inspect.Signature, -) -> Dict[str, inspect.Parameter]: - """Extract parameters annotated as InFile from function signature.""" - - return { - name: param - for name, param in func_signature.parameters.items() - if param.annotation is InFile - } - - -def split_kwargs( - kwargs: Dict[str, Any], config_params: Dict[str, Any] -) -> Tuple[Dict[str, Any], Dict[str, Any]]: - """Split keyword arguments into function parameters and API configuration parameters.""" - - func_params = {k: v for k, v in kwargs.items() if k not in config_params} - api_config_params = {k: v for k, v in kwargs.items() if k in config_params} - return func_params, api_config_params - - -def ingest_files( - func_params: Dict[str, Any], ingestible_files: Dict[str, inspect.Parameter] -) -> None: - """Ingest files specified in function parameters.""" - - for name in ingestible_files: - if name in func_params and func_params[name] is not None: - func_params[name] = ingest_file(func_params[name]) - - -async def execute_function(func: Callable[..., Any], *args, **func_params): - """Execute the function and handle any exceptions.""" - - try: - """Note: The following block is for backward compatibility. - It allows functions to work seamlessly whether they are synchronous or asynchronous. - For synchronous functions, it calls them directly, while for asynchronous functions, - it awaits their execution. - """ - is_coroutine_function = inspect.iscoroutinefunction(func) - start_time = time.perf_counter() - if is_coroutine_function: - result = await func(*args, **func_params["params"]) - else: - result = func(*args, **func_params["params"]) - - end_time = time.perf_counter() - latency = end_time - start_time - - if isinstance(result, Context): - save_context(result) - if isinstance(result, Dict): - return FuncResponse(**result, latency=round(latency, 4)) - if isinstance(result, str): - return FuncResponse(message=result, latency=round(latency, 4)) # type: ignore - except Exception as e: - handle_exception(e) - return FuncResponse(message="Unexpected error occurred", latency=0) # type: ignore - - -def handle_exception(e: Exception): - """Handle exceptions.""" - - status_code: int = e.status_code if hasattr(e, "status_code") else 500 - traceback_str = traceback.format_exception(e, value=e, tb=e.__traceback__) # type: ignore - raise HTTPException( - status_code=status_code, - detail={"error": str(e), "traceback": "".join(traceback_str)}, - ) - - -def update_wrapper_signature(wrapper: Callable[..., Any], updated_params: List): - """ - Updates the signature of a wrapper function with a new list of parameters. - - Args: - wrapper (callable): A callable object, such as a function or a method, that requires a signature update. - updated_params (List[inspect.Parameter]): A list of `inspect.Parameter` objects representing the updated parameters - for the wrapper function. - """ - - wrapper_signature = inspect.signature(wrapper) - wrapper_signature = wrapper_signature.replace(parameters=updated_params) - wrapper.__signature__ = wrapper_signature - - -def update_function_signature( - wrapper: Callable[..., Any], - func_signature: inspect.Signature, - config_params: Dict[str, Any], - ingestible_files: Dict[str, inspect.Parameter], -) -> None: - """Update the function signature to include new parameters.""" - - updated_params = [] - add_config_params_to_parser(updated_params, config_params) - add_func_params_to_parser(updated_params, func_signature, ingestible_files) - update_wrapper_signature(wrapper, updated_params) - - -def update_deployed_function_signature( - wrapper: Callable[..., Any], - func_signature: inspect.Signature, - ingestible_files: Dict[str, inspect.Parameter], -) -> None: - """Update the function signature to include new parameters.""" - updated_params = [] - add_func_params_to_parser(updated_params, func_signature, ingestible_files) - for param in [ - "config", - "environment", - ]: # we add the config and environment parameters - updated_params.append( - inspect.Parameter( - param, - inspect.Parameter.KEYWORD_ONLY, - default=Body(None), - annotation=str, - ) - ) - update_wrapper_signature(wrapper, updated_params) - - -def add_config_params_to_parser( - updated_params: list, config_params: Dict[str, Any] -) -> None: - """Add configuration parameters to function signature.""" - for name, param in config_params.items(): - updated_params.append( - inspect.Parameter( - name, - inspect.Parameter.KEYWORD_ONLY, - default=Body(param), - annotation=Optional[type(param)], - ) - ) - - -def add_func_params_to_parser( - updated_params: list, - func_signature: inspect.Signature, - ingestible_files: Dict[str, inspect.Parameter], -) -> None: - """Add function parameters to function signature.""" - for name, param in func_signature.parameters.items(): - if name in ingestible_files: - updated_params.append( - inspect.Parameter(name, param.kind, annotation=UploadFile) - ) - else: - updated_params.append( - inspect.Parameter( - name, - inspect.Parameter.KEYWORD_ONLY, - default=Body(..., embed=True), - annotation=param.annotation, - ) - ) - - -def is_main_script(func: Callable) -> bool: - """ - Check if the script containing the function is the main script being run. - - Args: - func (Callable): The function object to check. - - Returns: - bool: True if the script containing the function is the main script, False otherwise. - - Example: - if is_main_script(my_function): - print("This is the main script.") - """ - return ( - os.path.splitext(os.path.basename(sys.argv[0]))[0] - == os.path.splitext(os.path.basename(inspect.getfile(func)))[0] - ) - - -def handle_terminal_run( - func: Callable, - func_params: Dict[str, Any], - config_params: Dict[str, Any], - ingestible_files: Dict, -) -> None: - """ - Parses command line arguments and sets configuration when script is run from the terminal. - - Args: - func_params (dict): A dictionary containing the function parameters and their annotations. - config_params (dict): A dictionary containing the configuration parameters. - - Example: - handle_terminal_run(func_params=inspect.signature(my_function).parameters, config_params=config.all()) - """ - - # For required parameters, we add them as arguments - parser = argparse.ArgumentParser() - for name, param in func_params.items(): - if name in ingestible_files: - parser.add_argument(name, type=str) - else: - parser.add_argument(name, type=param.annotation) - - for name, param in config_params.items(): - if type(param) is MultipleChoiceParam: - parser.add_argument( - f"--{name}", - type=str, - default=param.default, - choices=param.choices, - ) - else: - parser.add_argument( - f"--{name}", - type=type(param), - default=param, - ) - - args = parser.parse_args() - - # split the arg list into the arg in the app_param and - # the args from the sig.parameter - args_config_params = {k: v for k, v in vars(args).items() if k in config_params} - args_func_params = {k: v for k, v in vars(args).items() if k not in config_params} - for name in ingestible_files: - args_func_params[name] = InFile( - file_name=Path(args_func_params[name]).stem, - file_path=args_func_params[name], - ) - agenta.config.set(**args_config_params) - - loop = asyncio.get_event_loop() - result = loop.run_until_complete( - execute_function( - func, **{"params": args_func_params, "config_params": args_config_params} - ) - ) - print(result) - - -def override_schema(openapi_schema: dict, func_name: str, endpoint: str, params: dict): - """ - Overrides the default openai schema generated by fastapi with additional information about: - - The choices available for each MultipleChoiceParam instance - - The min and max values for each FloatParam instance - - The min and max values for each IntParam instance - - The default value for DictInput instance - - The default value for MessagesParam instance - - The default value for FileInputURL instance - - The default value for BinaryParam instance - - ... [PLEASE ADD AT EACH CHANGE] - - Args: - openapi_schema (dict): The openapi schema generated by fastapi - func_name (str): The name of the function to override - endpoint (str): The name of the endpoint to override - params (dict(param_name, param_val)): The dictionary of the parameters for the function - """ - - def find_in_schema(schema: dict, param_name: str, xparam: str): - """Finds a parameter in the schema based on its name and x-parameter value""" - for _, value in schema.items(): - value_title_lower = str(value.get("title")).lower() - value_title = ( - "_".join(value_title_lower.split()) - if len(value_title_lower.split()) >= 2 - else value_title_lower - ) - - if ( - isinstance(value, dict) - and value.get("x-parameter") == xparam - and value_title == param_name - ): - return value - - schema_to_override = openapi_schema["components"]["schemas"][ - f"Body_{func_name}_{endpoint}_post" - ]["properties"] - for param_name, param_val in params.items(): - if isinstance(param_val, GroupedMultipleChoiceParam): - subschema = find_in_schema(schema_to_override, param_name, "grouped_choice") - assert ( - subschema - ), f"GroupedMultipleChoiceParam '{param_name}' is in the parameters but could not be found in the openapi.json" - subschema["choices"] = param_val.choices - subschema["default"] = param_val.default - if isinstance(param_val, MultipleChoiceParam): - subschema = find_in_schema(schema_to_override, param_name, "choice") - default = str(param_val) - param_choices = param_val.choices - choices = ( - [default] + param_choices - if param_val not in param_choices - else param_choices - ) - subschema["enum"] = choices - subschema["default"] = default if default in param_choices else choices[0] - if isinstance(param_val, FloatParam): - subschema = find_in_schema(schema_to_override, param_name, "float") - subschema["minimum"] = param_val.minval - subschema["maximum"] = param_val.maxval - subschema["default"] = param_val - if isinstance(param_val, IntParam): - subschema = find_in_schema(schema_to_override, param_name, "int") - subschema["minimum"] = param_val.minval - subschema["maximum"] = param_val.maxval - subschema["default"] = param_val - if ( - isinstance(param_val, inspect.Parameter) - and param_val.annotation is DictInput - ): - subschema = find_in_schema(schema_to_override, param_name, "dict") - subschema["default"] = param_val.default["default_keys"] - if isinstance(param_val, TextParam): - subschema = find_in_schema(schema_to_override, param_name, "text") - subschema["default"] = param_val - if ( - isinstance(param_val, inspect.Parameter) - and param_val.annotation is MessagesInput - ): - subschema = find_in_schema(schema_to_override, param_name, "messages") - subschema["default"] = param_val.default - if ( - isinstance(param_val, inspect.Parameter) - and param_val.annotation is FileInputURL - ): - subschema = find_in_schema(schema_to_override, param_name, "file_url") - subschema["default"] = "https://example.com" - if isinstance(param_val, BinaryParam): - subschema = find_in_schema(schema_to_override, param_name, "bool") - subschema["default"] = param_val.default diff --git a/agenta-cli/agenta/sdk/agenta_init.py b/agenta-cli/agenta/sdk/agenta_init.py index de9b2f2460..3e4f522bdd 100644 --- a/agenta-cli/agenta/sdk/agenta_init.py +++ b/agenta-cli/agenta/sdk/agenta_init.py @@ -1,9 +1,9 @@ import os import logging -from typing import Any, Optional - -from .utils.globals import set_global +import toml +from typing import Optional +from agenta.sdk.utils.globals import set_global from agenta.client.backend.client import AgentaApi from agenta.sdk.tracing.llm_tracing import Tracing from agenta.client.exceptions import APIRequestError @@ -13,119 +13,97 @@ logger.setLevel(logging.DEBUG) -BACKEND_URL_SUFFIX = os.environ.get("BACKEND_URL_SUFFIX", "api") -CLIENT_API_KEY = os.environ.get("AGENTA_API_KEY") -CLIENT_HOST = os.environ.get("AGENTA_HOST", "http://localhost") - -# initialize the client with the backend url and api key -backend_url = f"{CLIENT_HOST}/{BACKEND_URL_SUFFIX}" -client = AgentaApi( - base_url=backend_url, - api_key=CLIENT_API_KEY if CLIENT_API_KEY else "", -) - - class AgentaSingleton: """Singleton class to save all the "global variables" for the sdk.""" _instance = None setup = None config = None + tracing: Optional[Tracing] = None def __new__(cls): if not cls._instance: cls._instance = super(AgentaSingleton, cls).__new__(cls) return cls._instance + @property + def client(self): + """API Backend client. + + Returns: + AgentaAPI: instance of agenta api backend + """ + + return AgentaApi(base_url=self.host + "/api", api_key=self.api_key) + def init( self, - app_name: Optional[str] = None, - base_name: Optional[str] = None, - api_key: Optional[str] = None, - base_id: Optional[str] = None, app_id: Optional[str] = None, host: Optional[str] = None, - **kwargs: Any, + api_key: Optional[str] = None, + config_fname: Optional[str] = None, ) -> None: """Main function to initialize the singleton. - Initializes the singleton with the given `app_name`, `base_name`, and `host`. If any of these arguments are not provided, - the function will look for them in environment variables. + Initializes the singleton with the given `app_id`, `host`, and `api_key`. The order of precedence for these variables is: + 1. Explicit argument provided in the function call. + 2. Value from the configuration file specified by `config_fname`. + 3. Environment variables. + + Examples: + ag.init(app_id="xxxx", api_key="xxx") + ag.init(config_fname="config.toml") + ag.init() #assuming env vars are set Args: - app_name (Optional[str]): Name of the Agenta application. Defaults to None. If not provided, will look for "AGENTA_APP_NAME" in environment variables. - base_name (Optional[str]): Base name for the Agenta setup. Defaults to None. If not provided, will look for "AGENTA_BASE_NAME" in environment variables. - host (Optional[str]): Host name of the backend server. Defaults to None. If not provided, will look for "AGENTA_HOST" in environment variables. - kwargs (Any): Additional keyword arguments. + app_id (Optional[str]): ID of the Agenta application. Defaults to None. If not provided, will look for "app_id" in the config file, then "AGENTA_APP_ID" in environment variables. + host (Optional[str]): Host name of the backend server. Defaults to None. If not provided, will look for "backend_host" in the config file, then "AGENTA_HOST" in environment variables. + api_key (Optional[str]): API Key to use with the host of the backend server. Defaults to None. If not provided, will look for "api_key" in the config file, then "AGENTA_API_KEY" in environment variables. + config_fname (Optional[str]): Path to the configuration file (relative or absolute). Defaults to None. Raises: - ValueError: If `app_name`, `base_name`, or `host` are not specified either as arguments or in the environment variables. - """ - if app_name is None: - app_name = os.environ.get("AGENTA_APP_NAME") - if base_name is None: - base_name = os.environ.get("AGENTA_BASE_NAME") - if api_key is None: - api_key = os.environ.get("AGENTA_API_KEY") - if base_id is None: - base_id = os.environ.get("AGENTA_BASE_ID") - if host is None: - host = os.environ.get("AGENTA_HOST", "http://localhost") - - if base_id is None: - if app_name is None or base_name is None: - print( - f"Warning: Your configuration will not be saved permanently since app_name and base_name are not provided." - ) - else: - try: - app_id = self.get_app(app_name) - base_id = self.get_app_base(app_id, base_name) - except Exception as ex: - raise APIRequestError( - f"Failed to get base id and/or app_id from the server with error: {ex}" - ) - - self.base_id = base_id - self.host = host - self.app_id = os.environ.get("AGENTA_APP_ID") if app_id is None else app_id - self.variant_id = os.environ.get("AGENTA_VARIANT_ID") - self.variant_name = os.environ.get("AGENTA_VARIANT_NAME") - self.api_key = api_key - self.config = Config(base_id=base_id, host=host) - - def get_app(self, app_name: str) -> str: - apps = client.apps.list_apps(app_name=app_name) - if len(apps) == 0: - raise APIRequestError(f"App with name {app_name} not found") - - app_id = apps[0].app_id - return app_id - - def get_app_base(self, app_id: str, base_name: str) -> str: - bases = client.bases.list_bases(app_id=app_id, base_name=base_name) - if len(bases) == 0: - raise APIRequestError(f"No base was found for the app {app_id}") - return bases[0].base_id - - def get_current_config(self): - """ - Retrieves the current active configuration + ValueError: If `app_id` is not specified either as an argument, in the config file, or in the environment variables. """ + config = {} + if config_fname: + config = toml.load(config_fname) + + self.app_id = app_id or config.get("app_id") or os.environ.get("AGENTA_APP_ID") + self.host = ( + host + or config.get("backend_host") + or os.environ.get("AGENTA_HOST", "https://cloud.agenta.ai") + ) + self.api_key = ( + api_key or config.get("api_key") or os.environ.get("AGENTA_API_KEY") + ) + + if not self.app_id: + raise ValueError( + "App ID must be specified. You can provide it in one of the following ways:\n" + "1. As an argument when calling ag.init(app_id='your_app_id').\n" + "2. In the configuration file specified by config_fname.\n" + "3. As an environment variable 'AGENTA_APP_ID'." + ) + self.base_id = os.environ.get("AGENTA_BASE_ID") + if self.base_id is None: + print( + "Warning: Your configuration will not be saved permanently since base_id is not provided." + ) - if self._config_data is None: - raise RuntimeError("AgentaSingleton has not been initialized") - return self._config_data + self.config = Config(base_id=self.base_id, host=self.host) # type: ignore class Config: - def __init__(self, base_id, host): + def __init__(self, base_id: str, host: str, api_key: str = ""): self.base_id = base_id self.host = host + if base_id is None or host is None: self.persist = False else: self.persist = True + self.client = AgentaApi(base_url=self.host + "/api", api_key=api_key) def register_default(self, overwrite=False, **kwargs): """alias for default""" @@ -144,7 +122,7 @@ def default(self, overwrite=False, **kwargs): self.push(config_name="default", overwrite=overwrite, **kwargs) except Exception as ex: logger.warning( - "Unable to push the default configuration to the server." + str(ex) + "Unable to push the default configuration to the server. %s", str(ex) ) def push(self, config_name: str, overwrite=True, **kwargs): @@ -157,7 +135,7 @@ def push(self, config_name: str, overwrite=True, **kwargs): if not self.persist: return try: - client.configs.save_config( + self.client.configs.save_config( base_id=self.base_id, config_name=config_name, parameters=kwargs, @@ -165,38 +143,40 @@ def push(self, config_name: str, overwrite=True, **kwargs): ) except Exception as ex: logger.warning( - "Failed to push the configuration to the server with error: " + str(ex) + "Failed to push the configuration to the server with error: %s", ex ) - def pull(self, config_name: str = "default", environment_name: str = None): + def pull( + self, config_name: str = "default", environment_name: Optional[str] = None + ): """Pulls the parameters for the app variant from the server and sets them to the config""" if not self.persist and ( config_name != "default" or environment_name is not None ): - raise Exception( + raise ValueError( "Cannot pull the configuration from the server since the app_name and base_name are not provided." ) if self.persist: try: if environment_name: - config = client.configs.get_config( + config = self.client.configs.get_config( base_id=self.base_id, environment_name=environment_name ) else: - config = client.configs.get_config( + config = self.client.configs.get_config( base_id=self.base_id, config_name=config_name, ) except Exception as ex: logger.warning( - "Failed to pull the configuration from the server with error: " - + str(ex) + "Failed to pull the configuration from the server with error: %s", + str(ex), ) try: self.set(**{"current_version": config.current_version, **config.parameters}) except Exception as ex: - logger.warning("Failed to set the configuration with error: " + str(ex)) + logger.warning("Failed to set the configuration with error: %s", str(ex)) def all(self): """Returns all the parameters for the app variant""" @@ -204,7 +184,15 @@ def all(self): k: v for k, v in self.__dict__.items() if k - not in ["app_name", "base_name", "host", "base_id", "api_key", "persist"] + not in [ + "app_name", + "base_name", + "host", + "base_id", + "api_key", + "persist", + "client", + ] } # function to set the parameters for the app variant @@ -217,28 +205,52 @@ def set(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) + def dump(self): + """Returns all the information about the current version in the configuration. -def init(app_name=None, base_name=None, **kwargs): - """Main function to be called by the user to initialize the sdk. + Raises: + NotImplementedError: _description_ + """ - Args: - app_name: _description_. Defaults to None. - base_name: _description_. Defaults to None. - """ - singleton = AgentaSingleton() - singleton.init(app_name=app_name, base_name=base_name, **kwargs) - set_global(setup=singleton.setup, config=singleton.config) + raise NotImplementedError() -def llm_tracing(max_workers: Optional[int] = None) -> Tracing: - """Function to start llm tracing.""" +def init( + app_id: Optional[str] = None, + host: Optional[str] = None, + api_key: Optional[str] = None, + config_fname: Optional[str] = None, + max_workers: Optional[int] = None, +): + """Main function to initialize the agenta sdk. + + Initializes agenta with the given `app_id`, `host`, and `api_key`. The order of precedence for these variables is: + 1. Explicit argument provided in the function call. + 2. Value from the configuration file specified by `config_fname`. + 3. Environment variables. + + - `app_id` is a required parameter (to be specified in one of the above ways) + - `host` is optional and defaults to "https://cloud.agenta.ai" + - `api_key` is optional and defaults to "". It is required only when using cloud or enterprise version of agenta. + + + Args: + app_id (Optional[str]): ID of the Agenta application. Defaults to None. If not provided, will look for "app_id" in the config file, then "AGENTA_APP_ID" in environment variables. + host (Optional[str]): Host name of the backend server. Defaults to None. If not provided, will look for "backend_host" in the config file, then "AGENTA_HOST" in environment variables. + api_key (Optional[str]): API Key to use with the host of the backend server. Defaults to None. If not provided, will look for "api_key" in the config file, then "AGENTA_API_KEY" in environment variables. + config_fname (Optional[str]): Path to the configuration file. Defaults to None. + + Raises: + ValueError: If `app_id` is not specified either as an argument, in the config file, or in the environment variables. + """ singleton = AgentaSingleton() - return Tracing( - base_url=singleton.host, + + singleton.init(app_id=app_id, host=host, api_key=api_key, config_fname=config_fname) + tracing = Tracing( + host=singleton.host, # type: ignore app_id=singleton.app_id, # type: ignore - variant_id=singleton.variant_id, # type: ignore - variant_name=singleton.variant_name, api_key=singleton.api_key, max_workers=max_workers, ) + set_global(setup=singleton.setup, config=singleton.config, tracing=tracing) diff --git a/agenta-cli/agenta/sdk/decorators/base.py b/agenta-cli/agenta/sdk/decorators/base.py new file mode 100644 index 0000000000..ae831fbba1 --- /dev/null +++ b/agenta-cli/agenta/sdk/decorators/base.py @@ -0,0 +1,10 @@ +# Stdlib Imports +from typing import Any, Callable + + +class BaseDecorator: + def __init__(self): + pass + + def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]: + raise NotImplementedError diff --git a/agenta-cli/agenta/sdk/decorators/llm_entrypoint.py b/agenta-cli/agenta/sdk/decorators/llm_entrypoint.py new file mode 100644 index 0000000000..25b82f38de --- /dev/null +++ b/agenta-cli/agenta/sdk/decorators/llm_entrypoint.py @@ -0,0 +1,499 @@ +"""The code for the Agenta SDK""" + +import os +import sys +import time +import inspect +import argparse +import asyncio +import traceback +import functools +from pathlib import Path +from tempfile import NamedTemporaryFile +from typing import Any, Callable, Dict, Optional, Tuple, List + +from fastapi.middleware.cors import CORSMiddleware +from fastapi import Body, FastAPI, UploadFile, HTTPException + +import agenta +from agenta.sdk.context import save_context +from agenta.sdk.router import router as router +from agenta.sdk.tracing.llm_tracing import Tracing +from agenta.sdk.decorators.base import BaseDecorator +from agenta.sdk.types import ( + Context, + DictInput, + FloatParam, + InFile, + IntParam, + MultipleChoiceParam, + GroupedMultipleChoiceParam, + TextParam, + MessagesInput, + FileInputURL, + FuncResponse, + BinaryParam, +) + +app = FastAPI() + +origins = [ + "*", +] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.include_router(router, prefix="") + + +class entrypoint(BaseDecorator): + """Decorator class to wrap a function for HTTP POST, terminal exposure and enable tracing. + + + Example: + ```python + import agenta as ag + + @ag.entrypoint + async def chain_of_prompts_llm(prompt: str): + return ... + ``` + """ + + def __init__(self, func: Callable[..., Any]): + endpoint_name = "generate" + func_signature = inspect.signature(func) + config_params = agenta.config.all() + ingestible_files = self.extract_ingestible_files(func_signature) + + @functools.wraps(func) + async def wrapper(*args, **kwargs) -> Any: + func_params, api_config_params = self.split_kwargs(kwargs, config_params) + self.ingest_files(func_params, ingestible_files) + agenta.config.set(**api_config_params) + + # Set the configuration and environment of the LLM app parent span at run-time + agenta.tracing.set_span_attribute( + {"config": config_params, "environment": "playground"} + ) + + llm_result = await self.execute_function( + func, *args, params=func_params, config_params=config_params + ) + return llm_result + + @functools.wraps(func) + async def wrapper_deployed(*args, **kwargs) -> Any: + func_params = { + k: v for k, v in kwargs.items() if k not in ["config", "environment"] + } + + if "environment" in kwargs and kwargs["environment"] is not None: + agenta.config.pull(environment_name=kwargs["environment"]) + elif "config" in kwargs and kwargs["config"] is not None: + agenta.config.pull(config_name=kwargs["config"]) + else: + agenta.config.pull(config_name="default") + + # Set the configuration and environment of the LLM app parent span at run-time + agenta.tracing.set_span_attribute( + {"config": config_params, "environment": kwargs["environment"]} + ) + + llm_result = await self.execute_function( + func, *args, params=func_params, config_params=config_params + ) + return llm_result + + self.update_function_signature( + wrapper, func_signature, config_params, ingestible_files + ) + route = f"/{endpoint_name}" + app.post(route, response_model=FuncResponse)(wrapper) + + self.update_deployed_function_signature( + wrapper_deployed, + func_signature, + ingestible_files, + ) + route_deployed = f"/{endpoint_name}_deployed" + app.post(route_deployed, response_model=FuncResponse)(wrapper_deployed) + self.override_schema( + openapi_schema=app.openapi(), + func_name=func.__name__, + endpoint=endpoint_name, + params={**config_params, **func_signature.parameters}, + ) + + if self.is_main_script(func): + self.handle_terminal_run( + func, + func_signature.parameters, # type: ignore + config_params, + ingestible_files, + ) + + def extract_ingestible_files( + self, + func_signature: inspect.Signature, + ) -> Dict[str, inspect.Parameter]: + """Extract parameters annotated as InFile from function signature.""" + + return { + name: param + for name, param in func_signature.parameters.items() + if param.annotation is InFile + } + + def split_kwargs( + self, kwargs: Dict[str, Any], config_params: Dict[str, Any] + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Split keyword arguments into function parameters and API configuration parameters.""" + + func_params = {k: v for k, v in kwargs.items() if k not in config_params} + api_config_params = {k: v for k, v in kwargs.items() if k in config_params} + return func_params, api_config_params + + def ingest_file(self, upfile: UploadFile): + temp_file = NamedTemporaryFile(delete=False) + temp_file.write(upfile.file.read()) + temp_file.close() + return InFile(file_name=upfile.filename, file_path=temp_file.name) + + def ingest_files( + self, + func_params: Dict[str, Any], + ingestible_files: Dict[str, inspect.Parameter], + ) -> None: + """Ingest files specified in function parameters.""" + + for name in ingestible_files: + if name in func_params and func_params[name] is not None: + func_params[name] = self.ingest_file(func_params[name]) + + async def execute_function(self, func: Callable[..., Any], *args, **func_params): + """Execute the function and handle any exceptions.""" + + try: + """Note: The following block is for backward compatibility. + It allows functions to work seamlessly whether they are synchronous or asynchronous. + For synchronous functions, it calls them directly, while for asynchronous functions, + it awaits their execution. + """ + is_coroutine_function = inspect.iscoroutinefunction(func) + start_time = time.perf_counter() + if is_coroutine_function: + result = await func(*args, **func_params["params"]) + else: + result = func(*args, **func_params["params"]) + + end_time = time.perf_counter() + latency = end_time - start_time + + if isinstance(result, Context): + save_context(result) + if isinstance(result, Dict): + return FuncResponse(**result, latency=round(latency, 4)) + if isinstance(result, str): + return FuncResponse(message=result, latency=round(latency, 4)) # type: ignore + if isinstance(result, int) or isinstance(result, float): + return FuncResponse(message=str(result), latency=round(latency, 4)) + if result is None: + return FuncResponse( + message="Function executed successfully, but did return None. \n Are you sure you did not forget to return a value?", + latency=round(latency, 4), + ) + except Exception as e: + self.handle_exception(e) + return FuncResponse(message="Unexpected error occurred when calling the @entrypoing decorated function", latency=0) # type: ignore + + def handle_exception(self, e: Exception): + """Handle exceptions.""" + + status_code: int = e.status_code if hasattr(e, "status_code") else 500 + traceback_str = traceback.format_exception(e, value=e, tb=e.__traceback__) # type: ignore + raise HTTPException( + status_code=status_code, + detail={"error": str(e), "traceback": "".join(traceback_str)}, + ) + + def update_wrapper_signature( + self, wrapper: Callable[..., Any], updated_params: List + ): + """ + Updates the signature of a wrapper function with a new list of parameters. + + Args: + wrapper (callable): A callable object, such as a function or a method, that requires a signature update. + updated_params (List[inspect.Parameter]): A list of `inspect.Parameter` objects representing the updated parameters + for the wrapper function. + """ + + wrapper_signature = inspect.signature(wrapper) + wrapper_signature = wrapper_signature.replace(parameters=updated_params) + wrapper.__signature__ = wrapper_signature + + def update_function_signature( + self, + wrapper: Callable[..., Any], + func_signature: inspect.Signature, + config_params: Dict[str, Any], + ingestible_files: Dict[str, inspect.Parameter], + ) -> None: + """Update the function signature to include new parameters.""" + + updated_params = [] + self.add_config_params_to_parser(updated_params, config_params) + self.add_func_params_to_parser(updated_params, func_signature, ingestible_files) + self.update_wrapper_signature(wrapper, updated_params) + + def update_deployed_function_signature( + self, + wrapper: Callable[..., Any], + func_signature: inspect.Signature, + ingestible_files: Dict[str, inspect.Parameter], + ) -> None: + """Update the function signature to include new parameters.""" + updated_params = [] + self.add_func_params_to_parser(updated_params, func_signature, ingestible_files) + for param in [ + "config", + "environment", + ]: # we add the config and environment parameters + updated_params.append( + inspect.Parameter( + param, + inspect.Parameter.KEYWORD_ONLY, + default=Body(None), + annotation=str, + ) + ) + self.update_wrapper_signature(wrapper, updated_params) + + def add_config_params_to_parser( + self, updated_params: list, config_params: Dict[str, Any] + ) -> None: + """Add configuration parameters to function signature.""" + for name, param in config_params.items(): + updated_params.append( + inspect.Parameter( + name, + inspect.Parameter.KEYWORD_ONLY, + default=Body(param), + annotation=Optional[type(param)], + ) + ) + + def add_func_params_to_parser( + self, + updated_params: list, + func_signature: inspect.Signature, + ingestible_files: Dict[str, inspect.Parameter], + ) -> None: + """Add function parameters to function signature.""" + for name, param in func_signature.parameters.items(): + if name in ingestible_files: + updated_params.append( + inspect.Parameter(name, param.kind, annotation=UploadFile) + ) + else: + updated_params.append( + inspect.Parameter( + name, + inspect.Parameter.KEYWORD_ONLY, + default=Body(..., embed=True), + annotation=param.annotation, + ) + ) + + def is_main_script(self, func: Callable) -> bool: + """ + Check if the script containing the function is the main script being run. + + Args: + func (Callable): The function object to check. + + Returns: + bool: True if the script containing the function is the main script, False otherwise. + + Example: + if is_main_script(my_function): + print("This is the main script.") + """ + return func.__module__ == "__main__" + + def handle_terminal_run( + self, + func: Callable, + func_params: Dict[str, inspect.Parameter], + config_params: Dict[str, Any], + ingestible_files: Dict, + ): + """ + Parses command line arguments and sets configuration when script is run from the terminal. + + Args: + func_params (dict): A dictionary containing the function parameters and their annotations. + config_params (dict): A dictionary containing the configuration parameters. + ingestible_files (dict): A dictionary containing the files that should be ingested. + """ + + # For required parameters, we add them as arguments + parser = argparse.ArgumentParser() + for name, param in func_params.items(): + if name in ingestible_files: + parser.add_argument(name, type=str) + else: + parser.add_argument(name, type=param.annotation) + + for name, param in config_params.items(): + if type(param) is MultipleChoiceParam: + parser.add_argument( + f"--{name}", + type=str, + default=param.default, + choices=param.choices, + ) + else: + parser.add_argument( + f"--{name}", + type=type(param), + default=param, + ) + + args = parser.parse_args() + + # split the arg list into the arg in the app_param and + # the args from the sig.parameter + args_config_params = {k: v for k, v in vars(args).items() if k in config_params} + args_func_params = { + k: v for k, v in vars(args).items() if k not in config_params + } + for name in ingestible_files: + args_func_params[name] = InFile( + file_name=Path(args_func_params[name]).stem, + file_path=args_func_params[name], + ) + + agenta.config.set(**args_config_params) + + # Set the configuration and environment of the LLM app parent span at run-time + agenta.tracing.set_span_attribute( + {"config": agenta.config.all(), "environment": "bash"} + ) + + loop = asyncio.get_event_loop() + result = loop.run_until_complete( + self.execute_function( + func, + **{"params": args_func_params, "config_params": args_config_params}, + ) + ) + print( + f"\n========== Result ==========\n\nMessage: {result.message}\nCost: {result.cost}\nToken Usage: {result.usage}" + ) + + def override_schema( + self, openapi_schema: dict, func_name: str, endpoint: str, params: dict + ): + """ + Overrides the default openai schema generated by fastapi with additional information about: + - The choices available for each MultipleChoiceParam instance + - The min and max values for each FloatParam instance + - The min and max values for each IntParam instance + - The default value for DictInput instance + - The default value for MessagesParam instance + - The default value for FileInputURL instance + - The default value for BinaryParam instance + - ... [PLEASE ADD AT EACH CHANGE] + + Args: + openapi_schema (dict): The openapi schema generated by fastapi + func_name (str): The name of the function to override + endpoint (str): The name of the endpoint to override + params (dict(param_name, param_val)): The dictionary of the parameters for the function + """ + + def find_in_schema(schema: dict, param_name: str, xparam: str): + """Finds a parameter in the schema based on its name and x-parameter value""" + for _, value in schema.items(): + value_title_lower = str(value.get("title")).lower() + value_title = ( + "_".join(value_title_lower.split()) + if len(value_title_lower.split()) >= 2 + else value_title_lower + ) + + if ( + isinstance(value, dict) + and value.get("x-parameter") == xparam + and value_title == param_name + ): + return value + + schema_to_override = openapi_schema["components"]["schemas"][ + f"Body_{func_name}_{endpoint}_post" + ]["properties"] + for param_name, param_val in params.items(): + if isinstance(param_val, GroupedMultipleChoiceParam): + subschema = find_in_schema( + schema_to_override, param_name, "grouped_choice" + ) + assert ( + subschema + ), f"GroupedMultipleChoiceParam '{param_name}' is in the parameters but could not be found in the openapi.json" + subschema["choices"] = param_val.choices + subschema["default"] = param_val.default + if isinstance(param_val, MultipleChoiceParam): + subschema = find_in_schema(schema_to_override, param_name, "choice") + default = str(param_val) + param_choices = param_val.choices + choices = ( + [default] + param_choices + if param_val not in param_choices + else param_choices + ) + subschema["enum"] = choices + subschema["default"] = ( + default if default in param_choices else choices[0] + ) + if isinstance(param_val, FloatParam): + subschema = find_in_schema(schema_to_override, param_name, "float") + subschema["minimum"] = param_val.minval + subschema["maximum"] = param_val.maxval + subschema["default"] = param_val + if isinstance(param_val, IntParam): + subschema = find_in_schema(schema_to_override, param_name, "int") + subschema["minimum"] = param_val.minval + subschema["maximum"] = param_val.maxval + subschema["default"] = param_val + if ( + isinstance(param_val, inspect.Parameter) + and param_val.annotation is DictInput + ): + subschema = find_in_schema(schema_to_override, param_name, "dict") + subschema["default"] = param_val.default["default_keys"] + if isinstance(param_val, TextParam): + subschema = find_in_schema(schema_to_override, param_name, "text") + subschema["default"] = param_val + if ( + isinstance(param_val, inspect.Parameter) + and param_val.annotation is MessagesInput + ): + subschema = find_in_schema(schema_to_override, param_name, "messages") + subschema["default"] = param_val.default + if ( + isinstance(param_val, inspect.Parameter) + and param_val.annotation is FileInputURL + ): + subschema = find_in_schema(schema_to_override, param_name, "file_url") + subschema["default"] = "https://example.com" + if isinstance(param_val, BinaryParam): + subschema = find_in_schema(schema_to_override, param_name, "bool") + subschema["default"] = param_val.default diff --git a/agenta-cli/agenta/sdk/decorators/tracing.py b/agenta-cli/agenta/sdk/decorators/tracing.py new file mode 100644 index 0000000000..199599f66e --- /dev/null +++ b/agenta-cli/agenta/sdk/decorators/tracing.py @@ -0,0 +1,98 @@ +# Stdlib Imports +import inspect +from functools import wraps +from typing import Any, Callable, Optional + +# Own Imports +import agenta as ag +from agenta.sdk.decorators.base import BaseDecorator + + +class instrument(BaseDecorator): + """Decorator class for monitoring llm apps functions. + + Args: + BaseDecorator (object): base decorator class + + Example: + ```python + import agenta as ag + + prompt_config = {"system_prompt": ..., "temperature": 0.5, "max_tokens": ...} + + @ag.instrument(spankind="llm") + async def litellm_openai_call(prompt:str) -> str: + return "do something" + + @ag.instrument(config=prompt_config) # spankind for parent span defaults to workflow + async def generate(prompt: str): + return ... + ``` + """ + + def __init__( + self, config: Optional[dict] = None, spankind: str = "workflow" + ) -> None: + self.config = config + self.spankind = spankind + self.tracing = ag.tracing + + def __call__(self, func: Callable[..., Any]): + is_coroutine_function = inspect.iscoroutinefunction(func) + + @wraps(func) + async def async_wrapper(*args, **kwargs): + result = None + func_args = inspect.getfullargspec(func).args + input_dict = {name: value for name, value in zip(func_args, args)} + input_dict.update(kwargs) + + span = self.tracing.start_span( + name=func.__name__, + input=input_dict, + spankind=self.spankind, + config=self.config, + ) + + try: + result = await func(*args, **kwargs) + self.tracing.update_span_status(span=span, value="OK") + except Exception as e: + result = str(e) + self.tracing.update_span_status(span=span, value="ERROR") + finally: + self.tracing.end_span( + outputs=( + {"message": result} if not isinstance(result, dict) else result + ) + ) + return result + + @wraps(func) + def sync_wrapper(*args, **kwargs): + result = None + func_args = inspect.getfullargspec(func).args + input_dict = {name: value for name, value in zip(func_args, args)} + input_dict.update(kwargs) + + span = self.tracing.start_span( + name=func.__name__, + input=input_dict, + spankind=self.spankind, + config=self.config, + ) + + try: + result = func(*args, **kwargs) + self.tracing.update_span_status(span=span, value="OK") + except Exception as e: + result = str(e) + self.tracing.update_span_status(span=span, value="ERROR") + finally: + self.tracing.end_span( + outputs=( + {"message": result} if not isinstance(result, dict) else result + ) + ) + + return async_wrapper if is_coroutine_function else sync_wrapper diff --git a/agenta-cli/agenta/sdk/tracing/decorators.py b/agenta-cli/agenta/sdk/tracing/decorators.py deleted file mode 100644 index 338033be9b..0000000000 --- a/agenta-cli/agenta/sdk/tracing/decorators.py +++ /dev/null @@ -1,41 +0,0 @@ -# Stdlib Imports -import inspect -from functools import wraps - -# Own Imports -import agenta as ag - - -def span(type: str): - """Decorator to automatically start and end spans.""" - - tracing = ag.llm_tracing() - - def decorator(func): - @wraps(func) - async def wrapper(*args, **kwargs): - result = None - span = tracing.start_span( - name=func.__name__, - input=kwargs, - spankind=type, - ) - try: - is_coroutine_function = inspect.iscoroutinefunction(func) - if is_coroutine_function: - result = await func(*args, **kwargs) - else: - result = func(*args, **kwargs) - tracing.update_span_status(span=span, value="OK") - except Exception as e: - result = str(e) - tracing.update_span_status(span=span, value="ERROR") - finally: - if not isinstance(result, dict): - result = {"message": result} - tracing.end_span(outputs=result, span=span) - return result - - return wrapper - - return decorator diff --git a/agenta-cli/agenta/sdk/tracing/llm_tracing.py b/agenta-cli/agenta/sdk/tracing/llm_tracing.py index ee1dc58683..263334adac 100644 --- a/agenta-cli/agenta/sdk/tracing/llm_tracing.py +++ b/agenta-cli/agenta/sdk/tracing/llm_tracing.py @@ -1,59 +1,79 @@ -# Stdlib Imports +import os +from threading import Lock from datetime import datetime, timezone from typing import Optional, Dict, Any, List, Union -# Own Imports from agenta.sdk.tracing.logger import llm_logger from agenta.sdk.tracing.tasks_manager import TaskQueue from agenta.client.backend.client import AsyncAgentaApi from agenta.client.backend.client import AsyncObservabilityClient from agenta.client.backend.types.create_span import CreateSpan, SpanKind, SpanStatusCode -# Third Party Imports from bson.objectid import ObjectId +VARIANT_TRACKING_FEATURE_FLAG = False -class Tracing(object): - """Agenta llm tracing object. - Args: - base_url (str): The URL of the backend host +class SingletonMeta(type): + """ + Thread-safe implementation of Singleton. + """ + + _instances = {} # type: ignore + + # We need the lock mechanism to synchronize threads \ + # during the initial access to the Singleton object. + _lock: Lock = Lock() + + def __call__(cls, *args, **kwargs): + """ + Ensures that changes to the `__init__` arguments do not affect the + returned instance. + + Uses a lock to make this method thread-safe. If an instance of the class + does not already exist, it creates one. Otherwise, it returns the + existing instance. + """ + + with cls._lock: + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + return cls._instances[cls] + + +class Tracing(metaclass=SingletonMeta): + """The `Tracing` class is an agent for LLM tracing with specific initialization arguments. + + __init__ args: + host (str): The URL of the backend host api_key (str): The API Key of the backend host tasks_manager (TaskQueue): The tasks manager dedicated to handling asynchronous tasks llm_logger (Logger): The logger associated with the LLM tracing max_workers (int): The maximum number of workers to run tracing """ - _instance = None - - def __new__(cls, *args, **kwargs): - if not cls._instance: - cls._instance = super().__new__(cls) - return cls._instance - def __init__( self, - base_url: str, + host: str, app_id: str, - variant_id: str, - variant_name: Optional[str] = None, api_key: Optional[str] = None, max_workers: Optional[int] = None, ): - self.base_url = base_url + "/api" + self.host = host + "/api" self.api_key = api_key if api_key is not None else "" self.llm_logger = llm_logger self.app_id = app_id - self.variant_id = variant_id - self.variant_name = variant_name self.tasks_manager = TaskQueue( max_workers if max_workers else 4, logger=llm_logger ) - self.active_span = CreateSpan - self.active_trace = CreateSpan - self.recording_trace_id: Union[str, None] = None - self.recorded_spans: List[CreateSpan] = [] + self.active_span: Optional[CreateSpan] = None + self.active_trace_id: Optional[str] = None + self.pending_spans: List[CreateSpan] = [] self.tags: List[str] = [] + self.trace_config_cache: Dict[ + str, Any + ] = {} # used to save the trace configuration before starting the first span self.span_dict: Dict[str, CreateSpan] = {} # type: ignore @property @@ -65,131 +85,130 @@ def client(self) -> AsyncObservabilityClient: """ return AsyncAgentaApi( - base_url=self.base_url, api_key=self.api_key, timeout=120 # type: ignore + base_url=self.host, api_key=self.api_key, timeout=120 # type: ignore ).observability def set_span_attribute( - self, parent_key: Optional[str] = None, attributes: Dict[str, Any] = {} - ): - span = self.span_dict[self.active_span.id] # type: ignore - for key, value in attributes.items(): - self.set_attribute(span.attributes, key, value, parent_key) # type: ignore - - def set_attribute( self, - span_attributes: Dict[str, Any], - key: str, - value: Any, - parent_key: Optional[str] = None, + attributes: Dict[str, Any] = {}, ): - if parent_key is not None: - model_config = span_attributes.get(parent_key, None) - if not model_config: - span_attributes[parent_key] = {} - span_attributes[parent_key][key] = value + if ( + self.active_span is None + ): # This is the case where entrypoint wants to save the trace information but the parent span has not been initialized yet + for key, value in attributes.items(): + self.trace_config_cache[key] = value else: - span_attributes[key] = value + for key, value in attributes.items(): + self.active_span.attributes[key] = value def set_trace_tags(self, tags: List[str]): self.tags.extend(tags) - def start_parent_span( - self, name: str, inputs: Dict[str, Any], config: Dict[str, Any], **kwargs - ): - trace_id = self._create_trace_id() - span_id = self._create_span_id() - self.llm_logger.info("Recording parent span...") - span = CreateSpan( - id=span_id, - app_id=self.app_id, - variant_id=self.variant_id, - variant_name=self.variant_name, - inputs=inputs, - name=name, - config=config, - environment=kwargs.get("environment"), - spankind=SpanKind.WORKFLOW.value, - status=SpanStatusCode.UNSET.value, - start_time=datetime.now(timezone.utc), - ) - self.active_trace = span - self.recording_trace_id = trace_id - self.parent_span_id = span.id - self.llm_logger.info( - f"Recorded active_trace and setting parent_span_id: {span.id}" - ) - def start_span( self, name: str, spankind: str, input: Dict[str, Any], - config: Dict[str, Any] = {}, + config: Optional[Dict[str, Any]] = None, + **kwargs, ) -> CreateSpan: span_id = self._create_span_id() - self.llm_logger.info(f"Recording {spankind} span...") + self.llm_logger.info( + f"Recording {'parent' if spankind == 'workflow' else spankind} span..." + ) span = CreateSpan( id=span_id, inputs=input, name=name, app_id=self.app_id, - variant_id=self.variant_id, - variant_name=self.variant_name, config=config, - environment=self.active_trace.environment, - parent_span_id=self.parent_span_id, spankind=spankind.upper(), attributes={}, status=SpanStatusCode.UNSET.value, start_time=datetime.now(timezone.utc), + outputs=None, + tags=None, + user=None, + end_time=None, + tokens=None, + cost=None, + token_consumption=None, + parent_span_id=None, ) - self.active_span = span + if self.active_trace_id is None: # This is a parent span + self.active_trace_id = self._create_trace_id() + span.environment = ( + self.trace_config_cache.get("environment") + if self.trace_config_cache is not None + else os.environ.get("environment", "unset") + ) + span.config = ( + self.trace_config_cache.get("config") + if not config and self.trace_config_cache is not None + else None + ) + if VARIANT_TRACKING_FEATURE_FLAG: + # TODO: we should get the variant_id and variant_name (and environment) from the config object + span.variant_id = config.variant_id + span.variant_name = (config.variant_name,) + + else: + span.parent_span_id = self.active_span.id self.span_dict[span.id] = span - self.parent_span_id = span.id - self.llm_logger.info( - f"Recorded active_span and setting parent_span_id: {span.id}" - ) + self.active_span = span + + self.llm_logger.info(f"Recorded span and setting parent_span_id: {span.id}") return span def update_span_status(self, span: CreateSpan, value: str): - updated_span = CreateSpan(**{**span.dict(), "status": value}) - self.active_span = updated_span - - def end_span(self, outputs: Dict[str, Any], span: CreateSpan, **kwargs): - updated_span = CreateSpan( - **span.dict(), - end_time=datetime.now(timezone.utc), - outputs=[outputs["message"]], - cost=outputs.get("cost", None), - tokens=outputs.get("usage"), - ) + span.status = value + + def end_span(self, outputs: Dict[str, Any]): + """ + Ends the active span, if it is a parent span, ends the trace too. + """ + if self.active_span is None: + raise ValueError("There is no active span to end.") + self.active_span.end_time = datetime.now(timezone.utc) + self.active_span.outputs = [outputs.get("message", "")] + self.active_span.cost = outputs.get("cost", None) + self.active_span.tokens = outputs.get("usage", None) # Push span to list of recorded spans - self.recorded_spans.append(updated_span) + self.pending_spans.append(self.active_span) self.llm_logger.info( - f"Pushed {updated_span.spankind} span {updated_span.id} to recorded spans." + f"Pushed {self.active_span.spankind} span {self.active_span.id} to recorded spans." ) + if self.active_span.parent_span_id is None: + self.end_trace(parent_span=self.active_span) + else: + self.active_span = self.span_dict[self.active_span.parent_span_id] - def end_recording(self, outputs: Dict[str, Any], span: CreateSpan, **kwargs): - self.end_span(outputs=outputs, span=span, **kwargs) + def end_trace(self, parent_span: CreateSpan): if self.api_key == "": return - self.llm_logger.info(f"Preparing to send recorded spans for processing.") - self.llm_logger.info(f"Recorded spans => {len(self.recorded_spans)}") + if not self.active_trace_id: + raise RuntimeError("No active trace to end.") + + self.llm_logger.info("Preparing to send recorded spans for processing.") + self.llm_logger.info(f"Recorded spans => {len(self.pending_spans)}") self.tasks_manager.add_task( - self.active_trace.id, + self.active_trace_id, "trace", self.client.create_traces( - trace=self.recording_trace_id, spans=self.recorded_spans # type: ignore + trace=self.active_trace_id, spans=self.pending_spans # type: ignore ), self.client, ) self.llm_logger.info( - f"Tracing for {span.id} recorded successfully and sent for processing." + f"Tracing for {parent_span.id} recorded successfully and sent for processing." ) - self._clear_recorded_spans() + self._clear_pending_spans() + self.active_trace_id = None + self.active_span = None + self.trace_config_cache.clear() def _create_trace_id(self) -> str: """Creates a unique mongo id for the trace object. @@ -209,12 +228,12 @@ def _create_span_id(self) -> str: return str(ObjectId()) - def _clear_recorded_spans(self) -> None: + def _clear_pending_spans(self) -> None: """ Clear the list of recorded spans to prepare for next batch processing. """ - self.recorded_spans = [] + self.pending_spans = [] self.llm_logger.info( - f"Cleared all recorded spans from batch: {self.recorded_spans}" + f"Cleared all recorded spans from batch: {self.pending_spans}" ) diff --git a/agenta-cli/agenta/sdk/tracing/logger.py b/agenta-cli/agenta/sdk/tracing/logger.py index 1dc83e8cfd..a2038989ae 100644 --- a/agenta-cli/agenta/sdk/tracing/logger.py +++ b/agenta-cli/agenta/sdk/tracing/logger.py @@ -2,7 +2,7 @@ class LLMLogger: - def __init__(self, name="LLMLogger", level=logging.INFO): + def __init__(self, name="LLMLogger", level=logging.DEBUG): self.logger = logging.getLogger(name) self.logger.setLevel(level) diff --git a/agenta-cli/agenta/sdk/tracing/tasks_manager.py b/agenta-cli/agenta/sdk/tracing/tasks_manager.py index a7d807f995..29f988c5de 100644 --- a/agenta-cli/agenta/sdk/tracing/tasks_manager.py +++ b/agenta-cli/agenta/sdk/tracing/tasks_manager.py @@ -106,9 +106,7 @@ def _worker(self): future.result() except Exception as exc: self._logger.error(f"Error running task: {str(exc)}") - self._logger.error( - f"Recording trace {task.coroutine_type} status to ERROR." - ) + self._logger.error(f"Recording {task.coroutine_type} status to ERROR.") break finally: self.tasks.task_done() diff --git a/agenta-cli/agenta/sdk/utils/globals.py b/agenta-cli/agenta/sdk/utils/globals.py index 58b8e31bb5..7b1bb3ff00 100644 --- a/agenta-cli/agenta/sdk/utils/globals.py +++ b/agenta-cli/agenta/sdk/utils/globals.py @@ -1,7 +1,7 @@ import agenta -def set_global(setup=None, config=None): +def set_global(setup=None, config=None, tracing=None): """Allows usage of agenta.config and agenta.setup in the user's code. Args: @@ -12,3 +12,5 @@ def set_global(setup=None, config=None): agenta.setup = setup if config is not None: agenta.config = config + if tracing is not None: + agenta.tracing = tracing diff --git a/agenta-cli/pyproject.toml b/agenta-cli/pyproject.toml index b284d037f4..95c8edcfd2 100644 --- a/agenta-cli/pyproject.toml +++ b/agenta-cli/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "agenta" -version = "0.14.14" +version = "0.15.0a4" description = "The SDK for agenta is an open-source LLMOps platform." readme = "README.md" authors = ["Mahmoud Mabrouk "] diff --git a/examples/app_with_observability/app_async.py b/examples/app_with_observability/app_async.py index 0e437d7c5e..0cf24a9ddb 100644 --- a/examples/app_with_observability/app_async.py +++ b/examples/app_with_observability/app_async.py @@ -8,19 +8,23 @@ ) ag.init() -tracing = ag.llm_tracing() ag.config.default( temperature=ag.FloatParam(0.2), prompt_template=ag.TextParam(default_prompt) ) -@ag.span(type="LLM") +@ag.instrument(spankind="llm") async def llm_call(prompt): chat_completion = await client.chat.completions.create( model="gpt-3.5-turbo", messages=[{"role": "user", "content": prompt}] ) - tracing.set_span_attribute( - "model_config", {"model": "gpt-3.5-turbo", "temperature": ag.config.temperature} + ag.tracing.set_span_attribute( + { + "model_config": { + "model": "gpt-3.5-turbo", + "temperature": ag.config.temperature, + } + } ) # translate to {"model_config": {"model": "gpt-3.5-turbo", "temperature": 0.2}} tokens_usage = chat_completion.usage.dict() return { @@ -31,6 +35,7 @@ async def llm_call(prompt): @ag.entrypoint +@ag.instrument() async def generate(country: str, gender: str): """ Generate a baby name based on the given country and gender. diff --git a/examples/app_with_observability/app_nested_async.py b/examples/app_with_observability/app_nested_async.py index 8bdc4291b1..e0e88a7db3 100644 --- a/examples/app_with_observability/app_nested_async.py +++ b/examples/app_with_observability/app_nested_async.py @@ -19,7 +19,6 @@ ] ag.init() -tracing = ag.llm_tracing() ag.config.default( temperature_1=ag.FloatParam(default=1, minval=0.0, maxval=2.0), model_1=ag.MultipleChoiceParam("gpt-3.5-turbo", CHAT_LLM_GPT), @@ -38,7 +37,7 @@ ) -@ag.span(type="llm") +@ag.instrument(spankind="llm") async def llm_call( prompt: str, model: str, @@ -57,8 +56,8 @@ async def llm_call( frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, ) - tracing.set_span_attribute( - "model_config", {"model": model, "temperature": temperature} + ag.tracing.set_span_attribute( + {"model_config": {"model": model, "temperature": temperature}} ) tokens_usage = response.usage.dict() # type: ignore return { @@ -68,7 +67,7 @@ async def llm_call( } -@ag.span(type="chain") +@ag.instrument(spankind="chain") async def finalize_wrapper(context_1: str, max_tokens: int, llm_response: str): prompt = ag.config.prompt_user_2.format(topics=llm_response, context_1=context_1) response = await llm_call( @@ -83,7 +82,7 @@ async def finalize_wrapper(context_1: str, max_tokens: int, llm_response: str): return response -@ag.span(type="chain") +@ag.instrument(spankind="chain") async def wrapper(context_1: str, max_tokens: int): prompt = ag.config.prompt_user_1.format(context_1=context_1) @@ -105,6 +104,7 @@ async def wrapper(context_1: str, max_tokens: int): @ag.entrypoint +@ag.instrument() async def generate(context_1: str): """ Generate a baby name based on the given country and gender. diff --git a/examples/app_with_observability/dict_app_async.py b/examples/app_with_observability/dict_app_async.py index 5d42a00c97..0671ce7463 100644 --- a/examples/app_with_observability/dict_app_async.py +++ b/examples/app_with_observability/dict_app_async.py @@ -1,7 +1,12 @@ import agenta as ag import litellm -ag.init() + +ag.init( + app_id="xxxxxxxx", + host="https://cloud.agenta.ai", + api_key="xxxxx.xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", +) prompts = { "system_prompt": "You are an expert in geography.", @@ -33,7 +38,7 @@ ) -@ag.span(type="llm_request") +@ag.instrument(spankind="llm") async def litellm_call(prompt_system: str, prompt_user: str): max_tokens = ag.config.max_tokens if ag.config.max_tokens != -1 else None if ag.config.force_json and ag.config.model not in GPT_FORMAT_RESPONSE: @@ -58,7 +63,14 @@ async def litellm_call(prompt_system: str, prompt_user: str): presence_penalty=ag.config.presence_penalty, response_format=response_format, ) - + ag.tracing.set_span_attribute( + { + "model_config": { + "model": ag.config.model, + "temperature": ag.config.temperature, + } + } + ) tokens_usage = response.usage.dict() return { "cost": ag.calculate_token_usage(ag.config.model, tokens_usage), @@ -68,6 +80,7 @@ async def litellm_call(prompt_system: str, prompt_user: str): @ag.entrypoint +@ag.instrument() async def generate( inputs: ag.DictInput = ag.DictInput(default_keys=["country"]), ): diff --git a/examples/app_with_observability/workflows/tracing_from_bash.py b/examples/app_with_observability/workflows/tracing_from_bash.py new file mode 100644 index 0000000000..35f56b8b45 --- /dev/null +++ b/examples/app_with_observability/workflows/tracing_from_bash.py @@ -0,0 +1,60 @@ +import asyncio +import agenta as ag +from openai import AsyncOpenAI + + +client = AsyncOpenAI() + +default_prompt = ( + "Give me 10 names for a baby from this country {country} with gender {gender}!!!!" +) + +ag.init( + app_id="xxxxxxxx", + host="https://cloud.agenta.ai", + api_key="xxxxx.xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", +) +ag.config.default( + temperature=ag.FloatParam(0.2), prompt_template=ag.TextParam(default_prompt) +) + + +@ag.instrument(spankind="llm") +async def llm_call(prompt): + chat_completion = await client.chat.completions.create( + model="gpt-3.5-turbo", messages=[{"role": "user", "content": prompt}] + ) + ag.tracing.set_span_attribute( + { + "model_config": { + "model": "gpt-3.5-turbo", + "temperature": ag.config.temperature, + } + } + ) # translates to {"model_config": {"model": "gpt-3.5-turbo", "temperature": 0.2}} + tokens_usage = chat_completion.usage.dict() + return { + "cost": ag.calculate_token_usage("gpt-3.5-turbo", tokens_usage), + "message": chat_completion.choices[0].message.content, + "usage": tokens_usage, + } + + +@ag.instrument(config=ag.config.all()) +async def generate(country: str, gender: str): + """ + Generate a baby name based on the given country and gender. + + Args: + country (str): The country to generate the name from. + gender (str): The gender of the baby. + """ + + prompt = ag.config.prompt_template.format(country=country, gender=gender) + response = await llm_call(prompt=prompt) + return response + + +if __name__ == "__main__": + loop = asyncio.get_event_loop() + loop.run_until_complete(generate(country="Germany", gender="Male")) diff --git a/examples/app_with_observability/workflows/tracing_from_hosted.py b/examples/app_with_observability/workflows/tracing_from_hosted.py new file mode 100644 index 0000000000..5a49eb1182 --- /dev/null +++ b/examples/app_with_observability/workflows/tracing_from_hosted.py @@ -0,0 +1,49 @@ +import os +import requests +import agenta as ag + + +API_URL = "https://xxxxxxx.xxx" +llm_config = {"environment": "production"} + +ag.init( + app_id="xxxxxxxx", + host="https://cloud.agenta.ai", + api_key="xxxxx.xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", +) + + +def hosted_platform_call(content: str): + ag.tracing.start_span( + name="gpt3.5-llm-call", + spankind="llm", + input={"content": content}, + ) + response = requests.post( + url=API_URL, + json={ + "inputs": [{"role": "user", "content": content}], + "environment": llm_config["environment"], + }, + ) + ag.tracing.end_span(outputs=response.json()) + return response.json() + + +def query(content: str): + ag.tracing.start_span( + name="query", + input={"content": content}, + spankind="workflow", + config=llm_config, + ) + response = hosted_platform_call(content=content) + ag.tracing.end_span(outputs=response) + return response + + +if __name__ == "__main__": + result = query( + content="How is a vector database used when building LLM applications?" + ) + print("\n\n============== Result ============== \n\n", result["message"]) diff --git a/examples/app_with_observability/workflows/tracing_plus_entrypoint.py b/examples/app_with_observability/workflows/tracing_plus_entrypoint.py new file mode 100644 index 0000000000..b2e5f94dae --- /dev/null +++ b/examples/app_with_observability/workflows/tracing_plus_entrypoint.py @@ -0,0 +1,52 @@ +import agenta as ag +from openai import AsyncOpenAI + + +client = AsyncOpenAI() + +default_prompt = ( + "Give me 10 names for a baby from this country {country} with gender {gender}!!!!" +) + +ag.init() + +ag.config.default( + temperature=ag.FloatParam(0.2), prompt_template=ag.TextParam(default_prompt) +) + + +@ag.instrument(spankind="llm") +async def llm_call(prompt): + chat_completion = await client.chat.completions.create( + model="gpt-3.5-turbo", messages=[{"role": "user", "content": prompt}] + ) + ag.tracing.set_span_attribute( + { + "model_config": { + "model": "gpt-3.5-turbo", + "temperature": ag.config.temperature, + } + } + ) # translates to {"model_config": {"model": "gpt-3.5-turbo", "temperature": 0.2}} + tokens_usage = chat_completion.usage.dict() + return { + "cost": ag.calculate_token_usage("gpt-3.5-turbo", tokens_usage), + "message": chat_completion.choices[0].message.content, + "usage": tokens_usage, + } + + +@ag.entrypoint +@ag.instrument() +async def generate(country: str, gender: str): + """ + Generate a baby name based on the given country and gender. + + Args: + country (str): The country to generate the name from. + gender (str): The gender of the baby. + """ + + prompt = ag.config.prompt_template.format(country=country, gender=gender) + response = await llm_call(prompt=prompt) + return response diff --git a/examples/baby_name_generator/app.py b/examples/baby_name_generator/app.py index e9a553a005..f20ac6fd09 100644 --- a/examples/baby_name_generator/app.py +++ b/examples/baby_name_generator/app.py @@ -8,7 +8,7 @@ "Give me 10 names for a baby from this country {country} with gender {gender}!!!!" ) -ag.init(app_name="test", base_name="app") +ag.init() ag.config.default( temperature=FloatParam(0.2), prompt_template=TextParam(default_prompt) ) diff --git a/examples/qa_generator_chain_of_prompts/app.py b/examples/qa_generator_chain_of_prompts/app.py index 68e1040d65..a7e605c15c 100644 --- a/examples/qa_generator_chain_of_prompts/app.py +++ b/examples/qa_generator_chain_of_prompts/app.py @@ -7,7 +7,7 @@ prompt_1 = "Determine the three main topics that a user would ask about based on this documentation page {context_1}" prompt_2 = "Create 10 Question and Answers based on the following topics {topics} and the documentation page {context_1} " -ag.init(app_name="test", base_name="app") +ag.init() CHAT_LLM_GPT = [ "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k",