Skip to content

Commit

Permalink
Merge pull request #50 from langchain-ai/raspawar/LLOC-80
Browse files Browse the repository at this point in the history
Fix for NVIDIAClient output to redact the api_key
  • Loading branch information
mattf authored Jun 10, 2024
2 parents 21b2203 + c484b4c commit 8706589
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 5 deletions.
29 changes: 24 additions & 5 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import time
import warnings
from copy import deepcopy
from typing import (
Any,
AsyncIterator,
Expand Down Expand Up @@ -125,7 +126,7 @@ def headers(self) -> dict:
for header in headers_.values():
if "{api_key}" in header["Authorization"] and self.api_key:
header["Authorization"] = header["Authorization"].format(
api_key=self.api_key.get_secret_value(),
api_key=self.api_key,
)
return headers_

Expand All @@ -151,6 +152,13 @@ def _validate_model(cls, values: Dict[str, Any]) -> Dict[str, Any]:
)
return values

def __add_authorization(self, payload: dict) -> dict:
if self.api_key:
payload["headers"].update(
{"Authorization": f"Bearer {self.api_key.get_secret_value()}"}
)
return payload

@property
def available_models(self) -> list[Model]:
"""List the available models that can be invoked."""
Expand Down Expand Up @@ -200,7 +208,9 @@ def _post(
"stream": False,
}
session = self.get_session_fn()
self.last_response = response = session.post(**self.last_inputs)
self.last_response = response = session.post(
**self.__add_authorization(deepcopy(self.last_inputs))
)
self._try_raise(response)
return response, session

Expand All @@ -217,8 +227,11 @@ def _get(
}
if payload:
self.last_inputs["json"] = self.payload_fn(payload)

session = self.get_session_fn()
self.last_response = response = session.get(**self.last_inputs)
self.last_response = response = session.get(
**self.__add_authorization(deepcopy(self.last_inputs))
)
self._try_raise(response)
return response, session

Expand Down Expand Up @@ -440,7 +453,10 @@ def get_req_stream(
"json": self.payload_fn(payload),
"stream": True,
}
response = self.get_session_fn().post(**self.last_inputs)

response = self.get_session_fn().post(
**self.__add_authorization(deepcopy(self.last_inputs))
)
self._try_raise(response)
call = self.copy()

Expand Down Expand Up @@ -474,8 +490,11 @@ async def get_req_astream(
"headers": self.headers["stream"],
"json": self.payload_fn(payload),
}

async with self.get_asession_fn() as session:
async with session.post(**self.last_inputs) as response:
async with session.post(
**self.__add_authorization(deepcopy(self.last_inputs))
) as response:
self._try_raise(response)
async for line in response.content.iter_any():
if line and line.strip() != b"data: [DONE]":
Expand Down
20 changes: 20 additions & 0 deletions libs/ai-endpoints/tests/integration_tests/test_api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from langchain_core.documents import Document
from langchain_core.messages import HumanMessage

from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings, NVIDIARerank

Expand Down Expand Up @@ -49,3 +50,22 @@ def test_api_key(public_class: type, param: str) -> None:
with no_env_var("NVIDIA_API_KEY"):
client = public_class(**{param: api_key})
contact_service(client)


def test_api_key_leakage(chat_model: str, mode: dict) -> None:
"""Test ChatNVIDIA wrapper."""
chat = ChatNVIDIA(model=chat_model, temperature=0.7, **mode)
message = HumanMessage(content="Hello")
chat.invoke([message])

# check last_input post request
last_inputs = chat._client.client.last_inputs
assert last_inputs

authorization_header = last_inputs.get("headers", {}).get("Authorization")

if authorization_header:
key = authorization_header.split("Bearer ")[1]

assert not key.startswith("nvapi-")
assert key == "**********"
11 changes: 11 additions & 0 deletions libs/ai-endpoints/tests/unit_tests/test_api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Generator

import pytest
from langchain_core.pydantic_v1 import SecretStr


@contextmanager
Expand Down Expand Up @@ -43,3 +44,13 @@ def get_api_key(instance: Any) -> str:
assert get_api_key(public_class(nvidia_api_key="PARAM")) == "PARAM"
assert get_api_key(public_class(api_key="PARAM")) == "PARAM"
assert get_api_key(public_class(api_key="LOW", nvidia_api_key="HIGH")) == "HIGH"


def test_api_key_type(public_class: type) -> None:
# Test case to make sure the api_key is SecretStr and not str
def get_api_key(instance: Any) -> str:
return instance._client.client.api_key

with no_env_var("NVIDIA_API_KEY"):
os.environ["NVIDIA_API_KEY"] = "ENV"
assert type(get_api_key(public_class())) == SecretStr

0 comments on commit 8706589

Please sign in to comment.