Skip to content

Commit

Permalink
Merge pull request #93 from langchain-ai/raspawar/base_url_fix
Browse files Browse the repository at this point in the history
Fix: Raise a warning for known endpoints
  • Loading branch information
raspawar authored Aug 28, 2024
2 parents 603a636 + d315b11 commit 199438f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 30 deletions.
25 changes: 6 additions & 19 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Tuple,
Union,
)
from urllib.parse import urlparse, urlunparse
from urllib.parse import urlparse

import requests
from langchain_core.pydantic_v1 import (
Expand Down Expand Up @@ -124,7 +124,7 @@ def _preprocess_args(cls, values: Dict[str, Any]) -> Dict[str, Any]:

## Making sure /v1 in added to the url, followed by infer_path
if "base_url" in values:
base_url = values["base_url"]
base_url = values["base_url"].strip("/")
parsed = urlparse(base_url)
expected_format = "Expected format is: http://host:port"

Expand All @@ -133,24 +133,11 @@ def _preprocess_args(cls, values: Dict[str, Any]) -> Dict[str, Any]:
f"Invalid base_url format. {expected_format} Got: {base_url}"
)

if parsed.path:
normalized_path = parsed.path.strip("/")
if normalized_path == "v1":
pass
elif normalized_path in [
"v1/embeddings",
"v1/completions",
"v1/rankings",
]:
warnings.warn(f"Using {base_url}, ignoring the rest")
else:
raise ValueError(
f"Base URL path is not recognized. {expected_format}"
)
if base_url.endswith(
("/embeddings", "/completions", "/rankings", "/reranking")
):
warnings.warn(f"Using {base_url}, ignoring the rest")

base_url = urlunparse(
(parsed.scheme, parsed.netloc, "v1", None, None, None)
)
values["base_url"] = base_url
values["infer_path"] = values["infer_path"].format(base_url=base_url)

Expand Down
15 changes: 4 additions & 11 deletions libs/ai-endpoints/tests/unit_tests/test_base_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def test_param_base_url_hosted(public_class: type, base_url: str) -> None:
"https://localhost",
"http://localhost:8888",
"http://0.0.0.0:8888/v1",
"http://0.0.0.0:8888/v1/",
"http://blah/some/other/path/v1",
],
)
def test_param_base_url_not_hosted(public_class: type, base_url: str) -> None:
Expand All @@ -107,18 +109,9 @@ def test_param_base_url_not_hosted(public_class: type, base_url: str) -> None:
[
"http://localhost:8888/embeddings",
"http://0.0.0.0:8888/rankings",
"http://localhost:8888/embeddings/",
"http://0.0.0.0:8888/rankings/",
"http://localhost:8888/chat/completions",
],
)
def test_expect_error(public_class: type, base_url: str) -> None:
with pytest.raises(ValueError) as e:
public_class(model="model1", base_url=base_url)
assert "Expected format is" in str(e.value)


@pytest.mark.parametrize(
"base_url",
[
"http://localhost:8080/v1/embeddings",
"http://0.0.0.0:8888/v1/rankings",
],
Expand Down

0 comments on commit 199438f

Please sign in to comment.