Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new core schemas in guardrails API #46

Merged
merged 25 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions .github/workflows/pr_qa.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@ jobs:
name: PR checks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v4
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install Dependencies
run: |
python -m venv ./.venv
source ./.venv/bin/activate
make install-lock;
make install-dev;

curl https://raw.githubusercontent.com/guardrails-ai/guardrails-api-client/main/service-specs/guardrails-service-spec.yml -o ./open-api-spec.yml
npx @redocly/cli bundle --dereferenced --output ./open-api-spec.json --ext json ./open-api-spec.yml
cp ./.venv/lib/python3.12/site-packages/guardrails_api_client/openapi-spec.json ./open-api-spec.json

- name: Run Quality Checks
run: |
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ test:

test-cov:
coverage run --source=./src -m pytest ./tests
coverage report --fail-under=50
coverage report --fail-under=45

view-test-cov:
coverage run --source=./src -m pytest ./tests
Expand Down
10 changes: 10 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from flask import Flask
from flask.json.provider import DefaultJSONProvider
from flask_cors import CORS
from werkzeug.middleware.proxy_fix import ProxyFix
from urllib.parse import urlparse
Expand All @@ -9,6 +10,14 @@
from src.otel import otel_is_disabled, initialize


# TODO: Move this to a separate file
class OverrideJsonProvider(DefaultJSONProvider):
def default(self, o):
if isinstance(o, set):
return list(o)
return super().default(self, o)


class ReverseProxied(object):
def __init__(self, app):
self.app = app
Expand All @@ -27,6 +36,7 @@ def create_app():
load_dotenv()

app = Flask(__name__)
app.json = OverrideJsonProvider(app)

app.config["APPLICATION_ROOT"] = "/"
app.config["PREFERRED_URL_SCHEME"] = "https"
Expand Down
1 change: 1 addition & 0 deletions compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ services:
PGADMIN_DEFAULT_EMAIL: "${PGUSER:-postgres}@guardrails.com"
PGADMIN_DEFAULT_PASSWORD: ${PGPASSWORD:-changeme}
PGADMIN_SERVER_JSON_FILE: /var/lib/pgadmin/servers.json
# FIXME: Copy over server.json file and create passfile
volumes:
- ./pgadmin-data:/var/lib/pgadmin
depends_on:
Expand Down
4 changes: 0 additions & 4 deletions local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ export SELF_ENDPOINT=http://localhost:8000
export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
export HF_API_KEY=${HF_TOKEN}


curl https://raw.githubusercontent.com/guardrails-ai/guardrails-api-client/main/service-specs/guardrails-service-spec.yml -o ./open-api-spec.yml
npx @redocly/cli bundle --dereferenced --output ./open-api-spec.json --ext json ./open-api-spec.yml

