Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate from flask to fast api for uvicorn and asgi support #75

Merged
merged 13 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
312 changes: 312 additions & 0 deletions guardrails_api/api/guards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
import json
import os
import inspect
from typing import Any, Dict, Optional
from fastapi import HTTPException, Request, APIRouter
from fastapi.responses import JSONResponse, StreamingResponse
from urllib.parse import unquote_plus
from guardrails import AsyncGuard, Guard
from guardrails.classes import ValidationOutcome
from opentelemetry.trace import Span
from guardrails_api_client import Guard as GuardStruct
from guardrails_api.clients.cache_client import CacheClient
from guardrails_api.clients.memory_guard_client import MemoryGuardClient
from guardrails_api.clients.pg_guard_client import PGGuardClient
from guardrails_api.clients.postgres_client import postgres_is_enabled
from guardrails_api.utils.get_llm_callable import get_llm_callable
from guardrails_api.utils.openai import (
outcome_to_chat_completion,
outcome_to_stream_response,
)
from guardrails_api.utils.handle_error import handle_error
from string import Template

# if no pg_host is set, use in memory guards
if postgres_is_enabled():
guard_client = PGGuardClient()
else:
guard_client = MemoryGuardClient()
# Will be defined at runtime
import config # noqa

exports = config.__dir__()
for export_name in exports:
export = getattr(config, export_name)
is_guard = isinstance(export, Guard)
if is_guard:
guard_client.create_guard(export)

cache_client = CacheClient()
router = APIRouter()


@router.get("/guards")
@handle_error
async def get_guards():
guards = guard_client.get_guards()
return [g.to_dict() for g in guards]


@router.post("/guards")
@handle_error
async def create_guard(guard: GuardStruct):
if not postgres_is_enabled():
raise HTTPException(
status_code=501,
detail="Not Implemented POST /guards is not implemented for in-memory guards.",
)
new_guard = guard_client.create_guard(guard)
return new_guard.to_dict()


@router.get("/guards/{guard_name}")
@handle_error
async def get_guard(guard_name: str, asOf: Optional[str] = None):
decoded_guard_name = unquote_plus(guard_name)
guard = guard_client.get_guard(decoded_guard_name, asOf)
if guard is None:
raise HTTPException(
status_code=404,
detail=f"A Guard with the name {decoded_guard_name} does not exist!",
)
return guard.to_dict()


@router.put("/guards/{guard_name}")
@handle_error
async def update_guard(guard_name: str, guard: GuardStruct):
if not postgres_is_enabled():
raise HTTPException(
status_code=501,
detail="PUT /<guard_name> is not implemented for in-memory guards.",
)
decoded_guard_name = unquote_plus(guard_name)
updated_guard = guard_client.upsert_guard(decoded_guard_name, guard)
return updated_guard.to_dict()


@router.delete("/guards/{guard_name}")
@handle_error
async def delete_guard(guard_name: str):
if not postgres_is_enabled():
raise HTTPException(
status_code=501,
detail="DELETE /<guard_name> is not implemented for in-memory guards.",
)
decoded_guard_name = unquote_plus(guard_name)
guard = guard_client.delete_guard(decoded_guard_name)
return guard.to_dict()


@router.post("/guards/{guard_name}/openai/v1/chat/completions")
@handle_error
async def openai_v1_chat_completions(guard_name: str, request: Request):
payload = await request.json()
decoded_guard_name = unquote_plus(guard_name)
guard_struct = guard_client.get_guard(decoded_guard_name)
if guard_struct is None:
raise HTTPException(
status_code=404,
detail=f"A Guard with the name {decoded_guard_name} does not exist!",
)

guard = (
Guard.from_dict(guard_struct.to_dict())
if not isinstance(guard_struct, Guard)
else guard_struct
)
stream = payload.get("stream", False)
has_tool_gd_tool_call = any(
tool.get("function", {}).get("name") == "gd_response_tool"
for tool in payload.get("tools", [])
)

