Skip to content

Commit

Permalink
feat: Add custom telemetry context upon client creation (#31)
Browse files Browse the repository at this point in the history
* feat: Add custom telemetry context upon client creation

- **Description:** Add custom user agent to Vertex AI SDK initialization. Allows API usage metrics collection.
  - Follow-up to langchain-ai/langchain#12168

---------

Co-authored-by: Leonid Kuligin <[email protected]>
  • Loading branch information
holtskinner and lkuligin authored Mar 15, 2024
1 parent 4dd35fe commit 5ad6963
Show file tree
Hide file tree
Showing 10 changed files with 474 additions and 373 deletions.
7 changes: 7 additions & 0 deletions libs/vertexai/langchain_google_vertexai/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory
from langchain_google_vertexai._utils import (
get_client_info,
get_user_agent,
is_codey_model,
is_gemini_model,
)
Expand Down Expand Up @@ -142,6 +143,12 @@ def _default_params(self) -> Dict[str, Any]:
)
return updated_params

@property
def _user_agent(self) -> str:
"""Gets the User Agent."""
_, user_agent = get_user_agent(f"{type(self).__name__}_{self.model_name}")
return user_agent

@classmethod
def _init_vertexai(cls, values: Dict) -> None:
vertexai.init(
Expand Down
2 changes: 1 addition & 1 deletion libs/vertexai/langchain_google_vertexai/_image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from urllib.parse import urlparse

import requests
from google.cloud import storage # type: ignore[attr-defined]
from google.cloud import storage


class ImageBytesLoader:
Expand Down
30 changes: 23 additions & 7 deletions libs/vertexai/langchain_google_vertexai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dataclasses
import re
from importlib import metadata
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import google.api_core
import proto # type: ignore[import-untyped]
Expand Down Expand Up @@ -45,7 +45,7 @@ def create_retry_decorator(
return decorator


def raise_vertex_import_error(minimum_expected_version: str = "1.38.0") -> None:
def raise_vertex_import_error(minimum_expected_version: str = "1.44.0") -> None:
"""Raise ImportError related to Vertex SDK being not available.
Args:
Expand All @@ -59,27 +59,43 @@ def raise_vertex_import_error(minimum_expected_version: str = "1.38.0") -> None:
)


def get_client_info(module: Optional[str] = None) -> "ClientInfo":
def get_user_agent(module: Optional[str] = None) -> Tuple[str, str]:
r"""Returns a custom user agent header.
Args:
module (Optional[str]):
Optional. The module for a custom user agent header.
Returns:
google.api_core.gapic_v1.client_info.ClientInfo
Tuple[str, str]
"""
langchain_version = metadata.version("langchain")
try:
langchain_version = metadata.version("langchain")
except metadata.PackageNotFoundError:
langchain_version = "0.0.0"
client_library_version = (
f"{langchain_version}-{module}" if module else langchain_version
)
return client_library_version, f"langchain/{client_library_version}"


def get_client_info(module: Optional[str] = None) -> "ClientInfo":
r"""Returns a client info object with a custom user agent header.
Args:
module (Optional[str]):
Optional. The module for a custom user agent header.
Returns:
google.api_core.gapic_v1.client_info.ClientInfo
"""
client_library_version, user_agent = get_user_agent(module)
return ClientInfo(
client_library_version=client_library_version,
user_agent=f"langchain/{client_library_version}",
user_agent=user_agent,
)


def load_image_from_gcs(path: str, project: Optional[str] = None) -> Image:
"""Loads im Image from GCS."""
"""Loads an Image from GCS."""
gcs_client = storage.Client(project=project)
pieces = path.split("/")
blobs = list(gcs_client.list_blobs(pieces[2], prefix="/".join(pieces[3:])))
Expand Down
177 changes: 96 additions & 81 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Wrapper around Google VertexAI chat-based models."""
from __future__ import annotations

from __future__ import annotations # noqa

import json
import logging
Expand All @@ -10,6 +11,8 @@
import proto # type: ignore[import-untyped]
from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart
from google.cloud.aiplatform_v1beta1.types.tool import FunctionCall
from google.cloud.aiplatform import telemetry

from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
Expand Down Expand Up @@ -250,7 +253,7 @@ def _get_question(messages: List[BaseMessage]) -> HumanMessage:
def _parse_response_candidate(response_candidate: "Candidate") -> AIMessage:
try:
content = response_candidate.text
except ValueError:
except AttributeError:
content = ""

additional_kwargs = {}
Expand Down Expand Up @@ -345,10 +348,11 @@ def _generate(
should_stream = stream if stream is not None else self.streaming
safety_settings = kwargs.pop("safety_settings", None)
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
with telemetry.tool_context_manager(self._user_agent):
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)

params = self._prepare_params(stop=stop, stream=False, **kwargs)
msg_params = {}
Expand All @@ -362,17 +366,19 @@ def _generate(
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
with telemetry.tool_context_manager(self._user_agent):
chat = self.client.start_chat(history=history_gemini)

# set param to `functions` until core tool/function calling implemented
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
response = chat.send_message(
message,
generation_config=params,
tools=tools,
safety_settings=safety_settings,
)
with telemetry.tool_context_manager(self._user_agent):
response = chat.send_message(
message,
generation_config=params,
tools=tools,
safety_settings=safety_settings,
)
generations = [
ChatGeneration(
message=_parse_response_candidate(candidate),
Expand All @@ -390,8 +396,9 @@ def _generate(
examples = kwargs.get("examples") or self.examples
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
response = chat.send_message(question.content, **msg_params)
with telemetry.tool_context_manager(self._user_agent):
chat = self._start_chat(history, **params)
response = chat.send_message(question.content, **msg_params)
generations = [
ChatGeneration(
message=AIMessage(content=candidate.text),
Expand Down Expand Up @@ -443,16 +450,18 @@ async def _agenerate(
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
with telemetry.tool_context_manager(self._user_agent):
chat = self.client.start_chat(history=history_gemini)
# set param to `functions` until core tool/function calling implemented
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
response = await chat.send_message_async(
message,
generation_config=params,
tools=tools,
safety_settings=safety_settings,
)
with telemetry.tool_context_manager(self._user_agent):
response = await chat.send_message_async(
message,
generation_config=params,
tools=tools,
safety_settings=safety_settings,
)
generations = [
ChatGeneration(
message=_parse_response_candidate(c),
Expand All @@ -470,8 +479,9 @@ async def _agenerate(
examples = kwargs.get("examples", None) or self.examples
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
response = await chat.send_message_async(question.content, **msg_params)
with telemetry.tool_context_manager(self._user_agent):
chat = self._start_chat(history, **params)
response = await chat.send_message_async(question.content, **msg_params)
generations = [
ChatGeneration(
message=AIMessage(content=r.text),
Expand Down Expand Up @@ -500,52 +510,55 @@ def _stream(
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
with telemetry.tool_context_manager(self._user_agent):
chat = self.client.start_chat(history=history_gemini)
# set param to `functions` until core tool/function calling implemented
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
safety_settings = params.pop("safety_settings", None)
responses = chat.send_message(
message,
stream=True,
generation_config=params,
safety_settings=safety_settings,
tools=tools,
)
for response in responses:
message = _parse_response_candidate(response.candidates[0])
if run_manager:
run_manager.on_llm_new_token(message.content)
yield ChatGenerationChunk(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
),
generation_info=get_generation_info(
response.candidates[0],
self._is_gemini_model,
usage_metadata=response.to_dict().get("usage_metadata"),
),
with telemetry.tool_context_manager(self._user_agent):
responses = chat.send_message(
message,
stream=True,
generation_config=params,
safety_settings=safety_settings,
tools=tools,
)
for response in responses:
message = _parse_response_candidate(response.candidates[0])
if run_manager:
run_manager.on_llm_new_token(message.content)
yield ChatGenerationChunk(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
),
generation_info=get_generation_info(
response.candidates[0],
self._is_gemini_model,
usage_metadata=response.to_dict().get("usage_metadata"),
),
)
else:
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
examples = kwargs.get("examples", None)
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
responses = chat.send_message_streaming(question.content, **params)
for response in responses:
if run_manager:
run_manager.on_llm_new_token(response.text)
yield ChatGenerationChunk(
message=AIMessageChunk(content=response.text),
generation_info=get_generation_info(
response,
self._is_gemini_model,
usage_metadata=response.raw_prediction_response.metadata,
),
)
with telemetry.tool_context_manager(self._user_agent):
chat = self._start_chat(history, **params)
responses = chat.send_message_streaming(question.content, **params)
for response in responses:
if run_manager:
run_manager.on_llm_new_token(response.text)
yield ChatGenerationChunk(
message=AIMessageChunk(content=response.text),
generation_info=get_generation_info(
response,
self._is_gemini_model,
usage_metadata=response.raw_prediction_response.metadata,
),
)

async def _astream(
self,
Expand All @@ -563,31 +576,33 @@ async def _astream(
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
with telemetry.tool_context_manager(self._user_agent):
chat = self.client.start_chat(history=history_gemini)
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
safety_settings = params.pop("safety_settings", None)
async for chunk in await chat.send_message_async(
message,
stream=True,
generation_config=params,
safety_settings=safety_settings,
tools=tools,
):
message = _parse_response_candidate(chunk.candidates[0])
if run_manager:
await run_manager.on_llm_new_token(message.content)
yield ChatGenerationChunk(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
),
generation_info=get_generation_info(
chunk.candidates[0],
self._is_gemini_model,
usage_metadata=chunk.to_dict().get("usage_metadata"),
),
)
with telemetry.tool_context_manager(self._user_agent):
async for chunk in await chat.send_message_async(
message,
stream=True,
generation_config=params,
safety_settings=safety_settings,
tools=tools,
):
message = _parse_response_candidate(chunk.candidates[0])
if run_manager:
await run_manager.on_llm_new_token(message.content)
yield ChatGenerationChunk(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
),
generation_info=get_generation_info(
chunk.candidates[0],
self._is_gemini_model,
usage_metadata=chunk.to_dict().get("usage_metadata"),
),
)

def with_structured_output(
self,
Expand Down
Loading

0 comments on commit 5ad6963

Please sign in to comment.