Skip to content

Commit

Permalink
feat: Add custom telemetry context upon client creation
Browse files Browse the repository at this point in the history
- **Description:** Add custom user agent to Vertex AI SDK initialization. Allows API usage metrics collection.
  - Follow-up to langchain-ai/langchain#12168
- **Dependencies:** `google-cloud-aiplatform` to be updated in PR googleapis/python-aiplatform#3261
  - Before Merging, Update `raise_vertex_import_error(minimum_expected_version: str = "1.38.0")` to the actual version once the SDK is updated.
  - https://pypi.org/project/google-cloud-aiplatform/

Tested locally successfully when installing from source PR.

DO NOT MERGE - Until googleapis/python-aiplatform#3261 is Released and `minimum_expected_version` is updated to correct version.
  • Loading branch information
holtskinner committed Feb 26, 2024
1 parent 8d13520 commit 26c2487
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 40 deletions.
24 changes: 20 additions & 4 deletions libs/vertexai/langchain_google_vertexai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import dataclasses
from importlib import metadata
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

import google.api_core
import proto # type: ignore[import-untyped]
Expand Down Expand Up @@ -58,7 +58,7 @@ def raise_vertex_import_error(minimum_expected_version: str = "1.38.0") -> None:
)


def get_client_info(module: Optional[str] = None) -> "ClientInfo":
def get_user_agent(module: Optional[str] = None) -> Tuple[str, str]:
r"""Returns a custom user agent header.
Args:
Expand All @@ -67,13 +67,29 @@ def get_client_info(module: Optional[str] = None) -> "ClientInfo":
Returns:
google.api_core.gapic_v1.client_info.ClientInfo
"""
langchain_version = metadata.version("langchain")
try:
langchain_version = metadata.version("langchain")
except metadata.PackageNotFoundError:
langchain_version = "0.0.0"
client_library_version = (
f"{langchain_version}-{module}" if module else langchain_version
)
return client_library_version, f"langchain/{client_library_version}"


def get_client_info(module: Optional[str] = None) -> "ClientInfo":
r"""Returns a client info object with a custom user agent header.
Args:
module (Optional[str]):
Optional. The module for a custom user agent header.
Returns:
google.api_core.gapic_v1.client_info.ClientInfo
"""
client_library_version, user_agent = get_user_agent(module)
return ClientInfo(
client_library_version=client_library_version,
user_agent=f"langchain/{client_library_version}",
user_agent=user_agent,
)


