Skip to content

Commit

Permalink
Merge pull request #99 from langchain-ai/mattf/fix-202-polling-auth
Browse files Browse the repository at this point in the history
fix missing auth header on polling requests
  • Loading branch information
mattf authored Sep 17, 2024
2 parents d14850d + 04bbfc3 commit e276623
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 8 deletions.
24 changes: 16 additions & 8 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,16 @@ class Config:
api_key: Optional[SecretStr] = Field(description="API Key for service of choice")

## Generation arguments
timeout: float = Field(60, ge=0, description="Timeout for waiting on response (s)")
interval: float = Field(0.02, ge=0, description="Interval for pulling response")
timeout: float = Field(
60,
ge=0,
description="The minimum amount of time (in sec) to poll after a 202 response",
)
interval: float = Field(
0.02,
ge=0,
description="Interval (in sec) between polling attempts after a 202 response",
)
last_inputs: Optional[dict] = Field(
description="Last inputs sent over to the server"
)
Expand Down Expand Up @@ -372,9 +380,7 @@ def _wait(self, response: Response, session: requests.Session) -> Response:
start_time = time.time()
# note: the local NIM does not return a 202 status code
# (per RL 22may2024 circa 24.05)
while (
response.status_code == 202
): # todo: there are no tests that reach this point
while response.status_code == 202:
time.sleep(self.interval)
if (time.time() - start_time) > self.timeout:
raise TimeoutError(
Expand All @@ -385,10 +391,12 @@ def _wait(self, response: Response, session: requests.Session) -> Response:
"NVCF-REQID" in response.headers
), "Received 202 response with no request id to follow"
request_id = response.headers.get("NVCF-REQID")
# todo: this needs testing, missing auth header update
payload = {
"url": self.polling_url_tmpl.format(request_id=request_id),
"headers": self.headers_tmpl["call"],
}
self.last_response = response = session.get(
self.polling_url_tmpl.format(request_id=request_id),
headers=self.headers_tmpl["call"],
**self.__add_authorization(payload)
)
self._try_raise(response)
return response
Expand Down
57 changes: 57 additions & 0 deletions libs/ai-endpoints/tests/unit_tests/test_202_polling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import requests_mock
from langchain_core.messages import AIMessage

from langchain_nvidia_ai_endpoints import ChatNVIDIA


def test_polling_auth_header(
requests_mock: requests_mock.Mocker,
mock_model: str,
) -> None:
infer_url = "https://integrate.api.nvidia.com/v1/chat/completions"
polling_url = "https://api.nvcf.nvidia.com/v2/nvcf/pexec/status/test-request-id"

requests_mock.post(
infer_url, status_code=202, headers={"NVCF-REQID": "test-request-id"}, json={}
)

requests_mock.get(
polling_url,
status_code=200,
json={
"id": "mock-id",
"created": 1234567890,
"object": "chat.completion",
"model": mock_model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "WORKED"},
}
],
},
)

client = ChatNVIDIA(model=mock_model, api_key="BOGUS")
response = client.invoke("IGNORED")

# expected behavior -
# - first a GET request to /v1/models to check the model exists
# - second a POST request to /v1/chat/completions
# - third a GET request to /v2/nvcf/pexec/status/test-request-id
# we want to check on the second and third requests

assert len(requests_mock.request_history) == 3

infer_request = requests_mock.request_history[-2]
assert infer_request.method == "POST"
assert infer_request.url == infer_url
assert infer_request.headers["Authorization"] == "Bearer BOGUS"

poll_request = requests_mock.request_history[-1]
assert poll_request.method == "GET"
assert poll_request.url == polling_url
assert poll_request.headers["Authorization"] == "Bearer BOGUS"

assert isinstance(response, AIMessage)
assert response.content == "WORKED"

0 comments on commit e276623

Please sign in to comment.