Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Async Streaming #53

Merged
merged 5 commits into from
Jul 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 101 additions & 18 deletions guardrails_api/blueprints/guards.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
import json
import os
import inspect
from guardrails.hub import * # noqa
from string import Template
from typing import Any, Dict, cast
from flask import Blueprint, Response, request, stream_with_context
from urllib.parse import unquote_plus
from guardrails import Guard
from guardrails import AsyncGuard, Guard
from guardrails.classes import ValidationOutcome
from opentelemetry.trace import Span
from guardrails_api_client import Guard as GuardStruct
Expand Down Expand Up @@ -170,7 +172,6 @@ def validate(guard_name: str):
)
decoded_guard_name = unquote_plus(guard_name)
guard_struct = guard_client.get_guard(decoded_guard_name)

llm_output = payload.pop("llmOutput", None)
num_reasks = payload.pop("numReasks", None)
prompt_params = payload.pop("promptParams", {})
Expand All @@ -187,9 +188,7 @@ def validate(guard_name: str):
# f"validate-{decoded_guard_name}"
# ) as validate_span:
# guard: Guard = guard_struct.to_guard(openai_api_key, otel_tracer)
guard = guard_struct
if not isinstance(guard_struct, Guard):
guard: Guard = Guard.from_dict(guard_struct.to_dict())


# validate_span.set_attribute("guardName", decoded_guard_name)
if llm_api is not None:
Expand All @@ -204,7 +203,18 @@ def validate(guard_name: str):
" OPENAI_API_KEY environment variable."
),
)
elif num_reasks and num_reasks > 1:

guard = guard_struct
is_async = inspect.iscoroutinefunction(llm_api)
if not isinstance(guard_struct, Guard):
if is_async:
guard = AsyncGuard.from_dict(guard_struct.to_dict())
else:
guard: Guard = Guard.from_dict(guard_struct.to_dict())
elif is_async:
guard:Guard = AsyncGuard.from_dict(guard_struct.to_dict())

if llm_api is None and num_reasks and num_reasks > 1:
raise HttpError(
status=400,
message="BadRequest",
Expand All @@ -229,7 +239,6 @@ def validate(guard_name: str):
)
else:
if stream:

def guard_streamer():
guard_stream = guard(
llm_api=llm_api,
Expand All @@ -239,24 +248,30 @@ def guard_streamer():
*args,
**payload,
)

for result in guard_stream:
# TODO: Just make this a ValidationOutcome with history
validation_output: ValidationOutcome = (
ValidationOutcome.from_guard_history(guard.history.last)
)
yield validation_output, cast(ValidationOutcome, result)

# ValidationOutcome(
# guard.history,
# validation_passed=result.validation_passed,
# validated_output=result.validated_output,
# raw_llm_output=result.raw_llm_output,
# )
async def async_guard_streamer():
guard_stream = await guard(
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
stream=stream,
*args,
**payload,
)
async for result in guard_stream:
validation_output: ValidationOutcome = (
ValidationOutcome.from_guard_history(guard.history.last)
)
yield validation_output, cast(ValidationOutcome, result)

def validate_streamer(guard_iter):
next_result = None
# next_validation_output = None
for validation_output, result in guard_iter:
next_result = result
# next_validation_output = validation_output
Expand All @@ -271,7 +286,6 @@ def validate_streamer(guard_iter):
)
fragment = json.dumps(fragment_dict)
yield f"{fragment}\n"

call = guard.history.last
final_validation_output: ValidationOutcome = ValidationOutcome(
callId=call.id,
Expand Down Expand Up @@ -303,11 +317,80 @@ def validate_streamer(guard_iter):
serialized_history = [call.to_dict() for call in guard.history]
cache_key = f"{guard.name}-{final_validation_output.call_id}"
cache_client.set(cache_key, serialized_history, 300)

yield f"{final_output_json}\n"

async def async_validate_streamer(guard_iter):
next_result = None
# next_validation_output = None
async for validation_output, result in guard_iter:
next_result = result
# next_validation_output = validation_output
fragment_dict = result.to_dict()
fragment_dict["error_spans"] = list(
map(
lambda x: json.dumps(
{"start": x.start, "end": x.end, "reason": x.reason}
),
guard.error_spans_in_output(),
)
)
fragment = json.dumps(fragment_dict)
yield f"{fragment}\n"

call = guard.history.last
final_validation_output: ValidationOutcome = ValidationOutcome(
callId=call.id,
validation_passed=next_result.validation_passed,
validated_output=next_result.validated_output,
history=guard.history,
raw_llm_output=next_result.raw_llm_output,
)
# I don't know if these are actually making it to OpenSearch
# because the span may be ended already
# collect_telemetry(
# guard=guard,
# validate_span=validate_span,
# validation_output=next_validation_output,
# prompt_params=prompt_params,
# result=next_result
# )
final_output_dict = final_validation_output.to_dict()
final_output_dict["error_spans"] = list(
map(
lambda x: json.dumps(
{"start": x.start, "end": x.end, "reason": x.reason}
),
guard.error_spans_in_output(),
)
)
final_output_json = json.dumps(final_output_dict)

serialized_history = [call.to_dict() for call in guard.history]
cache_key = f"{guard.name}-{final_validation_output.call_id}"
cache_client.set(cache_key, serialized_history, 300)
yield f"{final_output_json}\n"
# apropos of https://stackoverflow.com/questions/73949570/using-stream-with-context-as-async
def iter_over_async(ait, loop):
ait = ait.__aiter__()
async def get_next():
try:
obj = await ait.__anext__()
return False, obj
except StopAsyncIteration:
return True, None
while True:
done, obj = loop.run_until_complete(get_next())
if done:
break
yield obj
if is_async:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
iter = iter_over_async(async_validate_streamer(async_guard_streamer()), loop)
else:
iter = validate_streamer(guard_streamer())
return Response(
stream_with_context(validate_streamer(guard_streamer())),
stream_with_context(iter),
content_type="application/json",
# content_type="text/event-stream"
)
Expand Down