Skip to content

Commit

Permalink
feat: add psc support feature store (#632)
Browse files Browse the repository at this point in the history
  • Loading branch information
eliasecchig authored Dec 9, 2024
1 parent 019d578 commit 5c453e0
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _similarity_search_by_vectors_with_scores_and_embeddings(
filter: Optional[Dict[str, Any]] = None,
k: int = 5,
batch_size: Union[int, None] = None,
) -> list[list[list[Any]]]:
) -> List[List[List[Any]]]:
...

@model_validator(mode="after")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,31 @@ class VertexFSVectorStore(BaseBigQueryVectorStore):
crowding_column (str, optional): Column to use for crowding.
distance_measure_type (str, optional): Distance measure type (default:
DOT_PRODUCT_DISTANCE).
enable_private_service_connect (bool, optional): Whether to enable Private
Service Connect for the online store at creation time. Defaults to False.
transport (Optional[Union[str, FeatureOnlineStoreServiceTransport,
Callable[..., FeatureOnlineStoreServiceTransport]]]): Transport
configuration for API requests. Can be a transport instance, string
identifier, or callable that returns a transport.
Required when using Private Service Connect for querying. Example:
```python
import grpc
from google.cloud.aiplatform_v1.services.feature_online_store_service.\
transports.grpc import FeatureOnlineStoreServiceGrpcTransport
transport = FeatureOnlineStoreServiceGrpcTransport(
channel=grpc.insecure_channel("10.128.0.1:10002")
)
vertex_fs = VertexFSVectorStore(
transport=transport,
Your other params....
)
vertex_fs.similarity_search("My query")
```
project_allowlist (List[str], optional): Only needed when
`enable_private_service_connect` is set to true. List of projects allowed
to access the online store. Required at creation time.
Defaults to empty list.
"""

online_store_name: Union[str, None] = None
Expand All @@ -78,16 +103,23 @@ class VertexFSVectorStore(BaseBigQueryVectorStore):
crowding_column: Optional[str] = None
distance_measure_type: Optional[str] = None
online_store: Any = None
enable_private_service_connect: bool = False
transport: Any = None
project_allowlist: List[str] = []
_user_agent: str = ""
feature_view: Any = None
_admin_client: Any = None

@model_validator(mode="after")
def _initialize_bq_vector_index(self) -> Self:
import vertexai
from google.cloud.aiplatform_v1beta1 import (
from google.cloud.aiplatform_v1 import (
FeatureOnlineStoreAdminServiceClient,
FeatureOnlineStoreServiceClient,
)

# ruff: noqa: E501
from google.cloud.aiplatform_v1.services.feature_online_store_service.transports.base import (
FeatureOnlineStoreServiceTransport,
)
from vertexai.resources.preview.feature_store import (
utils, # type: ignore[import-untyped]
Expand All @@ -104,25 +136,25 @@ def _initialize_bq_vector_index(self) -> Self:
self.online_store_name = self.dataset_name
if self.view_name is None:
self.view_name = self.table_name
if self.transport:
if not isinstance(self.transport, FeatureOnlineStoreServiceTransport):
raise ValueError(
"Transport must be an instance of "
"FeatureOnlineStoreServiceTransport"
)

api_endpoint = f"{self.location}-aiplatform.googleapis.com"
self._admin_client = FeatureOnlineStoreAdminServiceClient(
client_options={"api_endpoint": api_endpoint},
client_info=get_client_info(module=self._user_agent),
)
self.online_store = _create_online_store(
project_id=self.project_id,
location=self.location,
online_store_name=self.online_store_name,
_admin_client=self._admin_client,
_logger=self._logger,
)
self.online_store = self._create_online_store()

gca_resource = self.online_store.gca_resource
endpoint = gca_resource.dedicated_serving_endpoint.public_endpoint_domain_name
self._search_client = FeatureOnlineStoreServiceClient(
client_options={"api_endpoint": endpoint},
client_info=get_client_info(module=self._user_agent),
public_endpoint = (
gca_resource.dedicated_serving_endpoint.public_endpoint_domain_name
)
self._search_client = self._get_search_client(public_endpoint=public_endpoint)
self.feature_view = _get_feature_view(self.online_store, self.view_name)

self._logger.info(
Expand All @@ -131,16 +163,22 @@ def _initialize_bq_vector_index(self) -> Self:
)
return self

def _init_store(self) -> None:
from google.cloud.aiplatform_v1beta1 import FeatureOnlineStoreServiceClient
def _get_search_client(self, public_endpoint: Optional[str] = None) -> Any:
from google.cloud.aiplatform_v1 import FeatureOnlineStoreServiceClient

return FeatureOnlineStoreServiceClient(
transport=self.transport,
client_options={"api_endpoint": public_endpoint},
client_info=get_client_info(module=self._user_agent),
)

def _init_store(self) -> None:
self.online_store = self._create_online_store()
gca_resource = self.online_store.gca_resource
endpoint = gca_resource.dedicated_serving_endpoint.public_endpoint_domain_name
self._search_client = FeatureOnlineStoreServiceClient(
client_options={"api_endpoint": endpoint},
client_info=get_client_info(module=self._user_agent),
public_endpoint = (
gca_resource.dedicated_serving_endpoint.public_endpoint_domain_name
)
self._search_client = self._get_search_client(public_endpoint=public_endpoint)
self.feature_view = self._get_feature_view()

def _validate_bq_existing_source(
Expand Down Expand Up @@ -218,12 +256,11 @@ def sync_data(self) -> None:

self._wait_until_dummy_query_success()

def _similarity_search_by_vectors_with_scores_and_embeddings(
def _similarity_search_by_vectors_with_scores_and_embeddings( # type: ignore[override]
self,
embeddings: List[List[float]],
filter: Optional[Dict[str, Any]] = None,
k: int = 5,
batch_size: Union[int, None] = None,
**kwargs: Any,
) -> List[List[List[Any]]]:
"""Performs a similarity search using vector embeddings
Expand All @@ -242,7 +279,6 @@ def _similarity_search_by_vectors_with_scores_and_embeddings(
"int_property": 123
}
k: The number of top results to return for each query.
batch_size: The size of batches to process embeddings.
Returns:
A list of lists of lists. Each inner list represents the results for a
Expand Down Expand Up @@ -390,7 +426,7 @@ def _search_embedding(
leaf_nodes_search_fraction: Optional[float] = None,
) -> MutableSequence[Any]:
from google.cloud import aiplatform
from google.cloud.aiplatform_v1beta1.types import (
from google.cloud.aiplatform_v1.types import (
NearestNeighborQuery,
feature_online_store_service,
)
Expand Down Expand Up @@ -420,14 +456,36 @@ def _search_embedding(

def _create_online_store(self) -> Any:
# Search for existing Online store
import vertexai

stores_list = vertexai.resources.preview.FeatureOnlineStore.list(
project=self.project_id,
location=self.location,
)
for store in stores_list:
if store.name == self.online_store_name:
return store

# Create it otherwise
if self.online_store_name:
return _create_online_store(
project_id=self.project_id,
fos = vertexai.resources.preview.FeatureOnlineStore.create_optimized_store(
project=self.project_id,
location=self.location,
online_store_name=self.online_store_name,
_admin_client=self._admin_client,
_logger=self._logger,
name=self.online_store_name,
enable_private_service_connect=self.enable_private_service_connect,
project_allowlist=self.project_allowlist,
credentials=self.credentials,
)
if self.enable_private_service_connect:
self._logger.info(
"Optimized Store created with Private Service Connect Enabled. "
"Please follow instructions in "
"https://cloud.google.com/vertex-ai/docs/featurestore/latest/"
"serve-feature-values#optimized_serving_private to setup PSC. "
"Note that Service attachment string will be available after "
"the first feature view creation and sync."
)
return fos

def _create_feature_view(self) -> Any:
import vertexai
Expand Down Expand Up @@ -523,47 +581,6 @@ def to_bq_vector_store(self, **kwargs: Any) -> Any:
return bq_obj


def _create_online_store(
project_id: str,
location: str,
online_store_name: str,
_logger: Any,
_admin_client: Any,
) -> Any:
# Search for existing Online store
import vertexai
from google.cloud.aiplatform_v1beta1.types import (
feature_online_store as feature_online_store_pb2,
)

stores_list = vertexai.resources.preview.FeatureOnlineStore.list(
project=project_id, location=location
)
for store in stores_list:
if store.name == online_store_name:
return store

_logger.info("Creating feature store online store")
# Create it otherwise

online_store_config = feature_online_store_pb2.FeatureOnlineStore(
optimized=feature_online_store_pb2.FeatureOnlineStore.Optimized()
)
create_store_lro = _admin_client.create_feature_online_store(
parent=f"projects/{project_id}/locations/{location}",
feature_online_store_id=online_store_name,
feature_online_store=online_store_config,
)
_logger.info(create_store_lro.result())
_logger.info(create_store_lro.result())
stores_list = vertexai.resources.preview.FeatureOnlineStore.list(
project=project_id, location=location
)
for store in stores_list:
if store.name == online_store_name:
return store


def _get_feature_view(online_store: Any, view_name: Optional[str]) -> Any:
# Search for existing Feature view
import vertexai
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import random

import grpc
import pytest

from langchain_google_community import VertexFSVectorStore
Expand Down Expand Up @@ -127,3 +128,33 @@ def test_to_bq_vector_store(
"""Test getter feature store vectorstore"""
new_store = store_fs_vectorstore.to_bq_vector_store()
assert new_store.dataset_name == TEST_DATASET


@pytest.mark.extended
def test_psc_feature_store() -> None:
"""Test creation of feature store with private service connect enabled"""
# ruff: noqa: E501
from google.cloud.aiplatform_v1.services.feature_online_store_service.transports.grpc import (
FeatureOnlineStoreServiceGrpcTransport,
)

embedding_model = FakeEmbeddings(size=EMBEDDING_SIZE)
project_id = os.environ.get("PROJECT_ID", None)

transport = FeatureOnlineStoreServiceGrpcTransport(
channel=grpc.insecure_channel("dummy:10002")
)
try:
vertex_fs = VertexFSVectorStore(
project_id=project_id, # type: ignore[arg-type]
location="us-central1",
dataset_name=TEST_DATASET + f"_psc_{str(random.randint(1,100000))}",
table_name=TEST_TABLE_NAME,
embedding=embedding_model,
enable_private_service_connect=True,
project_allowlist=[project_id], # type: ignore[list-item]
transport=transport,
)
finally:
# Clean up resources
vertex_fs.online_store.delete()

0 comments on commit 5c453e0

Please sign in to comment.