From 11d9bd5466f8a5d1d12692424895163642d6be75 Mon Sep 17 00:00:00 2001 From: Nicholas Chen Date: Fri, 28 Jun 2024 17:02:33 -0400 Subject: [PATCH 1/5] fix async streaming --- guardrails_api/blueprints/guards.py | 135 +++++++++++++++++++--------- 1 file changed, 92 insertions(+), 43 deletions(-) diff --git a/guardrails_api/blueprints/guards.py b/guardrails_api/blueprints/guards.py index dea5083..6407099 100644 --- a/guardrails_api/blueprints/guards.py +++ b/guardrails_api/blueprints/guards.py @@ -1,11 +1,13 @@ +import asyncio import json import os +import inspect from guardrails.hub import * # noqa from string import Template from typing import Any, Dict, cast from flask import Blueprint, Response, request, stream_with_context from urllib.parse import unquote_plus -from guardrails import Guard +from guardrails import AsyncGuard, Guard from guardrails.classes import ValidationOutcome from opentelemetry.trace import Span from guardrails_api_client import Guard as GuardStruct @@ -187,9 +189,7 @@ def validate(guard_name: str): # f"validate-{decoded_guard_name}" # ) as validate_span: # guard: Guard = guard_struct.to_guard(openai_api_key, otel_tracer) - guard = guard_struct - if not isinstance(guard_struct, Guard): - guard: Guard = Guard.from_dict(guard_struct.to_dict()) + # validate_span.set_attribute("guardName", decoded_guard_name) if llm_api is not None: @@ -204,6 +204,17 @@ def validate(guard_name: str): " OPENAI_API_KEY environment variable." ), ) + + guard = guard_struct + is_async = inspect.iscoroutinefunction(llm_api) + if not isinstance(guard_struct, Guard): + if is_async: + guard = AsyncGuard.from_dict(guard_struct.to_dict()) + else: + guard: Guard = Guard.from_dict(guard_struct.to_dict()) + elif is_async: + guard:Guard = AsyncGuard.from_dict(guard_struct.to_dict()) + elif num_reasks and num_reasks > 1: raise HttpError( status=400, @@ -230,47 +241,73 @@ def validate(guard_name: str): else: if stream: - def guard_streamer(): - guard_stream = guard( - llm_api=llm_api, - prompt_params=prompt_params, - num_reasks=num_reasks, - stream=stream, - *args, - **payload, - ) - - for result in guard_stream: - # TODO: Just make this a ValidationOutcome with history - validation_output: ValidationOutcome = ( - ValidationOutcome.from_guard_history(guard.history.last) + async def guard_streamer(): + if is_async: + guard_stream = await guard( + llm_api=llm_api, + prompt_params=prompt_params, + num_reasks=num_reasks, + stream=stream, + *args, + **payload, ) + else: + guard_stream = guard( + llm_api=llm_api, + prompt_params=prompt_params, + num_reasks=num_reasks, + stream=stream, + *args, + **payload, + ) + if is_async: + async for result in guard_stream: + validation_output: ValidationOutcome = ( + ValidationOutcome.from_guard_history(guard.history.last) + ) + yield validation_output, cast(ValidationOutcome, result) + else: + for result in guard_stream: + # TODO: Just make this a ValidationOutcome with history + validation_output: ValidationOutcome = ( + ValidationOutcome.from_guard_history(guard.history.last) + ) + yield validation_output, cast(ValidationOutcome, result) - # ValidationOutcome( - # guard.history, - # validation_passed=result.validation_passed, - # validated_output=result.validated_output, - # raw_llm_output=result.raw_llm_output, - # ) - yield validation_output, cast(ValidationOutcome, result) - - def validate_streamer(guard_iter): + async def validate_streamer(guard_iter): next_result = None # next_validation_output = None - for validation_output, result in guard_iter: - next_result = result - # next_validation_output = validation_output - fragment_dict = result.to_dict() - fragment_dict["error_spans"] = list( - map( - lambda x: json.dumps( - {"start": x.start, "end": x.end, "reason": x.reason} - ), - guard.error_spans_in_output(), + if is_async: + async for validation_output, result in guard_iter: + next_result = result + # next_validation_output = validation_output + fragment_dict = result.to_dict() + fragment_dict["error_spans"] = list( + map( + lambda x: json.dumps( + {"start": x.start, "end": x.end, "reason": x.reason} + ), + guard.error_spans_in_output(), + ) ) - ) - fragment = json.dumps(fragment_dict) - yield f"{fragment}\n" + fragment = json.dumps(fragment_dict) + print('yileding async', fragment) + yield f"{fragment}\n" + else: + for validation_output, result in guard_iter: + next_result = result + # next_validation_output = validation_output + fragment_dict = result.to_dict() + fragment_dict["error_spans"] = list( + map( + lambda x: json.dumps( + {"start": x.start, "end": x.end, "reason": x.reason} + ), + guard.error_spans_in_output(), + ) + ) + fragment = json.dumps(fragment_dict) + yield f"{fragment}\n" call = guard.history.last final_validation_output: ValidationOutcome = ValidationOutcome( @@ -303,11 +340,23 @@ def validate_streamer(guard_iter): serialized_history = [call.to_dict() for call in guard.history] cache_key = f"{guard.name}-{final_validation_output.call_id}" cache_client.set(cache_key, serialized_history, 300) - yield f"{final_output_json}\n" - + + def iter_over_async(ait, loop): + ait = ait.__aiter__() + async def get_next(): + try: obj = await ait.__anext__(); return False, obj + except StopAsyncIteration: return True, None + while True: + done, obj = loop.run_until_complete(get_next()) + if done: break + yield obj + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + iter = iter_over_async(validate_streamer(guard_streamer()), loop) return Response( - stream_with_context(validate_streamer(guard_streamer())), + stream_with_context(iter), content_type="application/json", # content_type="text/event-stream" ) From 8f13c6ee0c5b21ea6862d6c9944ad738c25d93d8 Mon Sep 17 00:00:00 2001 From: Nicholas Chen Date: Fri, 28 Jun 2024 17:03:06 -0400 Subject: [PATCH 2/5] give stack overflow credit --- guardrails_api/blueprints/guards.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/guardrails_api/blueprints/guards.py b/guardrails_api/blueprints/guards.py index 6407099..63eeaae 100644 --- a/guardrails_api/blueprints/guards.py +++ b/guardrails_api/blueprints/guards.py @@ -291,7 +291,6 @@ async def validate_streamer(guard_iter): ) ) fragment = json.dumps(fragment_dict) - print('yileding async', fragment) yield f"{fragment}\n" else: for validation_output, result in guard_iter: @@ -341,7 +340,7 @@ async def validate_streamer(guard_iter): cache_key = f"{guard.name}-{final_validation_output.call_id}" cache_client.set(cache_key, serialized_history, 300) yield f"{final_output_json}\n" - + # apropos of https://stackoverflow.com/questions/73949570/using-stream-with-context-as-async def iter_over_async(ait, loop): ait = ait.__aiter__() async def get_next(): From f34345c2ba51b356af3be582858b74c33b845af6 Mon Sep 17 00:00:00 2001 From: Nicholas Chen Date: Mon, 1 Jul 2024 11:58:28 -0400 Subject: [PATCH 3/5] format --- guardrails_api/blueprints/guards.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/guardrails_api/blueprints/guards.py b/guardrails_api/blueprints/guards.py index 63eeaae..0903295 100644 --- a/guardrails_api/blueprints/guards.py +++ b/guardrails_api/blueprints/guards.py @@ -344,11 +344,15 @@ async def validate_streamer(guard_iter): def iter_over_async(ait, loop): ait = ait.__aiter__() async def get_next(): - try: obj = await ait.__anext__(); return False, obj - except StopAsyncIteration: return True, None + try: + obj = await ait.__anext__() + return False, obj + except StopAsyncIteration: + return True, None while True: done, obj = loop.run_until_complete(get_next()) - if done: break + if done: + break yield obj loop = asyncio.new_event_loop() From 51c1e31bbe7ffb851d124f41a142be9089751b28 Mon Sep 17 00:00:00 2001 From: Nicholas Chen Date: Mon, 1 Jul 2024 19:06:48 -0400 Subject: [PATCH 4/5] fix bug that tests caught --- guardrails_api/blueprints/guards.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/guardrails_api/blueprints/guards.py b/guardrails_api/blueprints/guards.py index 0903295..7d3c099 100644 --- a/guardrails_api/blueprints/guards.py +++ b/guardrails_api/blueprints/guards.py @@ -172,7 +172,6 @@ def validate(guard_name: str): ) decoded_guard_name = unquote_plus(guard_name) guard_struct = guard_client.get_guard(decoded_guard_name) - llm_output = payload.pop("llmOutput", None) num_reasks = payload.pop("numReasks", None) prompt_params = payload.pop("promptParams", {}) @@ -215,7 +214,7 @@ def validate(guard_name: str): elif is_async: guard:Guard = AsyncGuard.from_dict(guard_struct.to_dict()) - elif num_reasks and num_reasks > 1: + if llm_api is None and num_reasks and num_reasks > 1: raise HttpError( status=400, message="BadRequest", From 56b050fabbbea1c93cb1a85a79f84616b016ff48 Mon Sep 17 00:00:00 2001 From: Nicholas Chen Date: Wed, 3 Jul 2024 17:35:05 -0400 Subject: [PATCH 5/5] split async and sync into separate methods --- guardrails_api/blueprints/guards.py | 160 +++++++++++++++++----------- 1 file changed, 96 insertions(+), 64 deletions(-) diff --git a/guardrails_api/blueprints/guards.py b/guardrails_api/blueprints/guards.py index 7d3c099..1485f07 100644 --- a/guardrails_api/blueprints/guards.py +++ b/guardrails_api/blueprints/guards.py @@ -239,73 +239,103 @@ def validate(guard_name: str): ) else: if stream: - - async def guard_streamer(): - if is_async: - guard_stream = await guard( - llm_api=llm_api, - prompt_params=prompt_params, - num_reasks=num_reasks, - stream=stream, - *args, - **payload, + def guard_streamer(): + guard_stream = guard( + llm_api=llm_api, + prompt_params=prompt_params, + num_reasks=num_reasks, + stream=stream, + *args, + **payload, + ) + for result in guard_stream: + # TODO: Just make this a ValidationOutcome with history + validation_output: ValidationOutcome = ( + ValidationOutcome.from_guard_history(guard.history.last) ) - else: - guard_stream = guard( - llm_api=llm_api, - prompt_params=prompt_params, - num_reasks=num_reasks, - stream=stream, - *args, - **payload, + yield validation_output, cast(ValidationOutcome, result) + + async def async_guard_streamer(): + guard_stream = await guard( + llm_api=llm_api, + prompt_params=prompt_params, + num_reasks=num_reasks, + stream=stream, + *args, + **payload, + ) + async for result in guard_stream: + validation_output: ValidationOutcome = ( + ValidationOutcome.from_guard_history(guard.history.last) ) - if is_async: - async for result in guard_stream: - validation_output: ValidationOutcome = ( - ValidationOutcome.from_guard_history(guard.history.last) - ) - yield validation_output, cast(ValidationOutcome, result) - else: - for result in guard_stream: - # TODO: Just make this a ValidationOutcome with history - validation_output: ValidationOutcome = ( - ValidationOutcome.from_guard_history(guard.history.last) + yield validation_output, cast(ValidationOutcome, result) + + def validate_streamer(guard_iter): + next_result = None + for validation_output, result in guard_iter: + next_result = result + # next_validation_output = validation_output + fragment_dict = result.to_dict() + fragment_dict["error_spans"] = list( + map( + lambda x: json.dumps( + {"start": x.start, "end": x.end, "reason": x.reason} + ), + guard.error_spans_in_output(), ) - yield validation_output, cast(ValidationOutcome, result) + ) + fragment = json.dumps(fragment_dict) + yield f"{fragment}\n" + call = guard.history.last + final_validation_output: ValidationOutcome = ValidationOutcome( + callId=call.id, + validation_passed=next_result.validation_passed, + validated_output=next_result.validated_output, + history=guard.history, + raw_llm_output=next_result.raw_llm_output, + ) + # I don't know if these are actually making it to OpenSearch + # because the span may be ended already + # collect_telemetry( + # guard=guard, + # validate_span=validate_span, + # validation_output=next_validation_output, + # prompt_params=prompt_params, + # result=next_result + # ) + final_output_dict = final_validation_output.to_dict() + final_output_dict["error_spans"] = list( + map( + lambda x: json.dumps( + {"start": x.start, "end": x.end, "reason": x.reason} + ), + guard.error_spans_in_output(), + ) + ) + final_output_json = json.dumps(final_output_dict) + + serialized_history = [call.to_dict() for call in guard.history] + cache_key = f"{guard.name}-{final_validation_output.call_id}" + cache_client.set(cache_key, serialized_history, 300) + yield f"{final_output_json}\n" - async def validate_streamer(guard_iter): + async def async_validate_streamer(guard_iter): next_result = None # next_validation_output = None - if is_async: - async for validation_output, result in guard_iter: - next_result = result - # next_validation_output = validation_output - fragment_dict = result.to_dict() - fragment_dict["error_spans"] = list( - map( - lambda x: json.dumps( - {"start": x.start, "end": x.end, "reason": x.reason} - ), - guard.error_spans_in_output(), - ) - ) - fragment = json.dumps(fragment_dict) - yield f"{fragment}\n" - else: - for validation_output, result in guard_iter: - next_result = result - # next_validation_output = validation_output - fragment_dict = result.to_dict() - fragment_dict["error_spans"] = list( - map( - lambda x: json.dumps( - {"start": x.start, "end": x.end, "reason": x.reason} - ), - guard.error_spans_in_output(), - ) + async for validation_output, result in guard_iter: + next_result = result + # next_validation_output = validation_output + fragment_dict = result.to_dict() + fragment_dict["error_spans"] = list( + map( + lambda x: json.dumps( + {"start": x.start, "end": x.end, "reason": x.reason} + ), + guard.error_spans_in_output(), ) - fragment = json.dumps(fragment_dict) - yield f"{fragment}\n" + ) + fragment = json.dumps(fragment_dict) + yield f"{fragment}\n" call = guard.history.last final_validation_output: ValidationOutcome = ValidationOutcome( @@ -353,10 +383,12 @@ async def get_next(): if done: break yield obj - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - iter = iter_over_async(validate_streamer(guard_streamer()), loop) + if is_async: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + iter = iter_over_async(async_validate_streamer(async_guard_streamer()), loop) + else: + iter = validate_streamer(guard_streamer()) return Response( stream_with_context(iter), content_type="application/json",