Expand Down
38 changes: 21 additions & 17 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Wrapper around Google VertexAI chat-based models."""

from __future__ import annotations

import json
Expand All @@ -13,6 +14,7 @@
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from google.cloud.aiplatform.telemetry import tool_context_manager
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
Expand Down Expand Up @@ -55,6 +57,7 @@
from langchain_google_vertexai._image_utils import ImageBytesLoader
from langchain_google_vertexai._utils import (
get_generation_info,
get_user_agent,
is_codey_model,
is_gemini_model,
)
Expand Down Expand Up @@ -291,24 +294,25 @@ def validate_environment(cls, values: Dict) -> Dict:
raise ValueError("Safety settings are only supported for Gemini models")

cls._init_vertexai(values)
if is_gemini:
values["client"] = GenerativeModel(
model_name=values["model_name"], safety_settings=safety_settings
)
values["client_preview"] = GenerativeModel(
model_name=values["model_name"], safety_settings=safety_settings
)
else:
if is_codey_model(values["model_name"]):
model_cls = CodeChatModel
model_cls_preview = PreviewCodeChatModel
with tool_context_manager(get_user_agent("vertex-ai-llm")):
if is_gemini:
values["client"] = GenerativeModel(
model_name=values["model_name"], safety_settings=safety_settings
)
values["client_preview"] = GenerativeModel(
model_name=values["model_name"], safety_settings=safety_settings
)
else:
model_cls = ChatModel
model_cls_preview = PreviewChatModel
values["client"] = model_cls.from_pretrained(values["model_name"])
values["client_preview"] = model_cls_preview.from_pretrained(
values["model_name"]
)
if is_codey_model(values["model_name"]):
model_cls = CodeChatModel
model_cls_preview = PreviewCodeChatModel
else:
model_cls = ChatModel
model_cls_preview = PreviewChatModel
values["client"] = model_cls.from_pretrained(values["model_name"])
values["client_preview"] = model_cls_preview.from_pretrained(
values["model_name"]
)
return values

def _generate(
Expand Down
11 changes: 7 additions & 4 deletions libs/vertexai/langchain_google_vertexai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ResourceExhausted,
ServiceUnavailable,
)
from google.cloud.aiplatform.telemetry import tool_context_manager
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.pydantic_v1 import root_validator
Expand All @@ -21,6 +22,7 @@
)

from langchain_google_vertexai._base import _VertexAICommon
from langchain_google_vertexai._utils import get_user_agent

logger = logging.getLogger(__name__)

Expand All @@ -46,7 +48,8 @@ def validate_environment(cls, values: Dict) -> Dict:
"textembedding-gecko@001"
)
values["model_name"] = "textembedding-gecko@001"
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
with tool_context_manager(get_user_agent("vertex-ai-embeddings")):
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
return values

def __init__(
Expand Down Expand Up @@ -79,9 +82,9 @@ def __init__(
self.instance["task_executor"] = ThreadPoolExecutor(
max_workers=request_parallelism
)
self.instance[
"embeddings_task_type_supported"
] = not self.client._endpoint_name.endswith("/textembedding-gecko@001")
self.instance["embeddings_task_type_supported"] = (
not self.client._endpoint_name.endswith("/textembedding-gecko@001")
)

@staticmethod
def _split_by_punctuation(text: str) -> List[str]:
Expand Down
33 changes: 19 additions & 14 deletions libs/vertexai/langchain_google_vertexai/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
PredictionServiceClient,
)
from google.cloud.aiplatform.models import Prediction
from google.cloud.aiplatform.telemetry import tool_context_manager
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
from langchain_core.callbacks.manager import (
Expand Down Expand Up @@ -55,6 +56,7 @@
create_retry_decorator,
get_client_info,
get_generation_info,
get_user_agent,
is_codey_model,
is_gemini_model,
)
Expand Down Expand Up @@ -314,22 +316,25 @@ def validate_environment(cls, values: Dict) -> Dict:
model_cls = TextGenerationModel
preview_model_cls = PreviewTextGenerationModel

if tuned_model_name:
values["client"] = model_cls.get_tuned_model(tuned_model_name)
values["client_preview"] = preview_model_cls.get_tuned_model(
tuned_model_name
)
else:
if is_gemini:
values["client"] = model_cls(
model_name=model_name, safety_settings=safety_settings
)
values["client_preview"] = preview_model_cls(
model_name=model_name, safety_settings=safety_settings
with tool_context_manager(get_user_agent("vertex-ai-llm")):
if tuned_model_name:
values["client"] = model_cls.get_tuned_model(tuned_model_name)
values["client_preview"] = preview_model_cls.get_tuned_model(
tuned_model_name
)
else:
values["client"] = model_cls.from_pretrained(model_name)
values["client_preview"] = preview_model_cls.from_pretrained(model_name)
if is_gemini:
values["client"] = model_cls(
model_name=model_name, safety_settings=safety_settings
)
values["client_preview"] = preview_model_cls(
model_name=model_name, safety_settings=safety_settings
)
else:
values["client"] = model_cls.from_pretrained(model_name)
values["client_preview"] = preview_model_cls.from_pretrained(
model_name
)

if values["streaming"] and values["n"] > 1:
raise ValueError("Only one candidate can be generated with streaming!")
Expand Down
5 changes: 4 additions & 1 deletion libs/vertexai/langchain_google_vertexai/vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Any, Dict, List, Union

from google.cloud.aiplatform.telemetry import tool_context_manager
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel, BaseLLM
from langchain_core.messages import AIMessage, BaseMessage
Expand All @@ -22,6 +23,7 @@
get_text_str_from_content_part,
image_bytes_to_b64_string,
)
from langchain_google_vertexai._utils import get_user_agent


class _BaseImageTextModel(BaseModel):
Expand All @@ -38,7 +40,8 @@ class _BaseImageTextModel(BaseModel):

def _create_model(self) -> ImageTextModel:
"""Builds the model object from the class attributes."""
return ImageTextModel.from_pretrained(model_name=self.model_name)
with tool_context_manager(get_user_agent("vertex-ai-imagen")):
return ImageTextModel.from_pretrained(model_name=self.model_name)

def _get_image_from_message_part(self, message_part: str | Dict) -> Image | None:
"""Given a message part obtain a image if the part represents it.
Expand Down

0 comments on commit 26c2487

Please sign in to comment.