Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new core schemas in guardrails API #46

Merged
merged 25 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/pr_qa.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ jobs:
make install-lock;
make install-dev;

curl https://raw.githubusercontent.com/guardrails-ai/guardrails-api-client/main/service-specs/guardrails-service-spec.yml -o ./open-api-spec.yml
npx @redocly/cli bundle --dereferenced --output ./open-api-spec.json --ext json ./open-api-spec.yml
cp ./.venv/lib/python3.12/site-packages/guardrails_api_client/openapi-spec.json ./open-api-spec.json

- name: Run Quality Checks
run: |
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
flask
sqlalchemy
lxml
guardrails-ai
guardrails-ai @ git+https://github.com/guardrails-ai/guardrails.git@core-schema-impl
# Let this come from guardrails-ai as a transient dependency.
# Pip confuses tag versions with commit ids,
# and claims a conflict even though it's the same thing.
Expand Down
83 changes: 38 additions & 45 deletions src/blueprints/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
from typing import Any, Dict, cast
from flask import Blueprint, Response, request, stream_with_context
from urllib.parse import unquote_plus
from guardrails import Guard
from guardrails import Guard
from guardrails.api_client import GuardrailsApiClient
from guardrails.classes import ValidationOutcome
from opentelemetry.trace import Span
from src.classes.guard_struct import GuardStruct
from src.classes.http_error import HttpError
from src.classes.validation_output import ValidationOutput
from src.clients.memory_guard_client import MemoryGuardClient
from src.clients.pg_guard_client import PGGuardClient
from src.clients.postgres_client import postgres_is_enabled
from src.utils.handle_error import handle_error
from src.utils.get_llm_callable import get_llm_callable
from src.utils.prep_environment import cleanup_environment, prep_environment
from guardrails_api_client import Guard as GuardStruct


guards_bp = Blueprint("guards", __name__, url_prefix="/guards")
Expand All @@ -43,9 +43,7 @@
def guards():
if request.method == "GET":
guards = guard_client.get_guards()
if len(guards) > 0 and (isinstance(guards[0], Guard)):
return [g._to_request() for g in guards]
return [g.to_response() for g in guards]
return [g.to_json() for g in guards]
elif request.method == "POST":
if not postgres_is_enabled():
raise HttpError(
Expand All @@ -54,11 +52,9 @@ def guards():
"POST /guards is not implemented for in-memory guards.",
)
payload = request.json
guard = GuardStruct.from_request(payload)
guard = GuardStruct.from_json(payload)
new_guard = guard_client.create_guard(guard)
if isinstance(new_guard, Guard):
return new_guard._to_request()
return new_guard.to_response()
return new_guard.to_json()
else:
raise HttpError(
405,
Expand All @@ -83,9 +79,7 @@ def guard(guard_name: str):
guard_name=decoded_guard_name
),
)
if isinstance(guard, Guard):
return guard._to_request()
return guard.to_response()
return guard.to_json()
elif request.method == "PUT":
if not postgres_is_enabled():
raise HttpError(
Expand All @@ -94,11 +88,9 @@ def guard(guard_name: str):
"PUT /<guard_name> is not implemented for in-memory guards.",
)
payload = request.json
guard = GuardStruct.from_request(payload)
guard = GuardStruct.from_json(payload)
updated_guard = guard_client.upsert_guard(decoded_guard_name, guard)
if isinstance(updated_guard, Guard):
return updated_guard._to_request()
return updated_guard.to_response()
return updated_guard.to_json()
elif request.method == "DELETE":
if not postgres_is_enabled():
raise HttpError(
Expand All @@ -107,9 +99,7 @@ def guard(guard_name: str):
"DELETE /<guard_name> is not implemented for in-memory guards.",
)
guard = guard_client.delete_guard(decoded_guard_name)
if isinstance(guard, Guard):
return guard._to_request()
return guard.to_response()
return guard.to_json()
else:
raise HttpError(
405,
Expand All @@ -123,7 +113,7 @@ def collect_telemetry(
*,
guard: Guard,
validate_span: Span,
validation_output: ValidationOutput,
validation_output: ValidationOutcome,
prompt_params: Dict[str, Any],
result: ValidationOutcome,
):
Expand Down Expand Up @@ -179,12 +169,11 @@ def validate(guard_name: str):
)
decoded_guard_name = unquote_plus(guard_name)
guard_struct = guard_client.get_guard(decoded_guard_name)
if isinstance(guard_struct, GuardStruct):
# TODO: is there a way to do this with Guard?
prep_environment(guard_struct)
prep_environment(guard_struct)

