From 13621450b6d2a689dcf5f0077457cfa67aa0e8df Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Thu, 24 Oct 2024 11:13:01 -0400 Subject: [PATCH] allow base_url to work with proxies, e.g. http://host/proxy/path/v1 fixes #90 --- .../langchain_nvidia_ai_endpoints/_common.py | 26 ++++++++++--- libs/ai-endpoints/pyproject.toml | 2 +- .../tests/unit_tests/test_base_url.py | 37 ++++++++++++++++++- 3 files changed, 58 insertions(+), 7 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py index 218a0ab4..daa47ebf 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py @@ -126,6 +126,17 @@ class _NVIDIAClient(BaseModel): @field_validator("base_url") def _validate_base_url(cls, v: str) -> str: + """ + validate the base_url. + + if the base_url is not a url, raise an error + + if the base_url does not end in /v1, e.g. /embeddings, /completions, /rankings, + or /reranking, emit a warning. old documentation told users to pass in the full + inference url, which is incorrect and prevents model listing from working. + + normalize base_url to end in /v1 + """ ## Making sure /v1 in added to the url if v is not None: parsed = urlparse(v) @@ -135,12 +146,17 @@ def _validate_base_url(cls, v: str) -> str: expected_format = "Expected format is: http://host:port" raise ValueError(f"Invalid base_url format. {expected_format} Got: {v}") - if v.strip("/").endswith( - ("/embeddings", "/completions", "/rankings", "/reranking") - ): - warnings.warn(f"Using {v}, ignoring the rest") + normalized_path = parsed.path.rstrip("/") + if not normalized_path.endswith("/v1"): + warnings.warn( + f"{v} does not end in /v1, you may " + "have inference and listing issues" + ) + normalized_path += "/v1" - v = urlunparse((parsed.scheme, parsed.netloc, "v1", None, None, None)) + v = urlunparse( + (parsed.scheme, parsed.netloc, normalized_path, None, None, None) + ) return v diff --git a/libs/ai-endpoints/pyproject.toml b/libs/ai-endpoints/pyproject.toml index 5fd51139..d7b060d6 100644 --- a/libs/ai-endpoints/pyproject.toml +++ b/libs/ai-endpoints/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-nvidia-ai-endpoints" -version = "0.3.1" +version = "0.3.2" description = "An integration package connecting NVIDIA AI Endpoints and LangChain" authors = [] readme = "README.md" diff --git a/libs/ai-endpoints/tests/unit_tests/test_base_url.py b/libs/ai-endpoints/tests/unit_tests/test_base_url.py index fb9c513a..8c108830 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_base_url.py +++ b/libs/ai-endpoints/tests/unit_tests/test_base_url.py @@ -1,5 +1,6 @@ import os import re +import warnings from typing import Any import pytest @@ -98,6 +99,7 @@ def test_param_base_url_hosted(public_class: type, base_url: str) -> None: ], ) def test_param_base_url_not_hosted(public_class: type, base_url: str) -> None: + warnings.filterwarnings("ignore", r".*does not end in /v1.*") with no_env_var("NVIDIA_BASE_URL"): client = public_class(model="model1", base_url=base_url) assert not client._client.is_hosted @@ -119,9 +121,42 @@ def test_expect_warn(public_class: type, base_url: str) -> None: with pytest.warns(UserWarning) as record: public_class(model="model1", base_url=base_url) assert len(record) == 1 - assert "ignoring the rest" in str(record[0].message) + assert "does not end in /v1" in str(record[0].message) def test_default_hosted(public_class: type) -> None: x = public_class(api_key="BOGUS") assert x._client.is_hosted + + +@pytest.mark.parametrize( + "base_url", + [ + "http://host/path0/path1/path2/v1", + "http://host:123/path0/path1/path2/v1/", + ], +) +def test_proxy_base_url( + public_class: type, base_url: str, requests_mock: Mocker +) -> None: + with no_env_var("NVIDIA_BASE_URL"): + client = public_class(model="model1", base_url=base_url) + assert base_url.startswith(client.base_url) + + +@pytest.mark.parametrize( + "base_url", + [ + "http://host/path0/path1/path2/v1", + "http://host:123/path0/path1/path2/v1/", + ], +) +def test_proxy_base_url_models( + public_class: type, base_url: str, requests_mock: Mocker +) -> None: + with no_env_var("NVIDIA_BASE_URL"): + client = public_class(model="model1", base_url=base_url) + client.available_models + models_url = base_url.rstrip("/") + "/models" + assert requests_mock.last_request + assert requests_mock.last_request.url == models_url