Skip to content

Commit

Permalink
Merge branch 'main' into fix/978-inputs-not-saved-to-eval-scenario
Browse files Browse the repository at this point in the history
  • Loading branch information
bekossy committed Dec 5, 2023
2 parents 523cbe4 + 9a89cd0 commit e7c016f
Show file tree
Hide file tree
Showing 41 changed files with 1,090 additions and 837 deletions.
4 changes: 3 additions & 1 deletion agenta-backend/agenta_backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os

if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
if os.environ["FEATURE_FLAG"] in ["cloud"]:
import agenta_backend.cloud.__init__
if os.environ["FEATURE_FLAG"] in ["ee"]:
import agenta_backend.ee.__init__
2 changes: 1 addition & 1 deletion agenta-backend/agenta_backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)

if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
from agenta_backend.cloud.services import templates_manager
from agenta_backend.commons.services import templates_manager
else:
from agenta_backend.services import templates_manager

Expand Down
4 changes: 2 additions & 2 deletions agenta-backend/agenta_backend/models/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def app_variant_db_to_pydantic(
app_id=str(app_variant_db.app.id),
app_name=app_variant_db.app.app_name,
variant_name=app_variant_db.variant_name,
parameters=app_variant_db.parameters,
parameters=app_variant_db.config.parameters,
previous_variant_name=app_variant_db.previous_variant_name,
organization_id=str(app_variant_db.organization.id),
base_name=app_variant_db.base_name,
Expand All @@ -132,7 +132,7 @@ async def app_variant_db_to_output(app_variant_db: AppVariantDB) -> AppVariantOu
variant_id=str(app_variant_db.id),
user_id=str(app_variant_db.user.id),
organization_id=str(app_variant_db.organization.id),
parameters=app_variant_db.parameters,
parameters=app_variant_db.config.parameters,
previous_variant_name=app_variant_db.previous_variant_name,
base_name=app_variant_db.base_name,
base_id=str(app_variant_db.base.id),
Expand Down
8 changes: 4 additions & 4 deletions agenta-backend/agenta_backend/routers/app_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from agenta_backend.models import converters

if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
from agenta_backend.cloud.services.selectors import (
from agenta_backend.commons.services.selectors import (
get_user_and_org_id,
) # noqa pylint: disable-all
else:
Expand Down Expand Up @@ -352,7 +352,7 @@ async def create_app_and_variant_from_template(

logger.debug("Step 5: Retrieve template from db")
template_db = await db_manager.get_template(payload.template_id)
repo_name = os.environ.get("AGENTA_TEMPLATE_REPO", "agentaai/lambda_templates")
repo_name = os.environ.get("AGENTA_TEMPLATE_REPO", "agentaai/templates_v2")
image_name = f"{repo_name}:{template_db.name}"

logger.debug(
Expand All @@ -362,10 +362,10 @@ async def create_app_and_variant_from_template(
app=app,
variant_name="app.default",
docker_id_or_template_uri=template_db.template_uri
if os.environ["FEATURE_FLAG"] in ["cloud"]
if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]
else template_db.digest,
tags=f"{image_name}"
if os.environ["FEATURE_FLAG"] not in ["cloud"]
if os.environ["FEATURE_FLAG"] not in ["cloud", "ee"]
else None,
base_name="app",
config_name="default",
Expand Down
2 changes: 1 addition & 1 deletion agenta-backend/agenta_backend/routers/bases_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from agenta_backend.models import converters

if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
from agenta_backend.cloud.services.selectors import (
from agenta_backend.commons.services.selectors import (
get_user_and_org_id,
) # noqa pylint: disable-all
else:
Expand Down
2 changes: 1 addition & 1 deletion agenta-backend/agenta_backend/routers/configs_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)

if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
from agenta_backend.cloud.services.selectors import (
from agenta_backend.commons.services.selectors import (
get_user_and_org_id,
) # noqa pylint: disable-all
else:
Expand Down
2 changes: 1 addition & 1 deletion agenta-backend/agenta_backend/routers/container_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from fastapi.responses import JSONResponse

if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
from agenta_backend.cloud.services.selectors import (
from agenta_backend.commons.services.selectors import (
get_user_and_org_id,
) # noqa pylint: disable-all
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)

if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
from agenta_backend.cloud.services.selectors import (
from agenta_backend.commons.services.selectors import (
get_user_and_org_id,
) # noqa pylint: disable-all
else:
Expand Down
2 changes: 1 addition & 1 deletion agenta-backend/agenta_backend/routers/evaluation_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from agenta_backend.services import results_service

if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
from agenta_backend.cloud.services.selectors import ( # noqa pylint: disable-all
from agenta_backend.commons.services.selectors import ( # noqa pylint: disable-all
get_user_and_org_id,
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)

if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
from agenta_backend.cloud.services.selectors import (
from agenta_backend.commons.services.selectors import (
get_user_and_org_id,
) # noqa pylint: disable-all
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from agenta_backend.services import db_manager

if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
from agenta_backend.cloud.services.selectors import (
from agenta_backend.commons.services.selectors import (
get_user_and_org_id,
) # noqa pylint: disable-all
else:
Expand Down
8 changes: 6 additions & 2 deletions agenta-backend/agenta_backend/routers/testset_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from fastapi import HTTPException, APIRouter, UploadFile, File, Form, Request
from fastapi.responses import JSONResponse
from pydantic import ValidationError

from agenta_backend.models.api.testset_model import (
TestSetSimpleResponse,
Expand All @@ -28,7 +29,7 @@


if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
from agenta_backend.cloud.services.selectors import (
from agenta_backend.commons.services.selectors import (
get_user_and_org_id,
) # noqa pylint: disable-all
else:
Expand Down Expand Up @@ -99,7 +100,10 @@ async def upload_file(
document["csvdata"].append(row)

user = await get_user(user_uid=user_org_data["uid"])
testset_instance = TestSetDB(**document, user=user)
try:
testset_instance = TestSetDB(**document, user=user)
except ValidationError as e:
raise HTTPException(status_code=403, detail=e.errors())
result = await engine.save(testset_instance)

if isinstance(result.id, ObjectId):
Expand Down
2 changes: 1 addition & 1 deletion agenta-backend/agenta_backend/routers/user_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
router = APIRouter()

if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
from agenta_backend.cloud.services.selectors import (
from agenta_backend.commons.services.selectors import (
get_user_and_org_id,
) # noqa pylint: disable-all
else:
Expand Down
2 changes: 1 addition & 1 deletion agenta-backend/agenta_backend/routers/variants_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
from agenta_backend.cloud.services.selectors import (
from agenta_backend.commons.services.selectors import (
get_user_and_org_id,
) # noqa pylint: disable-all
else:
Expand Down
6 changes: 3 additions & 3 deletions agenta-backend/agenta_backend/services/app_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from agenta_backend.services import deployment_manager

if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
from agenta_backend.cloud.services import (
from agenta_backend.commons.services import (
api_key_service,
) # noqa pylint: disable-all

Expand Down Expand Up @@ -208,7 +208,7 @@ async def terminate_and_remove_app_variant(
# If image deletable is True, remove docker image and image db
if image.deletable:
try:
if os.environ["FEATURE_FLAG"] == "cloud":
if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]:
await deployment_manager.remove_repository(image.tags)
else:
await deployment_manager.remove_image(image)
Expand Down Expand Up @@ -379,7 +379,7 @@ async def add_variant_based_on_image(
):
raise ValueError("App variant or image is None")

if os.environ["FEATURE_FLAG"] not in ["cloud"]:
if os.environ["FEATURE_FLAG"] not in ["cloud", "ee"]:
if tags in [None, ""]:
raise ValueError("OSS: Tags is None")

Expand Down
4 changes: 2 additions & 2 deletions agenta-backend/agenta_backend/services/evaluation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ async def prepare_csvdata_and_create_evaluation_scenario(
Args:
csvdata: A list of dictionaries representing the CSV data.
inputs: A list of strings representing the names of the inputs in the variant.
payload_inputs: A list of strings representing the names of the inputs in the variant.
evaluation_type: The type of evaluation
new_evaluation: The instance of EvaluationDB
user: The owner of the evaluation scenario
Expand All @@ -208,7 +208,7 @@ async def prepare_csvdata_and_create_evaluation_scenario(
await engine.delete(new_evaluation)
msg = f"""
Columns in the test set should match the names of the inputs in the variant.
Inputs names in variant are: {inputs} while
Inputs names in variant are: {[variant_input for variant_input in payload_inputs]} while
columns in test set are: {[col for col in datum.keys() if col != 'correct_answer']}
"""
raise HTTPException(
Expand Down
Loading

0 comments on commit e7c016f

Please sign in to comment.