Skip to content

Commit

Permalink
Update Enrichment Handlers to PEP 585 typing (#33087)
Browse files Browse the repository at this point in the history
* Update Enrichment Handlers to PEP 585 typing

* BT test

* linting
  • Loading branch information
jrmccluskey authored Nov 12, 2024
1 parent 0f77934 commit 941e542
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 33 deletions.
30 changes: 14 additions & 16 deletions sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections.abc import Callable
from collections.abc import Mapping
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional
from typing import Union

Expand All @@ -30,7 +28,7 @@
from apache_beam.transforms.enrichment import EnrichmentSourceHandler

QueryFn = Callable[[beam.Row], str]
ConditionValueFn = Callable[[beam.Row], List[Any]]
ConditionValueFn = Callable[[beam.Row], list[Any]]


def _validate_bigquery_metadata(
Expand All @@ -54,8 +52,8 @@ def _validate_bigquery_metadata(
"`condition_value_fn`")


class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, List[Row]],
Union[Row, List[Row]]]):
class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, list[Row]],
Union[Row, list[Row]]]):
"""Enrichment handler for Google Cloud BigQuery.
Use this handler with :class:`apache_beam.transforms.enrichment.Enrichment`
Expand Down Expand Up @@ -83,8 +81,8 @@ def __init__(
*,
table_name: str = "",
row_restriction_template: str = "",
fields: Optional[List[str]] = None,
column_names: Optional[List[str]] = None,
fields: Optional[list[str]] = None,
column_names: Optional[list[str]] = None,
condition_value_fn: Optional[ConditionValueFn] = None,
query_fn: Optional[QueryFn] = None,
min_batch_size: int = 1,
Expand All @@ -107,10 +105,10 @@ def __init__(
row_restriction_template (str): A template string for the `WHERE` clause
in the BigQuery query with placeholders (`{}`) to dynamically filter
rows based on input data.
fields: (Optional[List[str]]) List of field names present in the input
fields: (Optional[list[str]]) List of field names present in the input
`beam.Row`. These are used to construct the WHERE clause
(if `condition_value_fn` is not provided).
column_names: (Optional[List[str]]) Names of columns to select from the
column_names: (Optional[list[str]]) Names of columns to select from the
BigQuery table. If not provided, all columns (`*`) are selected.
condition_value_fn: (Optional[Callable[[beam.Row], Any]]) A function
that takes a `beam.Row` and returns a list of value to populate in the
Expand Down Expand Up @@ -179,11 +177,11 @@ def create_row_key(self, row: beam.Row):
return (tuple(row_dict[field] for field in self.fields))
raise ValueError("Either fields or condition_value_fn must be specified")

def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs):
if isinstance(request, List):
def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
if isinstance(request, list):
values = []
responses = []
requests_map: Dict[Any, Any] = {}
requests_map: dict[Any, Any] = {}
batch_size = len(request)
raw_query = self.query_template
if batch_size > 1:
Expand Down Expand Up @@ -230,8 +228,8 @@ def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs):
def __exit__(self, exc_type, exc_val, exc_tb):
self.client.close()

def get_cache_key(self, request: Union[beam.Row, List[beam.Row]]):
if isinstance(request, List):
def get_cache_key(self, request: Union[beam.Row, list[beam.Row]]):
if isinstance(request, list):
cache_keys = []
for req in request:
req_dict = req._asdict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
# limitations under the License.
#
import logging
from collections.abc import Callable
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional

from google.api_core.exceptions import NotFound
Expand Down Expand Up @@ -115,7 +114,7 @@ def __call__(self, request: beam.Row, *args, **kwargs):
Args:
request: the input `beam.Row` to enrich.
"""
response_dict: Dict[str, Any] = {}
response_dict: dict[str, Any] = {}
row_key_str: str = ""
try:
if self._row_key_fn:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
import datetime
import logging
import unittest
from typing import Dict
from typing import List
from typing import NamedTuple
from typing import Tuple
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -57,8 +54,8 @@ class ValidateResponse(beam.DoFn):
def __init__(
self,
n_fields: int,
fields: List[str],
enriched_fields: Dict[str, List[str]],
fields: list[str],
enriched_fields: dict[str, list[str]],
include_timestamp: bool = False,
):
self.n_fields = n_fields
Expand Down Expand Up @@ -88,7 +85,7 @@ def process(self, element: beam.Row, *args, **kwargs):
"Response from bigtable should contain a %s column_family with "
"%s columns." % (column_family, columns))
if (self._include_timestamp and
not isinstance(element_dict[column_family][key][0], Tuple)): # type: ignore[arg-type]
not isinstance(element_dict[column_family][key][0], tuple)):
raise BeamAssertException(
"Response from bigtable should contain timestamp associated with "
"its value.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
#
import logging
import tempfile
from collections.abc import Callable
from collections.abc import Mapping
from pathlib import Path
from typing import Any
from typing import Callable
from typing import List
from typing import Mapping
from typing import Optional

import apache_beam as beam
Expand Down Expand Up @@ -95,7 +94,7 @@ class FeastFeatureStoreEnrichmentHandler(EnrichmentSourceHandler[beam.Row,
def __init__(
self,
feature_store_yaml_path: str,
feature_names: Optional[List[str]] = None,
feature_names: Optional[list[str]] = None,
feature_service_name: Optional[str] = "",
full_feature_names: Optional[bool] = False,
entity_id: str = "",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
"""

import unittest
from collections.abc import Mapping
from typing import Any
from typing import Mapping

import pytest

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
#
import logging
from typing import List

import proto
from google.api_core.exceptions import NotFound
Expand Down Expand Up @@ -209,7 +208,7 @@ def __init__(
api_endpoint: str,
feature_store_id: str,
entity_type_id: str,
feature_ids: List[str],
feature_ids: list[str],
row_key: str,
*,
exception_level: ExceptionLevel = ExceptionLevel.WARN,
Expand All @@ -224,7 +223,7 @@ def __init__(
Vertex AI Feature Store (Legacy).
feature_store_id (str): The id of the Vertex AI Feature Store (Legacy).
entity_type_id (str): The entity type of the feature store.
feature_ids (List[str]): A list of feature-ids to fetch
feature_ids (list[str]): A list of feature-ids to fetch
from the Feature Store.
row_key (str): The row key field name containing the entity id
for the feature values.
Expand Down

0 comments on commit 941e542

Please sign in to comment.