Skip to content

Commit

Permalink
split async and sync into separate methods
Browse files Browse the repository at this point in the history
  • Loading branch information
nichwch committed Jul 3, 2024
1 parent 51c1e31 commit 56b050f
Showing 1 changed file with 96 additions and 64 deletions.
160 changes: 96 additions & 64 deletions guardrails_api/blueprints/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,73 +239,103 @@ def validate(guard_name: str):
)
else:
if stream:

async def guard_streamer():
if is_async:
guard_stream = await guard(
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
stream=stream,
*args,
**payload,
def guard_streamer():
guard_stream = guard(
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
stream=stream,
*args,
**payload,
)
for result in guard_stream:
# TODO: Just make this a ValidationOutcome with history
validation_output: ValidationOutcome = (
ValidationOutcome.from_guard_history(guard.history.last)
)
else:
guard_stream = guard(
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
stream=stream,
*args,
**payload,
yield validation_output, cast(ValidationOutcome, result)

async def async_guard_streamer():
guard_stream = await guard(
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
stream=stream,
*args,
**payload,
)
async for result in guard_stream:
validation_output: ValidationOutcome = (
ValidationOutcome.from_guard_history(guard.history.last)
)
if is_async:
async for result in guard_stream:
validation_output: ValidationOutcome = (
ValidationOutcome.from_guard_history(guard.history.last)
)
yield validation_output, cast(ValidationOutcome, result)
else:
for result in guard_stream:
# TODO: Just make this a ValidationOutcome with history
validation_output: ValidationOutcome = (
ValidationOutcome.from_guard_history(guard.history.last)
yield validation_output, cast(ValidationOutcome, result)

def validate_streamer(guard_iter):
next_result = None
for validation_output, result in guard_iter:
next_result = result
# next_validation_output = validation_output
fragment_dict = result.to_dict()
fragment_dict["error_spans"] = list(
map(
lambda x: json.dumps(
{"start": x.start, "end": x.end, "reason": x.reason}
),
guard.error_spans_in_output(),
)
yield validation_output, cast(ValidationOutcome, result)
)
fragment = json.dumps(fragment_dict)
yield f"{fragment}\n"
call = guard.history.last
final_validation_output: ValidationOutcome = ValidationOutcome(
callId=call.id,
validation_passed=next_result.validation_passed,
validated_output=next_result.validated_output,
history=guard.history,
raw_llm_output=next_result.raw_llm_output,
)
# I don't know if these are actually making it to OpenSearch
# because the span may be ended already
# collect_telemetry(
# guard=guard,
# validate_span=validate_span,
# validation_output=next_validation_output,
# prompt_params=prompt_params,
# result=next_result
# )
final_output_dict = final_validation_output.to_dict()
final_output_dict["error_spans"] = list(
map(
lambda x: json.dumps(
{"start": x.start, "end": x.end, "reason": x.reason}
),
guard.error_spans_in_output(),
)
)
final_output_json = json.dumps(final_output_dict)

serialized_history = [call.to_dict() for call in guard.history]
cache_key = f"{guard.name}-{final_validation_output.call_id}"
cache_client.set(cache_key, serialized_history, 300)
yield f"{final_output_json}\n"

async def validate_streamer(guard_iter):
async def async_validate_streamer(guard_iter):
next_result = None
# next_validation_output = None
if is_async:
async for validation_output, result in guard_iter:
next_result = result
# next_validation_output = validation_output
fragment_dict = result.to_dict()
fragment_dict["error_spans"] = list(
map(
lambda x: json.dumps(
{"start": x.start, "end": x.end, "reason": x.reason}
),
guard.error_spans_in_output(),
)
)
fragment = json.dumps(fragment_dict)
yield f"{fragment}\n"
else:
for validation_output, result in guard_iter:
next_result = result
# next_validation_output = validation_output
fragment_dict = result.to_dict()
fragment_dict["error_spans"] = list(
map(
lambda x: json.dumps(
{"start": x.start, "end": x.end, "reason": x.reason}
),
guard.error_spans_in_output(),
)
async for validation_output, result in guard_iter:
next_result = result
# next_validation_output = validation_output
fragment_dict = result.to_dict()
fragment_dict["error_spans"] = list(
map(
lambda x: json.dumps(
{"start": x.start, "end": x.end, "reason": x.reason}
),
guard.error_spans_in_output(),
)
fragment = json.dumps(fragment_dict)
yield f"{fragment}\n"
)
fragment = json.dumps(fragment_dict)
yield f"{fragment}\n"

call = guard.history.last
final_validation_output: ValidationOutcome = ValidationOutcome(
Expand Down Expand Up @@ -353,10 +383,12 @@ async def get_next():
if done:
break
yield obj

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
iter = iter_over_async(validate_streamer(guard_streamer()), loop)
if is_async:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
iter = iter_over_async(async_validate_streamer(async_guard_streamer()), loop)
else:
iter = validate_streamer(guard_streamer())
return Response(
stream_with_context(iter),
content_type="application/json",
Expand Down

0 comments on commit 56b050f

Please sign in to comment.