Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default model for NVIDIA HayStack local NIM endpoints #915

Merged
merged 21 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/nvidia/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ root = "../.."
git_describe_command = 'git describe --tags --match="integrations/nvidia-v[0-9]*"'

[tool.hatch.envs.default]
dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools", "requests_mock", "pydantic"]
dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools", "requests_mock"]
[tool.hatch.envs.default.scripts]
test = "pytest --reruns 3 --reruns-delay 30 -x {args:tests}"
test-cov = "coverage run -m pytest --reruns 3 --reruns-delay 30 -x {args:tests}"
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@

from haystack import Document, component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack_integrations.util.nvidia import EmbedderBackend, NimBackend
from tqdm import tqdm

from ._nim_backend import NimBackend
from .backend import EmbedderBackend
from .truncate import EmbeddingTruncateMode


Expand Down Expand Up @@ -49,6 +48,8 @@ def __init__(

:param model:
Embedding model to use.
If no specific model along with locally hosted API URL is provided,
the system defaults to the available model found using /models API.
:param api_key:
API key for the NVIDIA NIM.
:param api_url:
Expand Down Expand Up @@ -87,11 +88,15 @@ def __init__(

self.backend: Optional[EmbedderBackend] = None
self._initialized = False
self.is_hosted = urlparse(self.api_url).netloc in [
"integrate.api.nvidia.com",
"ai.api.nvidia.com",
]
if self.is_hosted and not self.model:

if (
urlparse(self.api_url).netloc
in [
"integrate.api.nvidia.com",
"ai.api.nvidia.com",
]
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
and not self.model
):
# manually set default model
self.model = "NV-Embed-QA"

Expand Down Expand Up @@ -132,7 +137,8 @@ def warm_up(self):
)

self._initialized = True
if not self.is_hosted and not self.model:

if not self.model:
self.default_model()

def to_dict(self) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

from haystack import component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack_integrations.util.nvidia import EmbedderBackend, NimBackend

from ._nim_backend import NimBackend
from .backend import EmbedderBackend
from .truncate import EmbeddingTruncateMode


Expand Down Expand Up @@ -46,6 +45,8 @@ def __init__(

:param model:
Embedding model to use.
If no specific model along with locally hosted API URL is provided,
the system defaults to the available model found using /models API.
:param api_key:
API key for the NVIDIA NIM.
:param api_url:
Expand All @@ -72,11 +73,14 @@ def __init__(
self.backend: Optional[EmbedderBackend] = None
self._initialized = False

self.is_hosted = urlparse(self.api_url).netloc in [
"integrate.api.nvidia.com",
"ai.api.nvidia.com",
]
if self.is_hosted and not self.model:
if (
urlparse(self.api_url).netloc
in [
"integrate.api.nvidia.com",
"ai.api.nvidia.com",
]
and not self.model
):
# manually set default model
self.model = "NV-Embed-QA"

Expand Down Expand Up @@ -118,7 +122,7 @@ def warm_up(self):

self._initialized = True

if not self.is_hosted and not self.model:
if not self.model:
self.default_model()

def to_dict(self) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

from haystack import component, default_from_dict, default_to_dict
from haystack.utils.auth import Secret, deserialize_secrets_inplace

from ._nim_backend import NimBackend
from .backend import GeneratorBackend
from haystack_integrations.util.nvidia import GeneratorBackend, NimBackend

_DEFAULT_API_URL = "https://integrate.api.nvidia.com/v1"

Expand Down Expand Up @@ -55,6 +53,8 @@ def __init__(
Name of the model to use for text generation.
See the [NVIDIA NIMs](https://ai.nvidia.com)
for more information on the supported models.
`Note`: If no specific model along with locally hosted API URL is provided,
the system defaults to the available model found using /models API.
:param api_key:
API key for the NVIDIA NIM.
:param api_url:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .backend import EmbedderBackend, GeneratorBackend, Model
from .nim_backend import NimBackend

__all__ = ["NimBackend", "EmbedderBackend", "GeneratorBackend", "Model"]
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

from pydantic import BaseModel, ConfigDict


class Model(BaseModel):
@dataclass
class Model:
"""
Model information.

Expand All @@ -15,11 +15,47 @@ class Model(BaseModel):
"""

id: str
model_config = ConfigDict(from_attributes=True, protected_namespaces=())
aliases: Optional[list] = None
aliases: Optional[List[str]] = field(default_factory=list)
base_model: Optional[str] = None


class EmbedderBackend(ABC):
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, model: str, model_kwargs: Optional[Dict[str, Any]] = None):
"""
Initialize the backend.

:param model:
The name of the model to use.
:param model_kwargs:
Additional keyword arguments to pass to the model.
"""
self.model_name = model
self.model_kwargs = model_kwargs or {}

@abstractmethod
def embed(self, texts: List[str]) -> Tuple[List[List[float]], Dict[str, Any]]:
"""
Invoke the backend and embed the given texts.

:param texts:
Texts to embed.
:return:
Vector representation of the texts and
metadata returned by the service.
"""
pass

@abstractmethod
def models(self) -> List[Model]:
"""
Invoke the backend to get available models.

:return:
Available models
"""
pass


class GeneratorBackend(ABC):
def __init__(self, model: str, model_kwargs: Optional[Dict[str, Any]] = None):
"""
Expand Down Expand Up @@ -52,6 +88,6 @@ def models(self) -> List[Model]:
Invoke the backend to get available models.

:return:
Models available
Available models
"""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import requests
from haystack.utils import Secret

from .backend import GeneratorBackend, Model
from .backend import EmbedderBackend, GeneratorBackend, Model

REQUEST_TIMEOUT = 60


class NimBackend(GeneratorBackend):
class NimBackend(GeneratorBackend, EmbedderBackend):
def __init__(
self,
model: str,
Expand All @@ -31,6 +31,26 @@ def __init__(
self.api_url = api_url
self.model_kwargs = model_kwargs or {}

def embed(self, texts: List[str]) -> Tuple[List[List[float]], Dict[str, Any]]:
url = f"{self.api_url}/embeddings"

res = self.session.post(
url,
json={
"model": self.model,
"input": texts,
**self.model_kwargs,
},
timeout=REQUEST_TIMEOUT,
)
res.raise_for_status()

data = res.json()
# Sort the embeddings by index, we don't know whether they're out of order or not
embeddings = [e["embedding"] for e in sorted(data["data"], key=lambda e: e["index"])]

return embeddings, {"usage": data["usage"]}

def generate(self, prompt: str) -> Tuple[List[str], List[Dict[str, Any]]]:
# We're using the chat completion endpoint as the NIM API doesn't support
# the /completions endpoint. So both the non-chat and chat generator will use this.
Expand Down Expand Up @@ -89,9 +109,8 @@ def models(self) -> List[Model]:
res.raise_for_status()

data = res.json()["data"]
models = []
for element in data:
assert "id" in element, f"No id found in {element}"
models.append(Model(id=element["id"]))

models = [Model(element["id"]) for element in data if "id" in element]
if not models:
msg = "No valid hosted model found."
raspawar marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(msg)
return models
Loading