Skip to content

Commit

Permalink
fix review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
raspawar committed Jun 25, 2024
1 parent 834843e commit 4eb3d46
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 17 deletions.
16 changes: 6 additions & 10 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,6 @@ def available_models(self) -> list[Model]:
# add base model for local-nim mode
model.base_model = element.get("root")

if model.base_model and model.id != model.base_model:
model.model_type = "lora"

self._available_models.append(model)

return self._available_models
Expand Down Expand Up @@ -608,14 +605,13 @@ def get_available_models(
) -> List[Model]:
"""Retrieve a list of available models."""

# set client for lora models in local-nim mode
if not self.is_hosted:
for model in self.client.available_models:
if model.model_type == "lora":
model.client = filter

available = [
model for model in self.client.available_models if model.client == filter
model
for model in self.client.available_models
if (
model.client == filter
or (model.base_model and model.base_model != model.id)
)
]

# if we're talking to a hosted endpoint, we mix in the known models
Expand Down
23 changes: 20 additions & 3 deletions libs/ai-endpoints/tests/unit_tests/test_api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Any, Generator

import pytest
import requests
from langchain_core.pydantic_v1 import SecretStr
from requests_mock import Mocker


@contextmanager
Expand All @@ -18,6 +18,24 @@ def no_env_var(var: str) -> Generator[None, None, None]:
os.environ[var] = val


@pytest.fixture(autouse=True)
def mock_v1_local_models(requests_mock: Mocker) -> None:
requests_mock.get(
"https://test_url/v1/models",
json={
"data": [
{
"id": "model1",
"object": "model",
"created": 1234567890,
"owned_by": "OWNER",
"root": "model1",
},
]
},
)


def test_create_without_api_key(public_class: type) -> None:
with no_env_var("NVIDIA_API_KEY"):
with pytest.warns(UserWarning):
Expand All @@ -26,8 +44,7 @@ def test_create_without_api_key(public_class: type) -> None:

def test_create_unknown_url_no_api_key(public_class: type) -> None:
with no_env_var("NVIDIA_API_KEY"):
with pytest.raises(requests.exceptions.ConnectionError):
public_class(base_url="https://test_url/v1")
public_class(base_url="https://test_url/v1")


@pytest.mark.parametrize("param", ["nvidia_api_key", "api_key"])
Expand Down
25 changes: 21 additions & 4 deletions libs/ai-endpoints/tests/unit_tests/test_base_url.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
import requests
from requests_mock import Mocker


@pytest.mark.parametrize(
Expand All @@ -25,6 +25,24 @@ def test_param_base_url_hosted(public_class: type, base_url: str) -> None:
assert client._client.is_hosted


@pytest.fixture(autouse=True)
def mock_v1_local_models(requests_mock: Mocker, base_url: str) -> None:
requests_mock.get(
f"{base_url}/models",
json={
"data": [
{
"id": "model1",
"object": "model",
"created": 1234567890,
"owned_by": "OWNER",
"root": "model1",
},
]
},
)


@pytest.mark.parametrize(
"base_url",
[
Expand All @@ -34,6 +52,5 @@ def test_param_base_url_hosted(public_class: type, base_url: str) -> None:
],
)
def test_param_base_url_not_hosted(public_class: type, base_url: str) -> None:
with pytest.raises(requests.exceptions.ConnectionError):
client = public_class(base_url=base_url)
assert not client._client.is_hosted
client = public_class(base_url=base_url)
assert not client._client.is_hosted

0 comments on commit 4eb3d46

Please sign in to comment.