Skip to content

Commit

Permalink
fix missing auth header on polling requests
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf committed Sep 5, 2024
1 parent c546bc4 commit 72a1eac
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
7 changes: 5 additions & 2 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,12 @@ def _wait(self, response: Response, session: requests.Session) -> Response:
), "Received 202 response with no request id to follow"
request_id = response.headers.get("NVCF-REQID")
# todo: this needs testing, missing auth header update
payload = {
"url": self.polling_url_tmpl.format(request_id=request_id),
"headers": self.headers_tmpl["call"],
}
self.last_response = response = session.get(
self.polling_url_tmpl.format(request_id=request_id),
headers=self.headers_tmpl["call"],
**self.__add_authorization(payload)
)
self._try_raise(response)
return response
Expand Down
57 changes: 57 additions & 0 deletions libs/ai-endpoints/tests/unit_tests/test_202_polling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import requests_mock
from langchain_core.messages import AIMessage

from langchain_nvidia_ai_endpoints import ChatNVIDIA


def test_polling_auth_header(
requests_mock: requests_mock.Mocker,
mock_model: str,
) -> None:
infer_url = "https://integrate.api.nvidia.com/v1/chat/completions"
polling_url = "https://api.nvcf.nvidia.com/v2/nvcf/pexec/status/test-request-id"

requests_mock.post(
infer_url, status_code=202, headers={"NVCF-REQID": "test-request-id"}, json={}
)

requests_mock.get(
polling_url,
status_code=200,
json={
"id": "mock-id",
"created": 1234567890,
"object": "chat.completion",
"model": mock_model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "WORKED"},
}
],
},
)

client = ChatNVIDIA(model=mock_model, api_key="BOGUS")
response = client.invoke("IGNORED")

# expected behavior -
# - first a GET request to /v1/models to check the model exists
# - second a POST request to /v1/chat/completions
# - third a GET request to /v2/nvcf/pexec/status/test-request-id
# we want to check on the second and third requests

assert len(requests_mock.request_history) == 3

infer_request = requests_mock.request_history[-2]
assert infer_request.method == "POST"
assert infer_request.url == infer_url
assert infer_request.headers["Authorization"] == "Bearer BOGUS"

poll_request = requests_mock.request_history[-1]
assert poll_request.method == "GET"
assert poll_request.url == polling_url
assert poll_request.headers["Authorization"] == "Bearer BOGUS"

assert isinstance(response, AIMessage)
assert response.content == "WORKED"

0 comments on commit 72a1eac

Please sign in to comment.