Skip to content

Commit

Permalink
Add tool_context_manger to more places in Embeddings and VectorSearch
Browse files Browse the repository at this point in the history
  • Loading branch information
holtskinner committed Feb 26, 2024
1 parent 2993393 commit ad59603
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 15 deletions.
2 changes: 2 additions & 0 deletions libs/vertexai/langchain_google_vertexai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory
from langchain_google_vertexai.chains import create_structured_runnable
from langchain_google_vertexai.chat_models import ChatVertexAI
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser
from langchain_google_vertexai.gemma import (
GemmaChatLocalHF,
GemmaChatLocalKaggle,
GemmaChatVertexAIModelGarden,
GemmaLocalHF,
Expand Down
3 changes: 3 additions & 0 deletions libs/vertexai/langchain_google_vertexai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)

from langchain_google_vertexai._base import _VertexAICommon
from langchain_google_vertexai._utils import get_user_agent

logger = logging.getLogger(__name__)

Expand All @@ -47,6 +48,8 @@ def validate_environment(cls, values: Dict) -> Dict:
"textembedding-gecko@001"
)
values["model_name"] = "textembedding-gecko@001"
_, user_agent = get_user_agent(f"{cls.__name__}_{values['model_name']}")
with tool_context_manager(user_agent):
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
return values

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
MatchingEngineIndex,
MatchingEngineIndexEndpoint,
)
from google.cloud.aiplatform.telemetry import tool_context_manager
from google.oauth2.service_account import Credentials # type: ignore

if TYPE_CHECKING:
from google.cloud import datastore # type: ignore[attr-defined]

from langchain_google_vertexai._utils import get_client_info, get_user_agent


class VectorSearchSDKManager:
"""Class in charge of building all Google Cloud SDK Objects needed to build
Expand Down Expand Up @@ -60,7 +63,11 @@ def get_gcs_client(self) -> storage.Client:
Returns:
Google Cloud Storage Agent.
"""
return storage.Client(project=self._project_id, credentials=self._credentials)
return storage.Client(
project=self._project_id,
credentials=self._credentials,
client_info=get_client_info(module="vertex-ai-matching-engine"),
)

def get_gcs_bucket(self, bucket_name: str) -> storage.Bucket:
"""Retrieves a Google Cloud Bucket by bucket name.
Expand All @@ -79,12 +86,14 @@ def get_index(self, index_id: str) -> MatchingEngineIndex:
Returns:
MatchingEngineIndex instance.
"""
return MatchingEngineIndex(
index_name=index_id,
project=self._project_id,
location=self._region,
credentials=self._credentials,
)
_, user_agent = get_user_agent("vertex-ai-matching-engine")
with tool_context_manager(user_agent):
return MatchingEngineIndex(
index_name=index_id,
project=self._project_id,
location=self._region,
credentials=self._credentials,
)

def get_endpoint(self, endpoint_id: str) -> MatchingEngineIndexEndpoint:
"""Retrieves a MatchingEngineIndexEndpoint (VectorSearchIndexEndpoint) by id.
Expand All @@ -93,24 +102,29 @@ def get_endpoint(self, endpoint_id: str) -> MatchingEngineIndexEndpoint:
Returns:
MatchingEngineIndexEndpoint instance.
"""
return MatchingEngineIndexEndpoint(
index_endpoint_name=endpoint_id,
project=self._project_id,
location=self._region,
credentials=self._credentials,
)
_, user_agent = get_user_agent("vertex-ai-matching-engine")
with tool_context_manager(user_agent):
return MatchingEngineIndexEndpoint(
index_endpoint_name=endpoint_id,
project=self._project_id,
location=self._region,
credentials=self._credentials,
)

def get_datastore_client(self, **kwargs: Any) -> "datastore.Client":
"""Gets a datastore Client.
Args:
**kwargs: Keyword arguments to pass to datatastore.Client constructor.
**kwargs: Keyword arguments to pass to datastore.Client constructor.
Returns:
datastore Client.
"""
from google.cloud import datastore # type: ignore[attr-defined]

ds_client = datastore.Client(
project=self._project_id, credentials=self._credentials, **kwargs
project=self._project_id,
credentials=self._credentials,
client_info=get_client_info(module="vertex-ai-matching-engine"),
**kwargs,
)

return ds_client

0 comments on commit ad59603

Please sign in to comment.