if not stream:
validation_outcome: ValidationOutcome = guard(num_reasks=0, **payload)
llm_response = guard.history.last.iterations.last.outputs.llm_response_info
result = outcome_to_chat_completion(
validation_outcome=validation_outcome,
llm_response=llm_response,
has_tool_gd_tool_call=has_tool_gd_tool_call,
)
return JSONResponse(content=result)
else:

async def openai_streamer():
guard_stream = guard(num_reasks=0, **payload)
for result in guard_stream:
chunk = json.dumps(
outcome_to_stream_response(validation_outcome=result)
)
yield f"data: {chunk}\n\n"
yield "\n"

return StreamingResponse(openai_streamer(), media_type="text/event-stream")


@router.post("/guards/{guard_name}/validate")
@handle_error
async def validate(guard_name: str, request: Request):
payload = await request.json()
openai_api_key = request.headers.get(
"x-openai-api-key", os.environ.get("OPENAI_API_KEY")
)
decoded_guard_name = unquote_plus(guard_name)
guard_struct = guard_client.get_guard(decoded_guard_name)

llm_output = payload.pop("llmOutput", None)
num_reasks = payload.pop("numReasks", None)
prompt_params = payload.pop("promptParams", {})
llm_api = payload.pop("llmApi", None)
args = payload.pop("args", [])
stream = payload.pop("stream", False)

payload["api_key"] = payload.get("api_key", openai_api_key)

if llm_api is not None:
llm_api = get_llm_callable(llm_api)
if openai_api_key is None:
raise HTTPException(
status_code=400,
detail="Cannot perform calls to OpenAI without an api key.",
)

guard = guard_struct
is_async = inspect.iscoroutinefunction(llm_api)

if not isinstance(guard_struct, Guard):
if is_async:
guard = AsyncGuard.from_dict(guard_struct.to_dict())
else:
guard: Guard = Guard.from_dict(guard_struct.to_dict())
elif is_async:
guard: Guard = AsyncGuard.from_dict(guard_struct.to_dict())

if llm_api is None and num_reasks and num_reasks > 1:
raise HTTPException(
status_code=400,
detail="Cannot perform re-asks without an LLM API. Specify llm_api when calling guard(...).",
)

if llm_output is not None:
if stream:
raise HTTPException(
status_code=400, detail="Streaming is not supported for parse calls!"
)
result: ValidationOutcome = guard.parse(
llm_output=llm_output,
num_reasks=num_reasks,
prompt_params=prompt_params,
llm_api=llm_api,
**payload,
)
else:
if stream:

async def guard_streamer():
guard_stream = guard(
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
stream=stream,
*args,
**payload,
)
for result in guard_stream:
validation_output = ValidationOutcome.from_guard_history(
guard.history.last
)
yield validation_output, result

async def validate_streamer(guard_iter):
async for validation_output, result in guard_iter:
fragment_dict = result.to_dict()
fragment_dict["error_spans"] = [
json.dumps({"start": x.start, "end": x.end, "reason": x.reason})
for x in guard.error_spans_in_output()
]
yield json.dumps(fragment_dict) + "\n"

call = guard.history.last
final_validation_output = ValidationOutcome(
callId=call.id,
validation_passed=result.validation_passed,
validated_output=result.validated_output,
history=guard.history,
raw_llm_output=result.raw_llm_output,
)
final_output_dict = final_validation_output.to_dict()
final_output_dict["error_spans"] = [
json.dumps({"start": x.start, "end": x.end, "reason": x.reason})
for x in guard.error_spans_in_output()
]
yield json.dumps(final_output_dict) + "\n"

serialized_history = [call.to_dict() for call in guard.history]
cache_key = f"{guard.name}-{final_validation_output.call_id}"
await cache_client.set(cache_key, serialized_history, 300)

