Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update transforms module to PEP 585 type hints #33081

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 18 additions & 20 deletions sdks/python/apache_beam/transforms/batch_dofn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
# pytype: skip-file

import unittest
from typing import Iterator
from typing import List
from typing import Tuple
from collections.abc import Iterator
from typing import no_type_check

from parameterized import parameterized_class
Expand All @@ -36,13 +34,13 @@ def process(self, element: int, *args, **kwargs) -> Iterator[float]:


class BatchDoFn(beam.DoFn):
def process_batch(self, batch: List[int], *args,
**kwargs) -> Iterator[List[float]]:
def process_batch(self, batch: list[int], *args,
**kwargs) -> Iterator[list[float]]:
yield [element / 2 for element in batch]


class NoReturnAnnotation(beam.DoFn):
def process_batch(self, batch: List[int], *args, **kwargs):
def process_batch(self, batch: list[int], *args, **kwargs):
yield [element * 2 for element in batch]


Expand All @@ -51,24 +49,24 @@ def process_batch(self, batch, *args, **kwargs):
yield [element * 2 for element in batch]

def get_input_batch_type(self, input_element_type):
return List[input_element_type]
return list[input_element_type]

def get_output_batch_type(self, input_element_type):
return List[input_element_type]
return list[input_element_type]


class EitherDoFn(beam.DoFn):
def process(self, element: int, *args, **kwargs) -> Iterator[float]:
yield element / 2

def process_batch(self, batch: List[int], *args,
**kwargs) -> Iterator[List[float]]:
def process_batch(self, batch: list[int], *args,
**kwargs) -> Iterator[list[float]]:
yield [element / 2 for element in batch]


class ElementToBatchDoFn(beam.DoFn):
@beam.DoFn.yields_batches
def process(self, element: int, *args, **kwargs) -> Iterator[List[int]]:
def process(self, element: int, *args, **kwargs) -> Iterator[list[int]]:
yield [element] * element

def infer_output_type(self, input_element_type):
Expand All @@ -77,8 +75,8 @@ def infer_output_type(self, input_element_type):

class BatchToElementDoFn(beam.DoFn):
@beam.DoFn.yields_elements
def process_batch(self, batch: List[int], *args,
**kwargs) -> Iterator[Tuple[int, int]]:
def process_batch(self, batch: list[int], *args,
**kwargs) -> Iterator[tuple[int, int]]:
yield (sum(batch), len(batch))


Expand Down Expand Up @@ -178,11 +176,11 @@ class MismatchedBatchProducingDoFn(beam.DoFn):
mismatched return types (one yields floats, the other ints). Should yield
a construction time error when applied."""
@beam.DoFn.yields_batches
def process(self, element: int, *args, **kwargs) -> Iterator[List[int]]:
def process(self, element: int, *args, **kwargs) -> Iterator[list[int]]:
yield [element]

def process_batch(self, batch: List[int], *args,
**kwargs) -> Iterator[List[float]]:
def process_batch(self, batch: list[int], *args,
**kwargs) -> Iterator[list[float]]:
yield [element / 2 for element in batch]


Expand All @@ -194,13 +192,13 @@ def process(self, element: int, *args, **kwargs) -> Iterator[float]:
yield element / 2

@beam.DoFn.yields_elements
def process_batch(self, batch: List[int], *args, **kwargs) -> Iterator[int]:
def process_batch(self, batch: list[int], *args, **kwargs) -> Iterator[int]:
yield batch[0]


class NoElementOutputAnnotation(beam.DoFn):
def process_batch(self, batch: List[int], *args,
**kwargs) -> Iterator[List[int]]:
def process_batch(self, batch: list[int], *args,
**kwargs) -> Iterator[list[int]]:
yield [element * 2 for element in batch]


Expand All @@ -225,7 +223,7 @@ def test_no_input_annotation_raises(self):
def test_unsupported_dofn_param_raises(self):
class BadParam(beam.DoFn):
@no_type_check
def process_batch(self, batch: List[int], key=beam.DoFn.KeyParam):
def process_batch(self, batch: list[int], key=beam.DoFn.KeyParam):
yield batch * key

p = beam.Pipeline()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
# pytype: skip-file

import math
from typing import Set
from typing import Tuple

import apache_beam as beam
from apache_beam.options.pipeline_options import TypeOptions
Expand All @@ -36,7 +34,7 @@
@with_input_types(int)
@with_output_types(int)
class CallSequenceEnforcingCombineFn(beam.CombineFn):
instances: Set['CallSequenceEnforcingCombineFn'] = set()
instances: set['CallSequenceEnforcingCombineFn'] = set()

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -81,8 +79,8 @@ def teardown(self, *args, **kwargs):
self._teardown_called = True


@with_input_types(Tuple[None, str])
@with_output_types(Tuple[int, str])
@with_input_types(tuple[None, str])
@with_output_types(tuple[int, str])
class IndexAssigningDoFn(beam.DoFn):
state_param = beam.DoFn.StateParam(
userstate.CombiningValueStateSpec(
Expand Down
68 changes: 32 additions & 36 deletions sdks/python/apache_beam/transforms/combiners.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@
import itertools
import operator
import random
from collections.abc import Iterable
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Set
from typing import Tuple
from typing import TypeVar
from typing import Union

Expand Down Expand Up @@ -146,15 +142,15 @@ def expand(self, pcoll):
else:
return pcoll | core.CombineGlobally(CountCombineFn()).without_defaults()

@with_input_types(Tuple[K, V])
@with_output_types(Tuple[K, int])
@with_input_types(tuple[K, V])
@with_output_types(tuple[K, int])
class PerKey(ptransform.PTransform):
"""combiners.Count.PerKey counts how many elements each unique key has."""
def expand(self, pcoll):
return pcoll | core.CombinePerKey(CountCombineFn())

@with_input_types(T)
@with_output_types(Tuple[T, int])
@with_output_types(tuple[T, int])
class PerElement(ptransform.PTransform):
"""combiners.Count.PerElement counts how many times each element occurs."""
def expand(self, pcoll):
Expand Down Expand Up @@ -193,7 +189,7 @@ class Top(object):

# pylint: disable=no-self-argument
@with_input_types(T)
@with_output_types(List[T])
@with_output_types(list[T])
class Of(CombinerWithoutDefaults):
"""Returns the n greatest elements in the PCollection.

