diff --git a/libs/vertexai/langchain_google_vertexai/_base.py b/libs/vertexai/langchain_google_vertexai/_base.py index 69fe4adb..6811c3d3 100644 --- a/libs/vertexai/langchain_google_vertexai/_base.py +++ b/libs/vertexai/langchain_google_vertexai/_base.py @@ -1,7 +1,7 @@ from __future__ import annotations from concurrent.futures import Executor -from typing import Any, ClassVar, Dict, List, Optional +from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple import vertexai # type: ignore[import-untyped] from google.api_core.client_options import ClientOptions @@ -54,7 +54,7 @@ class _VertexAIBase(BaseModel): async_client: Any = None #: :meta private: project: Optional[str] = None "The default GCP project to use when making Vertex API calls." - location: str = _DEFAULT_LOCATION + location: str = Field(default=_DEFAULT_LOCATION) "The default location to use when making API calls." request_parallelism: int = 5 "The amount of parallelism allowed for requests issued to VertexAI models. " @@ -66,9 +66,22 @@ class _VertexAIBase(BaseModel): "Optional list of stop words to use when generating." model_name: Optional[str] = Field(default=None, alias="model") "Underlying model name." - model_family: Optional[GoogleModelFamily] = None - full_model_name: Optional[str] = None - """The full name of the model's endpoint.""" + model_family: Optional[GoogleModelFamily] = None #: :meta private: + full_model_name: Optional[str] = None #: :meta private: + "The full name of the model's endpoint." + client_options: Optional["ClientOptions"] = Field( + default=None, exclude=True + ) #: :meta private: + api_endpoint: Optional[str] = None + "Desired API endpoint, e.g., us-central1-aiplatform.googleapis.com" + default_metadata: Sequence[Tuple[str, str]] = Field( + default_factory=list + ) #: :meta private: + additional_headers: Optional[Dict[str, str]] = Field(default=None) + "A key-value dictionary representing additional headers for the model call" + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None + "A callback which returns client certificate bytes and private key bytes both " + "in PEM format." class Config: """Configuration for this pydantic object.""" @@ -82,17 +95,25 @@ def validate_params_base(cls, values: dict) -> dict: values["model_name"] = values.pop("model") if values.get("project") is None: values["project"] = initializer.global_config.project + if values.get("api_endpoint"): + api_endpoint = values["api-endpoint"] + else: + location = values.get("location", cls.__fields__["location"].default) + api_endpoint = f"{location}-{constants.PREDICTION_API_BASE_PATH}" + client_options = ClientOptions(api_endpoint=api_endpoint) + if values.get("client_cert_source"): + client_options.client_cert_source = values["client_cert_source"] + values["client_options"] = client_options + additional_headers = values.get("additional_headers", {}) + values["default_metadata"] = tuple(additional_headers.items()) return values @property def prediction_client(self) -> v1beta1PredictionServiceClient: """Returns PredictionServiceClient.""" if self.client is None: - client_options = { - "api_endpoint": f"{self.location}-{constants.PREDICTION_API_BASE_PATH}" - } self.client = v1beta1PredictionServiceClient( - client_options=client_options, + client_options=self.client_options, client_info=get_client_info(module=self._user_agent), ) return self.client @@ -101,11 +122,8 @@ def prediction_client(self) -> v1beta1PredictionServiceClient: def async_prediction_client(self) -> v1beta1PredictionServiceAsyncClient: """Returns PredictionServiceClient.""" if self.async_client is None: - client_options = { - "api_endpoint": f"{self.location}-{constants.PREDICTION_API_BASE_PATH}" - } self.async_client = v1beta1PredictionServiceAsyncClient( - client_options=ClientOptions(**client_options), + client_options=self.client_options, client_info=get_client_info(module=self._user_agent), ) return self.async_client @@ -157,8 +175,6 @@ class _VertexAICommon(_VertexAIBase): api_transport: Optional[str] = None """The desired API transport method, can be either 'grpc' or 'rest'""" - api_endpoint: Optional[str] = None - """The desired API endpoint, e.g., us-central1-aiplatform.googleapis.com""" tuned_model_name: Optional[str] = None """The name of a tuned model. If tuned_model_name is passed model_name will be used to determine the model family diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index a63705a5..c6f6713e 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -105,6 +105,7 @@ create_retry_decorator, get_generation_info, _format_model_name, + is_gemini_model, ) from langchain_google_vertexai.functions_utils import ( _format_tool_config, @@ -611,10 +612,7 @@ def validate_environment(cls, values: Dict) -> Dict: project=values["project"], ) - if safety_settings and values["model_family"] not in [ - GoogleModelFamily.GEMINI, - GoogleModelFamily.GEMINI_ADVANCED, - ]: + if safety_settings and not is_gemini_model(values["model_family"]): raise ValueError("Safety settings are only supported for Gemini models") if tuned_model_name: @@ -622,10 +620,7 @@ def validate_environment(cls, values: Dict) -> Dict: else: generative_model_name = values["model_name"] - if values["model_family"] not in [ - GoogleModelFamily.GEMINI, - GoogleModelFamily.GEMINI_ADVANCED, - ]: + if not is_gemini_model(values["model_family"]): cls._init_vertexai(values) if values["model_family"] == GoogleModelFamily.CODEY: model_cls = CodeChatModel @@ -784,6 +779,7 @@ def _generate_gemini( self.prediction_client.generate_content, max_retries=self.max_retries, request=request, + metadata=self.default_metadata, **kwargs, ) return self._gemini_response_to_chat_result(response) @@ -802,6 +798,7 @@ async def _agenerate_gemini( messages=messages, stop=stop, **kwargs ), is_gemini=True, + metadata=self.default_metadata, **kwargs, ) return self._gemini_response_to_chat_result(response) @@ -988,6 +985,7 @@ def _stream_gemini( max_retries=self.max_retries, request=request, is_gemini=True, + metadata=self.default_metadata, **kwargs, ) for response_chunk in response_iter: diff --git a/libs/vertexai/tests/unit_tests/test_chat_models.py b/libs/vertexai/tests/unit_tests/test_chat_models.py index bc07c90b..0038b2de 100644 --- a/libs/vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/vertexai/tests/unit_tests/test_chat_models.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass from typing import Any, Dict, Optional -from unittest.mock import ANY, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest from google.cloud.aiplatform_v1beta1.types import ( @@ -97,11 +97,10 @@ def test_init_client(model: str, location: str) -> None: mock_prediction_service.return_value.generate_content.return_value = response llm._generate_gemini(messages=[]) + mock_prediction_service.assert_called_once() client_info = mock_prediction_service.call_args.kwargs["client_info"] - mock_prediction_service.assert_called_once_with( - client_options={"api_endpoint": f"{location}-aiplatform.googleapis.com"}, - client_info=ANY, - ) + client_options = mock_prediction_service.call_args.kwargs["client_options"] + assert client_options.api_endpoint == f"{location}-aiplatform.googleapis.com" assert "langchain-google-vertexai" in client_info.user_agent assert "ChatVertexAI" in client_info.user_agent assert "langchain-google-vertexai" in client_info.client_library_version