Skip to content

Commit

Permalink
update openai chat completion to be fully async
Browse files Browse the repository at this point in the history
  • Loading branch information
dtam committed Oct 4, 2024
1 parent cd69b1e commit 7bd2bc6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
8 changes: 4 additions & 4 deletions guardrails_api/api/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def openai_v1_chat_completions(guard_name: str, request: Request):
)

guard = (
Guard.from_dict(guard_struct.to_dict())
AsyncGuard.from_dict(guard_struct.to_dict())
if not isinstance(guard_struct, Guard)
else guard_struct
)
Expand All @@ -125,7 +125,7 @@ async def openai_v1_chat_completions(guard_name: str, request: Request):
)

if not stream:
validation_outcome: ValidationOutcome = guard(num_reasks=0, **payload)
validation_outcome: ValidationOutcome = await guard(num_reasks=0, **payload)
llm_response = guard.history.last.iterations.last.outputs.llm_response_info
result = outcome_to_chat_completion(
validation_outcome=validation_outcome,
Expand All @@ -136,8 +136,8 @@ async def openai_v1_chat_completions(guard_name: str, request: Request):
else:

async def openai_streamer():
guard_stream = guard(num_reasks=0, **payload)
for result in guard_stream:
guard_stream = await guard(num_reasks=0, **payload)
async for result in guard_stream:
chunk = json.dumps(
outcome_to_stream_response(validation_outcome=result)
)
Expand Down
7 changes: 6 additions & 1 deletion tests/api/test_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from tests.mocks.mock_guard_client import MockGuardStruct
from guardrails_api.api.guards import router as guards_router


import asyncio

# TODO: Should we mock this somehow?
# Right now it's just empty, but it technically does a file read
register_config()
Expand Down Expand Up @@ -344,7 +347,9 @@ def test_openai_v1_chat_completions__call(mocker):
)

mock___call__ = mocker.patch.object(MockGuardStruct, "__call__")
mock___call__.return_value = mock_outcome
future = asyncio.Future()
future.set_result(mock_outcome)
mock___call__.return_value = future

mock_from_dict = mocker.patch("guardrails_api.api.guards.Guard.from_dict")
mock_from_dict.return_value = mock_guard
Expand Down

0 comments on commit 7bd2bc6

Please sign in to comment.