From 941e5421fbb3c3b46746b5c1e7ec152f95220e1d Mon Sep 17 00:00:00 2001 From: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> Date: Tue, 12 Nov 2024 13:00:30 -0500 Subject: [PATCH] Update Enrichment Handlers to PEP 585 typing (#33087) * Update Enrichment Handlers to PEP 585 typing * BT test * linting --- .../enrichment_handlers/bigquery.py | 30 +++++++++---------- .../enrichment_handlers/bigtable.py | 5 ++-- .../enrichment_handlers/bigtable_it_test.py | 9 ++---- .../feast_feature_store.py | 7 ++--- .../feast_feature_store_it_test.py | 2 +- .../vertex_ai_feature_store.py | 5 ++-- 6 files changed, 25 insertions(+), 33 deletions(-) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py index ea98fb6b0bbd..06b40bf38cc1 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py @@ -14,11 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from collections.abc import Callable +from collections.abc import Mapping from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Mapping from typing import Optional from typing import Union @@ -30,7 +28,7 @@ from apache_beam.transforms.enrichment import EnrichmentSourceHandler QueryFn = Callable[[beam.Row], str] -ConditionValueFn = Callable[[beam.Row], List[Any]] +ConditionValueFn = Callable[[beam.Row], list[Any]] def _validate_bigquery_metadata( @@ -54,8 +52,8 @@ def _validate_bigquery_metadata( "`condition_value_fn`") -class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, List[Row]], - Union[Row, List[Row]]]): +class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, list[Row]], + Union[Row, list[Row]]]): """Enrichment handler for Google Cloud BigQuery. Use this handler with :class:`apache_beam.transforms.enrichment.Enrichment` @@ -83,8 +81,8 @@ def __init__( *, table_name: str = "", row_restriction_template: str = "", - fields: Optional[List[str]] = None, - column_names: Optional[List[str]] = None, + fields: Optional[list[str]] = None, + column_names: Optional[list[str]] = None, condition_value_fn: Optional[ConditionValueFn] = None, query_fn: Optional[QueryFn] = None, min_batch_size: int = 1, @@ -107,10 +105,10 @@ def __init__( row_restriction_template (str): A template string for the `WHERE` clause in the BigQuery query with placeholders (`{}`) to dynamically filter rows based on input data. - fields: (Optional[List[str]]) List of field names present in the input + fields: (Optional[list[str]]) List of field names present in the input `beam.Row`. These are used to construct the WHERE clause (if `condition_value_fn` is not provided). - column_names: (Optional[List[str]]) Names of columns to select from the + column_names: (Optional[list[str]]) Names of columns to select from the BigQuery table. If not provided, all columns (`*`) are selected. condition_value_fn: (Optional[Callable[[beam.Row], Any]]) A function that takes a `beam.Row` and returns a list of value to populate in the @@ -179,11 +177,11 @@ def create_row_key(self, row: beam.Row): return (tuple(row_dict[field] for field in self.fields)) raise ValueError("Either fields or condition_value_fn must be specified") - def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): - if isinstance(request, List): + def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs): + if isinstance(request, list): values = [] responses = [] - requests_map: Dict[Any, Any] = {} + requests_map: dict[Any, Any] = {} batch_size = len(request) raw_query = self.query_template if batch_size > 1: @@ -230,8 +228,8 @@ def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): def __exit__(self, exc_type, exc_val, exc_tb): self.client.close() - def get_cache_key(self, request: Union[beam.Row, List[beam.Row]]): - if isinstance(request, List): + def get_cache_key(self, request: Union[beam.Row, list[beam.Row]]): + if isinstance(request, list): cache_keys = [] for req in request: req_dict = req._asdict() diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py index ddb62c2f60d5..c251ab05ecab 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py @@ -15,9 +15,8 @@ # limitations under the License. # import logging +from collections.abc import Callable from typing import Any -from typing import Callable -from typing import Dict from typing import Optional from google.api_core.exceptions import NotFound @@ -115,7 +114,7 @@ def __call__(self, request: beam.Row, *args, **kwargs): Args: request: the input `beam.Row` to enrich. """ - response_dict: Dict[str, Any] = {} + response_dict: dict[str, Any] = {} row_key_str: str = "" try: if self._row_key_fn: diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py index 79d73178e94e..6bf57cefacbe 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py @@ -18,10 +18,7 @@ import datetime import logging import unittest -from typing import Dict -from typing import List from typing import NamedTuple -from typing import Tuple from unittest.mock import MagicMock import pytest @@ -57,8 +54,8 @@ class ValidateResponse(beam.DoFn): def __init__( self, n_fields: int, - fields: List[str], - enriched_fields: Dict[str, List[str]], + fields: list[str], + enriched_fields: dict[str, list[str]], include_timestamp: bool = False, ): self.n_fields = n_fields @@ -88,7 +85,7 @@ def process(self, element: beam.Row, *args, **kwargs): "Response from bigtable should contain a %s column_family with " "%s columns." % (column_family, columns)) if (self._include_timestamp and - not isinstance(element_dict[column_family][key][0], Tuple)): # type: ignore[arg-type] + not isinstance(element_dict[column_family][key][0], tuple)): raise BeamAssertException( "Response from bigtable should contain timestamp associated with " "its value.") diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py index dc2a71786f65..f8e8b4db1d7f 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py @@ -16,11 +16,10 @@ # import logging import tempfile +from collections.abc import Callable +from collections.abc import Mapping from pathlib import Path from typing import Any -from typing import Callable -from typing import List -from typing import Mapping from typing import Optional import apache_beam as beam @@ -95,7 +94,7 @@ class FeastFeatureStoreEnrichmentHandler(EnrichmentSourceHandler[beam.Row, def __init__( self, feature_store_yaml_path: str, - feature_names: Optional[List[str]] = None, + feature_names: Optional[list[str]] = None, feature_service_name: Optional[str] = "", full_feature_names: Optional[bool] = False, entity_id: str = "", diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py index 89cb39c2c19c..9c4dab3d68b8 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py @@ -22,8 +22,8 @@ """ import unittest +from collections.abc import Mapping from typing import Any -from typing import Mapping import pytest diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py index 753b04e1793d..b6de3aa1c826 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py @@ -15,7 +15,6 @@ # limitations under the License. # import logging -from typing import List import proto from google.api_core.exceptions import NotFound @@ -209,7 +208,7 @@ def __init__( api_endpoint: str, feature_store_id: str, entity_type_id: str, - feature_ids: List[str], + feature_ids: list[str], row_key: str, *, exception_level: ExceptionLevel = ExceptionLevel.WARN, @@ -224,7 +223,7 @@ def __init__( Vertex AI Feature Store (Legacy). feature_store_id (str): The id of the Vertex AI Feature Store (Legacy). entity_type_id (str): The entity type of the feature store. - feature_ids (List[str]): A list of feature-ids to fetch + feature_ids (list[str]): A list of feature-ids to fetch from the Feature Store. row_key (str): The row key field name containing the entity id for the feature values.