diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py index 5393e160..2bde648a 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py @@ -17,7 +17,7 @@ Tuple, Union, ) -from urllib.parse import urlparse, urlunparse +from urllib.parse import urlparse import requests from langchain_core.pydantic_v1 import ( @@ -124,7 +124,7 @@ def _preprocess_args(cls, values: Dict[str, Any]) -> Dict[str, Any]: ## Making sure /v1 in added to the url, followed by infer_path if "base_url" in values: - base_url = values["base_url"] + base_url = values["base_url"].strip("/") parsed = urlparse(base_url) expected_format = "Expected format is: http://host:port" @@ -133,24 +133,11 @@ def _preprocess_args(cls, values: Dict[str, Any]) -> Dict[str, Any]: f"Invalid base_url format. {expected_format} Got: {base_url}" ) - if parsed.path: - normalized_path = parsed.path.strip("/") - if normalized_path == "v1": - pass - elif normalized_path in [ - "v1/embeddings", - "v1/completions", - "v1/rankings", - ]: - warnings.warn(f"Using {base_url}, ignoring the rest") - else: - raise ValueError( - f"Base URL path is not recognized. {expected_format}" - ) + if base_url.endswith( + ("/embeddings", "/completions", "/rankings", "/reranking") + ): + warnings.warn(f"Using {base_url}, ignoring the rest") - base_url = urlunparse( - (parsed.scheme, parsed.netloc, "v1", None, None, None) - ) values["base_url"] = base_url values["infer_path"] = values["infer_path"].format(base_url=base_url) diff --git a/libs/ai-endpoints/tests/unit_tests/test_base_url.py b/libs/ai-endpoints/tests/unit_tests/test_base_url.py index b5c3a8ef..0baaca6c 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_base_url.py +++ b/libs/ai-endpoints/tests/unit_tests/test_base_url.py @@ -94,6 +94,8 @@ def test_param_base_url_hosted(public_class: type, base_url: str) -> None: "https://localhost", "http://localhost:8888", "http://0.0.0.0:8888/v1", + "http://0.0.0.0:8888/v1/", + "http://blah/some/other/path/v1", ], ) def test_param_base_url_not_hosted(public_class: type, base_url: str) -> None: @@ -107,18 +109,9 @@ def test_param_base_url_not_hosted(public_class: type, base_url: str) -> None: [ "http://localhost:8888/embeddings", "http://0.0.0.0:8888/rankings", + "http://localhost:8888/embeddings/", + "http://0.0.0.0:8888/rankings/", "http://localhost:8888/chat/completions", - ], -) -def test_expect_error(public_class: type, base_url: str) -> None: - with pytest.raises(ValueError) as e: - public_class(model="model1", base_url=base_url) - assert "Expected format is" in str(e.value) - - -@pytest.mark.parametrize( - "base_url", - [ "http://localhost:8080/v1/embeddings", "http://0.0.0.0:8888/v1/rankings", ],