From 56b050fabbbea1c93cb1a85a79f84616b016ff48 Mon Sep 17 00:00:00 2001 From: Nicholas Chen Date: Wed, 3 Jul 2024 17:35:05 -0400 Subject: [PATCH] split async and sync into separate methods --- guardrails_api/blueprints/guards.py | 160 +++++++++++++++++----------- 1 file changed, 96 insertions(+), 64 deletions(-) diff --git a/guardrails_api/blueprints/guards.py b/guardrails_api/blueprints/guards.py index 7d3c099..1485f07 100644 --- a/guardrails_api/blueprints/guards.py +++ b/guardrails_api/blueprints/guards.py @@ -239,73 +239,103 @@ def validate(guard_name: str): ) else: if stream: - - async def guard_streamer(): - if is_async: - guard_stream = await guard( - llm_api=llm_api, - prompt_params=prompt_params, - num_reasks=num_reasks, - stream=stream, - *args, - **payload, + def guard_streamer(): + guard_stream = guard( + llm_api=llm_api, + prompt_params=prompt_params, + num_reasks=num_reasks, + stream=stream, + *args, + **payload, + ) + for result in guard_stream: + # TODO: Just make this a ValidationOutcome with history + validation_output: ValidationOutcome = ( + ValidationOutcome.from_guard_history(guard.history.last) ) - else: - guard_stream = guard( - llm_api=llm_api, - prompt_params=prompt_params, - num_reasks=num_reasks, - stream=stream, - *args, - **payload, + yield validation_output, cast(ValidationOutcome, result) + + async def async_guard_streamer(): + guard_stream = await guard( + llm_api=llm_api, + prompt_params=prompt_params, + num_reasks=num_reasks, + stream=stream, + *args, + **payload, + ) + async for result in guard_stream: + validation_output: ValidationOutcome = ( + ValidationOutcome.from_guard_history(guard.history.last) ) - if is_async: - async for result in guard_stream: - validation_output: ValidationOutcome = ( - ValidationOutcome.from_guard_history(guard.history.last) - ) - yield validation_output, cast(ValidationOutcome, result) - else: - for result in guard_stream: - # TODO: Just make this a ValidationOutcome with history - validation_output: ValidationOutcome = ( - ValidationOutcome.from_guard_history(guard.history.last) + yield validation_output, cast(ValidationOutcome, result) + + def validate_streamer(guard_iter): + next_result = None + for validation_output, result in guard_iter: + next_result = result + # next_validation_output = validation_output + fragment_dict = result.to_dict() + fragment_dict["error_spans"] = list( + map( + lambda x: json.dumps( + {"start": x.start, "end": x.end, "reason": x.reason} + ), + guard.error_spans_in_output(), ) - yield validation_output, cast(ValidationOutcome, result) + ) + fragment = json.dumps(fragment_dict) + yield f"{fragment}\n" + call = guard.history.last + final_validation_output: ValidationOutcome = ValidationOutcome( + callId=call.id, + validation_passed=next_result.validation_passed, + validated_output=next_result.validated_output, + history=guard.history, + raw_llm_output=next_result.raw_llm_output, + ) + # I don't know if these are actually making it to OpenSearch + # because the span may be ended already + # collect_telemetry( + # guard=guard, + # validate_span=validate_span, + # validation_output=next_validation_output, + # prompt_params=prompt_params, + # result=next_result + # ) + final_output_dict = final_validation_output.to_dict() + final_output_dict["error_spans"] = list( + map( + lambda x: json.dumps( + {"start": x.start, "end": x.end, "reason": x.reason} + ), + guard.error_spans_in_output(), + ) + ) + final_output_json = json.dumps(final_output_dict) + + serialized_history = [call.to_dict() for call in guard.history] + cache_key = f"{guard.name}-{final_validation_output.call_id}" + cache_client.set(cache_key, serialized_history, 300) + yield f"{final_output_json}\n" - async def validate_streamer(guard_iter): + async def async_validate_streamer(guard_iter): next_result = None # next_validation_output = None - if is_async: - async for validation_output, result in guard_iter: - next_result = result - # next_validation_output = validation_output - fragment_dict = result.to_dict() - fragment_dict["error_spans"] = list( - map( - lambda x: json.dumps( - {"start": x.start, "end": x.end, "reason": x.reason} - ), - guard.error_spans_in_output(), - ) - ) - fragment = json.dumps(fragment_dict) - yield f"{fragment}\n" - else: - for validation_output, result in guard_iter: - next_result = result - # next_validation_output = validation_output - fragment_dict = result.to_dict() - fragment_dict["error_spans"] = list( - map( - lambda x: json.dumps( - {"start": x.start, "end": x.end, "reason": x.reason} - ), - guard.error_spans_in_output(), - ) + async for validation_output, result in guard_iter: + next_result = result + # next_validation_output = validation_output + fragment_dict = result.to_dict() + fragment_dict["error_spans"] = list( + map( + lambda x: json.dumps( + {"start": x.start, "end": x.end, "reason": x.reason} + ), + guard.error_spans_in_output(), ) - fragment = json.dumps(fragment_dict) - yield f"{fragment}\n" + ) + fragment = json.dumps(fragment_dict) + yield f"{fragment}\n" call = guard.history.last final_validation_output: ValidationOutcome = ValidationOutcome( @@ -353,10 +383,12 @@ async def get_next(): if done: break yield obj - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - iter = iter_over_async(validate_streamer(guard_streamer()), loop) + if is_async: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + iter = iter_over_async(async_validate_streamer(async_guard_streamer()), loop) + else: + iter = validate_streamer(guard_streamer()) return Response( stream_with_context(iter), content_type="application/json",