From 1c9b0efd496f973acbcbc0fc835b3b203c9be3e9 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Thu, 5 Sep 2024 07:27:50 -0400 Subject: [PATCH 1/3] fix missing auth header on polling requests --- .../langchain_nvidia_ai_endpoints/_common.py | 12 ++-- .../tests/unit_tests/test_202_polling.py | 57 +++++++++++++++++++ 2 files changed, 63 insertions(+), 6 deletions(-) create mode 100644 libs/ai-endpoints/tests/unit_tests/test_202_polling.py diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py index 79c5a54b..0d0893f7 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py @@ -372,9 +372,7 @@ def _wait(self, response: Response, session: requests.Session) -> Response: start_time = time.time() # note: the local NIM does not return a 202 status code # (per RL 22may2024 circa 24.05) - while ( - response.status_code == 202 - ): # todo: there are no tests that reach this point + while response.status_code == 202: time.sleep(self.interval) if (time.time() - start_time) > self.timeout: raise TimeoutError( @@ -385,10 +383,12 @@ def _wait(self, response: Response, session: requests.Session) -> Response: "NVCF-REQID" in response.headers ), "Received 202 response with no request id to follow" request_id = response.headers.get("NVCF-REQID") - # todo: this needs testing, missing auth header update + payload = { + "url": self.polling_url_tmpl.format(request_id=request_id), + "headers": self.headers_tmpl["call"], + } self.last_response = response = session.get( - self.polling_url_tmpl.format(request_id=request_id), - headers=self.headers_tmpl["call"], + **self.__add_authorization(payload) ) self._try_raise(response) return response diff --git a/libs/ai-endpoints/tests/unit_tests/test_202_polling.py b/libs/ai-endpoints/tests/unit_tests/test_202_polling.py new file mode 100644 index 00000000..18469e63 --- /dev/null +++ b/libs/ai-endpoints/tests/unit_tests/test_202_polling.py @@ -0,0 +1,57 @@ +import requests_mock +from langchain_core.messages import AIMessage + +from langchain_nvidia_ai_endpoints import ChatNVIDIA + + +def test_polling_auth_header( + requests_mock: requests_mock.Mocker, + mock_model: str, +) -> None: + infer_url = "https://integrate.api.nvidia.com/v1/chat/completions" + polling_url = "https://api.nvcf.nvidia.com/v2/nvcf/pexec/status/test-request-id" + + requests_mock.post( + infer_url, status_code=202, headers={"NVCF-REQID": "test-request-id"}, json={} + ) + + requests_mock.get( + polling_url, + status_code=200, + json={ + "id": "mock-id", + "created": 1234567890, + "object": "chat.completion", + "model": mock_model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "WORKED"}, + } + ], + }, + ) + + client = ChatNVIDIA(model=mock_model, api_key="BOGUS") + response = client.invoke("IGNORED") + + # expected behavior - + # - first a GET request to /v1/models to check the model exists + # - second a POST request to /v1/chat/completions + # - third a GET request to /v2/nvcf/pexec/status/test-request-id + # we want to check on the second and third requests + + assert len(requests_mock.request_history) == 3 + + infer_request = requests_mock.request_history[-2] + assert infer_request.method == "POST" + assert infer_request.url == infer_url + assert infer_request.headers["Authorization"] == "Bearer BOGUS" + + poll_request = requests_mock.request_history[-1] + assert poll_request.method == "GET" + assert poll_request.url == polling_url + assert poll_request.headers["Authorization"] == "Bearer BOGUS" + + assert isinstance(response, AIMessage) + assert response.content == "WORKED" From 95a4233087bb034c69fe767294fe70c55da9496a Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Thu, 5 Sep 2024 07:31:47 -0400 Subject: [PATCH 2/3] add warning for debugging purposes --- libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py index 0d0893f7..55aa88ba 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py @@ -383,6 +383,7 @@ def _wait(self, response: Response, session: requests.Session) -> Response: "NVCF-REQID" in response.headers ), "Received 202 response with no request id to follow" request_id = response.headers.get("NVCF-REQID") + warnings.warn(f"Polling for response: {request_id}") # todo: remove payload = { "url": self.polling_url_tmpl.format(request_id=request_id), "headers": self.headers_tmpl["call"], From 04bbfc321c5c698ade798acdac9ff1b982f55364 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Tue, 10 Sep 2024 10:48:32 -0400 Subject: [PATCH 3/3] clarify descriptions for timeout & interval, remove debug warning --- .../langchain_nvidia_ai_endpoints/_common.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py index 55aa88ba..bd82b6df 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py @@ -74,8 +74,16 @@ class Config: api_key: Optional[SecretStr] = Field(description="API Key for service of choice") ## Generation arguments - timeout: float = Field(60, ge=0, description="Timeout for waiting on response (s)") - interval: float = Field(0.02, ge=0, description="Interval for pulling response") + timeout: float = Field( + 60, + ge=0, + description="The minimum amount of time (in sec) to poll after a 202 response", + ) + interval: float = Field( + 0.02, + ge=0, + description="Interval (in sec) between polling attempts after a 202 response", + ) last_inputs: Optional[dict] = Field( description="Last inputs sent over to the server" ) @@ -383,7 +391,6 @@ def _wait(self, response: Response, session: requests.Session) -> Response: "NVCF-REQID" in response.headers ), "Received 202 response with no request id to follow" request_id = response.headers.get("NVCF-REQID") - warnings.warn(f"Polling for response: {request_id}") # todo: remove payload = { "url": self.polling_url_tmpl.format(request_id=request_id), "headers": self.headers_tmpl["call"],