Skip to content

Commit

Permalink
fix: nvidia-haystack- Handle non-strict env var secrets correctly (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
shadeMe authored Mar 6, 2024
1 parent 5a339d4 commit 4998a7a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ def __init__(
if isinstance(model, str):
model = NvidiaEmbeddingModel.from_str(model)

resolved_api_key = api_key.resolve_value()
assert resolved_api_key is not None

# Upper-limit for the endpoint.
if batch_size > MAX_INPUTS:
msg = f"NVIDIA Cloud Functions currently support a maximum batch size of {MAX_INPUTS}."
Expand All @@ -83,7 +80,7 @@ def __init__(
self.embedding_separator = embedding_separator

self.client = NvidiaCloudFunctionsClient(
api_key=resolved_api_key,
api_key=api_key,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
Expand Down Expand Up @@ -193,7 +190,7 @@ def run(self, documents: List[Document]):
if not self._initialized:
msg = "The embedding model has not been loaded. Please call warm_up() before running."
raise RuntimeError(msg)
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
elif not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
msg = (
"NvidiaDocumentEmbedder expects a list of Documents as input."
"In case you want to embed a string, please use the NvidiaTextEmbedder."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,12 @@ def __init__(
if isinstance(model, str):
model = NvidiaEmbeddingModel.from_str(model)

resolved_api_key = api_key.resolve_value()
assert resolved_api_key is not None

self.api_key = api_key
self.model = model
self.prefix = prefix
self.suffix = suffix
self.client = NvidiaCloudFunctionsClient(
api_key=resolved_api_key,
api_key=api_key,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
Expand Down Expand Up @@ -128,7 +125,7 @@ def run(self, text: str):
if not self._initialized:
msg = "The embedding model has not been loaded. Please call warm_up() before running."
raise RuntimeError(msg)
if not isinstance(text, str):
elif not isinstance(text, str):
msg = (
"NvidiaTextEmbedder expects a string as an input."
"In case you want to embed a list of Documents, please use the NvidiaDocumentEmbedder."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Dict, Optional

import requests
from haystack.utils import Secret

FUNCTIONS_ENDPOINT = "https://api.nvcf.nvidia.com/v2/nvcf/functions"
INVOKE_ENDPOINT = "https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions"
Expand All @@ -19,13 +20,17 @@ class AvailableNvidiaCloudFunctions:


class NvidiaCloudFunctionsClient:
def __init__(self, *, api_key: str, headers: Dict[str, str], timeout: int = 60):
self.api_key = api_key
def __init__(self, *, api_key: Secret, headers: Dict[str, str], timeout: int = 60):
self.api_key = api_key.resolve_value()
if self.api_key is None:
msg = "Nvidia Cloud Functions API key is not set."
raise ValueError(msg)

self.fetch_url_format = STATUS_ENDPOINT
self.headers = copy.deepcopy(headers)
self.headers.update(
{
"Authorization": f"Bearer {api_key}",
"Authorization": f"Bearer {self.api_key}",
}
)
self.timeout = timeout
Expand Down

0 comments on commit 4998a7a

Please sign in to comment.