Skip to content

Commit

Permalink
Refactor: implemented sdkclient class to build sdk clien instance and…
Browse files Browse the repository at this point in the history
… rewrite agenta singleton init
  • Loading branch information
aybruhm committed Apr 26, 2024
1 parent c90fcfc commit 447282b
Showing 1 changed file with 78 additions and 63 deletions.
141 changes: 78 additions & 63 deletions agenta-cli/agenta/sdk/agenta_init.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import os
import logging
from typing import Any, Optional

from .utils.globals import set_global
from typing import Optional

from agenta.sdk.utils.globals import set_global
from agenta.client.backend.client import AgentaApi
from agenta.sdk.tracing.llm_tracing import Tracing
from agenta.client.exceptions import APIRequestError
Expand All @@ -13,16 +12,17 @@
logger.setLevel(logging.DEBUG)


BACKEND_URL_SUFFIX = os.environ.get("BACKEND_URL_SUFFIX", "api")
CLIENT_API_KEY = os.environ.get("AGENTA_API_KEY")
CLIENT_HOST = os.environ.get("AGENTA_HOST", "http://localhost")
class SDKClient(object):
"""
Class to build SDK client instance.
"""

def __init__(self, api_key: str, host: str):
self.api_key = api_key
self.host = host + "/api"

# initialize the client with the backend url and api key
backend_url = f"{CLIENT_HOST}/{BACKEND_URL_SUFFIX}"
client = AgentaApi(
base_url=backend_url,
api_key=CLIENT_API_KEY if CLIENT_API_KEY else "",
)
def _build_sdk_client(self) -> AgentaApi:
return AgentaApi(base_url=self.host, api_key=self.api_key)


class AgentaSingleton:
Expand All @@ -37,73 +37,67 @@ def __new__(cls):
cls._instance = super(AgentaSingleton, cls).__new__(cls)
return cls._instance

@property
def client(self):
"""Builds sdk client instance.
Returns:
AgentaAPI: instance of agenta api backend
"""

sdk_client = SDKClient(api_key=self.api_key, host=self.host) # type: ignore
return sdk_client._build_sdk_client()

def init(
self,
app_name: Optional[str] = None,
base_name: Optional[str] = None,
api_key: Optional[str] = None,
base_id: Optional[str] = None,
app_id: Optional[str] = None,
host: Optional[str] = None,
**kwargs: Any,
api_key: Optional[str] = None,
) -> None:
"""Main function to initialize the singleton.
Initializes the singleton with the given `app_name`, `base_name`, and `host`. If any of these arguments are not provided,
the function will look for them in environment variables.
Args:
app_name (Optional[str]): Name of the Agenta application. Defaults to None. If not provided, will look for "AGENTA_APP_NAME" in environment variables.
base_name (Optional[str]): Base name for the Agenta setup. Defaults to None. If not provided, will look for "AGENTA_BASE_NAME" in environment variables.
app_id (Optional[str]): ID of the Agenta application. Defaults to None. If not provided, will look for "AGENTA_APP_NAME" in environment variables.
host (Optional[str]): Host name of the backend server. Defaults to None. If not provided, will look for "AGENTA_HOST" in environment variables.
api_key (Optional[str]): API Key to use with the host of the backend server.
kwargs (Any): Additional keyword arguments.
Raises:
ValueError: If `app_name`, `base_name`, or `host` are not specified either as arguments or in the environment variables.
"""
if app_name is None:
app_name = os.environ.get("AGENTA_APP_NAME")
if base_name is None:
base_name = os.environ.get("AGENTA_BASE_NAME")
if api_key is None:
api_key = os.environ.get("AGENTA_API_KEY")
if base_id is None:
base_id = os.environ.get("AGENTA_BASE_ID")
if host is None:
host = os.environ.get("AGENTA_HOST", "http://localhost")

