From d22c13ec48be2828ead6bba46ca6611bb58e0b28 Mon Sep 17 00:00:00 2001 From: Ran Date: Wed, 6 Dec 2023 01:42:00 +0200 Subject: [PATCH] Mask API key for Minimax LLM (#14309) - **Description:** Added masking for the API key for Minimax LLM + tests inspired by https://github.com/langchain-ai/langchain/pull/12418. - **Issue:** the issue # fixes https://github.com/langchain-ai/langchain/issues/12165 - **Dependencies:** this fix is dependent on Minimax instantiation fix which is introduced in https://github.com/langchain-ai/langchain/pull/13439, so merge this one after. - **Tag maintainer:** @eyurtsev --------- Co-authored-by: Harrison Chase --- libs/langchain/langchain/llms/minimax.py | 29 ++++++------- .../tests/unit_tests/llms/test_minimax.py | 42 +++++++++++++++++++ 2 files changed, 55 insertions(+), 16 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/llms/test_minimax.py diff --git a/libs/langchain/langchain/llms/minimax.py b/libs/langchain/langchain/llms/minimax.py index 488f296d5a864..dd5f00266b21d 100644 --- a/libs/langchain/langchain/llms/minimax.py +++ b/libs/langchain/langchain/llms/minimax.py @@ -10,14 +10,14 @@ ) import requests -from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain.callbacks.manager import ( CallbackManagerForLLMRun, ) from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.utils import get_from_dict_or_env +from langchain.utils import convert_to_secret_str, get_from_dict_or_env logger = logging.getLogger(__name__) @@ -27,7 +27,7 @@ class _MinimaxEndpointClient(BaseModel): host: str group_id: str - api_key: str + api_key: SecretStr api_url: str @root_validator(pre=True, allow_reuse=True) @@ -40,7 +40,7 @@ def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]: return values def post(self, request: Any) -> Any: - headers = {"Authorization": f"Bearer {self.api_key}"} + headers = {"Authorization": f"Bearer {self.api_key.get_secret_value()}"} response = requests.post(self.api_url, headers=headers, json=request) # TODO: error handling and automatic retries if not response.ok: @@ -56,7 +56,7 @@ def post(self, request: Any) -> Any: class MinimaxCommon(BaseModel): """Common parameters for Minimax large language models.""" - _client: Any = None + _client: _MinimaxEndpointClient model: str = "abab5.5-chat" """Model name to use.""" max_tokens: int = 256 @@ -69,13 +69,13 @@ class MinimaxCommon(BaseModel): """Holds any model parameters valid for `create` call not explicitly specified.""" minimax_api_host: Optional[str] = None minimax_group_id: Optional[str] = None - minimax_api_key: Optional[str] = None + minimax_api_key: Optional[SecretStr] = None @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - values["minimax_api_key"] = get_from_dict_or_env( - values, "minimax_api_key", "MINIMAX_API_KEY" + values["minimax_api_key"] = convert_to_secret_str( + get_from_dict_or_env(values, "minimax_api_key", "MINIMAX_API_KEY") ) values["minimax_group_id"] = get_from_dict_or_env( values, "minimax_group_id", "MINIMAX_GROUP_ID" @@ -87,6 +87,11 @@ def validate_environment(cls, values: Dict) -> Dict: "MINIMAX_API_HOST", default="https://api.minimax.chat", ) + values["_client"] = _MinimaxEndpointClient( + host=values["minimax_api_host"], + api_key=values["minimax_api_key"], + group_id=values["minimax_group_id"], + ) return values @property @@ -110,14 +115,6 @@ def _llm_type(self) -> str: """Return type of llm.""" return "minimax" - def __init__(self, **data: Any): - super().__init__(**data) - self._client = _MinimaxEndpointClient( - host=self.minimax_api_host, - api_key=self.minimax_api_key, - group_id=self.minimax_group_id, - ) - class Minimax(MinimaxCommon, LLM): """Wrapper around Minimax large language models. diff --git a/libs/langchain/tests/unit_tests/llms/test_minimax.py b/libs/langchain/tests/unit_tests/llms/test_minimax.py new file mode 100644 index 0000000000000..9b53408f21d2a --- /dev/null +++ b/libs/langchain/tests/unit_tests/llms/test_minimax.py @@ -0,0 +1,42 @@ +"""Test Minimax llm""" +from typing import cast + +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture, MonkeyPatch + +from langchain.llms.minimax import Minimax + + +def test_api_key_is_secret_string() -> None: + llm = Minimax(minimax_api_key="secret-api-key", minimax_group_id="group_id") + assert isinstance(llm.minimax_api_key, SecretStr) + + +def test_api_key_masked_when_passed_from_env( + monkeypatch: MonkeyPatch, capsys: CaptureFixture +) -> None: + """Test initialization with an API key provided via an env variable""" + monkeypatch.setenv("MINIMAX_API_KEY", "secret-api-key") + monkeypatch.setenv("MINIMAX_GROUP_ID", "group_id") + llm = Minimax() + print(llm.minimax_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + """Test initialization with an API key provided via the initializer""" + llm = Minimax(minimax_api_key="secret-api-key", minimax_group_id="group_id") + print(llm.minimax_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_uses_actual_secret_value_from_secretstr() -> None: + """Test that actual secret is retrieved using `.get_secret_value()`.""" + llm = Minimax(minimax_api_key="secret-api-key", minimax_group_id="group_id") + assert cast(SecretStr, llm.minimax_api_key).get_secret_value() == "secret-api-key"