Expand Down Expand Up @@ -246,8 +242,8 @@ def expand(self, pcoll):
TopCombineFn(self._n, self._key,
self._reverse)).without_defaults()

@with_input_types(Tuple[K, V])
@with_output_types(Tuple[K, List[V]])
@with_input_types(tuple[K, V])
@with_output_types(tuple[K, list[V]])
class PerKey(ptransform.PTransform):
"""Identifies the N greatest elements associated with each key.

Expand Down Expand Up @@ -280,7 +276,7 @@ def expand(self, pcoll):
"""Expands the transform.

Raises TypeCheckError: If the output type of the input PCollection is not
compatible with Tuple[A, B].
compatible with tuple[A, B].

Args:
pcoll: PCollection to process
Expand Down Expand Up @@ -323,7 +319,7 @@ def SmallestPerKey(pcoll, n, *, key=None):


@with_input_types(T)
@with_output_types(Tuple[None, List[T]])
@with_output_types(tuple[None, list[T]])
class _TopPerBundle(core.DoFn):
def __init__(self, n, key, reverse):
self._n = n
Expand Down Expand Up @@ -356,8 +352,8 @@ def finish_bundle(self):
yield window.GlobalWindows.windowed_value((None, self._heap))


@with_input_types(Tuple[None, Iterable[List[T]]])
@with_output_types(List[T])
@with_input_types(tuple[None, Iterable[list[T]]])
@with_output_types(list[T])
class _MergeTopPerBundle(core.DoFn):
def __init__(self, n, key, reverse):
self._n = n
Expand All @@ -380,7 +376,7 @@ def push(hp, e):
return False