if base_id is None:
if app_name is None or base_name is None:
print(
f"Warning: Your configuration will not be saved permanently since app_name and base_name are not provided."

self.api_key = api_key or os.environ.get("AGENTA_API_KEY")
self.host = host or os.environ.get("AGENTA_HOST", "http://localhost")

app_id = app_id or os.environ.get("AGENTA_APP_ID")
if not app_id:
raise ValueError("App ID must be specified.")

base_id = os.environ.get("AGENTA_BASE_ID")
base_name = os.environ.get("AGENTA_BASE_NAME")
if base_id is None and (app_id is None or base_name is None):
print(
f"Warning: Your configuration will not be saved permanently since app_name and base_name are not provided."
)
else:
try:
base_id = self.get_app_base(app_id, base_name) # type: ignore
except Exception as ex:
raise APIRequestError(
f"Failed to get base id and/or app_id from the server with error: {ex}"
)
else:
try:
app_id = self.get_app(app_name)
base_id = self.get_app_base(app_id, base_name)
except Exception as ex:
raise APIRequestError(
f"Failed to get base id and/or app_id from the server with error: {ex}"
)

self.app_id = app_id
self.base_id = base_id
self.host = host
self.app_id = os.environ.get("AGENTA_APP_ID") if app_id is None else app_id
self.variant_id = os.environ.get("AGENTA_VARIANT_ID")
self.variant_name = os.environ.get("AGENTA_VARIANT_NAME")
self.api_key = api_key
self.config = Config(base_id=base_id, host=host)

def get_app(self, app_name: str) -> str:
apps = client.apps.list_apps(app_name=app_name)
if len(apps) == 0:
raise APIRequestError(f"App with name {app_name} not found")

app_id = apps[0].app_id
return app_id

def get_app_base(self, app_id: str, base_name: str) -> str:
bases = client.bases.list_bases(app_id=app_id, base_name=base_name)
bases = self.client.bases.list_bases(app_id=app_id, base_name=base_name)
if len(bases) == 0:
raise APIRequestError(f"No base was found for the app {app_id}")
return bases[0].base_id
Expand All @@ -122,11 +116,23 @@ class Config:
def __init__(self, base_id, host):
self.base_id = base_id
self.host = host

if base_id is None or host is None:
self.persist = False
else:
self.persist = True

@property
def client(self):
"""Builds sdk client instance.
Returns:
AgentaAPI: instance of agenta api backend
"""

sdk_client = SDKClient(api_key=self.api_key, host=self.host) # type: ignore
return sdk_client._build_sdk_client()

def register_default(self, overwrite=False, **kwargs):
"""alias for default"""
return self.default(overwrite=overwrite, **kwargs)
Expand Down Expand Up @@ -157,7 +163,7 @@ def push(self, config_name: str, overwrite=True, **kwargs):
if not self.persist:
return
try:
client.configs.save_config(
self.client.configs.save_config(
base_id=self.base_id,
config_name=config_name,
parameters=kwargs,
Expand All @@ -168,7 +174,9 @@ def push(self, config_name: str, overwrite=True, **kwargs):
"Failed to push the configuration to the server with error: " + str(ex)
)

def pull(self, config_name: str = "default", environment_name: str = None):
def pull(
self, config_name: str = "default", environment_name: Optional[str] = None
):
"""Pulls the parameters for the app variant from the server and sets them to the config"""
if not self.persist and (
config_name != "default" or environment_name is not None
Expand All @@ -179,12 +187,12 @@ def pull(self, config_name: str = "default", environment_name: str = None):
if self.persist:
try:
if environment_name:
config = client.configs.get_config(
config = self.client.configs.get_config(
base_id=self.base_id, environment_name=environment_name
)

else:
config = client.configs.get_config(
config = self.client.configs.get_config(
base_id=self.base_id,
config_name=config_name,
)
Expand Down Expand Up @@ -218,15 +226,22 @@ def set(self, **kwargs):
setattr(self, key, value)


def init(app_name=None, base_name=None, **kwargs):
def init(
app_id: Optional[str] = None,
host: Optional[str] = None,
api_key: Optional[str] = None,
):
"""Main function to be called by the user to initialize the sdk.
Args:
app_name: _description_. Defaults to None.
base_name: _description_. Defaults to None.
app_id (str): The Id of the app.
host (str): The host of the backend server.
api_key (str): The API key to use for the backend server.
"""

singleton = AgentaSingleton()
singleton.init(app_name=app_name, base_name=base_name, **kwargs)

singleton.init(app_id=app_id, host=host, api_key=api_key)
set_global(setup=singleton.setup, config=singleton.config)


Expand All @@ -235,7 +250,7 @@ def llm_tracing(max_workers: Optional[int] = None) -> Tracing:

singleton = AgentaSingleton()
return Tracing(
base_url=singleton.host,
base_url=singleton.host, # type: ignore
app_id=singleton.app_id, # type: ignore
variant_id=singleton.variant_id, # type: ignore
variant_name=singleton.variant_name,
Expand Down

0 comments on commit 447282b

Please sign in to comment.