return StreamingResponse(
validate_streamer(guard_streamer()), media_type="application/json"
)
else:
result: ValidationOutcome = guard(
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
*args,
**payload,
)

serialized_history = [call.to_dict() for call in guard.history]
cache_key = f"{guard.name}-{result.call_id}"
await cache_client.set(cache_key, serialized_history, 300)
return result.to_dict()


@router.get("/guards/{guard_name}/history/{call_id}")
@handle_error
async def guard_history(guard_name: str, call_id: str):
cache_key = f"{guard_name}-{call_id}"
return await cache_client.get(cache_key)


def collect_telemetry(
*,
guard: Guard,
validate_span: Span,
validation_output: ValidationOutcome,
prompt_params: Dict[str, Any],
result: ValidationOutcome,
):
# Below is all telemetry collection and
# should have no impact on what is returned to the user
prompt = guard.history.last.inputs.prompt
if prompt:
prompt = Template(prompt).safe_substitute(**prompt_params)
validate_span.set_attribute("prompt", prompt)

instructions = guard.history.last.inputs.instructions
if instructions:
instructions = Template(instructions).safe_substitute(**prompt_params)
validate_span.set_attribute("instructions", instructions)

validate_span.set_attribute("validation_status", guard.history.last.status)
validate_span.set_attribute("raw_llm_ouput", result.raw_llm_output)

# Use the serialization from the class instead of re-writing it
valid_output: str = (
json.dumps(validation_output.validated_output)
if isinstance(validation_output.validated_output, dict)
else str(validation_output.validated_output)
)
validate_span.set_attribute("validated_output", valid_output)

validate_span.set_attribute("tokens_consumed", guard.history.last.tokens_consumed)

num_of_reasks = (
guard.history.last.iterations.length - 1
if guard.history.last.iterations.length > 0
else 0
)
validate_span.set_attribute("num_of_reasks", num_of_reasks)
78 changes: 78 additions & 0 deletions guardrails_api/api/root.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os
from string import Template

from fastapi import HTTPException, APIRouter
from fastapi.responses import HTMLResponse, JSONResponse
from pydantic import BaseModel

from guardrails_api.open_api_spec import get_open_api_spec
from sqlalchemy import text
from guardrails_api.classes.health_check import HealthCheck
from guardrails_api.clients.postgres_client import PostgresClient, postgres_is_enabled
from guardrails_api.utils.logger import logger


class HealthCheckResponse(BaseModel):
status: int
message: str


router = APIRouter()


@router.get("/")
async def home():
return "Hello, world!"


@router.get("/health-check", response_model=HealthCheckResponse)
async def health_check():
try:
if not postgres_is_enabled():
return HealthCheck(200, "Ok").to_dict()

pg_client = PostgresClient()
query = text("SELECT count(datid) FROM pg_stat_activity;")
response = pg_client.db.session.execute(query).all()

logger.info("response: %s", response)

return HealthCheck(200, "Ok").to_dict()
except Exception as e:
logger.error(f"Health check failed: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")


@router.get("/api-docs", response_class=JSONResponse)
async def api_docs():
api_spec = get_open_api_spec()
return JSONResponse(content=api_spec)


@router.get("/docs", response_class=HTMLResponse)
async def docs():
host = os.environ.get("SELF_ENDPOINT", "http://localhost:8000")
swagger_ui = Template("""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="description" content="SwaggerUI" />
<title>SwaggerUI</title>
<link rel="stylesheet" href="https://unpkg.com/[email protected]/swagger-ui.css" />
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://unpkg.com/[email protected]/swagger-ui-bundle.js" crossorigin></script>
<script>
window.onload = () => {
window.ui = SwaggerUIBundle({
url: '${apiDocUrl}',
dom_id: '#swagger-ui',
});
};
</script>
</body>
</html>""").safe_substitute(apiDocUrl=f"{host}/api-docs")

return HTMLResponse(content=swagger_ui)
Loading
Loading