From cf381490aecc6b432f597f858637041d0647a81d Mon Sep 17 00:00:00 2001 From: Theodor Mihalache Date: Tue, 4 Jun 2024 14:03:51 -0400 Subject: [PATCH] Added missing remote offline store apis implementation Signed-off-by: Theodor Mihalache --- sdk/python/feast/feature_store.py | 5 + .../feast/infra/offline_stores/remote.py | 357 ++++++++++++++---- sdk/python/feast/offline_server.py | 216 ++++++++--- sdk/python/feast/templates/local/bootstrap.py | 1 + .../local/feature_repo/example_repo.py | 5 + .../offline_stores/test_offline_store.py | 34 +- sdk/python/tests/unit/test_offline_server.py | 83 ++-- 7 files changed, 537 insertions(+), 164 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index db21b44283b..e7cb2f97902 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1121,6 +1121,11 @@ def create_saved_dataset( dataset.min_event_timestamp = from_.metadata.min_event_timestamp dataset.max_event_timestamp = from_.metadata.max_event_timestamp + # Added to handle the case that the offline store is remote + self._registry.apply_data_source( + storage.to_data_source(), self.project, commit=True + ) + from_.persist(storage=storage, allow_overwrite=allow_overwrite) dataset = dataset.with_retrieval_job( diff --git a/sdk/python/feast/infra/offline_stores/remote.py b/sdk/python/feast/infra/offline_stores/remote.py index a4bce15c3cf..5be1ed50a58 100644 --- a/sdk/python/feast/infra/offline_stores/remote.py +++ b/sdk/python/feast/infra/offline_stores/remote.py @@ -3,8 +3,9 @@ import uuid from datetime import datetime from pathlib import Path -from typing import Any, Callable, List, Literal, Optional, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +import numpy as np import pandas as pd import pyarrow as pa import pyarrow.flight as fl @@ -13,14 +14,21 @@ from feast import OnDemandFeatureView from feast.data_source import DataSource -from feast.feature_logging import LoggingConfig, LoggingSource +from feast.feature_logging import ( + FeatureServiceLoggingSource, + LoggingConfig, + LoggingSource, +) from feast.feature_view import FeatureView +from feast.infra.offline_stores import offline_utils from feast.infra.offline_stores.offline_store import ( OfflineStore, RetrievalJob, + RetrievalMetadata, ) from feast.infra.registry.base_registry import BaseRegistry from feast.repo_config import FeastConfigBaseModel, RepoConfig +from feast.saved_dataset import SavedDatasetStorage logger = logging.getLogger(__name__) @@ -38,38 +46,19 @@ class RemoteRetrievalJob(RetrievalJob): def __init__( self, client: fl.FlightClient, - feature_view_names: List[str], - name_aliases: List[Optional[str]], - feature_refs: List[str], - entity_df: Union[pd.DataFrame, str], - project: str, - full_feature_names: bool = False, + api: str, + api_parameters: Dict[str, Any], + entity_df: Union[pd.DataFrame, str] = None, + table: pa.Table = None, + metadata: Optional[RetrievalMetadata] = None, ): # Initialize the client connection self.client = client - self.feature_view_names = feature_view_names - self.name_aliases = name_aliases - self.feature_refs = feature_refs + self.api = api + self.api_parameters = api_parameters self.entity_df = entity_df - self.project = project - self._full_feature_names = full_feature_names - - @property - def full_feature_names(self) -> bool: - return self._full_feature_names - - # TODO add one specialized implementation for each OfflineStore API - # This can result in a dictionary of functions indexed by api (e.g., "get_historical_features") - def _put_parameters(self, command_descriptor): - entity_df_table = pa.Table.from_pandas(self.entity_df) - - writer, _ = self.client.do_put( - command_descriptor, - entity_df_table.schema, - ) - - writer.write_table(entity_df_table) - writer.close() + self.table = table + self._metadata = metadata # Invoked to realize the Pandas DataFrame def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: @@ -79,34 +68,51 @@ def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: # Invoked to synchronously execute the underlying query and return the result as an arrow table # This is where do_get service is invoked def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table: - # Generate unique command identifier - command_id = str(uuid.uuid4()) - command = { - "command_id": command_id, - "api": "get_historical_features", - "feature_view_names": self.feature_view_names, - "name_aliases": self.name_aliases, - "feature_refs": self.feature_refs, - "project": self.project, - "full_feature_names": self._full_feature_names, - } - command_descriptor = fl.FlightDescriptor.for_command( - json.dumps( - command, - ) + return _send_retrieve_remote( + self.api, self.api_parameters, self.entity_df, self.table, self.client ) - self._put_parameters(command_descriptor) - flight = self.client.get_flight_info(command_descriptor) - ticket = flight.endpoints[0].ticket - - reader = self.client.do_get(ticket) - return reader.read_all() - @property def on_demand_feature_views(self) -> List[OnDemandFeatureView]: return [] + @property + def metadata(self) -> Optional[RetrievalMetadata]: + return self._metadata + + @property + def full_feature_names(self) -> bool: + logger.info(f"{self.api_parameters}") + return self.api_parameters["full_feature_names"] + + def persist( + self, + storage: SavedDatasetStorage, + allow_overwrite: bool = False, + timeout: Optional[int] = None, + ): + api_parameters = { + "data_source_name": storage.to_data_source().name, + "allow_overwrite": allow_overwrite, + "timeout": timeout, + } + + # Add api parameters to command + for key, value in self.api_parameters.items(): + api_parameters[key] = value + + command_descriptor = _call_put( + api=self.api, + api_parameters=api_parameters, + client=self.client, + table=self.table, + entity_df=self.entity_df, + ) + bytes = command_descriptor.serialize() + + self.client.do_action(pa.flight.Action("persist", bytes)) + logger.info("end of persist()") + class RemoteOfflineStore(OfflineStore): @staticmethod @@ -121,8 +127,6 @@ def get_historical_features( ) -> RemoteRetrievalJob: assert isinstance(config.offline_store, RemoteOfflineStoreConfig) - # TODO: extend RemoteRetrievalJob API with all method parameters - # Initialize the client connection location = f"grpc://{config.offline_store.host}:{config.offline_store.port}" client = fl.connect(location=location) @@ -130,14 +134,23 @@ def get_historical_features( feature_view_names = [fv.name for fv in feature_views] name_aliases = [fv.projection.name_alias for fv in feature_views] + + api_parameters = { + "feature_view_names": feature_view_names, + "feature_refs": feature_refs, + "project": project, + "full_feature_names": full_feature_names, + "name_aliases": name_aliases, + } + return RemoteRetrievalJob( client=client, - feature_view_names=feature_view_names, - name_aliases=name_aliases, - feature_refs=feature_refs, + api="get_historical_features", + api_parameters=api_parameters, entity_df=entity_df, - project=project, - full_feature_names=full_feature_names, + metadata=_create_retrieval_metadata( + feature_refs, entity_df + ), ) @staticmethod @@ -150,8 +163,27 @@ def pull_all_from_table_or_query( start_date: datetime, end_date: datetime, ) -> RetrievalJob: - # TODO Implementation here. - raise NotImplementedError + assert isinstance(config.offline_store, RemoteOfflineStoreConfig) + + # Initialize the client connection + location = f"grpc://{config.offline_store.host}:{config.offline_store.port}" + client = fl.connect(location=location) + logger.info(f"Connecting FlightClient at {location}") + + api_parameters = { + "data_source_name": data_source.name, + "join_key_columns": join_key_columns, + "feature_name_columns": feature_name_columns, + "timestamp_field": timestamp_field, + "start_date": start_date.isoformat(), + "end_date": end_date.isoformat(), + } + + return RemoteRetrievalJob( + client=client, + api="pull_all_from_table_or_query", + api_parameters=api_parameters, + ) @staticmethod def pull_latest_from_table_or_query( @@ -164,8 +196,28 @@ def pull_latest_from_table_or_query( start_date: datetime, end_date: datetime, ) -> RetrievalJob: - # TODO Implementation here. - raise NotImplementedError + assert isinstance(config.offline_store, RemoteOfflineStoreConfig) + + # Initialize the client connection + location = f"grpc://{config.offline_store.host}:{config.offline_store.port}" + client = fl.connect(location=location) + logger.info(f"Connecting FlightClient at {location}") + + api_parameters = { + "data_source_name": data_source.name, + "join_key_columns": join_key_columns, + "feature_name_columns": feature_name_columns, + "timestamp_field": timestamp_field, + "created_timestamp_column": created_timestamp_column, + "start_date": start_date.isoformat(), + "end_date": end_date.isoformat(), + } + + return RemoteRetrievalJob( + client=client, + api="pull_latest_from_table_or_query", + api_parameters=api_parameters, + ) @staticmethod def write_logged_features( @@ -175,8 +227,31 @@ def write_logged_features( logging_config: LoggingConfig, registry: BaseRegistry, ): - # TODO Implementation here. - raise NotImplementedError + assert isinstance(config.offline_store, RemoteOfflineStoreConfig) + assert isinstance(source, FeatureServiceLoggingSource) + + if isinstance(data, Path): + data = pyarrow.parquet.read_table(data, use_threads=False, pre_buffer=False) + + # Initialize the client connection + location = f"grpc://{config.offline_store.host}:{config.offline_store.port}" + client = fl.connect(location=location) + logger.info(f"Connecting FlightClient at {location}") + + api_parameters = { + "feature_service_name": source._feature_service.name, + } + + command_descriptor = _call_put( + api="write_logged_features", + api_parameters=api_parameters, + client=client, + table=data, + entity_df=None, + ) + bytes = command_descriptor.serialize() + + client.do_action(pa.flight.Action("write_logged_features", bytes)) @staticmethod def offline_write_batch( @@ -185,5 +260,153 @@ def offline_write_batch( table: pyarrow.Table, progress: Optional[Callable[[int], Any]], ): - # TODO Implementation here. - raise NotImplementedError + assert isinstance(config.offline_store, RemoteOfflineStoreConfig) + + # Initialize the client connection + location = f"grpc://{config.offline_store.host}:{config.offline_store.port}" + client = fl.connect(location=location) + logger.info(f"Connecting FlightClient at {location}") + + feature_view_names = [feature_view.name] + name_aliases = [feature_view.projection.name_alias] + + api_parameters = { + "feature_view_names": feature_view_names, + "progress": progress, + "name_aliases": name_aliases, + } + + command_descriptor = _call_put( + api="offline_write_batch", + api_parameters=api_parameters, + client=client, + table=table, + entity_df=None, + ) + bytes = command_descriptor.serialize() + + client.do_action(pa.flight.Action("offline_write_batch", bytes)) + +def _create_retrieval_metadata( + feature_refs: List[str], entity_df: pd.DataFrame +): + entity_schema = _get_entity_schema( + entity_df=entity_df, + ) + + event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( + entity_schema=entity_schema, + ) + + timestamp_range = _get_entity_df_event_timestamp_range( + entity_df, event_timestamp_col + ) + + return RetrievalMetadata( + features=feature_refs, + keys=list(set(entity_df.columns) - {event_timestamp_col}), + min_event_timestamp=timestamp_range[0], + max_event_timestamp=timestamp_range[1], + ) + +def _get_entity_schema(entity_df: pd.DataFrame) -> Dict[str, np.dtype]: + return dict(zip(entity_df.columns, entity_df.dtypes)) + +def _get_entity_df_event_timestamp_range( + entity_df: Union[pd.DataFrame, str], + entity_df_event_timestamp_col: str, +) -> Tuple[datetime, datetime]: + if not isinstance(entity_df, pd.DataFrame): + raise ValueError( + f"Please provide an entity_df of type {type(pd.DataFrame)} instead of type {type(entity_df)}" + ) + + entity_df_event_timestamp = entity_df.loc[ + :, entity_df_event_timestamp_col + ].infer_objects() + if pd.api.types.is_string_dtype(entity_df_event_timestamp): + entity_df_event_timestamp = pd.to_datetime( + entity_df_event_timestamp, utc=True + ) + + return ( + entity_df_event_timestamp.min().to_pydatetime(), + entity_df_event_timestamp.max().to_pydatetime(), + ) + + +def _send_retrieve_remote( + api: str, + api_parameters: Dict[str, Any], + entity_df: Union[pd.DataFrame, str], + table: pa.Table, + client: fl.FlightClient, +): + command_descriptor = _call_put(api, api_parameters, client, entity_df, table) + return _call_get(client, command_descriptor) + + +def _call_get(client, command_descriptor): + flight = client.get_flight_info(command_descriptor) + ticket = flight.endpoints[0].ticket + reader = client.do_get(ticket) + return reader.read_all() + + +def _call_put(api, api_parameters, client, entity_df, table): + # Generate unique command identifier + command_id = str(uuid.uuid4()) + command = { + "command_id": command_id, + "api": api, + } + # Add api parameters to command + for key, value in api_parameters.items(): + command[key] = value + + command_descriptor = fl.FlightDescriptor.for_command( + json.dumps( + command, + ) + ) + + _put_parameters(command_descriptor, entity_df, table, client) + return command_descriptor + + +def _put_parameters( + command_descriptor, + entity_df: Union[pd.DataFrame, str], + table: pa.Table, + client: fl.FlightClient, +): + updatedTable: pa.Table + + if entity_df is not None: + updatedTable = pa.Table.from_pandas(entity_df) + elif table is not None: + updatedTable = table + else: + updatedTable = _create_empty_table() + + writer, _ = client.do_put( + command_descriptor, + updatedTable.schema, + ) + + writer.write_table(updatedTable) + writer.close() + + +def _create_empty_table(): + schema = pa.schema( + { + "key": pa.string(), + } + ) + + keys = ["mock_key"] + + table = pa.Table.from_pydict(dict(zip(schema.names, keys)), schema=schema) + + return table diff --git a/sdk/python/feast/offline_server.py b/sdk/python/feast/offline_server.py index ebd83092c95..3876604e1b9 100644 --- a/sdk/python/feast/offline_server.py +++ b/sdk/python/feast/offline_server.py @@ -2,13 +2,17 @@ import json import logging import traceback +from datetime import datetime from typing import Any, Dict, List import pyarrow as pa import pyarrow.flight as fl -from feast import FeatureStore, FeatureView +from feast import FeatureStore, FeatureView, utils +from feast.feature_logging import FeatureServiceLoggingSource from feast.feature_view import DUMMY_ENTITY_NAME +from feast.infra.offline_stores.offline_utils import get_offline_store_from_config +from feast.saved_dataset import SavedDatasetStorage logger = logging.getLogger(__name__) @@ -20,6 +24,7 @@ def __init__(self, store: FeatureStore, location: str, **kwargs): # A dictionary of configured flights, e.g. API calls received and not yet served self.flights: Dict[str, Any] = {} self.store = store + self.offline_store = get_offline_store_from_config(store.config.offline_store) @classmethod def descriptor_to_key(self, descriptor): @@ -127,69 +132,188 @@ def do_get(self, context, ticket): logger.debug(f"get command is {command}") logger.debug(f"requested api is {api}") if api == "get_historical_features": - # Extract parameters from the internal flights dictionary - entity_df_value = self.flights[key] - entity_df = pa.Table.to_pandas(entity_df_value) - logger.debug(f"do_get: entity_df is {entity_df}") + training_df = self.get_historical_features(command, key) + elif api == "pull_all_from_table_or_query": + training_df = self.pull_all_from_table_or_query(command) + elif api == "pull_latest_from_table_or_query": + training_df = self.pull_latest_from_table_or_query(command) + else: + raise NotImplementedError + + table = pa.Table.from_pandas(training_df) + # Get service is consumed, so we clear the corresponding flight and data + del self.flights[key] + return fl.RecordBatchStream(table) + + def offline_write_batch(self, command, key): + try: feature_view_names = command["feature_view_names"] - logger.debug(f"do_get: feature_view_names is {feature_view_names}") name_aliases = command["name_aliases"] - logger.debug(f"do_get: name_aliases is {name_aliases}") - feature_refs = command["feature_refs"] - logger.debug(f"do_get: feature_refs is {feature_refs}") - project = command["project"] - logger.debug(f"do_get: project is {project}") - full_feature_names = command["full_feature_names"] + project = self.store.config.project feature_views = self.list_feature_views_by_name( feature_view_names=feature_view_names, name_aliases=name_aliases, project=project, ) - logger.debug(f"do_get: feature_views is {feature_views}") - logger.info( - f"get_historical_features for: entity_df from {entity_df.index[0]} to {entity_df.index[len(entity_df)-1]}, " - f"feature_views is {[(fv.name, fv.entities) for fv in feature_views]}" - f"feature_refs is {feature_refs}" + table = self.flights[key] + self.offline_store.offline_write_batch( + self.store.config, feature_views[0], table, command["progress"] ) + except Exception as e: + logger.exception(e) + traceback.print_exc() + raise e - try: - training_df = ( - self.store._get_provider() - .get_historical_features( - config=self.store.config, - feature_views=feature_views, - feature_refs=feature_refs, - entity_df=entity_df, - registry=self.store._registry, - project=project, - full_feature_names=full_feature_names, - ) - .to_df() - ) - logger.debug(f"Len of training_df is {len(training_df)}") - table = pa.Table.from_pandas(training_df) - except Exception as e: - logger.exception(e) - traceback.print_exc() - raise e + def write_logged_features(self, command, key): + try: + table = self.flights[key] + feature_service = self.store.get_feature_service( + command["feature_service_name"] + ) - # Get service is consumed, so we clear the corresponding flight and data - del self.flights[key] + self.offline_store.write_logged_features( + config=self.store.config, + data=table, + source=FeatureServiceLoggingSource( + feature_service, self.store.config.project + ), + logging_config=feature_service.logging_config, + registry=self.store.registry, + ) + except Exception as e: + logger.exception(e) + traceback.print_exc() + raise e - return fl.RecordBatchStream(table) - else: - raise NotImplementedError + def pull_all_from_table_or_query(self, command): + try: + return self._pull_all_from_table_or_query(command).to_df() + except Exception as e: + logger.exception(e) + traceback.print_exc() + raise e + + def _pull_all_from_table_or_query(self, command): + return self.offline_store.pull_all_from_table_or_query( + self.store.config, + self.store.get_data_source(command["data_source_name"]), + command["join_key_columns"], + command["feature_name_columns"], + command["timestamp_field"], + utils.make_tzaware(datetime.fromisoformat(command["start_date"])), + utils.make_tzaware(datetime.fromisoformat(command["end_date"])), + ) + + def pull_latest_from_table_or_query(self, command): + try: + return self._pull_latest_from_table_or_query(command).to_df() + except Exception as e: + logger.exception(e) + traceback.print_exc() + raise e + + def _pull_latest_from_table_or_query(self, command): + return self.offline_store.pull_latest_from_table_or_query( + self.store.config, + self.store.get_data_source(command["data_source_name"]), + command["join_key_columns"], + command["feature_name_columns"], + command["timestamp_field"], + command["created_timestamp_column"], + utils.make_tzaware(datetime.fromisoformat(command["start_date"])), + utils.make_tzaware(datetime.fromisoformat(command["end_date"])), + ) def list_actions(self, context): - return [] + return [ + ( + "offline_write_batch", + "Writes the specified arrow table to the data source underlying the specified feature view.", + ), + ( + "write_logged_features", + "Writes logged features to a specified destination in the offline store.", + ), + ( + "persist", + "Synchronously executes the underlying query and persists the result in the same offline store at the " + "specified destination.", + ), + ] + + def get_historical_features(self, command, key): + try: + ret_job = self._get_historical_features(command, key) + training_df = ret_job.to_df() + + return training_df + except Exception as e: + logger.exception(e) + traceback.print_exc() + raise e + + def _get_historical_features(self, command, key): + # Extract parameters from the internal flights dictionary + entity_df_value = self.flights[key] + entity_df = pa.Table.to_pandas(entity_df_value) + feature_view_names = command["feature_view_names"] + name_aliases = command["name_aliases"] + feature_refs = command["feature_refs"] + project = command["project"] + full_feature_names = command["full_feature_names"] + feature_views = self.list_feature_views_by_name( + feature_view_names=feature_view_names, + name_aliases=name_aliases, + project=project, + ) + retJob = self.store._get_provider().get_historical_features( + config=self.store.config, + feature_views=feature_views, + feature_refs=feature_refs, + entity_df=entity_df, + registry=self.store.registry, + project=project, + full_feature_names=full_feature_names, + ) + return retJob + + def persist(self, command, key): + try: + api = command["api"] + if api == "get_historical_features": + ret_job = self._get_historical_features(command, key) + elif api == "pull_latest_from_table_or_query": + ret_job = self._pull_latest_from_table_or_query(command) + elif api == "pull_all_from_table_or_query": + ret_job = self._pull_all_from_table_or_query(command) + else: + raise NotImplementedError + + data_source = self.store.get_data_source(command["data_source_name"]) + storage = SavedDatasetStorage.from_data_source(data_source) + ret_job.persist(storage, command["allow_overwrite"], command["timeout"]) + except Exception as e: + logger.exception(e) + traceback.print_exc() + raise e def do_action(self, context, action): - raise NotImplementedError + command_descriptor = fl.FlightDescriptor.deserialize(action.body.to_pybytes()) - def do_drop_dataset(self, dataset): - pass + key = OfflineServer.descriptor_to_key(command_descriptor) + command = json.loads(key[1]) + logger.info(f"do_action command is {command}") + + if action.type == "offline_write_batch": + self.offline_write_batch(command, key) + elif action.type == "write_logged_features": + self.write_logged_features(command, key) + elif action.type == "persist": + self.persist(command, key) + else: + raise NotImplementedError def remove_dummies(fv: FeatureView) -> FeatureView: diff --git a/sdk/python/feast/templates/local/bootstrap.py b/sdk/python/feast/templates/local/bootstrap.py index 125eb7c2e72..ee2847c19c8 100644 --- a/sdk/python/feast/templates/local/bootstrap.py +++ b/sdk/python/feast/templates/local/bootstrap.py @@ -24,6 +24,7 @@ def bootstrap(): example_py_file = repo_path / "example_repo.py" replace_str_in_file(example_py_file, "%PARQUET_PATH%", str(driver_stats_path)) + replace_str_in_file(example_py_file, "%LOGGING_PATH%", str(data_path)) if __name__ == "__main__": diff --git a/sdk/python/feast/templates/local/feature_repo/example_repo.py b/sdk/python/feast/templates/local/feature_repo/example_repo.py index 5aed3371b14..debe9d45e92 100644 --- a/sdk/python/feast/templates/local/feature_repo/example_repo.py +++ b/sdk/python/feast/templates/local/feature_repo/example_repo.py @@ -13,6 +13,8 @@ PushSource, RequestSource, ) +from feast.feature_logging import LoggingConfig +from feast.infra.offline_stores.file_source import FileLoggingDestination from feast.on_demand_feature_view import on_demand_feature_view from feast.types import Float32, Float64, Int64 @@ -88,6 +90,9 @@ def transformed_conv_rate(inputs: pd.DataFrame) -> pd.DataFrame: driver_stats_fv[["conv_rate"]], # Sub-selects a feature from a feature view transformed_conv_rate, # Selects all features from the feature view ], + logging_config=LoggingConfig( + destination=FileLoggingDestination(path="%LOGGING_PATH%") + ), ) driver_activity_v2 = FeatureService( name="driver_activity_v2", features=[driver_stats_fv, transformed_conv_rate] diff --git a/sdk/python/tests/unit/infra/offline_stores/test_offline_store.py b/sdk/python/tests/unit/infra/offline_stores/test_offline_store.py index 956951f780c..fd50d376322 100644 --- a/sdk/python/tests/unit/infra/offline_stores/test_offline_store.py +++ b/sdk/python/tests/unit/infra/offline_stores/test_offline_store.py @@ -215,24 +215,26 @@ def retrieval_job(request, environment): port=0, ) environment.config._offline_store = offline_store_config + + entity_df = pd.DataFrame.from_dict( + { + "id": [1], + "event_timestamp": ["datetime"], + "val_to_add": [1], + } + ) + return RemoteRetrievalJob( client=MagicMock(), - feature_refs=[ - "str:str", - ], - feature_view_names=[ - "str:str", - ], - name_aliases=[ - "str:str", - ], - project="project", - entity_df=pd.DataFrame.from_dict( - { - "id": [1], - "event_timestamp": ["datetime"], - "val_to_add": [1], - } + api_parameters={ + "str": "str", + }, + api="api", + table=pyarrow.Table.from_pandas(entity_df), + entity_df=entity_df, + metadata=RetrievalMetadata( + features=["1", "2", "3", "4"], + keys=["1", "2", "3", "4"], ), ) else: diff --git a/sdk/python/tests/unit/test_offline_server.py b/sdk/python/tests/unit/test_offline_server.py index 9d5960821c5..5991e7450d1 100644 --- a/sdk/python/tests/unit/test_offline_server.py +++ b/sdk/python/tests/unit/test_offline_server.py @@ -9,6 +9,7 @@ import pytest from feast import FeatureStore +from feast.feature_logging import FeatureServiceLoggingSource from feast.infra.offline_stores.remote import ( RemoteOfflineStore, RemoteOfflineStoreConfig, @@ -43,7 +44,23 @@ def test_offline_server_is_alive(environment, empty_offline_server, arrow_client actions = list(client.list_actions()) flights = list(client.list_flights()) - assertpy.assert_that(actions).is_empty() + assertpy.assert_that(actions).is_equal_to( + [ + ( + "offline_write_batch", + "Writes the specified arrow table to the data source underlying the specified feature view.", + ), + ( + "write_logged_features", + "Writes logged features to a specified destination in the offline store.", + ), + ( + "persist", + "Synchronously executes the underlying query and persists the result in the same offline store at the " + "specified destination.", + ), + ] + ) assertpy.assert_that(flights).is_empty() @@ -81,7 +98,7 @@ def remote_feature_store(offline_server): return store -def test_get_historical_features(): +def test_remote_offline_store_apis(): with tempfile.TemporaryDirectory() as temp_dir: store = default_store(str(temp_dir)) location = "grpc+tcp://localhost:0" @@ -179,10 +196,9 @@ def _test_offline_write_batch(temp_dir, fs: FeatureStore): data_df = pd.read_parquet(data_file) feature_view = fs.get_feature_view("driver_hourly_stats") - with pytest.raises(NotImplementedError): - RemoteOfflineStore.offline_write_batch( - fs.config, feature_view, pa.Table.from_pandas(data_df), progress=None - ) + RemoteOfflineStore.offline_write_batch( + fs.config, feature_view, pa.Table.from_pandas(data_df), progress=None + ) def _test_write_logged_features(temp_dir, fs: FeatureStore): @@ -192,14 +208,13 @@ def _test_write_logged_features(temp_dir, fs: FeatureStore): data_df = pd.read_parquet(data_file) feature_service = fs.get_feature_service("driver_activity_v1") - with pytest.raises(NotImplementedError): - RemoteOfflineStore.write_logged_features( - config=fs.config, - data=pa.Table.from_pandas(data_df), - source=feature_service, - logging_config=None, - registry=fs.registry, - ) + RemoteOfflineStore.write_logged_features( + config=fs.config, + data=pa.Table.from_pandas(data_df), + source=FeatureServiceLoggingSource(feature_service, fs.config.project), + logging_config=feature_service.logging_config, + registry=fs.registry, + ) def _test_pull_latest_from_table_or_query(temp_dir, fs: FeatureStore): @@ -207,17 +222,16 @@ def _test_pull_latest_from_table_or_query(temp_dir, fs: FeatureStore): end_date = datetime.now().replace(microsecond=0, second=0, minute=0) start_date = end_date - timedelta(days=15) - with pytest.raises(NotImplementedError): - RemoteOfflineStore.pull_latest_from_table_or_query( - config=fs.config, - data_source=data_source, - join_key_columns=[], - feature_name_columns=[], - timestamp_field="event_timestamp", - created_timestamp_column="created", - start_date=start_date, - end_date=end_date, - ) + RemoteOfflineStore.pull_latest_from_table_or_query( + config=fs.config, + data_source=data_source, + join_key_columns=[], + feature_name_columns=[], + timestamp_field="event_timestamp", + created_timestamp_column="created", + start_date=start_date, + end_date=end_date, + ).to_df() def _test_pull_all_from_table_or_query(temp_dir, fs: FeatureStore): @@ -225,13 +239,12 @@ def _test_pull_all_from_table_or_query(temp_dir, fs: FeatureStore): end_date = datetime.now().replace(microsecond=0, second=0, minute=0) start_date = end_date - timedelta(days=15) - with pytest.raises(NotImplementedError): - RemoteOfflineStore.pull_all_from_table_or_query( - config=fs.config, - data_source=data_source, - join_key_columns=[], - feature_name_columns=[], - timestamp_field="event_timestamp", - start_date=start_date, - end_date=end_date, - ) + RemoteOfflineStore.pull_all_from_table_or_query( + config=fs.config, + data_source=data_source, + join_key_columns=[], + feature_name_columns=[], + timestamp_field="event_timestamp", + start_date=start_date, + end_date=end_date, + ).to_df()