Skip to content

Commit

Permalink
Fixes to NvidiaRanker
Browse files Browse the repository at this point in the history
  • Loading branch information
sjrl committed Nov 14, 2024
1 parent a986ace commit 67a898d
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "NvidiaDocumentEmbedder":
:returns:
The deserialized component.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
init_parameters = data.get("init_parameters", {})
if init_parameters:
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import warnings
from typing import Any, Dict, List, Optional, Union

from haystack import Document, component, default_from_dict, default_to_dict
from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.utils import Secret, deserialize_secrets_inplace

from haystack_integrations.utils.nvidia import NimBackend, url_validation

from .truncate import RankerTruncateMode

logger = logging.getLogger(__name__)

_DEFAULT_MODEL = "nvidia/nv-rerankqa-mistral-4b-v3"

_MODEL_ENDPOINT_MAP = {
Expand Down Expand Up @@ -51,7 +53,7 @@ def __init__(
model: Optional[str] = None,
truncate: Optional[Union[RankerTruncateMode, str]] = None,
api_url: Optional[str] = None,
api_key: Optional[Secret] = None,
api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"),
top_k: int = 5,
):
"""
Expand Down Expand Up @@ -100,6 +102,7 @@ def __init__(
self._api_key = Secret.from_env_var("NVIDIA_API_KEY")
self._top_k = top_k
self._initialized = False
self._backend: Optional[Any] = None

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -113,7 +116,7 @@ def to_dict(self) -> Dict[str, Any]:
top_k=self._top_k,
truncate=self._truncate,
api_url=self._api_url,
api_key=self._api_key,
api_key=self._api_key.to_dict() if self._api_key else None,
)

@classmethod
Expand All @@ -124,7 +127,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "NvidiaRanker":
:param data: A dictionary containing the ranker's attributes.
:returns: The deserialized ranker.
"""
deserialize_secrets_inplace(data, keys=["api_key"])
init_parameters = data.get("init_parameters", {})
if init_parameters:
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)

def warm_up(self):
Expand Down Expand Up @@ -170,24 +175,24 @@ def run(
msg = "The ranker has not been loaded. Please call warm_up() before running."
raise RuntimeError(msg)
if not isinstance(query, str):
msg = "Ranker expects the `query` parameter to be a string."
msg = "NvidiaRanker expects the `query` parameter to be a string."
raise TypeError(msg)
if not isinstance(documents, list):
msg = "Ranker expects the `documents` parameter to be a list."
msg = "NvidiaRanker expects the `documents` parameter to be a list."
raise TypeError(msg)
if not all(isinstance(doc, Document) for doc in documents):
msg = "Ranker expects the `documents` parameter to be a list of Document objects."
msg = "NvidiaRanker expects the `documents` parameter to be a list of Document objects."
raise TypeError(msg)
if top_k is not None and not isinstance(top_k, int):
msg = "Ranker expects the `top_k` parameter to be an integer."
msg = "NvidiaRanker expects the `top_k` parameter to be an integer."
raise TypeError(msg)

if len(documents) == 0:
return {"documents": []}

top_k = top_k if top_k is not None else self._top_k
if top_k < 1:
warnings.warn("top_k should be at least 1, returning nothing", stacklevel=2)
logger.warning("top_k should be at least 1, returning nothing")
return {"documents": []}

assert self._backend is not None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import warnings
from typing import List
from typing import List, Optional
from urllib.parse import urlparse, urlunparse


def url_validation(api_url: str, default_api_url: str, allowed_paths: List[str]) -> str:
def url_validation(api_url: str, default_api_url: Optional[str], allowed_paths: List[str]) -> str:
"""
Validate and normalize an API URL.
Expand Down

0 comments on commit 67a898d

Please sign in to comment.