diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py index 019cdb01..ce672690 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py @@ -74,8 +74,16 @@ class Config: api_key: Optional[SecretStr] = Field(description="API Key for service of choice") ## Generation arguments - timeout: float = Field(60, ge=0, description="Timeout for waiting on response (s)") - interval: float = Field(0.02, ge=0, description="Interval for pulling response") + timeout: float = Field( + 60, + ge=0, + description="The minimum amount of time (in sec) to poll after a 202 response", + ) + interval: float = Field( + 0.02, + ge=0, + description="Interval (in sec) between polling attempts after a 202 response", + ) last_inputs: Optional[dict] = Field( description="Last inputs sent over to the server" ) @@ -372,9 +380,7 @@ def _wait(self, response: Response, session: requests.Session) -> Response: start_time = time.time() # note: the local NIM does not return a 202 status code # (per RL 22may2024 circa 24.05) - while ( - response.status_code == 202 - ): # todo: there are no tests that reach this point + while response.status_code == 202: time.sleep(self.interval) if (time.time() - start_time) > self.timeout: raise TimeoutError( @@ -385,10 +391,12 @@ def _wait(self, response: Response, session: requests.Session) -> Response: "NVCF-REQID" in response.headers ), "Received 202 response with no request id to follow" request_id = response.headers.get("NVCF-REQID") - # todo: this needs testing, missing auth header update + payload = { + "url": self.polling_url_tmpl.format(request_id=request_id), + "headers": self.headers_tmpl["call"], + } self.last_response = response = session.get( - self.polling_url_tmpl.format(request_id=request_id), - headers=self.headers_tmpl["call"], + **self.__add_authorization(payload) ) self._try_raise(response) return response diff --git a/libs/ai-endpoints/tests/unit_tests/test_202_polling.py b/libs/ai-endpoints/tests/unit_tests/test_202_polling.py new file mode 100644 index 00000000..18469e63 --- /dev/null +++ b/libs/ai-endpoints/tests/unit_tests/test_202_polling.py @@ -0,0 +1,57 @@ +import requests_mock +from langchain_core.messages import AIMessage + +from langchain_nvidia_ai_endpoints import ChatNVIDIA + + +def test_polling_auth_header( + requests_mock: requests_mock.Mocker, + mock_model: str, +) -> None: + infer_url = "https://integrate.api.nvidia.com/v1/chat/completions" + polling_url = "https://api.nvcf.nvidia.com/v2/nvcf/pexec/status/test-request-id" + + requests_mock.post( + infer_url, status_code=202, headers={"NVCF-REQID": "test-request-id"}, json={} + ) + + requests_mock.get( + polling_url, + status_code=200, + json={ + "id": "mock-id", + "created": 1234567890, + "object": "chat.completion", + "model": mock_model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "WORKED"}, + } + ], + }, + ) + + client = ChatNVIDIA(model=mock_model, api_key="BOGUS") + response = client.invoke("IGNORED") + + # expected behavior - + # - first a GET request to /v1/models to check the model exists + # - second a POST request to /v1/chat/completions + # - third a GET request to /v2/nvcf/pexec/status/test-request-id + # we want to check on the second and third requests + + assert len(requests_mock.request_history) == 3 + + infer_request = requests_mock.request_history[-2] + assert infer_request.method == "POST" + assert infer_request.url == infer_url + assert infer_request.headers["Authorization"] == "Bearer BOGUS" + + poll_request = requests_mock.request_history[-1] + assert poll_request.method == "GET" + assert poll_request.url == polling_url + assert poll_request.headers["Authorization"] == "Bearer BOGUS" + + assert isinstance(response, AIMessage) + assert response.content == "WORKED"