if self._compare or self._key:
heapc: List[cy_combiners.ComparableValue] = []
heapc: list[cy_combiners.ComparableValue] = []
for bundle in bundles:
if not heapc:
heapc = [
Expand Down Expand Up @@ -421,7 +417,7 @@ def push(hp, e):


@with_input_types(T)
@with_output_types(List[T])
@with_output_types(list[T])
class TopCombineFn(core.CombineFn):
"""CombineFn doing the combining for all of the Top transforms.

Expand Down Expand Up @@ -467,7 +463,7 @@ def display_data(self):
}

# The accumulator type is a tuple
# (bool, Union[List[T], List[ComparableValue[T]])
# (bool, Union[list[T], list[ComparableValue[T]])
# where the boolean indicates whether the second slot contains a List of T
# (False) or List of ComparableValue[T] (True). In either case, the List
# maintains heap invariance. When the contents of the List are
Expand Down Expand Up @@ -564,7 +560,7 @@ class Sample(object):
# pylint: disable=no-self-argument

@with_input_types(T)
@with_output_types(List[T])
@with_output_types(list[T])
class FixedSizeGlobally(CombinerWithoutDefaults):
"""Sample n elements from the input PCollection without replacement."""
def __init__(self, n):
Expand All @@ -584,8 +580,8 @@ def display_data(self):
def default_label(self):
return 'FixedSizeGlobally(%d)' % self._n

@with_input_types(Tuple[K, V])
@with_output_types(Tuple[K, List[V]])
@with_input_types(tuple[K, V])
@with_output_types(tuple[K, list[V]])
class FixedSizePerKey(ptransform.PTransform):
"""Sample n elements associated with each key without replacement."""
def __init__(self, n):
Expand All @@ -602,7 +598,7 @@ def default_label(self):


@with_input_types(T)
@with_output_types(List[T])
@with_output_types(list[T])
class SampleCombineFn(core.CombineFn):
"""CombineFn for all Sample transforms."""
def __init__(self, n):
Expand Down Expand Up @@ -734,7 +730,7 @@ def add_input(self, accumulator, element, *args, **kwargs):


@with_input_types(T)
@with_output_types(List[T])
@with_output_types(list[T])
class ToList(CombinerWithoutDefaults):
"""A global CombineFn that condenses a PCollection into a single list."""
def expand(self, pcoll):
Expand All @@ -746,7 +742,7 @@ def expand(self, pcoll):


@with_input_types(T)
@with_output_types(List[T])
@with_output_types(list[T])
class ToListCombineFn(core.CombineFn):
"""CombineFn for to_list."""
def create_accumulator(self):
Expand Down Expand Up @@ -780,8 +776,8 @@ def extract_output(self, accumulator):
return accumulator


@with_input_types(Tuple[K, V])
@with_output_types(Dict[K, V])
@with_input_types(tuple[K, V])
@with_output_types(dict[K, V])
class ToDict(CombinerWithoutDefaults):
"""A global CombineFn that condenses a PCollection into a single dict.

Expand All @@ -797,8 +793,8 @@ def expand(self, pcoll):
ToDictCombineFn()).without_defaults()


@with_input_types(Tuple[K, V])
@with_output_types(Dict[K, V])
@with_input_types(tuple[K, V])
@with_output_types(dict[K, V])
class ToDictCombineFn(core.CombineFn):
"""CombineFn for to_dict."""
def create_accumulator(self):
Expand All @@ -820,7 +816,7 @@ def extract_output(self, accumulator):


@with_input_types(T)
@with_output_types(Set[T])
@with_output_types(set[T])
class ToSet(CombinerWithoutDefaults):
"""A global CombineFn that condenses a PCollection into a set."""
def expand(self, pcoll):
Expand All @@ -832,7 +828,7 @@ def expand(self, pcoll):


@with_input_types(T)
@with_output_types(Set[T])
@with_output_types(set[T])
class ToSetCombineFn(core.CombineFn):
"""CombineFn for ToSet."""
def create_accumulator(self):
Expand Down Expand Up @@ -941,17 +937,17 @@ def expand(self, pcoll):
return (
pcoll
| core.ParDo(self.add_timestamp).with_output_types(
Tuple[T, TimestampType])
tuple[T, TimestampType])
| core.CombineGlobally(LatestCombineFn()))
else:
return (
pcoll
| core.ParDo(self.add_timestamp).with_output_types(
Tuple[T, TimestampType])
tuple[T, TimestampType])
| core.CombineGlobally(LatestCombineFn()).without_defaults())

@with_input_types(Tuple[K, V])
@with_output_types(Tuple[K, V])
@with_input_types(tuple[K, V])
@with_output_types(tuple[K, V])
class PerKey(ptransform.PTransform):
"""Compute elements with the latest timestamp for each key
from a keyed PCollection"""
Expand All @@ -964,11 +960,11 @@ def expand(self, pcoll):
return (
pcoll
| core.ParDo(self.add_timestamp).with_output_types(
Tuple[K, Tuple[T, TimestampType]])
tuple[K, tuple[T, TimestampType]])
| core.CombinePerKey(LatestCombineFn()))


@with_input_types(Tuple[T, TimestampType])
@with_input_types(tuple[T, TimestampType])
@with_output_types(T)
class LatestCombineFn(core.CombineFn):
"""CombineFn to get the element with the latest timestamp
Expand Down
Loading
Loading