Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
nichwch committed Jun 25, 2024
1 parent bb20bd3 commit 4b7dc30
Show file tree
Hide file tree
Showing 9 changed files with 297 additions and 38 deletions.
6 changes: 4 additions & 2 deletions guardrails_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def register_config(config: Optional[str] = None):
SourceFileLoader("config", config_file_path).load_module()


def create_app(env: Optional[str] = None, config: Optional[str] = None, port: Optional[int] = None):
def create_app(
env: Optional[str] = None, config: Optional[str] = None, port: Optional[int] = None
):
if os.environ.get("APP_ENVIRONMENT") != "production":
from dotenv import load_dotenv

Expand Down Expand Up @@ -86,7 +88,7 @@ def create_app(env: Optional[str] = None, config: Optional[str] = None, port: Op

pg_client = PostgresClient()
pg_client.initialize(app)

cache_client = CacheClient()
cache_client.initialize(app)

Expand Down
34 changes: 25 additions & 9 deletions guardrails_api/blueprints/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
is_guard = isinstance(export, Guard)
if is_guard:
guard_client.create_guard(export)

cache_client = CacheClient()


Expand Down Expand Up @@ -242,8 +242,10 @@ def guard_streamer():

for result in guard_stream:
# TODO: Just make this a ValidationOutcome with history
validation_output: ValidationOutcome = ValidationOutcome.from_guard_history(guard.history.last)

validation_output: ValidationOutcome = (
ValidationOutcome.from_guard_history(guard.history.last)
)

# ValidationOutcome(
# guard.history,
# validation_passed=result.validation_passed,
Expand All @@ -259,7 +261,14 @@ def validate_streamer(guard_iter):
next_result = result
# next_validation_output = validation_output
fragment_dict = result.to_dict()
fragment_dict['error_spans'] = list(map(lambda x: json.dumps({"start":x.start, "end":x.end,"reason":x.reason }),guard.error_spans_in_output()))
fragment_dict["error_spans"] = list(
map(
lambda x: json.dumps(
{"start": x.start, "end": x.end, "reason": x.reason}
),
guard.error_spans_in_output(),
)
)
fragment = json.dumps(fragment_dict)
yield f"{fragment}\n"

Expand All @@ -281,9 +290,17 @@ def validate_streamer(guard_iter):
# result=next_result
# )
final_output_dict = final_validation_output.to_dict()
final_output_dict['error_spans'] = list(map(lambda x: json.dumps({"start":x.start, "end":x.end,"reason":x.reason }),guard.error_spans_in_output()))
final_output_dict["error_spans"] = list(
map(
lambda x: json.dumps(
{"start": x.start, "end": x.end, "reason": x.reason}
),
guard.error_spans_in_output(),
)
)
final_output_json = json.dumps(final_output_dict)
yield f"{final_output_json}\n"

return Response(
stream_with_context(validate_streamer(guard_streamer())),
content_type="application/json",
Expand Down Expand Up @@ -314,16 +331,15 @@ def validate_streamer(guard_iter):
# prompt_params=prompt_params,
# result=result
# )
serialized_history = [
call.to_dict() for call in guard.history
]
serialized_history = [call.to_dict() for call in guard.history]
cache_key = f"{guard.name}-{result.call_id}"
cache_client.set(cache_key, serialized_history, 300)
return result.to_dict()


@guards_bp.route("/<guard_name>/history/<call_id>", methods=["GET"])
@handle_error
def guard_history(guard_name: str, call_id: str):
if request.method == "GET":
cache_key = f"{guard_name}-{call_id}"
return cache_client.get(cache_key)
return cache_client.get(cache_key)
9 changes: 4 additions & 5 deletions guardrails_api/clients/cache_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,15 @@ def __new__(cls):
if cls._instance is None:
cls._instance = super(CacheClient, cls).__new__(cls)
return cls._instance


def initialize(self, app):
self.cache = Cache(
app,
app,
config={
"CACHE_TYPE": "SimpleCache",
"CACHE_DEFAULT_TIMEOUT": 300,
"CACHE_THRESHOLD": 50
}
"CACHE_THRESHOLD": 50,
},
)

def get(self, key):
Expand All @@ -31,4 +30,4 @@ def delete(self, key):
self.cache.delete(key)

def clear(self):
self.cache.clear()
self.cache.clear()
2 changes: 0 additions & 2 deletions guardrails_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,3 @@
and guards will be persisted into postgres. In that case,
these guards will not be initialized.
"""

from guardrails import Guard
9 changes: 9 additions & 0 deletions guardrails_api/start
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
export APP_ENVIRONMENT=local
export PYTHONUNBUFFERED=1
export LOGLEVEL="INFO"
export GUARDRAILS_LOG_LEVEL="INFO"
export GUARDRAILS_PROCESS_COUNT=1
export SELF_ENDPOINT=http://localhost:8001
export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES

gunicorn --bind 0.0.0.0:8001 --timeout=5 --threads=10 "guardrails_api.app:create_app()"
Loading

0 comments on commit 4b7dc30

Please sign in to comment.