From 77e9bce9e50890e16ddfb192f61335c962374487 Mon Sep 17 00:00:00 2001 From: David Tam Date: Fri, 13 Sep 2024 13:18:48 -0700 Subject: [PATCH 01/13] migration from flask to fastapi for support asgi and for better performance --- .../{blueprints => api}/__init__.py | 0 guardrails_api/api/guards.py | 263 +++++++ guardrails_api/api/root.py | 74 ++ guardrails_api/app.py | 157 ++-- guardrails_api/blueprints/guards.py | 516 ------------- guardrails_api/blueprints/root.py | 79 -- guardrails_api/clients/cache_client.py | 53 +- guardrails_api/start-dev.sh | 13 +- guardrails_api/start.sh | 2 +- guardrails_api/utils/handle_error.py | 50 +- pyproject.toml | 9 +- requirements-lock.txt | 6 + tests/{blueprints => api}/__init__.py | 0 tests/api/test_guards.py | 370 +++++++++ tests/api/test_root.py | 59 ++ tests/blueprints/test_guards.py | 719 ------------------ tests/blueprints/test_root.py | 48 -- 17 files changed, 956 insertions(+), 1462 deletions(-) rename guardrails_api/{blueprints => api}/__init__.py (100%) create mode 100644 guardrails_api/api/guards.py create mode 100644 guardrails_api/api/root.py delete mode 100644 guardrails_api/blueprints/guards.py delete mode 100644 guardrails_api/blueprints/root.py rename tests/{blueprints => api}/__init__.py (100%) create mode 100644 tests/api/test_guards.py create mode 100644 tests/api/test_root.py delete mode 100644 tests/blueprints/test_guards.py delete mode 100644 tests/blueprints/test_root.py diff --git a/guardrails_api/blueprints/__init__.py b/guardrails_api/api/__init__.py similarity index 100% rename from guardrails_api/blueprints/__init__.py rename to guardrails_api/api/__init__.py diff --git a/guardrails_api/api/guards.py b/guardrails_api/api/guards.py new file mode 100644 index 0000000..076c2d8 --- /dev/null +++ b/guardrails_api/api/guards.py @@ -0,0 +1,263 @@ +import asyncio +import json +import os +import inspect +from typing import Any, Dict, List, Optional +from fastapi import FastAPI, HTTPException, Request, Response, APIRouter +from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import BaseModel +from urllib.parse import unquote_plus +from guardrails import AsyncGuard, Guard +from guardrails.classes import ValidationOutcome +from opentelemetry.trace import Span +from guardrails_api_client import Guard as GuardStruct +from guardrails_api.clients.cache_client import CacheClient +from guardrails_api.clients.memory_guard_client import MemoryGuardClient +from guardrails_api.clients.pg_guard_client import PGGuardClient +from guardrails_api.clients.postgres_client import postgres_is_enabled +from guardrails_api.utils.get_llm_callable import get_llm_callable +from guardrails_api.utils.openai import outcome_to_chat_completion, outcome_to_stream_response +from guardrails_api.utils.handle_error import handle_error +from string import Template + +# if no pg_host is set, use in memory guards +if postgres_is_enabled(): + guard_client = PGGuardClient() +else: + guard_client = MemoryGuardClient() + # Will be defined at runtime + import config # noqa + + exports = config.__dir__() + for export_name in exports: + export = getattr(config, export_name) + is_guard = isinstance(export, Guard) + if is_guard: + guard_client.create_guard(export) + +cache_client = CacheClient() + +router = APIRouter() + +@router.get("/guards") +@handle_error +async def get_guards(): + guards = guard_client.get_guards() + return [g.to_dict() for g in guards] + +@router.post("/guards") +@handle_error +async def create_guard(guard: GuardStruct): + if not postgres_is_enabled(): + raise HTTPException(status_code=501, detail="Not Implemented POST /guards is not implemented for in-memory guards.") + new_guard = guard_client.create_guard(guard) + return new_guard.to_dict() + +@router.get("/guards/{guard_name}") +@handle_error +async def get_guard(guard_name: str, asOf: Optional[str] = None): + decoded_guard_name = unquote_plus(guard_name) + guard = guard_client.get_guard(decoded_guard_name, asOf) + if guard is None: + raise HTTPException(status_code=404, detail=f"A Guard with the name {decoded_guard_name} does not exist!") + return guard.to_dict() + +@router.put("/guards/{guard_name}") +@handle_error +async def update_guard(guard_name: str, guard: GuardStruct): + if not postgres_is_enabled(): + raise HTTPException(status_code=501, detail="PUT / is not implemented for in-memory guards.") + decoded_guard_name = unquote_plus(guard_name) + updated_guard = guard_client.upsert_guard(decoded_guard_name, guard) + return updated_guard.to_dict() + +@router.delete("/guards/{guard_name}") +@handle_error +async def delete_guard(guard_name: str): + if not postgres_is_enabled(): + raise HTTPException(status_code=501, detail="DELETE / is not implemented for in-memory guards.") + decoded_guard_name = unquote_plus(guard_name) + guard = guard_client.delete_guard(decoded_guard_name) + return guard.to_dict() + +@router.post("/guards/{guard_name}/openai/v1/chat/completions") +@handle_error +async def openai_v1_chat_completions(guard_name: str, request: Request): + payload = await request.json() + decoded_guard_name = unquote_plus(guard_name) + guard_struct = guard_client.get_guard(decoded_guard_name) + if guard_struct is None: + raise HTTPException(status_code=404, detail=f"A Guard with the name {decoded_guard_name} does not exist!") + + guard = Guard.from_dict(guard_struct.to_dict()) if not isinstance(guard_struct, Guard) else guard_struct + stream = payload.get("stream", False) + has_tool_gd_tool_call = any(tool.get("function", {}).get("name") == "gd_response_tool" for tool in payload.get("tools", [])) + + if not stream: + validation_outcome: ValidationOutcome = guard(num_reasks=0, **payload) + llm_response = guard.history.last.iterations.last.outputs.llm_response_info + result = outcome_to_chat_completion( + validation_outcome=validation_outcome, + llm_response=llm_response, + has_tool_gd_tool_call=has_tool_gd_tool_call, + ) + return JSONResponse(content=result) + else: + async def openai_streamer(): + guard_stream = guard(num_reasks=0, **payload) + for result in guard_stream: + chunk = json.dumps(outcome_to_stream_response(validation_outcome=result)) + yield f"data: {chunk}\n\n" + yield "\n" + + return StreamingResponse(openai_streamer(), media_type="text/event-stream") + +@router.post("/guards/{guard_name}/validate") +@handle_error +async def validate(guard_name: str, request: Request): + payload = await request.json() + openai_api_key = request.headers.get("x-openai-api-key", os.environ.get("OPENAI_API_KEY")) + decoded_guard_name = unquote_plus(guard_name) + guard_struct = guard_client.get_guard(decoded_guard_name) + + llm_output = payload.pop("llmOutput", None) + num_reasks = payload.pop("numReasks", None) + prompt_params = payload.pop("promptParams", {}) + llm_api = payload.pop("llmApi", None) + args = payload.pop("args", []) + stream = payload.pop("stream", False) + + payload["api_key"] = payload.get("api_key", openai_api_key) + + if llm_api is not None: + llm_api = get_llm_callable(llm_api) + if openai_api_key is None: + raise HTTPException(status_code=400, detail="Cannot perform calls to OpenAI without an api key.") + + guard = guard_struct + is_async = inspect.iscoroutinefunction(llm_api) + + if not isinstance(guard_struct, Guard): + if is_async: + guard = AsyncGuard.from_dict(guard_struct.to_dict()) + else: + guard: Guard = Guard.from_dict(guard_struct.to_dict()) + elif is_async: + guard:Guard = AsyncGuard.from_dict(guard_struct.to_dict()) + + if llm_api is None and num_reasks and num_reasks > 1: + raise HTTPException(status_code=400, detail="Cannot perform re-asks without an LLM API. Specify llm_api when calling guard(...).") + + if llm_output is not None: + if stream: + raise HTTPException(status_code=400, detail="Streaming is not supported for parse calls!") + result: ValidationOutcome = guard.parse( + llm_output=llm_output, + num_reasks=num_reasks, + prompt_params=prompt_params, + llm_api=llm_api, + **payload, + ) + else: + if stream: + async def guard_streamer(): + guard_stream = guard( + llm_api=llm_api, + prompt_params=prompt_params, + num_reasks=num_reasks, + stream=stream, + *args, + **payload, + ) + for result in guard_stream: + validation_output = ValidationOutcome.from_guard_history(guard.history.last) + yield validation_output, result + + async def validate_streamer(guard_iter): + async for validation_output, result in guard_iter: + fragment_dict = result.to_dict() + fragment_dict["error_spans"] = [ + json.dumps({"start": x.start, "end": x.end, "reason": x.reason}) + for x in guard.error_spans_in_output() + ] + yield json.dumps(fragment_dict) + "\n" + + call = guard.history.last + final_validation_output = ValidationOutcome( + callId=call.id, + validation_passed=result.validation_passed, + validated_output=result.validated_output, + history=guard.history, + raw_llm_output=result.raw_llm_output, + ) + final_output_dict = final_validation_output.to_dict() + final_output_dict["error_spans"] = [ + json.dumps({"start": x.start, "end": x.end, "reason": x.reason}) + for x in guard.error_spans_in_output() + ] + yield json.dumps(final_output_dict) + "\n" + + serialized_history = [call.to_dict() for call in guard.history] + cache_key = f"{guard.name}-{final_validation_output.call_id}" + await cache_client.set(cache_key, serialized_history, 300) + + return StreamingResponse(validate_streamer(guard_streamer()), media_type="application/json") + else: + result: ValidationOutcome = guard( + llm_api=llm_api, + prompt_params=prompt_params, + num_reasks=num_reasks, + *args, + **payload, + ) + + # serialized_history = [call.to_dict() for call in guard.history] + # cache_key = f"{guard.name}-{result.call_id}" + # await cache_client.set(cache_key, serialized_history, 300) + return result.to_dict() + +@router.get("/guards/{guard_name}/history/{call_id}") +@handle_error +async def guard_history(guard_name: str, call_id: str): + cache_key = f"{guard_name}-{call_id}" + return await cache_client.get(cache_key) + +def collect_telemetry( + *, + guard: Guard, + validate_span: Span, + validation_output: ValidationOutcome, + prompt_params: Dict[str, Any], + result: ValidationOutcome, +): + # Below is all telemetry collection and + # should have no impact on what is returned to the user + prompt = guard.history.last.inputs.prompt + if prompt: + prompt = Template(prompt).safe_substitute(**prompt_params) + validate_span.set_attribute("prompt", prompt) + + instructions = guard.history.last.inputs.instructions + if instructions: + instructions = Template(instructions).safe_substitute(**prompt_params) + validate_span.set_attribute("instructions", instructions) + + validate_span.set_attribute("validation_status", guard.history.last.status) + validate_span.set_attribute("raw_llm_ouput", result.raw_llm_output) + + # Use the serialization from the class instead of re-writing it + valid_output: str = ( + json.dumps(validation_output.validated_output) + if isinstance(validation_output.validated_output, dict) + else str(validation_output.validated_output) + ) + validate_span.set_attribute("validated_output", valid_output) + + validate_span.set_attribute("tokens_consumed", guard.history.last.tokens_consumed) + + num_of_reasks = ( + guard.history.last.iterations.length - 1 + if guard.history.last.iterations.length > 0 + else 0 + ) + validate_span.set_attribute("num_of_reasks", num_of_reasks) diff --git a/guardrails_api/api/root.py b/guardrails_api/api/root.py new file mode 100644 index 0000000..2b0a73f --- /dev/null +++ b/guardrails_api/api/root.py @@ -0,0 +1,74 @@ +import os +import json +from string import Template +from typing import Dict + +from fastapi import HTTPException, APIRouter +from fastapi.responses import HTMLResponse, JSONResponse +from pydantic import BaseModel + +from guardrails_api.open_api_spec import get_open_api_spec +from sqlalchemy import text +from guardrails_api.classes.health_check import HealthCheck +from guardrails_api.clients.postgres_client import PostgresClient, postgres_is_enabled +from guardrails_api.utils.logger import logger + +class HealthCheckResponse(BaseModel): + status: int + message: str + +router = APIRouter() + +@router.get("/") +async def home(): + return "Hello, FastAPI!" + +@router.get("/health-check", response_model=HealthCheckResponse) +async def health_check(): + try: + if not postgres_is_enabled(): + return HealthCheck(200, "Ok").to_dict() + + pg_client = PostgresClient() + query = text("SELECT count(datid) FROM pg_stat_activity;") + response = pg_client.db.session.execute(query).all() + + logger.info("response: %s", response) + + return HealthCheck(200, "Ok").to_dict() + except Exception as e: + logger.error(f"Health check failed: {str(e)}") + raise HTTPException(status_code=500, detail="Internal Server Error") + +@router.get("/api-docs", response_class=JSONResponse) +async def api_docs(): + api_spec = get_open_api_spec() + return JSONResponse(content=api_spec) + +@router.get("/docs", response_class=HTMLResponse) +async def docs(): + host = os.environ.get("SELF_ENDPOINT", "http://localhost:8000") + swagger_ui = Template(""" + + + + + + SwaggerUI + + + +
+ + + +""").safe_substitute(apiDocUrl=f"{host}/api-docs") + + return HTMLResponse(content=swagger_ui) diff --git a/guardrails_api/app.py b/guardrails_api/app.py index a33fbbc..7e05bf0 100644 --- a/guardrails_api/app.py +++ b/guardrails_api/app.py @@ -1,22 +1,75 @@ -import os -from typing import Optional -from flask import Flask -from flask.json.provider import DefaultJSONProvider -from flask_cors import CORS -from werkzeug.middleware.proxy_fix import ProxyFix -from urllib.parse import urlparse +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse from guardrails import configure_logging -from opentelemetry.instrumentation.flask import FlaskInstrumentor +from guardrails_api.clients.cache_client import CacheClient +from guardrails_api.clients.cache_client import CacheClient from guardrails_api.clients.postgres_client import postgres_is_enabled from guardrails_api.otel import otel_is_disabled, initialize from guardrails_api.utils.trace_server_start_if_enabled import trace_server_start_if_enabled -from guardrails_api.clients.cache_client import CacheClient +from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from rich.console import Console from rich.rule import Rule +from starlette.middleware.base import BaseHTTPMiddleware +from typing import Optional +from urllib.parse import urlparse +import importlib.util +import json +import os -# TODO: Move this to a separate file -class OverrideJsonProvider(DefaultJSONProvider): +# from pyinstrument import Profiler +# from pyinstrument.renderers.html import HTMLRenderer +# from pyinstrument.renderers.speedscope import SpeedscopeRenderer +# from starlette.middleware.base import RequestResponseEndpoint +# class ProfilingMiddleware(BaseHTTPMiddleware): +# async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: +# """Profile the current request + +# Taken from https://pyinstrument.readthedocs.io/en/latest/guide.html#profile-a-web-request-in-fastapi +# with small improvements. + +# """ +# # we map a profile type to a file extension, as well as a pyinstrument profile renderer +# profile_type_to_ext = {"html": "html", "speedscope": "speedscope.json"} +# profile_type_to_renderer = { +# "html": HTMLRenderer, +# "speedscope": SpeedscopeRenderer, +# } + +# if request.headers.get("X-Profile-Request"): +# # The default profile format is speedscope +# profile_type = request.query_params.get("profile_format", "speedscope") + +# # we profile the request along with all additional middlewares, by interrupting +# # the program every 1ms1 and records the entire stack at that point +# with Profiler(interval=0.001, async_mode="enabled") as profiler: +# response = await call_next(request) + +# # we dump the profiling into a file +# # Generate a unique filename based on timestamp and request properties +# timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") +# method = request.method +# path = request.url.path.replace("/", "_").strip("_") +# extension = profile_type_to_ext[profile_type] +# filename = f"profile_{timestamp}_{method}_{path}.{extension}" + +# # Ensure the profiling directory exists +# profiling_dir = "profiling" +# os.makedirs(profiling_dir, exist_ok=True) + +# # Dump the profiling into a file +# renderer = profile_type_to_renderer[profile_type]() +# filepath = os.path.join(profiling_dir, filename) +# with open(filepath, "w") as out: +# out.write(profiler.output(renderer=renderer)) + +# return response +# else: +# return await call_next(request) + +# Custom JSON encoder +class CustomJSONEncoder(json.JSONEncoder): def default(self, o): if isinstance(o, set): return list(o) @@ -24,30 +77,23 @@ def default(self, o): return str(o) return super().default(o) - -class ReverseProxied(object): - def __init__(self, app): - self.app = app - - def __call__(self, environ, start_response): +# Custom middleware for reverse proxy +class ReverseProxyMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): self_endpoint = os.environ.get("SELF_ENDPOINT", "http://localhost:8000") url = urlparse(self_endpoint) - environ["wsgi.url_scheme"] = url.scheme - return self.app(environ, start_response) - + request.scope["scheme"] = url.scheme + response = await call_next(request) + return response def register_config(config: Optional[str] = None): default_config_file = os.path.join(os.getcwd(), "./config.py") config_file = config or default_config_file config_file_path = os.path.abspath(config_file) if os.path.isfile(config_file_path): - from importlib.machinery import SourceFileLoader - - # This creates a module named "validators" with the contents of the init file - # This allow statements like `from validators import StartsWith` - # But more importantly, it registers all of the validators imported in the init - SourceFileLoader("config", config_file_path).load_module() - + spec = importlib.util.spec_from_file_location("config", config_file_path) + config_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config_module) def create_app( env: Optional[str] = None, config: Optional[str] = None, port: Optional[int] = None @@ -74,21 +120,29 @@ def create_app( register_config(config) - app = Flask(__name__) - app.json = OverrideJsonProvider(app) + app = FastAPI() - app.config["APPLICATION_ROOT"] = "/" - app.config["PREFERRED_URL_SCHEME"] = "https" - app.wsgi_app = ReverseProxied(app.wsgi_app) - CORS(app) + # Initialize FastAPIInstrumentor + FastAPIInstrumentor.instrument_app(app) - app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_port=1) + # app.add_middleware(ProfilingMiddleware) + + # Add CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Add reverse proxy middleware + app.add_middleware(ReverseProxyMiddleware) guardrails_log_level = os.environ.get("GUARDRAILS_LOG_LEVEL", "INFO") configure_logging(log_level=guardrails_log_level) if not otel_is_disabled(): - FlaskInstrumentor().instrument_app(app) initialize() # if no pg_host is set, don't set up postgres @@ -101,11 +155,19 @@ def create_app( cache_client = CacheClient() cache_client.initialize(app) - from guardrails_api.blueprints.root import root_bp - from guardrails_api.blueprints.guards import guards_bp + from guardrails_api.api.root import router as root_router + from guardrails_api.api.guards import router as guards_router, guard_client + + app.include_router(root_router) + app.include_router(guards_router) - app.register_blueprint(root_bp) - app.register_blueprint(guards_bp) + # Custom JSON encoder + @app.exception_handler(ValueError) + async def value_error_handler(request: Request, exc: ValueError): + return JSONResponse( + status_code=400, + content={"message": str(exc)}, + ) console.print( f"\n:rocket: Guardrails API is available at {self_endpoint}" @@ -114,13 +176,18 @@ def create_app( console.print(":green_circle: Active guards and OpenAI compatible endpoints:") - with app.app_context(): - from guardrails_api.blueprints.guards import guard_client - for g in guard_client.get_guards(): - g = g.to_dict() - console.print(f"- Guard: [bold white]{g.get('name')}[/bold white] {self_endpoint}/guards/{g.get('name')}/openai/v1") + guards = guard_client.get_guards() + + for g in guards: + g_dict = g.to_dict() + console.print(f"- Guard: [bold white]{g_dict.get('name')}[/bold white] {self_endpoint}/guards/{g_dict.get('name')}/openai/v1") console.print("") console.print(Rule("[bold grey]Server Logs[/bold grey]", characters="=", style="white")) - return app \ No newline at end of file + return app + +if __name__ == "__main__": + import uvicorn + app = create_app() + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/guardrails_api/blueprints/guards.py b/guardrails_api/blueprints/guards.py deleted file mode 100644 index dd8db10..0000000 --- a/guardrails_api/blueprints/guards.py +++ /dev/null @@ -1,516 +0,0 @@ -import asyncio -import json -import os -import inspect -from guardrails.hub import * # noqa -from string import Template -from typing import Any, Dict, cast -from flask import Blueprint, Response, request, stream_with_context -from urllib.parse import unquote_plus -from guardrails import AsyncGuard, Guard -from guardrails.classes import ValidationOutcome -from opentelemetry.trace import Span -from guardrails_api_client import Guard as GuardStruct -from guardrails_api.classes.http_error import HttpError -from guardrails_api.clients.cache_client import CacheClient -from guardrails_api.clients.memory_guard_client import MemoryGuardClient -from guardrails_api.clients.pg_guard_client import PGGuardClient -from guardrails_api.clients.postgres_client import postgres_is_enabled -from guardrails_api.utils.handle_error import handle_error -from guardrails_api.utils.get_llm_callable import get_llm_callable -from guardrails_api.utils.openai import outcome_to_chat_completion, outcome_to_stream_response - -guards_bp = Blueprint("guards", __name__, url_prefix="/guards") - - -# if no pg_host is set, use in memory guards -if postgres_is_enabled(): - guard_client = PGGuardClient() -else: - guard_client = MemoryGuardClient() - # Will be defined at runtime - import config # noqa - - exports = config.__dir__() - for export_name in exports: - export = getattr(config, export_name) - is_guard = isinstance(export, Guard) - if is_guard: - guard_client.create_guard(export) - -cache_client = CacheClient() - - -@guards_bp.route("/", methods=["GET", "POST"]) -@handle_error -def guards(): - if request.method == "GET": - guards = guard_client.get_guards() - return [g.to_dict() for g in guards] - elif request.method == "POST": - if not postgres_is_enabled(): - raise HttpError( - 501, - "NotImplemented", - "POST /guards is not implemented for in-memory guards.", - ) - payload = request.json - guard = GuardStruct.from_dict(payload) - new_guard = guard_client.create_guard(guard) - return new_guard.to_dict() - else: - raise HttpError( - 405, - "Method Not Allowed", - "/guards only supports the GET and POST methods. You specified" - " {request_method}".format(request_method=request.method), - ) - - -@guards_bp.route("/", methods=["GET", "PUT", "DELETE"]) -@handle_error -def guard(guard_name: str): - decoded_guard_name = unquote_plus(guard_name) - if request.method == "GET": - as_of_query = request.args.get("asOf") - guard = guard_client.get_guard(decoded_guard_name, as_of_query) - if guard is None: - raise HttpError( - 404, - "NotFound", - "A Guard with the name {guard_name} does not exist!".format( - guard_name=decoded_guard_name - ), - ) - return guard.to_dict() - elif request.method == "PUT": - if not postgres_is_enabled(): - raise HttpError( - 501, - "NotImplemented", - "PUT / is not implemented for in-memory guards.", - ) - payload = request.json - guard = GuardStruct.from_dict(payload) - updated_guard = guard_client.upsert_guard(decoded_guard_name, guard) - return updated_guard.to_dict() - elif request.method == "DELETE": - if not postgres_is_enabled(): - raise HttpError( - 501, - "NotImplemented", - "DELETE / is not implemented for in-memory guards.", - ) - guard = guard_client.delete_guard(decoded_guard_name) - return guard.to_dict() - else: - raise HttpError( - 405, - "Method Not Allowed", - "/guard/ only supports the GET, PUT, and DELETE methods." - " You specified {request_method}".format(request_method=request.method), - ) - - -def collect_telemetry( - *, - guard: Guard, - validate_span: Span, - validation_output: ValidationOutcome, - prompt_params: Dict[str, Any], - result: ValidationOutcome, -): - # Below is all telemetry collection and - # should have no impact on what is returned to the user - prompt = guard.history.last.inputs.prompt - if prompt: - prompt = Template(prompt).safe_substitute(**prompt_params) - validate_span.set_attribute("prompt", prompt) - - instructions = guard.history.last.inputs.instructions - if instructions: - instructions = Template(instructions).safe_substitute(**prompt_params) - validate_span.set_attribute("instructions", instructions) - - validate_span.set_attribute("validation_status", guard.history.last.status) - validate_span.set_attribute("raw_llm_ouput", result.raw_llm_output) - - # Use the serialization from the class instead of re-writing it - valid_output: str = ( - json.dumps(validation_output.validated_output) - if isinstance(validation_output.validated_output, dict) - else str(validation_output.validated_output) - ) - validate_span.set_attribute("validated_output", valid_output) - - validate_span.set_attribute("tokens_consumed", guard.history.last.tokens_consumed) - - num_of_reasks = ( - guard.history.last.iterations.length - 1 - if guard.history.last.iterations.length > 0 - else 0 - ) - validate_span.set_attribute("num_of_reasks", num_of_reasks) - - -@guards_bp.route("//openai/v1/chat/completions", methods=["POST"]) -@handle_error -def openai_v1_chat_completions(guard_name: str): - # This endpoint implements the OpenAI Chat API - # It is mean to be fully compatible - # The only difference is that it uses the Guard API under the hood - # instead of the OpenAI API and supports guardrail API error handling - # To use this with the OpenAI SDK you can use the following code: - # import openai - # openai.base_url = "http://localhost:8000/guards//openai/v1/" - # response = openai.chat.completions( - # model="gpt-3.5-turbo-0125", - # messages=[ - # {"role": "user", "content": "Hello, how are you?"}, - # ], - # stream=True, - # ) - # print(response) - # to configure guard rails error handling from the server side you can use the following code: - # - - payload = request.json - decoded_guard_name = unquote_plus(guard_name) - guard_struct = guard_client.get_guard(decoded_guard_name) - guard = guard_struct - if guard_struct is None: - raise HttpError( - 404, - "NotFound", - "A Guard with the name {guard_name} does not exist!".format( - guard_name=decoded_guard_name - ), - ) - - if not isinstance(guard_struct, Guard): - guard: Guard = Guard.from_dict(guard_struct.to_dict()) - stream = payload.get("stream", False) - has_tool_gd_tool_call = False - - try: - tools = payload.get("tools", []) - tools.filter(lambda tool: tool["funcion"]["name"] == "gd_response_tool") - has_tool_gd_tool_call = len(tools) > 0 - except (KeyError, AttributeError): - pass - - if not stream: - validation_outcome: ValidationOutcome = guard( - # todo make this come from the guard struct? - # currently we dont support .configure - num_reasks=0, - **payload, - ) - llm_response = guard.history.last.iterations.last.outputs.llm_response_info - result = outcome_to_chat_completion( - validation_outcome=validation_outcome, - llm_response=llm_response, - has_tool_gd_tool_call=has_tool_gd_tool_call, - ) - return result - - else: - # need to return validated chunks that look identical to openai's - # should look something like - # data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-3.5-turbo-0125", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":None,"finish_reason":None}]} - # .... - # data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-3.5-turbo-0125", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{},"logprobs":None,"finish_reason":"stop"}]} - def openai_streamer(): - guard_stream = guard( - num_reasks=0, - **payload, - ) - for result in guard_stream: - chunk_string = f"data: {json.dumps(outcome_to_stream_response(validation_outcome=result))}\n\n" - yield chunk_string.encode("utf-8") - # close the stream - yield b"\n" - - return Response( - stream_with_context(openai_streamer()), - ) - - -@guards_bp.route("//validate", methods=["POST"]) -@handle_error -def validate(guard_name: str): - # Do we actually need a child span here? - # We could probably use the existing span from the request unless we forsee - # capturing the same attributes on non-GaaS Guard runs. - if request.method != "POST": - raise HttpError( - 405, - "Method Not Allowed", - "/guards//validate only supports the POST method. You specified" - " {request_method}".format(request_method=request.method), - ) - payload = request.json - openai_api_key = request.headers.get( - "x-openai-api-key", os.environ.get("OPENAI_API_KEY") - ) - decoded_guard_name = unquote_plus(guard_name) - guard_struct = guard_client.get_guard(decoded_guard_name) - llm_output = payload.pop("llmOutput", None) - num_reasks = payload.pop("numReasks", None) - prompt_params = payload.pop("promptParams", {}) - llm_api = payload.pop("llmApi", None) - args = payload.pop("args", []) - stream = payload.pop("stream", False) - - # service_name = os.environ.get("OTEL_SERVICE_NAME", "guardrails-api") - # otel_tracer = get_tracer(service_name) - - payload["api_key"] = payload.get("api_key", openai_api_key) - - # with otel_tracer.start_as_current_span( - # f"validate-{decoded_guard_name}" - # ) as validate_span: - # guard: Guard = guard_struct.to_guard(openai_api_key, otel_tracer) - - - # validate_span.set_attribute("guardName", decoded_guard_name) - if llm_api is not None: - llm_api = get_llm_callable(llm_api) - if openai_api_key is None: - raise HttpError( - status=400, - message="BadRequest", - cause=( - "Cannot perform calls to OpenAI without an api key. Pass" - " openai_api_key when initializing the Guard or set the" - " OPENAI_API_KEY environment variable." - ), - ) - - guard = guard_struct - is_async = inspect.iscoroutinefunction(llm_api) - if not isinstance(guard_struct, Guard): - if is_async: - guard = AsyncGuard.from_dict(guard_struct.to_dict()) - else: - guard: Guard = Guard.from_dict(guard_struct.to_dict()) - elif is_async: - guard:Guard = AsyncGuard.from_dict(guard_struct.to_dict()) - - if llm_api is None and num_reasks and num_reasks > 1: - raise HttpError( - status=400, - message="BadRequest", - cause=( - "Cannot perform re-asks without an LLM API. Specify llm_api when" - " calling guard(...)." - ), - ) - if llm_output is not None: - if stream: - raise HttpError( - status=400, - message="BadRequest", - cause="Streaming is not supported for parse calls!", - ) - result: ValidationOutcome = guard.parse( - llm_output=llm_output, - num_reasks=num_reasks, - prompt_params=prompt_params, - llm_api=llm_api, - **payload, - ) - else: - if stream: - def guard_streamer(): - guard_stream = guard( - llm_api=llm_api, - prompt_params=prompt_params, - num_reasks=num_reasks, - stream=stream, - *args, - **payload, - ) - for result in guard_stream: - # TODO: Just make this a ValidationOutcome with history - validation_output: ValidationOutcome = ( - ValidationOutcome.from_guard_history(guard.history.last) - ) - yield validation_output, cast(ValidationOutcome, result) - - async def async_guard_streamer(): - guard_stream = await guard( - llm_api=llm_api, - prompt_params=prompt_params, - num_reasks=num_reasks, - stream=stream, - *args, - **payload, - ) - async for result in guard_stream: - validation_output: ValidationOutcome = ( - ValidationOutcome.from_guard_history(guard.history.last) - ) - yield validation_output, cast(ValidationOutcome, result) - - def validate_streamer(guard_iter): - next_result = None - for validation_output, result in guard_iter: - next_result = result - # next_validation_output = validation_output - fragment_dict = result.to_dict() - fragment_dict["error_spans"] = list( - map( - lambda x: json.dumps( - {"start": x.start, "end": x.end, "reason": x.reason} - ), - guard.error_spans_in_output(), - ) - ) - fragment = json.dumps(fragment_dict) - yield f"{fragment}\n" - call = guard.history.last - final_validation_output: ValidationOutcome = ValidationOutcome( - callId=call.id, - validation_passed=next_result.validation_passed, - validated_output=next_result.validated_output, - history=guard.history, - raw_llm_output=next_result.raw_llm_output, - ) - # I don't know if these are actually making it to OpenSearch - # because the span may be ended already - # collect_telemetry( - # guard=guard, - # validate_span=validate_span, - # validation_output=next_validation_output, - # prompt_params=prompt_params, - # result=next_result - # ) - final_output_dict = final_validation_output.to_dict() - final_output_dict["error_spans"] = list( - map( - lambda x: json.dumps( - {"start": x.start, "end": x.end, "reason": x.reason} - ), - guard.error_spans_in_output(), - ) - ) - final_output_json = json.dumps(final_output_dict) - - serialized_history = [call.to_dict() for call in guard.history] - cache_key = f"{guard.name}-{final_validation_output.call_id}" - cache_client.set(cache_key, serialized_history, 300) - yield f"{final_output_json}\n" - - async def async_validate_streamer(guard_iter): - next_result = None - # next_validation_output = None - async for validation_output, result in guard_iter: - next_result = result - # next_validation_output = validation_output - fragment_dict = result.to_dict() - fragment_dict["error_spans"] = list( - map( - lambda x: json.dumps( - {"start": x.start, "end": x.end, "reason": x.reason} - ), - guard.error_spans_in_output(), - ) - ) - fragment = json.dumps(fragment_dict) - yield f"{fragment}\n" - - call = guard.history.last - final_validation_output: ValidationOutcome = ValidationOutcome( - callId=call.id, - validation_passed=next_result.validation_passed, - validated_output=next_result.validated_output, - history=guard.history, - raw_llm_output=next_result.raw_llm_output, - ) - # I don't know if these are actually making it to OpenSearch - # because the span may be ended already - # collect_telemetry( - # guard=guard, - # validate_span=validate_span, - # validation_output=next_validation_output, - # prompt_params=prompt_params, - # result=next_result - # ) - final_output_dict = final_validation_output.to_dict() - final_output_dict["error_spans"] = list( - map( - lambda x: json.dumps( - {"start": x.start, "end": x.end, "reason": x.reason} - ), - guard.error_spans_in_output(), - ) - ) - final_output_json = json.dumps(final_output_dict) - - serialized_history = [call.to_dict() for call in guard.history] - cache_key = f"{guard.name}-{final_validation_output.call_id}" - cache_client.set(cache_key, serialized_history, 300) - yield f"{final_output_json}\n" - # apropos of https://stackoverflow.com/questions/73949570/using-stream-with-context-as-async - def iter_over_async(ait, loop): - ait = ait.__aiter__() - async def get_next(): - try: - obj = await ait.__anext__() - return False, obj - except StopAsyncIteration: - return True, None - while True: - done, obj = loop.run_until_complete(get_next()) - if done: - break - yield obj - if is_async: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - iter = iter_over_async(async_validate_streamer(async_guard_streamer()), loop) - else: - iter = validate_streamer(guard_streamer()) - return Response( - stream_with_context(iter), - content_type="application/json", - # content_type="text/event-stream" - ) - - result: ValidationOutcome = guard( - llm_api=llm_api, - prompt_params=prompt_params, - num_reasks=num_reasks, - # api_key=openai_api_key, - *args, - **payload, - ) - - # TODO: Just make this a ValidationOutcome with history - # validation_output = ValidationOutcome( - # validation_passed = result.validation_passed, - # validated_output=result.validated_output, - # history=guard.history, - # raw_llm_output=result.raw_llm_output, - # ) - - # collect_telemetry( - # guard=guard, - # validate_span=validate_span, - # validation_output=validation_output, - # prompt_params=prompt_params, - # result=result - # ) - serialized_history = [call.to_dict() for call in guard.history] - cache_key = f"{guard.name}-{result.call_id}" - cache_client.set(cache_key, serialized_history, 300) - return result.to_dict() - - -@guards_bp.route("//history/", methods=["GET"]) -@handle_error -def guard_history(guard_name: str, call_id: str): - if request.method == "GET": - cache_key = f"{guard_name}-{call_id}" - return cache_client.get(cache_key) diff --git a/guardrails_api/blueprints/root.py b/guardrails_api/blueprints/root.py deleted file mode 100644 index ed388be..0000000 --- a/guardrails_api/blueprints/root.py +++ /dev/null @@ -1,79 +0,0 @@ -import os -import json -import flask -from string import Template -from flask import Blueprint -from guardrails_api.open_api_spec import get_open_api_spec -from sqlalchemy import text -from guardrails_api.classes.health_check import HealthCheck -from guardrails_api.clients.postgres_client import PostgresClient, postgres_is_enabled -from guardrails_api.utils.handle_error import handle_error -from guardrails_api.utils.logger import logger - - -root_bp = Blueprint("root", __name__, url_prefix="/") - - -@root_bp.route("/") -@handle_error -def home(): - return "Hello, Flask!" - - -@root_bp.route("/health-check") -@handle_error -def health_check(): - # If we're not using postgres, just return Ok - if not postgres_is_enabled(): - return HealthCheck(200, "Ok").to_dict() - # Make sure we're connected to the database and can run queries - pg_client = PostgresClient() - query = text("SELECT count(datid) FROM pg_stat_activity;") - response = pg_client.db.session.execute(query).all() - # # This works with otel logging - # logger.info(f"response: {response}") - # As does this - logger.info("response: %s", response) - # # This throws an error - # print("response: ", response) - return HealthCheck(200, "Ok").to_dict() - - -@root_bp.route("/api-docs") -@handle_error -def api_docs(): - api_spec = get_open_api_spec() - return json.dumps(api_spec) - - -@root_bp.route("/docs") -@handle_error -def docs(): - host = os.environ.get("SELF_ENDPOINT", "http://localhost:8000") - swagger_ui = Template(""" - - - - - - SwaggerUI - - - -
- - - -""").safe_substitute(apiDocUrl=f"{host}/api-docs") # noqa - - return flask.render_template_string(swagger_ui) diff --git a/guardrails_api/clients/cache_client.py b/guardrails_api/clients/cache_client.py index 550bc4b..fa52d3a 100644 --- a/guardrails_api/clients/cache_client.py +++ b/guardrails_api/clients/cache_client.py @@ -1,8 +1,8 @@ import threading -from flask_caching import Cache +from fastapi import FastAPI +from aiocache import caches, Cache +from aiocache.serializers import JsonSerializer - -# TODO: Add option to connect to Redis or MemCached backend with environment variables class CacheClient: _instance = None _lock = threading.Lock() @@ -10,27 +10,30 @@ class CacheClient: def __new__(cls): if cls._instance is None: with cls._lock: - cls._instance = super(CacheClient, cls).__new__(cls) + if cls._instance is None: # Double-checked locking + cls._instance = super().__new__(cls) return cls._instance - def initialize(self, app): - self.cache = Cache( - app, - config={ - "CACHE_TYPE": "SimpleCache", - "CACHE_DEFAULT_TIMEOUT": 300, - "CACHE_THRESHOLD": 50, - }, - ) - - def get(self, key): - return self.cache.get(key) - - def set(self, key, value, ttl): - self.cache.set(key, value, timeout=ttl) - - def delete(self, key): - self.cache.delete(key) - - def clear(self): - self.cache.clear() + def initialize(self, app: FastAPI): + caches.set_config({ + 'default': { + 'cache': "aiocache.SimpleMemoryCache", + 'serializer': { + 'class': "aiocache.serializers.JsonSerializer" + }, + 'ttl': 300 + } + }) + self.cache = caches.get('default') + + async def get(self, key: str): + return await self.cache.get(key) + + async def set(self, key: str, value: str, ttl: int): + await self.cache.set(key, value, ttl=ttl) + + async def delete(self, key: str): + await self.cache.delete(key) + + async def clear(self): + await self.cache.clear() \ No newline at end of file diff --git a/guardrails_api/start-dev.sh b/guardrails_api/start-dev.sh index a27f2d3..36f33ba 100755 --- a/guardrails_api/start-dev.sh +++ b/guardrails_api/start-dev.sh @@ -1 +1,12 @@ -gunicorn --bind 0.0.0.0:8000 --timeout=5 --threads=10 "guardrails_api.app:create_app()" --reload --capture-output --enable-stdio-inheritance +gunicorn --bind 0.0.0.0:8000 \ + --timeout 120 \ + --workers 3 \ + --threads 2 \ + --worker-class=uvicorn.workers.UvicornWorker \ + "guardrails_api.app:create_app()" \ + --reload \ + --capture-output \ + --enable-stdio-inheritance \ + --access-logfile - \ + --error-logfile - \ + --access-logformat '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s" pid=%(p)s' \ No newline at end of file diff --git a/guardrails_api/start.sh b/guardrails_api/start.sh index 2696b88..4a4a353 100755 --- a/guardrails_api/start.sh +++ b/guardrails_api/start.sh @@ -1 +1 @@ -gunicorn --bind 0.0.0.0:8000 --timeout=5 --threads=10 "guardrails_api.app:create_app()" +gunicorn --bind 0.0.0.0:8000 --timeout=5 --workers=3 --worker-class=uvicorn.workers.UvicornWorker "guardrails_api.app:create_app()" \ No newline at end of file diff --git a/guardrails_api/utils/handle_error.py b/guardrails_api/utils/handle_error.py index 4fcf231..cf7cb07 100644 --- a/guardrails_api/utils/handle_error.py +++ b/guardrails_api/utils/handle_error.py @@ -1,32 +1,36 @@ from functools import wraps import traceback -from werkzeug.exceptions import HTTPException from guardrails_api.classes.http_error import HttpError from guardrails_api.utils.logger import logger from guardrails.errors import ValidationError -def handle_error(fn): - @wraps(fn) - def decorator(*args, **kwargs): - try: - return fn(*args, **kwargs) - except ValidationError as validation_error: - logger.error(validation_error) - traceback.print_exception(type(validation_error), validation_error, validation_error.__traceback__) - return str(validation_error), 400 - except HttpError as http_error: - logger.error(http_error) - traceback.print_exception(type(http_error), http_error, http_error.__traceback__) - return http_error.to_dict(), http_error.status - except HTTPException as http_exception: - logger.error(http_exception) - traceback.print_exception(http_exception) - http_error = HttpError(http_exception.code, http_exception.description) - return http_error.to_dict(), http_error.status - except Exception as e: - logger.error(e) - traceback.print_exception(e) - return HttpError(500, "Internal Server Error").to_dict(), 500 +from fastapi import HTTPException +def handle_error(func=None): + def decorator(fn): + @wraps(fn) + async def wrapper(*args, **kwargs): + try: + return await fn(*args, **kwargs) + except ValidationError as validation_error: + logger.error(validation_error) + traceback.print_exception(type(validation_error), validation_error, validation_error.__traceback__) + raise HTTPException(status_code=400, detail=str(validation_error)) + except HttpError as http_error: + logger.error(http_error) + traceback.print_exception(type(http_error), http_error, http_error.__traceback__) + raise HTTPException(status_code=http_error.status_code, detail=http_error.detail) + except HTTPException as http_exception: + logger.error(http_exception) + traceback.print_exception(type(http_exception), http_exception, http_exception.__traceback__) + raise + except Exception as e: + logger.error(e) + traceback.print_exception(type(e), e, e.__traceback__) + raise HTTPException(status_code=500, detail="Internal Server Error") + return wrapper + + if func: + return decorator(func) return decorator diff --git a/pyproject.toml b/pyproject.toml index ef12ba0..d7ecddd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,13 +11,10 @@ keywords = ["Guardrails", "Guardrails AI", "Guardrails API", "Guardrails API"] requires-python = ">= 3.8.1" dependencies = [ "guardrails-ai>=0.5.6", - "flask>=3.0.3,<4", "Flask-SQLAlchemy>=3.1.1,<4", - "Flask-Caching>=2.3.0,<3", "Werkzeug>=3.0.3,<4", "jsonschema>=4.22.0,<5", "referencing>=0.35.1,<1", - "Flask-Cors>=4.0.1,<6", "boto3>=1.34.115,<2", "psycopg2-binary>=2.9.9,<3", "litellm>=1.39.3,<2", @@ -26,8 +23,10 @@ dependencies = [ "opentelemetry-sdk>=1.0.0,<2", "opentelemetry-exporter-otlp-proto-grpc>=1.0.0,<2", "opentelemetry-exporter-otlp-proto-http>=1.0.0,<2", - "opentelemetry-instrumentation-flask>=0.12b0,<1", - "requests>=2.32.3" + "opentelemetry-instrumentation-fastapi>=0.47b0", + "requests>=2.32.3", + "aiocache>=0.11.1", + "fastapi", ] [tool.setuptools.dynamic] diff --git a/requirements-lock.txt b/requirements-lock.txt index 0c18e6e..950b069 100644 --- a/requirements-lock.txt +++ b/requirements-lock.txt @@ -28,10 +28,16 @@ frozenlist==1.4.1 fsspec==2024.6.1 googleapis-common-protos==1.63.2 griffe==0.36.9 +<<<<<<< Updated upstream grpcio==1.65.1 guardrails-ai==0.5.7 guardrails-api-client==0.3.12 guardrails_hub_types==0.0.4 +======= +grpcio==1.64.1 +guardrails-ai==0.5.0a2 +guardrails-api-client==0.3.8 +>>>>>>> Stashed changes gunicorn==22.0.0 h11==0.14.0 httpcore==1.0.5 diff --git a/tests/blueprints/__init__.py b/tests/api/__init__.py similarity index 100% rename from tests/blueprints/__init__.py rename to tests/api/__init__.py diff --git a/tests/api/test_guards.py b/tests/api/test_guards.py new file mode 100644 index 0000000..83ec28f --- /dev/null +++ b/tests/api/test_guards.py @@ -0,0 +1,370 @@ +import os +from unittest.mock import PropertyMock +from typing import Dict, Tuple + +import pytest +from fastapi.testclient import TestClient +from fastapi import FastAPI + +from guardrails.classes import ValidationOutcome +from guardrails.classes.generic import Stack +from guardrails.classes.history import Call, Iteration +from guardrails.errors import ValidationError + +# Assuming these imports exist in your FastAPI project +from guardrails_api.app import register_config +from tests.mocks.mock_guard_client import MockGuardStruct + +# TODO: Should we mock this somehow? +# Right now it's just empty, but it technically does a file read +register_config() + +app = FastAPI() +from guardrails_api.api.guards import router as guards_router +app.include_router(guards_router) +client = TestClient(app) + +MOCK_GUARD_STRING = { + "id": "mock-guard-id", + "name": "mock-guard", + "description": "mock guard description", + "history": Stack(), +} + +@pytest.fixture(autouse=True) +def around_each(): + # Code that will run before the test + openai_api_key_bak = os.environ.get("OPENAI_API_KEY") + if openai_api_key_bak: + del os.environ["OPENAI_API_KEY"] + yield + # Code that will run after the test + if openai_api_key_bak: + os.environ["OPENAI_API_KEY"] = openai_api_key_bak + +def test_guards__get(mocker): + mock_guard = MockGuardStruct() + mock_get_guards = mocker.patch( + "guardrails_api.api.guards.guard_client.get_guards", + return_value=[mock_guard], + ) + mocker.patch("guardrails_api.api.guards.collect_telemetry") + + response = client.get("/guards") + + assert mock_get_guards.call_count == 1 + assert response.status_code == 200 + assert response.json() == [MOCK_GUARD_STRING] + +def test_guards__post_pg(mocker): + os.environ["PGHOST"] = "localhost" + mock_guard = MockGuardStruct() + mock_from_request = mocker.patch( + "guardrails_api.api.guards.GuardStruct.from_dict", + return_value=mock_guard, + ) + mock_create_guard = mocker.patch( + "guardrails_api.api.guards.guard_client.create_guard", + return_value=mock_guard, + ) + + response = client.post("/guards", json=mock_guard.to_dict()) + + # mock_from_request.assert_called_once_with(mock_guard.to_dict()) + # mock_create_guard.assert_called_once_with(mock_guard) + assert response.status_code == 200 + assert response.json() == MOCK_GUARD_STRING + + del os.environ["PGHOST"] + +def test_guards__post_mem(mocker): + old = None + if "PGHOST" in os.environ: + old = os.environ.get("PGHOST") + del os.environ["PGHOST"] + mock_guard = MockGuardStruct() + + response = client.post("/guards", json=mock_guard.to_dict()) + + assert response.status_code == 501 + assert "Not Implemented" in response.json()["detail"] + if (old): + os.environ["PGHOST"] = old + +def test_guard__get_mem(mocker): + mock_guard = MockGuardStruct() + timestamp = "2024-03-04T14:11:42-06:00" + mock_get_guard = mocker.patch( + "guardrails_api.api.guards.guard_client.get_guard", + return_value=mock_guard, + ) + + response = client.get(f"/guards/My%20Guard's%20Name?asOf={timestamp}") + + mock_get_guard.assert_called_once_with("My Guard's Name", timestamp) + assert response.status_code == 200 + assert response.json() == MOCK_GUARD_STRING + +def test_guard__put_pg(mocker): + os.environ["PGHOST"] = "localhost" + mock_guard = MockGuardStruct() + json_guard = { + "name": "mock-guard", + "id": "mock-guard-id", + "description": "mock guard description", + "history": Stack(), + } + mocker.patch( + "guardrails_api.api.guards.GuardStruct.from_dict", + return_value=mock_guard, + ) + mocker.patch( + "guardrails_api.api.guards.guard_client.upsert_guard", + return_value=mock_guard, + ) + + response = client.put("/guards/My%20Guard's%20Name", json=json_guard) + + assert response.status_code == 200 + assert response.json() == MOCK_GUARD_STRING + del os.environ["PGHOST"] + +def test_guard__delete_pg(mocker): + os.environ["PGHOST"] = "localhost" + mock_guard = MockGuardStruct() + mock_delete_guard = mocker.patch( + "guardrails_api.api.guards.guard_client.delete_guard", + return_value=mock_guard, + ) + + response = client.delete("/guards/my-guard-name") + + mock_delete_guard.assert_called_once_with("my-guard-name") + assert response.status_code == 200 + assert response.json() == MOCK_GUARD_STRING + del os.environ["PGHOST"] + +def test_validate__parse(mocker): + os.environ["PGHOST"] = "localhost" + mock_outcome = ValidationOutcome( + call_id="mock-call-id", + raw_llm_output="Hello world!", + validated_output="Hello world!", + validation_passed=True, + ) + + mock_parse = mocker.patch.object(MockGuardStruct, "parse") + mock_parse.return_value = mock_outcome + + mock_guard = MockGuardStruct() + mock_from_dict = mocker.patch("guardrails_api.api.guards.Guard.from_dict") + mock_from_dict.return_value = mock_guard + + mock_get_guard = mocker.patch( + "guardrails_api.api.guards.guard_client.get_guard", + return_value=mock_guard, + ) + + mock_status = mocker.patch( + "guardrails.classes.history.call.Call.status", new_callable=PropertyMock + ) + mock_status.return_value = "pass" + mock_guard.history = Stack(Call()) + + response = client.post("/guards/My%20Guard's%20Name/validate", json={ + "llmOutput": "Hello world!", + "args": [1, 2, 3], + "some_kwarg": "foo" + }) + + mock_get_guard.assert_called_once_with("My Guard's Name") + assert mock_parse.call_count == 1 + mock_parse.assert_called_once_with( + llm_output="Hello world!", + num_reasks=None, + prompt_params={}, + llm_api=None, + some_kwarg="foo", + api_key=None, + ) + + assert response.status_code == 200 + assert response.json() == { + "callId": "mock-call-id", + "validatedOutput": "Hello world!", + "validationPassed": True, + "rawLlmOutput": "Hello world!", + } + + del os.environ["PGHOST"] + +def test_validate__call(mocker): + os.environ["PGHOST"] = "localhost" + mock_outcome = ValidationOutcome( + call_id="mock-call-id", + raw_llm_output="Hello world!", + validated_output=None, + validation_passed=False, + ) + + mock___call__ = mocker.patch.object(MockGuardStruct, "__call__") + mock___call__.return_value = mock_outcome + + mock_guard = MockGuardStruct() + mock_from_dict = mocker.patch("guardrails_api.api.guards.Guard.from_dict") + mock_from_dict.return_value = mock_guard + + mock_get_guard = mocker.patch( + "guardrails_api.api.guards.guard_client.get_guard", + return_value=mock_guard, + ) + + mock_status = mocker.patch( + "guardrails.classes.history.call.Call.status", new_callable=PropertyMock + ) + mock_status.return_value = "fail" + mock_guard.history = Stack(Call()) + + response = client.post("/guards/My%20Guard's%20Name/validate", json={ + "promptParams": {"p1": "bar"}, + "args": [1, 2, 3], + "some_kwarg": "foo", + "prompt": "Hello world!" + }, headers={"x-openai-api-key": "mock-key"}) + + mock_get_guard.assert_called_once_with("My Guard's Name") + assert mock___call__.call_count == 1 + mock___call__.assert_called_once_with( + 1, + 2, + 3, + llm_api=None, + prompt_params={"p1": "bar"}, + num_reasks=None, + some_kwarg="foo", + api_key="mock-key", + prompt="Hello world!", + ) + + assert response.status_code == 200 + assert response.json() == { + "callId": "mock-call-id", + "validationPassed": False, + "validatedOutput": None, + "rawLlmOutput": "Hello world!", + } + + del os.environ["PGHOST"] + +def test_validate__call_throws_validation_error(mocker): + os.environ["PGHOST"] = "localhost" + error = ValidationError("Test guard validation error") + mock_parse = mocker.patch.object(MockGuardStruct, "__call__") + mock_parse.side_effect = error + + mock_guard = MockGuardStruct() + mock_from_dict = mocker.patch("guardrails_api.api.guards.Guard.from_dict") + mock_from_dict.return_value = mock_guard + + mock_get_guard = mocker.patch( + "guardrails_api.api.guards.guard_client.get_guard", + return_value=mock_guard, + ) + + mock_status = mocker.patch( + "guardrails.classes.history.call.Call.status", new_callable=PropertyMock + ) + mock_status.return_value = "fail" + mock_guard.history = Stack(Call()) + + response = client.post("/guards/My%20Guard's%20Name/validate", json={ + "promptParams": {"p1": "bar"}, + "args": [1, 2, 3], + "some_kwarg": "foo", + "prompt": "Hello world!" + }) + + mock_get_guard.assert_called_once_with("My Guard's Name") + + assert response.status_code == 400 + assert response.json() == {"detail": "Test guard validation error"} + + del os.environ["PGHOST"] + +def test_openai_v1_chat_completions__raises_404(mocker): + os.environ["PGHOST"] = "localhost" + mock_guard = None + + mock_get_guard = mocker.patch( + "guardrails_api.api.guards.guard_client.get_guard", + return_value=mock_guard, + ) + + response = client.post("/guards/My%20Guard's%20Name/openai/v1/chat/completions", json={ + "messages": [{"role":"user", "content":"Hello world!"}], + }, headers={"x-openai-api-key": "mock-key"}) + + assert response.status_code == 404 + assert response.json()["detail"] == "A Guard with the name My Guard's Name does not exist!" + + mock_get_guard.assert_called_once_with("My Guard's Name") + + del os.environ["PGHOST"] + +def test_openai_v1_chat_completions__call(mocker): + os.environ["PGHOST"] = "localhost" + mock_guard = MockGuardStruct() + mock_outcome = ValidationOutcome( + call_id="mock-call-id", + raw_llm_output="Hello world!", + validated_output="Hello world!", + validation_passed=False, + ) + + mock___call__ = mocker.patch.object(MockGuardStruct, "__call__") + mock___call__.return_value = mock_outcome + + mock_from_dict = mocker.patch("guardrails_api.api.guards.Guard.from_dict") + mock_from_dict.return_value = mock_guard + + mock_get_guard = mocker.patch( + "guardrails_api.api.guards.guard_client.get_guard", + return_value=mock_guard, + ) + + mock_status = mocker.patch( + "guardrails.classes.history.call.Call.status", new_callable=PropertyMock + ) + mock_status.return_value = "fail" + mock_call = Call() + mock_call.iterations= Stack(Iteration('some-id', 1)) + mock_guard.history = Stack(mock_call) + + response = client.post("/guards/My%20Guard's%20Name/openai/v1/chat/completions", json={ + "messages": [{"role":"user", "content":"Hello world!"}], + }, headers={"x-openai-api-key": "mock-key"}) + + mock_get_guard.assert_called_once_with("My Guard's Name") + assert mock___call__.call_count == 1 + mock___call__.assert_called_once_with( + num_reasks=0, + messages=[{"role":"user", "content":"Hello world!"}], + ) + + assert response.status_code == 200 + assert response.json() == { + "choices": [ + { + "message": { + "content": "Hello world!", + }, + } + ], + "guardrails": { + "reask": None, + "validation_passed": False, + "error": None, + }, + } + + del os.environ["PGHOST"] diff --git a/tests/api/test_root.py b/tests/api/test_root.py new file mode 100644 index 0000000..affbdb2 --- /dev/null +++ b/tests/api/test_root.py @@ -0,0 +1,59 @@ +import os +from fastapi.testclient import TestClient +from fastapi import FastAPI +import pytest + +from guardrails_api.utils.logger import logger +from tests.mocks.mock_postgres_client import MockPostgresClient + +# Assuming you have a similar structure in your FastAPI app +from guardrails_api.api import root + +@pytest.fixture +def app(): + app = FastAPI() + app.include_router(root.router) + return app + +@pytest.fixture +def client(app): + return TestClient(app) + +def test_home(client): + response = client.get("/") + assert response.status_code == 200 + assert response.json() == "Hello, FastAPI!" + + # Check if all expected routes are registered + routes = [route.path for route in client.app.routes] + assert "/" in routes + assert "/health-check" in routes + assert "/openapi.json" in routes # This is FastAPI's equivalent to /api-docs + assert "/docs" in routes + +def test_health_check(client, mocker): + os.environ["PGHOST"] = "localhost" + + mock_pg = MockPostgresClient() + mock_pg.db.session._set_rows([(1,)]) + mocker.patch("guardrails_api.api.root.PostgresClient", return_value=mock_pg) + + def text_side_effect(query: str): + return query + + mock_text = mocker.patch( + "guardrails_api.api.root.text", side_effect=text_side_effect + ) + + info_spy = mocker.spy(logger, "info") + + response = client.get("/health-check") + + mock_text.assert_called_once_with("SELECT count(datid) FROM pg_stat_activity;") + assert mock_pg.db.session.queries == ["SELECT count(datid) FROM pg_stat_activity;"] + + info_spy.assert_called_once_with("response: %s", [(1,)]) + + assert response.json() == {"status": 200, "message": "Ok"} + + del os.environ["PGHOST"] \ No newline at end of file diff --git a/tests/blueprints/test_guards.py b/tests/blueprints/test_guards.py deleted file mode 100644 index 3cce59d..0000000 --- a/tests/blueprints/test_guards.py +++ /dev/null @@ -1,719 +0,0 @@ -import os -from unittest.mock import PropertyMock -from typing import Dict, Tuple - -import pytest - -from tests.mocks.mock_blueprint import MockBlueprint -from tests.mocks.mock_guard_client import MockGuardStruct -from tests.mocks.mock_request import MockRequest -from guardrails.classes import ValidationOutcome -from guardrails.classes.generic import Stack -from guardrails.classes.history import Call, Iteration -from guardrails_api.app import register_config -from guardrails.errors import ValidationError - -# TODO: Should we mock this somehow? -# Right now it's just empty, but it technically does a file read -register_config() - - -MOCK_GUARD_STRING = { - "id": "mock-guard-id", - "name": "mock-guard", - "description": "mock guard description", - "history": Stack(), -} - - -# FIXME: Why doesn't this work when running a single test? -# Either a config issue or a pytest issue -@pytest.fixture(autouse=True) -def around_each(): - # Code that will run before the test - openai_api_key_bak = os.environ.get("OPENAI_API_KEY") - if openai_api_key_bak: - del os.environ["OPENAI_API_KEY"] - yield - # Code that will run after the test - if openai_api_key_bak: - os.environ["OPENAI_API_KEY"] = openai_api_key_bak - - -def test_route_setup(mocker): - mocker.patch("flask.Blueprint", new=MockBlueprint) - - from guardrails_api.blueprints.guards import guards_bp - - assert guards_bp.route_call_count == 5 - assert guards_bp.routes == [ - "/", - "/", - "//openai/v1/chat/completions", - "//validate", - "//history/", - ] - - -def test_guards__get(mocker): - mock_guard = MockGuardStruct() - mock_request = MockRequest("GET") - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - mock_get_guards = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.get_guards", - return_value=[mock_guard], - ) - mocker.patch("guardrails_api.blueprints.guards.collect_telemetry") - - # >>> Conflict - # mock_get_guards = mocker.patch( - # "guardrails_api.blueprints.guards.guard_client.get_guards", return_value=[mock_guard] - # ) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - - from guardrails_api.blueprints.guards import guards - - response = guards() - - assert mock_get_guards.call_count == 1 - - assert response == [MOCK_GUARD_STRING] - - -def test_guards__post_pg(mocker): - os.environ["PGHOST"] = "localhost" - mock_guard = MockGuardStruct() - mock_request = MockRequest("POST", mock_guard.to_dict()) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - mock_from_request = mocker.patch( - "guardrails_api.blueprints.guards.GuardStruct.from_dict", - return_value=mock_guard, - ) - mock_create_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.create_guard", - return_value=mock_guard, - ) - - from guardrails_api.blueprints.guards import guards - - response = guards() - - mock_from_request.assert_called_once_with(mock_guard.to_dict()) - mock_create_guard.assert_called_once_with(mock_guard) - - assert response == MOCK_GUARD_STRING - - del os.environ["PGHOST"] - - -def test_guards__post_mem(mocker): - mock_guard = MockGuardStruct() - mock_request = MockRequest("POST", mock_guard.to_dict()) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - - from guardrails_api.blueprints.guards import guards - - response = guards() - - error_body, status = response - - assert status == 501 - - -def test_guards__raises(mocker): - mock_request = MockRequest("PUT") - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - mocker.patch("guardrails_api.utils.handle_error.logger.error") - mocker.patch("guardrails_api.utils.handle_error.traceback.print_exception") - from guardrails_api.blueprints.guards import guards - - response = guards() - - assert isinstance(response, Tuple) - error, status = response - assert isinstance(error, Dict) - assert error.get("status") == 405 - assert error.get("message") == "Method Not Allowed" - assert ( - error.get("cause") - == "/guards only supports the GET and POST methods. You specified PUT" - ) - assert status == 405 - - -def test_guard__get_mem(mocker): - mock_guard = MockGuardStruct() - timestamp = "2024-03-04T14:11:42-06:00" - mock_request = MockRequest("GET", args={"asOf": timestamp}) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - - mock_get_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.get_guard", - return_value=mock_guard, - ) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - - # >>> Conflict - # mock_get_guard = mocker.patch( - # "guardrails_api.blueprints.guards.guard_client.get_guard", return_value=mock_guard - # ) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - - from guardrails_api.blueprints.guards import guard - - response = guard("My%20Guard's%20Name") - - mock_get_guard.assert_called_once_with("My Guard's Name", timestamp) - assert response == MOCK_GUARD_STRING - - -def test_guard__put_pg(mocker): - os.environ["PGHOST"] = "localhost" - mock_guard = MockGuardStruct() - json_guard = { - "name": "mock-guard", - "id": "mock-guard-id", - "description": "mock guard description", - "history": Stack(), - } - mock_request = MockRequest("PUT", json=json_guard) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - - mock_from_request = mocker.patch( - "guardrails_api.blueprints.guards.GuardStruct.from_dict", - return_value=mock_guard, - ) - mock_upsert_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.upsert_guard", - return_value=mock_guard, - ) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - - # >>> Conflict - # mock_from_request = mocker.patch( - # "guardrails_api.blueprints.guards.GuardStruct.from_request", return_value=mock_guard - # ) - # mock_upsert_guard = mocker.patch( - # "guardrails_api.blueprints.guards.guard_client.upsert_guard", return_value=mock_guard - # ) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - - from guardrails_api.blueprints.guards import guard - - response = guard("My%20Guard's%20Name") - - mock_from_request.assert_called_once_with(json_guard) - mock_upsert_guard.assert_called_once_with("My Guard's Name", mock_guard) - assert response == MOCK_GUARD_STRING - del os.environ["PGHOST"] - - -def test_guard__delete_pg(mocker): - os.environ["PGHOST"] = "localhost" - mock_guard = MockGuardStruct() - mock_request = MockRequest("DELETE") - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - - mock_delete_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.delete_guard", - return_value=mock_guard, - ) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - - # >>> Conflict - # mock_delete_guard = mocker.patch( - # "guardrails_api.blueprints.guards.guard_client.delete_guard", return_value=mock_guard - # ) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - - from guardrails_api.blueprints.guards import guard - - response = guard("my-guard-name") - - mock_delete_guard.assert_called_once_with("my-guard-name") - assert response == MOCK_GUARD_STRING - del os.environ["PGHOST"] - - -def test_guard__raises(mocker): - mock_request = MockRequest("POST") - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - mocker.patch("guardrails_api.utils.handle_error.logger.error") - mocker.patch("guardrails_api.utils.handle_error.traceback.print_exception") - from guardrails_api.blueprints.guards import guard - - response = guard("guard") - - assert isinstance(response, Tuple) - error, status = response - assert isinstance(error, Dict) - assert error.get("status") == 405 - assert error.get("message") == "Method Not Allowed" - assert ( - error.get("cause") - == "/guard/ only supports the GET, PUT, and DELETE methods. You specified POST" - ) - assert status == 405 - - -def test_validate__raises_method_not_allowed(mocker): - mock_request = MockRequest("PUT") - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - mocker.patch("guardrails_api.utils.handle_error.logger.error") - mocker.patch("guardrails_api.utils.handle_error.traceback.print_exception") - from guardrails_api.blueprints.guards import validate - - response = validate("guard") - - assert isinstance(response, Tuple) - error, status = response - assert isinstance(error, Dict) - assert error.get("status") == 405 - assert error.get("message") == "Method Not Allowed" - assert ( - error.get("cause") - == "/guards//validate only supports the POST method. You specified PUT" - ) - assert status == 405 - - -def test_validate__raises_bad_request__openai_api_key(mocker): - os.environ["PGHOST"] = "localhost" - mock_guard = MockGuardStruct() - # mock_tracer = MockTracer() - mock_request = MockRequest("POST", json={"llmApi": "bar"}) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - mock_get_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.get_guard", - return_value=mock_guard, - ) - - # mocker.patch("guardrails_api.blueprints.guards.get_tracer", return_value=mock_tracer) - mocker.patch("guardrails_api.utils.handle_error.logger.error") - mocker.patch("guardrails_api.utils.handle_error.traceback.print_exception") - from guardrails_api.blueprints.guards import validate - - response = validate("mock-guard") - - mock_get_guard.assert_called_once_with("mock-guard") - - assert isinstance(response, Tuple) - error, status = response - assert isinstance(error, Dict) - assert error.get("status") == 400 - assert error.get("message") == "BadRequest" - assert error.get("cause") == ( - "Cannot perform calls to OpenAI without an api key. Pass" - " openai_api_key when initializing the Guard or set the" - " OPENAI_API_KEY environment variable." - ) - assert status == 400 - del os.environ["PGHOST"] - - -def test_validate__raises_bad_request__num_reasks(mocker): - os.environ["PGHOST"] = "localhost" - mock_guard = MockGuardStruct() - # mock_tracer = MockTracer() - mock_request = MockRequest("POST", json={"numReasks": 3}) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - mock_get_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.get_guard", - return_value=mock_guard, - ) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer", return_value=mock_tracer) - mocker.patch("guardrails_api.utils.handle_error.logger.error") - mocker.patch("guardrails_api.utils.handle_error.traceback.print_exception") - from guardrails_api.blueprints.guards import validate - - response = validate("mock-guard") - - mock_get_guard.assert_called_once_with("mock-guard") - - assert isinstance(response, Tuple) - error, status = response - assert isinstance(error, Dict) - assert error.get("status") == 400 - assert error.get("message") == "BadRequest" - assert error.get("cause") == ( - "Cannot perform re-asks without an LLM API. Specify llm_api when" - " calling guard(...)." - ) - assert status == 400 - del os.environ["PGHOST"] - - -def test_validate__parse(mocker): - os.environ["PGHOST"] = "localhost" - mock_outcome = ValidationOutcome( - call_id="mock-call-id", - raw_llm_output="Hello world!", - validated_output="Hello world!", - validation_passed=True, - ) - - mock_parse = mocker.patch.object(MockGuardStruct, "parse") - mock_parse.return_value = mock_outcome - - mock_guard = MockGuardStruct() - mock_from_dict = mocker.patch("guardrails_api.blueprints.guards.Guard.from_dict") - mock_from_dict.return_value = mock_guard - - # mock_tracer = MockTracer() - mock_request = MockRequest( - "POST", - json={"llmOutput": "Hello world!", "args": [1, 2, 3], "some_kwarg": "foo"}, - ) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - mock_get_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.get_guard", - return_value=mock_guard, - ) - - mocker.patch("guardrails_api.blueprints.guards.CacheClient.set") - - # mocker.patch("guardrails_api.blueprints.guards.get_tracer", return_value=mock_tracer) - - # >>> Conflict - # mocker.patch("guardrails_api.blueprints.guards.get_tracer", return_value=mock_tracer) - - # set_attribute_spy = mocker.spy(mock_tracer.span, "set_attribute") - - mock_status = mocker.patch( - "guardrails.classes.history.call.Call.status", new_callable=PropertyMock - ) - mock_status.return_value = "pass" - mock_guard.history = Stack(Call()) - from guardrails_api.blueprints.guards import validate - - response = validate("My%20Guard's%20Name") - - mock_get_guard.assert_called_once_with("My Guard's Name") - - assert mock_parse.call_count == 1 - - mock_parse.assert_called_once_with( - llm_output="Hello world!", - num_reasks=None, - prompt_params={}, - llm_api=None, - some_kwarg="foo", - api_key=None, - ) - - # Temporarily Disabled - # assert set_attribute_spy.call_count == 7 - # expected_calls = [ - # call("guardName", "My Guard's Name"), - # call("prompt", "Hello world prompt!"), - # call("validation_status", "pass"), - # call("raw_llm_ouput", "Hello world!"), - # call("validated_output", "Hello world!"), - # call("tokens_consumed", None), - # call("num_of_reasks", 0), - # ] - # set_attribute_spy.assert_has_calls(expected_calls) - - assert response == { - "callId": "mock-call-id", - "validatedOutput": "Hello world!", - "validationPassed": True, - "rawLlmOutput": "Hello world!", - } - - del os.environ["PGHOST"] - - -def test_validate__call(mocker): - os.environ["PGHOST"] = "localhost" - mock_guard = MockGuardStruct() - mock_outcome = ValidationOutcome( - call_id="mock-call-id", - raw_llm_output="Hello world!", - validated_output=None, - validation_passed=False, - ) - - mock___call__ = mocker.patch.object(MockGuardStruct, "__call__") - mock___call__.return_value = mock_outcome - - mock_guard = MockGuardStruct() - mock_from_dict = mocker.patch("guardrails_api.blueprints.guards.Guard.from_dict") - mock_from_dict.return_value = mock_guard - - # mock_tracer = MockTracer() - mock_request = MockRequest( - "POST", - json={ - "llmApi": "openai.Completion.create", - "promptParams": {"p1": "bar"}, - "args": [1, 2, 3], - "some_kwarg": "foo", - "prompt": "Hello world!", - }, - headers={"x-openai-api-key": "mock-key"}, - ) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - mock_get_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.get_guard", - return_value=mock_guard, - ) - mocker.patch( - "guardrails_api.blueprints.guards.get_llm_callable", - return_value="openai.Completion.create", - ) - - mocker.patch("guardrails_api.blueprints.guards.CacheClient.set") - - # mocker.patch("guardrails_api.blueprints.guards.get_tracer", return_value=mock_tracer) - - # >>> Conflict - # mocker.patch("guardrails_api.blueprints.guards.get_tracer", return_value=mock_tracer) - - # set_attribute_spy = mocker.spy(mock_tracer.span, "set_attribute") - - mock_status = mocker.patch( - "guardrails.classes.history.call.Call.status", new_callable=PropertyMock - ) - mock_status.return_value = "fail" - mock_guard.history = Stack(Call()) - from guardrails_api.blueprints.guards import validate - - response = validate("My%20Guard's%20Name") - - mock_get_guard.assert_called_once_with("My Guard's Name") - - assert mock___call__.call_count == 1 - - mock___call__.assert_called_once_with( - 1, - 2, - 3, - llm_api="openai.Completion.create", - prompt_params={"p1": "bar"}, - num_reasks=None, - some_kwarg="foo", - api_key="mock-key", - prompt="Hello world!", - ) - - # Temporarily Disabled - # assert set_attribute_spy.call_count == 8 - # expected_calls = [ - # call("guardName", "My Guard's Name"), - # call("prompt", "Hello world prompt!"), - # call("instructions", "Hello world instructions!"), - # call("validation_status", "fail"), - # call("raw_llm_ouput", "Hello world!"), - # call("validated_output", "None"), - # call("tokens_consumed", None), - # call("num_of_reasks", 0), - # ] - # set_attribute_spy.assert_has_calls(expected_calls) - - assert response == { - "callId": "mock-call-id", - "validationPassed": False, - "validatedOutput": None, - "rawLlmOutput": "Hello world!", - } - - del os.environ["PGHOST"] - -def test_validate__call_throws_validation_error(mocker): - os.environ["PGHOST"] = "localhost" - - mock___call__ = mocker.patch.object(MockGuardStruct, "__call__") - mock___call__.side_effect = ValidationError("Test guard validation error") - - mock_guard = MockGuardStruct() - mock_from_dict = mocker.patch("guardrails_api.blueprints.guards.Guard.from_dict") - mock_from_dict.return_value = mock_guard - - # mock_tracer = MockTracer() - mock_request = MockRequest( - "POST", - json={ - "llmApi": "openai.Completion.create", - "promptParams": {"p1": "bar"}, - "args": [1, 2, 3], - "some_kwarg": "foo", - "prompt": "Hello world!", - }, - headers={"x-openai-api-key": "mock-key"}, - ) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - mock_get_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.get_guard", - return_value=mock_guard, - ) - mocker.patch( - "guardrails_api.blueprints.guards.get_llm_callable", - return_value="openai.Completion.create", - ) - - mocker.patch("guardrails_api.blueprints.guards.CacheClient.set") - - mock_status = mocker.patch( - "guardrails.classes.history.call.Call.status", new_callable=PropertyMock - ) - mock_status.return_value = "fail" - mock_guard.history = Stack(Call()) - from guardrails_api.blueprints.guards import validate - - response = validate("My%20Guard's%20Name") - - mock_get_guard.assert_called_once_with("My Guard's Name") - - assert mock___call__.call_count == 1 - - mock___call__.assert_called_once_with( - 1, - 2, - 3, - llm_api="openai.Completion.create", - prompt_params={"p1": "bar"}, - num_reasks=None, - some_kwarg="foo", - api_key="mock-key", - prompt="Hello world!", - ) - - assert response == ('Test guard validation error', 400) - - del os.environ["PGHOST"] - -def test_openai_v1_chat_completions__raises_404(mocker): - from guardrails_api.blueprints.guards import openai_v1_chat_completions - os.environ["PGHOST"] = "localhost" - mock_guard = None - - mock_request = MockRequest( - "POST", - json={ - "messages": [{"role":"user", "content":"Hello world!"}], - }, - headers={"x-openai-api-key": "mock-key"}, - ) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - mock_get_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.get_guard", - return_value=mock_guard, - ) - mocker.patch("guardrails_api.blueprints.guards.CacheClient.set") - - response = openai_v1_chat_completions("My%20Guard's%20Name") - assert response[1] == 404 - assert response[0]["message"] == 'NotFound' - - - mock_get_guard.assert_called_once_with("My Guard's Name") - - del os.environ["PGHOST"] - -def test_openai_v1_chat_completions__call(mocker): - from guardrails_api.blueprints.guards import openai_v1_chat_completions - os.environ["PGHOST"] = "localhost" - mock_guard = MockGuardStruct() - mock_outcome = ValidationOutcome( - call_id="mock-call-id", - raw_llm_output="Hello world!", - validated_output="Hello world!", - validation_passed=False, - ) - - mock___call__ = mocker.patch.object(MockGuardStruct, "__call__") - mock___call__.return_value = mock_outcome - - mock_from_dict = mocker.patch("guardrails_api.blueprints.guards.Guard.from_dict") - mock_from_dict.return_value = mock_guard - - mock_request = MockRequest( - "POST", - json={ - "messages": [{"role":"user", "content":"Hello world!"}], - }, - headers={"x-openai-api-key": "mock-key"}, - ) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - mock_get_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.get_guard", - return_value=mock_guard, - ) - mocker.patch( - "guardrails_api.blueprints.guards.get_llm_callable", - return_value="openai.Completion.create", - ) - - mocker.patch("guardrails_api.blueprints.guards.CacheClient.set") - - mock_status = mocker.patch( - "guardrails.classes.history.call.Call.status", new_callable=PropertyMock - ) - mock_status.return_value = "fail" - mock_call = Call() - mock_call.iterations= Stack(Iteration('some-id', 1)) - mock_guard.history = Stack(mock_call) - - response = openai_v1_chat_completions("My%20Guard's%20Name") - - mock_get_guard.assert_called_once_with("My Guard's Name") - - assert mock___call__.call_count == 1 - - mock___call__.assert_called_once_with( - num_reasks=0, - messages=[{"role":"user", "content":"Hello world!"}], - ) - - assert response == { - "choices": [ - { - "message": { - "content": "Hello world!", - }, - } - ], - "guardrails": { - "reask": None, - "validation_passed": False, - "error": None, - }, - } - - del os.environ["PGHOST"] \ No newline at end of file diff --git a/tests/blueprints/test_root.py b/tests/blueprints/test_root.py deleted file mode 100644 index 7ef611f..0000000 --- a/tests/blueprints/test_root.py +++ /dev/null @@ -1,48 +0,0 @@ -import os -from guardrails_api.utils.logger import logger -from tests.mocks.mock_blueprint import MockBlueprint -from tests.mocks.mock_postgres_client import MockPostgresClient - - -def test_home(mocker): - mocker.patch("flask.Blueprint", new=MockBlueprint) - from guardrails_api.blueprints.root import home, root_bp - - response = home() - - assert root_bp.route_call_count == 4 - assert root_bp.routes == ["/", "/health-check", "/api-docs", "/docs"] - assert response == "Hello, Flask!" - - mocker.resetall() - - -def test_health_check(mocker): - os.environ["PGHOST"] = "localhost" - mocker.patch("flask.Blueprint", new=MockBlueprint) - - mock_pg = MockPostgresClient() - mock_pg.db.session._set_rows([(1,)]) - mocker.patch("guardrails_api.blueprints.root.PostgresClient", return_value=mock_pg) - - def text_side_effect(query: str): - return query - - mock_text = mocker.patch( - "guardrails_api.blueprints.root.text", side_effect=text_side_effect - ) - - from guardrails_api.blueprints.root import health_check - - info_spy = mocker.spy(logger, "info") - - response = health_check() - - mock_text.assert_called_once_with("SELECT count(datid) FROM pg_stat_activity;") - assert mock_pg.db.session.queries == ["SELECT count(datid) FROM pg_stat_activity;"] - - info_spy.assert_called_once_with("response: %s", [(1,)]) - assert response == {"status": 200, "message": "Ok"} - - mocker.resetall() - del os.environ["PGHOST"] From 78fde5c10f8627ba34b1b0d133aff9a877cf48a4 Mon Sep 17 00:00:00 2001 From: David Tam Date: Fri, 13 Sep 2024 13:22:13 -0700 Subject: [PATCH 02/13] lint --- guardrails_api/api/guards.py | 94 ++++++++++++++----- guardrails_api/api/root.py | 12 ++- guardrails_api/app.py | 25 +++-- guardrails_api/cli/start.py | 1 + guardrails_api/clients/cache_client.py | 24 ++--- guardrails_api/utils/configuration.py | 26 +++-- guardrails_api/utils/handle_error.py | 22 ++++- .../utils/has_internet_connection.py | 2 +- guardrails_api/utils/openai.py | 1 + .../utils/trace_server_start_if_enabled.py | 3 +- tests/api/test_guards.py | 94 ++++++++++++------- tests/api/test_root.py | 6 +- tests/cli/test_start.py | 2 + tests/mocks/mock_guard_client.py | 1 + tests/utils/test_configuration.py | 7 +- 15 files changed, 222 insertions(+), 98 deletions(-) diff --git a/guardrails_api/api/guards.py b/guardrails_api/api/guards.py index 076c2d8..8621e57 100644 --- a/guardrails_api/api/guards.py +++ b/guardrails_api/api/guards.py @@ -1,11 +1,9 @@ -import asyncio import json import os import inspect -from typing import Any, Dict, List, Optional -from fastapi import FastAPI, HTTPException, Request, Response, APIRouter +from typing import Any, Dict, Optional +from fastapi import HTTPException, Request, APIRouter from fastapi.responses import JSONResponse, StreamingResponse -from pydantic import BaseModel from urllib.parse import unquote_plus from guardrails import AsyncGuard, Guard from guardrails.classes import ValidationOutcome @@ -16,7 +14,10 @@ from guardrails_api.clients.pg_guard_client import PGGuardClient from guardrails_api.clients.postgres_client import postgres_is_enabled from guardrails_api.utils.get_llm_callable import get_llm_callable -from guardrails_api.utils.openai import outcome_to_chat_completion, outcome_to_stream_response +from guardrails_api.utils.openai import ( + outcome_to_chat_completion, + outcome_to_stream_response, +) from guardrails_api.utils.handle_error import handle_error from string import Template @@ -39,47 +40,65 @@ router = APIRouter() + @router.get("/guards") @handle_error async def get_guards(): guards = guard_client.get_guards() return [g.to_dict() for g in guards] + @router.post("/guards") @handle_error async def create_guard(guard: GuardStruct): if not postgres_is_enabled(): - raise HTTPException(status_code=501, detail="Not Implemented POST /guards is not implemented for in-memory guards.") + raise HTTPException( + status_code=501, + detail="Not Implemented POST /guards is not implemented for in-memory guards.", + ) new_guard = guard_client.create_guard(guard) return new_guard.to_dict() + @router.get("/guards/{guard_name}") @handle_error async def get_guard(guard_name: str, asOf: Optional[str] = None): decoded_guard_name = unquote_plus(guard_name) guard = guard_client.get_guard(decoded_guard_name, asOf) if guard is None: - raise HTTPException(status_code=404, detail=f"A Guard with the name {decoded_guard_name} does not exist!") + raise HTTPException( + status_code=404, + detail=f"A Guard with the name {decoded_guard_name} does not exist!", + ) return guard.to_dict() + @router.put("/guards/{guard_name}") @handle_error async def update_guard(guard_name: str, guard: GuardStruct): if not postgres_is_enabled(): - raise HTTPException(status_code=501, detail="PUT / is not implemented for in-memory guards.") + raise HTTPException( + status_code=501, + detail="PUT / is not implemented for in-memory guards.", + ) decoded_guard_name = unquote_plus(guard_name) updated_guard = guard_client.upsert_guard(decoded_guard_name, guard) return updated_guard.to_dict() + @router.delete("/guards/{guard_name}") @handle_error async def delete_guard(guard_name: str): if not postgres_is_enabled(): - raise HTTPException(status_code=501, detail="DELETE / is not implemented for in-memory guards.") + raise HTTPException( + status_code=501, + detail="DELETE / is not implemented for in-memory guards.", + ) decoded_guard_name = unquote_plus(guard_name) guard = guard_client.delete_guard(decoded_guard_name) return guard.to_dict() + @router.post("/guards/{guard_name}/openai/v1/chat/completions") @handle_error async def openai_v1_chat_completions(guard_name: str, request: Request): @@ -87,11 +106,21 @@ async def openai_v1_chat_completions(guard_name: str, request: Request): decoded_guard_name = unquote_plus(guard_name) guard_struct = guard_client.get_guard(decoded_guard_name) if guard_struct is None: - raise HTTPException(status_code=404, detail=f"A Guard with the name {decoded_guard_name} does not exist!") + raise HTTPException( + status_code=404, + detail=f"A Guard with the name {decoded_guard_name} does not exist!", + ) - guard = Guard.from_dict(guard_struct.to_dict()) if not isinstance(guard_struct, Guard) else guard_struct + guard = ( + Guard.from_dict(guard_struct.to_dict()) + if not isinstance(guard_struct, Guard) + else guard_struct + ) stream = payload.get("stream", False) - has_tool_gd_tool_call = any(tool.get("function", {}).get("name") == "gd_response_tool" for tool in payload.get("tools", [])) + has_tool_gd_tool_call = any( + tool.get("function", {}).get("name") == "gd_response_tool" + for tool in payload.get("tools", []) + ) if not stream: validation_outcome: ValidationOutcome = guard(num_reasks=0, **payload) @@ -103,23 +132,29 @@ async def openai_v1_chat_completions(guard_name: str, request: Request): ) return JSONResponse(content=result) else: + async def openai_streamer(): guard_stream = guard(num_reasks=0, **payload) for result in guard_stream: - chunk = json.dumps(outcome_to_stream_response(validation_outcome=result)) + chunk = json.dumps( + outcome_to_stream_response(validation_outcome=result) + ) yield f"data: {chunk}\n\n" yield "\n" return StreamingResponse(openai_streamer(), media_type="text/event-stream") + @router.post("/guards/{guard_name}/validate") @handle_error async def validate(guard_name: str, request: Request): payload = await request.json() - openai_api_key = request.headers.get("x-openai-api-key", os.environ.get("OPENAI_API_KEY")) + openai_api_key = request.headers.get( + "x-openai-api-key", os.environ.get("OPENAI_API_KEY") + ) decoded_guard_name = unquote_plus(guard_name) guard_struct = guard_client.get_guard(decoded_guard_name) - + llm_output = payload.pop("llmOutput", None) num_reasks = payload.pop("numReasks", None) prompt_params = payload.pop("promptParams", {}) @@ -132,25 +167,33 @@ async def validate(guard_name: str, request: Request): if llm_api is not None: llm_api = get_llm_callable(llm_api) if openai_api_key is None: - raise HTTPException(status_code=400, detail="Cannot perform calls to OpenAI without an api key.") + raise HTTPException( + status_code=400, + detail="Cannot perform calls to OpenAI without an api key.", + ) guard = guard_struct is_async = inspect.iscoroutinefunction(llm_api) - + if not isinstance(guard_struct, Guard): if is_async: guard = AsyncGuard.from_dict(guard_struct.to_dict()) else: guard: Guard = Guard.from_dict(guard_struct.to_dict()) elif is_async: - guard:Guard = AsyncGuard.from_dict(guard_struct.to_dict()) + guard: Guard = AsyncGuard.from_dict(guard_struct.to_dict()) if llm_api is None and num_reasks and num_reasks > 1: - raise HTTPException(status_code=400, detail="Cannot perform re-asks without an LLM API. Specify llm_api when calling guard(...).") + raise HTTPException( + status_code=400, + detail="Cannot perform re-asks without an LLM API. Specify llm_api when calling guard(...).", + ) if llm_output is not None: if stream: - raise HTTPException(status_code=400, detail="Streaming is not supported for parse calls!") + raise HTTPException( + status_code=400, detail="Streaming is not supported for parse calls!" + ) result: ValidationOutcome = guard.parse( llm_output=llm_output, num_reasks=num_reasks, @@ -160,6 +203,7 @@ async def validate(guard_name: str, request: Request): ) else: if stream: + async def guard_streamer(): guard_stream = guard( llm_api=llm_api, @@ -170,7 +214,9 @@ async def guard_streamer(): **payload, ) for result in guard_stream: - validation_output = ValidationOutcome.from_guard_history(guard.history.last) + validation_output = ValidationOutcome.from_guard_history( + guard.history.last + ) yield validation_output, result async def validate_streamer(guard_iter): @@ -201,7 +247,9 @@ async def validate_streamer(guard_iter): cache_key = f"{guard.name}-{final_validation_output.call_id}" await cache_client.set(cache_key, serialized_history, 300) - return StreamingResponse(validate_streamer(guard_streamer()), media_type="application/json") + return StreamingResponse( + validate_streamer(guard_streamer()), media_type="application/json" + ) else: result: ValidationOutcome = guard( llm_api=llm_api, @@ -216,12 +264,14 @@ async def validate_streamer(guard_iter): # await cache_client.set(cache_key, serialized_history, 300) return result.to_dict() + @router.get("/guards/{guard_name}/history/{call_id}") @handle_error async def guard_history(guard_name: str, call_id: str): cache_key = f"{guard_name}-{call_id}" return await cache_client.get(cache_key) + def collect_telemetry( *, guard: Guard, diff --git a/guardrails_api/api/root.py b/guardrails_api/api/root.py index 2b0a73f..d818923 100644 --- a/guardrails_api/api/root.py +++ b/guardrails_api/api/root.py @@ -1,7 +1,5 @@ import os -import json from string import Template -from typing import Dict from fastapi import HTTPException, APIRouter from fastapi.responses import HTMLResponse, JSONResponse @@ -13,16 +11,20 @@ from guardrails_api.clients.postgres_client import PostgresClient, postgres_is_enabled from guardrails_api.utils.logger import logger + class HealthCheckResponse(BaseModel): status: int message: str + router = APIRouter() + @router.get("/") async def home(): return "Hello, FastAPI!" + @router.get("/health-check", response_model=HealthCheckResponse) async def health_check(): try: @@ -32,19 +34,21 @@ async def health_check(): pg_client = PostgresClient() query = text("SELECT count(datid) FROM pg_stat_activity;") response = pg_client.db.session.execute(query).all() - + logger.info("response: %s", response) - + return HealthCheck(200, "Ok").to_dict() except Exception as e: logger.error(f"Health check failed: {str(e)}") raise HTTPException(status_code=500, detail="Internal Server Error") + @router.get("/api-docs", response_class=JSONResponse) async def api_docs(): api_spec = get_open_api_spec() return JSONResponse(content=api_spec) + @router.get("/docs", response_class=HTMLResponse) async def docs(): host = os.environ.get("SELF_ENDPOINT", "http://localhost:8000") diff --git a/guardrails_api/app.py b/guardrails_api/app.py index 7e05bf0..64e5180 100644 --- a/guardrails_api/app.py +++ b/guardrails_api/app.py @@ -3,10 +3,11 @@ from fastapi.responses import JSONResponse from guardrails import configure_logging from guardrails_api.clients.cache_client import CacheClient -from guardrails_api.clients.cache_client import CacheClient from guardrails_api.clients.postgres_client import postgres_is_enabled from guardrails_api.otel import otel_is_disabled, initialize -from guardrails_api.utils.trace_server_start_if_enabled import trace_server_start_if_enabled +from guardrails_api.utils.trace_server_start_if_enabled import ( + trace_server_start_if_enabled, +) from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from rich.console import Console from rich.rule import Rule @@ -68,6 +69,7 @@ # else: # return await call_next(request) + # Custom JSON encoder class CustomJSONEncoder(json.JSONEncoder): def default(self, o): @@ -77,6 +79,7 @@ def default(self, o): return str(o) return super().default(o) + # Custom middleware for reverse proxy class ReverseProxyMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): @@ -86,6 +89,7 @@ async def dispatch(self, request: Request, call_next): response = await call_next(request) return response + def register_config(config: Optional[str] = None): default_config_file = os.path.join(os.getcwd(), "./config.py") config_file = config or default_config_file @@ -95,6 +99,7 @@ def register_config(config: Optional[str] = None): config_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(config_module) + def create_app( env: Optional[str] = None, config: Optional[str] = None, port: Optional[int] = None ): @@ -126,7 +131,7 @@ def create_app( FastAPIInstrumentor.instrument_app(app) # app.add_middleware(ProfilingMiddleware) - + # Add CORS middleware app.add_middleware( CORSMiddleware, @@ -169,10 +174,10 @@ async def value_error_handler(request: Request, exc: ValueError): content={"message": str(exc)}, ) + console.print(f"\n:rocket: Guardrails API is available at {self_endpoint}") console.print( - f"\n:rocket: Guardrails API is available at {self_endpoint}" + f":book: Visit {self_endpoint}/docs to see available API endpoints.\n" ) - console.print(f":book: Visit {self_endpoint}/docs to see available API endpoints.\n") console.print(":green_circle: Active guards and OpenAI compatible endpoints:") @@ -180,14 +185,20 @@ async def value_error_handler(request: Request, exc: ValueError): for g in guards: g_dict = g.to_dict() - console.print(f"- Guard: [bold white]{g_dict.get('name')}[/bold white] {self_endpoint}/guards/{g_dict.get('name')}/openai/v1") + console.print( + f"- Guard: [bold white]{g_dict.get('name')}[/bold white] {self_endpoint}/guards/{g_dict.get('name')}/openai/v1" + ) console.print("") - console.print(Rule("[bold grey]Server Logs[/bold grey]", characters="=", style="white")) + console.print( + Rule("[bold grey]Server Logs[/bold grey]", characters="=", style="white") + ) return app + if __name__ == "__main__": import uvicorn + app = create_app() uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/guardrails_api/cli/start.py b/guardrails_api/cli/start.py index eb8f027..6ac66fa 100644 --- a/guardrails_api/cli/start.py +++ b/guardrails_api/cli/start.py @@ -4,6 +4,7 @@ from guardrails_api.app import create_app from guardrails_api.utils.configuration import valid_configuration + @cli.command("start") def start( env: Optional[str] = typer.Option( diff --git a/guardrails_api/clients/cache_client.py b/guardrails_api/clients/cache_client.py index fa52d3a..4d80a63 100644 --- a/guardrails_api/clients/cache_client.py +++ b/guardrails_api/clients/cache_client.py @@ -1,7 +1,7 @@ import threading from fastapi import FastAPI -from aiocache import caches, Cache -from aiocache.serializers import JsonSerializer +from aiocache import caches + class CacheClient: _instance = None @@ -15,16 +15,16 @@ def __new__(cls): return cls._instance def initialize(self, app: FastAPI): - caches.set_config({ - 'default': { - 'cache': "aiocache.SimpleMemoryCache", - 'serializer': { - 'class': "aiocache.serializers.JsonSerializer" - }, - 'ttl': 300 + caches.set_config( + { + "default": { + "cache": "aiocache.SimpleMemoryCache", + "serializer": {"class": "aiocache.serializers.JsonSerializer"}, + "ttl": 300, + } } - }) - self.cache = caches.get('default') + ) + self.cache = caches.get("default") async def get(self, key: str): return await self.cache.get(key) @@ -36,4 +36,4 @@ async def delete(self, key: str): await self.cache.delete(key) async def clear(self): - await self.cache.clear() \ No newline at end of file + await self.cache.clear() diff --git a/guardrails_api/utils/configuration.py b/guardrails_api/utils/configuration.py index 1fdb965..f793a68 100644 --- a/guardrails_api/utils/configuration.py +++ b/guardrails_api/utils/configuration.py @@ -2,21 +2,31 @@ from typing import Optional import os -def valid_configuration(config: Optional[str]=""): + +def valid_configuration(config: Optional[str] = ""): default_config_file = os.path.join(os.getcwd(), "./config.py") default_config_file_path = os.path.abspath(default_config_file) - # If config.py is not present and + # If config.py is not present and # if a config filepath is not passed and - # if postgres is not there (i.e. we’re using in-mem db) + # if postgres is not there (i.e. we’re using in-mem db) # then raise ConfigurationError has_default_config_file = os.path.isfile(default_config_file_path) - has_config_file = (config != "" and config is not None) and os.path.isfile(os.path.abspath(config)) - if not has_default_config_file and not has_config_file and not postgres_is_enabled(): - raise ConfigurationError("Can not start. Configuration not provided and default" - " configuration not found and postgres is not enabled.") + has_config_file = (config != "" and config is not None) and os.path.isfile( + os.path.abspath(config) + ) + if ( + not has_default_config_file + and not has_config_file + and not postgres_is_enabled() + ): + raise ConfigurationError( + "Can not start. Configuration not provided and default" + " configuration not found and postgres is not enabled." + ) return True + class ConfigurationError(Exception): - pass \ No newline at end of file + pass diff --git a/guardrails_api/utils/handle_error.py b/guardrails_api/utils/handle_error.py index cf7cb07..1458d9d 100644 --- a/guardrails_api/utils/handle_error.py +++ b/guardrails_api/utils/handle_error.py @@ -7,6 +7,7 @@ from fastapi import HTTPException + def handle_error(func=None): def decorator(fn): @wraps(fn) @@ -15,22 +16,33 @@ async def wrapper(*args, **kwargs): return await fn(*args, **kwargs) except ValidationError as validation_error: logger.error(validation_error) - traceback.print_exception(type(validation_error), validation_error, validation_error.__traceback__) + traceback.print_exception( + type(validation_error), + validation_error, + validation_error.__traceback__, + ) raise HTTPException(status_code=400, detail=str(validation_error)) except HttpError as http_error: logger.error(http_error) - traceback.print_exception(type(http_error), http_error, http_error.__traceback__) - raise HTTPException(status_code=http_error.status_code, detail=http_error.detail) + traceback.print_exception( + type(http_error), http_error, http_error.__traceback__ + ) + raise HTTPException( + status_code=http_error.status_code, detail=http_error.detail + ) except HTTPException as http_exception: logger.error(http_exception) - traceback.print_exception(type(http_exception), http_exception, http_exception.__traceback__) + traceback.print_exception( + type(http_exception), http_exception, http_exception.__traceback__ + ) raise except Exception as e: logger.error(e) traceback.print_exception(type(e), e, e.__traceback__) raise HTTPException(status_code=500, detail="Internal Server Error") + return wrapper - + if func: return decorator(func) return decorator diff --git a/guardrails_api/utils/has_internet_connection.py b/guardrails_api/utils/has_internet_connection.py index 8a7099c..1a92721 100644 --- a/guardrails_api/utils/has_internet_connection.py +++ b/guardrails_api/utils/has_internet_connection.py @@ -7,4 +7,4 @@ def has_internet_connection() -> bool: res.raise_for_status() return True except requests.ConnectionError: - return False \ No newline at end of file + return False diff --git a/guardrails_api/utils/openai.py b/guardrails_api/utils/openai.py index 10ecfa2..79cc7b5 100644 --- a/guardrails_api/utils/openai.py +++ b/guardrails_api/utils/openai.py @@ -1,5 +1,6 @@ from guardrails.classes import ValidationOutcome + def outcome_to_stream_response(validation_outcome: ValidationOutcome): stream_chunk_template = { "choices": [ diff --git a/guardrails_api/utils/trace_server_start_if_enabled.py b/guardrails_api/utils/trace_server_start_if_enabled.py index 467abd6..91fbbcf 100644 --- a/guardrails_api/utils/trace_server_start_if_enabled.py +++ b/guardrails_api/utils/trace_server_start_if_enabled.py @@ -8,6 +8,7 @@ def trace_server_start_if_enabled(): config = Credentials.from_rc_file() if config.enable_metrics is True and has_internet_connection(): from guardrails.utils.hub_telemetry_utils import HubTelemetry + HubTelemetry().create_new_span( "guardrails-api/start", [ @@ -21,4 +22,4 @@ def trace_server_start_if_enabled(): ], True, False, - ) \ No newline at end of file + ) diff --git a/tests/api/test_guards.py b/tests/api/test_guards.py index 83ec28f..453a976 100644 --- a/tests/api/test_guards.py +++ b/tests/api/test_guards.py @@ -1,6 +1,5 @@ import os from unittest.mock import PropertyMock -from typing import Dict, Tuple import pytest from fastapi.testclient import TestClient @@ -11,16 +10,16 @@ from guardrails.classes.history import Call, Iteration from guardrails.errors import ValidationError -# Assuming these imports exist in your FastAPI project from guardrails_api.app import register_config from tests.mocks.mock_guard_client import MockGuardStruct +from guardrails_api.api.guards import router as guards_router # TODO: Should we mock this somehow? # Right now it's just empty, but it technically does a file read register_config() app = FastAPI() -from guardrails_api.api.guards import router as guards_router + app.include_router(guards_router) client = TestClient(app) @@ -31,6 +30,7 @@ "history": Stack(), } + @pytest.fixture(autouse=True) def around_each(): # Code that will run before the test @@ -42,6 +42,7 @@ def around_each(): if openai_api_key_bak: os.environ["OPENAI_API_KEY"] = openai_api_key_bak + def test_guards__get(mocker): mock_guard = MockGuardStruct() mock_get_guards = mocker.patch( @@ -56,27 +57,27 @@ def test_guards__get(mocker): assert response.status_code == 200 assert response.json() == [MOCK_GUARD_STRING] + def test_guards__post_pg(mocker): os.environ["PGHOST"] = "localhost" mock_guard = MockGuardStruct() - mock_from_request = mocker.patch( + mocker.patch( "guardrails_api.api.guards.GuardStruct.from_dict", return_value=mock_guard, ) - mock_create_guard = mocker.patch( + mocker.patch( "guardrails_api.api.guards.guard_client.create_guard", return_value=mock_guard, ) response = client.post("/guards", json=mock_guard.to_dict()) - # mock_from_request.assert_called_once_with(mock_guard.to_dict()) - # mock_create_guard.assert_called_once_with(mock_guard) assert response.status_code == 200 assert response.json() == MOCK_GUARD_STRING del os.environ["PGHOST"] + def test_guards__post_mem(mocker): old = None if "PGHOST" in os.environ: @@ -88,9 +89,10 @@ def test_guards__post_mem(mocker): assert response.status_code == 501 assert "Not Implemented" in response.json()["detail"] - if (old): + if old: os.environ["PGHOST"] = old + def test_guard__get_mem(mocker): mock_guard = MockGuardStruct() timestamp = "2024-03-04T14:11:42-06:00" @@ -105,6 +107,7 @@ def test_guard__get_mem(mocker): assert response.status_code == 200 assert response.json() == MOCK_GUARD_STRING + def test_guard__put_pg(mocker): os.environ["PGHOST"] = "localhost" mock_guard = MockGuardStruct() @@ -129,6 +132,7 @@ def test_guard__put_pg(mocker): assert response.json() == MOCK_GUARD_STRING del os.environ["PGHOST"] + def test_guard__delete_pg(mocker): os.environ["PGHOST"] = "localhost" mock_guard = MockGuardStruct() @@ -144,6 +148,7 @@ def test_guard__delete_pg(mocker): assert response.json() == MOCK_GUARD_STRING del os.environ["PGHOST"] + def test_validate__parse(mocker): os.environ["PGHOST"] = "localhost" mock_outcome = ValidationOutcome( @@ -171,11 +176,10 @@ def test_validate__parse(mocker): mock_status.return_value = "pass" mock_guard.history = Stack(Call()) - response = client.post("/guards/My%20Guard's%20Name/validate", json={ - "llmOutput": "Hello world!", - "args": [1, 2, 3], - "some_kwarg": "foo" - }) + response = client.post( + "/guards/My%20Guard's%20Name/validate", + json={"llmOutput": "Hello world!", "args": [1, 2, 3], "some_kwarg": "foo"}, + ) mock_get_guard.assert_called_once_with("My Guard's Name") assert mock_parse.call_count == 1 @@ -198,6 +202,7 @@ def test_validate__parse(mocker): del os.environ["PGHOST"] + def test_validate__call(mocker): os.environ["PGHOST"] = "localhost" mock_outcome = ValidationOutcome( @@ -225,12 +230,16 @@ def test_validate__call(mocker): mock_status.return_value = "fail" mock_guard.history = Stack(Call()) - response = client.post("/guards/My%20Guard's%20Name/validate", json={ - "promptParams": {"p1": "bar"}, - "args": [1, 2, 3], - "some_kwarg": "foo", - "prompt": "Hello world!" - }, headers={"x-openai-api-key": "mock-key"}) + response = client.post( + "/guards/My%20Guard's%20Name/validate", + json={ + "promptParams": {"p1": "bar"}, + "args": [1, 2, 3], + "some_kwarg": "foo", + "prompt": "Hello world!", + }, + headers={"x-openai-api-key": "mock-key"}, + ) mock_get_guard.assert_called_once_with("My Guard's Name") assert mock___call__.call_count == 1 @@ -256,6 +265,7 @@ def test_validate__call(mocker): del os.environ["PGHOST"] + def test_validate__call_throws_validation_error(mocker): os.environ["PGHOST"] = "localhost" error = ValidationError("Test guard validation error") @@ -277,12 +287,15 @@ def test_validate__call_throws_validation_error(mocker): mock_status.return_value = "fail" mock_guard.history = Stack(Call()) - response = client.post("/guards/My%20Guard's%20Name/validate", json={ - "promptParams": {"p1": "bar"}, - "args": [1, 2, 3], - "some_kwarg": "foo", - "prompt": "Hello world!" - }) + response = client.post( + "/guards/My%20Guard's%20Name/validate", + json={ + "promptParams": {"p1": "bar"}, + "args": [1, 2, 3], + "some_kwarg": "foo", + "prompt": "Hello world!", + }, + ) mock_get_guard.assert_called_once_with("My Guard's Name") @@ -291,6 +304,7 @@ def test_validate__call_throws_validation_error(mocker): del os.environ["PGHOST"] + def test_openai_v1_chat_completions__raises_404(mocker): os.environ["PGHOST"] = "localhost" mock_guard = None @@ -300,17 +314,25 @@ def test_openai_v1_chat_completions__raises_404(mocker): return_value=mock_guard, ) - response = client.post("/guards/My%20Guard's%20Name/openai/v1/chat/completions", json={ - "messages": [{"role":"user", "content":"Hello world!"}], - }, headers={"x-openai-api-key": "mock-key"}) + response = client.post( + "/guards/My%20Guard's%20Name/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hello world!"}], + }, + headers={"x-openai-api-key": "mock-key"}, + ) assert response.status_code == 404 - assert response.json()["detail"] == "A Guard with the name My Guard's Name does not exist!" + assert ( + response.json()["detail"] + == "A Guard with the name My Guard's Name does not exist!" + ) mock_get_guard.assert_called_once_with("My Guard's Name") del os.environ["PGHOST"] + def test_openai_v1_chat_completions__call(mocker): os.environ["PGHOST"] = "localhost" mock_guard = MockGuardStruct() @@ -337,18 +359,22 @@ def test_openai_v1_chat_completions__call(mocker): ) mock_status.return_value = "fail" mock_call = Call() - mock_call.iterations= Stack(Iteration('some-id', 1)) + mock_call.iterations = Stack(Iteration("some-id", 1)) mock_guard.history = Stack(mock_call) - response = client.post("/guards/My%20Guard's%20Name/openai/v1/chat/completions", json={ - "messages": [{"role":"user", "content":"Hello world!"}], - }, headers={"x-openai-api-key": "mock-key"}) + response = client.post( + "/guards/My%20Guard's%20Name/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hello world!"}], + }, + headers={"x-openai-api-key": "mock-key"}, + ) mock_get_guard.assert_called_once_with("My Guard's Name") assert mock___call__.call_count == 1 mock___call__.assert_called_once_with( num_reasks=0, - messages=[{"role":"user", "content":"Hello world!"}], + messages=[{"role": "user", "content": "Hello world!"}], ) assert response.status_code == 200 diff --git a/tests/api/test_root.py b/tests/api/test_root.py index affbdb2..c0b4029 100644 --- a/tests/api/test_root.py +++ b/tests/api/test_root.py @@ -9,16 +9,19 @@ # Assuming you have a similar structure in your FastAPI app from guardrails_api.api import root + @pytest.fixture def app(): app = FastAPI() app.include_router(root.router) return app + @pytest.fixture def client(app): return TestClient(app) + def test_home(client): response = client.get("/") assert response.status_code == 200 @@ -31,6 +34,7 @@ def test_home(client): assert "/openapi.json" in routes # This is FastAPI's equivalent to /api-docs assert "/docs" in routes + def test_health_check(client, mocker): os.environ["PGHOST"] = "localhost" @@ -56,4 +60,4 @@ def text_side_effect(query: str): assert response.json() == {"status": 200, "message": "Ok"} - del os.environ["PGHOST"] \ No newline at end of file + del os.environ["PGHOST"] diff --git a/tests/cli/test_start.py b/tests/cli/test_start.py index 2fd10da..e6973d9 100644 --- a/tests/cli/test_start.py +++ b/tests/cli/test_start.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock import os + def test_start(mocker): mocker.patch("guardrails_api.cli.start.cli") @@ -10,6 +11,7 @@ def test_start(mocker): ) from guardrails_api.cli.start import start + # pg enabled os.environ["PGHOST"] = "localhost" start("env", "config", 8000) diff --git a/tests/mocks/mock_guard_client.py b/tests/mocks/mock_guard_client.py index 04bb77f..beca0a7 100644 --- a/tests/mocks/mock_guard_client.py +++ b/tests/mocks/mock_guard_client.py @@ -3,6 +3,7 @@ from pydantic import ConfigDict from guardrails.classes.generic import Stack + class MockGuardStruct(GuardStruct): # Pydantic Config model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/tests/utils/test_configuration.py b/tests/utils/test_configuration.py index 635893a..0ad5098 100644 --- a/tests/utils/test_configuration.py +++ b/tests/utils/test_configuration.py @@ -2,15 +2,16 @@ import pytest from guardrails_api.utils.configuration import valid_configuration, ConfigurationError + def test_valid_configuration(mocker): with pytest.raises(ConfigurationError): valid_configuration() - + # pg enabled os.environ["PGHOST"] = "localhost" valid_configuration("config.py") os.environ.pop("PGHOST") - + # custom config mock_isfile = mocker.patch("os.path.isfile") mock_isfile.side_effect = [False, True] @@ -20,7 +21,7 @@ def test_valid_configuration(mocker): mock_isfile.side_effect = [False, False] with pytest.raises(ConfigurationError): valid_configuration("") - + # default config mock_isfile = mocker.patch("os.path.isfile") mock_isfile.side_effect = [True, False] From 2e3f050140790cc1d430754a5d008d4ae91ecbb5 Mon Sep 17 00:00:00 2001 From: David Tam Date: Fri, 13 Sep 2024 13:49:11 -0700 Subject: [PATCH 03/13] update version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d7ecddd..d3d3d94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "opentelemetry-instrumentation-fastapi>=0.47b0", "requests>=2.32.3", "aiocache>=0.11.1", - "fastapi", + "fastapi>=0.114.1", ] [tool.setuptools.dynamic] From 1ab9e5219955bc7a4d0a7bfc15eeb0383cc9c6c8 Mon Sep 17 00:00:00 2001 From: David Tam Date: Fri, 13 Sep 2024 22:17:08 -0700 Subject: [PATCH 04/13] fix pg support --- guardrails_api/app.py | 1 - guardrails_api/clients/pg_guard_client.py | 38 +++++--- guardrails_api/clients/postgres_client.py | 101 +++++++++++++++++----- guardrails_api/models/__init__.py | 5 ++ guardrails_api/models/base.py | 8 -- guardrails_api/models/guard_item.py | 4 +- guardrails_api/models/guard_item_audit.py | 30 +------ guardrails_api/start-dev.sh | 2 +- pyproject.toml | 3 +- requirements-lock.txt | 8 +- tests/cli/test_start.py | 4 +- 11 files changed, 119 insertions(+), 85 deletions(-) delete mode 100644 guardrails_api/models/base.py diff --git a/guardrails_api/app.py b/guardrails_api/app.py index 64e5180..27512f7 100644 --- a/guardrails_api/app.py +++ b/guardrails_api/app.py @@ -18,7 +18,6 @@ import json import os - # from pyinstrument import Profiler # from pyinstrument.renderers.html import HTMLRenderer # from pyinstrument.renderers.speedscope import SpeedscopeRenderer diff --git a/guardrails_api/clients/pg_guard_client.py b/guardrails_api/clients/pg_guard_client.py index c7a1f48..226232a 100644 --- a/guardrails_api/clients/pg_guard_client.py +++ b/guardrails_api/clients/pg_guard_client.py @@ -18,14 +18,20 @@ def __init__(self): self.initialized = True self.pgClient = PostgresClient() + def get_db(self): # generator for local sessions + db = self.pgClient.SessionLocal() + try: + yield db + finally: + db.close() + def get_guard(self, guard_name: str, as_of_date: str = None) -> GuardStruct: - latest_guard_item = ( - self.pgClient.db.session.query(GuardItem).filter_by(name=guard_name).first() - ) + db = next(self.get_db()) + latest_guard_item = db.query(GuardItem).filter_by(name=guard_name).first() audit_item = None if as_of_date is not None: audit_item = ( - self.pgClient.db.session.query(GuardItemAudit) + db.query(GuardItemAudit) .filter_by(name=guard_name) .filter(GuardItemAudit.replaced_on > as_of_date) .order_by(GuardItemAudit.replaced_on.asc()) @@ -43,27 +49,29 @@ def get_guard(self, guard_name: str, as_of_date: str = None) -> GuardStruct: return from_guard_item(guard_item) def get_guard_item(self, guard_name: str) -> GuardItem: - return ( - self.pgClient.db.session.query(GuardItem).filter_by(name=guard_name).first() - ) + db = next(self.get_db()) + return db.query(GuardItem).filter_by(name=guard_name).first() def get_guards(self) -> List[GuardStruct]: - guard_items = self.pgClient.db.session.query(GuardItem).all() + db = next(self.get_db()) + guard_items = db.query(GuardItem).all() return [from_guard_item(gi) for gi in guard_items] def create_guard(self, guard: GuardStruct) -> GuardStruct: + db = next(self.get_db()) guard_item = GuardItem( name=guard.name, railspec=guard.to_dict(), num_reasks=None, description=guard.description, ) - self.pgClient.db.session.add(guard_item) - self.pgClient.db.session.commit() + db.add(guard_item) + db.commit() return from_guard_item(guard_item) def update_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct: + db = next(self.get_db()) guard_item = self.get_guard_item(guard_name) if guard_item is None: raise HttpError( @@ -76,21 +84,23 @@ def update_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct: # guard_item.num_reasks = guard.num_reasks guard_item.railspec = guard.to_dict() guard_item.description = guard.description - self.pgClient.db.session.commit() + db.commit() return from_guard_item(guard_item) def upsert_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct: + db = next(self.get_db()) guard_item = self.get_guard_item(guard_name) if guard_item is not None: guard_item.railspec = guard.to_dict() guard_item.description = guard.description # guard_item.num_reasks = guard.num_reasks - self.pgClient.db.session.commit() + db.commit() return from_guard_item(guard_item) else: return self.create_guard(guard) def delete_guard(self, guard_name: str) -> GuardStruct: + db = next(self.get_db()) guard_item = self.get_guard_item(guard_name) if guard_item is None: raise HttpError( @@ -100,7 +110,7 @@ def delete_guard(self, guard_name: str) -> GuardStruct: guard_name=guard_name ), ) - self.pgClient.db.session.delete(guard_item) - self.pgClient.db.session.commit() + db.delete(guard_item) + db.commit() guard = from_guard_item(guard_item) return guard diff --git a/guardrails_api/clients/postgres_client.py b/guardrails_api/clients/postgres_client.py index 951a4f4..56e1501 100644 --- a/guardrails_api/clients/postgres_client.py +++ b/guardrails_api/clients/postgres_client.py @@ -2,16 +2,24 @@ import json import os import threading -from flask import Flask -from sqlalchemy import text +from fastapi import FastAPI from typing import Tuple -from guardrails_api.models.base import db, INIT_EXTENSIONS +from sqlalchemy import create_engine, text +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +Base = declarative_base() def postgres_is_enabled() -> bool: return os.environ.get("PGHOST", None) is not None +# Global variables for database session +postgres_client = None +SessionLocal = None + + class PostgresClient: _instance = None _lock = threading.Lock() @@ -45,7 +53,17 @@ def get_pg_creds(self) -> Tuple[str, str]: pg_password = pg_password or os.environ.get("PGPASSWORD") return pg_user, pg_password - def initialize(self, app: Flask): + def get_db(self): + if postgres_is_enabled(): + db = self.SessionLocal() + try: + yield db + finally: + db.close() + else: + yield None + + def initialize(self, app: FastAPI): pg_user, pg_password = self.get_pg_creds() pg_host = os.environ.get("PGHOST", "localhost") pg_port = os.environ.get("PGPORT", "5432") @@ -64,23 +82,64 @@ def initialize(self, app: Flask): if os.environ.get("NODE_ENV") == "production": conf = f"{conf}?sslmode=verify-ca&sslrootcert=global-bundle.pem" - app.config["SQLALCHEMY_DATABASE_URI"] = conf + engine = create_engine(conf) + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False - app.secret_key = "secret" self.app = app - self.db = db - db.init_app(app) - from guardrails_api.models.guard_item import GuardItem # NOQA - from guardrails_api.models.guard_item_audit import ( # NOQA - GuardItemAudit, - AUDIT_FUNCTION, - AUDIT_TRIGGER, - ) + self.engine = engine + self.SessionLocal = SessionLocal + # Create tables + from guardrails_api.models import GuardItem, GuardItemAudit # noqa + + Base.metadata.create_all(bind=engine) + + # Execute custom SQL + with engine.connect() as connection: + connection.execute(text(INIT_EXTENSIONS)) + connection.execute(text(AUDIT_FUNCTION)) + connection.execute(text(AUDIT_TRIGGER)) + connection.commit() + + +# Define INIT_EXTENSIONS, AUDIT_FUNCTION, and AUDIT_TRIGGER here as they were in your original code +INIT_EXTENSIONS = """ +-- Your SQL for initializing extensions +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'uuid-ossp') THEN + CREATE EXTENSION "uuid-ossp"; + END IF; +END $$; + +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN + CREATE EXTENSION "vector"; + END IF; +END $$; +""" + +AUDIT_FUNCTION = """ +CREATE OR REPLACE FUNCTION guard_audit_function() RETURNS TRIGGER AS $guard_audit$ +BEGIN + IF (TG_OP = 'DELETE') THEN + INSERT INTO guards_audit SELECT uuid_generate_v4(), OLD.*, now(), 'D'; + ELSIF (TG_OP = 'UPDATE') THEN + INSERT INTO guards_audit SELECT uuid_generate_v4(), OLD.*, now(), 'U'; + ELSIF (TG_OP = 'INSERT') THEN + INSERT INTO guards_audit SELECT uuid_generate_v4(), NEW.*, now(), 'I'; + END IF; + RETURN null; +END; +$guard_audit$ +LANGUAGE plpgsql; +""" - with self.app.app_context(): - self.db.session.execute(text(INIT_EXTENSIONS)) - self.db.create_all() - self.db.session.execute(text(AUDIT_FUNCTION)) - self.db.session.execute(text(AUDIT_TRIGGER)) - self.db.session.commit() +AUDIT_TRIGGER = """ +DROP TRIGGER IF EXISTS guard_audit_trigger + ON guards; +CREATE TRIGGER guard_audit_trigger + AFTER INSERT OR UPDATE OR DELETE ON guards + FOR EACH ROW + EXECUTE PROCEDURE guard_audit_function(); +""" diff --git a/guardrails_api/models/__init__.py b/guardrails_api/models/__init__.py index e69de29..391a299 100644 --- a/guardrails_api/models/__init__.py +++ b/guardrails_api/models/__init__.py @@ -0,0 +1,5 @@ +# __init__.py +from .guard_item_audit import GuardItemAudit +from .guard_item import GuardItem + +__all__ = ["GuardItemAudit", "GuardItem"] diff --git a/guardrails_api/models/base.py b/guardrails_api/models/base.py deleted file mode 100644 index 29f0169..0000000 --- a/guardrails_api/models/base.py +++ /dev/null @@ -1,8 +0,0 @@ -from flask_sqlalchemy import SQLAlchemy - -db = SQLAlchemy() - -INIT_EXTENSIONS = """ -CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; -CREATE EXTENSION IF NOT EXISTS "vector"; -""" diff --git a/guardrails_api/models/guard_item.py b/guardrails_api/models/guard_item.py index a2dbf32..6bcacfa 100644 --- a/guardrails_api/models/guard_item.py +++ b/guardrails_api/models/guard_item.py @@ -1,9 +1,9 @@ from sqlalchemy import Column, String, Integer from sqlalchemy.dialects.postgresql import JSONB -from guardrails_api.models.base import db +from guardrails_api.clients.postgres_client import Base -class GuardItem(db.Model): +class GuardItem(Base): __tablename__ = "guards" # TODO: Make primary key a composite between guard.name and the guard owner's userId name = Column(String, primary_key=True) diff --git a/guardrails_api/models/guard_item_audit.py b/guardrails_api/models/guard_item_audit.py index 183626e..13ee3c2 100644 --- a/guardrails_api/models/guard_item_audit.py +++ b/guardrails_api/models/guard_item_audit.py @@ -1,9 +1,9 @@ from sqlalchemy import Column, String, Integer from sqlalchemy.dialects.postgresql import JSONB, TIMESTAMP, CHAR -from guardrails_api.models.base import db +from guardrails_api.clients.postgres_client import Base -class GuardItemAudit(db.Model): +class GuardItemAudit(Base): __tablename__ = "guards_audit" id = Column(String, primary_key=True) name = Column(String, nullable=False, index=True) @@ -35,29 +35,3 @@ def __init__( self.replaced_on = replaced_on self.operation = operation # self.owner = owner - - -AUDIT_FUNCTION = """ -CREATE OR REPLACE FUNCTION guard_audit_function() RETURNS TRIGGER AS $guard_audit$ -BEGIN - IF (TG_OP = 'DELETE') THEN - INSERT INTO guards_audit SELECT uuid_generate_v4(), OLD.*, now(), 'D'; - ELSIF (TG_OP = 'UPDATE') THEN - INSERT INTO guards_audit SELECT uuid_generate_v4(), OLD.*, now(), 'U'; - ELSIF (TG_OP = 'INSERT') THEN - INSERT INTO guards_audit SELECT uuid_generate_v4(), NEW.*, now(), 'I'; - END IF; - RETURN null; -END; -$guard_audit$ -LANGUAGE plpgsql; -""" - -AUDIT_TRIGGER = """ -DROP TRIGGER IF EXISTS guard_audit_trigger - ON guards; -CREATE TRIGGER guard_audit_trigger - AFTER INSERT OR UPDATE OR DELETE ON guards - FOR EACH ROW - EXECUTE PROCEDURE guard_audit_function(); -""" diff --git a/guardrails_api/start-dev.sh b/guardrails_api/start-dev.sh index 36f33ba..83a2d70 100755 --- a/guardrails_api/start-dev.sh +++ b/guardrails_api/start-dev.sh @@ -1,6 +1,6 @@ gunicorn --bind 0.0.0.0:8000 \ --timeout 120 \ - --workers 3 \ + --workers 2 \ --threads 2 \ --worker-class=uvicorn.workers.UvicornWorker \ "guardrails_api.app:create_app()" \ diff --git a/pyproject.toml b/pyproject.toml index d3d3d94..a2db2fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,6 @@ keywords = ["Guardrails", "Guardrails AI", "Guardrails API", "Guardrails API"] requires-python = ">= 3.8.1" dependencies = [ "guardrails-ai>=0.5.6", - "Flask-SQLAlchemy>=3.1.1,<4", "Werkzeug>=3.0.3,<4", "jsonschema>=4.22.0,<5", "referencing>=0.35.1,<1", @@ -27,6 +26,7 @@ dependencies = [ "requests>=2.32.3", "aiocache>=0.11.1", "fastapi>=0.114.1", + "SQLAlchemy>=2.0.34", ] [tool.setuptools.dynamic] @@ -42,6 +42,7 @@ dev = [ "coverage", "pytest-mock", "gunicorn>=22.0.0,<23", + "uvicorn", ] [tool.pytest.ini_options] diff --git a/requirements-lock.txt b/requirements-lock.txt index 950b069..cacc27e 100644 --- a/requirements-lock.txt +++ b/requirements-lock.txt @@ -28,16 +28,10 @@ frozenlist==1.4.1 fsspec==2024.6.1 googleapis-common-protos==1.63.2 griffe==0.36.9 -<<<<<<< Updated upstream grpcio==1.65.1 -guardrails-ai==0.5.7 +guardrails-ai==0.5.9 guardrails-api-client==0.3.12 guardrails_hub_types==0.0.4 -======= -grpcio==1.64.1 -guardrails-ai==0.5.0a2 -guardrails-api-client==0.3.8 ->>>>>>> Stashed changes gunicorn==22.0.0 h11==0.14.0 httpcore==1.0.5 diff --git a/tests/cli/test_start.py b/tests/cli/test_start.py index e6973d9..befe21a 100644 --- a/tests/cli/test_start.py +++ b/tests/cli/test_start.py @@ -5,9 +5,9 @@ def test_start(mocker): mocker.patch("guardrails_api.cli.start.cli") - mock_flask_app = MagicMock() + mock_app = MagicMock() mock_create_app = mocker.patch( - "guardrails_api.cli.start.create_app", return_value=mock_flask_app + "guardrails_api.cli.start.create_app", return_value=mock_app ) from guardrails_api.cli.start import start From 065a6bc6659a473e2642d5e6c361cef07915a24a Mon Sep 17 00:00:00 2001 From: David Tam Date: Fri, 13 Sep 2024 22:22:51 -0700 Subject: [PATCH 05/13] remove last refs to flask --- requirements-lock.txt | 4 ---- setup.py | 18 ++++++++++-------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/requirements-lock.txt b/requirements-lock.txt index cacc27e..1462843 100644 --- a/requirements-lock.txt +++ b/requirements-lock.txt @@ -19,10 +19,6 @@ diff-match-patch==20230430 distro==1.9.0 Faker==25.9.2 filelock==3.15.4 -Flask==3.0.3 -Flask-Caching==2.3.0 -Flask-Cors==5.0.0 -Flask-SQLAlchemy==3.1.1 fqdn==1.5.1 frozenlist==1.4.1 fsspec==2024.6.1 diff --git a/setup.py b/setup.py index 78528e2..ca6d9f1 100644 --- a/setup.py +++ b/setup.py @@ -16,21 +16,23 @@ packages=find_packages(), python_requires=">=3.8, <4", install_requires=[ - "guardrails-ai>=0.4.5", - "flask>=3.0.3,<4", - "Flask-SQLAlchemy>=3.1.1,<4", + "guardrails-ai>=0.5.6", "Werkzeug>=3.0.3,<4", "jsonschema>=4.22.0,<5", "referencing>=0.35.1,<1", - "Flask-Cors>=4.0.1,<6", "boto3>=1.34.115,<2", "psycopg2-binary>=2.9.9,<3", "litellm>=1.39.3,<2", "typer>=0.9.4,<1", - "opentelemetry-api>1,<2", - "opentelemetry-exporter-otlp-proto-grpc>1,<2", - "opentelemetry-exporter-otlp-proto-http>1,<2", - "opentelemetry-instrumentation-flask>=0.12b0,<1" + "opentelemetry-api>=1.0.0,<2", + "opentelemetry-sdk>=1.0.0,<2", + "opentelemetry-exporter-otlp-proto-grpc>=1.0.0,<2", + "opentelemetry-exporter-otlp-proto-http>=1.0.0,<2", + "opentelemetry-instrumentation-fastapi>=0.47b0", + "requests>=2.32.3", + "aiocache>=0.11.1", + "fastapi>=0.114.1", + "SQLAlchemy>=2.0.34", ], package_data={"guardrails_api": ["py.typed", "open-api-spec.json"]}, ) From 2cfa5f5d9745509e50c506726464ec122a1c9e10 Mon Sep 17 00:00:00 2001 From: David Tam Date: Mon, 16 Sep 2024 10:38:16 -0700 Subject: [PATCH 06/13] fix tests --- tests/clients/test_pg_guard_client.py | 111 +++++++++++++++++++------- tests/mocks/mock_postgres_client.py | 7 ++ 2 files changed, 90 insertions(+), 28 deletions(-) diff --git a/tests/clients/test_pg_guard_client.py b/tests/clients/test_pg_guard_client.py index 0b94224..add0048 100644 --- a/tests/clients/test_pg_guard_client.py +++ b/tests/clients/test_pg_guard_client.py @@ -1,5 +1,5 @@ import pytest -from unittest.mock import ANY as AnyMatcher +from unittest.mock import ANY as AnyMatcher, MagicMock from guardrails_api.classes.http_error import HttpError from guardrails_api.models.guard_item import GuardItem @@ -28,14 +28,19 @@ def test_init(mocker): class TestGetGuard: def test_get_latest(self, mocker): mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) + mock_session = mock_pg_client.SessionLocal() mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, ) + query_spy = mock_session.query + query_spy.return_value = mock_session - query_spy = mocker.spy(mock_pg_client.db.session, "query") - filter_by_spy = mocker.spy(mock_pg_client.db.session, "filter_by") - mock_first = mocker.patch.object(mock_pg_client.db.session, "first") + filter_by_spy = mock_session.filter_by + filter_by_spy.return_value = mock_session + + mock_first = mock_session.first latest_guard = MockGuardStruct() mock_first.return_value = latest_guard @@ -59,16 +64,25 @@ def test_get_latest(self, mocker): def test_with_as_of_date(self, mocker): mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) + mock_session = mock_pg_client.SessionLocal() mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, ) + query_spy = mock_session.query + query_spy.return_value = mock_session + + filter_by_spy = mock_session.filter_by + filter_by_spy.return_value = mock_session + + filter_spy = mock_session.filter + filter_spy.return_value = mock_session - query_spy = mocker.spy(mock_pg_client.db.session, "query") - filter_by_spy = mocker.spy(mock_pg_client.db.session, "filter_by") - filter_spy = mocker.spy(mock_pg_client.db.session, "filter") - order_by_spy = mocker.spy(mock_pg_client.db.session, "order_by") - mock_first = mocker.patch.object(mock_pg_client.db.session, "first") + order_by_spy = mock_session.order_by + order_by_spy.return_value = mock_session + + mock_first = mock_session.first latest_guard = MockGuardStruct(name="latest") previous_guard = MockGuardStruct(name="previous") mock_first.side_effect = [latest_guard, previous_guard] @@ -107,13 +121,20 @@ def test_with_as_of_date(self, mocker): def test_raises_not_found(self, mocker): mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) + mock_session = mock_pg_client.SessionLocal() mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, ) - - mock_first = mocker.patch.object(mock_pg_client.db.session, "first") - mock_first.return_value = None + # Mock the query method to return a mock query object + mock_query = mock_session.query.return_value + + # Mock the filter_by method to return a mock filter object + mock_filter_by = mock_query.filter_by.return_value + mock_first = mock_filter_by.first + # Mock the first method on the mock filter object to return None + mock_filter_by.first.return_value = None mock_from_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.from_guard_item" ) @@ -136,14 +157,20 @@ def test_raises_not_found(self, mocker): def test_get_guard_item(mocker): mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) + mock_session = mock_pg_client.SessionLocal() mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, ) - query_spy = mocker.spy(mock_pg_client.db.session, "query") - filter_by_spy = mocker.spy(mock_pg_client.db.session, "filter_by") - mock_first = mocker.patch.object(mock_pg_client.db.session, "first") + query_spy = mock_session.query + query_spy.return_value = mock_session + + filter_by_spy = mock_session.filter_by + filter_by_spy.return_value = mock_session + + mock_first = mock_session.first latest_guard = MockGuardStruct(name="latest") mock_first.return_value = latest_guard @@ -162,17 +189,23 @@ def test_get_guard_item(mocker): def test_get_guards(mocker): mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) + mock_session = mock_pg_client.SessionLocal() mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, ) - query_spy = mocker.spy(mock_pg_client.db.session, "query") - mock_all = mocker.patch.object(mock_pg_client.db.session, "all") + # Ensure that query returns the mock session itself + mock_session.query.return_value = mock_session + query_spy = mock_session.query + guard_one = MockGuardStruct(name="guard one") guard_two = MockGuardStruct(name="guard two") guards = [guard_one, guard_two] - mock_all.return_value = guards + # Mock the all method on the mock session + mock_session.all.return_value = guards + mock_all = mock_session.all mock_from_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.from_guard_item" @@ -198,7 +231,9 @@ def test_get_guards(mocker): def test_create_guard(mocker): mock_guard = MockGuardStruct() mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) mock_guard_struct_init_spy = mocker.spy(MockGuardStruct, "__init__") + mock_session = mock_pg_client.SessionLocal() mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, @@ -207,8 +242,8 @@ def test_create_guard(mocker): "guardrails_api.clients.pg_guard_client.GuardItem", new=MockGuardStruct ) - add_spy = mocker.spy(mock_pg_client.db.session, "add") - commit_spy = mocker.spy(mock_pg_client.db.session, "commit") + add_spy = mocker.spy(mock_session, "add") + commit_spy = mocker.spy(mock_session, "commit") mock_from_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.from_guard_item" @@ -244,6 +279,8 @@ class TestUpdateGuard: def test_raises_not_found(self, mocker): mock_guard = MockGuardStruct() mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) + mock_session = mock_pg_client.SessionLocal() mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, @@ -253,7 +290,7 @@ def test_raises_not_found(self, mocker): ) mock_get_guard_item.return_value = None - commit_spy = mocker.spy(mock_pg_client.db.session, "commit") + commit_spy = mocker.spy(mock_session, "commit") mock_from_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.from_guard_item" ) @@ -287,6 +324,8 @@ def test_updates_guard_item(self, mocker): updated_guard = MockGuardStruct(description="updated description") mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) + mock_session = mock_pg_client.SessionLocal() mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, @@ -296,7 +335,7 @@ def test_updates_guard_item(self, mocker): ) mock_get_guard_item.return_value = old_guard_item - commit_spy = mocker.spy(mock_pg_client.db.session, "commit") + commit_spy = mocker.spy(mock_session, "commit") mock_from_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.from_guard_item" ) @@ -324,6 +363,7 @@ def test_guard_doesnt_exist_yet(self, mocker): input_guard = MockGuardStruct() new_guard = MockGuardStruct() mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, @@ -367,16 +407,20 @@ def test_guard_already_exists(self, mocker): updated_guard = MockGuardStruct(description="updated description") mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, ) + + mock_session = mock_pg_client.SessionLocal() + mock_get_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.PGGuardClient.get_guard_item" ) mock_get_guard_item.return_value = old_guard_item - commit_spy = mocker.spy(mock_pg_client.db.session, "commit") + commit_spy = mocker.spy(mock_session, "commit") mock_from_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.from_guard_item" ) @@ -404,16 +448,21 @@ def test_guard_already_exists(self, mocker): class TestDeleteGuard: def test_raises_not_found(self, mocker): mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, ) + + mock_session = mock_pg_client.SessionLocal() + mock_get_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.PGGuardClient.get_guard_item" ) mock_get_guard_item.return_value = None - commit_spy = mocker.spy(mock_pg_client.db.session, "commit") + commit_spy = mocker.spy(mock_session, "commit") + mock_from_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.from_guard_item" ) @@ -438,17 +487,23 @@ def test_raises_not_found(self, mocker): def test_deletes_guard_item(self, mocker): old_guard = MockGuardStruct() mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, ) + + mock_session = mock_pg_client.SessionLocal() + mock_get_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.PGGuardClient.get_guard_item" ) mock_get_guard_item.return_value = old_guard - delete_spy = mocker.spy(mock_pg_client.db.session, "delete") - commit_spy = mocker.spy(mock_pg_client.db.session, "commit") + # Mock the query and delete operations + mock_query = mock_session.query.return_value + mock_filter = mock_query.filter_by.return_value + mock_filter.first.return_value = old_guard mock_from_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.from_guard_item" ) @@ -461,8 +516,8 @@ def test_deletes_guard_item(self, mocker): result = guard_client.delete_guard("mock-guard") mock_get_guard_item.assert_called_once_with("mock-guard") - assert delete_spy.call_count == 1 - assert commit_spy.call_count == 1 + assert mock_session.delete.call_count == 1 + assert mock_session.commit.call_count == 1 mock_from_guard_item.assert_called_once_with(old_guard) assert result == old_guard diff --git a/tests/mocks/mock_postgres_client.py b/tests/mocks/mock_postgres_client.py index 4197882..7ee9b2b 100644 --- a/tests/mocks/mock_postgres_client.py +++ b/tests/mocks/mock_postgres_client.py @@ -49,7 +49,14 @@ class MockDb: def __init__(self) -> None: self.session = MockSession() + def SessionLocal(self): + return self.session + class MockPostgresClient: def __init__(self): self.db = MockDb() + self.pgClient = self.db + + def get_db(self): + return MockSession() From 758cd8fcea744b48ffd9c1ce71723131592f7114 Mon Sep 17 00:00:00 2001 From: David Tam Date: Tue, 17 Sep 2024 14:46:42 -0700 Subject: [PATCH 07/13] code review comments --- guardrails_api/api/guards.py | 7 ++-- guardrails_api/api/root.py | 2 +- guardrails_api/app.py | 67 +----------------------------------- 3 files changed, 5 insertions(+), 71 deletions(-) diff --git a/guardrails_api/api/guards.py b/guardrails_api/api/guards.py index 8621e57..f5939ec 100644 --- a/guardrails_api/api/guards.py +++ b/guardrails_api/api/guards.py @@ -37,7 +37,6 @@ guard_client.create_guard(export) cache_client = CacheClient() - router = APIRouter() @@ -259,9 +258,9 @@ async def validate_streamer(guard_iter): **payload, ) - # serialized_history = [call.to_dict() for call in guard.history] - # cache_key = f"{guard.name}-{result.call_id}" - # await cache_client.set(cache_key, serialized_history, 300) + serialized_history = [call.to_dict() for call in guard.history] + cache_key = f"{guard.name}-{result.call_id}" + await cache_client.set(cache_key, serialized_history, 300) return result.to_dict() diff --git a/guardrails_api/api/root.py b/guardrails_api/api/root.py index d818923..6a13a9b 100644 --- a/guardrails_api/api/root.py +++ b/guardrails_api/api/root.py @@ -22,7 +22,7 @@ class HealthCheckResponse(BaseModel): @router.get("/") async def home(): - return "Hello, FastAPI!" + return "Hello, world!" @router.get("/health-check", response_model=HealthCheckResponse) diff --git a/guardrails_api/app.py b/guardrails_api/app.py index 27512f7..add23e4 100644 --- a/guardrails_api/app.py +++ b/guardrails_api/app.py @@ -11,63 +11,11 @@ from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from rich.console import Console from rich.rule import Rule -from starlette.middleware.base import BaseHTTPMiddleware from typing import Optional -from urllib.parse import urlparse import importlib.util import json import os -# from pyinstrument import Profiler -# from pyinstrument.renderers.html import HTMLRenderer -# from pyinstrument.renderers.speedscope import SpeedscopeRenderer -# from starlette.middleware.base import RequestResponseEndpoint -# class ProfilingMiddleware(BaseHTTPMiddleware): -# async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: -# """Profile the current request - -# Taken from https://pyinstrument.readthedocs.io/en/latest/guide.html#profile-a-web-request-in-fastapi -# with small improvements. - -# """ -# # we map a profile type to a file extension, as well as a pyinstrument profile renderer -# profile_type_to_ext = {"html": "html", "speedscope": "speedscope.json"} -# profile_type_to_renderer = { -# "html": HTMLRenderer, -# "speedscope": SpeedscopeRenderer, -# } - -# if request.headers.get("X-Profile-Request"): -# # The default profile format is speedscope -# profile_type = request.query_params.get("profile_format", "speedscope") - -# # we profile the request along with all additional middlewares, by interrupting -# # the program every 1ms1 and records the entire stack at that point -# with Profiler(interval=0.001, async_mode="enabled") as profiler: -# response = await call_next(request) - -# # we dump the profiling into a file -# # Generate a unique filename based on timestamp and request properties -# timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") -# method = request.method -# path = request.url.path.replace("/", "_").strip("_") -# extension = profile_type_to_ext[profile_type] -# filename = f"profile_{timestamp}_{method}_{path}.{extension}" - -# # Ensure the profiling directory exists -# profiling_dir = "profiling" -# os.makedirs(profiling_dir, exist_ok=True) - -# # Dump the profiling into a file -# renderer = profile_type_to_renderer[profile_type]() -# filepath = os.path.join(profiling_dir, filename) -# with open(filepath, "w") as out: -# out.write(profiler.output(renderer=renderer)) - -# return response -# else: -# return await call_next(request) - # Custom JSON encoder class CustomJSONEncoder(json.JSONEncoder): @@ -79,16 +27,6 @@ def default(self, o): return super().default(o) -# Custom middleware for reverse proxy -class ReverseProxyMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): - self_endpoint = os.environ.get("SELF_ENDPOINT", "http://localhost:8000") - url = urlparse(self_endpoint) - request.scope["scheme"] = url.scheme - response = await call_next(request) - return response - - def register_config(config: Optional[str] = None): default_config_file = os.path.join(os.getcwd(), "./config.py") config_file = config or default_config_file @@ -124,7 +62,7 @@ def create_app( register_config(config) - app = FastAPI() + app = FastAPI(openapi_url="") # Initialize FastAPIInstrumentor FastAPIInstrumentor.instrument_app(app) @@ -140,9 +78,6 @@ def create_app( allow_headers=["*"], ) - # Add reverse proxy middleware - app.add_middleware(ReverseProxyMiddleware) - guardrails_log_level = os.environ.get("GUARDRAILS_LOG_LEVEL", "INFO") configure_logging(log_level=guardrails_log_level) From 4fa1b4e9fc79fd14601f0cd5b1b6e9a233f06581 Mon Sep 17 00:00:00 2001 From: David Tam Date: Wed, 18 Sep 2024 16:40:21 -0700 Subject: [PATCH 08/13] updates for async all the way down and add extra user info to share with oss --- Dockerfile | 4 ++-- compose.yml | 6 +++--- guardrails_api/api/guards.py | 23 +++++++++++++++------- guardrails_api/app.py | 37 +++++++++++++++++++++++++++++++++++- pyproject.toml | 4 ++-- setup.py | 2 +- tests/api/test_root.py | 2 +- 7 files changed, 61 insertions(+), 17 deletions(-) diff --git a/Dockerfile b/Dockerfile index ca847f0..d01a5ee 100644 --- a/Dockerfile +++ b/Dockerfile @@ -45,7 +45,7 @@ COPY . . EXPOSE 8000 # This is our start command; yours might be different. -# The guardrails-api is a standard Flask application. -# You can use whatever production server you want that support Flask. +# The guardrails-api is a standard FastAPI application. +# You can use whatever production server you want that support FastAPI. # Here we use gunicorn CMD gunicorn --bind 0.0.0.0:8000 --timeout=90 --workers=2 'guardrails_api.app:create_app(".env", "sample-config.py")' \ No newline at end of file diff --git a/compose.yml b/compose.yml index b445876..5033bcf 100644 --- a/compose.yml +++ b/compose.yml @@ -3,8 +3,8 @@ services: profiles: ["all", "db", "infra"] image: ankane/pgvector environment: - POSTGRES_USER: ${PGUSER:-postgres} - POSTGRES_PASSWORD: ${PGPASSWORD:-changeme} + POSTGRES_USER: admin + POSTGRES_PASSWORD: admin POSTGRES_DATA: /data/postgres volumes: - ./postgres:/data/postgres @@ -21,7 +21,7 @@ services: - "8088:80" environment: PGADMIN_DEFAULT_EMAIL: "${PGUSER:-postgres}@guardrails.com" - PGADMIN_DEFAULT_PASSWORD: ${PGPASSWORD:-changeme} + PGADMIN_DEFAULT_PASSWORD: admin PGADMIN_SERVER_JSON_FILE: /var/lib/pgadmin/servers.json # FIXME: Copy over server.json file and create passfile volumes: diff --git a/guardrails_api/api/guards.py b/guardrails_api/api/guards.py index f5939ec..26f3269 100644 --- a/guardrails_api/api/guards.py +++ b/guardrails_api/api/guards.py @@ -250,13 +250,22 @@ async def validate_streamer(guard_iter): validate_streamer(guard_streamer()), media_type="application/json" ) else: - result: ValidationOutcome = guard( - llm_api=llm_api, - prompt_params=prompt_params, - num_reasks=num_reasks, - *args, - **payload, - ) + if inspect.iscoroutinefunction(guard): + result: ValidationOutcome = await guard( + llm_api=llm_api, + prompt_params=prompt_params, + num_reasks=num_reasks, + *args, + **payload, + ) + else: + result: ValidationOutcome = guard( + llm_api=llm_api, + prompt_params=prompt_params, + num_reasks=num_reasks, + *args, + **payload, + ) serialized_history = [call.to_dict() for call in guard.history] cache_key = f"{guard.name}-{result.call_id}" diff --git a/guardrails_api/app.py b/guardrails_api/app.py index add23e4..75602a6 100644 --- a/guardrails_api/app.py +++ b/guardrails_api/app.py @@ -9,6 +9,8 @@ trace_server_start_if_enabled, ) from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor +from opentelemetry import trace, context, baggage + from rich.console import Console from rich.rule import Rule from typing import Optional @@ -16,6 +18,37 @@ import json import os +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request + +class RequestInfoMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + tracer = trace.get_tracer(__name__) + # Get the current context and attach it to this task + with tracer.start_as_current_span("request_info") as span: + client_ip = request.client.host + user_agent = request.headers.get("user-agent", "unknown") + referrer = request.headers.get("referrer", "unknown") + user_id = request.headers.get("x-user-id", "unknown") + organization = request.headers.get("x-organization", "unknown") + app = request.headers.get("x-app", "unknown") + + context.attach(baggage.set_baggage("client.ip", client_ip)) + context.attach(baggage.set_baggage("http.user_agent", user_agent)) + context.attach(baggage.set_baggage("http.referrer", referrer)) + context.attach(baggage.set_baggage("user.id", user_id)) + context.attach(baggage.set_baggage("organization", organization)) + context.attach(baggage.set_baggage("app", app)) + + span.set_attribute("client.ip", client_ip) + span.set_attribute("http.user_agent", user_agent) + span.set_attribute("http.referrer", referrer) + span.set_attribute("user.id", user_id) + span.set_attribute("organization", organization) + span.set_attribute("app", app) + + response = await call_next(request) + return response # Custom JSON encoder class CustomJSONEncoder(json.JSONEncoder): @@ -64,10 +97,12 @@ def create_app( app = FastAPI(openapi_url="") + # Add the custom middleware + app.add_middleware(RequestInfoMiddleware) + # Initialize FastAPIInstrumentor FastAPIInstrumentor.instrument_app(app) - # app.add_middleware(ProfilingMiddleware) # Add CORS middleware app.add_middleware( diff --git a/pyproject.toml b/pyproject.toml index a2db2fb..a6d9e52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ readme = "README.md" keywords = ["Guardrails", "Guardrails AI", "Guardrails API", "Guardrails API"] requires-python = ">= 3.8.1" dependencies = [ - "guardrails-ai>=0.5.6", + "guardrails-ai>=0.5.10", "Werkzeug>=3.0.3,<4", "jsonschema>=4.22.0,<5", "referencing>=0.35.1,<1", @@ -22,7 +22,7 @@ dependencies = [ "opentelemetry-sdk>=1.0.0,<2", "opentelemetry-exporter-otlp-proto-grpc>=1.0.0,<2", "opentelemetry-exporter-otlp-proto-http>=1.0.0,<2", - "opentelemetry-instrumentation-fastapi>=0.47b0", + "opentelemetry-instrumentation-fastapi>=0.48b0", "requests>=2.32.3", "aiocache>=0.11.1", "fastapi>=0.114.1", diff --git a/setup.py b/setup.py index ca6d9f1..e8f135c 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ "opentelemetry-sdk>=1.0.0,<2", "opentelemetry-exporter-otlp-proto-grpc>=1.0.0,<2", "opentelemetry-exporter-otlp-proto-http>=1.0.0,<2", - "opentelemetry-instrumentation-fastapi>=0.47b0", + "opentelemetry-instrumentation-fastapi>=0.48b0", "requests>=2.32.3", "aiocache>=0.11.1", "fastapi>=0.114.1", diff --git a/tests/api/test_root.py b/tests/api/test_root.py index c0b4029..753dee0 100644 --- a/tests/api/test_root.py +++ b/tests/api/test_root.py @@ -25,7 +25,7 @@ def client(app): def test_home(client): response = client.get("/") assert response.status_code == 200 - assert response.json() == "Hello, FastAPI!" + assert response.json() == "Hello, world!" # Check if all expected routes are registered routes = [route.path for route in client.app.routes] From 04884ae9c57cac577a88ec2689f03fa2f3c6fba7 Mon Sep 17 00:00:00 2001 From: David Tam Date: Wed, 18 Sep 2024 16:41:09 -0700 Subject: [PATCH 09/13] backout compose updates --- compose.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/compose.yml b/compose.yml index 5033bcf..b445876 100644 --- a/compose.yml +++ b/compose.yml @@ -3,8 +3,8 @@ services: profiles: ["all", "db", "infra"] image: ankane/pgvector environment: - POSTGRES_USER: admin - POSTGRES_PASSWORD: admin + POSTGRES_USER: ${PGUSER:-postgres} + POSTGRES_PASSWORD: ${PGPASSWORD:-changeme} POSTGRES_DATA: /data/postgres volumes: - ./postgres:/data/postgres @@ -21,7 +21,7 @@ services: - "8088:80" environment: PGADMIN_DEFAULT_EMAIL: "${PGUSER:-postgres}@guardrails.com" - PGADMIN_DEFAULT_PASSWORD: admin + PGADMIN_DEFAULT_PASSWORD: ${PGPASSWORD:-changeme} PGADMIN_SERVER_JSON_FILE: /var/lib/pgadmin/servers.json # FIXME: Copy over server.json file and create passfile volumes: From b073f6be17ad0bc92f91e3a7ebad0bd138a507a3 Mon Sep 17 00:00:00 2001 From: David Tam Date: Thu, 19 Sep 2024 11:15:07 -0700 Subject: [PATCH 10/13] fix tests --- guardrails_api/api/guards.py | 3 +++ guardrails_api/clients/cache_client.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/guardrails_api/api/guards.py b/guardrails_api/api/guards.py index 26f3269..afdd206 100644 --- a/guardrails_api/api/guards.py +++ b/guardrails_api/api/guards.py @@ -37,6 +37,9 @@ guard_client.create_guard(export) cache_client = CacheClient() + +cache_client.initialize() + router = APIRouter() diff --git a/guardrails_api/clients/cache_client.py b/guardrails_api/clients/cache_client.py index 4d80a63..869162a 100644 --- a/guardrails_api/clients/cache_client.py +++ b/guardrails_api/clients/cache_client.py @@ -14,7 +14,7 @@ def __new__(cls): cls._instance = super().__new__(cls) return cls._instance - def initialize(self, app: FastAPI): + def initialize(self): caches.set_config( { "default": { From 0fc56addfe697b04057ea7390baa70508de2ab95 Mon Sep 17 00:00:00 2001 From: David Tam Date: Thu, 19 Sep 2024 11:55:00 -0700 Subject: [PATCH 11/13] update our websever --- guardrails_api/app.py | 2 +- guardrails_api/cli/start.py | 5 +++-- pyproject.toml | 3 +-- setup.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/guardrails_api/app.py b/guardrails_api/app.py index 75602a6..e591c4f 100644 --- a/guardrails_api/app.py +++ b/guardrails_api/app.py @@ -127,7 +127,7 @@ def create_app( pg_client.initialize(app) cache_client = CacheClient() - cache_client.initialize(app) + cache_client.initialize() from guardrails_api.api.root import router as root_router from guardrails_api.api.guards import router as guards_router, guard_client diff --git a/guardrails_api/cli/start.py b/guardrails_api/cli/start.py index 6ac66fa..d1844eb 100644 --- a/guardrails_api/cli/start.py +++ b/guardrails_api/cli/start.py @@ -3,7 +3,7 @@ from guardrails_api.cli.cli import cli from guardrails_api.app import create_app from guardrails_api.utils.configuration import valid_configuration - +import uvicorn @cli.command("start") def start( @@ -25,4 +25,5 @@ def start( env = env or None config = config or None valid_configuration(config) - create_app(env, config, port).run(port=port) + app = create_app(env, config, port) + uvicorn.run(app, port=port) diff --git a/pyproject.toml b/pyproject.toml index a6d9e52..2520656 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,6 @@ keywords = ["Guardrails", "Guardrails AI", "Guardrails API", "Guardrails API"] requires-python = ">= 3.8.1" dependencies = [ "guardrails-ai>=0.5.10", - "Werkzeug>=3.0.3,<4", "jsonschema>=4.22.0,<5", "referencing>=0.35.1,<1", "boto3>=1.34.115,<2", @@ -27,6 +26,7 @@ dependencies = [ "aiocache>=0.11.1", "fastapi>=0.114.1", "SQLAlchemy>=2.0.34", + "uvicorn>=0.30.6", ] [tool.setuptools.dynamic] @@ -42,7 +42,6 @@ dev = [ "coverage", "pytest-mock", "gunicorn>=22.0.0,<23", - "uvicorn", ] [tool.pytest.ini_options] diff --git a/setup.py b/setup.py index e8f135c..806bde1 100644 --- a/setup.py +++ b/setup.py @@ -16,8 +16,7 @@ packages=find_packages(), python_requires=">=3.8, <4", install_requires=[ - "guardrails-ai>=0.5.6", - "Werkzeug>=3.0.3,<4", + "guardrails-ai>=0.5.10", "jsonschema>=4.22.0,<5", "referencing>=0.35.1,<1", "boto3>=1.34.115,<2", @@ -33,6 +32,7 @@ "aiocache>=0.11.1", "fastapi>=0.114.1", "SQLAlchemy>=2.0.34", + "uvicorn>=0.30.6", ], package_data={"guardrails_api": ["py.typed", "open-api-spec.json"]}, ) From 1442c54080560123d013d803231e3e4dbca70cad Mon Sep 17 00:00:00 2001 From: David Tam Date: Thu, 19 Sep 2024 12:02:29 -0700 Subject: [PATCH 12/13] lint --- guardrails_api/app.py | 3 +-- guardrails_api/cli/start.py | 1 + guardrails_api/clients/cache_client.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/guardrails_api/app.py b/guardrails_api/app.py index e591c4f..eb1dd8a 100644 --- a/guardrails_api/app.py +++ b/guardrails_api/app.py @@ -19,7 +19,6 @@ import os from starlette.middleware.base import BaseHTTPMiddleware -from starlette.requests import Request class RequestInfoMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): @@ -50,6 +49,7 @@ async def dispatch(self, request: Request, call_next): response = await call_next(request) return response + # Custom JSON encoder class CustomJSONEncoder(json.JSONEncoder): def default(self, o): @@ -103,7 +103,6 @@ def create_app( # Initialize FastAPIInstrumentor FastAPIInstrumentor.instrument_app(app) - # Add CORS middleware app.add_middleware( CORSMiddleware, diff --git a/guardrails_api/cli/start.py b/guardrails_api/cli/start.py index d1844eb..95d93d2 100644 --- a/guardrails_api/cli/start.py +++ b/guardrails_api/cli/start.py @@ -5,6 +5,7 @@ from guardrails_api.utils.configuration import valid_configuration import uvicorn + @cli.command("start") def start( env: Optional[str] = typer.Option( diff --git a/guardrails_api/clients/cache_client.py b/guardrails_api/clients/cache_client.py index 869162a..4ddace1 100644 --- a/guardrails_api/clients/cache_client.py +++ b/guardrails_api/clients/cache_client.py @@ -1,5 +1,4 @@ import threading -from fastapi import FastAPI from aiocache import caches From a2d99213d45ad69c40f7856f33d1f731c4ddf4e4 Mon Sep 17 00:00:00 2001 From: David Tam Date: Thu, 19 Sep 2024 13:46:23 -0700 Subject: [PATCH 13/13] mock server --- guardrails_api/app.py | 1 + tests/cli/test_start.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/guardrails_api/app.py b/guardrails_api/app.py index eb1dd8a..29e2d44 100644 --- a/guardrails_api/app.py +++ b/guardrails_api/app.py @@ -20,6 +20,7 @@ from starlette.middleware.base import BaseHTTPMiddleware + class RequestInfoMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): tracer = trace.get_tracer(__name__) diff --git a/tests/cli/test_start.py b/tests/cli/test_start.py index befe21a..3138a66 100644 --- a/tests/cli/test_start.py +++ b/tests/cli/test_start.py @@ -10,6 +10,8 @@ def test_start(mocker): "guardrails_api.cli.start.create_app", return_value=mock_app ) + mocker.patch("uvicorn.run") + from guardrails_api.cli.start import start # pg enabled