Skip to content

Commit

Permalink
Merge branch 'evaluations-in-backend' of https://github.com/Agenta-AI…
Browse files Browse the repository at this point in the history
…/agenta into evaluations-in-backend
  • Loading branch information
MohammedMaaz committed Jan 8, 2024
2 parents c16f0c4 + 0b753a5 commit 206ae00
Show file tree
Hide file tree
Showing 45 changed files with 1,152 additions and 1,530 deletions.
8 changes: 7 additions & 1 deletion agenta-backend/agenta_backend/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from celery import Celery
import asyncio
from contextlib import asynccontextmanager

from agenta_backend.config import settings
Expand All @@ -20,6 +20,7 @@
configs_router,
health_router,
)
from agenta_backend.models.db_engine import DBEngine

if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
from agenta_backend.commons.services import templates_manager
Expand All @@ -29,6 +30,9 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from celery import Celery


origins = [
"http://localhost:3000",
"http://localhost:3001",
Expand All @@ -48,6 +52,8 @@ async def lifespan(application: FastAPI, cache=True):
application: FastAPI application.
cache: A boolean value that indicates whether to use the cached data or not.
"""
# initialize the database
await DBEngine().init_db()
await templates_manager.update_and_sync_templates(cache=cache)
yield

Expand Down
37 changes: 1 addition & 36 deletions agenta-backend/agenta_backend/models/api/evaluation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,8 @@ class EvaluatorConfig(BaseModel):
updated_at: datetime


class EvaluationTypeSettings(BaseModel):
similarity_threshold: Optional[float]
regex_pattern: Optional[str]
regex_should_match: Optional[bool]
webhook_url: Optional[str]
custom_code_evaluation_id: Optional[str]
llm_app_prompt_template: Optional[str]
evaluation_prompt_template: Optional[str]


class EvaluationType(str, Enum):
auto_exact_match = "auto_exact_match"
auto_similarity_match = "auto_similarity_match"
auto_regex_test = "auto_regex_test"
auto_webhook_test = "auto_webhook_test"
auto_ai_critique = "auto_ai_critique"
human_a_b_testing = "human_a_b_testing"
human_scoring = "human_scoring"
custom_code_run = "custom_code_run"
single_model_test = "single_model_test"


Expand All @@ -63,7 +46,6 @@ class NewHumanEvaluation(BaseModel):
app_id: str
variant_ids: List[str]
evaluation_type: EvaluationType
evaluation_type_settings: Optional[EvaluationTypeSettings]
inputs: List[str]
testset_id: str
status: str
Expand Down Expand Up @@ -99,7 +81,6 @@ class SimpleEvaluationOutput(BaseModel):

class HumanEvaluationUpdate(BaseModel):
status: Optional[EvaluationStatusEnum]
evaluation_type_settings: Optional[EvaluationTypeSettings]


class EvaluationScenarioResult(BaseModel):
Expand Down Expand Up @@ -134,7 +115,6 @@ class HumanEvaluation(BaseModel):
user_id: str
user_username: str
evaluation_type: EvaluationType
evaluation_type_settings: Optional[EvaluationTypeSettings]
variant_ids: List[str]
variant_names: List[str]
testset_id: str
Expand Down Expand Up @@ -179,18 +159,9 @@ class EvaluationScenario(BaseModel):
results: List[EvaluationScenarioResult]


class AICritiqueCreate(BaseModel):
correct_answer: str
llm_app_prompt_template: Optional[str]
inputs: List[EvaluationScenarioInput]
outputs: List[EvaluationScenarioOutput]
evaluation_prompt_template: Optional[str]
open_ai_key: Optional[str]


class EvaluationScenarioUpdate(BaseModel):
vote: Optional[str]
score: Optional[Union[str, int]]
score: Optional[Any]
correct_answer: Optional[str] # will be used when running custom code evaluation
outputs: Optional[List[EvaluationScenarioOutput]]
inputs: Optional[List[EvaluationScenarioInput]]
Expand Down Expand Up @@ -245,12 +216,6 @@ class EvaluationWebhook(BaseModel):
score: float


class EvaluationSettingsTemplate(BaseModel):
type: str
default: str
description: str


class LLMRunRateLimit(BaseModel):
batch_size: int
max_retries: int
Expand Down
105 changes: 67 additions & 38 deletions agenta-backend/agenta_backend/models/db_engine.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,94 @@
import os
import logging
from typing import List

from odmantic import AIOEngine
from pymongo import MongoClient
from beanie import init_beanie, Document
from motor.motor_asyncio import AsyncIOMotorClient

from agenta_backend.models.db_models import (
APIKeyDB,
AppEnvironmentDB,
OrganizationDB,
UserDB,
ImageDB,
AppDB,
DeploymentDB,
VariantBaseDB,
ConfigDB,
AppVariantDB,
TemplateDB,
TestSetDB,
EvaluatorConfigDB,
HumanEvaluationDB,
HumanEvaluationScenarioDB,
EvaluationDB,
EvaluationScenarioDB,
SpanDB,
TraceDB,
)

# Configure and set logging level
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# Define Document Models
document_models: List[Document] = [
APIKeyDB,
AppEnvironmentDB,
OrganizationDB,
UserDB,
ImageDB,
AppDB,
DeploymentDB,
VariantBaseDB,
ConfigDB,
AppVariantDB,
TemplateDB,
TestSetDB,
EvaluatorConfigDB,
HumanEvaluationDB,
HumanEvaluationScenarioDB,
EvaluationDB,
EvaluationScenarioDB,
SpanDB,
TraceDB,
]

class DBEngine(object):

class DBEngine:
"""
Database engine to initialize client and return engine based on mode
Database engine to initialize Beanie and return the engine based on mode.
"""

def __init__(self) -> None:
self.mode = os.environ.get("DATABASE_MODE", "v2")
self.db_url = os.environ["MONGODB_URI"]

@property
def initialize_client(self) -> AsyncIOMotorClient:
async def initialize_client(self):
return AsyncIOMotorClient(self.db_url)

async def init_db(self):
"""
Returns an instance of `AsyncIOMotorClient` initialized \
with the provided `db_url`.
Initialize Beanie based on the mode and store the engine.
"""

client = AsyncIOMotorClient(self.db_url)
return client
client = await self.initialize_client()
db_name = self._get_database_name(self.mode)

await init_beanie(database=client[db_name], document_models=document_models)
logger.info(f"Using {db_name} database...")

def engine(self) -> AIOEngine:
def _get_database_name(self, mode: str) -> str:
"""
Returns an AIOEngine object with a specified database name based on the mode.
Determine the appropriate database name based on the mode.
"""
if mode in ("test", "default", "v2"):
return f"agenta_{mode}"

if self.mode == "test":
aio_engine = AIOEngine(
client=self.initialize_client, database="agenta_test"
)
logger.info("Using test database...")
return aio_engine
elif self.mode == "default":
aio_engine = AIOEngine(client=self.initialize_client, database="agenta")
logger.info("Using default database...")
return aio_engine
elif self.mode == "v2":
aio_engine = AIOEngine(client=self.initialize_client, database="agenta_v2")
logger.info("Using v2 database...")
return aio_engine
else:
# make sure that self.mode does only contain alphanumeric characters
if not self.mode.isalnum():
raise ValueError("Mode of database needs to be alphanumeric.")
aio_engine = AIOEngine(
client=self.initialize_client, database=f"agenta_{self.mode}"
)
logger.info(f"Using {self.mode} database...")
return aio_engine
if not mode.isalnum():
raise ValueError("Mode of database needs to be alphanumeric.")
return f"agenta_{mode}"

def remove_db(self) -> None:
"""
Expand All @@ -67,7 +98,5 @@ def remove_db(self) -> None:
client = MongoClient(self.db_url)
if self.mode == "default":
client.drop_database("agenta")
elif self.mode == "v2":
client.drop_database("agenta_v2")
elif self.mode == "test":
client.drop_database("agenta_test")
else:
client.drop_database(f"agenta_{self.mode}")
Loading

0 comments on commit 206ae00

Please sign in to comment.