llm_output = payload.pop("llmOutput", None)
num_reasks = payload.pop("numReasks", guard_struct.num_reasks)
# TODO: not sure if this is right - how do we get numReasks from new IGuard?
num_reasks = payload.pop("numReasks", 0)
CalebCourier marked this conversation as resolved.
Show resolved Hide resolved
prompt_params = payload.pop("promptParams", {})
llm_api = payload.pop("llmApi", None)
args = payload.pop("args", [])
Expand All @@ -201,7 +190,14 @@ def validate(guard_name: str):
# guard: Guard = guard_struct.to_guard(openai_api_key, otel_tracer)
guard: Guard = Guard()
if isinstance(guard_struct, GuardStruct):
guard: Guard = guard_struct.to_guard(openai_api_key)
guard: Guard = Guard(
id=guard_struct.id,
name=guard_struct.name,
description=guard_struct.description,
validators=guard_struct.validators,
output_schema=guard_struct.output_schema,
)
guard._api_client = GuardrailsApiClient(api_key=openai_api_key)
elif isinstance(guard_struct, Guard):
guard = guard_struct
# validate_span.set_attribute("guardName", decoded_guard_name)
Expand Down Expand Up @@ -234,22 +230,20 @@ def validate(guard_name: str):
message="BadRequest",
cause="Streaming is not supported for parse calls!",
)

result: ValidationOutcome = guard.parse(
llm_output=llm_output,
num_reasks=num_reasks,
prompt_params=prompt_params,
llm_api=llm_api,
# api_key=openai_api_key,
*args,
**payload,
)
else:
if stream:

def guard_streamer():
guard_stream = guard(
llm_api=llm_api,
# llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
stream=stream,
Expand All @@ -260,7 +254,7 @@ def guard_streamer():

for result in guard_stream:
# TODO: Just make this a ValidationOutcome with history
validation_output: ValidationOutput = ValidationOutput(
validation_output: ValidationOutcome = ValidationOutcome(
result.validation_passed,
result.validated_output,
guard.history,
Expand All @@ -278,11 +272,11 @@ def validate_streamer(guard_iter):
fragment = json.dumps(validation_output.to_response())
yield f"{fragment}\n"

final_validation_output: ValidationOutput = ValidationOutput(
next_result.validation_passed,
next_result.validated_output,
guard.history,
next_result.raw_llm_output,
final_validation_output: ValidationOutcome = ValidationOutcome(
validation_passed=next_result.validation_passed,
validated_output=next_result.validated_output,
history=guard.history,
raw_llm_output=next_result.raw_llm_output,
)
# I don't know if these are actually making it to OpenSearch
# because the span may be ended already
Expand All @@ -293,7 +287,7 @@ def validate_streamer(guard_iter):
# prompt_params=prompt_params,
# result=next_result
# )
final_output_json = json.dumps(final_validation_output.to_response())
final_output_json = final_validation_output.to_json()
yield f"{final_output_json}\n"

return Response(
Expand All @@ -312,12 +306,12 @@ def validate_streamer(guard_iter):
)

# TODO: Just make this a ValidationOutcome with history
validation_output = ValidationOutput(
result.validation_passed,
result.validated_output,
guard.history,
result.raw_llm_output,
)
# validation_output = ValidationOutcome(
# validation_passed = result.validation_passed,
# validated_output=result.validated_output,
# history=guard.history,
# raw_llm_output=result.raw_llm_output,
# )

# collect_telemetry(
# guard=guard,
Expand All @@ -326,6 +320,5 @@ def validate_streamer(guard_iter):
# prompt_params=prompt_params,
# result=result
# )
if isinstance(guard_struct, GuardStruct):
cleanup_environment(guard_struct)
return validation_output.to_response()
cleanup_environment(guard_struct)
return result.to_dict()
Loading
Loading