# For running https locally
# mkdir -p ~/certificates
# if [ ! -f ~/certificates/local.key ]; then
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
flask
sqlalchemy
lxml
guardrails-ai
guardrails-ai @ git+https://github.com/guardrails-ai/guardrails.git@core-schema-impl
# Let this come from guardrails-ai as a transient dependency.
# Pip confuses tag versions with commit ids,
# and claims a conflict even though it's the same thing.
Expand Down
74 changes: 45 additions & 29 deletions sample-config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
'''

from guardrails import Guard
from guardrails.hub import RegexMatch, RestrictToTopic
from guardrails.hub import RegexMatch, ValidChoices, ValidLength #, RestrictToTopic

name_case = Guard(
name='name-case',
description='Checks that a string is in Name Case format.'
).use(
RegexMatch(regex="^[A-Z][a-z\\s]*$")
RegexMatch(regex="^(?:[A-Z][^\s]*\s?)+$")
)

all_caps = Guard(
Expand All @@ -25,31 +25,47 @@
RegexMatch(regex="^[A-Z\\s]*$")
)

valid_topics = ["music", "cooking", "camping", "outdoors"]
invalid_topics = ["sports", "work", "ai"]
all_topics = [*valid_topics, *invalid_topics]

def custom_llm (text: str, *args, **kwargs):
return [
{
"name": t,
"present": (t in text),
"confidence": 5
}
for t in all_topics
]

custom_code_guard = Guard(
name='custom',
description='Uses a custom llm for RestrictToTopic'
lower_case = Guard(
name='lower-case',
description='Checks that a string is all lowercase.'
).use(
RegexMatch(regex="^[a-z\\s]*$")
).use(
ValidLength(1, 100)
).use(
RestrictToTopic(
valid_topics=valid_topics,
invalid_topics=invalid_topics,
llm_callable=custom_llm,
disable_classifier=True,
disable_llm=False,
# Pass this so it doesn't load the bart model
classifier_api_endpoint="https://m-1e7af27102f54c3a9eb9cb11aa4715bd-m.default.model-v2.inferless.com/v2/models/RestrictToTopic_1e7af27102f54c3a9eb9cb11aa4715bd/versions/1/infer",
)
)
ValidChoices(["music", "cooking", "camping", "outdoors"])
)

print(lower_case.to_json())




# valid_topics = ["music", "cooking", "camping", "outdoors"]
# invalid_topics = ["sports", "work", "ai"]
# all_topics = [*valid_topics, *invalid_topics]

# def custom_llm (text: str, *args, **kwargs):
# return [
# {
# "name": t,
# "present": (t in text),
# "confidence": 5
# }
# for t in all_topics
# ]

# custom_code_guard = Guard(
# name='custom',
# description='Uses a custom llm for RestrictToTopic'
# ).use(
# RestrictToTopic(
# valid_topics=valid_topics,
# invalid_topics=invalid_topics,
# llm_callable=custom_llm,
# disable_classifier=True,
# disable_llm=False,
# # Pass this so it doesn't load the bart model
# classifier_api_endpoint="https://m-1e7af27102f54c3a9eb9cb11aa4715bd-m.default.model-v2.inferless.com/v2/models/RestrictToTopic_1e7af27102f54c3a9eb9cb11aa4715bd/versions/1/infer",
# )
# )
78 changes: 29 additions & 49 deletions src/blueprints/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@
from guardrails import Guard
from guardrails.classes import ValidationOutcome
from opentelemetry.trace import Span
from src.classes.guard_struct import GuardStruct
from src.classes.http_error import HttpError
from src.classes.validation_output import ValidationOutput
from src.clients.memory_guard_client import MemoryGuardClient
from src.clients.pg_guard_client import PGGuardClient
from src.clients.postgres_client import postgres_is_enabled
from src.utils.handle_error import handle_error
from src.utils.get_llm_callable import get_llm_callable
from src.utils.prep_environment import cleanup_environment, prep_environment
from guardrails_api_client import Guard as GuardStruct


guards_bp = Blueprint("guards", __name__, url_prefix="/guards")
Expand All @@ -43,9 +41,7 @@
def guards():
if request.method == "GET":
guards = guard_client.get_guards()
if len(guards) > 0 and (isinstance(guards[0], Guard)):
return [g._to_request() for g in guards]
return [g.to_response() for g in guards]
return [g.to_dict() for g in guards]
elif request.method == "POST":
if not postgres_is_enabled():
raise HttpError(
Expand All @@ -54,11 +50,9 @@ def guards():
"POST /guards is not implemented for in-memory guards.",
)
payload = request.json
guard = GuardStruct.from_request(payload)
guard = GuardStruct.from_dict(payload)
new_guard = guard_client.create_guard(guard)
if isinstance(new_guard, Guard):
return new_guard._to_request()
return new_guard.to_response()
return new_guard.to_dict()
else:
raise HttpError(
405,
Expand All @@ -83,9 +77,7 @@ def guard(guard_name: str):
guard_name=decoded_guard_name
),
)
if isinstance(guard, Guard):
return guard._to_request()
return guard.to_response()
return guard.to_dict()
elif request.method == "PUT":
if not postgres_is_enabled():
raise HttpError(
Expand All @@ -94,11 +86,9 @@ def guard(guard_name: str):
"PUT /<guard_name> is not implemented for in-memory guards.",
)
payload = request.json
guard = GuardStruct.from_request(payload)
guard = GuardStruct.from_dict(payload)
updated_guard = guard_client.upsert_guard(decoded_guard_name, guard)
if isinstance(updated_guard, Guard):
return updated_guard._to_request()
return updated_guard.to_response()
return updated_guard.to_dict()
elif request.method == "DELETE":
if not postgres_is_enabled():
raise HttpError(
Expand All @@ -107,9 +97,7 @@ def guard(guard_name: str):
"DELETE /<guard_name> is not implemented for in-memory guards.",
)
guard = guard_client.delete_guard(decoded_guard_name)
if isinstance(guard, Guard):
return guard._to_request()
return guard.to_response()
return guard.to_dict()
else:
raise HttpError(
405,
Expand All @@ -123,7 +111,7 @@ def collect_telemetry(
*,
guard: Guard,
validate_span: Span,
validation_output: ValidationOutput,
validation_output: ValidationOutcome,
prompt_params: Dict[str, Any],
result: ValidationOutcome,
):
Expand Down Expand Up @@ -179,12 +167,9 @@ def validate(guard_name: str):
)
decoded_guard_name = unquote_plus(guard_name)
guard_struct = guard_client.get_guard(decoded_guard_name)
if isinstance(guard_struct, GuardStruct):
# TODO: is there a way to do this with Guard?
prep_environment(guard_struct)

llm_output = payload.pop("llmOutput", None)
num_reasks = payload.pop("numReasks", guard_struct.num_reasks)
num_reasks = payload.pop("numReasks", None)
prompt_params = payload.pop("promptParams", {})
llm_api = payload.pop("llmApi", None)
args = payload.pop("args", [])
Expand All @@ -199,11 +184,10 @@ def validate(guard_name: str):
# f"validate-{decoded_guard_name}"
# ) as validate_span:
# guard: Guard = guard_struct.to_guard(openai_api_key, otel_tracer)
guard: Guard = Guard()
if isinstance(guard_struct, GuardStruct):
guard: Guard = guard_struct.to_guard(openai_api_key)
elif isinstance(guard_struct, Guard):
guard = guard_struct
guard = guard_struct
if not isinstance(guard_struct, Guard):
guard: Guard = Guard.from_dict(guard_struct.to_dict())

# validate_span.set_attribute("guardName", decoded_guard_name)
if llm_api is not None:
llm_api = get_llm_callable(llm_api)
Expand Down Expand Up @@ -234,22 +218,20 @@ def validate(guard_name: str):
message="BadRequest",
cause="Streaming is not supported for parse calls!",
)

result: ValidationOutcome = guard.parse(
llm_output=llm_output,
num_reasks=num_reasks,
prompt_params=prompt_params,
llm_api=llm_api,
# api_key=openai_api_key,
*args,
**payload,
)
else:
if stream:

def guard_streamer():
guard_stream = guard(
llm_api=llm_api,
# llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
stream=stream,
Expand All @@ -260,7 +242,7 @@ def guard_streamer():

for result in guard_stream:
# TODO: Just make this a ValidationOutcome with history
validation_output: ValidationOutput = ValidationOutput(
validation_output: ValidationOutcome = ValidationOutcome(
result.validation_passed,
result.validated_output,
guard.history,
Expand All @@ -278,11 +260,11 @@ def validate_streamer(guard_iter):
fragment = json.dumps(validation_output.to_response())
yield f"{fragment}\n"

final_validation_output: ValidationOutput = ValidationOutput(
next_result.validation_passed,
next_result.validated_output,
guard.history,
next_result.raw_llm_output,
final_validation_output: ValidationOutcome = ValidationOutcome(
validation_passed=next_result.validation_passed,
validated_output=next_result.validated_output,
history=guard.history,
raw_llm_output=next_result.raw_llm_output,
)
# I don't know if these are actually making it to OpenSearch
# because the span may be ended already
Expand All @@ -293,7 +275,7 @@ def validate_streamer(guard_iter):
# prompt_params=prompt_params,
# result=next_result
# )
final_output_json = json.dumps(final_validation_output.to_response())
final_output_json = final_validation_output.to_json()
yield f"{final_output_json}\n"

return Response(
Expand All @@ -312,12 +294,12 @@ def validate_streamer(guard_iter):
)

# TODO: Just make this a ValidationOutcome with history
validation_output = ValidationOutput(
result.validation_passed,
result.validated_output,
guard.history,
result.raw_llm_output,
)
# validation_output = ValidationOutcome(
# validation_passed = result.validation_passed,
# validated_output=result.validated_output,
# history=guard.history,
# raw_llm_output=result.raw_llm_output,
# )

# collect_telemetry(
# guard=guard,
Expand All @@ -326,6 +308,4 @@ def validate_streamer(guard_iter):
# prompt_params=prompt_params,
# result=result
# )
if isinstance(guard_struct, GuardStruct):
cleanup_environment(guard_struct)
return validation_output.to_response()
return result.to_dict()
Loading