From 8ea0de755df3e2ab8946b5e1cb4638c96b0e6eda Mon Sep 17 00:00:00 2001 From: raspawar Date: Mon, 26 Aug 2024 19:21:54 +0530 Subject: [PATCH 1/2] raise a warning for known endpoints --- .../langchain_nvidia_ai_endpoints/_common.py | 23 ++++--------------- .../tests/unit_tests/test_base_url.py | 14 +++-------- 2 files changed, 7 insertions(+), 30 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py index 5393e160..687f2e13 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py @@ -17,7 +17,7 @@ Tuple, Union, ) -from urllib.parse import urlparse, urlunparse +from urllib.parse import urlparse import requests from langchain_core.pydantic_v1 import ( @@ -124,7 +124,7 @@ def _preprocess_args(cls, values: Dict[str, Any]) -> Dict[str, Any]: ## Making sure /v1 in added to the url, followed by infer_path if "base_url" in values: - base_url = values["base_url"] + base_url = values["base_url"].strip("/") parsed = urlparse(base_url) expected_format = "Expected format is: http://host:port" @@ -133,24 +133,9 @@ def _preprocess_args(cls, values: Dict[str, Any]) -> Dict[str, Any]: f"Invalid base_url format. {expected_format} Got: {base_url}" ) - if parsed.path: - normalized_path = parsed.path.strip("/") - if normalized_path == "v1": - pass - elif normalized_path in [ - "v1/embeddings", - "v1/completions", - "v1/rankings", - ]: - warnings.warn(f"Using {base_url}, ignoring the rest") - else: - raise ValueError( - f"Base URL path is not recognized. {expected_format}" - ) + if base_url.endswith(("/embeddings", "/completions", "/rankings")): + warnings.warn(f"Using {base_url}, ignoring the rest") - base_url = urlunparse( - (parsed.scheme, parsed.netloc, "v1", None, None, None) - ) values["base_url"] = base_url values["infer_path"] = values["infer_path"].format(base_url=base_url) diff --git a/libs/ai-endpoints/tests/unit_tests/test_base_url.py b/libs/ai-endpoints/tests/unit_tests/test_base_url.py index b5c3a8ef..e8e393b2 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_base_url.py +++ b/libs/ai-endpoints/tests/unit_tests/test_base_url.py @@ -94,6 +94,7 @@ def test_param_base_url_hosted(public_class: type, base_url: str) -> None: "https://localhost", "http://localhost:8888", "http://0.0.0.0:8888/v1", + "http://0.0.0.0:8888/v1/", ], ) def test_param_base_url_not_hosted(public_class: type, base_url: str) -> None: @@ -107,18 +108,9 @@ def test_param_base_url_not_hosted(public_class: type, base_url: str) -> None: [ "http://localhost:8888/embeddings", "http://0.0.0.0:8888/rankings", + "http://localhost:8888/embeddings/", + "http://0.0.0.0:8888/rankings/", "http://localhost:8888/chat/completions", - ], -) -def test_expect_error(public_class: type, base_url: str) -> None: - with pytest.raises(ValueError) as e: - public_class(model="model1", base_url=base_url) - assert "Expected format is" in str(e.value) - - -@pytest.mark.parametrize( - "base_url", - [ "http://localhost:8080/v1/embeddings", "http://0.0.0.0:8888/v1/rankings", ], From d315b112065909550a96f4b065752759f8652bfa Mon Sep 17 00:00:00 2001 From: raspawar Date: Tue, 27 Aug 2024 11:35:55 +0530 Subject: [PATCH 2/2] add reranker test case and url --- libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py | 4 +++- libs/ai-endpoints/tests/unit_tests/test_base_url.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py index 687f2e13..2bde648a 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py @@ -133,7 +133,9 @@ def _preprocess_args(cls, values: Dict[str, Any]) -> Dict[str, Any]: f"Invalid base_url format. {expected_format} Got: {base_url}" ) - if base_url.endswith(("/embeddings", "/completions", "/rankings")): + if base_url.endswith( + ("/embeddings", "/completions", "/rankings", "/reranking") + ): warnings.warn(f"Using {base_url}, ignoring the rest") values["base_url"] = base_url diff --git a/libs/ai-endpoints/tests/unit_tests/test_base_url.py b/libs/ai-endpoints/tests/unit_tests/test_base_url.py index e8e393b2..0baaca6c 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_base_url.py +++ b/libs/ai-endpoints/tests/unit_tests/test_base_url.py @@ -95,6 +95,7 @@ def test_param_base_url_hosted(public_class: type, base_url: str) -> None: "http://localhost:8888", "http://0.0.0.0:8888/v1", "http://0.0.0.0:8888/v1/", + "http://blah/some/other/path/v1", ], ) def test_param_base_url_not_hosted(public_class: type, base_url: str) -> None: