Skip to content

Commit

Permalink
allowed custom headers, cert and api_endpoint (langchain-ai#266)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin authored Jun 3, 2024
1 parent ad02b46 commit 5164781
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 28 deletions.
46 changes: 31 additions & 15 deletions libs/vertexai/langchain_google_vertexai/_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from concurrent.futures import Executor
from typing import Any, ClassVar, Dict, List, Optional
from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple

import vertexai # type: ignore[import-untyped]
from google.api_core.client_options import ClientOptions
Expand Down Expand Up @@ -54,7 +54,7 @@ class _VertexAIBase(BaseModel):
async_client: Any = None #: :meta private:
project: Optional[str] = None
"The default GCP project to use when making Vertex API calls."
location: str = _DEFAULT_LOCATION
location: str = Field(default=_DEFAULT_LOCATION)
"The default location to use when making API calls."
request_parallelism: int = 5
"The amount of parallelism allowed for requests issued to VertexAI models. "
Expand All @@ -66,9 +66,22 @@ class _VertexAIBase(BaseModel):
"Optional list of stop words to use when generating."
model_name: Optional[str] = Field(default=None, alias="model")
"Underlying model name."
model_family: Optional[GoogleModelFamily] = None
full_model_name: Optional[str] = None
"""The full name of the model's endpoint."""
model_family: Optional[GoogleModelFamily] = None #: :meta private:
full_model_name: Optional[str] = None #: :meta private:
"The full name of the model's endpoint."
client_options: Optional["ClientOptions"] = Field(
default=None, exclude=True
) #: :meta private:
api_endpoint: Optional[str] = None
"Desired API endpoint, e.g., us-central1-aiplatform.googleapis.com"
default_metadata: Sequence[Tuple[str, str]] = Field(
default_factory=list
) #: :meta private:
additional_headers: Optional[Dict[str, str]] = Field(default=None)
"A key-value dictionary representing additional headers for the model call"
client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None
"A callback which returns client certificate bytes and private key bytes both "
"in PEM format."

class Config:
"""Configuration for this pydantic object."""
Expand All @@ -82,17 +95,25 @@ def validate_params_base(cls, values: dict) -> dict:
values["model_name"] = values.pop("model")
if values.get("project") is None:
values["project"] = initializer.global_config.project
if values.get("api_endpoint"):
api_endpoint = values["api-endpoint"]
else:
location = values.get("location", cls.__fields__["location"].default)
api_endpoint = f"{location}-{constants.PREDICTION_API_BASE_PATH}"
client_options = ClientOptions(api_endpoint=api_endpoint)
if values.get("client_cert_source"):
client_options.client_cert_source = values["client_cert_source"]
values["client_options"] = client_options
additional_headers = values.get("additional_headers", {})
values["default_metadata"] = tuple(additional_headers.items())
return values

@property
def prediction_client(self) -> v1beta1PredictionServiceClient:
"""Returns PredictionServiceClient."""
if self.client is None:
client_options = {
"api_endpoint": f"{self.location}-{constants.PREDICTION_API_BASE_PATH}"
}
self.client = v1beta1PredictionServiceClient(
client_options=client_options,
client_options=self.client_options,
client_info=get_client_info(module=self._user_agent),
)
return self.client
Expand All @@ -101,11 +122,8 @@ def prediction_client(self) -> v1beta1PredictionServiceClient:
def async_prediction_client(self) -> v1beta1PredictionServiceAsyncClient:
"""Returns PredictionServiceClient."""
if self.async_client is None:
client_options = {
"api_endpoint": f"{self.location}-{constants.PREDICTION_API_BASE_PATH}"
}
self.async_client = v1beta1PredictionServiceAsyncClient(
client_options=ClientOptions(**client_options),
client_options=self.client_options,
client_info=get_client_info(module=self._user_agent),
)
return self.async_client
Expand Down Expand Up @@ -157,8 +175,6 @@ class _VertexAICommon(_VertexAIBase):

api_transport: Optional[str] = None
"""The desired API transport method, can be either 'grpc' or 'rest'"""
api_endpoint: Optional[str] = None
"""The desired API endpoint, e.g., us-central1-aiplatform.googleapis.com"""
tuned_model_name: Optional[str] = None
"""The name of a tuned model. If tuned_model_name is passed
model_name will be used to determine the model family
Expand Down
14 changes: 6 additions & 8 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
create_retry_decorator,
get_generation_info,
_format_model_name,
is_gemini_model,
)
from langchain_google_vertexai.functions_utils import (
_format_tool_config,
Expand Down Expand Up @@ -611,21 +612,15 @@ def validate_environment(cls, values: Dict) -> Dict:
project=values["project"],
)

if safety_settings and values["model_family"] not in [
GoogleModelFamily.GEMINI,
GoogleModelFamily.GEMINI_ADVANCED,
]:
if safety_settings and not is_gemini_model(values["model_family"]):
raise ValueError("Safety settings are only supported for Gemini models")

if tuned_model_name:
generative_model_name = values["tuned_model_name"]
else:
generative_model_name = values["model_name"]

if values["model_family"] not in [
GoogleModelFamily.GEMINI,
GoogleModelFamily.GEMINI_ADVANCED,
]:
if not is_gemini_model(values["model_family"]):
cls._init_vertexai(values)
if values["model_family"] == GoogleModelFamily.CODEY:
model_cls = CodeChatModel
Expand Down Expand Up @@ -784,6 +779,7 @@ def _generate_gemini(
self.prediction_client.generate_content,
max_retries=self.max_retries,
request=request,
metadata=self.default_metadata,
**kwargs,
)
return self._gemini_response_to_chat_result(response)
Expand All @@ -802,6 +798,7 @@ async def _agenerate_gemini(
messages=messages, stop=stop, **kwargs
),
is_gemini=True,
metadata=self.default_metadata,
**kwargs,
)
return self._gemini_response_to_chat_result(response)
Expand Down Expand Up @@ -988,6 +985,7 @@ def _stream_gemini(
max_retries=self.max_retries,
request=request,
is_gemini=True,
metadata=self.default_metadata,
**kwargs,
)
for response_chunk in response_iter:
Expand Down
9 changes: 4 additions & 5 deletions libs/vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from dataclasses import dataclass
from typing import Any, Dict, Optional
from unittest.mock import ANY, MagicMock, patch
from unittest.mock import MagicMock, patch

import pytest
from google.cloud.aiplatform_v1beta1.types import (
Expand Down Expand Up @@ -97,11 +97,10 @@ def test_init_client(model: str, location: str) -> None:
mock_prediction_service.return_value.generate_content.return_value = response

llm._generate_gemini(messages=[])
mock_prediction_service.assert_called_once()
client_info = mock_prediction_service.call_args.kwargs["client_info"]
mock_prediction_service.assert_called_once_with(
client_options={"api_endpoint": f"{location}-aiplatform.googleapis.com"},
client_info=ANY,
)
client_options = mock_prediction_service.call_args.kwargs["client_options"]
assert client_options.api_endpoint == f"{location}-aiplatform.googleapis.com"
assert "langchain-google-vertexai" in client_info.user_agent
assert "ChatVertexAI" in client_info.user_agent
assert "langchain-google-vertexai" in client_info.client_library_version
Expand Down

0 comments on commit 5164781

Please sign in to comment.