Skip to content

Commit

Permalink
Merge pull request #53 from guardrails-ai/nichwch/fix-async-streaming
Browse files Browse the repository at this point in the history
Fix Async Streaming
  • Loading branch information
nichwch authored Jul 3, 2024
2 parents 814ee8b + 56b050f commit 86e0b6f
Showing 1 changed file with 101 additions and 18 deletions.
119 changes: 101 additions & 18 deletions guardrails_api/blueprints/guards.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
import json
import os
import inspect
from guardrails.hub import * # noqa
from string import Template
from typing import Any, Dict, cast
from flask import Blueprint, Response, request, stream_with_context
from urllib.parse import unquote_plus
from guardrails import Guard
from guardrails import AsyncGuard, Guard
from guardrails.classes import ValidationOutcome
from opentelemetry.trace import Span
from guardrails_api_client import Guard as GuardStruct
Expand Down Expand Up @@ -170,7 +172,6 @@ def validate(guard_name: str):
)
decoded_guard_name = unquote_plus(guard_name)
guard_struct = guard_client.get_guard(decoded_guard_name)

llm_output = payload.pop("llmOutput", None)
num_reasks = payload.pop("numReasks", None)
prompt_params = payload.pop("promptParams", {})
Expand All @@ -187,9 +188,7 @@ def validate(guard_name: str):
# f"validate-{decoded_guard_name}"
# ) as validate_span:
# guard: Guard = guard_struct.to_guard(openai_api_key, otel_tracer)
guard = guard_struct
if not isinstance(guard_struct, Guard):
guard: Guard = Guard.from_dict(guard_struct.to_dict())


# validate_span.set_attribute("guardName", decoded_guard_name)
if llm_api is not None:
Expand All @@ -204,7 +203,18 @@ def validate(guard_name: str):
" OPENAI_API_KEY environment variable."
),
)
elif num_reasks and num_reasks > 1:

guard = guard_struct
is_async = inspect.iscoroutinefunction(llm_api)
if not isinstance(guard_struct, Guard):
if is_async:
guard = AsyncGuard.from_dict(guard_struct.to_dict())
else:
guard: Guard = Guard.from_dict(guard_struct.to_dict())
elif is_async:
guard:Guard = AsyncGuard.from_dict(guard_struct.to_dict())

if llm_api is None and num_reasks and num_reasks > 1:
raise HttpError(
status=400,
message="BadRequest",
Expand All @@ -229,7 +239,6 @@ def validate(guard_name: str):
)
else:
if stream:

def guard_streamer():
guard_stream = guard(
llm_api=llm_api,
Expand All @@ -239,24 +248,30 @@ def guard_streamer():
*args,
**payload,
)

for result in guard_stream:
# TODO: Just make this a ValidationOutcome with history
validation_output: ValidationOutcome = (
ValidationOutcome.from_guard_history(guard.history.last)
)
yield validation_output, cast(ValidationOutcome, result)

# ValidationOutcome(
# guard.history,
# validation_passed=result.validation_passed,
# validated_output=result.validated_output,
# raw_llm_output=result.raw_llm_output,
# )
async def async_guard_streamer():
guard_stream = await guard(
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
stream=stream,
*args,
**payload,
)
async for result in guard_stream:
validation_output: ValidationOutcome = (
ValidationOutcome.from_guard_history(guard.history.last)
)
yield validation_output, cast(ValidationOutcome, result)

def validate_streamer(guard_iter):
next_result = None
# next_validation_output = None
for validation_output, result in guard_iter:
next_result = result
# next_validation_output = validation_output
Expand All @@ -271,7 +286,6 @@ def validate_streamer(guard_iter):
)
fragment = json.dumps(fragment_dict)
yield f"{fragment}\n"

call = guard.history.last
final_validation_output: ValidationOutcome = ValidationOutcome(
callId=call.id,
Expand Down Expand Up @@ -303,11 +317,80 @@ def validate_streamer(guard_iter):
serialized_history = [call.to_dict() for call in guard.history]
cache_key = f"{guard.name}-{final_validation_output.call_id}"
cache_client.set(cache_key, serialized_history, 300)

yield f"{final_output_json}\n"

async def async_validate_streamer(guard_iter):
next_result = None
# next_validation_output = None
async for validation_output, result in guard_iter:
next_result = result
# next_validation_output = validation_output
fragment_dict = result.to_dict()
fragment_dict["error_spans"] = list(
map(
lambda x: json.dumps(
{"start": x.start, "end": x.end, "reason": x.reason}
),
guard.error_spans_in_output(),
)
)
fragment = json.dumps(fragment_dict)
yield f"{fragment}\n"

call = guard.history.last
final_validation_output: ValidationOutcome = ValidationOutcome(
callId=call.id,
validation_passed=next_result.validation_passed,
validated_output=next_result.validated_output,
history=guard.history,
raw_llm_output=next_result.raw_llm_output,
)
# I don't know if these are actually making it to OpenSearch
# because the span may be ended already
# collect_telemetry(
# guard=guard,
# validate_span=validate_span,
# validation_output=next_validation_output,
# prompt_params=prompt_params,
# result=next_result
# )
final_output_dict = final_validation_output.to_dict()
final_output_dict["error_spans"] = list(
map(
lambda x: json.dumps(
{"start": x.start, "end": x.end, "reason": x.reason}
),
guard.error_spans_in_output(),
)
)
final_output_json = json.dumps(final_output_dict)

serialized_history = [call.to_dict() for call in guard.history]
cache_key = f"{guard.name}-{final_validation_output.call_id}"
cache_client.set(cache_key, serialized_history, 300)
yield f"{final_output_json}\n"
# apropos of https://stackoverflow.com/questions/73949570/using-stream-with-context-as-async
def iter_over_async(ait, loop):
ait = ait.__aiter__()
async def get_next():
try:
obj = await ait.__anext__()
return False, obj
except StopAsyncIteration:
return True, None
while True:
done, obj = loop.run_until_complete(get_next())
if done:
break
yield obj
if is_async:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
iter = iter_over_async(async_validate_streamer(async_guard_streamer()), loop)
else:
iter = validate_streamer(guard_streamer())
return Response(
stream_with_context(validate_streamer(guard_streamer())),
stream_with_context(iter),
content_type="application/json",
# content_type="text/event-stream"
)
Expand Down

0 comments on commit 86e0b6f

Please sign in to comment.