Skip to content

Commit

Permalink
Merge pull request #2283 from Agenta-AI/fix/missing-headers-in-get-pa…
Browse files Browse the repository at this point in the history
…rams-from-openapi

[Enhancement]: CORS + App Security hotfix
  • Loading branch information
jp-agenta authored Nov 22, 2024
2 parents 39ca4c8 + 485ab69 commit 1221f27
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 34 deletions.
4 changes: 2 additions & 2 deletions agenta-backend/agenta_backend/routers/permissions_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def __init__(
class Deny(HTTPException):
def __init__(self) -> None:
super().__init__(
status_code=401,
detail="Unauthorized",
status_code=403,
detail="Forbidden",
)


Expand Down
28 changes: 22 additions & 6 deletions agenta-backend/agenta_backend/services/llm_apps_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import traceback
import aiohttp
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional


from agenta_backend.models.shared_models import InvokationResult, Result, Error
Expand Down Expand Up @@ -296,7 +296,17 @@ async def batch_invoke(
list_of_app_outputs: List[
InvokationResult
] = [] # Outputs after running all batches
openapi_parameters = await get_parameters_from_openapi(uri + "/openapi.json")

headers = None
if isCloudEE():
secret_token = await sign_secret_token(user_id, project_id, None)

headers = {"Authorization": f"Secret {secret_token}"}

openapi_parameters = await get_parameters_from_openapi(
uri + "/openapi.json",
headers,
)

async def run_batch(start_idx: int):
tasks = []
Expand Down Expand Up @@ -336,7 +346,10 @@ async def run_batch(start_idx: int):
return list_of_app_outputs


async def get_parameters_from_openapi(uri: str) -> List[Dict]:
async def get_parameters_from_openapi(
uri: str,
headers: Optional[Dict[str, str]],
) -> List[Dict]:
"""
Parse the OpenAI schema of an LLM app to return list of parameters that it takes with their type as determined by the x-parameter
Args:
Expand All @@ -351,7 +364,7 @@ async def get_parameters_from_openapi(uri: str) -> List[Dict]:
"""

schema = await _get_openai_json_from_uri(uri)
schema = await _get_openai_json_from_uri(uri, headers)

try:
body_schema_name = (
Expand Down Expand Up @@ -381,9 +394,12 @@ async def get_parameters_from_openapi(uri: str) -> List[Dict]:
return parameters


async def _get_openai_json_from_uri(uri):
async def _get_openai_json_from_uri(
uri: str,
headers: Optional[Dict[str, str]],
):
async with aiohttp.ClientSession() as client:
resp = await client.get(uri, timeout=5)
resp = await client.get(uri, headers=headers, timeout=5)
resp_text = await resp.text()
json_data = json.loads(resp_text)
return json_data
18 changes: 17 additions & 1 deletion agenta-backend/agenta_backend/tasks/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from celery import shared_task, states

from agenta_backend.utils.common import isCloudEE

if isCloudEE():
from agenta_backend.cloud.services.auth_helper import sign_secret_token

from agenta_backend.services import (
evaluators_service,
llm_apps_service,
Expand Down Expand Up @@ -143,8 +147,20 @@ def evaluate(
)

# 4. Evaluate the app outputs
secret_token = None
headers = None
if isCloudEE():
secret_token = loop.run_until_complete(
sign_secret_token(user_id, project_id, None)
)
if secret_token:
headers = {"Authorization": f"Secret {secret_token}"}

openapi_parameters = loop.run_until_complete(
llm_apps_service.get_parameters_from_openapi(uri + "/openapi.json")
llm_apps_service.get_parameters_from_openapi(
uri + "/openapi.json",
headers,
),
)

for data_point, app_output in zip(testset_db.csvdata, app_outputs): # type: ignore
Expand Down
31 changes: 17 additions & 14 deletions agenta-cli/agenta/sdk/decorators/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,21 @@

import agenta as ag

app = FastAPI()

origins = [
"*",
]

app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
AGENTA_USE_CORS = str(environ.get("AGENTA_USE_CORS", "true")).lower() in (
"true",
"1",
"t",
)

_MIDDLEWARES = True
app = FastAPI()
log.setLevel("DEBUG")


app.include_router(router, prefix="")
_MIDDLEWARES = True


log.setLevel("DEBUG")
app.include_router(router, prefix="")


class PathValidator(BaseModel):
Expand Down Expand Up @@ -137,6 +131,15 @@ def __init__(
resource_type="application",
)

if AGENTA_USE_CORS:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True,
)

_MIDDLEWARES = False

except: # pylint: disable=bare-except
Expand Down
22 changes: 14 additions & 8 deletions agenta-cli/agenta/sdk/middleware/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
15 * 60, # 15 minutes
)

AGENTA_SDK_AUTH_CACHE = str(environ.get("AGENTA_SDK_AUTH_CACHE", True)).lower() in (
"true",
"1",
"t",
)

AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED = str(
environ.get("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", False)
).lower() in ("true", "1", "t")
Expand Down Expand Up @@ -89,9 +95,11 @@ async def dispatch(
sort_keys=True,
)

cached_policy = cache.get(_hash)
policy = None
if AGENTA_SDK_AUTH_CACHE:
policy = cache.get(_hash)

if not cached_policy:
if not policy:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.host}/api/permissions/verify",
Expand All @@ -110,19 +118,17 @@ async def dispatch(
cache.put(_hash, {"effect": "deny"})
return Deny()

cached_policy = {
policy = {
"effect": "allow",
"credentials": auth.get("credentials"),
}

cache.put(_hash, cached_policy)
cache.put(_hash, policy)

if cached_policy.get("effect") == "deny":
if not policy or policy.get("effect") == "deny":
return Deny()

request.state.credentials = cached_policy.get("credentials")

print(f"credentials: {request.state.credentials}")
request.state.credentials = policy.get("credentials")

return await call_next(request)

Expand Down
4 changes: 2 additions & 2 deletions agenta-cli/agenta/sdk/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class LLMTokenUsage(BaseModel):

class BaseResponse(BaseModel):
version: Optional[str] = "2.0"
data: Optional[Union[str, Dict[str, Any]]]
trace: Optional[Dict[str, Any]]
data: Optional[Union[str, Dict[str, Any]]] = None
trace: Optional[Dict[str, Any]] = None


class DictInput(dict):
Expand Down
28 changes: 27 additions & 1 deletion agenta-web/src/services/api.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import axios from "@/lib//helpers/axiosConfig"
import Session from "supertokens-auth-react/recipe/session"
import {formatDay} from "@/lib/helpers/dateTimeHelper"
import {
detectChatVariantFromOpenAISchema,
Expand Down Expand Up @@ -113,17 +114,36 @@ export async function callVariant(
}

const appContainerURI = await fetchAppContainerURL(appId, undefined, baseId)
const jwt = await getJWT()

return axios
.post(`${appContainerURI}/generate`, requestBody, {
signal,
_ignoreError: ignoreAxiosError,
headers: {
Authorization: jwt && `Bearer ${jwt}`,
},
} as any)
.then((res) => {
return res.data
})
}

/**
* Get the JWT from SuperTokens
*/
const getJWT = async () => {
try {
if (await Session.doesSessionExist()) {
let jwt = await Session.getAccessToken()

return jwt
}
} catch (error) {}

return undefined
}

/**
* Parses the openapi.json from a variant and returns the parameters as an array of objects.
* @param app
Expand All @@ -138,7 +158,13 @@ export const fetchVariantParametersFromOpenAPI = async (
) => {
const appContainerURI = await fetchAppContainerURL(appId, variantId, baseId)
const url = `${appContainerURI}/openapi.json`
const response = await axios.get(url, {_ignoreError: ignoreAxiosError} as any)
const jwt = await getJWT()
const response = await axios.get(url, {
_ignoreError: ignoreAxiosError,
headers: {
Authorization: jwt && `Bearer ${jwt}`,
},
} as any)
const isChatVariant = detectChatVariantFromOpenAISchema(response.data)
let APIParams = openAISchemaToParameters(response.data)

Expand Down

0 comments on commit 1221f27

Please sign in to comment.