Skip to content

Commit

Permalink
Merge pull request #79 from guardrails-ai/fix_mismatch_with_await
Browse files Browse the repository at this point in the history
fix mismatch and rev required version
  • Loading branch information
zsimjee authored Oct 4, 2024
2 parents a1e71a9 + 7bd2bc6 commit 94617a2
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 21 deletions.
34 changes: 15 additions & 19 deletions guardrails_api/api/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def openai_v1_chat_completions(guard_name: str, request: Request):
)

guard = (
Guard.from_dict(guard_struct.to_dict())
AsyncGuard.from_dict(guard_struct.to_dict())
if not isinstance(guard_struct, Guard)
else guard_struct
)
Expand All @@ -125,7 +125,7 @@ async def openai_v1_chat_completions(guard_name: str, request: Request):
)

if not stream:
validation_outcome: ValidationOutcome = guard(num_reasks=0, **payload)
validation_outcome: ValidationOutcome = await guard(num_reasks=0, **payload)
llm_response = guard.history.last.iterations.last.outputs.llm_response_info
result = outcome_to_chat_completion(
validation_outcome=validation_outcome,
Expand All @@ -136,8 +136,8 @@ async def openai_v1_chat_completions(guard_name: str, request: Request):
else:

async def openai_streamer():
guard_stream = guard(num_reasks=0, **payload)
for result in guard_stream:
guard_stream = await guard(num_reasks=0, **payload)
async for result in guard_stream:
chunk = json.dumps(
outcome_to_stream_response(validation_outcome=result)
)
Expand Down Expand Up @@ -253,22 +253,18 @@ async def validate_streamer(guard_iter):
validate_streamer(guard_streamer()), media_type="application/json"
)
else:
if inspect.iscoroutinefunction(guard):
result: ValidationOutcome = await guard(
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
*args,
**payload,
)
execution = guard(
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
*args,
**payload,
)

if inspect.iscoroutine(execution):
result: ValidationOutcome = await execution
else:
result: ValidationOutcome = guard(
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
*args,
**payload,
)
result: ValidationOutcome = execution

serialized_history = [call.to_dict() for call in guard.history]
cache_key = f"{guard.name}-{result.call_id}"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ readme = "README.md"
keywords = ["Guardrails", "Guardrails AI", "Guardrails API", "Guardrails API"]
requires-python = ">= 3.8.1"
dependencies = [
"guardrails-ai>=0.5.10",
"guardrails-ai>=0.5.12",
"jsonschema>=4.22.0,<5",
"referencing>=0.35.1,<1",
"boto3>=1.34.115,<2",
Expand Down
7 changes: 6 additions & 1 deletion tests/api/test_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from tests.mocks.mock_guard_client import MockGuardStruct
from guardrails_api.api.guards import router as guards_router


import asyncio

# TODO: Should we mock this somehow?
# Right now it's just empty, but it technically does a file read
register_config()
Expand Down Expand Up @@ -344,7 +347,9 @@ def test_openai_v1_chat_completions__call(mocker):
)

mock___call__ = mocker.patch.object(MockGuardStruct, "__call__")
mock___call__.return_value = mock_outcome
future = asyncio.Future()
future.set_result(mock_outcome)
mock___call__.return_value = future

mock_from_dict = mocker.patch("guardrails_api.api.guards.Guard.from_dict")
mock_from_dict.return_value = mock_guard
Expand Down

0 comments on commit 94617a2

Please sign in to comment.