diff --git a/.github/workflows/check-python-code-black.yml b/.github/workflows/check-python-code-black.yml
index 8179105f17..4ae9298e2b 100644
--- a/.github/workflows/check-python-code-black.yml
+++ b/.github/workflows/check-python-code-black.yml
@@ -12,4 +12,4 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- - uses: psf/black@stable
+ - uses: psf/black@23.12.0
diff --git a/.gitignore b/.gitignore
index b32eb68f0d..a27893b1a4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -56,4 +56,7 @@ agenta-web/cypress/screenshots/
agenta-web/cypress/videos/
.nextjs_cache/
-rabbitmq_data
\ No newline at end of file
+rabbitmq_data
+
+# docker compose override
+docker-compose.*override.yaml
\ No newline at end of file
diff --git a/agenta-backend/agenta_backend/main.py b/agenta-backend/agenta_backend/main.py
index 81cb84dae8..c556824a1e 100644
--- a/agenta-backend/agenta_backend/main.py
+++ b/agenta-backend/agenta_backend/main.py
@@ -1,4 +1,3 @@
-import os
import asyncio
from contextlib import asynccontextmanager
@@ -12,7 +11,6 @@
human_evaluation_router,
evaluators_router,
observability_router,
- organization_router,
testset_router,
user_profile,
variants_router,
@@ -20,10 +18,11 @@
configs_router,
health_router,
)
+from agenta_backend.utils.common import isCloudEE
from agenta_backend.models.db_engine import DBEngine
from agenta_backend.open_api import open_api_tags_metadata
-if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
+if isCloudEE():
from agenta_backend.commons.services import templates_manager
else:
from agenta_backend.services import templates_manager
@@ -71,12 +70,12 @@ async def lifespan(application: FastAPI, cache=True):
allow_headers=allow_headers,
)
-if os.environ["FEATURE_FLAG"] not in ["cloud", "ee"]:
+if not isCloudEE():
from agenta_backend.services.auth_helper import authentication_middleware
app.middleware("http")(authentication_middleware)
-if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
+if isCloudEE():
import agenta_backend.cloud.main as cloud
app, allow_headers = cloud.extend_main(app)
@@ -102,13 +101,10 @@ async def lifespan(application: FastAPI, cache=True):
app.include_router(
observability_router.router, prefix="/observability", tags=["Observability"]
)
-app.include_router(
- organization_router.router, prefix="/organizations", tags=["Organizations"]
-)
app.include_router(bases_router.router, prefix="/bases", tags=["Bases"])
app.include_router(configs_router.router, prefix="/configs", tags=["Configs"])
-if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
+if isCloudEE():
import agenta_backend.cloud.main as cloud
app = cloud.extend_app_schema(app)
diff --git a/agenta-backend/agenta_backend/migrations/v0_11_0_to_v0_12_0/20240126100524_models_revamp.py b/agenta-backend/agenta_backend/migrations/v0_11_0_to_v0_12_0/20240126100524_models_revamp.py
new file mode 100644
index 0000000000..2305e27323
--- /dev/null
+++ b/agenta-backend/agenta_backend/migrations/v0_11_0_to_v0_12_0/20240126100524_models_revamp.py
@@ -0,0 +1,727 @@
+from enum import Enum
+from uuid import uuid4
+from datetime import datetime
+from typing import Any, Dict, List, Optional
+
+from pydantic import BaseModel, Field
+from beanie import Document, Link, PydanticObjectId
+
+from beanie import iterative_migration
+
+
+# Common Models
+class ConfigDB(BaseModel):
+ config_name: str
+ parameters: Dict[str, Any] = Field(default_factory=dict)
+
+
+class Error(BaseModel):
+ message: str
+ stacktrace: Optional[str] = None
+
+
+class Result(BaseModel):
+ type: str
+ value: Optional[Any] = None
+ error: Optional[Error] = None
+
+
+class InvokationResult(BaseModel):
+ result: Result
+
+
+class EvaluationScenarioResult(BaseModel):
+ evaluator_config: PydanticObjectId
+ result: Result
+
+
+class AggregatedResult(BaseModel):
+ evaluator_config: PydanticObjectId
+ result: Result
+
+
+class EvaluationScenarioInputDB(BaseModel):
+ name: str
+ type: str
+ value: str
+
+
+class EvaluationScenarioOutputDB(BaseModel):
+ result: Result
+
+
+class HumanEvaluationScenarioInput(BaseModel):
+ input_name: str
+ input_value: str
+
+
+class HumanEvaluationScenarioOutput(BaseModel):
+ variant_id: str
+ variant_output: str
+
+
+class SpanDB(Document):
+ parent_span_id: Optional[str]
+ meta: Optional[Dict[str, Any]]
+ event_name: str # Function or execution name
+ event_type: Optional[str]
+ start_time: datetime
+ duration: Optional[int]
+ status: str # initiated, completed, stopped, cancelled
+ end_time: datetime = Field(default=datetime.utcnow())
+ inputs: Optional[List[str]]
+ outputs: Optional[List[str]]
+ prompt_template: Optional[str]
+ tokens_input: Optional[int]
+ tokens_output: Optional[int]
+ token_total: Optional[int]
+ cost: Optional[float]
+ tags: Optional[List[str]]
+
+ class Settings:
+ name = "spans"
+
+
+class Feedback(BaseModel):
+ uid: str = Field(default=str(uuid4()))
+ user_id: str
+ feedback: Optional[str]
+ score: Optional[float]
+ meta: Optional[Dict[str, Any]]
+ created_at: datetime
+ updated_at: datetime = Field(default=datetime.utcnow())
+
+
+class TemplateDB(Document):
+ type: Optional[str] = Field(default="image")
+ template_uri: Optional[str]
+ tag_id: Optional[int]
+ name: str = Field(unique=True) # tag name of image
+ repo_name: Optional[str]
+ title: str
+ description: str
+ size: Optional[int]
+ digest: Optional[str] # sha256 hash of image digest
+ last_pushed: Optional[datetime]
+
+ class Settings:
+ name = "templates"
+
+
+# Old DB Models
+class InvitationDB(BaseModel):
+ token: str = Field(unique=True)
+ email: str
+ expiration_date: datetime = Field(default="0")
+ used: bool = False
+
+
+class OldOrganizationDB(Document):
+ name: str = Field(default="agenta")
+ description: str = Field(default="")
+ type: Optional[str]
+ owner: str # user id
+ members: Optional[List[PydanticObjectId]]
+ invitations: Optional[List[InvitationDB]] = []
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "organizations"
+
+
+class OldUserDB(Document):
+ uid: str = Field(default="0", unique=True, index=True)
+ username: str = Field(default="agenta")
+ email: str = Field(default="demo@agenta.ai", unique=True)
+ organizations: Optional[List[PydanticObjectId]] = []
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "users"
+
+
+class OldImageDB(Document):
+ """Defines the info needed to get an image and connect it to the app variant"""
+
+ type: Optional[str] = Field(default="image")
+ template_uri: Optional[str]
+ docker_id: Optional[str] = Field(index=True)
+ tags: Optional[str]
+ deletable: bool = Field(default=True)
+ user: Link[OldUserDB]
+ organization: Link[OldOrganizationDB]
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "docker_images"
+
+
+class OldAppDB(Document):
+ app_name: str
+ organization: Link[OldOrganizationDB]
+ user: Link[OldUserDB]
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "app_db"
+
+
+class OldDeploymentDB(Document):
+ app: Link[OldAppDB]
+ organization: Link[OldOrganizationDB]
+ user: Link[OldUserDB]
+ container_name: Optional[str]
+ container_id: Optional[str]
+ uri: Optional[str]
+ status: str
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "deployments"
+
+
+class OldVariantBaseDB(Document):
+ app: Link[OldAppDB]
+ organization: Link[OldOrganizationDB]
+ user: Link[OldUserDB]
+ base_name: str
+ image: Link[OldImageDB]
+ deployment: Optional[PydanticObjectId] # Link to deployment
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "bases"
+
+
+class OldAppVariantDB(Document):
+ app: Link[OldAppDB]
+ variant_name: str
+ revision: int
+ image: Link[OldImageDB]
+ user: Link[OldUserDB]
+ modified_by: Link[OldUserDB]
+ organization: Link[OldOrganizationDB]
+ parameters: Dict[str, Any] = Field(default=dict) # TODO: deprecated. remove
+ previous_variant_name: Optional[str] # TODO: deprecated. remove
+ base_name: Optional[str]
+ base: Link[OldVariantBaseDB]
+ config_name: Optional[str]
+ config: ConfigDB
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ is_deleted: bool = Field( # TODO: deprecated. remove
+ default=False
+ ) # soft deletion for using the template variants
+
+ class Settings:
+ name = "app_variants"
+
+
+class OldAppVariantRevisionsDB(Document):
+ variant: Link[OldAppVariantDB]
+ revision: int
+ modified_by: Link[OldUserDB]
+ base: Link[OldVariantBaseDB]
+ config: ConfigDB
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "app_variant_revisions"
+
+
+class OldAppEnvironmentDB(Document):
+ app: Link[OldAppDB]
+ name: str
+ user: Link[OldUserDB]
+ organization: Link[OldOrganizationDB]
+ deployed_app_variant: Optional[PydanticObjectId]
+ deployed_app_variant_revision: Optional[Link[OldAppVariantRevisionsDB]]
+ deployment: Optional[PydanticObjectId] # reference to deployment
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "environments"
+
+
+class TemplateDB(Document):
+ type: Optional[str] = Field(default="image")
+ template_uri: Optional[str]
+ tag_id: Optional[int]
+ name: str = Field(unique=True) # tag name of image
+ repo_name: Optional[str]
+ title: str
+ description: str
+ size: Optional[int]
+ digest: Optional[str] # sha256 hash of image digest
+ last_pushed: Optional[datetime]
+
+ class Settings:
+ name = "templates"
+
+
+class OldTestSetDB(Document):
+ name: str
+ app: Link[OldAppDB]
+ csvdata: List[Dict[str, str]]
+ user: Link[OldUserDB]
+ organization: Link[OldOrganizationDB]
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "testsets"
+
+
+class OldEvaluatorConfigDB(Document):
+ app: Link[OldAppDB]
+ organization: Link[OldOrganizationDB]
+ user: Link[OldUserDB]
+ name: str
+ evaluator_key: str
+ settings_values: Dict[str, Any] = Field(default=dict)
+ created_at: datetime = Field(default=datetime.utcnow())
+ updated_at: datetime = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "evaluators_configs"
+
+
+class OldHumanEvaluationDB(Document):
+ app: Link[OldAppDB]
+ organization: Link[OldOrganizationDB]
+ user: Link[OldUserDB]
+ status: str
+ evaluation_type: str
+ variants: List[PydanticObjectId]
+ variants_revisions: List[PydanticObjectId]
+ testset: Link[OldTestSetDB]
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "human_evaluations"
+
+
+class OldHumanEvaluationScenarioDB(Document):
+ user: Link[OldUserDB]
+ organization: Link[OldOrganizationDB]
+ evaluation: Link[OldHumanEvaluationDB]
+ inputs: List[HumanEvaluationScenarioInput]
+ outputs: List[HumanEvaluationScenarioOutput]
+ vote: Optional[str]
+ score: Optional[Any]
+ correct_answer: Optional[str]
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+ is_pinned: Optional[bool]
+ note: Optional[str]
+
+ class Settings:
+ name = "human_evaluations_scenarios"
+
+
+class OldEvaluationDB(Document):
+ app: Link[OldAppDB]
+ organization: Link[OldOrganizationDB]
+ user: Link[OldUserDB]
+ status: Result
+ testset: Link[OldTestSetDB]
+ variant: PydanticObjectId
+ variant_revision: Optional[PydanticObjectId] = None
+ evaluators_configs: List[PydanticObjectId]
+ aggregated_results: List[AggregatedResult]
+ created_at: datetime = Field(default=datetime.utcnow())
+ updated_at: datetime = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "new_evaluations"
+
+
+class OldEvaluationScenarioDB(Document):
+ user: Link[OldUserDB]
+ organization: Link[OldOrganizationDB]
+ evaluation: Link[OldEvaluationDB]
+ variant_id: PydanticObjectId
+ inputs: List[EvaluationScenarioInputDB]
+ outputs: List[EvaluationScenarioOutputDB]
+ correct_answer: Optional[str]
+ is_pinned: Optional[bool]
+ note: Optional[str]
+ evaluators_configs: List[PydanticObjectId]
+ results: List[EvaluationScenarioResult]
+ created_at: datetime = Field(default=datetime.utcnow())
+ updated_at: datetime = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "new_evaluation_scenarios"
+
+
+# New DB Models
+class NewUserDB(Document):
+ uid: str = Field(default="0", unique=True, index=True)
+ username: str = Field(default="agenta")
+ email: str = Field(default="demo@agenta.ai", unique=True)
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "users"
+
+
+class NewImageDB(Document):
+ """Defines the info needed to get an image and connect it to the app variant"""
+
+ type: Optional[str] = Field(default="image")
+ template_uri: Optional[str]
+ docker_id: Optional[str] = Field(index=True)
+ tags: Optional[str]
+ deletable: bool = Field(default=True)
+ user: Link[NewUserDB]
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "docker_images"
+
+
+class NewAppDB(Document):
+ app_name: str
+ user: Link[NewUserDB]
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "app_db"
+
+
+class NewDeploymentDB(Document):
+ app: Link[NewAppDB]
+ user: Link[NewUserDB]
+ container_name: Optional[str]
+ container_id: Optional[str]
+ uri: Optional[str]
+ status: str
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "deployments"
+
+
+class NewVariantBaseDB(Document):
+ app: Link[NewAppDB]
+ user: Link[NewUserDB]
+ base_name: str
+ image: Link[NewImageDB]
+ deployment: Optional[PydanticObjectId] # Link to deployment
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "bases"
+
+
+class NewAppVariantDB(Document):
+ app: Link[NewAppDB]
+ variant_name: str
+ revision: int
+ image: Link[NewImageDB]
+ user: Link[NewUserDB]
+ modified_by: Link[NewUserDB]
+ parameters: Dict[str, Any] = Field(default=dict) # TODO: deprecated. remove
+ previous_variant_name: Optional[str] # TODO: deprecated. remove
+ base_name: Optional[str]
+ base: Link[NewVariantBaseDB]
+ config_name: Optional[str]
+ config: ConfigDB
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ is_deleted: bool = Field( # TODO: deprecated. remove
+ default=False
+ ) # soft deletion for using the template variants
+
+ class Settings:
+ name = "app_variants"
+
+
+class NewAppVariantRevisionsDB(Document):
+ variant: Link[NewAppVariantDB]
+ revision: int
+ modified_by: Link[NewUserDB]
+ base: Link[NewVariantBaseDB]
+ config: ConfigDB
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "app_variant_revisions"
+
+
+class NewAppEnvironmentDB(Document):
+ app: Link[NewAppDB]
+ name: str
+ user: Link[NewUserDB]
+ deployed_app_variant: Optional[PydanticObjectId]
+ deployed_app_variant_revision: Optional[Link[NewAppVariantRevisionsDB]]
+ deployment: Optional[PydanticObjectId] # reference to deployment
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "environments"
+
+
+class NewTestSetDB(Document):
+ name: str
+ app: Link[NewAppDB]
+ csvdata: List[Dict[str, str]]
+ user: Link[NewUserDB]
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "testsets"
+
+
+class NewEvaluatorConfigDB(Document):
+ app: Link[NewAppDB]
+ user: Link[NewUserDB]
+ name: str
+ evaluator_key: str
+ settings_values: Dict[str, Any] = Field(default=dict)
+ created_at: datetime = Field(default=datetime.utcnow())
+ updated_at: datetime = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "evaluators_configs"
+
+
+class NewHumanEvaluationDB(Document):
+ app: Link[NewAppDB]
+ user: Link[NewUserDB]
+ status: str
+ evaluation_type: str
+ variants: List[PydanticObjectId]
+ variants_revisions: List[PydanticObjectId]
+ testset: Link[NewTestSetDB]
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "human_evaluations"
+
+
+class NewHumanEvaluationScenarioDB(Document):
+ user: Link[NewUserDB]
+ evaluation: Link[NewHumanEvaluationDB]
+ inputs: List[HumanEvaluationScenarioInput]
+ outputs: List[HumanEvaluationScenarioOutput]
+ vote: Optional[str]
+ score: Optional[Any]
+ correct_answer: Optional[str]
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+ is_pinned: Optional[bool]
+ note: Optional[str]
+
+ class Settings:
+ name = "human_evaluations_scenarios"
+
+
+class NewEvaluationDB(Document):
+ app: Link[NewAppDB]
+ user: Link[NewUserDB]
+ status: Result
+ testset: Link[NewTestSetDB]
+ variant: PydanticObjectId
+ variant_revision: Optional[PydanticObjectId] = None
+ evaluators_configs: List[PydanticObjectId]
+ aggregated_results: List[AggregatedResult]
+ created_at: datetime = Field(default=datetime.utcnow())
+ updated_at: datetime = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "new_evaluations"
+
+
+class NewEvaluationScenarioDB(Document):
+ user: Link[NewUserDB]
+ evaluation: Link[NewEvaluationDB]
+ variant_id: PydanticObjectId
+ inputs: List[EvaluationScenarioInputDB]
+ outputs: List[EvaluationScenarioOutputDB]
+ correct_answer: Optional[str]
+ is_pinned: Optional[bool]
+ note: Optional[str]
+ evaluators_configs: List[PydanticObjectId]
+ results: List[EvaluationScenarioResult]
+ created_at: datetime = Field(default=datetime.utcnow())
+ updated_at: datetime = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "new_evaluation_scenarios"
+
+
+class Forward:
+ @iterative_migration(
+ document_models=[
+ OldUserDB,
+ NewUserDB,
+ ]
+ )
+ async def remove_organization_from_user_model(
+ self, input_document: OldUserDB, output_document: NewUserDB
+ ):
+ input_document.dict(exclude={"organizations"})
+
+ @iterative_migration(
+ document_models=[
+ OldAppDB,
+ NewAppDB,
+ ]
+ )
+ async def remove_organization_from_app_model(
+ self, input_document: OldAppDB, output_document: NewAppDB
+ ):
+ input_document.dict(exclude={"organization"})
+
+ @iterative_migration(
+ document_models=[
+ OldImageDB,
+ NewImageDB,
+ ]
+ )
+ async def remove_organization_from_image_model(
+ self, input_document: OldImageDB, output_document: NewImageDB
+ ):
+ input_document.dict(exclude={"organization"})
+
+ @iterative_migration(
+ document_models=[
+ OldTestSetDB,
+ NewTestSetDB,
+ ]
+ )
+ async def remove_organization_from_testset_model(
+ self, input_document: OldTestSetDB, output_document: NewTestSetDB
+ ):
+ input_document.dict(exclude={"organization"})
+
+ @iterative_migration(
+ document_models=[
+ OldVariantBaseDB,
+ NewVariantBaseDB,
+ ]
+ )
+ async def remove_organization_from_variant_base_model(
+ self, input_document: OldVariantBaseDB, output_document: NewVariantBaseDB
+ ):
+ input_document.dict(exclude={"organization"})
+
+ @iterative_migration(
+ document_models=[
+ OldAppVariantDB,
+ NewVariantBaseDB,
+ ]
+ )
+ async def remove_organization_from_app_variant_model(
+ self, input_document: OldAppVariantDB, output_document: NewAppVariantDB
+ ):
+ input_document.dict(exclude={"organization"})
+
+ @iterative_migration(
+ document_models=[
+ OldEvaluationDB,
+ NewEvaluationDB,
+ ]
+ )
+ async def remove_organization_from_evaluation_model(
+ self, input_document: OldEvaluationDB, output_document: NewEvaluationDB
+ ):
+ input_document.dict(exclude={"organization"})
+
+ @iterative_migration(
+ document_models=[
+ OldDeploymentDB,
+ NewDeploymentDB,
+ ]
+ )
+ async def remove_organization_from_deployment_model(
+ self, input_document: OldDeploymentDB, output_document: NewDeploymentDB
+ ):
+ input_document.dict(exclude={"organization"})
+
+ @iterative_migration(
+ document_models=[
+ OldAppEnvironmentDB,
+ NewAppEnvironmentDB,
+ ]
+ )
+ async def remove_organization_from_app_environment_model(
+ self, input_document: OldAppEnvironmentDB, output_document: NewAppEnvironmentDB
+ ):
+ input_document.dict(exclude={"organization"})
+
+ @iterative_migration(
+ document_models=[
+ OldEvaluatorConfigDB,
+ NewEvaluatorConfigDB,
+ ]
+ )
+ async def remove_organization_from_evaluator_config_model(
+ self,
+ input_document: OldEvaluatorConfigDB,
+ output_document: NewEvaluatorConfigDB,
+ ):
+ input_document.dict(exclude={"organization"})
+
+ @iterative_migration(
+ document_models=[
+ OldHumanEvaluationDB,
+ NewHumanEvaluationDB,
+ ]
+ )
+ async def remove_organization_from_human_evaluation_model(
+ self,
+ input_document: OldHumanEvaluationDB,
+ output_document: NewHumanEvaluationDB,
+ ):
+ input_document.dict(exclude={"organization"})
+
+ @iterative_migration(
+ document_models=[
+ OldEvaluationScenarioDB,
+ NewEvaluationScenarioDB,
+ ]
+ )
+ async def remove_organization_from_evaluation_scenario_model(
+ self,
+ input_document: OldEvaluationScenarioDB,
+ output_document: NewEvaluationScenarioDB,
+ ):
+ input_document.dict(exclude={"organization"})
+
+ @iterative_migration(
+ document_models=[
+ OldHumanEvaluationScenarioDB,
+ NewHumanEvaluationScenarioDB,
+ ]
+ )
+ async def remove_organization_from_app_environment_model(
+ self,
+ input_document: OldHumanEvaluationScenarioDB,
+ output_document: NewHumanEvaluationScenarioDB,
+ ):
+ input_document.dict(exclude={"organization"})
+
+
+class Backward:
+ pass
diff --git a/agenta-backend/agenta_backend/migrations/v0_11_0_to_v0_12_0/20240126144938_drop_organization_model.py b/agenta-backend/agenta_backend/migrations/v0_11_0_to_v0_12_0/20240126144938_drop_organization_model.py
new file mode 100644
index 0000000000..784c76c68e
--- /dev/null
+++ b/agenta-backend/agenta_backend/migrations/v0_11_0_to_v0_12_0/20240126144938_drop_organization_model.py
@@ -0,0 +1,39 @@
+from datetime import datetime
+from typing import List, Optional
+from pydantic import BaseModel, Field
+
+from beanie import Document, PydanticObjectId, free_fall_migration
+
+
+class InvitationDB(BaseModel):
+ token: str = Field(unique=True)
+ email: str
+ expiration_date: datetime = Field(default="0")
+ used: bool = False
+
+
+class OldOrganizationDB(Document):
+ name: str = Field(default="agenta")
+ description: str = Field(default="")
+ type: Optional[str]
+ owner: str # user id
+ members: Optional[List[PydanticObjectId]]
+ invitations: Optional[List[InvitationDB]] = []
+ created_at: Optional[datetime] = Field(default=datetime.utcnow())
+ updated_at: Optional[datetime] = Field(default=datetime.utcnow())
+
+ class Settings:
+ name = "organizations"
+
+
+class Forward:
+ @free_fall_migration(document_models=[OldOrganizationDB])
+ async def drop_old_organization_db(self, session):
+ # Wrap deletion loop in a with_transaction context for potential rollback
+ async with session.start_transaction():
+ async for old_organization in OldOrganizationDB.find_all():
+ await old_organization.delete()
+
+
+class Backward:
+ pass
diff --git a/agenta-backend/agenta_backend/models/api/api_models.py b/agenta-backend/agenta_backend/models/api/api_models.py
index b694f4e360..365a657b4c 100644
--- a/agenta-backend/agenta_backend/models/api/api_models.py
+++ b/agenta-backend/agenta_backend/models/api/api_models.py
@@ -42,7 +42,6 @@ class VariantAction(BaseModel):
class CreateApp(BaseModel):
app_name: str
- organization_id: Optional[str] = None
class CreateAppOutput(BaseModel):
@@ -64,7 +63,6 @@ class AppVariant(BaseModel):
variant_name: str
parameters: Optional[Dict[str, Any]]
previous_variant_name: Optional[str]
- organization_id: Optional[str] = None
base_name: Optional[str]
config_name: Optional[str]
@@ -73,14 +71,13 @@ class AppVariantFromImagePayload(BaseModel):
variant_name: str
-class AppVariantOutput(BaseModel):
+class AppVariantResponse(BaseModel):
app_id: str
app_name: str
variant_id: str
variant_name: str
parameters: Optional[Dict[str, Any]]
previous_variant_name: Optional[str]
- organization_id: str
user_id: str
base_name: str
base_id: str
@@ -102,7 +99,6 @@ class AppVariantOutputExtended(BaseModel):
variant_name: str
parameters: Optional[Dict[str, Any]]
previous_variant_name: Optional[str]
- organization_id: str
user_id: str
base_name: str
base_id: str
@@ -151,7 +147,6 @@ class AppVariantFromImage(BaseModel):
variant_name: str
parameters: Optional[Dict[str, Any]]
previous_variant_name: Optional[str]
- organization_id: Optional[str] = None
class RestartAppContainer(BaseModel):
@@ -162,7 +157,6 @@ class Image(BaseModel):
type: Optional[str]
docker_id: str
tags: str
- organization_id: Optional[str] = None
class AddVariantFromImagePayload(BaseModel):
@@ -215,15 +209,6 @@ class CreateAppVariant(BaseModel):
app_name: str
template_id: str
env_vars: Dict[str, str]
- organization_id: Optional[str] = None
-
-
-class InviteRequest(BaseModel):
- email: str
-
-
-class InviteToken(BaseModel):
- token: str
class Environment(BaseModel):
@@ -231,7 +216,6 @@ class Environment(BaseModel):
deployed_app_variant: Optional[str]
deployed_base_name: Optional[str]
deployed_config_name: Optional[str]
- organization_id: Optional[str] = None
class DeployToEnvironmentPayload(BaseModel):
@@ -255,13 +239,6 @@ class PostVariantConfigPayload(BaseModel):
overwrite: bool
-class ListAPIKeysOutput(BaseModel):
- prefix: str
- created_at: datetime
- last_used_at: datetime = None
- expiration_date: datetime = None
-
-
class BaseOutput(BaseModel):
base_id: str
base_name: str
diff --git a/agenta-backend/agenta_backend/models/converters.py b/agenta-backend/agenta_backend/models/converters.py
index 3c4f27202d..43021f272e 100644
--- a/agenta-backend/agenta_backend/models/converters.py
+++ b/agenta-backend/agenta_backend/models/converters.py
@@ -2,65 +2,100 @@
"""
import json
+import logging
from typing import List
+
from agenta_backend.services import db_manager
+from agenta_backend.utils.common import isCloudEE
from agenta_backend.models.api.user_models import User
-from agenta_backend.models.db_models import (
- AppVariantDB,
- AppVariantRevisionsDB,
- EvaluationScenarioResult,
- EvaluatorConfigDB,
- HumanEvaluationDB,
- HumanEvaluationScenarioDB,
- ImageDB,
- TemplateDB,
- AppDB,
- AppEnvironmentDB,
- AppEnvironmentRevisionDB,
- TestSetDB,
- SpanDB,
- TraceDB,
- Feedback as FeedbackDB,
- EvaluationDB,
- EvaluationScenarioDB,
- VariantBaseDB,
- UserDB,
- AggregatedResult,
-)
-from agenta_backend.models.api.api_models import (
- AppVariant,
- AppVariantRevision,
- AppVariantOutputExtended,
- ImageExtended,
- Template,
- TemplateImageInfo,
- AppVariantOutput,
- App,
- EnvironmentOutput,
- EnvironmentRevision,
- EnvironmentOutputExtended,
- TestSetOutput,
- BaseOutput,
-)
from agenta_backend.models.api.observability_models import (
Span,
Trace,
Feedback as FeedbackOutput,
)
from agenta_backend.models.api.evaluation_model import (
- HumanEvaluation,
- HumanEvaluationScenario,
- SimpleEvaluationOutput,
- EvaluationScenario,
Evaluation,
+ HumanEvaluation,
EvaluatorConfig,
+ EvaluationScenario,
+ SimpleEvaluationOutput,
EvaluationScenarioInput,
+ HumanEvaluationScenario,
EvaluationScenarioOutput,
)
-import logging
+if isCloudEE():
+ from agenta_backend.commons.models.db_models import (
+ AppDB_ as AppDB,
+ UserDB_ as UserDB,
+ ImageDB_ as ImageDB,
+ TestSetDB_ as TestSetDB,
+ EvaluationDB_ as EvaluationDB,
+ AppVariantDB_ as AppVariantDB,
+ VariantBaseDB_ as VariantBaseDB,
+ AppEnvironmentDB_ as AppEnvironmentDB,
+ AppEnvironmentRevisionDB_ as AppEnvironmentRevisionDB,
+ EvaluatorConfigDB_ as EvaluatorConfigDB,
+ HumanEvaluationDB_ as HumanEvaluationDB,
+ EvaluationScenarioDB_ as EvaluationScenarioDB,
+ HumanEvaluationScenarioDB_ as HumanEvaluationScenarioDB,
+ )
+ from agenta_backend.commons.models.api.api_models import (
+ AppVariant_ as AppVariant,
+ ImageExtended_ as ImageExtended,
+ AppVariantResponse_ as AppVariantResponse,
+ AppVariantOutputExtended_ as AppVariantOutputExtended,
+ EnvironmentRevision_ as EnvironmentRevision,
+ EnvironmentOutput_ as EnvironmentOutput,
+ EnvironmentOutputExtended_ as EnvironmentOutputExtended,
+ )
+else:
+ from agenta_backend.models.db_models import (
+ AppDB,
+ UserDB,
+ ImageDB,
+ TestSetDB,
+ EvaluationDB,
+ AppVariantDB,
+ VariantBaseDB,
+ AppEnvironmentDB,
+ AppEnvironmentRevisionDB,
+ EvaluatorConfigDB,
+ HumanEvaluationDB,
+ EvaluationScenarioDB,
+ HumanEvaluationScenarioDB,
+ )
+ from agenta_backend.models.api.api_models import (
+ AppVariant,
+ ImageExtended,
+ AppVariantResponse,
+ AppVariantOutputExtended,
+ EnvironmentRevision,
+ EnvironmentOutput,
+ EnvironmentOutputExtended,
+ )
+
+from agenta_backend.models.db_models import (
+ SpanDB,
+ TraceDB,
+ TemplateDB,
+ AggregatedResult,
+ AppVariantRevisionsDB,
+ Feedback as FeedbackDB,
+ EvaluationScenarioResult,
+)
+from agenta_backend.models.api.api_models import (
+ App,
+ Template,
+ BaseOutput,
+ TestSetOutput,
+ TemplateImageInfo,
+ AppVariantRevision,
+)
+
from beanie import Link
+
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -234,19 +269,24 @@ def evaluation_scenario_db_to_pydantic(
def app_variant_db_to_pydantic(
app_variant_db: AppVariantDB, previous_variant_name: str = None
) -> AppVariant:
- return AppVariant(
+ app_variant = AppVariant(
app_id=str(app_variant_db.app.id),
app_name=app_variant_db.app.app_name,
variant_name=app_variant_db.variant_name,
parameters=app_variant_db.config.parameters,
previous_variant_name=app_variant_db.previous_variant_name,
- organization_id=str(app_variant_db.organization.id),
base_name=app_variant_db.base_name,
config_name=app_variant_db.config_name,
)
+ if isCloudEE():
+ app_variant.organization_id = str(app_variant_db.organization.id)
+ app_variant.workspace_id = str(app_variant_db.workspace.id)
-async def app_variant_db_to_output(app_variant_db: AppVariantDB) -> AppVariantOutput:
+ return app_variant
+
+
+async def app_variant_db_to_output(app_variant_db: AppVariantDB) -> AppVariantResponse:
if app_variant_db.base.deployment:
deployment = await db_manager.get_deployment_by_objectid(
app_variant_db.base.deployment
@@ -256,13 +296,12 @@ async def app_variant_db_to_output(app_variant_db: AppVariantDB) -> AppVariantOu
deployment = None
uri = None
logger.info(f"uri: {uri} deployment: {app_variant_db.base.deployment} {deployment}")
- return AppVariantOutput(
+ variant_response = AppVariantResponse(
app_id=str(app_variant_db.app.id),
app_name=str(app_variant_db.app.app_name),
variant_name=app_variant_db.variant_name,
variant_id=str(app_variant_db.id),
user_id=str(app_variant_db.user.id),
- organization_id=str(app_variant_db.organization.id),
parameters=app_variant_db.config.parameters,
previous_variant_name=app_variant_db.previous_variant_name,
base_name=app_variant_db.base_name,
@@ -271,10 +310,16 @@ async def app_variant_db_to_output(app_variant_db: AppVariantDB) -> AppVariantOu
uri=uri,
)
+ if isCloudEE():
+ variant_response.organization_id = str(app_variant_db.organization.id)
+ variant_response.workspace_id = str(app_variant_db.workspace.id)
+
+ return variant_response
+
async def app_variant_db_and_revision_to_extended_output(
app_variant_db: AppVariantDB, app_variant_revisions_db: AppVariantRevisionsDB
-) -> AppVariantOutput:
+) -> AppVariantResponse:
if app_variant_db.base.deployment:
deployment = await db_manager.get_deployment_by_objectid(
app_variant_db.base.deployment
@@ -295,13 +340,12 @@ async def app_variant_db_and_revision_to_extended_output(
created_at=app_variant_revision_db.created_at,
)
)
- return AppVariantOutputExtended(
+ variant_extended = AppVariantOutputExtended(
app_id=str(app_variant_db.app.id),
app_name=str(app_variant_db.app.app_name),
variant_name=app_variant_db.variant_name,
variant_id=str(app_variant_db.id),
user_id=str(app_variant_db.user.id),
- organization_id=str(app_variant_db.organization.id),
parameters=app_variant_db.config.parameters,
previous_variant_name=app_variant_db.previous_variant_name,
base_name=app_variant_db.base_name,
@@ -312,6 +356,12 @@ async def app_variant_db_and_revision_to_extended_output(
revisions=app_variant_revisions,
)
+ if isCloudEE():
+ variant_extended.organization_id = str(app_variant_db.organization.id)
+ variant_extended.workspace_id = str(app_variant_db.workspace.id)
+
+ return variant_extended
+
async def environment_db_to_output(
environment_db: AppEnvironmentDB,
@@ -331,7 +381,7 @@ async def environment_db_to_output(
deployed_variant_name = None
revision = None
- return EnvironmentOutput(
+ environment_output = EnvironmentOutput(
name=environment_db.name,
app_id=str(environment_db.app.id),
deployed_app_variant_id=deployed_app_variant_id,
@@ -342,6 +392,11 @@ async def environment_db_to_output(
revision=revision,
)
+ if isCloudEE():
+ environment_output.organization_id = str(environment_db.organization.id)
+ environment_output.workspace_id = str(environment_db.workspace.id)
+ return environment_output
+
async def environment_db_and_revision_to_extended_output(
environment_db: AppEnvironmentDB,
@@ -374,7 +429,7 @@ async def environment_db_and_revision_to_extended_output(
created_at=app_environment_revision.created_at,
)
)
- return EnvironmentOutputExtended(
+ environment_output_extended = EnvironmentOutputExtended(
name=environment_db.name,
app_id=str(environment_db.app.id),
deployed_app_variant_id=deployed_app_variant_id,
@@ -386,6 +441,13 @@ async def environment_db_and_revision_to_extended_output(
revisions=app_environment_revisions,
)
+ if isCloudEE():
+ environment_output_extended.organization_id = str(
+ environment_db.organization.id
+ )
+ environment_output_extended.workspace_id = str(environment_db.workspace.id)
+ return environment_output_extended
+
def base_db_to_pydantic(base_db: VariantBaseDB) -> BaseOutput:
return BaseOutput(base_id=str(base_db.id), base_name=base_db.base_name)
@@ -396,13 +458,18 @@ def app_db_to_pydantic(app_db: AppDB) -> App:
def image_db_to_pydantic(image_db: ImageDB) -> ImageExtended:
- return ImageExtended(
- organization_id=str(image_db.organization.id),
+ image = ImageExtended(
docker_id=image_db.docker_id,
tags=image_db.tags,
id=str(image_db.id),
)
+ if isCloudEE():
+ image.organization_id = str(image_db.organization.id)
+ image.workspace_id = str(image_db.workspace.id)
+
+ return image
+
def templates_db_to_pydantic(templates_db: List[TemplateDB]) -> List[Template]:
return [
diff --git a/agenta-backend/agenta_backend/models/db_engine.py b/agenta-backend/agenta_backend/models/db_engine.py
index e9d8252c00..fb812e2bab 100644
--- a/agenta-backend/agenta_backend/models/db_engine.py
+++ b/agenta-backend/agenta_backend/models/db_engine.py
@@ -6,57 +6,83 @@
from beanie import init_beanie, Document
from motor.motor_asyncio import AsyncIOMotorClient
+from agenta_backend.utils.common import isCloudEE
+
+if isCloudEE():
+ from agenta_backend.commons.models.db_models import (
+ APIKeyDB,
+ WorkspaceDB,
+ OrganizationDB,
+ AppDB_ as AppDB,
+ UserDB_ as UserDB,
+ ImageDB_ as ImageDB,
+ TestSetDB_ as TestSetDB,
+ AppVariantDB_ as AppVariantDB,
+ EvaluationDB_ as EvaluationDB,
+ DeploymentDB_ as DeploymentDB,
+ VariantBaseDB_ as VariantBaseDB,
+ AppEnvironmentDB_ as AppEnvironmentDB,
+ AppEnvironmentRevisionDB_ as AppEnvironmentRevisionDB,
+ EvaluatorConfigDB_ as EvaluatorConfigDB,
+ HumanEvaluationDB_ as HumanEvaluationDB,
+ EvaluationScenarioDB_ as EvaluationScenarioDB,
+ HumanEvaluationScenarioDB_ as HumanEvaluationScenarioDB,
+ )
+else:
+ from agenta_backend.models.db_models import (
+ AppDB,
+ UserDB,
+ ImageDB,
+ TestSetDB,
+ EvaluationDB,
+ DeploymentDB,
+ AppVariantDB,
+ VariantBaseDB,
+ AppEnvironmentDB,
+ AppEnvironmentRevisionDB,
+ EvaluatorConfigDB,
+ HumanEvaluationDB,
+ EvaluationScenarioDB,
+ HumanEvaluationScenarioDB,
+ )
+
from agenta_backend.models.db_models import (
- APIKeyDB,
- AppEnvironmentDB,
- AppEnvironmentRevisionDB,
- OrganizationDB,
- UserDB,
- ImageDB,
- AppDB,
- DeploymentDB,
- VariantBaseDB,
- AppVariantRevisionsDB,
- AppVariantDB,
- TemplateDB,
- TestSetDB,
- EvaluatorConfigDB,
- HumanEvaluationDB,
- HumanEvaluationScenarioDB,
- EvaluationDB,
- EvaluationScenarioDB,
SpanDB,
TraceDB,
+ TemplateDB,
+ AppVariantRevisionsDB,
)
-# Configure and set logging level
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
# Define Document Models
document_models: List[Document] = [
- APIKeyDB,
- AppEnvironmentDB,
- AppEnvironmentRevisionDB,
- OrganizationDB,
+ AppDB,
UserDB,
+ SpanDB,
+ TraceDB,
ImageDB,
- AppDB,
+ TestSetDB,
+ TemplateDB,
+ AppVariantDB,
DeploymentDB,
+ EvaluationDB,
VariantBaseDB,
- AppVariantDB,
- AppVariantRevisionsDB,
- TemplateDB,
- TestSetDB,
+ AppEnvironmentDB,
+ AppEnvironmentRevisionDB,
EvaluatorConfigDB,
HumanEvaluationDB,
- HumanEvaluationScenarioDB,
- EvaluationDB,
EvaluationScenarioDB,
- SpanDB,
- TraceDB,
+ AppVariantRevisionsDB,
+ HumanEvaluationScenarioDB,
]
+if isCloudEE():
+ document_models = document_models + [OrganizationDB, WorkspaceDB, APIKeyDB]
+
+
+# Configure and set logging level
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
class DBEngine:
"""
diff --git a/agenta-backend/agenta_backend/models/db_models.py b/agenta-backend/agenta_backend/models/db_models.py
index f44ba6b460..9c086f7662 100644
--- a/agenta-backend/agenta_backend/models/db_models.py
+++ b/agenta-backend/agenta_backend/models/db_models.py
@@ -1,3 +1,4 @@
+from enum import Enum
from uuid import uuid4
from datetime import datetime
from typing import Any, Dict, List, Optional
@@ -6,47 +7,10 @@
from beanie import Document, Link, PydanticObjectId
-class APIKeyDB(Document):
- prefix: str
- hashed_key: str
- user_id: str
- rate_limit: int = Field(default=0)
- hidden: Optional[bool] = Field(default=False)
- expiration_date: Optional[datetime]
- created_at: Optional[datetime] = datetime.now()
- updated_at: Optional[datetime]
-
- class Settings:
- name = "api_keys"
-
-
-class InvitationDB(BaseModel):
- token: str = Field(unique=True)
- email: str
- expiration_date: datetime = Field(default="0")
- used: bool = False
-
-
-class OrganizationDB(Document):
- name: str = Field(default="agenta")
- description: str = Field(default="")
- type: Optional[str]
- owner: str # user id
- members: Optional[List[PydanticObjectId]]
- invitations: Optional[List[InvitationDB]] = []
- created_at: Optional[datetime] = Field(default=datetime.now())
- updated_at: Optional[datetime] = Field(default=datetime.now())
- is_paying: Optional[bool] = Field(default=False)
-
- class Settings:
- name = "organizations"
-
-
class UserDB(Document):
uid: str = Field(default="0", unique=True, index=True)
username: str = Field(default="agenta")
email: str = Field(default="demo@agenta.ai", unique=True)
- organizations: Optional[List[PydanticObjectId]] = []
created_at: Optional[datetime] = Field(default=datetime.now())
updated_at: Optional[datetime] = Field(default=datetime.now())
@@ -63,7 +27,6 @@ class ImageDB(Document):
tags: Optional[str]
deletable: bool = Field(default=True)
user: Link[UserDB]
- organization: Link[OrganizationDB]
created_at: Optional[datetime] = Field(default=datetime.now())
updated_at: Optional[datetime] = Field(default=datetime.now())
@@ -73,7 +36,6 @@ class Settings:
class AppDB(Document):
app_name: str
- organization: Link[OrganizationDB]
user: Link[UserDB]
created_at: Optional[datetime] = Field(default=datetime.now())
updated_at: Optional[datetime] = Field(default=datetime.now())
@@ -84,7 +46,6 @@ class Settings:
class DeploymentDB(Document):
app: Link[AppDB]
- organization: Link[OrganizationDB]
user: Link[UserDB]
container_name: Optional[str]
container_id: Optional[str]
@@ -99,7 +60,6 @@ class Settings:
class VariantBaseDB(Document):
app: Link[AppDB]
- organization: Link[OrganizationDB]
user: Link[UserDB]
base_name: str
image: Link[ImageDB]
@@ -123,7 +83,6 @@ class AppVariantDB(Document):
image: Link[ImageDB]
user: Link[UserDB]
modified_by: Link[UserDB]
- organization: Link[OrganizationDB]
parameters: Dict[str, Any] = Field(default=dict) # TODO: deprecated. remove
previous_variant_name: Optional[str] # TODO: deprecated. remove
base_name: Optional[str]
@@ -159,7 +118,6 @@ class AppEnvironmentDB(Document):
name: str
user: Link[UserDB]
revision: int
- organization: Link[OrganizationDB]
deployed_app_variant: Optional[PydanticObjectId]
deployed_app_variant_revision: Optional[Link[AppVariantRevisionsDB]]
deployment: Optional[PydanticObjectId] # reference to deployment
@@ -202,7 +160,6 @@ class TestSetDB(Document):
app: Link[AppDB]
csvdata: List[Dict[str, str]]
user: Link[UserDB]
- organization: Link[OrganizationDB]
created_at: Optional[datetime] = Field(default=datetime.now())
updated_at: Optional[datetime] = Field(default=datetime.now())
@@ -212,7 +169,6 @@ class Settings:
class EvaluatorConfigDB(Document):
app: Link[AppDB]
- organization: Link[OrganizationDB]
user: Link[UserDB]
name: str
evaluator_key: str
@@ -271,7 +227,6 @@ class HumanEvaluationScenarioOutput(BaseModel):
class HumanEvaluationDB(Document):
app: Link[AppDB]
- organization: Link[OrganizationDB]
user: Link[UserDB]
status: str
evaluation_type: str
@@ -287,7 +242,6 @@ class Settings:
class HumanEvaluationScenarioDB(Document):
user: Link[UserDB]
- organization: Link[OrganizationDB]
evaluation: Link[HumanEvaluationDB]
inputs: List[HumanEvaluationScenarioInput]
outputs: List[HumanEvaluationScenarioOutput]
@@ -305,7 +259,6 @@ class Settings:
class EvaluationDB(Document):
app: Link[AppDB]
- organization: Link[OrganizationDB]
user: Link[UserDB]
status: Result
testset: Link[TestSetDB]
@@ -322,7 +275,6 @@ class Settings:
class EvaluationScenarioDB(Document):
user: Link[UserDB]
- organization: Link[OrganizationDB]
evaluation: Link[EvaluationDB]
variant_id: PydanticObjectId
inputs: List[EvaluationScenarioInputDB]
diff --git a/agenta-backend/agenta_backend/routers/app_router.py b/agenta-backend/agenta_backend/routers/app_router.py
index 55267bfb04..730cba1f15 100644
--- a/agenta-backend/agenta_backend/routers/app_router.py
+++ b/agenta-backend/agenta_backend/routers/app_router.py
@@ -1,46 +1,70 @@
import os
import logging
+
+from typing import List, Optional
from docker.errors import DockerException
from fastapi.responses import JSONResponse
-from agenta_backend.config import settings
-from typing import List, Optional
from fastapi import HTTPException, Request
-from agenta_backend.utils.common import APIRouter
-from agenta_backend.services.selectors import get_user_own_org
+from beanie import PydanticObjectId as ObjectId
+
+from agenta_backend.config import settings
+from agenta_backend.models import converters
+from agenta_backend.utils.common import (
+ isEE,
+ isCloud,
+ APIRouter,
+ isCloudEE,
+)
+
from agenta_backend.services import (
- app_manager,
db_manager,
+ app_manager,
evaluator_manager,
)
-from agenta_backend.utils.common import (
- check_access_to_app,
- check_user_org_access,
-)
from agenta_backend.models.api.api_models import (
App,
- Image,
- CreateApp,
CreateAppOutput,
- CreateAppVariant,
- AppVariantOutput,
AddVariantFromImagePayload,
- EnvironmentOutput,
- EnvironmentOutputExtended,
)
-from agenta_backend.models import converters
-if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
+if isCloudEE():
+ from agenta_backend.commons.models.api.api_models import (
+ Image_ as Image,
+ CreateApp_ as CreateApp,
+ AppVariantResponse_ as AppVariantResponse,
+ CreateAppVariant_ as CreateAppVariant,
+ EnvironmentOutput_ as EnvironmentOutput,
+ EnvironmentOutputExtended_ as EnvironmentOutputExtended,
+ )
+else:
+ from agenta_backend.models.api.api_models import (
+ Image,
+ CreateApp,
+ AppVariantResponse,
+ CreateAppVariant,
+ EnvironmentOutput,
+ EnvironmentOutputExtended,
+ )
+if isCloudEE():
+ from agenta_backend.commons.services import db_manager_ee
from agenta_backend.commons.services.selectors import (
- get_user_and_org_id,
+ get_user_own_org,
+ get_user_org_and_workspace_id,
+ get_org_default_workspace,
) # noqa pylint: disable-all
-else:
- from agenta_backend.services.selectors import get_user_and_org_id
+ from agenta_backend.commons.utils.permissions import (
+ check_action_access,
+ check_rbac_permission,
+ check_apikey_action_access,
+ )
+ from agenta_backend.commons.models.db_models import Permission
-if os.environ["FEATURE_FLAG"] in ["cloud"]:
+
+if isCloud():
from agenta_backend.cloud.services import (
lambda_deployment_manager as deployment_manager,
) # noqa pylint: disable-all
-elif os.environ["FEATURE_FLAG"] in ["ee"]:
+elif isEE():
from agenta_backend.ee.services import (
deployment_manager,
) # noqa pylint: disable-all
@@ -54,7 +78,7 @@
@router.get(
"/{app_id}/variants/",
- response_model=List[AppVariantOutput],
+ response_model=List[AppVariantResponse],
operation_id="list_app_variants",
)
async def list_app_variants(
@@ -69,31 +93,92 @@ async def list_app_variants(
stoken_session (SessionContainer, optional): The session container to verify the user's session. Defaults to Depends(verify_session()).
Returns:
- List[AppVariantOutput]: A list of app variants for the given app ID.
+ List[AppVariantResponse]: A list of app variants for the given app ID.
"""
try:
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
-
- access_app = await check_access_to_app(
- user_org_data=user_org_data, app_id=app_id
- )
- if not access_app:
- error_msg = f"You cannot access app: {app_id}"
- logger.error(error_msg)
- return JSONResponse(
- {"detail": error_msg},
- status_code=403,
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=app_id,
+ object_type="app",
+ permission=Permission.VIEW_APPLICATION,
)
+ logger.debug(f"User has Permission to list app variants: {has_permission}")
+ if not has_permission:
+ error_msg = f"You do not have access to perform this action. Please contact your organization admin."
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
- app_variants = await db_manager.list_app_variants(
- app_id=app_id, **user_org_data
- )
+ app_variants = await db_manager.list_app_variants(app_id=app_id)
return [
await converters.app_variant_db_to_output(app_variant)
for app_variant in app_variants
]
except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@router.get(
+ "/get_variant_by_env/",
+ response_model=AppVariantResponse,
+ operation_id="get_variant_by_env",
+)
+async def get_variant_by_env(
+ app_id: str,
+ environment: str,
+ request: Request,
+):
+ """
+ Retrieve the app variant based on the provided app_id and environment.
+
+ Args:
+ app_id (str): The ID of the app to retrieve the variant for.
+ environment (str): The environment of the app variant to retrieve.
+ stoken_session (SessionContainer, optional): The session token container. Defaults to Depends(verify_session()).
+
+ Raises:
+ HTTPException: If the app variant is not found (status_code=500), or if a ValueError is raised (status_code=400), or if any other exception is raised (status_code=500).
+
+ Returns:
+ AppVariantResponse: The retrieved app variant.
+ """
+ try:
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=app_id,
+ object_type="app",
+ permission=Permission.VIEW_APPLICATION,
+ )
+ logger.debug(
+ f"user has Permission to get variant by environment: {has_permission}"
+ )
+ if not has_permission:
+ error_msg = f"You do not have access to perform this action. Please contact your organization admin."
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ # Fetch the app variant using the provided app_id and environment
+ app_variant_db = await db_manager.get_app_variant_by_app_name_and_environment(
+ app_id=app_id, environment=environment
+ )
+
+ # Check if the fetched app variant is None and raise exception if it is
+ if app_variant_db is None:
+ raise HTTPException(status_code=500, detail="App Variant not found")
+ return await converters.app_variant_db_to_output(app_variant_db)
+ except ValueError as e:
+ # Handle ValueErrors and return 400 status code
+ raise HTTPException(status_code=400, detail=str(e))
+ except HTTPException as e:
+ raise e
+ except Exception as e:
+ # Handle all other exceptions and return 500 status code
logger.exception(f"An error occurred: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@@ -117,27 +202,74 @@ async def create_app(
HTTPException: If there is an error creating the app or the user does not have permission to access the app.
"""
try:
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- if payload.organization_id:
- access = await check_user_org_access(user_org_data, payload.organization_id)
- if not access:
- raise HTTPException(
- status_code=403,
- detail="You do not have permission to access this app",
+ if isCloudEE():
+ api_key_from_headers = request.headers.get("Authorization")
+ if api_key_from_headers is not None:
+ await check_apikey_action_access(
+ api_key_from_headers,
+ request.state.user_id,
+ Permission.CREATE_APPLICATION,
)
- organization_id = payload.organization_id
- else:
- # Retrieve or create user organization
- organization = await get_user_own_org(user_org_data["uid"])
- if organization is None: # TODO: Check whether we need this
- logger.error("Organization for user not found.")
- organization = await db_manager.create_user_organization(
- user_org_data["uid"]
+ try:
+ user_org_workspace_data = await get_user_org_and_workspace_id(
+ request.state.user_id
+ )
+ if user_org_workspace_data is None:
+ raise HTTPException(
+ status_code=400,
+ detail="Failed to get user org and workspace data",
+ )
+
+ if payload.organization_id:
+ organization_id = payload.organization_id
+ organization = await db_manager_ee.get_organization(organization_id)
+ else:
+ organization = await get_user_own_org(
+ user_org_workspace_data["uid"]
+ )
+ organization_id = str(organization.id)
+
+ if not organization:
+ raise HTTPException(
+ status_code=400,
+ detail="User Organization not found",
+ )
+
+ if payload.workspace_id:
+ workspace_id = payload.workspace_id
+ workspace = db_manager_ee.get_workspace(workspace_id)
+ else:
+ workspace = await get_org_default_workspace(organization)
+
+ if not workspace:
+ raise HTTPException(
+ status_code=400,
+ detail="User Organization not found",
+ )
+
+ has_permission = await check_rbac_permission(
+ user_org_workspace_data=user_org_workspace_data,
+ workspace_id=workspace.id,
+ organization=organization,
+ permission=Permission.CREATE_APPLICATION,
)
- organization_id = str(organization.id)
+ logger.debug(
+ f"User has Permission to Create Application: {has_permission}"
+ )
+ if not has_permission:
+ error_msg = f"You do not have access to perform this action. Please contact your organization admin."
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
app_db = await db_manager.create_app_and_envs(
- payload.app_name, organization_id, **user_org_data
+ payload.app_name,
+ request.state.user_id,
+ organization_id if isCloudEE() else None,
+ workspace.id if isCloudEE() else None,
)
return CreateAppOutput(app_id=str(app_db.id), app_name=str(app_db.app_name))
except Exception as e:
@@ -150,6 +282,7 @@ async def list_apps(
request: Request,
app_name: Optional[str] = None,
org_id: Optional[str] = None,
+ workspace_id: Optional[str] = None,
) -> List[App]:
"""
Retrieve a list of apps filtered by app_name and org_id.
@@ -166,8 +299,12 @@ async def list_apps(
HTTPException: If there was an error retrieving the list of apps.
"""
try:
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- apps = await db_manager.list_apps(app_name, org_id, **user_org_data)
+ apps = await db_manager.list_apps(
+ user_uid=request.state.user_id,
+ app_name=app_name,
+ org_id=org_id,
+ workspace_id=workspace_id,
+ )
return apps
except Exception as e:
logger.exception(f"An error occurred: {str(e)}")
@@ -195,7 +332,7 @@ async def add_variant_from_image(
dict: The newly added variant.
"""
- if os.environ["FEATURE_FLAG"] not in ["cloud", "ee"]:
+ if not isCloudEE():
image = Image(
type="image",
docker_id=payload.docker_id,
@@ -210,17 +347,24 @@ async def add_variant_from_image(
raise HTTPException(status_code=404, detail="Image not found")
try:
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- access_app = await check_access_to_app(user_org_data, app_id=app_id)
- if not access_app:
- error_msg = f"You cannot access app: {app_id}"
- logger.error(error_msg)
- return JSONResponse(
- {"detail": error_msg},
- status_code=403,
- )
app = await db_manager.fetch_app_by_id(app_id)
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=app,
+ permission=Permission.CREATE_APPLICATION,
+ )
+ logger.debug(
+ f"User has Permission to create app from image: {has_permission}"
+ )
+ if not has_permission:
+ error_msg = f"You do not have access to perform this action. Please contact your organization admin."
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
variant_db = await app_manager.add_variant_based_on_image(
app=app,
variant_name=payload.variant_name,
@@ -229,7 +373,7 @@ async def add_variant_from_image(
base_name=payload.base_name,
config_name=payload.config_name,
is_template_image=False,
- **user_org_data,
+ user_uid=request.state.user_id,
)
app_variant_db = await db_manager.fetch_app_variant_by_id(str(variant_db.id))
@@ -250,20 +394,23 @@ async def remove_app(app_id: str, request: Request):
app -- App to remove
"""
try:
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- access_app = await check_access_to_app(
- user_org_data, app_id=app_id, check_owner=True
- )
+ app = await db_manager.fetch_app_by_id(app_id)
- if not access_app:
- error_msg = f"You do not have permission to delete app: {app_id}"
- logger.error(error_msg)
- return JSONResponse(
- {"detail": error_msg},
- status_code=400,
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=app,
+ permission=Permission.DELETE_APPLICATION,
)
- else:
- await app_manager.remove_app(app_id=app_id, **user_org_data)
+ logger.debug(f"User has Permission to delete app: {has_permission}")
+ if not has_permission:
+ error_msg = f"You do not have access to perform this action. Please contact your organization admin."
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ await app_manager.remove_app(app)
except DockerException as e:
detail = f"Docker error while trying to remove the app: {str(e)}"
logger.exception(f"Docker error while trying to remove the app: {str(e)}")
@@ -281,7 +428,7 @@ async def remove_app(app_id: str, request: Request):
async def create_app_and_variant_from_template(
payload: CreateAppVariant,
request: Request,
-) -> AppVariantOutput:
+) -> AppVariantResponse:
"""
Create an app and variant from a template.
@@ -293,90 +440,141 @@ async def create_app_and_variant_from_template(
HTTPException: If the user has reached the app limit or if an app with the same name already exists.
Returns:
- AppVariantOutput: The output of the created app variant.
+ AppVariantResponse: The output of the created app variant.
"""
try:
logger.debug("Start: Creating app and variant from template")
- # Get user and org id
- logger.debug("Step 1: Getting user and organization ID")
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
+ if isCloudEE():
+ # Get user and org id
+ logger.debug("Step 1: Getting user and organization ID")
+ user_org_workspace_data: dict = await get_user_org_and_workspace_id(
+ request.state.user_id
+ )
+
+ logger.debug(
+ "Step 2: Checking that workspace ID and organization ID are provided"
+ )
+ if payload.organization_id is None or payload.workspace_id is None:
+ raise Exception(
+ "Organization ID and Workspace ID must be provided to create app from template",
+ )
- logger.debug("Step 2: Setting organization ID")
- if payload.organization_id is None:
- organization = await get_user_own_org(user_org_data["uid"])
- organization_id = organization.id
- else:
- organization_id = payload.organization_id
+ logger.debug("Step 3: Checking user has permission to create app")
+ has_permission = await check_rbac_permission(
+ user_org_workspace_data=user_org_workspace_data,
+ workspace_id=payload.workspace_id,
+ organization_id=payload.organization_id,
+ permission=Permission.CREATE_APPLICATION,
+ )
+ logger.debug(
+ f"User has Permission to create app from template: {has_permission}"
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
- logger.debug(f"Step 3 Checking if app {payload.app_name} already exists")
+ logger.debug(
+ f"Step 4: Checking if app {payload.app_name} already exists"
+ if isCloudEE()
+ else f"Step 1: Checking if app {payload.app_name} already exists"
+ )
app_name = payload.app_name.lower()
- app = await db_manager.fetch_app_by_name_and_organization(
- app_name, organization_id, **user_org_data
+ app = await db_manager.fetch_app_by_name_and_parameters(
+ app_name,
+ request.state.user_id,
+ payload.organization_id if isCloudEE() else None,
+ payload.workspace_id if isCloudEE() else None,
)
if app is not None:
raise Exception(
f"App with name {app_name} already exists",
)
- logger.debug("Step 4: Creating new app and initializing environments")
+ logger.debug(
+ "Step 5: Creating new app and initializing environments"
+ if isCloudEE()
+ else "Step 2: Creating new app and initializing environments"
+ )
if app is None:
app = await db_manager.create_app_and_envs(
- app_name, organization_id, **user_org_data
+ app_name,
+ request.state.user_id,
+ payload.organization_id if isCloudEE() else None,
+ payload.workspace_id if isCloudEE() else None,
)
- logger.debug("Step 5: Retrieve template from db")
+ logger.debug(
+ "Step 6: Retrieve template from db"
+ if isCloudEE()
+ else "Step 3: Retrieve template from db"
+ )
template_db = await db_manager.get_template(payload.template_id)
repo_name = os.environ.get("AGENTA_TEMPLATE_REPO", "agentaai/templates_v2")
image_name = f"{repo_name}:{template_db.name}"
logger.debug(
- "Step 6: Creating image instance and adding variant based on image"
+ "Step 7: Creating image instance and adding variant based on image"
+ if isCloudEE()
+ else "Step 4: Creating image instance and adding variant based on image"
)
app_variant_db = await app_manager.add_variant_based_on_image(
app=app,
variant_name="app.default",
docker_id_or_template_uri=(
- template_db.template_uri
- if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]
- else template_db.digest
- ),
- tags=(
- f"{image_name}"
- if os.environ["FEATURE_FLAG"] not in ["cloud", "ee"]
- else None
+ template_db.template_uri if isCloudEE() else template_db.digest
),
+ tags=f"{image_name}" if not isCloudEE() else None,
base_name="app",
config_name="default",
is_template_image=True,
- **user_org_data,
+ user_uid=request.state.user_id,
)
- logger.debug("Step 7: Creating testset for app variant")
+ logger.debug(
+ "Step 8: Creating testset for app variant"
+ if isCloudEE()
+ else "Step 5: Creating testset for app variant"
+ )
await db_manager.add_testset_to_app_variant(
app_id=str(app.id),
- org_id=organization_id,
+ org_id=payload.organization_id if isCloudEE() else None,
+ workspace_id=payload.workspace_id if isCloudEE() else None,
template_name=template_db.name,
app_name=app.app_name,
- **user_org_data,
+ user_uid=request.state.user_id,
)
- logger.debug("Step 8: We create ready-to use evaluators")
+ logger.debug(
+ "Step 9: We create ready-to use evaluators"
+ if isCloudEE()
+ else "Step 6: We create ready-to use evaluators"
+ )
await evaluator_manager.create_ready_to_use_evaluators(app=app)
- logger.debug("Step 9: Starting variant and injecting environment variables")
-
- envvars = {} if payload.env_vars is None else payload.env_vars
- if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
- if envvars.get("OPENAI_API_KEY", "") == "":
- if not os.environ["OPENAI_API_KEY"]:
- raise HTTPException(
- status_code=400,
- detail="Unable to start app container. Please file an issue by clicking on the button below.",
- )
- envvars["OPENAI_API_KEY"] = os.environ["OPENAI_API_KEY"]
-
- await app_manager.start_variant(app_variant_db, envvars, **user_org_data)
+ logger.debug(
+ "Step 10: Starting variant and injecting environment variables"
+ if isCloudEE()
+ else "Step 7: Starting variant and injecting environment variables"
+ )
+ if isCloudEE():
+ if not os.environ["OPENAI_API_KEY"]:
+ raise Exception(
+ "Unable to start app container. Please file an issue by clicking on the button below.",
+ )
+ envvars = {
+ **(payload.env_vars or {}),
+ "OPENAI_API_KEY": os.environ[
+ "OPENAI_API_KEY"
+ ], # order is important here
+ }
+ else:
+ envvars = {} if payload.env_vars is None else payload.env_vars
+ await app_manager.start_variant(app_variant_db, envvars)
logger.debug("End: Successfully created app and variant")
return await converters.app_variant_db_to_output(app_variant_db)
@@ -407,30 +605,26 @@ async def list_environments(
"""
logger.debug(f"Listing environments for app: {app_id}")
try:
- logger.debug("get user and org data")
- user_and_org_data: dict = await get_user_and_org_id(request.state.user_id)
-
- # Check if has app access
- logger.debug("check_access_to_app")
- access_app = await check_access_to_app(
- user_org_data=user_and_org_data, app_id=app_id
- )
- logger.debug(f"access_app: {access_app}")
- if not access_app:
- error_msg = f"You do not have access to this app: {app_id}"
- return JSONResponse(
- {"detail": error_msg},
- status_code=400,
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=app_id,
+ object_type="app",
+ permission=Permission.VIEW_APPLICATION,
)
- else:
- environments_db = await db_manager.list_environments(
- app_id=app_id, **user_and_org_data
- )
- logger.debug(f"environments_db: {environments_db}")
- return [
- await converters.environment_db_to_output(env)
- for env in environments_db
- ]
+ logger.debug(f"User has Permission to list environments: {has_permission}")
+ if not has_permission:
+ error_msg = f"You do not have access to perform this action. Please contact your organization admin."
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ environments_db = await db_manager.list_environments(app_id=app_id)
+ logger.debug(f"environments_db: {environments_db}")
+ return [
+ await converters.environment_db_to_output(env) for env in environments_db
+ ]
except Exception as e:
logger.exception(f"An error occurred: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@@ -445,23 +639,27 @@ async def list_app_environment_revisions(
request: Request, app_id: str, environment_name
):
logger.debug("getting environment " + environment_name)
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
+ user_org_workspace_data: dict = await get_user_org_and_workspace_id(
+ request.state.user_id
+ )
try:
- logger.debug("check_access_to_app")
- access_app = await check_access_to_app(
- user_org_data=user_org_data, app_id=app_id
- )
- logger.debug(f"access_app: {access_app}")
-
- if not access_app:
- error_msg = f"You do not have access to this app: {app_id}"
- return JSONResponse(
- {"detail": error_msg},
- status_code=400,
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=app_id,
+ object_type="app",
+ permission=Permission.VIEW_APPLICATION,
)
+ logger.debug(f"User has Permission to list environments: {has_permission}")
+ if not has_permission:
+ error_msg = f"You do not have access to perform this action. Please contact your organization admin."
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
app_environment = await db_manager.fetch_app_environment_by_name_and_appid(
- app_id, environment_name, **user_org_data
+ app_id, environment_name, **user_org_workspace_data
)
if app_environment is None:
return JSONResponse(
@@ -470,7 +668,7 @@ async def list_app_environment_revisions(
app_environment_revisions = (
await db_manager.fetch_environment_revisions_for_environment(
- app_environment, **user_org_data
+ app_environment, **user_org_workspace_data
)
)
if app_environment_revisions is None:
diff --git a/agenta-backend/agenta_backend/routers/bases_router.py b/agenta-backend/agenta_backend/routers/bases_router.py
index ae58937800..cf1eaf7e98 100644
--- a/agenta-backend/agenta_backend/routers/bases_router.py
+++ b/agenta-backend/agenta_backend/routers/bases_router.py
@@ -1,27 +1,23 @@
-import os
+import logging
+
from typing import List, Optional
-from fastapi import Request, HTTPException
-from agenta_backend.utils.common import APIRouter
-from agenta_backend.models.api.api_models import BaseOutput
from fastapi.responses import JSONResponse
-from agenta_backend.services import db_manager
+from fastapi import Request, HTTPException
+
from agenta_backend.models import converters
+from agenta_backend.services import db_manager
+from agenta_backend.utils.common import APIRouter, isCloudEE
+from agenta_backend.models.api.api_models import BaseOutput
-if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
- from agenta_backend.commons.services.selectors import (
- get_user_and_org_id,
- ) # noqa pylint: disable-all
-else:
- from agenta_backend.services.selectors import get_user_and_org_id
-from agenta_backend.utils.common import check_access_to_app
+if isCloudEE():
+ from agenta_backend.commons.models.db_models import Permission
+ from agenta_backend.commons.utils.permissions import check_action_access
-import logging
+router = APIRouter()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
-router = APIRouter()
-
@router.get("/", response_model=List[BaseOutput], operation_id="list_bases")
async def list_bases(
@@ -44,20 +40,22 @@ async def list_bases(
HTTPException: If there was an error retrieving the bases.
"""
try:
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- access_app = await check_access_to_app(
- user_org_data=user_org_data, app_id=app_id
- )
- if not access_app:
- error_msg = f"You cannot access app: {app_id}"
- logger.error(error_msg)
- return JSONResponse(
- {"detail": error_msg},
- status_code=403,
+ if isCloudEE() and app_id is not None:
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=app_id,
+ object_type="app",
+ permission=Permission.VIEW_APPLICATION,
)
- bases = await db_manager.list_bases_for_app_id(
- app_id, base_name, **user_org_data
- )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ bases = await db_manager.list_bases_for_app_id(app_id, base_name)
return [converters.base_db_to_pydantic(base) for base in bases]
except Exception as e:
logger.error(f"list_bases exception ===> {e}")
diff --git a/agenta-backend/agenta_backend/routers/configs_router.py b/agenta-backend/agenta_backend/routers/configs_router.py
index ce94761157..acf86f119f 100644
--- a/agenta-backend/agenta_backend/routers/configs_router.py
+++ b/agenta-backend/agenta_backend/routers/configs_router.py
@@ -1,8 +1,9 @@
-import os
+import logging
+
from typing import Optional
+from fastapi.responses import JSONResponse
from fastapi import Request, HTTPException
-from agenta_backend.utils.common import APIRouter
-import logging
+from agenta_backend.utils.common import APIRouter, isCloudEE
from agenta_backend.models.api.api_models import (
SaveConfigPayload,
@@ -13,19 +14,15 @@
app_manager,
)
-if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
- from agenta_backend.commons.services.selectors import (
- get_user_and_org_id,
- ) # noqa pylint: disable-all
-else:
- from agenta_backend.services.selectors import get_user_and_org_id
+if isCloudEE():
+ from agenta_backend.commons.models.db_models import Permission
+ from agenta_backend.commons.utils.permissions import check_action_access
+router = APIRouter()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
-router = APIRouter()
-
@router.post("/", operation_id="save_config")
async def save_config(
@@ -33,11 +30,23 @@ async def save_config(
request: Request,
):
try:
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- base_db = await db_manager.fetch_base_and_check_access(
- payload.base_id, user_org_data
- )
- variants_db = await db_manager.list_variants_for_base(base_db, **user_org_data)
+ base_db = await db_manager.fetch_base_by_id(payload.base_id)
+
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=base_db,
+ permission=Permission.MODIFY_VARIANT_CONFIGURATIONS,
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ variants_db = await db_manager.list_variants_for_base(base_db)
variant_to_overwrite = None
for variant_db in variants_db:
if variant_db.config_name == payload.config_name:
@@ -49,7 +58,7 @@ async def save_config(
await app_manager.update_variant_parameters(
app_variant_id=str(variant_to_overwrite.id),
parameters=payload.parameters,
- **user_org_data,
+ user_uid=request.state.user_id,
)
else:
raise HTTPException(
@@ -64,7 +73,7 @@ async def save_config(
base_db=base_db,
new_config_name=payload.config_name,
parameters=payload.parameters,
- **user_org_data,
+ user_uid=request.state.user_id,
)
except HTTPException as e:
logger.error(f"save_config http exception ===> {e.detail}")
@@ -82,9 +91,23 @@ async def get_config(
environment_name: Optional[str] = None,
):
try:
+ base_db = await db_manager.fetch_base_by_id(base_id)
+
# detemine whether the user has access to the base
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- base_db = await db_manager.fetch_base_and_check_access(base_id, user_org_data)
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=base_db,
+ permission=Permission.MODIFY_VARIANT_CONFIGURATIONS,
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
# in case environment_name is provided, find the variant deployed
if environment_name:
app_environments = await db_manager.list_environments(
@@ -109,9 +132,7 @@ async def get_config(
)
config = found_variant_revision.config
elif config_name:
- variants_db = await db_manager.list_variants_for_base(
- base_db, **user_org_data
- )
+ variants_db = await db_manager.list_variants_for_base(base_db)
found_variant = None
for variant_db in variants_db:
if variant_db.config_name == config_name:
@@ -183,6 +204,20 @@ async def revert_deployment_revision(request: Request, deployment_revision_id: s
f"No environment revision found for deployment revision {deployment_revision_id}",
)
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=environment_revision,
+ permission=Permission.EDIT_APP_ENVIRONMENT_DEPLOYMENT,
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
if environment_revision.deployed_app_variant_revision is None:
raise HTTPException(
404,
diff --git a/agenta-backend/agenta_backend/routers/container_router.py b/agenta-backend/agenta_backend/routers/container_router.py
index 0f71250fa0..1e0b5bc3d1 100644
--- a/agenta-backend/agenta_backend/routers/container_router.py
+++ b/agenta-backend/agenta_backend/routers/container_router.py
@@ -1,40 +1,38 @@
-import os
import logging
-from typing import List, Optional, Union
+from typing import List, Optional, Union
from fastapi.responses import JSONResponse
from fastapi import Request, UploadFile, HTTPException
-if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
- from agenta_backend.commons.services.selectors import (
- get_user_and_org_id,
- ) # noqa pylint: disable-all
+from agenta_backend.services import db_manager
+from agenta_backend.utils.common import APIRouter, isCloudEE, isCloud, isEE
+
+if isCloudEE():
+ from agenta_backend.commons.models.db_models import Permission
+ from agenta_backend.commons.utils.permissions import check_action_access
+ from agenta_backend.commons.models.api.api_models import Image_ as Image
else:
- from agenta_backend.services.selectors import get_user_and_org_id
+ from agenta_backend.models.api.api_models import Image
-if os.environ["FEATURE_FLAG"] in ["cloud"]:
+if isCloud():
from agenta_backend.cloud.services import container_manager
-elif os.environ["FEATURE_FLAG"] in ["ee"]:
+elif isEE():
from agenta_backend.ee.services import container_manager
else:
from agenta_backend.services import container_manager
from agenta_backend.models.api.api_models import (
URI,
- Image,
RestartAppContainer,
Template,
)
-from agenta_backend.services import db_manager
-from agenta_backend.utils.common import APIRouter
+router = APIRouter()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
-router = APIRouter()
-
# TODO: We need to improve this to use the introduced abstraction to also use start and stop service
# * Edit: someone remind me (abram) to work on this.
@@ -57,21 +55,33 @@ async def build_image(
Returns:
Image: The Docker image that was built.
"""
- # Get user and org id
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
-
- # Check app access
- app_db = await db_manager.fetch_app_and_check_access(
- app_id=app_id, user_org_data=user_org_data
- )
-
- image_result = await container_manager.build_image(
- app_db=app_db,
- base_name=base_name,
- tar_file=tar_file,
- )
+ try:
+ app_db = await db_manager.fetch_app_by_id(app_id)
+
+ # Check app access
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=app_db,
+ permission=Permission.CREATE_APPLICATION,
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ image_result = await container_manager.build_image(
+ app_db=app_db,
+ base_name=base_name,
+ tar_file=tar_file,
+ )
- return image_result
+ return image_result
+ except Exception as ex:
+ return JSONResponse({"message": str(ex)}, status_code=500)
@router.post("/restart_container/", operation_id="restart_container")
@@ -85,11 +95,7 @@ async def restart_docker_container(
payload (RestartAppContainer) -- the required data (app_name and variant_name)
"""
logger.debug(f"Restarting container for variant {payload.variant_id}")
- # Get user and org id
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- app_variant_db = await db_manager.fetch_app_variant_and_check_access(
- app_variant_id=payload.variant_id, user_org_data=user_org_data
- )
+ app_variant_db = await db_manager.fetch_app_variant_by_id(payload.variant_id)
try:
deployment = await db_manager.get_deployment_by_objectid(
app_variant_db.base.deployment
@@ -136,7 +142,7 @@ async def construct_app_container_url(
Args:
base_id (Optional[str]): The ID of the base to use for the app container.
variant_id (Optional[str]): The ID of the variant to use for the app container.
- stoken_session (SessionContainer): The session container for the user.
+ request (Request): The request object.
Returns:
URI: The URI for the app container.
@@ -144,33 +150,40 @@ async def construct_app_container_url(
Raises:
HTTPException: If the base or variant cannot be found or the user does not have access.
"""
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
+ # assert that one of base_id or variant_id is provided
+ assert base_id or variant_id, "Please provide either base_id or variant_id"
+
if base_id:
- base_db = await db_manager.fetch_base_and_check_access(
- base_id=base_id, user_org_data=user_org_data
+ object_db = await db_manager.fetch_base_by_id(base_id)
+ elif variant_id:
+ object_db = await db_manager.fetch_app_variant_by_id(variant_id)
+
+ # Check app access
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=object_db,
+ permission=Permission.VIEW_APPLICATION,
)
- # TODO: Add status check if base_db.status == "running"
- if base_db.deployment:
- deployment = await db_manager.get_deployment_by_objectid(base_db.deployment)
- uri = deployment.uri
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ raise HTTPException(status_code=403, detail=error_msg)
+
+ try:
+ if getattr(object_db, "deployment", None): # this is a base
+ deployment = await db_manager.get_deployment_by_objectid(
+ object_db.deployment
+ )
+ elif getattr(object_db.base, "deployment", None): # this is a variant
+ deployment = await db_manager.get_deployment_by_objectid(
+ object_db.base.deployment
+ )
else:
raise HTTPException(
status_code=400,
- detail=f"Base {base_id} does not have a deployment",
+ detail="Deployment not found",
)
-
- return URI(uri=uri)
- elif variant_id:
- variant_db = await db_manager.fetch_app_variant_and_check_access(
- app_variant_id=variant_id, user_org_data=user_org_data
- )
- deployment = await db_manager.get_deployment_by_objectid(
- variant_db.base.deployment
- )
- assert deployment and deployment.uri, "Deployment not found"
return URI(uri=deployment.uri)
- else:
- return JSONResponse(
- {"detail": "Please provide either base_id or variant_id"},
- status_code=400,
- )
+ except Exception as e:
+ return JSONResponse({"message": str(e)}, status_code=500)
diff --git a/agenta-backend/agenta_backend/routers/environment_router.py b/agenta-backend/agenta_backend/routers/environment_router.py
index b08c352d7e..a21efaa581 100644
--- a/agenta-backend/agenta_backend/routers/environment_router.py
+++ b/agenta-backend/agenta_backend/routers/environment_router.py
@@ -1,30 +1,20 @@
-import os
import logging
-from typing import List
-from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse
+from fastapi import Request, HTTPException
-from agenta_backend.models import converters
from agenta_backend.services import db_manager
-from agenta_backend.utils.common import APIRouter
-from agenta_backend.utils.common import check_access_to_app, check_access_to_variant
-from agenta_backend.models.api.api_models import (
- DeployToEnvironmentPayload,
-)
+from agenta_backend.utils.common import APIRouter, isCloudEE
+from agenta_backend.models.api.api_models import DeployToEnvironmentPayload
-if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
- from agenta_backend.commons.services.selectors import (
- get_user_and_org_id,
- ) # noqa pylint: disable-all
-else:
- from agenta_backend.services.selectors import get_user_and_org_id
+if isCloudEE():
+ from agenta_backend.commons.models.db_models import Permission
+ from agenta_backend.commons.utils.permissions import check_action_access
+router = APIRouter()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
-router = APIRouter()
-
@router.post("/deploy/", operation_id="deploy_to_environment")
async def deploy_to_environment(
@@ -42,25 +32,27 @@ async def deploy_to_environment(
HTTPException: If the deployment fails.
"""
try:
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
-
- # Check if has app access
- access_app = await check_access_to_variant(
- user_org_data, variant_id=payload.variant_id
- )
-
- if not access_app:
- error_msg = f"You do not have access to this variant: {payload.variant_id}"
- return JSONResponse(
- {"detail": error_msg},
- status_code=400,
- )
- else:
- await db_manager.deploy_to_environment(
- environment_name=payload.environment_name,
- variant_id=payload.variant_id,
- **user_org_data,
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=payload.variant_id,
+ object_type="app_variant",
+ permission=Permission.DEPLOY_APPLICATION,
)
+ logger.debug(f"User has permission deploy to environment: {has_permission}")
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ await db_manager.deploy_to_environment(
+ environment_name=payload.environment_name,
+ variant_id=payload.variant_id,
+ user_uid=request.state.user_id,
+ )
except Exception as e:
logger.exception(f"An error occurred: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
diff --git a/agenta-backend/agenta_backend/routers/evaluation_router.py b/agenta-backend/agenta_backend/routers/evaluation_router.py
index 5974bdb696..a8552b73db 100644
--- a/agenta-backend/agenta_backend/routers/evaluation_router.py
+++ b/agenta-backend/agenta_backend/routers/evaluation_router.py
@@ -1,39 +1,35 @@
-import os
import secrets
+import logging
from typing import Any, List
from fastapi.responses import JSONResponse
from fastapi import HTTPException, Request, status, Response, Query
-from beanie import PydanticObjectId as ObjectId
-from agenta_backend.utils.common import APIRouter
+from agenta_backend.models import converters
+from agenta_backend.tasks.evaluations import evaluate
+from agenta_backend.utils.common import APIRouter, isCloudEE
+from agenta_backend.services import evaluation_service, db_manager
from agenta_backend.models.api.evaluation_model import (
Evaluation,
EvaluationScenario,
- LMProvidersEnum,
NewEvaluation,
DeleteEvaluation,
EvaluationWebhook,
)
-from agenta_backend.services import db_manager
-from agenta_backend.tasks.evaluations import evaluate
-from agenta_backend.services import evaluation_service
-from agenta_backend.utils.common import check_access_to_app
-
from agenta_backend.services.evaluator_manager import (
check_ai_critique_inputs,
)
-if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
- from agenta_backend.commons.services.selectors import ( # noqa pylint: disable-all
- get_user_and_org_id,
- )
-else:
- from agenta_backend.services.selectors import get_user_and_org_id
+if isCloudEE():
+ from agenta_backend.commons.models.db_models import Permission
+ from agenta_backend.commons.utils.permissions import check_action_access
+
+from beanie import PydanticObjectId as ObjectId
-# Initialize api router
router = APIRouter()
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
@router.get(
@@ -59,21 +55,30 @@ async def fetch_evaluation_ids(
Returns:
List[str]: A list of evaluation ids.
"""
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- access_app = await check_access_to_app(
- user_org_data=user_org_data,
- app_id=app_id,
- )
- if not access_app:
- raise HTTPException(
- status_code=403,
- detail=f"You do not have access to this app: {str(app_id)}",
+ try:
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=app_id,
+ object_type="app",
+ permission=Permission.VIEW_EVALUATION,
+ )
+ logger.debug(
+ f"User has permission to get single evaluation: {has_permission}"
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+ evaluations = await evaluation_service.fetch_evaluations_by_resource(
+ resource_type, resource_ids
)
-
- evaluations = await evaluation_service.fetch_evaluations_by_resource(
- resource_type, resource_ids
- )
- return list(map(lambda x: x.id, evaluations))
+ return list(map(lambda x: x.id, evaluations))
+ except Exception as exc:
+ raise HTTPException(status_code=500, detail=str(exc))
@router.post("/", response_model=List[Evaluation], operation_id="create_evaluation")
@@ -88,22 +93,25 @@ async def create_evaluation(
_description_
"""
try:
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- access_app = await check_access_to_app(
- user_org_data=user_org_data,
- app_id=payload.app_id,
- check_owner=False,
- )
- if not access_app:
- error_msg = f"You do not have access to this app: {payload.app_id}"
- return JSONResponse(
- {"detail": error_msg},
- status_code=400,
- )
app = await db_manager.fetch_app_by_id(app_id=payload.app_id)
if app is None:
raise HTTPException(status_code=404, detail="App not found")
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=app,
+ permission=Permission.CREATE_EVALUATION,
+ )
+ logger.debug(f"User has permission to create evaluation: {has_permission}")
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
success, response = await check_ai_critique_inputs(
payload.evaluators_configs, payload.lm_providers_keys
)
@@ -158,11 +166,24 @@ async def fetch_evaluation_status(evaluation_id: str, request: Request):
"""
try:
- # Get user and organization id
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- evaluation = await evaluation_service.fetch_evaluation(
- evaluation_id, **user_org_data
- )
+ evaluation = await db_manager.fetch_evaluation_by_id(evaluation_id)
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=evaluation,
+ permission=Permission.VIEW_EVALUATION,
+ )
+ logger.debug(
+ f"User has permission to fetch evaluation status: {has_permission}"
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
return {"status": evaluation.status}
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
@@ -181,10 +202,26 @@ async def fetch_evaluation_results(evaluation_id: str, request: Request):
"""
try:
- # Get user and organization id
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- results = await evaluation_service.retrieve_evaluation_results(
- evaluation_id, **user_org_data
+ evaluation = await db_manager.fetch_evaluation_by_id(evaluation_id)
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=evaluation,
+ permission=Permission.VIEW_EVALUATION,
+ )
+ logger.debug(
+ f"User has permission to get evaluation results: {has_permission}"
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ results = await converters.aggregated_result_to_pydantic(
+ evaluation.aggregated_results
)
return {"results": results, "evaluation_id": evaluation_id}
except Exception as exc:
@@ -212,12 +249,34 @@ async def fetch_evaluation_scenarios(
List[EvaluationScenario]: A list of evaluation scenarios.
"""
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- eval_scenarios = await evaluation_service.fetch_evaluation_scenarios_for_evaluation(
- evaluation_id, **user_org_data
- )
+ try:
+ evaluation = await db_manager.fetch_evaluation_by_id(evaluation_id)
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=evaluation,
+ permission=Permission.VIEW_EVALUATION,
+ )
+ logger.debug(
+ f"User has permission to get evaluation scenarios: {has_permission}"
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ eval_scenarios = (
+ await evaluation_service.fetch_evaluation_scenarios_for_evaluation(
+ evaluation=evaluation
+ )
+ )
+ return eval_scenarios
- return eval_scenarios
+ except Exception as exc:
+ raise HTTPException(status_code=500, detail=str(exc))
@router.get("/", response_model=List[Evaluation])
@@ -233,10 +292,28 @@ async def fetch_list_evaluations(
Returns:
List[Evaluation]: A list of evaluations.
"""
- user_org_data = await get_user_and_org_id(request.state.user_id)
- return await evaluation_service.fetch_list_evaluations(
- app_id=app_id, **user_org_data
- )
+ try:
+ app = await db_manager.fetch_app_by_id(app_id)
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=app,
+ permission=Permission.VIEW_EVALUATION,
+ )
+ logger.debug(
+ f"User has permission to get list of evaluations: {has_permission}"
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ return await evaluation_service.fetch_list_evaluations(app)
+ except Exception as exc:
+ raise HTTPException(status_code=500, detail=str(exc))
@router.get(
@@ -254,8 +331,29 @@ async def fetch_evaluation(
Returns:
Evaluation: The fetched evaluation.
"""
- user_org_data = await get_user_and_org_id(request.state.user_id)
- return await evaluation_service.fetch_evaluation(evaluation_id, **user_org_data)
+ try:
+ evaluation = await db_manager.fetch_evaluation_by_id(evaluation_id)
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=evaluation_id,
+ object_type="evaluation",
+ permission=Permission.VIEW_EVALUATION,
+ )
+ logger.debug(
+ f"User has permission to get single evaluation: {has_permission}"
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ return await converters.evaluation_db_to_pydantic(evaluation)
+ except Exception as exc:
+ raise HTTPException(status_code=500, detail=str(exc))
@router.delete("/", response_model=List[str], operation_id="delete_evaluations")
@@ -273,12 +371,30 @@ async def delete_evaluations(
A list of the deleted comparison tables' IDs.
"""
- # Get user and organization id
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- await evaluation_service.delete_evaluations(
- delete_evaluations.evaluations_ids, **user_org_data
- )
- return Response(status_code=status.HTTP_204_NO_CONTENT)
+ try:
+ if isCloudEE():
+ for evaluation_id in delete_evaluations.evaluations_ids:
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=evaluation_id,
+ object_type="evaluation",
+ permission=Permission.DELETE_EVALUATION,
+ )
+ logger.debug(
+ f"User has permission to delete evaluation: {has_permission}"
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ await evaluation_service.delete_evaluations(delete_evaluations.evaluations_ids)
+ return Response(status_code=status.HTTP_204_NO_CONTENT)
+ except Exception as exc:
+ raise HTTPException(status_code=500, detail=str(exc))
@router.post(
@@ -318,10 +434,32 @@ async def fetch_evaluation_scenarios(
Returns:
List[EvaluationScenario]: A list of evaluation scenarios.
"""
- evaluations_ids_list = evaluations_ids.split(",")
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- eval_scenarios = await evaluation_service.compare_evaluations_scenarios(
- evaluations_ids_list, **user_org_data
- )
+ try:
+ evaluations_ids_list = evaluations_ids.split(",")
+
+ if isCloudEE():
+ for evaluation_id in evaluations_ids_list:
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=evaluation_id,
+ object_type="evaluation",
+ permission=Permission.VIEW_EVALUATION,
+ )
+ logger.debug(
+ f"User has permission to get evaluation scenarios: {has_permission}"
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ eval_scenarios = await evaluation_service.compare_evaluations_scenarios(
+ evaluations_ids_list
+ )
- return eval_scenarios
+ return eval_scenarios
+ except Exception as exc:
+ raise HTTPException(status_code=500, detail=str(exc))
diff --git a/agenta-backend/agenta_backend/routers/evaluators_router.py b/agenta-backend/agenta_backend/routers/evaluators_router.py
index ef6d25670c..3c1ddfc870 100644
--- a/agenta-backend/agenta_backend/routers/evaluators_router.py
+++ b/agenta-backend/agenta_backend/routers/evaluators_router.py
@@ -1,11 +1,12 @@
-import os
-import json
-from typing import List
import logging
-from fastapi import HTTPException, APIRouter, Query
+from typing import List
+from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
+from agenta_backend.utils.common import APIRouter, isCloudEE
+from agenta_backend.services import evaluator_manager, db_manager
+
from agenta_backend.models.api.evaluation_model import (
Evaluator,
EvaluatorConfig,
@@ -13,20 +14,9 @@
UpdateEvaluatorConfig,
)
-from agenta_backend.services import (
- db_manager,
-)
-
-from agenta_backend.services import evaluator_manager
-
-from agenta_backend.utils.common import check_access_to_app
-
-if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
- from agenta_backend.commons.services.selectors import ( # noqa pylint: disable-all
- get_user_and_org_id,
- )
-else:
- from agenta_backend.services.selectors import get_user_and_org_id
+if isCloudEE():
+ from agenta_backend.commons.models.db_models import Permission
+ from agenta_backend.commons.utils.permissions import check_action_access
router = APIRouter()
logger = logging.getLogger(__name__)
@@ -41,19 +31,24 @@ async def get_evaluators_endpoint():
List[Evaluator]: A list of evaluator objects.
"""
- evaluators = evaluator_manager.get_evaluators()
+ try:
+ evaluators = evaluator_manager.get_evaluators()
- if evaluators is None:
- raise HTTPException(status_code=500, detail="Error processing evaluators file")
+ if evaluators is None:
+ raise HTTPException(
+ status_code=500, detail="Error processing evaluators file"
+ )
- if not evaluators:
- raise HTTPException(status_code=404, detail="No evaluators found")
+ if not evaluators:
+ raise HTTPException(status_code=404, detail="No evaluators found")
- return evaluators
+ return evaluators
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
@router.get("/configs/", response_model=List[EvaluatorConfig])
-async def get_evaluator_configs(app_id: str):
+async def get_evaluator_configs(app_id: str, request: Request):
"""Endpoint to fetch evaluator configurations for a specific app.
Args:
@@ -63,28 +58,68 @@ async def get_evaluator_configs(app_id: str):
List[EvaluatorConfigDB]: A list of evaluator configuration objects.
"""
- evaluators_configs = await evaluator_manager.get_evaluators_configs(app_id)
- return evaluators_configs
+ try:
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=app_id,
+ object_type="app",
+ permission=Permission.VIEW_EVALUATION,
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ evaluators_configs = await evaluator_manager.get_evaluators_configs(app_id)
+ return evaluators_configs
+ except Exception as e:
+ raise HTTPException(
+ status_code=500, detail=f"Error fetching evaluator configurations: {str(e)}"
+ )
@router.get("/configs/{evaluator_config_id}/", response_model=EvaluatorConfig)
-async def get_evaluator_config(evaluator_config_id: str):
+async def get_evaluator_config(evaluator_config_id: str, request: Request):
"""Endpoint to fetch evaluator configurations for a specific app.
Returns:
List[EvaluatorConfigDB]: A list of evaluator configuration objects.
"""
- evaluators_configs = await evaluator_manager.get_evaluator_config(
- evaluator_config_id
- )
- return evaluators_configs
+ try:
+ evaluator_config_db = await db_manager.fetch_evaluator_config(
+ evaluator_config_id
+ )
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=evaluator_config_db.app,
+ permission=Permission.VIEW_EVALUATION,
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ evaluators_configs = await evaluator_manager.get_evaluator_config(
+ evaluator_config_db
+ )
+ return evaluators_configs
+ except Exception as e:
+ raise HTTPException(
+ status_code=500, detail=f"Error fetching evaluator configuration: {str(e)}"
+ )
@router.post("/configs/", response_model=EvaluatorConfig)
-async def create_new_evaluator_config(
- payload: NewEvaluatorConfig,
-):
+async def create_new_evaluator_config(payload: NewEvaluatorConfig, request: Request):
"""Endpoint to fetch evaluator configurations for a specific app.
Args:
@@ -93,19 +128,38 @@ async def create_new_evaluator_config(
Returns:
EvaluatorConfigDB: Evaluator configuration api model.
"""
-
- evaluator_config = await evaluator_manager.create_evaluator_config(
- app_id=payload.app_id,
- name=payload.name,
- evaluator_key=payload.evaluator_key,
- settings_values=payload.settings_values,
- )
- return evaluator_config
+ try:
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=payload.app_id,
+ object_type="app",
+ permission=Permission.CREATE_EVALUATION,
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ evaluator_config = await evaluator_manager.create_evaluator_config(
+ app_id=payload.app_id,
+ name=payload.name,
+ evaluator_key=payload.evaluator_key,
+ settings_values=payload.settings_values,
+ )
+ return evaluator_config
+ except Exception as e:
+ raise HTTPException(
+ status_code=500, detail=f"Error creating evaluator configuration: {str(e)}"
+ )
@router.put("/configs/{evaluator_config_id}/", response_model=EvaluatorConfig)
async def update_evaluator_config(
- evaluator_config_id: str, payload: UpdateEvaluatorConfig
+ evaluator_config_id: str, payload: UpdateEvaluatorConfig, request: Request
):
"""Endpoint to update evaluator configurations for a specific app.
@@ -113,14 +167,34 @@ async def update_evaluator_config(
List[EvaluatorConfigDB]: A list of evaluator configuration objects.
"""
- evaluators_configs = await evaluator_manager.update_evaluator_config(
- evaluator_config_id=evaluator_config_id, updates=payload
- )
- return evaluators_configs
+ try:
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=evaluator_config_id,
+ object_type="evaluator_config",
+ permission=Permission.EDIT_EVALUATION,
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ evaluators_configs = await evaluator_manager.update_evaluator_config(
+ evaluator_config_id=evaluator_config_id, updates=payload
+ )
+ return evaluators_configs
+ except Exception as e:
+ raise HTTPException(
+ status_code=500, detail=f"Error updating evaluator configuration: {str(e)}"
+ )
@router.delete("/configs/{evaluator_config_id}/", response_model=bool)
-async def delete_evaluator_config(evaluator_config_id: str):
+async def delete_evaluator_config(evaluator_config_id: str, request: Request):
"""Endpoint to delete a specific evaluator configuration.
Args:
@@ -130,6 +204,21 @@ async def delete_evaluator_config(evaluator_config_id: str):
bool: True if deletion was successful, False otherwise.
"""
try:
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=evaluator_config_id,
+ object_type="evaluator_config",
+ permission=Permission.DELETE_EVALUATION,
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
success = await evaluator_manager.delete_evaluator_config(evaluator_config_id)
return success
except Exception as e:
diff --git a/agenta-backend/agenta_backend/routers/human_evaluation_router.py b/agenta-backend/agenta_backend/routers/human_evaluation_router.py
index 9dbdb2240d..6cfadb4dc2 100644
--- a/agenta-backend/agenta_backend/routers/human_evaluation_router.py
+++ b/agenta-backend/agenta_backend/routers/human_evaluation_router.py
@@ -1,10 +1,12 @@
-import os
-import secrets
from typing import List, Dict
-
from fastapi.responses import JSONResponse
-from fastapi.encoders import jsonable_encoder
-from fastapi import HTTPException, APIRouter, Body, Request, status, Response
+from agenta_backend.utils.common import APIRouter, isCloudEE
+from fastapi import HTTPException, Body, Request, status, Response
+
+from agenta_backend.models import converters
+from agenta_backend.services import db_manager
+from agenta_backend.services import results_service
+from agenta_backend.services import evaluation_service
from agenta_backend.models.api.evaluation_model import (
DeleteEvaluation,
@@ -18,27 +20,19 @@
SimpleEvaluationOutput,
)
-from agenta_backend.services import evaluation_service
-from agenta_backend.utils.common import check_access_to_app
-from agenta_backend.services import db_manager
-from agenta_backend.models import converters
-from agenta_backend.services import results_service
-
from agenta_backend.services.evaluation_service import (
UpdateEvaluationScenarioError,
- get_evaluation_scenario_score_service,
- update_evaluation_scenario_score_service,
update_human_evaluation_scenario,
update_human_evaluation_service,
)
-
-if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
- from agenta_backend.commons.services.selectors import ( # noqa pylint: disable-all
- get_user_and_org_id,
- )
-else:
- from agenta_backend.services.selectors import get_user_and_org_id
+if isCloudEE():
+ from agenta_backend.commons.models.db_models import (
+ Permission,
+ ) # noqa pylint: disable-all
+ from agenta_backend.commons.utils.permissions import (
+ check_action_access,
+ ) # noqa pylint: disable-all
router = APIRouter()
@@ -57,25 +51,26 @@ async def create_evaluation(
_description_
"""
try:
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- access_app = await check_access_to_app(
- user_org_data=user_org_data,
- app_id=payload.app_id,
- check_owner=False,
- )
- if not access_app:
- error_msg = f"You do not have access to this app: {payload.app_id}"
- return JSONResponse(
- {"detail": error_msg},
- status_code=400,
- )
app = await db_manager.fetch_app_by_id(app_id=payload.app_id)
-
if app is None:
raise HTTPException(status_code=404, detail="App not found")
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=payload.app_id,
+ object_type="app",
+ permission=Permission.CREATE_EVALUATION,
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin."
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
new_human_evaluation_db = await evaluation_service.create_new_human_evaluation(
- payload, **user_org_data
+ payload, request.state.user_id
)
return converters.human_evaluation_db_to_simple_evaluation_output(
new_human_evaluation_db
@@ -100,10 +95,24 @@ async def fetch_list_human_evaluations(
Returns:
List[HumanEvaluation]: A list of evaluations.
"""
- user_org_data = await get_user_and_org_id(request.state.user_id)
- return await evaluation_service.fetch_list_human_evaluations(
- app_id=app_id, **user_org_data
- )
+ try:
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=app_id,
+ object_type="app",
+ permission=Permission.VIEW_EVALUATION,
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin."
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ return await evaluation_service.fetch_list_human_evaluations(app_id)
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/{evaluation_id}/", response_model=HumanEvaluation)
@@ -119,10 +128,28 @@ async def fetch_human_evaluation(
Returns:
HumanEvaluation: The fetched evaluation.
"""
- user_org_data = await get_user_and_org_id(request.state.user_id)
- return await evaluation_service.fetch_human_evaluation(
- evaluation_id, **user_org_data
- )
+ try:
+ human_evaluation = await db_manager.fetch_human_evaluation_by_id(evaluation_id)
+ if not human_evaluation:
+ raise HTTPException(status_code=404, detail="Evaluation not found")
+
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=evaluation_id,
+ object_type="human_evaluation",
+ permission=Permission.VIEW_EVALUATION,
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin."
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ return await evaluation_service.fetch_human_evaluation(human_evaluation)
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e)) from e
@router.get(
@@ -146,14 +173,36 @@ async def fetch_evaluation_scenarios(
List[EvaluationScenario]: A list of evaluation scenarios.
"""
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- eval_scenarios = (
- await evaluation_service.fetch_human_evaluation_scenarios_for_evaluation(
- evaluation_id, **user_org_data
+ try:
+ human_evaluation = await db_manager.fetch_human_evaluation_by_id(evaluation_id)
+ if human_evaluation is None:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Evaluation with id {evaluation_id} not found",
+ )
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=evaluation_id,
+ object_type="human_evaluation",
+ permission=Permission.VIEW_EVALUATION,
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin."
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ eval_scenarios = (
+ await evaluation_service.fetch_human_evaluation_scenarios_for_evaluation(
+ human_evaluation
+ )
)
- )
- return eval_scenarios
+ return eval_scenarios
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e)) from e
@router.put("/{evaluation_id}/", operation_id="update_human_evaluation")
@@ -171,11 +220,24 @@ async def update_human_evaluation(
None: A 204 No Content status code, indicating that the update was successful.
"""
try:
- # Get user and organization id
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- await update_human_evaluation_service(
- evaluation_id, update_data, **user_org_data
- )
+ human_evaluation = await db_manager.fetch_human_evaluation_by_id(evaluation_id)
+ if not human_evaluation:
+ raise HTTPException(status_code=404, detail="Evaluation not found")
+
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=human_evaluation,
+ permission=Permission.EDIT_EVALUATION,
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin."
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ await update_human_evaluation_service(human_evaluation, update_data)
return Response(status_code=status.HTTP_204_NO_CONTENT)
except KeyError:
@@ -203,16 +265,39 @@ async def update_evaluation_scenario_router(
Returns:
None: 204 No Content status code upon successful update.
"""
- user_org_data = await get_user_and_org_id(request.state.user_id)
try:
+ evaluation_scenario_db = await db_manager.fetch_human_evaluation_scenario_by_id(
+ evaluation_scenario_id
+ )
+ if evaluation_scenario_db is None:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Evaluation scenario with id {evaluation_scenario_id} not found",
+ )
+
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=evaluation_scenario_db,
+ permission=Permission.EDIT_EVALUATION,
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin."
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
await update_human_evaluation_scenario(
- evaluation_scenario_id,
+ evaluation_scenario_db,
evaluation_scenario,
evaluation_type,
- **user_org_data,
)
return Response(status_code=status.HTTP_204_NO_CONTENT)
except UpdateEvaluationScenarioError as e:
+ import traceback
+
+ traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e)) from e
@@ -231,10 +316,35 @@ async def get_evaluation_scenario_score_router(
Returns:
Dictionary containing the scenario ID and its score.
"""
- user_org_data = await get_user_and_org_id(request.state.user_id)
- return await get_evaluation_scenario_score_service(
- evaluation_scenario_id, **user_org_data
- )
+ try:
+ evaluation_scenario = db_manager.fetch_evaluation_scenario_by_id(
+ evaluation_scenario_id
+ )
+ if evaluation_scenario is None:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Evaluation scenario with id {evaluation_scenario_id} not found",
+ )
+
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=evaluation_scenario,
+ permission=Permission.VIEW_EVALUATION,
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin."
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ return {
+ "scenario_id": str(evaluation_scenario.id),
+ "score": evaluation_scenario.score,
+ }
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e)) from e
@router.put("/evaluation_scenario/{evaluation_scenario_id}/score/")
@@ -251,11 +361,32 @@ async def update_evaluation_scenario_score_router(
Returns:
None: 204 No Content status code upon successful update.
"""
- user_org_data = await get_user_and_org_id(request.state.user_id)
try:
- await update_evaluation_scenario_score_service(
- evaluation_scenario_id, payload.score, **user_org_data
+ evaluation_scenario = db_manager.fetch_evaluation_scenario_by_id(
+ evaluation_scenario_id
)
+ if evaluation_scenario is None:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Evaluation scenario with id {evaluation_scenario_id} not found",
+ )
+
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=evaluation_scenario,
+ permission=Permission.VIEW_EVALUATION,
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin."
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ evaluation_scenario.score = payload.score
+ await evaluation_scenario.save()
+
return Response(status_code=status.HTTP_204_NO_CONTENT)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@@ -275,20 +406,37 @@ async def fetch_results(
_description_
"""
- # Get user and organization id
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- evaluation = await evaluation_service._fetch_human_evaluation_and_check_access(
- evaluation_id, **user_org_data
- )
- if evaluation.evaluation_type == EvaluationType.human_a_b_testing:
- results = await results_service.fetch_results_for_evaluation(evaluation)
- return {"votes_data": results}
-
- elif evaluation.evaluation_type == EvaluationType.single_model_test:
- results = await results_service.fetch_results_for_single_model_test(
- evaluation_id
- )
- return {"results_data": results}
+ try:
+ evaluation = await db_manager.fetch_human_evaluation_by_id(evaluation_id)
+ if evaluation is None:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Evaluation with id {evaluation_id} not found",
+ )
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=evaluation,
+ permission=Permission.VIEW_EVALUATION,
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin."
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ if evaluation.evaluation_type == EvaluationType.human_a_b_testing:
+ results = await results_service.fetch_results_for_evaluation(evaluation)
+ return {"votes_data": results}
+
+ elif evaluation.evaluation_type == EvaluationType.single_model_test:
+ results = await results_service.fetch_results_for_single_model_test(
+ evaluation_id
+ )
+ return {"results_data": results}
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e)) from e
@router.delete("/", response_model=List[str])
@@ -306,9 +454,25 @@ async def delete_evaluations(
A list of the deleted comparison tables' IDs.
"""
- # Get user and organization id
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- await evaluation_service.delete_human_evaluations(
- delete_evaluations.evaluations_ids, **user_org_data
- )
- return Response(status_code=status.HTTP_204_NO_CONTENT)
+ try:
+ if isCloudEE():
+ for evaluation_id in delete_evaluations.evaluations_ids:
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=evaluation_id,
+ object_type="evaluation",
+ permission=Permission.DELETE_EVALUATION,
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin."
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ await evaluation_service.delete_human_evaluations(
+ delete_evaluations.evaluations_ids
+ )
+ return Response(status_code=status.HTTP_204_NO_CONTENT)
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e)) from e
diff --git a/agenta-backend/agenta_backend/routers/observability_router.py b/agenta-backend/agenta_backend/routers/observability_router.py
index f385a99c33..45282038a4 100644
--- a/agenta-backend/agenta_backend/routers/observability_router.py
+++ b/agenta-backend/agenta_backend/routers/observability_router.py
@@ -27,13 +27,6 @@
UpdateTrace,
)
-if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
- from agenta_backend.commons.services.selectors import (
- get_user_and_org_id,
- ) # noqa pylint: disable-all
-else:
- from agenta_backend.services.selectors import get_user_and_org_id
-
router = APIRouter()
@@ -43,9 +36,7 @@ async def create_trace(
payload: CreateTrace,
request: Request,
):
- # Get user and org id
- kwargs: dict = await get_user_and_org_id(request.state.user_id)
- trace = await create_app_trace(payload, **kwargs)
+ trace = await create_app_trace(payload, request.state.user_id)
return trace
@@ -59,9 +50,7 @@ async def get_traces(
variant_id: str,
request: Request,
):
- # Get user and org id
- kwargs: dict = await get_user_and_org_id(request.state.user_id)
- traces = await get_variant_traces(app_id, variant_id, **kwargs)
+ traces = await get_variant_traces(app_id, variant_id, request.state.user_id)
return traces
@@ -72,9 +61,7 @@ async def get_single_trace(
trace_id: str,
request: Request,
):
- # Get user and org id
- kwargs: dict = await get_user_and_org_id(request.state.user_id)
- trace = await get_trace_single(trace_id, **kwargs)
+ trace = await get_trace_single(trace_id, request.state.user_id)
return trace
@@ -83,9 +70,7 @@ async def create_span(
payload: CreateSpan,
request: Request,
):
- # Get user and org id
- kwargs: dict = await get_user_and_org_id(request.state.user_id)
- spans_id = await create_trace_span(payload, **kwargs)
+ spans_id = await create_trace_span(payload)
return spans_id
@@ -96,9 +81,7 @@ async def get_spans_of_trace(
trace_id: str,
request: Request,
):
- # Get user and org id
- kwargs: dict = await get_user_and_org_id(request.state.user_id)
- spans = await get_trace_spans(trace_id, **kwargs)
+ spans = await get_trace_spans(trace_id, request.state.user_id)
return spans
@@ -110,9 +93,7 @@ async def update_trace_status(
payload: UpdateTrace,
request: Request,
):
- # Get user and org id
- kwargs: dict = await get_user_and_org_id(request.state.user_id)
- trace = await trace_status_update(trace_id, payload, **kwargs)
+ trace = await trace_status_update(trace_id, payload, request.state.user_id)
return trace
@@ -124,9 +105,7 @@ async def create_feedback(
payload: CreateFeedback,
request: Request,
):
- # Get user and org id
- kwargs: dict = await get_user_and_org_id(request.state.user_id)
- feedback = await add_feedback_to_trace(trace_id, payload, **kwargs)
+ feedback = await add_feedback_to_trace(trace_id, payload, request.state.user_id)
return feedback
@@ -136,9 +115,7 @@ async def create_feedback(
operation_id="get_feedbacks",
)
async def get_feedbacks(trace_id: str, request: Request):
- # Get user and org id
- kwargs: dict = await get_user_and_org_id(request.state.user_id)
- feedbacks = await get_trace_feedbacks(trace_id, **kwargs)
+ feedbacks = await get_trace_feedbacks(trace_id, request.state.user_id)
return feedbacks
@@ -152,9 +129,7 @@ async def get_feedback(
feedback_id: str,
request: Request,
):
- # Get user and org id
- kwargs: dict = await get_user_and_org_id(request.state.user_id)
- feedback = await get_feedback_detail(trace_id, feedback_id, **kwargs)
+ feedback = await get_feedback_detail(trace_id, feedback_id, request.state.user_id)
return feedback
@@ -169,7 +144,7 @@ async def update_feedback(
payload: UpdateFeedback,
request: Request,
):
- # Get user and org id
- kwargs: dict = await get_user_and_org_id(request.state.user_id)
- feedback = await update_trace_feedback(trace_id, feedback_id, payload, **kwargs)
+ feedback = await update_trace_feedback(
+ trace_id, feedback_id, payload, request.state.user_id
+ )
return feedback
diff --git a/agenta-backend/agenta_backend/routers/testset_router.py b/agenta-backend/agenta_backend/routers/testset_router.py
index fca1f2b89b..649723d4fc 100644
--- a/agenta-backend/agenta_backend/routers/testset_router.py
+++ b/agenta-backend/agenta_backend/routers/testset_router.py
@@ -1,39 +1,45 @@
import io
-import os
import csv
import json
+import logging
import requests
+
from bson import ObjectId
from datetime import datetime
from typing import Optional, List
from pydantic import ValidationError
-
from fastapi.responses import JSONResponse
+from agenta_backend.services import db_manager
+from agenta_backend.services.db_manager import get_user
+from agenta_backend.utils.common import APIRouter, isCloudEE
from fastapi import HTTPException, UploadFile, File, Form, Request
+from agenta_backend.models.converters import testset_db_to_pydantic
+
from agenta_backend.models.api.testset_model import (
- TestSetSimpleResponse,
- DeleteTestsets,
NewTestset,
+ DeleteTestsets,
+ TestSetSimpleResponse,
TestSetOutputResponse,
)
-from agenta_backend.services import db_manager
-from agenta_backend.models.db_models import TestSetDB
-from agenta_backend.services.db_manager import get_user
-from agenta_backend.models.converters import testset_db_to_pydantic
-from agenta_backend.utils.common import APIRouter, check_access_to_app
+if isCloudEE():
+ from agenta_backend.commons.utils.permissions import (
+ check_action_access,
+ ) # noqa pylint: disable-all
+ from agenta_backend.commons.models.db_models import (
+ Permission,
+ TestSetDB_ as TestSetDB,
+ ) # noqa pylint: disable-all
-router = APIRouter()
-upload_folder = "./path/to/upload/folder"
+else:
+ from agenta_backend.models.db_models import TestSetDB
+router = APIRouter()
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
-if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
- from agenta_backend.commons.services.selectors import (
- get_user_and_org_id,
- ) # noqa pylint: disable-all
-else:
- from agenta_backend.services.selectors import get_user_and_org_id
+upload_folder = "./path/to/upload/folder"
@router.post(
@@ -58,26 +64,34 @@ async def upload_file(
dict: The result of the upload process.
"""
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- access_app = await check_access_to_app(
- user_org_data=user_org_data, app_id=app_id, check_owner=False
- )
- if not access_app:
- error_msg = f"You do not have access to this app: {app_id}"
- return JSONResponse(
- {"detail": error_msg},
- status_code=400,
- )
app = await db_manager.fetch_app_by_id(app_id=app_id)
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=app,
+ permission=Permission.CREATE_TESTSET,
+ )
+ logger.debug(f"User has Permission to upload Testset: {has_permission}")
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
# Create a document
document = {
"created_at": datetime.now().isoformat(),
"name": testset_name if testset_name else file.filename,
"app": app,
- "organization": app.organization,
"csvdata": [],
}
+ if isCloudEE():
+ document["organization"] = app.organization
+ document["workspace"] = app.workspace
+
if upload_type == "JSON":
# Read and parse the JSON file
json_data = await file.read()
@@ -101,7 +115,7 @@ async def upload_file(
for row in csv_reader:
document["csvdata"].append(row)
- user = await get_user(user_uid=user_org_data["uid"])
+ user = await get_user(request.state.user_id)
try:
testset_instance = TestSetDB(**document, user=user)
except ValidationError as e:
@@ -135,17 +149,21 @@ async def import_testset(
Returns:
dict: The result of the import process.
"""
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- access_app = await check_access_to_app(
- user_org_data=user_org_data, app_id=app_id, check_owner=False
- )
- if not access_app:
- error_msg = f"You do not have access to this app: {app_id}"
- return JSONResponse(
- {"detail": error_msg},
- status_code=400,
- )
app = await db_manager.fetch_app_by_id(app_id=app_id)
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=app,
+ permission=Permission.CREATE_TESTSET,
+ )
+ logger.debug(f"User has Permission to import Testset: {has_permission}")
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
try:
response = requests.get(endpoint, timeout=10)
@@ -160,16 +178,19 @@ async def import_testset(
"created_at": datetime.now().isoformat(),
"name": testset_name,
"app": app,
- "organization": app.organization,
"csvdata": [],
}
+ if isCloudEE():
+ document["organization"] = app.organization
+ document["workspace"] = app.workspace
+
# Populate the document with column names and values
json_response = response.json()
for row in json_response:
document["csvdata"].append(row)
- user = await get_user(user_uid=user_org_data["uid"])
+ user = await get_user(request.state.user_id)
testset_instance = TestSetDB(**document, user=user)
result = await testset_instance.create()
@@ -215,27 +236,35 @@ async def create_testset(
str: The id of the test set created.
"""
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- user = await get_user(user_uid=user_org_data["uid"])
- access_app = await check_access_to_app(
- user_org_data=user_org_data, app_id=app_id, check_owner=False
- )
- if not access_app:
- error_msg = f"You do not have access to this app: {app_id}"
- return JSONResponse(
- {"detail": error_msg},
- status_code=400,
- )
app = await db_manager.fetch_app_by_id(app_id=app_id)
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=app,
+ permission=Permission.CREATE_TESTSET,
+ )
+ logger.debug(f"User has Permission to create Testset: {has_permission}")
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ user = await get_user(request.state.user_id)
testset = {
"created_at": datetime.now().isoformat(),
"name": csvdata.name,
"app": app,
- "organization": app.organization,
"csvdata": csvdata.csvdata,
"user": user,
}
+ if isCloudEE():
+ testset["organization"] = app.organization
+ testset["workspace"] = app.workspace
+
try:
testset_instance = TestSetDB(**testset)
await testset_instance.create()
@@ -267,25 +296,31 @@ async def update_testset(
Returns:
str: The id of the test set updated.
"""
+ test_set = await db_manager.fetch_testset_by_id(testset_id=testset_id)
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=test_set,
+ permission=Permission.EDIT_TESTSET,
+ )
+ logger.debug(f"User has Permission to update Testset: {has_permission}")
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
testset_update = {
"name": csvdata.name,
"csvdata": csvdata.csvdata,
"updated_at": datetime.now().isoformat(),
}
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- test_set = await db_manager.fetch_testset_by_id(testset_id=testset_id)
if test_set is None:
raise HTTPException(status_code=404, detail="testset not found")
- access_app = await check_access_to_app(
- user_org_data=user_org_data, app_id=str(test_set.app.id), check_owner=False
- )
- if not access_app:
- error_msg = f"You do not have access to this app: {test_set.app.id}"
- return JSONResponse(
- {"detail": error_msg},
- status_code=400,
- )
+
try:
await test_set.update({"$set": testset_update})
if isinstance(test_set.id, ObjectId):
@@ -315,17 +350,21 @@ async def get_testsets(
Raises:
- `HTTPException` with status code 404 if no testsets are found.
"""
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- access_app = await check_access_to_app(
- user_org_data=user_org_data, app_id=app_id, check_owner=False
- )
- if not access_app:
- error_msg = f"You do not have access to this app: {app_id}"
- return JSONResponse(
- {"detail": error_msg},
- status_code=400,
- )
app = await db_manager.fetch_app_by_id(app_id=app_id)
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=app,
+ permission=Permission.VIEW_TESTSET,
+ )
+ logger.debug(f"User has Permission to view Testsets: {has_permission}")
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
if app is None:
raise HTTPException(status_code=404, detail="App not found")
@@ -355,19 +394,25 @@ async def get_single_testset(
Returns:
The requested testset if found, else an HTTPException.
"""
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
test_set = await db_manager.fetch_testset_by_id(testset_id=testset_id)
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=test_set,
+ permission=Permission.VIEW_TESTSET,
+ )
+ logger.debug(f"User has Permission to view Testset: {has_permission}")
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
if test_set is None:
raise HTTPException(status_code=404, detail="testset not found")
- access_app = await check_access_to_app(
- user_org_data=user_org_data, app_id=str(test_set.app.id), check_owner=False
- )
- if not access_app:
- error_msg = "You do not have access to this test set"
- return JSONResponse(
- {"detail": error_msg},
- status_code=400,
- )
+
return testset_db_to_pydantic(test_set)
@@ -385,25 +430,28 @@ async def delete_testsets(
Returns:
A list of the deleted testsets' IDs.
"""
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
-
deleted_ids = []
-
for testset_id in delete_testsets.testset_ids:
test_set = await db_manager.fetch_testset_by_id(testset_id=testset_id)
if test_set is None:
raise HTTPException(status_code=404, detail="testset not found")
- access_app = await check_access_to_app(
- user_org_data=user_org_data,
- app_id=str(test_set.app.id),
- check_owner=False,
- )
- if not access_app:
- error_msg = "You do not have access to this test set"
- return JSONResponse(
- {"detail": error_msg},
- status_code=400,
- )
+
+ if isCloudEE():
+ for testset_id in delete_testsets.testset_ids:
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=test_set,
+ permission=Permission.DELETE_TESTSET,
+ )
+ logger.debug(f"User has Permission to delete Testset: {has_permission}")
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
await test_set.delete()
deleted_ids.append(testset_id)
diff --git a/agenta-backend/agenta_backend/routers/user_profile.py b/agenta-backend/agenta_backend/routers/user_profile.py
index a751ae72d7..d0e3332472 100644
--- a/agenta-backend/agenta_backend/routers/user_profile.py
+++ b/agenta-backend/agenta_backend/routers/user_profile.py
@@ -1,27 +1,18 @@
import os
-from agenta_backend.models.db_models import UserDB
from fastapi import HTTPException, Request
-from agenta_backend.models.api.user_models import User
from agenta_backend.services import db_manager
from agenta_backend.utils.common import APIRouter
+from agenta_backend.models.api.user_models import User
router = APIRouter()
-if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
- from agenta_backend.commons.services.selectors import (
- get_user_and_org_id,
- ) # noqa pylint: disable-all
-else:
- from agenta_backend.services.selectors import get_user_and_org_id
-
@router.get("/", operation_id="user_profile")
async def user_profile(
request: Request,
):
try:
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- user = await db_manager.get_user(user_uid=user_org_data["uid"])
+ user = await db_manager.get_user(request.state.user_id)
return User(
id=str(user.id),
uid=str(user.uid),
diff --git a/agenta-backend/agenta_backend/routers/variants_router.py b/agenta-backend/agenta_backend/routers/variants_router.py
index ad242ec152..af650162c0 100644
--- a/agenta-backend/agenta_backend/routers/variants_router.py
+++ b/agenta-backend/agenta_backend/routers/variants_router.py
@@ -1,41 +1,46 @@
import os
import logging
+
+from typing import Any, Optional, Union
from docker.errors import DockerException
from fastapi.responses import JSONResponse
-from typing import Any, Optional, Union
+from agenta_backend.models import converters
from fastapi import HTTPException, Request, Body
-from agenta_backend.utils.common import APIRouter
+from agenta_backend.utils.common import APIRouter, isCloudEE
+
from agenta_backend.services import (
app_manager,
db_manager,
)
-from agenta_backend.utils.common import (
- check_access_to_variant,
-)
-from agenta_backend.models import converters
+
+if isCloudEE():
+ from agenta_backend.commons.utils.permissions import (
+ check_action_access,
+ ) # noqa pylint: disable-all
+ from agenta_backend.commons.models.db_models import (
+ Permission,
+ ) # noqa pylint: disable-all
+ from agenta_backend.commons.models.api.api_models import (
+ Image_ as Image,
+ AppVariantResponse_ as AppVariantResponse,
+ AppVariantOutputExtended_ as AppVariantOutputExtended,
+ )
+else:
+ from agenta_backend.models.api.api_models import (
+ Image,
+ AppVariantResponse,
+ AppVariantOutputExtended,
+ )
from agenta_backend.models.api.api_models import (
- Image,
URI,
DockerEnvVars,
- AddVariantFromBasePayload,
- AppVariantOutput,
- UpdateVariantParameterPayload,
VariantAction,
VariantActionEnum,
- AppVariantOutputExtended,
-)
-from agenta_backend.utils.common import (
- check_access_to_app,
+ AddVariantFromBasePayload,
+ UpdateVariantParameterPayload,
)
-if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
- from agenta_backend.commons.services.selectors import (
- get_user_and_org_id,
- ) # noqa pylint: disable-all
-else:
- from agenta_backend.services.selectors import get_user_and_org_id
-
router = APIRouter()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -45,7 +50,7 @@
async def add_variant_from_base_and_config(
payload: AddVariantFromBasePayload,
request: Request,
-) -> Union[AppVariantOutput, Any]:
+) -> Union[AppVariantResponse, Any]:
"""Add a new variant based on an existing one.
Same as POST /config
@@ -57,22 +62,38 @@ async def add_variant_from_base_and_config(
HTTPException: Raised if the variant could not be added or accessed.
Returns:
- Union[AppVariantOutput, Any]: New variant details or exception.
+ Union[AppVariantResponse, Any]: New variant details or exception.
"""
try:
logger.debug("Initiating process to add a variant based on a previous one.")
logger.debug(f"Received payload: {payload}")
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- base_db = await db_manager.fetch_base_and_check_access(
- payload.base_id, user_org_data
- )
+
+ base_db = await db_manager.fetch_base_by_id(payload.base_id)
+
+ # Check user has permission to add variant
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=base_db,
+ permission=Permission.CREATE_APPLICATION,
+ )
+ logger.debug(
+ f"User has Permission to create variant from base and config: {has_permission}"
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
# Find the previous variant in the database
db_app_variant = await db_manager.add_variant_from_base_and_config(
base_db=base_db,
new_config_name=payload.new_config_name,
parameters=payload.parameters,
- **user_org_data,
+ user_uid=request.state.user_id,
)
logger.debug(f"Successfully added new variant: {db_app_variant}")
app_variant_db = await db_manager.get_app_variant_instance_by_id(
@@ -103,26 +124,23 @@ async def remove_variant(
HTTPException: If there is a problem removing the app variant
"""
try:
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
-
- # Check app access
- access_app = await check_access_to_variant(
- user_org_data, variant_id=variant_id, check_owner=True
- )
-
- if not access_app:
- error_msg = (
- f"You do not have permission to delete app variant: {variant_id}"
- )
- logger.error(error_msg)
- return JSONResponse(
- {"detail": error_msg},
- status_code=400,
- )
- else:
- await app_manager.terminate_and_remove_app_variant(
- app_variant_id=variant_id, **user_org_data
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=variant_id,
+ object_type="app_variant",
+ permission=Permission.DELETE_APPLICATION_VARIANT,
)
+ logger.debug(f"User has Permission to delete app variant: {has_permission}")
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ await app_manager.terminate_and_remove_app_variant(app_variant_id=variant_id)
except DockerException as e:
detail = f"Docker error while trying to remove the app variant: {str(e)}"
raise HTTPException(status_code=500, detail=detail)
@@ -152,26 +170,29 @@ async def update_variant_parameters(
JSONResponse: A JSON response containing the updated app variant parameters.
"""
try:
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- access_variant = await check_access_to_variant(
- user_org_data=user_org_data, variant_id=variant_id
- )
-
- if not access_variant:
- error_msg = (
- f"You do not have permission to update app variant: {variant_id}"
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object_id=variant_id,
+ object_type="app_variant",
+ permission=Permission.MODIFY_VARIANT_CONFIGURATIONS,
)
- logger.error(error_msg)
- return JSONResponse(
- {"detail": error_msg},
- status_code=400,
- )
- else:
- await app_manager.update_variant_parameters(
- app_variant_id=variant_id,
- parameters=payload.parameters,
- **user_org_data,
+ logger.debug(
+ f"User has Permission to update variant parameters: {has_permission}"
)
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ await app_manager.update_variant_parameters(
+ app_variant_id=variant_id,
+ parameters=payload.parameters,
+ user_uid=request.state.user_id,
+ )
except ValueError as e:
detail = f"Error while trying to update the app variant: {str(e)}"
raise HTTPException(status_code=500, detail=detail)
@@ -200,24 +221,30 @@ async def update_variant_image(
JSONResponse: A JSON response indicating whether the update was successful or not.
"""
try:
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- access_variant = await check_access_to_variant(
- user_org_data=user_org_data, variant_id=variant_id
- )
- if not access_variant:
- error_msg = (
- f"You do not have permission to update app variant: {variant_id}"
- )
- logger.error(error_msg)
- return JSONResponse(
- {"detail": error_msg},
- status_code=400,
- )
db_app_variant = await db_manager.fetch_app_variant_by_id(
app_variant_id=variant_id
)
- await app_manager.update_variant_image(db_app_variant, image, **user_org_data)
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=db_app_variant,
+ permission=Permission.CREATE_APPLICATION,
+ )
+ logger.debug(
+ f"User has Permission to update variant image: {has_permission}"
+ )
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
+ await app_manager.update_variant_image(
+ db_app_variant, image, request.state.user_id
+ )
except ValueError as e:
detail = f"Error while trying to update the app variant: {str(e)}"
raise HTTPException(status_code=500, detail=detail)
@@ -251,35 +278,41 @@ async def start_variant(
Raises:
HTTPException: If the app container cannot be started.
"""
+ app_variant_db = await db_manager.fetch_app_variant_by_id(app_variant_id=variant_id)
+
+ # Check user has permission to start variant
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=app_variant_db,
+ permission=Permission.CREATE_APPLICATION,
+ )
+ logger.debug(f"User has Permission to start variant: {has_permission}")
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
logger.debug("Starting variant %s", variant_id)
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
- envvars = {} if env_vars is None else env_vars.env_vars
+
# Inject env vars to docker container
- if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
- if envvars.get("OPENAI_API_KEY", "") == "":
- if not os.environ["OPENAI_API_KEY"]:
- raise HTTPException(
- status_code=400,
- detail="Unable to start app container. Please file an issue by clicking on the button below.",
- )
- envvars["OPENAI_API_KEY"] = os.environ["OPENAI_API_KEY"]
+ if isCloudEE():
+ if not os.environ["OPENAI_API_KEY"]:
+ raise HTTPException(
+ status_code=400,
+ detail="Unable to start app container. Please file an issue by clicking on the button below.",
+ )
+ envvars = {
+ "OPENAI_API_KEY": os.environ["OPENAI_API_KEY"],
+ }
+ else:
+ envvars = {} if env_vars is None else env_vars.env_vars
- access = await check_access_to_variant(
- user_org_data=user_org_data, variant_id=variant_id
- )
- if not access:
- error_msg = f"You do not have access to this variant: {variant_id}"
- logger.error(error_msg)
- return JSONResponse(
- {"detail": error_msg},
- status_code=400,
- )
- app_variant_db = await db_manager.fetch_app_variant_by_id(app_variant_id=variant_id)
if action.action == VariantActionEnum.START:
- url: URI = await app_manager.start_variant(
- app_variant_db, envvars, **user_org_data
- )
+ url: URI = await app_manager.start_variant(app_variant_db, envvars)
return url
@@ -293,23 +326,26 @@ async def get_variant(
request: Request,
):
logger.debug("getting variant " + variant_id)
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
try:
- user_org_data: dict = await get_user_and_org_id(request.state.user_id)
-
- access = await check_access_to_variant(
- user_org_data=user_org_data, variant_id=variant_id
- )
- if not access:
- error_msg = f"You do not have access to this variant: {variant_id}"
- logger.error(error_msg)
- return JSONResponse(
- {"detail": error_msg},
- status_code=400,
- )
app_variant = await db_manager.fetch_app_variant_by_id(
app_variant_id=variant_id
)
+
+ if isCloudEE():
+ has_permission = await check_action_access(
+ user_uid=request.state.user_id,
+ object=app_variant,
+ permission=Permission.VIEW_APPLICATION,
+ )
+ logger.debug(f"User has Permission to get variant: {has_permission}")
+ if not has_permission:
+ error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
+ logger.error(error_msg)
+ return JSONResponse(
+ {"detail": error_msg},
+ status_code=403,
+ )
+
app_variant_revisions = await db_manager.list_app_variant_revisions_by_variant(
app_variant=app_variant
)
diff --git a/agenta-backend/agenta_backend/services/app_manager.py b/agenta-backend/agenta_backend/services/app_manager.py
index 18822909fe..04c7d747bb 100644
--- a/agenta-backend/agenta_backend/services/app_manager.py
+++ b/agenta-backend/agenta_backend/services/app_manager.py
@@ -22,18 +22,25 @@
evaluator_manager,
)
-if os.environ["FEATURE_FLAG"] in ["cloud"]:
+from agenta_backend.utils.common import (
+ isEE,
+ isOssEE,
+ isCloud,
+ isCloudEE,
+)
+
+if isCloud():
from agenta_backend.cloud.services import (
lambda_deployment_manager as deployment_manager,
) # noqa pylint: disable-all
-elif os.environ["FEATURE_FLAG"] in ["ee"]:
+elif isEE():
from agenta_backend.ee.services import (
deployment_manager,
) # noqa pylint: disable-all
else:
from agenta_backend.services import deployment_manager
-if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
+if isCloudEE():
from agenta_backend.commons.services import (
api_key_service,
) # noqa pylint: disable-all
@@ -43,9 +50,7 @@
async def start_variant(
- db_app_variant: AppVariantDB,
- env_vars: DockerEnvVars = None,
- **user_org_data: dict,
+ db_app_variant: AppVariantDB, env_vars: DockerEnvVars = None
) -> URI:
"""
Starts a Docker container for a given app variant.
@@ -66,12 +71,13 @@ async def start_variant(
"""
try:
logger.debug(
- "Starting variant %s with image name %s and tags %s and app_name %s and organization %s",
+ "Starting variant %s with image name %s and tags %s and app_name %s and organization %s and workspace %s",
db_app_variant.variant_name,
db_app_variant.image.docker_id,
db_app_variant.image.tags,
db_app_variant.app.app_name,
- db_app_variant.organization,
+ db_app_variant.organization if isCloudEE() else None,
+ db_app_variant.workspace if isCloudEE() else None,
)
logger.debug("App name is %s", db_app_variant.app.app_name)
# update the env variables
@@ -86,9 +92,12 @@ async def start_variant(
env_vars.update(
{"AGENTA_BASE_ID": str(db_app_variant.base.id), "AGENTA_HOST": domain_name}
)
- if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
+ if isCloudEE():
api_key = await api_key_service.create_api_key(
- str(db_app_variant.user.uid), expiration_date=None, hidden=True
+ str(db_app_variant.user.uid),
+ workspace_id=str(db_app_variant.workspace),
+ expiration_date=None,
+ hidden=True,
)
env_vars.update({"AGENTA_API_KEY": api_key})
deployment = await deployment_manager.start_service(
@@ -99,6 +108,9 @@ async def start_variant(
deployment=deployment.id,
)
except Exception as e:
+ import traceback
+
+ traceback.print_exc()
logger.error(
f"Error starting Docker container for app variant {db_app_variant.app.app_name}/{db_app_variant.variant_name}: {str(e)}"
)
@@ -110,7 +122,7 @@ async def start_variant(
async def update_variant_image(
- app_variant_db: AppVariantDB, image: Image, **user_org_data: dict
+ app_variant_db: AppVariantDB, image: Image, user_uid: str
):
"""Updates the image for app variant in the database.
@@ -129,7 +141,7 @@ async def update_variant_image(
await deployment_manager.stop_and_delete_service(deployment)
await db_manager.remove_deployment(deployment)
- if os.environ["FEATURE_FLAG"] in ["ee", "oss"]:
+ if isOssEE():
await deployment_manager.remove_image(app_variant_db.base.image)
await db_manager.remove_image(app_variant_db.base.image)
@@ -140,23 +152,24 @@ async def update_variant_image(
docker_id=image.docker_id,
user=app_variant_db.user,
deletable=True,
- organization=app_variant_db.organization,
+ organization=app_variant_db.organization if isCloudEE() else None, # noqa
+ workspace=app_variant_db.workspace if isCloudEE() else None, # noqa
)
# Update base with new image
await db_manager.update_base(app_variant_db.base, image=db_image)
# Update variant to remove configuration
await db_manager.update_variant_parameters(
- app_variant_db=app_variant_db, parameters={}, **user_org_data
+ app_variant_db=app_variant_db, parameters={}, user_uid=user_uid
)
# Update variant with new image
app_variant_db = await db_manager.update_app_variant(app_variant_db, image=db_image)
# Start variant
- await start_variant(app_variant_db, **user_org_data)
+ await start_variant(app_variant_db)
async def terminate_and_remove_app_variant(
- app_variant_id: str = None, app_variant_db=None, **kwargs: dict
+ app_variant_id: str = None, app_variant_db=None
) -> None:
"""
Removes app variant from the database. If it's the last one using an image, performs additional operations:
@@ -218,20 +231,20 @@ async def terminate_and_remove_app_variant(
# If image deletable is True, remove docker image and image db
if image.deletable:
try:
- if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
+ if isCloudEE():
await deployment_manager.remove_repository(image.tags)
else:
await deployment_manager.remove_image(image)
except RuntimeError as e:
logger.error(f"Failed to remove image {image} {e}")
- await db_manager.remove_image(image, **kwargs)
+ await db_manager.remove_image(image)
logger.debug("remove base")
- await db_manager.remove_app_variant_from_db(app_variant_db, **kwargs)
+ await db_manager.remove_app_variant_from_db(app_variant_db)
logger.debug("Remove image object from db")
if deployment:
await db_manager.remove_deployment(deployment)
- await db_manager.remove_base_from_db(app_variant_db.base, **kwargs)
+ await db_manager.remove_base_from_db(app_variant_db.base)
logger.debug("remove_app_variant_from_db")
# Only delete the docker image for users that are running the oss version
@@ -243,13 +256,13 @@ async def terminate_and_remove_app_variant(
else:
# remove variant + config
logger.debug("remove_app_variant_from_db")
- await db_manager.remove_app_variant_from_db(app_variant_db, **kwargs)
+ await db_manager.remove_app_variant_from_db(app_variant_db)
logger.debug("list_app_variants")
- app_variants = await db_manager.list_app_variants(app_id=app_id, **kwargs)
+ app_variants = await db_manager.list_app_variants(app_id)
logger.debug(f"{app_variants}")
if len(app_variants) == 0: # this was the last variant for an app
logger.debug("remove_app_related_resources")
- await remove_app_related_resources(app_id=app_id, **kwargs)
+ await remove_app_related_resources(app_id)
except Exception as e:
logger.error(
f"An error occurred while deleting app variant {app_variant_db.app.app_name}/{app_variant_db.variant_name}: {str(e)}"
@@ -257,7 +270,7 @@ async def terminate_and_remove_app_variant(
raise e from None
-async def remove_app_related_resources(app_id: str, **kwargs: dict):
+async def remove_app_related_resources(app_id: str):
"""Removes environments and testsets associated with an app after its deletion.
When an app or its last variant is deleted, this function ensures that
@@ -269,16 +282,16 @@ async def remove_app_related_resources(app_id: str, **kwargs: dict):
try:
# Delete associated environments
environments: List[AppEnvironmentDB] = await db_manager.list_environments(
- app_id, **kwargs
+ app_id
)
for environment_db in environments:
- await db_manager.remove_environment(environment_db, **kwargs)
+ await db_manager.remove_environment(environment_db)
logger.info(f"Successfully deleted environment {environment_db.name}.")
# Delete associated testsets
- await db_manager.remove_app_testsets(app_id, **kwargs)
+ await db_manager.remove_app_testsets(app_id)
logger.info(f"Successfully deleted test sets associated with app {app_id}.")
- await db_manager.remove_app_by_id(app_id, **kwargs)
+ await db_manager.remove_app_by_id(app_id)
logger.info(f"Successfully remove app object {app_id}.")
except Exception as e:
logger.error(
@@ -287,7 +300,7 @@ async def remove_app_related_resources(app_id: str, **kwargs: dict):
raise e from None
-async def remove_app(app_id: str, **kwargs: dict):
+async def remove_app(app: AppDB):
"""Removes all app variants from db, if it is the last one using an image, then
deletes the image from the db, shutdowns the container, deletes it and remove
the image from the registry
@@ -296,39 +309,38 @@ async def remove_app(app_id: str, **kwargs: dict):
app_name -- the app name to remove
"""
# checks if it is the last app variant using its image
- app = await db_manager.fetch_app_by_id(app_id)
if app is None:
- error_msg = f"Failed to delete app {app_id}: Not found in DB."
+ error_msg = f"Failed to delete app {app.id}: Not found in DB."
logger.error(error_msg)
raise ValueError(error_msg)
try:
- app_variants = await db_manager.list_app_variants(app_id=app_id, **kwargs)
+ app_variants = await db_manager.list_app_variants(app.id)
for app_variant_db in app_variants:
- await terminate_and_remove_app_variant(
- app_variant_db=app_variant_db, **kwargs
- )
+ await terminate_and_remove_app_variant(app_variant_db=app_variant_db)
logger.info(
f"Successfully deleted app variant {app_variant_db.app.app_name}/{app_variant_db.variant_name}."
)
if len(app_variants) == 0: # Failsafe in case something went wrong before
logger.debug("remove_app_related_resources")
- await remove_app_related_resources(app_id=app_id, **kwargs)
+ await remove_app_related_resources(str(app.id))
except Exception as e:
logger.error(
- f"An error occurred while deleting app {app_id} and its associated resources: {str(e)}"
+ f"An error occurred while deleting app {app.id} and its associated resources: {str(e)}"
)
raise e from None
async def update_variant_parameters(
- app_variant_id: str, parameters: Dict[str, Any], **user_org_data: dict
+ app_variant_id: str, parameters: Dict[str, Any], user_uid: str
):
"""Updates the parameters for app variant in the database.
Arguments:
app_variant -- the app variant to update
+ parameters -- the parameters to update
+ user_uid -- the user uid
"""
assert app_variant_id is not None, "app_variant_id must be provided"
assert parameters is not None, "parameters must be provided"
@@ -339,7 +351,7 @@ async def update_variant_parameters(
raise ValueError(error_msg)
try:
await db_manager.update_variant_parameters(
- app_variant_db=app_variant_db, parameters=parameters, **user_org_data
+ app_variant_db=app_variant_db, parameters=parameters, user_uid=user_uid
)
except Exception as e:
logger.error(
@@ -352,11 +364,11 @@ async def add_variant_based_on_image(
app: AppDB,
variant_name: str,
docker_id_or_template_uri: str,
+ user_uid: str,
tags: str = None,
base_name: str = None,
config_name: str = "default",
is_template_image: bool = False,
- **user_org_data: dict,
) -> AppVariantDB:
"""
Adds a new variant to the app based on the specified Docker image.
@@ -369,7 +381,7 @@ async def add_variant_based_on_image(
base_name (str, optional): The name of the base to use for the new variant. Defaults to None.
config_name (str, optional): The name of the configuration to use for the new variant. Defaults to "default".
is_template_image (bool, optional): Whether or not the image used is for a template (in this case we won't delete it in the future).
- **user_org_data (dict): Additional user and organization data.
+ user_uid (str): The UID of the user.
Returns:
AppVariantDB: The newly created app variant.
@@ -387,9 +399,9 @@ async def add_variant_based_on_image(
or variant_name in [None, ""]
or docker_id_or_template_uri in [None, ""]
):
- raise ValueError("App variant or image is None")
+ raise ValueError("App variant, variant name or docker_id/template_uri is None")
- if os.environ["FEATURE_FLAG"] not in ["cloud", "ee"]:
+ if not isCloudEE():
if tags in [None, ""]:
raise ValueError("OSS: Tags is None")
@@ -399,9 +411,7 @@ async def add_variant_based_on_image(
# Check if app variant already exists
logger.debug("Step 2: Checking if app variant already exists")
- variants = await db_manager.list_app_variants_for_app_id(
- app_id=str(app.id), **user_org_data
- )
+ variants = await db_manager.list_app_variants_for_app_id(app_id=str(app.id))
already_exists = any(av for av in variants if av.variant_name == variant_name)
if already_exists:
logger.error("App variant with the same name already exists")
@@ -409,16 +419,18 @@ async def add_variant_based_on_image(
# Retrieve user and image objects
logger.debug("Step 3: Retrieving user and image objects")
- user_instance = await db_manager.get_user(user_uid=user_org_data["uid"])
+ user_instance = await db_manager.get_user(user_uid)
if parsed_url.scheme and parsed_url.netloc:
db_image = await db_manager.get_orga_image_instance_by_uri(
- organization_id=str(app.organization.id),
template_uri=docker_id_or_template_uri,
+ organization_id=str(app.organization.id) if isCloudEE() else None, # noqa
+ workspace_id=str(app.workspace.id) if isCloudEE() else None, # noqa
)
else:
db_image = await db_manager.get_orga_image_instance_by_docker_id(
- organization_id=str(app.organization.id),
docker_id=docker_id_or_template_uri,
+ organization_id=str(app.organization.id) if isCloudEE() else None, # noqa
+ workspace_id=str(app.workspace.id) if isCloudEE() else None, # noqa
)
# Create new image if not exists
@@ -430,7 +442,8 @@ async def add_variant_based_on_image(
template_uri=docker_id_or_template_uri,
deletable=not (is_template_image),
user=user_instance,
- organization=app.organization,
+ organization=app.organization if isCloudEE() else None, # noqa
+ workspace=app.workspace if isCloudEE() else None, # noqa
)
else:
docker_id = docker_id_or_template_uri
@@ -440,7 +453,8 @@ async def add_variant_based_on_image(
tags=tags,
deletable=not (is_template_image),
user=user_instance,
- organization=app.organization,
+ organization=app.organization if isCloudEE() else None, # noqa
+ workspace=app.workspace if isCloudEE() else None, # noqa
)
# Create config
@@ -457,7 +471,8 @@ async def add_variant_based_on_image(
] # TODO: Change this in SDK2 to directly use base_name
db_base = await db_manager.create_new_variant_base(
app=app,
- organization=app.organization,
+ organization=app.organization if isCloudEE() else None, # noqa
+ workspace=app.workspace if isCloudEE() else None, # noqa
user=user_instance,
base_name=base_name, # the first variant always has default base
image=db_image,
@@ -470,7 +485,8 @@ async def add_variant_based_on_image(
variant_name=variant_name,
image=db_image,
user=user_instance,
- organization=app.organization,
+ organization=app.organization if isCloudEE() else None, # noqa
+ workspace=app.workspace if isCloudEE() else None, # noqa
parameters={},
base_name=base_name,
config_name=config_name,
diff --git a/agenta-backend/agenta_backend/services/container_manager.py b/agenta-backend/agenta_backend/services/container_manager.py
index 5ee1418703..c055c83ddd 100644
--- a/agenta-backend/agenta_backend/services/container_manager.py
+++ b/agenta-backend/agenta_backend/services/container_manager.py
@@ -20,17 +20,17 @@
AppDB,
)
from agenta_backend.services import docker_utils
+from agenta_backend.utils.common import isCloud
client = docker.from_env()
-
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
async def build_image(app_db: AppDB, base_name: str, tar_file: UploadFile) -> Image:
app_name = app_db.app_name
- organization_id = str(app_db.organization.id)
+ user_id = app_db.user.id
image_name = f"agentaai/{app_name.lower()}_{base_name.lower()}:latest"
# Get event loop
@@ -53,10 +53,10 @@ async def build_image(app_db: AppDB, base_name: str, tar_file: UploadFile) -> Im
*(
app_name,
base_name,
- organization_id,
tar_path,
image_name,
temp_dir,
+ user_id,
),
)
image_result = await asyncio.wrap_future(future)
@@ -66,10 +66,10 @@ async def build_image(app_db: AppDB, base_name: str, tar_file: UploadFile) -> Im
def build_image_job(
app_name: str,
base_name: str,
- organization_id: str,
tar_path: Path,
image_name: str,
temp_dir: Path,
+ user_id: str,
) -> Image:
"""Business logic for building a docker image from a tar file
@@ -80,13 +80,13 @@ def build_image_job(
base_name -- The `base_name` parameter is a string that represents the variant of the \
application. It could be a specific version, configuration, or any other distinguishing \
factor for the application
- organization_id -- The id of the organization the app belongs to
tar_path -- The `tar_path` parameter is the path to the tar file that contains the source code \
or files needed to build the Docker image
image_name -- The `image_name` parameter is a string that represents the name of the Docker \
image that will be built. It is used as the tag for the image
temp_dir -- The `temp_dir` parameter is a `Path` object that represents the temporary directory
where the contents of the tar file will be extracted
+ user_id -- The id of the user that owns the app
Raises:
HTTPException: _description_
@@ -100,26 +100,28 @@ def build_image_job(
shutil.unpack_archive(tar_path, temp_dir)
try:
- if os.environ["FEATURE_FLAG"] in ["cloud"]:
+ if isCloud():
dockerfile = "Dockerfile.cloud"
else:
dockerfile = "Dockerfile"
image, build_log = client.images.build(
path=str(temp_dir),
tag=image_name,
- buildargs={"ROOT_PATH": f"/{organization_id}/{app_name}/{base_name}"},
+ buildargs={"ROOT_PATH": f"/{user_id}/{app_name}/{base_name}"},
rm=True,
dockerfile=dockerfile,
pull=True,
)
for line in build_log:
logger.info(line)
- return Image(
+ pydantic_image = Image(
type="image",
docker_id=image.id,
tags=image.tags[0],
- organization_id=organization_id,
)
+
+ return pydantic_image
+
except docker.errors.BuildError as ex:
log = "Error building Docker image:\n"
log += str(ex) + "\n"
diff --git a/agenta-backend/agenta_backend/services/db_manager.py b/agenta-backend/agenta_backend/services/db_manager.py
index 4a2ade3988..d8dfbb77a2 100644
--- a/agenta-backend/agenta_backend/services/db_manager.py
+++ b/agenta-backend/agenta_backend/services/db_manager.py
@@ -1,58 +1,76 @@
import os
import logging
+import traceback
+
from pathlib import Path
from datetime import datetime
from urllib.parse import urlparse
+from fastapi import HTTPException
+from fastapi.responses import JSONResponse
from typing import Any, Dict, List, Optional
+from agenta_backend.models import converters
+from agenta_backend.utils.common import isCloudEE
+from agenta_backend.services.json_importer_helper import get_json
+
from agenta_backend.models.api.api_models import (
App,
- AppVariant,
- ImageExtended,
Template,
)
-from agenta_backend.models.converters import (
- app_db_to_pydantic,
- image_db_to_pydantic,
- templates_db_to_pydantic,
-)
-from agenta_backend.services.json_importer_helper import get_json
+
+if isCloudEE():
+ from agenta_backend.commons.services import db_manager_ee
+ from agenta_backend.commons.utils.permissions import check_rbac_permission
+ from agenta_backend.commons.services.selectors import get_user_org_and_workspace_id
+
+ from agenta_backend.commons.models.db_models import (
+ Permission,
+ AppDB_ as AppDB,
+ UserDB_ as UserDB,
+ ImageDB_ as ImageDB,
+ TestSetDB_ as TestSetDB,
+ AppVariantDB_ as AppVariantDB,
+ EvaluationDB_ as EvaluationDB,
+ DeploymentDB_ as DeploymentDB,
+ VariantBaseDB_ as VariantBaseDB,
+ AppEnvironmentDB_ as AppEnvironmentDB,
+ AppEnvironmentRevisionDB_ as AppEnvironmentRevisionDB,
+ EvaluatorConfigDB_ as EvaluatorConfigDB,
+ HumanEvaluationDB_ as HumanEvaluationDB,
+ EvaluationScenarioDB_ as EvaluationScenarioDB,
+ HumanEvaluationScenarioDB_ as HumanEvaluationScenarioDB,
+ )
+
+else:
+ from agenta_backend.models.db_models import (
+ AppDB,
+ UserDB,
+ ImageDB,
+ TestSetDB,
+ AppVariantDB,
+ EvaluationDB,
+ DeploymentDB,
+ VariantBaseDB,
+ AppEnvironmentDB,
+ AppEnvironmentRevisionDB,
+ EvaluatorConfigDB,
+ HumanEvaluationDB,
+ EvaluationScenarioDB,
+ HumanEvaluationScenarioDB,
+ )
from agenta_backend.models.db_models import (
- AppEnvironmentRevisionDB,
- Result,
- HumanEvaluationDB,
- HumanEvaluationScenarioDB,
+ ConfigDB,
+ TemplateDB,
AggregatedResult,
- AppDB,
- AppVariantDB,
AppVariantRevisionsDB,
- ConfigDB,
+ EvaluationScenarioResult,
EvaluationScenarioInputDB,
EvaluationScenarioOutputDB,
- EvaluationScenarioResult,
- EvaluatorConfigDB,
- VariantBaseDB,
- AppEnvironmentDB,
- EvaluationDB,
- EvaluationScenarioDB,
- ImageDB,
- OrganizationDB,
- DeploymentDB,
- AppEnvironmentRevisionDB,
- TemplateDB,
- TestSetDB,
- UserDB,
)
-from agenta_backend.utils.common import check_user_org_access
-from agenta_backend.models.api.evaluation_model import EvaluationStatusEnum
-
-from fastapi import HTTPException
-from fastapi.responses import JSONResponse
from beanie.operators import In
from beanie import PydanticObjectId as ObjectId
-
# Define logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -62,7 +80,12 @@
async def add_testset_to_app_variant(
- app_id: str, org_id: str, template_name: str, app_name: str, **kwargs: dict
+ app_id: str,
+ template_name: str,
+ app_name: str,
+ user_uid: str,
+ org_id: str = None,
+ workspace_id: str = None,
):
"""Add testset to app variant.
Args:
@@ -70,13 +93,12 @@ async def add_testset_to_app_variant(
org_id (str): The id of the organization
template_name (str): The name of the app template image
app_name (str): The name of the app
- **kwargs (dict): Additional keyword arguments
+ user_uid (str): The uid of the user
"""
try:
app_db = await get_app_instance_by_id(app_id)
- org_db = await get_organization_object(org_id)
- user_db = await get_user(user_uid=kwargs["uid"])
+ user_db = await get_user(user_uid)
json_path = os.path.join(
PARENT_DIRECTORY,
@@ -94,37 +116,27 @@ async def add_testset_to_app_variant(
"csvdata": csvdata,
}
testset_db = TestSetDB(
- **testset, app=app_db, user=user_db, organization=org_db
+ **testset,
+ app=app_db,
+ user=user_db,
)
- await testset_db.create()
-
- except Exception as e:
- print(f"An error occurred in adding the default testset: {e}")
+ if isCloudEE():
+ # assert that if organization is provided, workspace_id is also provided, and vice versa
+ assert (
+ org_id is not None and workspace_id is not None
+ ), "organization and workspace must be provided together"
-async def get_image(app_variant: AppVariant, **kwargs: dict) -> ImageExtended:
- """Returns the image associated with the app variant
-
- Arguments:
- app_variant -- AppVariant to fetch the image for
+ organization_db = await db_manager_ee.get_organization(org_id)
+ workspace_db = await db_manager_ee.get_workspace(workspace_id)
- Returns:
- Image -- The Image associated with the app variant
- """
+ testset_db.organization = organization_db
+ testset_db.workspace = workspace_db
- # Build the query expression for the two conditions
- query_expression = (
- AppVariantDB.app.id == app_variant.app_id,
- AppVariantDB.variant_name == app_variant.variant_name,
- AppVariantDB.organization.id == app_variant.organization,
- )
+ await testset_db.create()
- db_app_variant = await AppVariantDB.find_one(query_expression)
- if db_app_variant:
- image_db = await ImageDB.find_one(ImageDB.id == db_app_variant.image.id)
- return image_db_to_pydantic(image_db)
- else:
- raise Exception("App variant not found")
+ except Exception as e:
+ print(f"An error occurred in adding the default testset: {e}")
async def get_image_by_id(image_id: str) -> ImageDB:
@@ -141,7 +153,7 @@ async def get_image_by_id(image_id: str) -> ImageDB:
return image
-async def fetch_app_by_id(app_id: str, **kwargs: dict) -> AppDB:
+async def fetch_app_by_id(app_id: str) -> AppDB:
"""Fetches an app by its ID.
Args:
@@ -152,29 +164,6 @@ async def fetch_app_by_id(app_id: str, **kwargs: dict) -> AppDB:
return app
-async def fetch_app_by_name(
- app_name: str, organization_id: Optional[str] = None, **user_org_data: dict
-) -> Optional[AppDB]:
- """Fetches an app by its name.
-
- Args:
- app_name (str): The name of the app to fetch.
-
- Returns:
- AppDB: the instance of the app
- """
-
- if not organization_id:
- user = await get_user(user_uid=user_org_data["uid"])
- app = await AppDB.find_one(AppDB.app_name == app_name, AppDB.user.id == user.id)
- else:
- app = await AppDB.find_one(
- AppDB.app_name == app_name,
- AppDB.organization.id == ObjectId(organization_id),
- )
- return app
-
-
async def fetch_app_variant_by_id(
app_variant_id: str,
) -> Optional[AppVariantDB]:
@@ -220,10 +209,7 @@ async def fetch_app_variant_revision_by_variant(
return app_variant_revision
-async def fetch_base_by_id(
- base_id: str,
- user_org_data: dict,
-) -> Optional[VariantBaseDB]:
+async def fetch_base_by_id(base_id: str) -> Optional[VariantBaseDB]:
"""
Fetches a base by its ID.
Args:
@@ -233,17 +219,12 @@ async def fetch_base_by_id(
"""
if base_id is None:
raise Exception("No base_id provided")
- base = await VariantBaseDB.find_one(VariantBaseDB.id == ObjectId(base_id))
+ base = await VariantBaseDB.find_one(
+ VariantBaseDB.id == ObjectId(base_id), fetch_links=True
+ )
if base is None:
logger.error("Base not found")
return False
- organization_id = base.organization.id
- access = await check_user_org_access(
- user_org_data, str(organization_id), check_owner=False
- )
- if not access:
- logger.error("User does not have access to this base")
- return False
return base
@@ -270,26 +251,40 @@ async def fetch_app_variant_by_name_and_appid(
async def create_new_variant_base(
app: AppDB,
- organization: OrganizationDB,
user: UserDB,
base_name: str,
image: ImageDB,
+ organization=None,
+ workspace=None,
) -> VariantBaseDB:
"""Create a new base.
Args:
base_name (str): The name of the base.
image (ImageDB): The image of the base.
+ user (UserDB): The User Object creating the variant.
+ app (AppDB): The associated App Object.
+ organization (OrganizationDB): The Organization the variant belongs to.
+ workspace (WorkspaceDB): The Workspace the variant belongs to.
Returns:
VariantBaseDB: The created base.
"""
logger.debug(f"Creating new base: {base_name} with image: {image} for app: {app}")
base = VariantBaseDB(
app=app,
- organization=organization,
user=user,
base_name=base_name,
image=image,
)
+
+ if isCloudEE():
+ # assert that if organization is provided, workspace_id is also provided, and vice versa
+ assert (
+ organization is not None and workspace is not None
+ ), "organization and workspace must be provided together"
+
+ base.organization = organization
+ base.workspace = workspace
+
await base.create()
return base
@@ -314,7 +309,6 @@ async def create_new_config(
async def create_new_app_variant(
app: AppDB,
- organization: OrganizationDB,
user: UserDB,
variant_name: str,
image: ImageDB,
@@ -323,6 +317,8 @@ async def create_new_app_variant(
base_name: str,
config_name: str,
parameters: Dict,
+ organization=None,
+ workspace=None,
) -> AppVariantDB:
"""Create a new variant.
Args:
@@ -338,7 +334,6 @@ async def create_new_app_variant(
), "Parameters should be empty when calling create_new_app_variant (otherwise revision should not be set to 0)"
variant = AppVariantDB(
app=app,
- organization=organization,
user=user,
modified_by=user,
revision=0,
@@ -351,6 +346,15 @@ async def create_new_app_variant(
parameters=parameters,
)
+ if isCloudEE():
+ # assert that if organization is provided, workspace_id is also provided, and vice versa
+ assert (
+ organization is not None and workspace is not None
+ ), "organization and workspace must be provided together"
+
+ variant.organization = organization
+ variant.workspace = workspace
+
await variant.create()
variant_revision = AppVariantRevisionsDB(
@@ -369,7 +373,8 @@ async def create_image(
image_type: str,
user: UserDB,
deletable: bool,
- organization: OrganizationDB,
+ organization=None,
+ workspace=None,
template_uri: str = None,
docker_id: str = None,
tags: str = None,
@@ -381,6 +386,7 @@ async def create_image(
user (UserDB): The user that the image belongs to.
deletable (bool): Whether the image can be deleted.
organization (OrganizationDB): The organization that the image belongs to.
+ workspace (WorkspaceDB): The workspace that the image belongs to.
Returns:
ImageDB: The created image.
"""
@@ -400,40 +406,47 @@ async def create_image(
elif image_type == "zip" and template_uri is None:
raise Exception("template_uri must be provided for type zip")
+ image = ImageDB(
+ deletable=deletable,
+ user=user,
+ )
+
if image_type == "zip":
- image = ImageDB(
- type="zip",
- template_uri=template_uri,
- deletable=deletable,
- user=user,
- organization=organization,
- )
+ image.type = "zip"
+ image.template_uri = template_uri
elif image_type == "image":
- image = ImageDB(
- type="image",
- docker_id=docker_id,
- tags=tags,
- deletable=deletable,
- user=user,
- organization=organization,
- )
+ image.type = "image"
+ image.tags = tags
+ image.docker_id = docker_id
+
+ if isCloudEE():
+ # assert that if organization is provided, workspace_id is also provided, and vice versa
+ assert (
+ organization is not None and workspace is not None
+ ), "organization and workspace must be provided together"
+
+ image.organization = organization
+ image.workspace = workspace
+
await image.create()
return image
async def create_deployment(
app: AppVariantDB,
- organization: OrganizationDB,
user: UserDB,
container_name: str,
container_id: str,
uri: str,
status: str,
+ organization=None,
+ workspace=None,
) -> DeploymentDB:
"""Create a new deployment.
Args:
app (AppVariantDB): The app variant to create the deployment for.
organization (OrganizationDB): The organization that the deployment belongs to.
+ workspace (WorkspaceDB): The Workspace that the deployment belongs to.
user (UserDB): The user that the deployment belongs to.
container_name (str): The name of the container.
container_id (str): The ID of the container.
@@ -442,29 +455,40 @@ async def create_deployment(
Returns:
DeploymentDB: The created deployment.
"""
- deployment = DeploymentDB(
- app=app,
- organization=organization,
- user=user,
- container_name=container_name,
- container_id=container_id,
- uri=uri,
- status=status,
- )
- await deployment.create()
- return deployment
+ try:
+ deployment = DeploymentDB(
+ app=app,
+ user=user,
+ container_name=container_name,
+ container_id=container_id,
+ uri=uri,
+ status=status,
+ )
+
+ if isCloudEE():
+ deployment.organization = organization
+ deployment.workspace = workspace
+
+ await deployment.create()
+ return deployment
+ except Exception as e:
+ raise Exception(f"Error while creating deployment: {e}")
async def create_app_and_envs(
- app_name: str, organization_id: str, **user_org_data
+ app_name: str,
+ user_uid: str,
+ organization_id: str = None,
+ workspace_id: str = None,
) -> AppDB:
"""
Create a new app with the given name and organization ID.
Args:
app_name (str): The name of the app to create.
+ user_uid (str): The UID of the user that the app belongs to.
organization_id (str): The ID of the organization that the app belongs to.
- **user_org_data: Additional keyword arguments.
+ workspace_id (str): The ID of the workspace that the app belongs to.
Returns:
AppDB: The created app.
@@ -473,36 +497,33 @@ async def create_app_and_envs(
ValueError: If an app with the same name already exists.
"""
- user_instance = await get_user(user_uid=user_org_data["uid"])
- app = await fetch_app_by_name(app_name, organization_id, **user_org_data)
+ user_instance = await get_user(user_uid)
+ app = await fetch_app_by_name_and_parameters(
+ app_name,
+ user_uid,
+ organization_id,
+ workspace_id,
+ )
if app is not None:
raise ValueError("App with the same name already exists")
- organization_db = await get_organization_object(organization_id)
- app = AppDB(
- app_name=app_name,
- organization=organization_db,
- user=user_instance,
- )
- await app.create()
- await initialize_environments(app, **user_org_data)
- return app
-
+ app = AppDB(app_name=app_name, user=user_instance)
-async def create_user_organization(user_uid: str) -> OrganizationDB:
- """Create a default organization for a user.
+ if isCloudEE():
+ # assert that if organization_id is provided, workspace_id is also provided, and vice versa
+ assert (
+ organization_id is not None and workspace_id is not None
+ ), "org_id and workspace_id must be provided together"
- Args:
- user_uid (str): The uid of the user
+ organization_db = await db_manager_ee.get_organization(organization_id)
+ workspace_db = await db_manager_ee.get_workspace(workspace_id)
- Returns:
- OrganizationDB: Instance of OrganizationDB
- """
+ app.organization = organization_db
+ app.workspace = workspace_db
- user = await UserDB.find_one(UserDB.uid == user_uid)
- org_db = OrganizationDB(owner=str(user.id), type="default")
- await org_db.create()
- return org_db
+ await app.create()
+ await initialize_environments(app)
+ return app
async def get_deployment_by_objectid(
@@ -541,41 +562,8 @@ async def get_deployment_by_appid(app_id: str) -> DeploymentDB:
return deployment
-async def get_organization_object(organization_id: str) -> OrganizationDB:
- """
- Fetches an organization by its ID.
-
- Args:
- organization_id (str): The ID of the organization to fetch.
-
- Returns:
- OrganizationDB: The fetched organization.
- """
- organization = await OrganizationDB.find_one(
- OrganizationDB.id == ObjectId(organization_id)
- )
- return organization
-
-
-async def get_organizations_by_list_ids(organization_ids: List) -> List:
- """
- Retrieve organizations from the database by their IDs.
-
- Args:
- organization_ids (List): A list of organization IDs to retrieve.
-
- Returns:
- List: A list of dictionaries representing the retrieved organizations.
- """
-
- organizations_db = await OrganizationDB.find(
- In(OrganizationDB.id, organization_ids)
- ).to_list()
- return organizations_db
-
-
async def list_app_variants_for_app_id(
- app_id: str, **kwargs: dict
+ app_id: str,
) -> List[AppVariantDB]:
"""
Lists all the app variants from the db
@@ -592,7 +580,7 @@ async def list_app_variants_for_app_id(
async def list_bases_for_app_id(
- app_id: str, base_name: Optional[str] = None, **kwargs: dict
+ app_id: str, base_name: Optional[str] = None
) -> List[VariantBaseDB]:
"""List all the bases for the specified app_id
@@ -612,9 +600,7 @@ async def list_bases_for_app_id(
return bases_db
-async def list_variants_for_base(
- base: VariantBaseDB, **kwargs: dict
-) -> List[AppVariantDB]:
+async def list_variants_for_base(base: VariantBaseDB) -> List[AppVariantDB]:
"""
Lists all the app variants from the db for a base
Args:
@@ -645,16 +631,11 @@ async def get_user(user_uid: str) -> UserDB:
user = await UserDB.find_one(UserDB.uid == user_uid)
if user is None:
- if os.environ["FEATURE_FLAG"] not in ["cloud", "ee"]:
+ if not isCloudEE():
+ # create user
user_db = UserDB(uid="0")
user = await user_db.create()
- org_db = OrganizationDB(type="default", owner=str(user.id))
- org = await org_db.create()
-
- user_db.organizations.append(org.id)
- await user_db.save()
-
return user
raise Exception("Please login or signup")
return user
@@ -725,7 +706,7 @@ async def get_users_by_ids(user_ids: List) -> List:
async def get_orga_image_instance_by_docker_id(
- organization_id: str, docker_id: str
+ docker_id: str, organization_id: str = None, workspace_id: str = None
) -> ImageDB:
"""Get the image object from the database with the provided id.
@@ -737,15 +718,27 @@ async def get_orga_image_instance_by_docker_id(
ImageDB: instance of image object
"""
- image = await ImageDB.find_one(
- ImageDB.docker_id == docker_id,
- ImageDB.organization.id == ObjectId(organization_id),
- )
+ query_expression = {"docker_id": docker_id}
+
+ if isCloudEE():
+ # assert that if organization is provided, workspace_id is also provided, and vice versa
+ assert (
+ organization_id is not None and workspace_id is not None
+ ), "organization and workspace must be provided together"
+
+ query_expression.update(
+ {
+ "organization.id": ObjectId(organization_id),
+ "workspace.id": ObjectId(workspace_id),
+ }
+ )
+
+ image = await ImageDB.find_one(query_expression)
return image
async def get_orga_image_instance_by_uri(
- organization_id: str, template_uri: str
+ template_uri: str, organization_id: str = None, workspace_id: str = None
) -> ImageDB:
"""Get the image object from the database with the provided id.
@@ -761,10 +754,22 @@ async def get_orga_image_instance_by_uri(
if not parsed_url.scheme and not parsed_url.netloc:
raise ValueError(f"Invalid URL: {template_uri}")
- image = await ImageDB.find_one(
- ImageDB.template_uri == template_uri,
- ImageDB.organization.id == ObjectId(organization_id),
- )
+ query_expression = {"template_uri": template_uri}
+
+ if isCloudEE():
+ # assert that if organization is provided, workspace_id is also provided, and vice versa
+ assert (
+ organization_id is not None and workspace_id is not None
+ ), "organization and workspace must be provided together"
+
+ query_expression.update(
+ {
+ "organization.id": ObjectId(organization_id),
+ "workspace.id": ObjectId(workspace_id),
+ }
+ )
+
+ image = await ImageDB.find_one(query_expression)
return image
@@ -786,7 +791,7 @@ async def add_variant_from_base_and_config(
base_db: VariantBaseDB,
new_config_name: str,
parameters: Dict[str, Any],
- **user_org_data: dict,
+ user_uid: str,
):
"""
Add a new variant to the database based on an existing base and a new configuration.
@@ -795,7 +800,7 @@ async def add_variant_from_base_and_config(
base_db (VariantBaseDB): The existing base to use as a template for the new variant.
new_config_name (str): The name of the new configuration to use for the new variant.
parameters (Dict[str, Any]): The parameters to use for the new configuration.
- **user_org_data (dict): Additional user and organization data.
+ user_uid (str): The UID of the user
Returns:
AppVariantDB: The newly created app variant.
@@ -813,7 +818,7 @@ async def add_variant_from_base_and_config(
)
if already_exists:
raise ValueError("App variant with the same name already exists")
- user_db = await get_user(user_uid=user_org_data["uid"])
+ user_db = await get_user(user_uid)
config_db = ConfigDB(
config_name=new_config_name,
parameters=parameters,
@@ -825,7 +830,6 @@ async def add_variant_from_base_and_config(
user=user_db,
modified_by=user_db,
revision=1,
- organization=previous_app_variant_db.organization,
parameters=parameters,
previous_variant_name=previous_app_variant_db.variant_name, # TODO: Remove in future
base_name=base_db.base_name,
@@ -834,6 +838,11 @@ async def add_variant_from_base_and_config(
config=config_db,
is_deleted=False,
)
+
+ if isCloudEE():
+ db_app_variant.organization = previous_app_variant_db.organization
+ db_app_variant.workspace = previous_app_variant_db.workspace
+
await db_app_variant.create()
variant_revision = AppVariantRevisionsDB(
variant=db_app_variant,
@@ -848,7 +857,10 @@ async def add_variant_from_base_and_config(
async def list_apps(
- app_name: str = None, org_id: str = None, **user_org_data: dict
+ user_uid: str,
+ app_name: str = None,
+ org_id: str = None,
+ workspace_id: str = None,
) -> List[App]:
"""
Lists all the unique app names and their IDs from the database
@@ -860,32 +872,59 @@ async def list_apps(
List[App]
"""
- user = await get_user(user_uid=user_org_data["uid"])
+ user = await get_user(user_uid)
assert user is not None, "User is None"
if app_name is not None:
- app_db = await fetch_app_by_name(app_name, org_id, **user_org_data)
- return [app_db_to_pydantic(app_db)]
- elif org_id is not None:
- organization_access = await check_user_org_access(user_org_data, org_id)
- if organization_access:
- apps: List[AppDB] = await AppDB.find(
- AppDB.organization.id == ObjectId(org_id)
- ).to_list()
- return [app_db_to_pydantic(app) for app in apps]
+ app_db = await fetch_app_by_name_and_parameters(
+ app_name=app_name,
+ user_uid=user_uid,
+ organization_id=org_id,
+ workspace_id=workspace_id,
+ )
+ return [converters.app_db_to_pydantic(app_db)]
- else:
+ elif (org_id is not None) or (workspace_id is not None):
+ if not isCloudEE():
+ return JSONResponse(
+ {
+ "error": "organization and/or workspace is only available in Cloud and EE"
+ },
+ status_code=400,
+ )
+
+ # assert that if org_id is provided, workspace_id is also provided, and vice versa
+ assert (
+ org_id is not None and workspace_id is not None
+ ), "org_id and workspace_id must be provided together"
+
+ user_org_workspace_data = await get_user_org_and_workspace_id(user_uid)
+ has_permission = await check_rbac_permission(
+ user_org_workspace_data=user_org_workspace_data,
+ workspace_id=ObjectId(workspace_id),
+ organization_id=ObjectId(org_id),
+ permission=Permission.VIEW_APPLICATION,
+ )
+ logger.debug(f"User has Permission to list apps: {has_permission}")
+ if not has_permission:
+ error_msg = f"You do not have access to perform this action. Please contact your organization admin."
return JSONResponse(
- {"error": "You do not have permission to access this organization"},
+ {"detail": error_msg},
status_code=403,
)
+ apps: List[AppDB] = await AppDB.find(
+ AppDB.organization.id == ObjectId(org_id),
+ AppDB.workspace.id == ObjectId(workspace_id),
+ ).to_list()
+ return [converters.app_db_to_pydantic(app) for app in apps]
+
else:
apps = await AppDB.find(AppDB.user.id == user.id).to_list()
- return [app_db_to_pydantic(app) for app in apps]
+ return [converters.app_db_to_pydantic(app) for app in apps]
-async def list_app_variants(app_id: str = None, **kwargs: dict) -> List[AppVariantDB]:
+async def list_app_variants(app_id: str) -> List[AppVariantDB]:
"""
Lists all the app variants from the db
Args:
@@ -913,14 +952,21 @@ async def check_is_last_variant_for_image(db_app_variant: AppVariantDB) -> bool:
true if it's the last variant, false otherwise
"""
- count_variants = await AppVariantDB.find(
- AppVariantDB.organization.id == db_app_variant.organization.id,
- AppVariantDB.base.id == db_app_variant.base.id,
- ).count()
+ query_expression = {"base.id": db_app_variant.base.id}
+
+ if isCloudEE():
+ query_expression.update(
+ {
+ "organization.id": db_app_variant.organization.id,
+ "workspace.id": db_app_variant.workspace.id,
+ }
+ )
+
+ count_variants = await AppVariantDB.find(query_expression).count()
return count_variants == 1
-async def remove_deployment(deployment_db: DeploymentDB, **kwargs: dict):
+async def remove_deployment(deployment_db: DeploymentDB):
"""Remove a deployment from the db
Arguments:
@@ -932,7 +978,7 @@ async def remove_deployment(deployment_db: DeploymentDB, **kwargs: dict):
await deployment_db.delete()
-async def remove_app_variant_from_db(app_variant_db: AppVariantDB, **kwargs: dict):
+async def remove_app_variant_from_db(app_variant_db: AppVariantDB):
"""Remove an app variant from the db
the logic for removing the image is in app_manager.py
@@ -944,17 +990,12 @@ async def remove_app_variant_from_db(app_variant_db: AppVariantDB, **kwargs: dic
# Remove the variant from the associated environments
logger.debug("list_environments_by_variant")
- environments = await list_environments_by_variant(
- app_variant_db,
- **kwargs,
- )
+ environments = await list_environments_by_variant(app_variant_db)
for environment in environments:
environment.deployed_app_variant = None
await environment.save()
- app_variant_revisions = await list_app_variant_revisions_by_variant(
- app_variant_db, **kwargs
- )
+ app_variant_revisions = await list_app_variant_revisions_by_variant(app_variant_db)
for app_variant_revision in app_variant_revisions:
await app_variant_revision.delete()
@@ -970,7 +1011,6 @@ async def deploy_to_environment(
Args:
environment_name (str): The name of the environment to deploy the app variant to.
variant_id (str): The ID of the app variant to deploy.
- **kwargs (dict): Additional keyword arguments.
Raises:
ValueError: If the app variant is not found or if the environment is not found or if the app variant is already
@@ -1010,7 +1050,7 @@ async def deploy_to_environment(
environment_db.deployment = deployment.id
# Create revision for app environment
- user = await get_user(user_uid=user_org_data["uid"])
+ user = await get_user(user_uid=user_org_data["user_uid"])
await create_environment_revision(
environment_db,
user,
@@ -1130,7 +1170,6 @@ async def list_environments(app_id: str, **kwargs: dict) -> List[AppEnvironmentD
Args:
app_id (str): The ID of the app to list environments for.
- **kwargs (dict): Additional keyword arguments.
Returns:
List[AppEnvironmentDB]: A list of AppEnvironmentDB objects representing the environments for the given app ID.
@@ -1147,47 +1186,42 @@ async def list_environments(app_id: str, **kwargs: dict) -> List[AppEnvironmentD
return environments_db
-async def initialize_environments(
- app_db: AppDB, **kwargs: dict
-) -> List[AppEnvironmentDB]:
+async def initialize_environments(app_db: AppDB) -> List[AppEnvironmentDB]:
"""
Initializes the environments for the app with the given database.
Args:
app_db (AppDB): The database for the app.
- **kwargs (dict): Additional keyword arguments.
Returns:
List[AppEnvironmentDB]: A list of the initialized environments.
"""
environments = []
for env_name in ["development", "staging", "production"]:
- env = await create_environment(name=env_name, app_db=app_db, **kwargs)
+ env = await create_environment(name=env_name, app_db=app_db)
environments.append(env)
return environments
-async def create_environment(
- name: str, app_db: AppDB, **kwargs: dict
-) -> AppEnvironmentDB:
+async def create_environment(name: str, app_db: AppDB) -> AppEnvironmentDB:
"""
Creates a new environment in the database.
Args:
name (str): The name of the environment.
app_db (AppDB): The AppDB object representing the app that the environment belongs to.
- **kwargs (dict): Additional keyword arguments.
Returns:
AppEnvironmentDB: The newly created AppEnvironmentDB object.
"""
environment_db = AppEnvironmentDB(
- app=app_db,
- name=name,
- user=app_db.user,
- revision=0,
- organization=app_db.organization,
+ app=app_db, name=name, user=app_db.user, revision=0
)
+
+ if isCloudEE():
+ environment_db.organization = app_db.organization
+ environment_db.workspace = app_db.workspace
+
await environment_db.create()
return environment_db
@@ -1221,17 +1255,21 @@ async def create_environment_revision(
deployment = kwargs.get("deployment")
if deployment is not None:
environment_revision.deployment = deployment
+
+ if isCloudEE():
+ environment_revision.organization = environment.organization
+ environment_revision.workspace = environment.workspace
+
await environment_revision.create()
async def list_app_variant_revisions_by_variant(
- app_variant: AppVariantDB, **kwargs: dict
+ app_variant: AppVariantDB,
) -> List[AppVariantRevisionsDB]:
"""Returns list of app variant revision for the given app variant
Args:
app_variant (AppVariantDB): The app variant to retrieve environments for.
- **kwargs (dict): Additional keyword arguments.
Returns:
List[AppVariantRevisionsDB]: A list of AppVariantRevisionsDB objects.
@@ -1243,14 +1281,13 @@ async def list_app_variant_revisions_by_variant(
async def list_environments_by_variant(
- app_variant: AppVariantDB, **kwargs: dict
+ app_variant: AppVariantDB,
) -> List[AppEnvironmentDB]:
"""
Returns a list of environments for a given app variant.
Args:
app_variant (AppVariantDB): The app variant to retrieve environments for.
- **kwargs (dict): Additional keyword arguments.
Returns:
List[AppEnvironmentDB]: A list of AppEnvironmentDB objects.
@@ -1262,13 +1299,12 @@ async def list_environments_by_variant(
return environments_db
-async def remove_image(image: ImageDB, **kwargs: dict):
+async def remove_image(image: ImageDB):
"""
Removes an image from the database.
Args:
image (ImageDB): The image to remove from the database.
- **kwargs (dict): Additional keyword arguments.
Raises:
ValueError: If the image is None.
@@ -1281,13 +1317,12 @@ async def remove_image(image: ImageDB, **kwargs: dict):
await image.delete()
-async def remove_environment(environment_db: AppEnvironmentDB, **kwargs: dict):
+async def remove_environment(environment_db: AppEnvironmentDB):
"""
Removes an environment from the database.
Args:
environment_db (AppEnvironmentDB): The environment to remove from the database.
- **kwargs (dict): Additional keyword arguments.
Raises:
AssertionError: If environment_db is None.
@@ -1299,7 +1334,7 @@ async def remove_environment(environment_db: AppEnvironmentDB, **kwargs: dict):
await environment_db.delete()
-async def remove_app_testsets(app_id: str, **kwargs):
+async def remove_app_testsets(app_id: str):
"""Returns a list of testsets owned by an app.
Args:
@@ -1330,13 +1365,12 @@ async def remove_app_testsets(app_id: str, **kwargs):
return 0
-async def remove_base_from_db(base: VariantBaseDB, **kwargs):
+async def remove_base_from_db(base: VariantBaseDB):
"""
Remove a base from the database.
Args:
base (VariantBaseDB): The base to be removed from the database.
- **kwargs: Additional keyword arguments.
Raises:
ValueError: If the base is None.
@@ -1349,7 +1383,7 @@ async def remove_base_from_db(base: VariantBaseDB, **kwargs):
await base.delete()
-async def remove_app_by_id(app_id: str, **kwargs):
+async def remove_app_by_id(app_id: str):
"""
Removes an app instance from the database by its ID.
@@ -1369,7 +1403,7 @@ async def remove_app_by_id(app_id: str, **kwargs):
async def update_variant_parameters(
- app_variant_db: AppVariantDB, parameters: Dict[str, Any], **user_org_data: dict
+ app_variant_db: AppVariantDB, parameters: Dict[str, Any], user_uid: str
) -> None:
"""
Update the parameters of an app variant in the database.
@@ -1377,7 +1411,7 @@ async def update_variant_parameters(
Args:
app_variant_db (AppVariantDB): The app variant to update.
parameters (Dict[str, Any]): The new parameters to set for the app variant.
- **kwargs (dict): Additional keyword arguments.
+ user_uid (str): The UID of the user that is updating the app variant.
Raises:
ValueError: If there is an issue updating the variant parameters.
@@ -1387,7 +1421,7 @@ async def update_variant_parameters(
try:
logging.debug("Updating variant parameters")
- user = await get_user(user_uid=user_org_data["uid"])
+ user = await get_user(user_uid)
# Update associated ConfigDB parameters and versioning
config_db = app_variant_db.config
config_db.parameters = parameters
@@ -1513,7 +1547,7 @@ async def fetch_evaluation_scenario_by_id(
"""
assert evaluation_scenario_id is not None, "evaluation_scenario_id cannot be None"
evaluation_scenario = await EvaluationScenarioDB.find_one(
- EvaluationScenarioDB.id == ObjectId(evaluation_scenario_id)
+ EvaluationScenarioDB.id == ObjectId(evaluation_scenario_id, fetch_links=True)
)
return evaluation_scenario
@@ -1535,6 +1569,23 @@ async def fetch_human_evaluation_scenario_by_id(
return evaluation_scenario
+async def fetch_human_evaluation_scenario_by_evaluation_id(
+ evaluation_id: str,
+) -> Optional[HumanEvaluationScenarioDB]:
+ """Fetches and evaluation scenario by its ID.
+ Args:
+ evaluation_id (str): The ID of the evaluation object to use in fetching the human evaluation.
+ Returns:
+ EvaluationScenarioDB: The fetched evaluation scenario, or None if no evaluation scenario was found.
+ """
+ evaluation = await fetch_human_evaluation_by_id(evaluation_id)
+ human_eval_scenario = await HumanEvaluationScenarioDB.find_one(
+ HumanEvaluationScenarioDB.evaluation.id == ObjectId(evaluation.id),
+ fetch_links=True,
+ )
+ return human_eval_scenario
+
+
async def find_previous_variant_from_base_id(
base_id: str,
) -> Optional[AppVariantDB]:
@@ -1657,22 +1708,7 @@ async def remove_old_template_from_db(tag_ids: list) -> None:
async def get_templates() -> List[Template]:
templates = await TemplateDB.find().to_list()
- return templates_db_to_pydantic(templates)
-
-
-async def count_apps(**user_org_data: dict) -> int:
- """
- Counts all the unique app names from the database
- """
-
- # Get user object
- user = await get_user(user_uid=user_org_data["uid"])
- if user is None:
- return 0
-
- query_expressions = AppVariantDB.user.id == user.id
- no_of_apps = await AppVariantDB.find(query_expressions).count()
- return no_of_apps
+ return converters.templates_db_to_pydantic(templates)
async def update_base(
@@ -1710,138 +1746,59 @@ async def update_app_variant(
return app_variant
-async def fetch_base_and_check_access(
- base_id: str, user_org_data: dict, check_owner=False
-):
- """
- Fetches a base from the database and checks if the user has access to it.
-
- Args:
- base_id (str): The ID of the base to fetch.
- user_org_data (dict): The user's organization data.
- check_owner (bool, optional): Whether to check if the user is the owner of the base. Defaults to False.
-
- Raises:
- Exception: If no base_id is provided.
- HTTPException: If the base is not found or the user does not have access to it.
-
- Returns:
- VariantBaseDB: The fetched base.
- """
- if base_id is None:
- raise Exception("No base_id provided")
- base = await VariantBaseDB.find_one(
- VariantBaseDB.id == ObjectId(base_id), fetch_links=True
- )
- if base is None:
- logger.error("Base not found")
- raise HTTPException(status_code=404, detail="Base not found")
- organization_id = base.organization.id
- access = await check_user_org_access(
- user_org_data, str(organization_id), check_owner
- )
- if not access:
- error_msg = f"You do not have access to this base: {base_id}"
- raise HTTPException(status_code=403, detail=error_msg)
- return base
-
-
-async def fetch_app_and_check_access(
- app_id: str, user_org_data: dict, check_owner=False
-):
- """
- Fetches an app from the database and checks if the user has access to it.
-
- Args:
- app_id (str): The ID of the app to fetch.
- user_org_data (dict): The user's organization data.
- check_owner (bool, optional): Whether to check if the user is the owner of the app. Defaults to False.
-
- Returns:
- dict: The fetched app.
-
- Raises:
- HTTPException: If the app is not found or the user does not have access to it.
- """
- app = await AppDB.find_one(AppDB.id == ObjectId(app_id), fetch_links=True)
- if app is None:
- logger.error("App not found")
- raise HTTPException
-
- # Check user's access to the organization linked to the app.
- organization_id = app.organization.id
- access = await check_user_org_access(
- user_org_data, str(organization_id), check_owner
- )
- if not access:
- error_msg = f"You do not have access to this app: {app_id}"
- raise HTTPException(status_code=403, detail=error_msg)
- return app
-
-
-async def fetch_app_variant_and_check_access(
- app_variant_id: str, user_org_data: dict, check_owner=False
+async def fetch_app_by_name_and_parameters(
+ app_name: str,
+ user_uid: str,
+ organization_id: str = None,
+ workspace_id: str = None,
):
- """
- Fetches an app variant from the database and checks if the user has access to it.
+ """Fetch an app by its name, organization id, and workspace id.
Args:
- app_variant_id (str): The ID of the app variant to fetch.
- user_org_data (dict): The user's organization data.
- check_owner (bool, optional): Whether to check if the user is the owner of the app variant. Defaults to False.
+ app_name (str): The name of the app
+ organization_id (str): The ID of the app organization
+ workspace_id (str): The ID of the app workspace
Returns:
- AppVariantDB: The fetched app variant.
-
- Raises:
- HTTPException: If the app variant is not found or the user does not have access to it.
+ AppDB: the instance of the app
"""
- app_variant = await AppVariantDB.find_one(
- AppVariantDB.id == ObjectId(app_variant_id), fetch_links=True
- )
- if app_variant is None:
- logger.error("App variant not found")
- raise HTTPException
-
- # Check user's access to the organization linked to the app.
- organization_id = app_variant.organization.id
- access = await check_user_org_access(
- user_org_data, str(organization_id), check_owner
- )
- if not access:
- error_msg = f"You do not have access to this app variant: {app_variant_id}"
- raise HTTPException(status_code=403, detail=error_msg)
- return app_variant
+ query_expression = {"app_name": app_name}
-async def fetch_app_by_name_and_organization(
- app_name: str, organization_id: str, **user_org_data: dict
-):
- """Fetch an app by it's name and organization id.
+ if isCloudEE():
+ # assert that if organization is provided, workspace_id is also provided, and vice versa
+ assert (
+ organization_id is not None and workspace_id is not None
+ ), "organization_id and workspace_id must be provided together"
- Args:
- app_name (str): The name of the app
- organization_id (str): The ID of the app organization
+ query_expression.update(
+ {
+ "organization.id": ObjectId(organization_id),
+ "workspace.id": ObjectId(workspace_id),
+ }
+ )
+ else:
+ query_expression.update(
+ {
+ "user.uid": user_uid,
+ }
+ )
- Returns:
- AppDB: the instance of the app
- """
+ app_db = await AppDB.find_one(query_expression, fetch_links=True)
- app_db = await AppDB.find_one(
- {"app_name": app_name, "organization": ObjectId(organization_id)}
- )
return app_db
async def create_new_evaluation(
app: AppDB,
- organization: OrganizationDB,
user: UserDB,
testset: TestSetDB,
status: str,
variant: str,
variant_revision: str,
evaluators_configs: List[str],
+ organization=None,
+ workspace=None,
) -> EvaluationDB:
"""Create a new evaluation scenario.
Returns:
@@ -1849,7 +1806,6 @@ async def create_new_evaluation(
"""
evaluation = EvaluationDB(
app=app,
- organization=organization,
user=user,
testset=testset,
status=status,
@@ -1860,13 +1816,22 @@ async def create_new_evaluation(
created_at=datetime.now().isoformat(),
updated_at=datetime.now().isoformat(),
)
+
+ if isCloudEE():
+ # assert that if organization is provided, workspace is also provided, and vice versa
+ assert (
+ organization is not None and workspace is not None
+ ), "organization and workspace must be provided together"
+
+ evaluation.organization = organization
+ evaluation.workspace = workspace
+
await evaluation.create()
return evaluation
async def create_new_evaluation_scenario(
user: UserDB,
- organization: OrganizationDB,
evaluation: EvaluationDB,
variant_id: str,
inputs: List[EvaluationScenarioInputDB],
@@ -1876,6 +1841,8 @@ async def create_new_evaluation_scenario(
note: Optional[str],
evaluators_configs: List[EvaluatorConfigDB],
results: List[EvaluationScenarioResult],
+ organization=None,
+ workspace=None,
) -> EvaluationScenarioDB:
"""Create a new evaluation scenario.
Returns:
@@ -1883,7 +1850,6 @@ async def create_new_evaluation_scenario(
"""
evaluation_scenario = EvaluationScenarioDB(
user=user,
- organization=organization,
evaluation=evaluation,
variant_id=ObjectId(variant_id),
inputs=inputs,
@@ -1894,6 +1860,16 @@ async def create_new_evaluation_scenario(
evaluators_configs=evaluators_configs,
results=results,
)
+
+ if isCloudEE():
+ # assert that if organization is provided, workspace is also provided, and vice versa
+ assert (
+ organization is not None and workspace is not None
+ ), "organization and workspace must be provided together"
+
+ evaluation_scenario.organization = organization
+ evaluation_scenario.workspace = workspace
+
await evaluation_scenario.create()
return evaluation_scenario
@@ -1939,7 +1915,7 @@ async def fetch_evaluator_config(evaluator_config_id: str):
try:
evaluator_config: EvaluatorConfigDB = await EvaluatorConfigDB.find_one(
- EvaluatorConfigDB.id == ObjectId(evaluator_config_id)
+ EvaluatorConfigDB.id == ObjectId(evaluator_config_id), fetch_links=True
)
return evaluator_config
except Exception as e:
@@ -1998,9 +1974,10 @@ async def fetch_evaluator_config_by_appId(
async def create_evaluator_config(
app: AppDB,
user: UserDB,
- organization: OrganizationDB,
name: str,
evaluator_key: str,
+ organization=None,
+ workspace=None,
settings_values: Optional[Dict[str, Any]] = None,
) -> EvaluatorConfigDB:
"""Create a new evaluator configuration in the database."""
@@ -2008,12 +1985,20 @@ async def create_evaluator_config(
new_evaluator_config = EvaluatorConfigDB(
app=app,
user=user,
- organization=organization,
name=name,
evaluator_key=evaluator_key,
settings_values=settings_values,
)
+ if isCloudEE():
+ # assert that if organization is provided, workspace is also provided, and vice versa
+ assert (
+ organization is not None and workspace is not None
+ ), "organization and workspace must be provided together"
+
+ new_evaluator_config.organization = organization
+ new_evaluator_config.workspace = workspace
+
try:
await new_evaluator_config.create()
return new_evaluator_config
diff --git a/agenta-backend/agenta_backend/services/deployment_manager.py b/agenta-backend/agenta_backend/services/deployment_manager.py
index 2059c61539..6511d319b7 100644
--- a/agenta-backend/agenta_backend/services/deployment_manager.py
+++ b/agenta-backend/agenta_backend/services/deployment_manager.py
@@ -3,6 +3,7 @@
from typing import Dict
from agenta_backend.config import settings
+from agenta_backend.utils.common import isCloudEE
from agenta_backend.models.api.api_models import Image
from agenta_backend.models.db_models import AppVariantDB, DeploymentDB, ImageDB
from agenta_backend.services import db_manager, docker_utils
@@ -19,18 +20,19 @@ async def start_service(
Start a service.
Args:
- image_name: List of image tags.
- app_name: Name of the app.
- base_name: Base name for the container.
- env_vars: Environment variables.
- organization_id: ID of the organization.
+ app_variant_db (AppVariantDB): The app variant to start.
+ env_vars (Dict[str, str]): The environment variables to pass to the container.
Returns:
True if successful, False otherwise.
"""
- uri_path = f"{app_variant_db.organization.id}/{app_variant_db.app.app_name}/{app_variant_db.base_name}"
- container_name = f"{app_variant_db.app.app_name}-{app_variant_db.base_name}-{app_variant_db.organization.id}"
+ if isCloudEE():
+ uri_path = f"{app_variant_db.organization.id}/{app_variant_db.app.app_name}/{app_variant_db.base_name}"
+ container_name = f"{app_variant_db.app.app_name}-{app_variant_db.base_name}-{app_variant_db.organization.id}"
+ else:
+ uri_path = f"{app_variant_db.user.id}/{app_variant_db.app.app_name}/{app_variant_db.base_name}"
+ container_name = f"{app_variant_db.app.app_name}-{app_variant_db.base_name}-{app_variant_db.user.id}"
logger.debug("Starting service with the following parameters:")
logger.debug(f"image_name: {app_variant_db.image.tags}")
logger.debug(f"uri_path: {uri_path}")
@@ -54,12 +56,13 @@ async def start_service(
deployment = await db_manager.create_deployment(
app=app_variant_db.app,
- organization=app_variant_db.organization,
user=app_variant_db.user,
container_name=container_name,
container_id=container_id,
uri=uri,
status="running",
+ organization=app_variant_db.organization if isCloudEE() else None,
+ workspace=app_variant_db.workspace if isCloudEE() else None,
)
return deployment
@@ -75,7 +78,7 @@ async def remove_image(image: Image):
None
"""
try:
- if os.environ["FEATURE_FLAG"] not in ["cloud", "ee"] and image.deletable:
+ if not isCloudEE() and image.deletable:
docker_utils.delete_image(image.docker_id)
logger.info(f"Image {image.docker_id} deleted")
except RuntimeError as e:
diff --git a/agenta-backend/agenta_backend/services/docker_utils.py b/agenta-backend/agenta_backend/services/docker_utils.py
index df5623f14b..9fced69465 100644
--- a/agenta-backend/agenta_backend/services/docker_utils.py
+++ b/agenta-backend/agenta_backend/services/docker_utils.py
@@ -134,6 +134,9 @@ def start_container(
logs = failed_container.logs().decode("utf-8")
raise Exception(f"Docker Logs: {logs}") from error
except Exception as e:
+ import traceback
+
+ traceback.print_exc()
logger.error(
f"Failed to fetch logs: {str(e)} \n Exception Error: {str(error)}"
)
diff --git a/agenta-backend/agenta_backend/services/evaluation_service.py b/agenta-backend/agenta_backend/services/evaluation_service.py
index 991b30aa24..dcdb9fa64a 100644
--- a/agenta-backend/agenta_backend/services/evaluation_service.py
+++ b/agenta-backend/agenta_backend/services/evaluation_service.py
@@ -1,9 +1,13 @@
import logging
from datetime import datetime
-from typing import Dict, List, Any
+from typing import Dict, List
from fastapi import HTTPException
+from agenta_backend.models import converters
+from agenta_backend.services import db_manager
+from agenta_backend.utils.common import isCloudEE
+
from agenta_backend.models.api.evaluation_model import (
Evaluation,
EvaluationScenario,
@@ -17,25 +21,34 @@
EvaluationStatusEnum,
NewHumanEvaluation,
)
-from agenta_backend.models import converters
-from agenta_backend.services import db_manager
-from agenta_backend.services.db_manager import get_user
-from agenta_backend.utils.common import check_access_to_app
+
+if isCloudEE():
+ from agenta_backend.commons.models.db_models import (
+ AppDB_ as AppDB,
+ UserDB_ as UserDB,
+ EvaluationDB_ as EvaluationDB,
+ HumanEvaluationDB_ as HumanEvaluationDB,
+ EvaluationScenarioDB_ as EvaluationScenarioDB,
+ HumanEvaluationScenarioDB_ as HumanEvaluationScenarioDB,
+ )
+else:
+ from agenta_backend.models.db_models import (
+ AppDB,
+ UserDB,
+ EvaluationDB,
+ HumanEvaluationDB,
+ EvaluationScenarioDB,
+ HumanEvaluationScenarioDB,
+ )
+
from agenta_backend.models.db_models import (
- AppVariantDB,
- EvaluationDB,
- EvaluationScenarioDB,
- HumanEvaluationDB,
- HumanEvaluationScenarioDB,
HumanEvaluationScenarioInput,
HumanEvaluationScenarioOutput,
Result,
- UserDB,
- AppDB,
)
-from beanie import PydanticObjectId as ObjectId
from beanie.operators import In
+from beanie import PydanticObjectId as ObjectId
logger = logging.getLogger(__name__)
@@ -48,34 +61,7 @@ class UpdateEvaluationScenarioError(Exception):
pass
-async def _fetch_evaluation_and_check_access(
- evaluation_id: str, **user_org_data: dict
-) -> EvaluationDB:
- # Fetch the evaluation by ID
- evaluation = await db_manager.fetch_evaluation_by_id(evaluation_id=evaluation_id)
-
- # Check if the evaluation exists
- if evaluation is None:
- raise HTTPException(
- status_code=404,
- detail=f"Evaluation with id {evaluation_id} not found",
- )
-
- # Check for access rights
- access = await check_access_to_app(
- user_org_data=user_org_data, app_id=evaluation.app.id
- )
- if not access:
- raise HTTPException(
- status_code=403,
- detail=f"You do not have access to this app: {str(evaluation.app.id)}",
- )
- return evaluation
-
-
-async def _fetch_human_evaluation_and_check_access(
- evaluation_id: str, **user_org_data: dict
-) -> HumanEvaluationDB:
+async def _fetch_human_evaluation(evaluation_id: str) -> HumanEvaluationDB:
# Fetch the evaluation by ID
evaluation = await db_manager.fetch_human_evaluation_by_id(
evaluation_id=evaluation_id
@@ -88,51 +74,9 @@ async def _fetch_human_evaluation_and_check_access(
detail=f"Evaluation with id {evaluation_id} not found",
)
- # Check for access rights
- access = await check_access_to_app(
- user_org_data=user_org_data, app_id=evaluation.app.id
- )
- if not access:
- raise HTTPException(
- status_code=403,
- detail=f"You do not have access to this app: {str(evaluation.app.id)}",
- )
return evaluation
-async def _fetch_human_evaluation_scenario_and_check_access(
- evaluation_scenario_id: str, **user_org_data: dict
-) -> HumanEvaluationDB:
- # Fetch the evaluation by ID
- evaluation_scenario = await db_manager.fetch_human_evaluation_scenario_by_id(
- evaluation_scenario_id=evaluation_scenario_id
- )
- if evaluation_scenario is None:
- raise HTTPException(
- status_code=404,
- detail=f"Evaluation scenario with id {evaluation_scenario_id} not found",
- )
- evaluation = evaluation_scenario.evaluation
-
- # Check if the evaluation exists
- if evaluation is None:
- raise HTTPException(
- status_code=404,
- detail=f"Evaluation scenario for evaluation scenario with id {evaluation_scenario_id} not found",
- )
-
- # Check for access rights
- access = await check_access_to_app(
- user_org_data=user_org_data, app_id=evaluation.app.id
- )
- if not access:
- raise HTTPException(
- status_code=403,
- detail=f"You do not have access to this app: {str(evaluation.app.id)}",
- )
- return evaluation_scenario
-
-
async def prepare_csvdata_and_create_evaluation_scenario(
csvdata: List[Dict[str, str]],
payload_inputs: List[str],
@@ -193,16 +137,20 @@ async def prepare_csvdata_and_create_evaluation_scenario(
eval_scenario_instance = HumanEvaluationScenarioDB(
**evaluation_scenario_payload,
user=user,
- organization=app.organization,
evaluation=new_evaluation,
inputs=list_of_scenario_input,
outputs=[],
)
+
+ if isCloudEE():
+ eval_scenario_instance.organization = app.organization
+ eval_scenario_instance.workspace = app.workspace
+
await eval_scenario_instance.create()
async def create_evaluation_scenario(
- evaluation_id: str, payload: EvaluationScenario, **user_org_data: dict
+ evaluation_id: str, payload: EvaluationScenario
) -> None:
"""
Create a new evaluation scenario.
@@ -210,14 +158,11 @@ async def create_evaluation_scenario(
Args:
evaluation_id (str): The ID of the evaluation.
payload (EvaluationScenario): Evaluation scenario data.
- user_org_data (dict): User and organization data.
Raises:
HTTPException: If evaluation not found or access denied.
"""
- evaluation = await _fetch_evaluation_and_check_access(
- evaluation_id=evaluation_id, **user_org_data
- )
+ evaluation = await db_manager.fetch_evaluation_by_id(evaluation_id)
scenario_inputs = [
EvaluationScenarioInput(
@@ -230,6 +175,7 @@ async def create_evaluation_scenario(
new_eval_scenario = EvaluationScenarioDB(
user=evaluation.user,
organization=evaluation.organization,
+ workspace=evaluation.workspace,
evaluation=evaluation,
inputs=scenario_inputs,
outputs=[],
@@ -244,23 +190,18 @@ async def create_evaluation_scenario(
async def update_human_evaluation_service(
- evaluation_id: str, update_payload: HumanEvaluationUpdate, **user_org_data: dict
+ evaluation: EvaluationDB, update_payload: HumanEvaluationUpdate
) -> None:
"""
Update an existing evaluation based on the provided payload.
Args:
- evaluation_id (str): The existing evaluation ID.
+ evaluation (EvaluationDB): The evaluation instance.
update_payload (EvaluationUpdate): The payload for the update.
Raises:
HTTPException: If the evaluation is not found or access is denied.
"""
- # Fetch the evaluation by ID
- evaluation = await _fetch_human_evaluation_and_check_access(
- evaluation_id=evaluation_id,
- **user_org_data,
- )
# Prepare updates
updates = {}
@@ -272,14 +213,14 @@ async def update_human_evaluation_service(
async def fetch_evaluation_scenarios_for_evaluation(
- evaluation_id: str, **user_org_data: dict
+ evaluation_id: str = None, evaluation: EvaluationDB = None
) -> List[EvaluationScenario]:
"""
Fetch evaluation scenarios for a given evaluation ID.
Args:
evaluation_id (str): The ID of the evaluation.
- user_org_data (dict): User and organization data.
+ evaluation (EvaluationDB): The evaluation instance.
Raises:
HTTPException: If the evaluation is not found or access is denied.
@@ -287,10 +228,13 @@ async def fetch_evaluation_scenarios_for_evaluation(
Returns:
List[EvaluationScenario]: A list of evaluation scenarios.
"""
- evaluation = await _fetch_evaluation_and_check_access(
- evaluation_id=evaluation_id,
- **user_org_data,
- )
+ assert (
+ evaluation_id or evaluation
+ ), "Please provide either evaluation_id or evaluation"
+
+ if not evaluation:
+ evaluation = await db_manager.fetch_evaluation_by_id(evaluation_id)
+
scenarios = await EvaluationScenarioDB.find(
EvaluationScenarioDB.evaluation.id == ObjectId(evaluation.id)
).to_list()
@@ -302,14 +246,13 @@ async def fetch_evaluation_scenarios_for_evaluation(
async def fetch_human_evaluation_scenarios_for_evaluation(
- evaluation_id: str, **user_org_data: dict
+ human_evaluation: HumanEvaluationDB,
) -> List[HumanEvaluationScenario]:
"""
Fetch evaluation scenarios for a given evaluation ID.
Args:
evaluation_id (str): The ID of the evaluation.
- user_org_data (dict): User and organization data.
Raises:
HTTPException: If the evaluation is not found or access is denied.
@@ -317,16 +260,12 @@ async def fetch_human_evaluation_scenarios_for_evaluation(
Returns:
List[EvaluationScenario]: A list of evaluation scenarios.
"""
- evaluation = await _fetch_human_evaluation_and_check_access(
- evaluation_id=evaluation_id,
- **user_org_data,
- )
scenarios = await HumanEvaluationScenarioDB.find(
- HumanEvaluationScenarioDB.evaluation.id == ObjectId(evaluation.id),
+ HumanEvaluationScenarioDB.evaluation.id == ObjectId(human_evaluation.id),
).to_list()
eval_scenarios = [
converters.human_evaluation_scenario_db_to_pydantic(
- scenario, str(evaluation.id)
+ scenario, str(human_evaluation.id)
)
for scenario in scenarios
]
@@ -334,27 +273,21 @@ async def fetch_human_evaluation_scenarios_for_evaluation(
async def update_human_evaluation_scenario(
- evaluation_scenario_id: str,
+ evaluation_scenario_db: HumanEvaluationScenarioDB,
evaluation_scenario_data: EvaluationScenarioUpdate,
evaluation_type: EvaluationType,
- **user_org_data,
) -> None:
"""
Updates an evaluation scenario.
Args:
- evaluation_scenario_id (str): The ID of the evaluation scenario.
+ evaluation_scenario_db (EvaluationScenarioDB): The evaluation scenario instance.
evaluation_scenario_data (EvaluationScenarioUpdate): New data for the scenario.
evaluation_type (EvaluationType): Type of the evaluation.
- user_org_data (dict): User and organization data.
Raises:
HTTPException: If evaluation scenario not found or access denied.
"""
- eval_scenario = await _fetch_human_evaluation_scenario_and_check_access(
- evaluation_scenario_id=evaluation_scenario_id,
- **user_org_data,
- )
updated_data = evaluation_scenario_data.dict()
updated_data["updated_at"] = datetime.now()
@@ -399,52 +332,7 @@ async def update_human_evaluation_scenario(
if updated_data["correct_answer"] is not None:
new_eval_set["correct_answer"] = updated_data["correct_answer"]
- await eval_scenario.update({"$set": new_eval_set})
-
-
-async def update_evaluation_scenario_score_service(
- evaluation_scenario_id: str, score: float, **user_org_data: dict
-) -> None:
- """
- Updates the score of an evaluation scenario.
-
- Args:
- evaluation_scenario_id (str): The ID of the evaluation scenario.
- score (float): The new score to set.
- user_org_data (dict): User and organization data.
-
- Raises:
- HTTPException: If evaluation scenario not found or access denied.
- """
- eval_scenario = await _fetch_human_evaluation_scenario_and_check_access(
- evaluation_scenario_id, **user_org_data
- )
- eval_scenario.score = score
-
- # Save the updated evaluation scenario
- await eval_scenario.save()
-
-
-async def get_evaluation_scenario_score_service(
- evaluation_scenario_id: str, **user_org_data: dict
-) -> Dict[str, str]:
- """
- Retrieve the score of a given evaluation scenario.
-
- Args:
- evaluation_scenario_id: The ID of the evaluation scenario.
- user_org_data: Additional user and organization data.
-
- Returns:
- Dictionary with 'scenario_id' and 'score' keys.
- """
- evaluation_scenario = await _fetch_human_evaluation_scenario_and_check_access(
- evaluation_scenario_id, **user_org_data
- )
- return {
- "scenario_id": str(evaluation_scenario.id),
- "score": evaluation_scenario.score,
- }
+ await evaluation_scenario_db.update({"$set": new_eval_set})
def _extend_with_evaluation(evaluation_type: EvaluationType):
@@ -465,28 +353,20 @@ def _extend_with_correct_answer(evaluation_type: EvaluationType, row: dict):
async def fetch_list_evaluations(
- app_id: str,
- **user_org_data: dict,
+ app: AppDB,
) -> List[Evaluation]:
"""
Fetches a list of evaluations based on the provided filtering criteria.
Args:
- app_id (Optional[str]): An optional app ID to filter the evaluations.
- user_org_data (dict): User and organization data.
+ app (AppDB): An app to filter the evaluations.
Returns:
List[Evaluation]: A list of evaluations.
"""
- access = await check_access_to_app(user_org_data=user_org_data, app_id=app_id)
- if not access:
- raise HTTPException(
- status_code=403,
- detail=f"You do not have access to this app: {app_id}",
- )
evaluations_db = await EvaluationDB.find(
- EvaluationDB.app.id == ObjectId(app_id), fetch_links=True
+ EvaluationDB.app.id == app.id, fetch_links=True
).to_list()
return [
await converters.evaluation_db_to_pydantic(evaluation)
@@ -494,44 +374,18 @@ async def fetch_list_evaluations(
]
-async def fetch_evaluation(evaluation_id: str, **user_org_data: dict) -> Evaluation:
- """
- Fetches a single evaluation based on its ID.
-
- Args:
- evaluation_id (str): The ID of the evaluation.
- user_org_data (dict): User and organization data.
-
- Returns:
- Evaluation: The fetched evaluation.
- """
- evaluation = await _fetch_evaluation_and_check_access(
- evaluation_id=evaluation_id, **user_org_data
- )
- return await converters.evaluation_db_to_pydantic(evaluation)
-
-
async def fetch_list_human_evaluations(
app_id: str,
- **user_org_data: dict,
) -> List[HumanEvaluation]:
"""
Fetches a list of evaluations based on the provided filtering criteria.
Args:
app_id (Optional[str]): An optional app ID to filter the evaluations.
- user_org_data (dict): User and organization data.
Returns:
List[Evaluation]: A list of evaluations.
"""
- access = await check_access_to_app(user_org_data=user_org_data, app_id=app_id)
- if not access:
- raise HTTPException(
- status_code=403,
- detail=f"You do not have access to this app: {app_id}",
- )
-
evaluations_db = await HumanEvaluationDB.find(
HumanEvaluationDB.app.id == ObjectId(app_id), fetch_links=True
).to_list()
@@ -541,77 +395,63 @@ async def fetch_list_human_evaluations(
]
-async def fetch_human_evaluation(
- evaluation_id: str, **user_org_data: dict
-) -> HumanEvaluation:
+async def fetch_human_evaluation(human_evaluation_db) -> HumanEvaluation:
"""
Fetches a single evaluation based on its ID.
Args:
- evaluation_id (str): The ID of the evaluation.
- user_org_data (dict): User and organization data.
+ human_evaluation_db (HumanEvaluationDB): The evaluation instance.
Returns:
Evaluation: The fetched evaluation.
"""
- evaluation = await _fetch_human_evaluation_and_check_access(
- evaluation_id=evaluation_id, **user_org_data
- )
- return await converters.human_evaluation_db_to_pydantic(evaluation)
+ return await converters.human_evaluation_db_to_pydantic(human_evaluation_db)
-async def delete_human_evaluations(
- evaluation_ids: List[str], **user_org_data: dict
-) -> None:
+async def delete_human_evaluations(evaluation_ids: List[str]) -> None:
"""
Delete evaluations by their IDs.
Args:
evaluation_ids (List[str]): A list of evaluation IDs.
- user_org_data (dict): User and organization data.
Raises:
HTTPException: If evaluation not found or access denied.
"""
for evaluation_id in evaluation_ids:
- evaluation = await _fetch_human_evaluation_and_check_access(
- evaluation_id=evaluation_id, **user_org_data
- )
+ evaluation = await _fetch_human_evaluation(evaluation_id=evaluation_id)
await evaluation.delete()
-async def delete_evaluations(evaluation_ids: List[str], **user_org_data: dict) -> None:
+async def delete_evaluations(evaluation_ids: List[str]) -> None:
"""
Delete evaluations by their IDs.
Args:
evaluation_ids (List[str]): A list of evaluation IDs.
- user_org_data (dict): User and organization data.
Raises:
HTTPException: If evaluation not found or access denied.
"""
for evaluation_id in evaluation_ids:
- evaluation = await _fetch_evaluation_and_check_access(
- evaluation_id=evaluation_id, **user_org_data
- )
+ evaluation = await db_manager.fetch_evaluation_by_id(evaluation_id)
await evaluation.delete()
async def create_new_human_evaluation(
- payload: NewHumanEvaluation, **user_org_data: dict
+ payload: NewHumanEvaluation, user_uid: str
) -> HumanEvaluationDB:
"""
Create a new evaluation based on the provided payload and additional arguments.
Args:
payload (NewEvaluation): The evaluation payload.
- **user_org_data (dict): Additional keyword arguments, e.g., user id.
+ user_uid (str): The user_uid of the user
Returns:
HumanEvaluationDB
"""
- user = await get_user(user_uid=user_org_data["uid"])
+ user = await db_manager.get_user(user_uid)
current_time = datetime.now()
@@ -639,7 +479,6 @@ async def create_new_human_evaluation(
]
eval_instance = HumanEvaluationDB(
app=app,
- organization=app.organization, # Assuming user has an organization_id attribute
user=user,
status=payload.status,
evaluation_type=payload.evaluation_type,
@@ -652,6 +491,11 @@ async def create_new_human_evaluation(
created_at=current_time,
updated_at=current_time,
)
+
+ if isCloudEE():
+ eval_instance.organization = app.organization
+ eval_instance.workspace = app.workspace
+
newEvaluation = await eval_instance.create()
if newEvaluation is None:
raise HTTPException(
@@ -698,7 +542,6 @@ async def create_new_evaluation(
evaluation_db = await db_manager.create_new_evaluation(
app=app,
- organization=app.organization,
user=app.user,
testset=testset,
status=Result(
@@ -707,13 +550,13 @@ async def create_new_evaluation(
variant=variant_id,
variant_revision=str(variant_revision.id),
evaluators_configs=evaluator_config_ids,
+ organization=app.organization if isCloudEE() else None,
+ workspace=app.workspace if isCloudEE() else None,
)
return await converters.evaluation_db_to_pydantic(evaluation_db)
-async def retrieve_evaluation_results(
- evaluation_id: str, **user_org_data: dict
-) -> List[dict]:
+async def retrieve_evaluation_results(evaluation_id: str) -> List[dict]:
"""Retrieve the aggregated results for a given evaluation.
Args:
@@ -725,20 +568,11 @@ async def retrieve_evaluation_results(
# Check for access rights
evaluation = await db_manager.fetch_evaluation_by_id(evaluation_id)
- access = await check_access_to_app(
- user_org_data=user_org_data, app_id=str(evaluation.app.id)
- )
- if not access:
- raise HTTPException(
- status_code=403,
- detail=f"You do not have access to this app: {str(evaluation.app.id)}",
- )
return await converters.aggregated_result_to_pydantic(evaluation.aggregated_results)
async def compare_evaluations_scenarios(
evaluations_ids: List[str],
- **user_org_data: dict,
):
evaluation = await db_manager.fetch_evaluation_by_id(evaluations_ids[0])
testset = evaluation.testset
@@ -750,7 +584,7 @@ async def compare_evaluations_scenarios(
for evaluation_id in evaluations_ids:
eval_scenarios = await fetch_evaluation_scenarios_for_evaluation(
- evaluation_id, **user_org_data
+ evaluation_id=evaluation_id
)
all_scenarios.append(eval_scenarios)
diff --git a/agenta-backend/agenta_backend/services/evaluator_manager.py b/agenta-backend/agenta_backend/services/evaluator_manager.py
index 5af6420919..c86a4ef946 100644
--- a/agenta-backend/agenta_backend/services/evaluator_manager.py
+++ b/agenta-backend/agenta_backend/services/evaluator_manager.py
@@ -1,16 +1,22 @@
-import json
import os
+import json
from typing import Any, Dict, Optional, List, Tuple
from fastapi.responses import JSONResponse
from agenta_backend.services import db_manager
+from agenta_backend.utils.common import isCloudEE
-
-from agenta_backend.models.db_models import AppDB, EvaluatorConfigDB
-from agenta_backend.models.api.evaluation_model import Evaluator, EvaluatorConfig
+if isCloudEE():
+ from agenta_backend.commons.models.db_models import (
+ AppDB_ as AppDB,
+ EvaluatorConfigDB_ as EvaluatorConfigDB,
+ )
+else:
+ from agenta_backend.models.db_models import AppDB, EvaluatorConfigDB
from agenta_backend.models.converters import evaluator_config_db_to_pydantic
from agenta_backend.resources.evaluators.evaluators import get_all_evaluators
+from agenta_backend.models.api.evaluation_model import Evaluator, EvaluatorConfig
def get_evaluators() -> Optional[List[Evaluator]]:
@@ -41,18 +47,17 @@ async def get_evaluators_configs(app_id: str) -> List[EvaluatorConfig]:
]
-async def get_evaluator_config(evaluator_config_id: str) -> EvaluatorConfig:
+async def get_evaluator_config(evaluator_config: EvaluatorConfig) -> EvaluatorConfig:
"""
Get an evaluator configuration by its ID.
Args:
- evaluator_config_id (str): The ID of the evaluator configuration.
+ evaluator_config: The evaluator configuration object.
Returns:
EvaluatorConfig: The evaluator configuration object.
"""
- evaluator_config_db = await db_manager.fetch_evaluator_config(evaluator_config_id)
- return evaluator_config_db_to_pydantic(evaluator_config_db)
+ return evaluator_config_db_to_pydantic(evaluator_config)
async def create_evaluator_config(
@@ -76,7 +81,8 @@ async def create_evaluator_config(
app = await db_manager.fetch_app_by_id(app_id)
evaluator_config = await db_manager.create_evaluator_config(
app=app,
- organization=app.organization,
+ organization=app.organization if isCloudEE() else None, # noqa,
+ workspace=app.workspace if isCloudEE() else None, # noqa,
user=app.user,
name=name,
evaluator_key=evaluator_key,
@@ -142,7 +148,8 @@ async def create_ready_to_use_evaluators(app: AppDB):
for evaluator in direct_use_evaluators:
await db_manager.create_evaluator_config(
app=app,
- organization=app.organization,
+ organization=app.organization if isCloudEE() else None, # noqa,
+ workspace=app.workspace if isCloudEE() else None, # noqa,
user=app.user,
name=evaluator["name"],
evaluator_key=evaluator["key"],
diff --git a/agenta-backend/agenta_backend/services/event_db_manager.py b/agenta-backend/agenta_backend/services/event_db_manager.py
index 474ca807e6..a30db917f8 100644
--- a/agenta-backend/agenta_backend/services/event_db_manager.py
+++ b/agenta-backend/agenta_backend/services/event_db_manager.py
@@ -34,7 +34,7 @@
async def get_variant_traces(
- app_id: str, variant_id: str, **kwargs: dict
+ app_id: str, variant_id: str, user_uid: str
) -> List[Trace]:
"""Get the traces for a given app variant.
@@ -46,7 +46,7 @@ async def get_variant_traces(
List[Trace]: the list of traces for the given app variant
"""
- user = await db_manager.get_user(user_uid=kwargs["uid"])
+ user = await db_manager.get_user(user_uid)
traces = await TraceDB.find(
TraceDB.user.id == user.id,
TraceDB.app_id == app_id,
@@ -56,7 +56,7 @@ async def get_variant_traces(
return [trace_db_to_pydantic(trace) for trace in traces]
-async def create_app_trace(payload: CreateTrace, **kwargs: dict) -> str:
+async def create_app_trace(payload: CreateTrace, user_uid: str) -> str:
"""Create a new trace.
Args:
@@ -66,7 +66,7 @@ async def create_app_trace(payload: CreateTrace, **kwargs: dict) -> str:
Trace: the created trace
"""
- user = await db_manager.get_user(user_uid=kwargs["uid"])
+ user = await db_manager.get_user(user_uid)
# Ensure spans exists in the db
for span in payload.spans:
@@ -79,7 +79,7 @@ async def create_app_trace(payload: CreateTrace, **kwargs: dict) -> str:
return str(trace_db.id)
-async def get_trace_single(trace_id: str, **kwargs: dict) -> Trace:
+async def get_trace_single(trace_id: str, user_uid: str) -> Trace:
"""Get a single trace.
Args:
@@ -89,7 +89,7 @@ async def get_trace_single(trace_id: str, **kwargs: dict) -> Trace:
Trace: the trace
"""
- user = await db_manager.get_user(user_uid=kwargs["uid"])
+ user = await db_manager.get_user(user_uid)
# Get trace
trace = await TraceDB.find_one(
@@ -99,7 +99,7 @@ async def get_trace_single(trace_id: str, **kwargs: dict) -> Trace:
async def trace_status_update(
- trace_id: str, payload: UpdateTrace, **kwargs: dict
+ trace_id: str, payload: UpdateTrace, user_uid: str
) -> bool:
"""Update status of trace.
@@ -111,7 +111,7 @@ async def trace_status_update(
bool: True if successful
"""
- user = await db_manager.get_user(user_uid=kwargs["uid"])
+ user = await db_manager.get_user(user_uid)
# Get trace
trace = await TraceDB.find_one(
@@ -124,7 +124,7 @@ async def trace_status_update(
return True
-async def create_trace_span(payload: CreateSpan, **kwargs: dict) -> str:
+async def create_trace_span(payload: CreateSpan) -> str:
"""Create a new span for a given trace.
Args:
@@ -139,7 +139,7 @@ async def create_trace_span(payload: CreateSpan, **kwargs: dict) -> str:
return str(span_db.id)
-async def get_trace_spans(trace_id: str, **kwargs: dict) -> List[Span]:
+async def get_trace_spans(trace_id: str, user_uid: str) -> List[Span]:
"""Get the spans for a given trace.
Args:
@@ -149,7 +149,7 @@ async def get_trace_spans(trace_id: str, **kwargs: dict) -> List[Span]:
List[Span]: the list of spans for the given trace
"""
- user = await db_manager.get_user(user_uid=kwargs["uid"])
+ user = await db_manager.get_user(user_uid)
# Get trace
trace = await TraceDB.find_one(
@@ -162,7 +162,7 @@ async def get_trace_spans(trace_id: str, **kwargs: dict) -> List[Span]:
async def add_feedback_to_trace(
- trace_id: str, payload: CreateFeedback, **kwargs: dict
+ trace_id: str, payload: CreateFeedback, user_uid: str
) -> str:
"""Add a feedback to a trace.
@@ -174,7 +174,7 @@ async def add_feedback_to_trace(
str: the feedback id
"""
- user = await db_manager.get_user(user_uid=kwargs["uid"])
+ user = await db_manager.get_user(user_uid)
feedback = FeedbackDB(
user_id=str(user.id),
feedback=payload.feedback,
@@ -193,7 +193,7 @@ async def add_feedback_to_trace(
return feedback.uid
-async def get_trace_feedbacks(trace_id: str, **kwargs: dict) -> List[Feedback]:
+async def get_trace_feedbacks(trace_id: str, user_uid: str) -> List[Feedback]:
"""Get the feedbacks for a given trace.
Args:
@@ -203,7 +203,7 @@ async def get_trace_feedbacks(trace_id: str, **kwargs: dict) -> List[Feedback]:
List[Feedback]: the list of feedbacks for the given trace
"""
- user = await db_manager.get_user(user_uid=kwargs["uid"])
+ user = await db_manager.get_user(user_uid)
# Get feedbacks in trace
trace = await TraceDB.find_one(
@@ -214,7 +214,7 @@ async def get_trace_feedbacks(trace_id: str, **kwargs: dict) -> List[Feedback]:
async def get_feedback_detail(
- trace_id: str, feedback_id: str, **kwargs: dict
+ trace_id: str, feedback_id: str, user_uid: str
) -> Feedback:
"""Get a single feedback.
@@ -226,7 +226,7 @@ async def get_feedback_detail(
Feedback: the feedback
"""
- user = await db_manager.get_user(user_uid=kwargs["uid"])
+ user = await db_manager.get_user(user_uid)
# Get trace
trace = await TraceDB.find_one(
@@ -243,7 +243,7 @@ async def get_feedback_detail(
async def update_trace_feedback(
- trace_id: str, feedback_id: str, payload: UpdateFeedback, **kwargs: dict
+ trace_id: str, feedback_id: str, payload: UpdateFeedback, user_uid: str
) -> Feedback:
"""Update a feedback.
@@ -256,7 +256,7 @@ async def update_trace_feedback(
Feedback: the feedback
"""
- user = await db_manager.get_user(user_uid=kwargs["uid"])
+ user = await db_manager.get_user(user_uid)
# Get trace
trace = await TraceDB.find_one(
diff --git a/agenta-backend/agenta_backend/services/llm_apps_service.py b/agenta-backend/agenta_backend/services/llm_apps_service.py
index 957c1ef941..e9c597e0a3 100644
--- a/agenta-backend/agenta_backend/services/llm_apps_service.py
+++ b/agenta-backend/agenta_backend/services/llm_apps_service.py
@@ -1,10 +1,10 @@
import json
-import asyncio
+import httpx
import logging
+import asyncio
import traceback
from typing import Any, Dict, List
-import httpx
from agenta_backend.models.db_models import InvokationResult, Result, Error
@@ -178,9 +178,9 @@ async def batch_invoke(
"delay_between_batches"
] # Delay between batches (in seconds)
- list_of_app_outputs: List[InvokationResult] = (
- []
- ) # Outputs after running all batches
+ list_of_app_outputs: List[
+ InvokationResult
+ ] = [] # Outputs after running all batches
openapi_parameters = await get_parameters_from_openapi(uri + "/openapi.json")
async def run_batch(start_idx: int):
diff --git a/agenta-backend/agenta_backend/services/results_service.py b/agenta-backend/agenta_backend/services/results_service.py
index e6a62d6a6a..86425f9fa5 100644
--- a/agenta-backend/agenta_backend/services/results_service.py
+++ b/agenta-backend/agenta_backend/services/results_service.py
@@ -1,12 +1,21 @@
-from agenta_backend.models.db_models import (
- HumanEvaluationDB,
- EvaluationScenarioDB,
- HumanEvaluationScenarioDB,
-)
+from beanie import PydanticObjectId as ObjectId
+
from agenta_backend.services import db_manager
+from agenta_backend.utils.common import isCloudEE
from agenta_backend.models.api.evaluation_model import EvaluationType
-from beanie import PydanticObjectId as ObjectId
+if isCloudEE():
+ from agenta_backend.commons.models.db_models import (
+ HumanEvaluationDB_ as HumanEvaluationDB,
+ EvaluationScenarioDB_ as EvaluationScenarioDB,
+ HumanEvaluationScenarioDB_ as HumanEvaluationScenarioDB,
+ )
+else:
+ from agenta_backend.models.db_models import (
+ HumanEvaluationDB,
+ EvaluationScenarioDB,
+ HumanEvaluationScenarioDB,
+ )
async def fetch_results_for_evaluation(evaluation: HumanEvaluationDB):
diff --git a/agenta-backend/agenta_backend/services/selectors.py b/agenta-backend/agenta_backend/services/selectors.py
deleted file mode 100644
index bae85cc8b5..0000000000
--- a/agenta-backend/agenta_backend/services/selectors.py
+++ /dev/null
@@ -1,61 +0,0 @@
-from typing import Tuple, Dict, List
-
-from agenta_backend.models.db_models import (
- UserDB,
- OrganizationDB,
-)
-
-
-async def get_user_and_org_id(user_uid_id) -> Dict[str, List]:
- """Retrieves the user ID and organization ID based on the logged-in session.
-
- Arguments:
- session (SessionContainer): Used to store and manage the user's session data
-
- Returns:
- A dictionary containing the user_id and a list of the user's organization_ids.
- """
- user_id, org_ids = await get_user_objectid(user_uid_id)
- return {"uid": user_id, "organization_ids": org_ids}
-
-
-async def get_user_objectid(user_uid: str) -> Tuple[str, List]:
- """Retrieves the user object ID and organization IDs from the database
- based on the user ID.
-
- Arguments:
- user_id (str): The unique identifier of a user
-
- Returns:
- a tuple containing the string representation of the user's ObjectId and the List
- of the user's organization_ids.
- """
-
- user = await UserDB.find_one(UserDB.uid == user_uid)
- if user is not None:
- user_id = str(user.uid)
- organization_ids: List = (
- [org for org in user.organizations] if user.organizations else []
- )
- return user_id, organization_ids
- return None, []
-
-
-async def get_user_own_org(user_uid: str) -> OrganizationDB:
- """Get's the default users' organization from the database.
-
- Arguments:
- user_uid (str): The uid of the user
-
- Returns:
- Organization: Instance of OrganizationDB
- """
-
- user = await UserDB.find_one(UserDB.uid == user_uid)
- org: OrganizationDB = await OrganizationDB.find_one(
- OrganizationDB.owner == str(user.id), OrganizationDB.type == "default"
- )
- if org is not None:
- return org
- else:
- return None
diff --git a/agenta-backend/agenta_backend/services/templates_manager.py b/agenta-backend/agenta_backend/services/templates_manager.py
index 28b793ce00..c77da181e6 100644
--- a/agenta-backend/agenta_backend/services/templates_manager.py
+++ b/agenta-backend/agenta_backend/services/templates_manager.py
@@ -1,13 +1,15 @@
+import os
import json
+import httpx
import backoff
+
from typing import Any, Dict, List
-import httpx
-import os
+from asyncio.exceptions import CancelledError
+from httpx import ConnectError, TimeoutException
+
from agenta_backend.config import settings
-from agenta_backend.services import db_manager
from agenta_backend.utils import redis_utils
-from httpx import ConnectError, TimeoutException
-from asyncio.exceptions import CancelledError
+from agenta_backend.services import db_manager
if os.environ["FEATURE_FLAG"] in ["oss", "cloud"]:
from agenta_backend.services import container_manager
diff --git a/agenta-backend/agenta_backend/services/user_service.py b/agenta-backend/agenta_backend/services/user_service.py
index b1761598fd..bf452b316c 100644
--- a/agenta-backend/agenta_backend/services/user_service.py
+++ b/agenta-backend/agenta_backend/services/user_service.py
@@ -1,4 +1,9 @@
-from agenta_backend.models.db_models import UserDB
+import os
+
+if os.environ["FEATURE_FLAG"] in ["cloud"]:
+ from agenta_backend.commons.models.db_models import UserDB_ as UserDB
+else:
+ from agenta_backend.models.db_models import UserDB
from agenta_backend.models.api.user_models import User, UserUpdate
diff --git a/agenta-backend/agenta_backend/tasks/evaluations.py b/agenta-backend/agenta_backend/tasks/evaluations.py
index cfe05d06e3..2408824ccb 100644
--- a/agenta-backend/agenta_backend/tasks/evaluations.py
+++ b/agenta-backend/agenta_backend/tasks/evaluations.py
@@ -1,14 +1,19 @@
+import re
+import os
import asyncio
import logging
-import os
-import re
import traceback
+
from typing import Any, Dict, List
+from celery import shared_task, states
+
+from agenta_backend.utils.common import isCloudEE
+from agenta_backend.models.db_engine import DBEngine
+from agenta_backend.services import evaluators_service, llm_apps_service
from agenta_backend.models.api.evaluation_model import (
EvaluationStatusEnum,
)
-from agenta_backend.models.db_engine import DBEngine
from agenta_backend.models.db_models import (
AggregatedResult,
AppDB,
@@ -38,7 +43,18 @@
EvaluationScenarioResult,
check_if_evaluation_contains_failed_evaluation_scenarios,
)
-from celery import shared_task, states
+
+if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
+ from agenta_backend.commons.models.db_models import AppDB_ as AppDB
+else:
+ from agenta_backend.models.db_models import AppDB
+from agenta_backend.models.db_models import (
+ Result,
+ AggregatedResult,
+ EvaluationScenarioResult,
+ EvaluationScenarioInputDB,
+ EvaluationScenarioOutputDB,
+)
# Set logger
logger = logging.getLogger(__name__)
@@ -237,7 +253,6 @@ def evaluate(
loop.run_until_complete(
create_new_evaluation_scenario(
user=app.user,
- organization=app.organization,
evaluation=new_evaluation_db,
variant_id=variant_id,
evaluators_configs=new_evaluation_db.evaluators_configs,
@@ -251,6 +266,8 @@ def evaluate(
)
],
results=evaluators_results,
+ organization=app.organization if isCloudEE() else None,
+ workspace=app.workspace if isCloudEE() else None,
)
)
diff --git a/agenta-backend/agenta_backend/tests/variants_main_router/conftest.py b/agenta-backend/agenta_backend/tests/variants_main_router/conftest.py
index 08173b6f6a..72b54cc4f2 100644
--- a/agenta-backend/agenta_backend/tests/variants_main_router/conftest.py
+++ b/agenta-backend/agenta_backend/tests/variants_main_router/conftest.py
@@ -10,9 +10,7 @@
ImageDB,
ConfigDB,
AppVariantDB,
- OrganizationDB,
)
-from agenta_backend.services import selectors
import httpx
@@ -39,12 +37,6 @@ async def get_first_user_object():
create_user = UserDB(uid="0")
await create_user.create()
- org = OrganizationDB(type="default", owner=str(create_user.id))
- await org.create()
-
- create_user.organizations.append(org.id)
- await create_user.save()
-
return create_user
return user
@@ -60,12 +52,6 @@ async def get_second_user_object():
)
await create_user.create()
- org = OrganizationDB(type="default", owner=str(create_user.id))
- await org.create()
-
- create_user.organizations.append(org.id)
- await create_user.save()
-
return create_user
return user
@@ -73,16 +59,14 @@ async def get_second_user_object():
@pytest.fixture()
async def get_first_user_app(get_first_user_object):
user = await get_first_user_object
- organization = await selectors.get_user_own_org(user.uid)
- app = AppDB(app_name="myapp", organization=organization, user=user)
+ app = AppDB(app_name="myapp", user=user)
await app.create()
db_image = ImageDB(
docker_id="sha256:xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
tags="agentaai/templates_v2:local_test_prompt",
user=user,
- organization=organization,
)
await db_image.create()
@@ -91,9 +75,7 @@ async def get_first_user_app(get_first_user_object):
parameters={},
)
- db_base = VariantBaseDB(
- base_name="app", image=db_image, organization=organization, user=user, app=app
- )
+ db_base = VariantBaseDB(base_name="app", image=db_image, user=user, app=app)
await db_base.create()
appvariant = AppVariantDB(
@@ -101,7 +83,6 @@ async def get_first_user_app(get_first_user_object):
variant_name="app",
image=db_image,
user=user,
- organization=organization,
parameters={},
base_name="app",
config_name="default",
@@ -111,7 +92,7 @@ async def get_first_user_app(get_first_user_object):
config=db_config,
)
await appvariant.create()
- return appvariant, user, organization, app, db_image, db_config, db_base
+ return appvariant, user, app, db_image, db_config, db_base
@pytest.fixture()
@@ -235,18 +216,11 @@ def fetch_single_prompt_template(fetch_templates):
)
-@pytest.fixture()
-async def fetch_user_organization():
- organization = await OrganizationDB.find().to_list()
- return {"org_id": str(organization[0].id)}
-
-
@pytest.fixture()
def app_from_template():
return {
"app_name": "string",
"env_vars": {"OPENAI_API_KEY": OPEN_AI_KEY},
- "organization_id": "string",
"template_id": "string",
}
diff --git a/agenta-backend/agenta_backend/tests/variants_main_router/test_app_variant_router.py b/agenta-backend/agenta_backend/tests/variants_main_router/test_app_variant_router.py
index 6a8141f8eb..880ecdab84 100644
--- a/agenta-backend/agenta_backend/tests/variants_main_router/test_app_variant_router.py
+++ b/agenta-backend/agenta_backend/tests/variants_main_router/test_app_variant_router.py
@@ -5,7 +5,7 @@
from bson import ObjectId
from agenta_backend.routers import app_router
-from agenta_backend.services import selectors, db_manager
+from agenta_backend.services import db_manager
from agenta_backend.models.db_models import (
AppDB,
VariantBaseDB,
@@ -36,13 +36,11 @@
@pytest.mark.asyncio
async def test_create_app(get_first_user_object):
user = await get_first_user_object
- organization = await selectors.get_user_own_org(user.uid)
response = await test_client.post(
f"{BACKEND_API_HOST}/apps/",
json={
"app_name": "app_variant_test",
- "organization_id": str(organization.id),
},
timeout=timeout,
)
@@ -61,14 +59,12 @@ async def test_list_apps():
@pytest.mark.asyncio
async def test_create_app_variant(get_first_user_object):
user = await get_first_user_object
- organization = await selectors.get_user_own_org(user.uid)
app = await AppDB.find_one(AppDB.app_name == "app_variant_test")
db_image = ImageDB(
docker_id="sha256:xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
tags="agentaai/templates_v2:local_test_prompt",
user=user,
- organization=organization,
)
await db_image.create()
@@ -80,7 +76,6 @@ async def test_create_app_variant(get_first_user_object):
db_base = VariantBaseDB(
base_name="app",
app=app,
- organization=organization,
user=user,
image=db_image,
)
@@ -91,7 +86,6 @@ async def test_create_app_variant(get_first_user_object):
variant_name="app",
image=db_image,
user=user,
- organization=organization,
parameters={},
base_name="app",
config_name="default",
@@ -118,25 +112,6 @@ async def test_list_app_variants():
assert len(response.json()) == 1
-@pytest.mark.asyncio
-async def test_delete_app_without_permission(get_second_user_object):
- user2 = await get_second_user_object
- user2_organization = await selectors.get_user_own_org(user2.uid)
-
- user2_app = AppDB(
- app_name="test_app_by_user2",
- organization=user2_organization,
- user=user2,
- )
- await user2_app.create()
-
- response = await test_client.delete(
- f"{BACKEND_API_HOST}/apps/{str(user2_app.id)}/",
- timeout=timeout,
- )
- assert response.status_code == 400
-
-
@pytest.mark.asyncio
async def test_list_environments():
app = await AppDB.find_one(AppDB.app_name == "app_variant_test")
@@ -150,7 +125,7 @@ async def test_list_environments():
@pytest.mark.asyncio
async def test_get_variant_by_env(get_first_user_app):
- _, _, _, app, _, _, _ = await get_first_user_app
+ _, _, app, _, _, _ = await get_first_user_app
environments = await db_manager.list_environments(app_id=str(app.id))
for environment in environments:
diff --git a/agenta-backend/agenta_backend/tests/variants_main_router/test_observability_router.py b/agenta-backend/agenta_backend/tests/variants_main_router/test_observability_router.py
index 7cb6b4e1bf..b2ce08ad09 100644
--- a/agenta-backend/agenta_backend/tests/variants_main_router/test_observability_router.py
+++ b/agenta-backend/agenta_backend/tests/variants_main_router/test_observability_router.py
@@ -9,13 +9,10 @@
SpanDB,
UserDB,
TraceDB,
- OrganizationDB,
ImageDB,
AppVariantDB,
VariantBaseDB,
)
-from agenta_backend.services import selectors
-
import httpx
@@ -44,11 +41,8 @@ async def test_create_spans_endpoint(spans_db_data):
@pytest.mark.asyncio
async def test_create_image_in_db(image_create_data):
user_db = await UserDB.find_one(UserDB.uid == "0")
- organization_db = await OrganizationDB.find_one(
- OrganizationDB.owner == str(user_db.id)
- )
- image_db = ImageDB(**image_create_data, user=user_db, organization=organization_db)
+ image_db = ImageDB(**image_create_data, user=user_db)
await image_db.create()
assert image_db.user.id == user_db.id
@@ -58,12 +52,10 @@ async def test_create_image_in_db(image_create_data):
@pytest.mark.asyncio
async def test_create_appvariant_in_db(app_variant_create_data):
user_db = await UserDB.find_one(UserDB.uid == "0")
- organization_db = await selectors.get_user_own_org(user_db.uid)
image_db = await ImageDB.find_one(ImageDB.user.id == user_db.id)
app = AppDB(
app_name="test_app",
- organization=organization_db,
user=user_db,
)
await app.create()
@@ -75,7 +67,6 @@ async def test_create_appvariant_in_db(app_variant_create_data):
db_base = VariantBaseDB(
app=app,
- organization=organization_db,
user=user_db,
base_name="app",
image=image_db,
@@ -87,7 +78,6 @@ async def test_create_appvariant_in_db(app_variant_create_data):
app=app,
image=image_db,
user=user_db,
- organization=organization_db,
base_name="app",
config_name="default",
base=db_base,
diff --git a/agenta-backend/agenta_backend/tests/variants_main_router/test_variant_evaluators_router.py b/agenta-backend/agenta_backend/tests/variants_main_router/test_variant_evaluators_router.py
index 6ef0fd6c3e..51e47fe295 100644
--- a/agenta-backend/agenta_backend/tests/variants_main_router/test_variant_evaluators_router.py
+++ b/agenta-backend/agenta_backend/tests/variants_main_router/test_variant_evaluators_router.py
@@ -30,12 +30,10 @@
@pytest.mark.asyncio
async def test_create_app_from_template(
- app_from_template, fetch_user, fetch_single_prompt_template
+ app_from_template, fetch_single_prompt_template
):
- user = await fetch_user
payload = app_from_template
payload["app_name"] = APP_NAME
- payload["organization_id"] = str(user.organizations[0])
payload["template_id"] = fetch_single_prompt_template["id"]
response = httpx.post(
@@ -192,7 +190,10 @@ async def test_create_evaluation():
assert response.status_code == 200
assert response_data["app_id"] == payload["app_id"]
- assert response_data["status"]["value"] == EvaluationStatusEnum.EVALUATION_STARTED
+ assert (
+ response_data["status"]["value"]
+ == EvaluationStatusEnum.EVALUATION_STARTED.value
+ )
assert response_data is not None
diff --git a/agenta-backend/agenta_backend/tests/variants_main_router/test_variant_versioning_deployment.py b/agenta-backend/agenta_backend/tests/variants_main_router/test_variant_versioning_deployment.py
index ae1375cbe7..405801405c 100644
--- a/agenta-backend/agenta_backend/tests/variants_main_router/test_variant_versioning_deployment.py
+++ b/agenta-backend/agenta_backend/tests/variants_main_router/test_variant_versioning_deployment.py
@@ -65,55 +65,3 @@ async def test_deploy_to_environment(deploy_to_environment_payload):
assert (
list_of_response_status_codes.count(200) == 3
), "The list does not contain 3 occurrences of 200 status code"
-
-
-@pytest.mark.asyncio
-async def test_list_app_environment_revisions():
- app = await AppDB.find_one(AppDB.app_name == APP_NAME)
- list_of_response_data = []
- list_of_response_status_codes = []
- for environment in VARIANT_DEPLOY_ENVIRONMENTS:
- response = await test_client.get(
- f"{BACKEND_API_HOST}/apps/{str(app.id)}/revisions/{environment}"
- )
- list_of_response_data.append(response.json())
- list_of_response_status_codes.append(response.status_code)
- assert (
- list_of_response_status_codes.count(200) == 3
- ), "The list does not container 3 occurrences of 200 status code"
- assert len(list_of_response_data) == 3, "The list does not contain 3 response data"
-
-
-@pytest.mark.asyncio
-async def test_get_config_deployment_revision():
- app = await AppDB.find_one(AppDB.app_name == APP_NAME)
- app_environment_revisions_response = await test_client.get(
- f"{BACKEND_API_HOST}/apps/{str(app.id)}/revisions/{VARIANT_DEPLOY_ENVIRONMENTS[0]}"
- )
-
- if app_environment_revisions_response.status_code == 200:
- revisions = app_environment_revisions_response.json()["revisions"]
- config_deployment_revision_response = await test_client.get(
- f"{BACKEND_API_HOST}/configs/deployment/{revisions[0]['id']}"
- )
- assert config_deployment_revision_response.status_code == 200
- assert config_deployment_revision_response.json() is not None
- else:
- assert False, "App environment revisions response is not 200"
-
-
-@pytest.mark.asyncio
-async def test_revert_deployment_revision():
- app = await AppDB.find_one(AppDB.app_name == APP_NAME)
- app_environment_revisions_response = await test_client.get(
- f"{BACKEND_API_HOST}/apps/{str(app.id)}/revisions/{VARIANT_DEPLOY_ENVIRONMENTS[0]}"
- )
-
- if app_environment_revisions_response.status_code == 200:
- revisions = app_environment_revisions_response.json()["revisions"]
- revert_deployment_revision_response = await test_client.post(
- f"{BACKEND_API_HOST}/configs/deployment/{revisions[0]['id']}/revert/"
- )
- assert revert_deployment_revision_response.status_code == 200
- else:
- assert False, "App environment revisions response is not 200"
diff --git a/agenta-backend/agenta_backend/tests/variants_organization_router/test_organization_router.py b/agenta-backend/agenta_backend/tests/variants_organization_router/test_organization_router.py
deleted file mode 100644
index 222c9dbe9c..0000000000
--- a/agenta-backend/agenta_backend/tests/variants_organization_router/test_organization_router.py
+++ /dev/null
@@ -1,50 +0,0 @@
-import os
-
-from agenta_backend.services import selectors
-from agenta_backend.models.db_models import UserDB
-from agenta_backend.models.api.organization_models import OrganizationOutput
-
-import httpx
-import pytest
-
-
-# Initialize http client
-test_client = httpx.AsyncClient()
-timeout = httpx.Timeout(timeout=5, read=None, write=5)
-
-# Set global variables
-ENVIRONMENT = os.environ.get("ENVIRONMENT")
-if ENVIRONMENT == "development":
- BACKEND_API_HOST = "http://host.docker.internal/api"
-elif ENVIRONMENT == "github":
- BACKEND_API_HOST = "http://agenta-backend-test:8000"
-
-
-@pytest.mark.asyncio
-async def test_list_organizations():
- response = await test_client.get(f"{BACKEND_API_HOST}/organizations/")
-
- assert response.status_code == 200
- assert len(response.json()) == 1
-
-
-@pytest.mark.asyncio
-async def test_get_user_organization():
- user = await UserDB.find_one(UserDB.uid == "0")
- user_org = await selectors.get_user_own_org(user.uid)
-
- response = await test_client.get(f"{BACKEND_API_HOST}/organizations/own/")
-
- assert response.status_code == 200
- assert response.json() == OrganizationOutput(
- id=str(user_org.id), name=user_org.name
- )
-
-
-@pytest.mark.asyncio
-async def test_user_does_not_have_an_organization():
- user = UserDB(uid="0123", username="john_doe", email="johndoe@email.com")
- await user.create()
-
- user_org = await selectors.get_user_own_org(user.uid)
- assert user_org == None
diff --git a/agenta-backend/agenta_backend/utils/common.py b/agenta-backend/agenta_backend/utils/common.py
index 4ed0f23c82..02af4e73b8 100644
--- a/agenta-backend/agenta_backend/utils/common.py
+++ b/agenta-backend/agenta_backend/utils/common.py
@@ -1,20 +1,10 @@
+import os
import logging
-from typing import Dict, List, Union, Optional, Any, Callable
+from typing import Any, Callable
from fastapi.types import DecoratedCallable
from fastapi import APIRouter as FastAPIRouter
-from agenta_backend.models.db_models import (
- UserDB,
- AppVariantDB,
- OrganizationDB,
- AppDB,
- VariantBaseDB,
-)
-
-from beanie import PydanticObjectId as ObjectId
-
-
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -60,109 +50,17 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable:
return decorator
-async def get_organization(org_id: str) -> OrganizationDB:
- org = await OrganizationDB.find_one(OrganizationDB.id == ObjectId(org_id))
- if org is not None:
- return org
- else:
- return None
-
-
-async def get_app_instance(
- app_id: str, variant_name: str = None, show_deleted: bool = False
-) -> AppVariantDB:
- queries = (AppVariantDB.is_deleted == show_deleted, AppVariantDB.app == app_id)
- if variant_name is not None:
- queries += AppVariantDB.variant_name == variant_name
-
- app_instance = await AppVariantDB.find_one(*queries)
- return app_instance
-
-
-async def check_user_org_access(
- kwargs: dict, organization_id: str, check_owner=False
-) -> bool:
- if check_owner: # Check that the user is the owner of the organization
- user = await UserDB.find_one(UserDB.uid == kwargs["uid"])
- organization = await get_organization(organization_id)
- if not organization:
- logger.error("Organization not found")
- raise Exception("Organization not found")
- return organization.owner == str(user.id)
- else:
- user_organizations: List = kwargs["organization_ids"]
- object_organization_id = ObjectId(organization_id)
- logger.debug(
- f"object_organization_id: {object_organization_id}, user_organizations: {user_organizations}"
- )
- user_exists_in_organizations = object_organization_id in user_organizations
- return user_exists_in_organizations
+def isCloudEE():
+ return os.environ["FEATURE_FLAG"] in ["cloud", "ee"]
-async def check_access_to_app(
- user_org_data: Dict[str, Union[str, list]],
- app: Optional[AppDB] = None,
- app_id: Optional[str] = None,
- check_owner: bool = False,
-) -> bool:
- """
- Check if a user has access to a specific application.
+def isCloud():
+ return os.environ["FEATURE_FLAG"] == "cloud"
- Args:
- user_org_data (Dict[str, Union[str, list]]): User-specific information.
- app (Optional[AppDB]): An instance of the AppDB model representing the application.
- app_id (Optional[str]): The ID of the application.
- check_owner (bool): Whether to check if the user is the owner of the application.
- Returns:
- bool: True if the user has access, False otherwise.
+def isEE():
+ return os.environ["FEATURE_FLAG"] == "ee"
- Raises:
- Exception: If neither or both `app` and `app_id` are provided.
- """
- if (app is None) == (app_id is None):
- raise Exception("Provide either app or app_id, not both or neither")
-
- # Fetch the app if only app_id is provided.
- if app is None:
- app = await AppDB.find_one(AppDB.id == ObjectId(app_id), fetch_links=True)
- if app is None:
- logger.error("App not found")
- return False
-
- # Check user's access to the organization linked to the app.
- organization_id = app.organization.id
- return await check_user_org_access(user_org_data, str(organization_id), check_owner)
-
-
-async def check_access_to_variant(
- user_org_data: Dict[str, Union[str, list]],
- variant_id: str,
- check_owner: bool = False,
-) -> bool:
- if variant_id is None:
- raise Exception("No variant_id provided")
- variant = await AppVariantDB.find_one(
- AppVariantDB.id == ObjectId(variant_id), fetch_links=True
- )
- if variant is None:
- logger.error("Variant not found")
- return False
- organization_id = variant.organization.id
- return await check_user_org_access(user_org_data, str(organization_id), check_owner)
-
-
-async def check_access_to_base(
- user_org_data: Dict[str, Union[str, list]],
- base_id: str,
- check_owner: bool = False,
-) -> bool:
- if base_id is None:
- raise Exception("No base_id provided")
- base = await VariantBaseDB.find_one(VariantBaseDB.id == base_id, fetch_links=True)
- if base is None:
- logger.error("Base not found")
- return False
- organization_id = base.organization.id
- return await check_user_org_access(user_org_data, str(organization_id), check_owner)
+def isOssEE():
+ return os.environ["FEATURE_FLAG"] in ["oss", "ee"]
diff --git a/agenta-cli/agenta/cli/main.py b/agenta-cli/agenta/cli/main.py
index d2ee1fef00..315a4e5295 100644
--- a/agenta-cli/agenta/cli/main.py
+++ b/agenta-cli/agenta/cli/main.py
@@ -141,11 +141,19 @@ def init(app_name: str, backend_host: str):
api_key=api_key if where_question == "On agenta cloud" else "",
)
+ # list of user organizations
+ user_organizations = []
+
# validate the api key if it is provided
if where_question == "On agenta cloud":
try:
key_prefix = api_key.split(".")[0]
client.validate_api_key(key_prefix=key_prefix)
+
+ # Make request to fetch user organizations after api key validation
+ organizations = client.list_organizations()
+ if len(organizations) >= 1:
+ user_organizations = organizations
except Exception as ex:
if ex.status_code == 401:
click.echo(click.style("Error: Invalid API key", fg="red"))
@@ -154,9 +162,28 @@ def init(app_name: str, backend_host: str):
click.echo(click.style(f"Error: {ex}", fg="red"))
sys.exit(1)
+ if where_question == "On agenta cloud":
+ which_organization = questionary.select(
+ "Which organization do you want to create the app for?",
+ choices=[
+ f"{org.name}: {org.description}" for org in user_organizations
+ ],
+ ).ask()
+ filtered_org = next(
+ (
+ org
+ for org in user_organizations
+ if org.name == which_organization.split(":")[0]
+ ),
+ None,
+ )
+
# Get app_id after creating new app in the backend server
try:
- app_id = client.create_app(app_name=app_name).app_id
+ app_id = client.create_app(
+ app_name=app_name,
+ organization_id=filtered_org.id if filtered_org else None,
+ ).app_id
except Exception as ex:
click.echo(click.style(f"Error: {ex}", fg="red"))
sys.exit(1)
diff --git a/agenta-web/src/components/AppSelector/AppCard.tsx b/agenta-web/src/components/AppSelector/AppCard.tsx
index 08f4e87356..0fcd89b55d 100644
--- a/agenta-web/src/components/AppSelector/AppCard.tsx
+++ b/agenta-web/src/components/AppSelector/AppCard.tsx
@@ -7,7 +7,6 @@ import {renameVariablesCapitalizeAll} from "@/lib/helpers/utils"
import {createUseStyles} from "react-jss"
import {getGradientFromStr} from "@/lib/helpers/colors"
import {ListAppsItem} from "@/lib/Types"
-import {useProfileData, Role} from "@/contexts/profile.context"
import {useAppsData} from "@/contexts/app.context"
const useStyles = createUseStyles({
@@ -79,8 +78,6 @@ const AppCard: React.FC<{
}> = ({app}) => {
const [visibleDelete, setVisibleDelete] = useState(false)
const [confirmLoading, setConfirmLoading] = useState(false)
- const {role} = useProfileData()
- const isOwner = role === Role.OWNER
const {mutate} = useAppsData()
const showDeleteModal = () => {
@@ -109,11 +106,7 @@ const AppCard: React.FC<{
<>
]
- : undefined
- }
+ actions={[]}
>
{
- const router = useRouter()
const posthog = usePostHogAg()
const {appTheme} = useAppTheme()
const classes = useStyles({themeMode: appTheme} as StyleProps)
@@ -123,7 +122,6 @@ const AppSelector: React.FC = () => {
const [statusModalOpen, setStatusModalOpen] = useState(false)
const [fetchingTemplate, setFetchingTemplate] = useState(false)
const [newApp, setNewApp] = useState("")
- const {selectedOrg} = useProfileData()
const {apps, error, isLoading, mutate} = useAppsData()
const [statusData, setStatusData] = useState<{status: string; details?: any; appId?: string}>({
status: "",
@@ -207,7 +205,6 @@ const AppSelector: React.FC = () => {
await createAndStartTemplate({
appName: newApp,
templateId: template_id,
- orgId: selectedOrg?.id!,
providerKey:
isDemo() && apiKey?.length === 0
? []
@@ -287,11 +284,7 @@ const AppSelector: React.FC = () => {
{
- if (
- isDemo() &&
- selectedOrg?.is_paying == false &&
- apps.length > 2
- ) {
+ if (isDemo() && apps.length > 2) {
showMaxAppError()
} else {
showCreateAppModal()
diff --git a/agenta-web/src/components/Evaluations/Evaluations.tsx b/agenta-web/src/components/Evaluations/Evaluations.tsx
index b41b85c57e..29230f21ec 100644
--- a/agenta-web/src/components/Evaluations/Evaluations.tsx
+++ b/agenta-web/src/components/Evaluations/Evaluations.tsx
@@ -13,7 +13,7 @@ import {
} from "antd"
import {DownOutlined} from "@ant-design/icons"
import {createNewEvaluation, fetchVariants, useLoadTestsetsList} from "@/lib/services/api"
-import {dynamicComponent, getAllProviderLlmKeys, getApikeys, isDemo} from "@/lib/helpers/utils"
+import {getAllProviderLlmKeys, getApikeys, isDemo} from "@/lib/helpers/utils"
import {useRouter} from "next/router"
import {Variant, Parameter, GenericObject, JSSTheme} from "@/lib/Types"
import {EvaluationType} from "@/lib/enums"
@@ -29,6 +29,8 @@ import {createUseStyles} from "react-jss"
import HumanEvaluationResult from "./HumanEvaluationResult"
import {getErrorMessage} from "@/lib/helpers/errorHandler"
import AutomaticEvaluationResult from "./AutomaticEvaluationResult"
+import {dynamicComponent} from "@/lib/helpers/dynamic"
+import {PERMISSION_ERR_MSG} from "@/lib/helpers/axiosConfig"
type StyleProps = {
themeMode: "dark" | "light"
@@ -354,11 +356,13 @@ export default function Evaluations() {
selectedCustomEvaluationID,
testsetId: selectedTestset._id!,
}).catch((err) => {
- setError({
- message: getErrorMessage(err),
- btnText: "Go to Test sets",
- endpoint: `/apps/${appId}/testsets`,
- })
+ if (err.message !== PERMISSION_ERR_MSG) {
+ setError({
+ message: getErrorMessage(err),
+ btnText: "Go to Test sets",
+ endpoint: `/apps/${appId}/testsets`,
+ })
+ }
})
if (!evaluationTableId) {
diff --git a/agenta-web/src/components/Playground/Views/TestView.tsx b/agenta-web/src/components/Playground/Views/TestView.tsx
index 7b0cf5bc36..daabacb5c6 100644
--- a/agenta-web/src/components/Playground/Views/TestView.tsx
+++ b/agenta-web/src/components/Playground/Views/TestView.tsx
@@ -10,7 +10,7 @@ import {
Parameter,
Variant,
} from "@/lib/Types"
-import {batchExecute, dynamicComponent, randString, removeKeys} from "@/lib/helpers/utils"
+import {batchExecute, randString, removeKeys} from "@/lib/helpers/utils"
import LoadTestsModal from "../LoadTestsModal"
import AddToTestSetDrawer from "../AddToTestSetDrawer/AddToTestSetDrawer"
import {DeleteOutlined} from "@ant-design/icons"
@@ -29,6 +29,7 @@ import dayjs from "dayjs"
import relativeTime from "dayjs/plugin/relativeTime"
import duration from "dayjs/plugin/duration"
import {useQueryParam} from "@/hooks/useQuery"
+import {dynamicComponent} from "@/lib/helpers/dynamic"
const PromptVersioningDrawer: any = dynamicComponent(
`PromptVersioningDrawer/PromptVersioningDrawer`,
diff --git a/agenta-web/src/components/Sidebar/Sidebar.tsx b/agenta-web/src/components/Sidebar/Sidebar.tsx
index f0aacbb82c..65b9e7a1fd 100644
--- a/agenta-web/src/components/Sidebar/Sidebar.tsx
+++ b/agenta-web/src/components/Sidebar/Sidebar.tsx
@@ -10,10 +10,9 @@ import {
PhoneOutlined,
SettingOutlined,
LogoutOutlined,
- ApartmentOutlined,
FormOutlined,
} from "@ant-design/icons"
-import {Layout, Menu, Space, Tooltip, theme, Avatar} from "antd"
+import {Layout, Menu, Space, Tooltip, theme} from "antd"
import Logo from "../Logo/Logo"
import Link from "next/link"
@@ -22,9 +21,9 @@ import {ErrorBoundary} from "react-error-boundary"
import {createUseStyles} from "react-jss"
import AlertPopup from "../AlertPopup/AlertPopup"
import {useProfileData} from "@/contexts/profile.context"
-import {getColorFromStr} from "@/lib/helpers/colors"
-import {getInitials, isDemo} from "@/lib/helpers/utils"
+import {isDemo} from "@/lib/helpers/utils"
import {useSession} from "@/hooks/useSession"
+import {dynamicComponent} from "@/lib/helpers/dynamic"
import {useLocalStorage} from "usehooks-ts"
type StyleProps = {
@@ -77,44 +76,10 @@ const useStyles = createUseStyles({
menuLinks: {
width: "100%",
},
- menuItemNoBg: {
- textOverflow: "unset !important",
- "& .ant-menu-submenu-title": {display: "flex", alignItems: "center"},
- "& .ant-select-selector": {
- padding: "0 !important",
- },
- "&> span": {
- display: "inline-block",
- marginTop: 4,
- },
- "& .ant-select-selection-item": {
- "&> span > span": {
- width: 120,
- marginRight: 10,
- },
- },
- },
- orgLabel: {
- display: "flex",
- alignItems: "center",
- gap: "6px",
- justifyContent: "flex-start",
- "&> div": {
- width: 18,
- height: 18,
- aspectRatio: "1/1",
- borderRadius: "50%",
- },
- "&> span": {
- overflow: "hidden",
- textOverflow: "ellipsis",
- whiteSpace: "nowrap",
- },
- },
})
const Sidebar: React.FC = () => {
- const {appTheme, toggleAppTheme} = useAppTheme()
+ const {appTheme} = useAppTheme()
const {
token: {colorBgContainer},
} = theme.useToken()
@@ -138,7 +103,7 @@ const Sidebar: React.FC = () => {
initialSelectedKeys = ["apps"]
}
const [selectedKeys, setSelectedKeys] = useState(initialSelectedKeys)
- const {user, orgs, selectedOrg, changeSelectedOrg, reset} = useProfileData()
+ const {user} = useProfileData()
const [collapsed, setCollapsed] = useLocalStorage("sidebarCollapsed", false)
useEffect(() => {
@@ -161,6 +126,8 @@ const Sidebar: React.FC = () => {
})
}
+ const OrgsListSubMenu = dynamicComponent("OrgsListSubMenu/OrgsListSubMenu")
+
return (
{
Book Onboarding Call
- {selectedOrg && (
- }
- >
- {orgs.map((org, index) => (
-
- {getInitials(org.name)}
-
- }
- onClick={() => changeSelectedOrg(org.id)}
- >
- {org.name}
-
- ))}
-
- )}
+
{user?.username && (
{},
}
+const useApps = () => {
+ const [useOrgData, setUseOrgData] = useState(() => () => "")
+
+ useEffect(() => {
+ dynamicContext("org.context", {useOrgData}).then((context) => {
+ setUseOrgData(() => context.useOrgData)
+ })
+ }, [])
+
+ const {selectedOrg, loading} = useOrgData()
+ const {data, error, isLoading, mutate} = useSWR(
+ `${getAgentaApiUrl()}/api/apps/` +
+ (isDemo()
+ ? `?org_id=${selectedOrg?.id}&workspace_id=${selectedOrg?.default_workspace.id}`
+ : ""),
+ isDemo() ? (selectedOrg?.id ? axiosFetcher : () => {}) : axiosFetcher,
+ {
+ shouldRetryOnError: false,
+ },
+ )
+ return {
+ data: (data || []) as ListAppsItem[],
+ error,
+ isLoading: isLoading || loading,
+ mutate,
+ }
+}
+
export const AppContext = createContext(initialValues)
export const useAppsData = () => useContext(AppContext)
diff --git a/agenta-web/src/contexts/profile.context.tsx b/agenta-web/src/contexts/profile.context.tsx
index 0e6ca730b7..d39c3f9b8d 100644
--- a/agenta-web/src/contexts/profile.context.tsx
+++ b/agenta-web/src/contexts/profile.context.tsx
@@ -2,46 +2,20 @@ import {usePostHogAg} from "@/hooks/usePostHogAg"
import {useSession} from "@/hooks/useSession"
import useStateCallback from "@/hooks/useStateCallback"
import {isDemo} from "@/lib/helpers/utils"
-import {getOrgsList, getProfile} from "@/lib/services/api"
-import {Org, User} from "@/lib/Types"
-import {useRouter} from "next/router"
-import {
- PropsWithChildren,
- createContext,
- useState,
- useContext,
- useEffect,
- useCallback,
- useMemo,
-} from "react"
-import {useUpdateEffect} from "usehooks-ts"
-
-const LS_ORG_KEY = "selectedOrg"
-
-export enum Role {
- OWNER = "owner",
- ADMIN = "admin",
- MEMBER = "member",
-}
+import {getProfile} from "@/lib/services/api"
+import {User} from "@/lib/Types"
+import {PropsWithChildren, createContext, useState, useContext, useEffect, useCallback} from "react"
type ProfileContextType = {
user: User | null
- orgs: Org[]
- selectedOrg: Org | null
- role: Role | null
loading: boolean
- changeSelectedOrg: (orgId: string, onSuccess?: () => void) => void
reset: () => void
refetch: (onSuccess?: () => void) => void
}
const initialValues: ProfileContextType = {
user: null,
- orgs: [],
- selectedOrg: null,
- role: null,
loading: false,
- changeSelectedOrg: () => {},
reset: () => {},
refetch: () => {},
}
@@ -56,27 +30,16 @@ export const getProfileValues = () => profileContextValues
const ProfileContextProvider: React.FC = ({children}) => {
const posthog = usePostHogAg()
- const router = useRouter()
- const [user, setUser] = useState(null)
- const [orgs, setOrgs] = useState([])
- const [selectedOrg, setSelectedOrg] = useStateCallback(null)
+ const [user, setUser] = useStateCallback(null)
const [loading, setLoading] = useState(false)
const {logout, doesSessionExist} = useSession()
const fetcher = useCallback((onSuccess?: () => void) => {
setLoading(true)
- Promise.all([getProfile(), getOrgsList()])
- .then(([profile, orgs]) => {
+ getProfile()
+ .then((profile) => {
posthog.identify()
- setUser(profile.data)
- setOrgs(orgs.data)
- setSelectedOrg(
- orgs.data.find((org: Org) => org.id === localStorage.getItem(LS_ORG_KEY)) ||
- orgs.data.find((org: Org) => org.owner === profile.data.id) ||
- orgs.data[0] ||
- null,
- onSuccess,
- )
+ setUser(profile.data, onSuccess)
})
.catch((error) => {
console.error(error)
@@ -85,52 +48,24 @@ const ProfileContextProvider: React.FC = ({children}) => {
.finally(() => setLoading(false))
}, [])
- useUpdateEffect(() => {
- localStorage.setItem(LS_ORG_KEY, selectedOrg?.id || "")
- }, [selectedOrg?.id])
-
useEffect(() => {
- // fetch profile and orgs list only if user is logged in
+ // fetch profile only if user is logged in
if (doesSessionExist) {
fetcher()
}
}, [doesSessionExist])
- const changeSelectedOrg: ProfileContextType["changeSelectedOrg"] = (orgId, onSuccess) => {
- setSelectedOrg(
- orgs.find((org) => org.id === orgId) || selectedOrg,
- onSuccess ||
- (() => {
- router.push("/apps")
- }),
- )
- }
-
const reset = () => {
setUser(initialValues.user)
- setOrgs(initialValues.orgs)
- setSelectedOrg(initialValues.selectedOrg)
}
- const role = useMemo(
- () => (loading ? null : selectedOrg?.owner === user?.id ? Role.OWNER : Role.MEMBER),
- [selectedOrg, user, loading],
- )
-
profileContextValues.user = user
- profileContextValues.orgs = orgs
- profileContextValues.selectedOrg = selectedOrg
- profileContextValues.changeSelectedOrg = changeSelectedOrg
return (
organization_id?: string
+ workspace_id?: string
}
export type GenericObject = Record
@@ -350,14 +351,6 @@ export interface User {
email: string
}
-export interface Org {
- id: string
- name: string
- description?: string
- owner: string
- is_paying: boolean
-}
-
export enum ChatRole {
System = "system",
User = "user",
diff --git a/agenta-web/src/lib/helpers/axiosConfig.ts b/agenta-web/src/lib/helpers/axiosConfig.ts
index cbdb76ef95..13a5844d2b 100644
--- a/agenta-web/src/lib/helpers/axiosConfig.ts
+++ b/agenta-web/src/lib/helpers/axiosConfig.ts
@@ -4,6 +4,10 @@ import {signOut} from "supertokens-auth-react/recipe/thirdpartypasswordless"
import router from "next/router"
import {getAgentaApiUrl} from "./utils"
import {isObject} from "lodash"
+import AlertPopup from "@/components/AlertPopup/AlertPopup"
+
+export const PERMISSION_ERR_MSG =
+ "You don't have permission to perform this action. Please contact your organization admin."
const axios = axiosApi.create({
baseURL: getAgentaApiUrl(),
@@ -27,6 +31,17 @@ axios.interceptors.response.use(
return response
},
(error) => {
+ if (error.response?.status === 403 && error.config.method !== "get") {
+ AlertPopup({
+ title: "Permission Denied",
+ message: PERMISSION_ERR_MSG,
+ cancelText: null,
+ okText: "Ok",
+ })
+ error.message = PERMISSION_ERR_MSG
+ throw error
+ }
+
// if axios config has _ignoreError set to true, then don't handle error
if (error.config?._ignoreError) throw error
diff --git a/agenta-web/src/lib/helpers/dynamic.ts b/agenta-web/src/lib/helpers/dynamic.ts
new file mode 100644
index 0000000000..19f5a003ce
--- /dev/null
+++ b/agenta-web/src/lib/helpers/dynamic.ts
@@ -0,0 +1,32 @@
+import dynamic from "next/dynamic"
+
+export function dynamicComponent(path: string, fallback: any = () => null) {
+ return dynamic(() => import(`@/components/${path}`), {
+ loading: fallback,
+ ssr: false,
+ })
+}
+
+export async function dynamicContext(path: string, fallback?: any) {
+ try {
+ return await import(`@/contexts/${path}`)
+ } catch (error) {
+ return fallback
+ }
+}
+
+export async function dynamicHook(path: string, fallback: any = () => null) {
+ try {
+ return await import(`@/hooks/${path}`)
+ } catch (error) {
+ return fallback
+ }
+}
+
+export async function dynamicService(path: string, fallback?: any) {
+ try {
+ return await import(`@/services/${path}`)
+ } catch (error) {
+ return fallback
+ }
+}
diff --git a/agenta-web/src/lib/helpers/utils.ts b/agenta-web/src/lib/helpers/utils.ts
index a907704b8d..384394586a 100644
--- a/agenta-web/src/lib/helpers/utils.ts
+++ b/agenta-web/src/lib/helpers/utils.ts
@@ -1,5 +1,4 @@
import {v4 as uuidv4} from "uuid"
-import dynamic from "next/dynamic"
import {EvaluationType} from "../enums"
import {GenericObject} from "../Types"
import promiseRetry from "promise-retry"
@@ -214,21 +213,6 @@ export const stringToNumberInRange = (text: string, min: number, max: number) =>
return result
}
-export const getInitials = (str: string, limit = 2) => {
- let initialText = "E"
-
- try {
- initialText = str
- ?.split(" ")
- .slice(0, limit)
- ?.reduce((acc, curr) => acc + (curr[0] || "")?.toUpperCase(), "")
- } catch (error) {
- console.log("Error using getInitials", error)
- }
-
- return initialText
-}
-
export const isDemo = () => {
if (process.env.NEXT_PUBLIC_FF) {
return ["cloud", "ee"].includes(process.env.NEXT_PUBLIC_FF)
@@ -236,13 +220,6 @@ export const isDemo = () => {
return false
}
-export function dynamicComponent(path: string, fallback: any = () => null) {
- return dynamic(() => import(`@/components/${path}`), {
- loading: fallback,
- ssr: false,
- })
-}
-
export const removeKeys = (obj: GenericObject, keys: string[]) => {
let newObj = Object.assign({}, obj)
for (let key of keys) {
@@ -435,3 +412,10 @@ export const redirectIfNoLLMKeys = () => {
}
return false
}
+
+export const snakeToTitle = (str: string) => {
+ return str
+ .split("_")
+ .map((word) => word.charAt(0).toUpperCase() + word.slice(1))
+ .join(" ")
+}
diff --git a/agenta-web/src/lib/hooks/useVariant.ts b/agenta-web/src/lib/hooks/useVariant.ts
index abc659c4a4..1262c656ba 100644
--- a/agenta-web/src/lib/hooks/useVariant.ts
+++ b/agenta-web/src/lib/hooks/useVariant.ts
@@ -1,8 +1,9 @@
-import {useState, useEffect, useContext} from "react"
+import {useState, useEffect} from "react"
import {promptVersioning, saveNewVariant, updateVariantParams} from "@/lib/services/api"
import {Variant, Parameter, IPromptVersioning} from "@/lib/Types"
import {getAllVariantParameters, updateInputParams} from "@/lib/helpers/variantHelper"
import {isDemo} from "../helpers/utils"
+import {PERMISSION_ERR_MSG} from "../helpers/axiosConfig"
/**
* Hook for using the variant.
@@ -42,10 +43,12 @@ export function useVariant(appId: string, variant: Variant) {
setIsChatVariant(isChatVariant)
setHistoryStatus({loading: false, error: true})
} catch (error: any) {
- console.log(error)
- setIsError(true)
- setError(error)
- setHistoryStatus({loading: false, error: true})
+ if (error.message !== PERMISSION_ERR_MSG) {
+ console.log(error)
+ setIsError(true)
+ setError(error)
+ setHistoryStatus({loading: false, error: true})
+ }
} finally {
setIsLoading(false)
setHistoryStatus({loading: false, error: false})
@@ -92,8 +95,10 @@ export function useVariant(appId: string, variant: Variant) {
}, {})
}
setPromptOptParams(updatedOptParams)
- } catch (error) {
- setIsError(true)
+ } catch (error: any) {
+ if (error.message !== PERMISSION_ERR_MSG) {
+ setIsError(true)
+ }
} finally {
setIsParamSaveLoading(false)
}
diff --git a/agenta-web/src/lib/services/api.ts b/agenta-web/src/lib/services/api.ts
index 559cd3b4df..8c935903dc 100644
--- a/agenta-web/src/lib/services/api.ts
+++ b/agenta-web/src/lib/services/api.ts
@@ -16,7 +16,6 @@ import {
DeploymentRevisionConfig,
CreateCustomEvaluation,
ExecuteCustomEvalCode,
- ListAppsItem,
AICritiqueCreate,
ChatMessage,
KeyValuePair,
@@ -26,13 +25,13 @@ import {
fromEvaluationScenarioResponseToEvaluationScenario,
} from "../transformers"
import {EvaluationFlow, EvaluationType} from "../enums"
-import {delay, getAgentaApiUrl, removeKeys, shortPoll} from "../helpers/utils"
-import {useProfileData} from "@/contexts/profile.context"
+import {getAgentaApiUrl, removeKeys, shortPoll} from "../helpers/utils"
+import {dynamicContext} from "../helpers/dynamic"
/**
* Raw interface for the parameters parsed from the openapi.json
*/
-const fetcher = (url: string) => axios.get(url).then((res) => res.data)
+export const axiosFetcher = (url: string) => axios.get(url).then((res) => res.data)
export async function fetchVariants(
appId: string,
@@ -251,8 +250,8 @@ export async function removeVariant(variantId: string) {
export const useLoadTestsetsList = (appId: string) => {
const {data, error, mutate, isLoading} = useSWR(
`${getAgentaApiUrl()}/api/testsets/?app_id=${appId}`,
- fetcher,
- {revalidateOnFocus: false},
+ axiosFetcher,
+ {revalidateOnFocus: false, shouldRetryOnError: false},
)
return {
@@ -536,33 +535,12 @@ export const updateEvaluationScenarioScore = async (
return response
}
-export const useApps = () => {
- const {selectedOrg} = useProfileData()
- const {data, error, isLoading, mutate} = useSWR(
- `${getAgentaApiUrl()}/api/apps/?org_id=${selectedOrg?.id}`,
- selectedOrg?.id ? fetcher : () => {}, //doon't fetch if org is not selected
- )
-
- return {
- data: (data || []) as ListAppsItem[],
- error,
- isLoading: selectedOrg?.id ? isLoading : true,
- mutate,
- }
-}
-
export const getProfile = async (ignoreAxiosError: boolean = false) => {
return axios.get(`${getAgentaApiUrl()}/api/profile/`, {
_ignoreError: ignoreAxiosError,
} as any)
}
-export const getOrgsList = async (ignoreAxiosError: boolean = false) => {
- return axios.get(`${getAgentaApiUrl()}/api/organizations/`, {
- _ignoreError: ignoreAxiosError,
- } as any)
-}
-
export const getTemplates = async () => {
const response = await axios.get(`${getAgentaApiUrl()}/api/containers/templates/`)
return response.data
@@ -616,14 +594,12 @@ export const createAndStartTemplate = async ({
appName,
providerKey,
templateId,
- orgId,
timeout,
onStatusChange,
}: {
appName: string
providerKey: Array<{title: string; key: string; name: string}>
templateId: string
- orgId: string
timeout?: number
onStatusChange?: (
status: "creating_app" | "starting_app" | "success" | "bad_request" | "timeout" | "error",
@@ -640,6 +616,12 @@ export const createAndStartTemplate = async ({
)
try {
+ const {getOrgValues} = await dynamicContext("org.context", {
+ getOrgValues: () => ({
+ selectedOrg: {id: undefined, default_workspace: {id: undefined}},
+ }),
+ })
+ const {selectedOrg} = getOrgValues()
onStatusChange?.("creating_app")
let app
try {
@@ -647,8 +629,9 @@ export const createAndStartTemplate = async ({
{
app_name: appName,
template_id: templateId,
+ organization_id: selectedOrg.id,
+ workspace_id: selectedOrg.default_workspace.id,
env_vars: apiKeys,
- organization_id: orgId,
},
true,
)
diff --git a/agenta-web/src/pages/apps/[app_id]/endpoints/index.tsx b/agenta-web/src/pages/apps/[app_id]/endpoints/index.tsx
index 263325b5a7..f99a2963a4 100644
--- a/agenta-web/src/pages/apps/[app_id]/endpoints/index.tsx
+++ b/agenta-web/src/pages/apps/[app_id]/endpoints/index.tsx
@@ -5,7 +5,8 @@ import DynamicCodeBlock from "@/components/DynamicCodeBlock/DynamicCodeBlock"
import ResultComponent from "@/components/ResultComponent/ResultComponent"
import {useQueryParam} from "@/hooks/useQuery"
import {Environment, GenericObject, Parameter, Variant} from "@/lib/Types"
-import {dynamicComponent, isDemo} from "@/lib/helpers/utils"
+import {isDemo} from "@/lib/helpers/utils"
+import {dynamicComponent} from "@/lib/helpers/dynamic"
import {useVariant} from "@/lib/hooks/useVariant"
import {fetchEnvironments, fetchVariants, getAppContainerURL} from "@/lib/services/api"
import {ApiOutlined, AppstoreOutlined, DownOutlined, HistoryOutlined} from "@ant-design/icons"
diff --git a/agenta-web/src/pages/settings/index.tsx b/agenta-web/src/pages/settings/index.tsx
index d2c7e43151..5fba15eb9b 100644
--- a/agenta-web/src/pages/settings/index.tsx
+++ b/agenta-web/src/pages/settings/index.tsx
@@ -1,10 +1,10 @@
import Secrets from "@/components/pages/settings/Secrets/Secrets"
import ProtectedRoute from "@/components/ProtectedRoute/ProtectedRoute"
import {useQueryParam} from "@/hooks/useQuery"
-import {isFeatureEnabled} from "@/lib/helpers/featureFlag"
-import {dynamicComponent, isDemo} from "@/lib/helpers/utils"
+import {dynamicComponent} from "@/lib/helpers/dynamic"
+import {isDemo} from "@/lib/helpers/utils"
import {ApartmentOutlined, KeyOutlined, LockOutlined} from "@ant-design/icons"
-import {Tabs, Typography} from "antd"
+import {Space, Tabs, Typography} from "antd"
import {createUseStyles} from "react-jss"
const useStyles = createUseStyles({
@@ -36,10 +36,10 @@ const Settings: React.FC = () => {
const items = [
{
label: (
-
+
Workspace
-
+
),
key: "workspace",
children: ,
@@ -47,20 +47,20 @@ const Settings: React.FC = () => {
},
{
label: (
-
+
LLM Keys
-
+
),
key: "secrets",
children: ,
},
{
label: (
-
+
API Keys
-
+
),
key: "apiKeys",
children: ,
diff --git a/docker-compose.test.yml b/docker-compose.test.yml
index fe2525335a..8be8cfcd99 100644
--- a/docker-compose.test.yml
+++ b/docker-compose.test.yml
@@ -22,7 +22,7 @@ services:
- DOMAIN_NAME=http://localhost
- CELERY_BROKER_URL=amqp://guest@rabbitmq//
- CELERY_RESULT_BACKEND=redis://redis:6379/0
- - DATABASE_MODE=v2
+ - DATABASE_MODE=test
- FEATURE_FLAG=oss
- OPENAI_API_KEY=${OPENAI_API_KEY}
- AGENTA_TEMPLATE_REPO=agentaai/templates_v2
@@ -61,6 +61,21 @@ services:
networks:
- agenta-network
+ mongo_express:
+ image: mongo-express:0.54.0
+ environment:
+ ME_CONFIG_MONGODB_ADMINUSERNAME: username
+ ME_CONFIG_MONGODB_ADMINPASSWORD: password
+ ME_CONFIG_MONGODB_SERVER: mongo
+ ports:
+ - "8081:8081"
+ networks:
+ - agenta-network
+ depends_on:
+ mongo:
+ condition: service_healthy
+ restart: always
+
mongo:
image: mongo:5.0
container_name: agenta-mongo-test