Skip to content

Commit

Permalink
Modernize python type hints for apache_beam.
Browse files Browse the repository at this point in the history
This was done with com2ann plus some manaual edits.
  • Loading branch information
robertwb committed Oct 22, 2024
1 parent ef0fc8b commit edf7abf
Show file tree
Hide file tree
Showing 16 changed files with 165 additions and 210 deletions.
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/dataframe/doctests.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class _InMemoryResultRecorder(object):
"""

# Class-level value to survive pickling.
_ALL_RESULTS = {} # type: Dict[str, List[Any]]
_ALL_RESULTS: Dict[str, List[Any]] = {}

def __init__(self):
self._id = id(self)
Expand Down
33 changes: 15 additions & 18 deletions sdks/python/apache_beam/dataframe/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ class Session(object):
def __init__(self, bindings=None):
self._bindings = dict(bindings or {})

def evaluate(self, expr): # type: (Expression) -> Any
def evaluate(self, expr: 'Expression') -> Any:
if expr not in self._bindings:
self._bindings[expr] = expr.evaluate_at(self)
return self._bindings[expr]

def lookup(self, expr): # type: (Expression) -> Any
def lookup(self, expr: 'Expression') -> Any:
return self._bindings[expr]


Expand Down Expand Up @@ -251,9 +251,9 @@ def preserves_partition_by(self) -> partitionings.Partitioning:
class PlaceholderExpression(Expression):
"""An expression whose value must be explicitly bound in the session."""
def __init__(
self, # type: PlaceholderExpression
proxy, # type: T
reference=None, # type: Any
self,
proxy: T,
reference: Any = None,
):
"""Initialize a placeholder expression.
Expand Down Expand Up @@ -282,11 +282,7 @@ def preserves_partition_by(self):

class ConstantExpression(Expression):
"""An expression whose value is known at pipeline construction time."""
def __init__(
self, # type: ConstantExpression
value, # type: T
proxy=None # type: Optional[T]
):
def __init__(self, value: T, proxy: Optional[T] = None):
"""Initialize a constant expression.
Args:
Expand Down Expand Up @@ -319,14 +315,15 @@ def preserves_partition_by(self):
class ComputedExpression(Expression):
"""An expression whose value must be computed at pipeline execution time."""
def __init__(
self, # type: ComputedExpression
name, # type: str
func, # type: Callable[...,T]
args, # type: Iterable[Expression]
proxy=None, # type: Optional[T]
_id=None, # type: Optional[str]
requires_partition_by=partitionings.Index(), # type: partitionings.Partitioning
preserves_partition_by=partitionings.Singleton(), # type: partitionings.Partitioning
self,
name: str,
func: Callable[..., T],
args: Iterable[Expression],
proxy: Optional[T] = None,
_id: Optional[str] = None,
requires_partition_by: partitionings.Partitioning = partitionings.Index(),
preserves_partition_by: partitionings.Partitioning = partitionings.
Singleton(),
):
"""Initialize a computed expression.
Expand Down
25 changes: 10 additions & 15 deletions sdks/python/apache_beam/dataframe/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import collections
import logging
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import List
Expand All @@ -32,15 +31,13 @@
from apache_beam.dataframe import expressions
from apache_beam.dataframe import frames # pylint: disable=unused-import
from apache_beam.dataframe import partitionings
from apache_beam.pvalue import PCollection
from apache_beam.utils import windowed_value

__all__ = [
'DataframeTransform',
]

if TYPE_CHECKING:
# pylint: disable=ungrouped-imports
from apache_beam.pvalue import PCollection

T = TypeVar('T')

Expand Down Expand Up @@ -108,15 +105,15 @@ def expand(self, input_pcolls):
from apache_beam.dataframe import convert

# Convert inputs to a flat dict.
input_dict = _flatten(input_pcolls) # type: Dict[Any, PCollection]
input_dict: Dict[Any, PCollection] = _flatten(input_pcolls)
proxies = _flatten(self._proxy) if self._proxy is not None else {
tag: None
for tag in input_dict
}
input_frames = {
input_frames: Dict[Any, DeferredFrame] = {
k: convert.to_dataframe(pc, proxies[k])
for k, pc in input_dict.items()
} # type: Dict[Any, DeferredFrame] # noqa: F821
} # noqa: F821

# Apply the function.
frames_input = _substitute(input_pcolls, input_frames)
Expand Down Expand Up @@ -152,9 +149,9 @@ def expand(self, inputs):

def _apply_deferred_ops(
self,
inputs, # type: Dict[expressions.Expression, PCollection]
outputs, # type: Dict[Any, expressions.Expression]
): # -> Dict[Any, PCollection]
inputs: Dict[expressions.Expression, PCollection],
outputs: Dict[Any, expressions.Expression],
): # -> Dict[Any, PCollection]
"""Construct a Beam graph that evaluates a set of expressions on a set of
input PCollections.
Expand Down Expand Up @@ -585,11 +582,9 @@ def _concat(parts):


def _flatten(
valueish, # type: Union[T, List[T], Tuple[T], Dict[Any, T]]
root=(), # type: Tuple[Any, ...]
):
# type: (...) -> Mapping[Tuple[Any, ...], T]

valueish: Union[T, List[T], Tuple[T], Dict[Any, T]],
root: Tuple[Any, ...] = (),
) -> Mapping[Tuple[Any, ...], T]:
"""Given a nested structure of dicts, tuples, and lists, return a flat
dictionary where the values are the leafs and the keys are the "paths" to
these leaves.
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/io/avroio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@

class AvroBase(object):

_temp_files = [] # type: List[str]
_temp_files: List[str] = []

def __init__(self, methodName='runTest'):
super().__init__(methodName)
Expand Down
37 changes: 12 additions & 25 deletions sdks/python/apache_beam/io/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@
import uuid
from collections import namedtuple
from functools import partial
from typing import TYPE_CHECKING
from typing import Any
from typing import BinaryIO # pylint: disable=unused-import
from typing import Callable
Expand All @@ -115,15 +114,13 @@
from apache_beam.options.value_provider import ValueProvider
from apache_beam.transforms.periodicsequence import PeriodicImpulse
from apache_beam.transforms.userstate import CombiningValueStateSpec
from apache_beam.transforms.window import BoundedWindow
from apache_beam.transforms.window import FixedWindows
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import IntervalWindow
from apache_beam.utils.timestamp import MAX_TIMESTAMP
from apache_beam.utils.timestamp import Timestamp

if TYPE_CHECKING:
from apache_beam.transforms.window import BoundedWindow

__all__ = [
'EmptyMatchTreatment',
'MatchFiles',
Expand Down Expand Up @@ -382,8 +379,7 @@ def create_metadata(
mime_type="application/octet-stream",
compression_type=CompressionTypes.AUTO)

def open(self, fh):
# type: (BinaryIO) -> None
def open(self, fh: BinaryIO) -> None:
raise NotImplementedError

def write(self, record):
Expand Down Expand Up @@ -575,8 +571,7 @@ class signature or an instance of FileSink to this parameter. If none is
self._max_num_writers_per_bundle = max_writers_per_bundle

@staticmethod
def _get_sink_fn(input_sink):
# type: (...) -> Callable[[Any], FileSink]
def _get_sink_fn(input_sink) -> Callable[[Any], FileSink]:
if isinstance(input_sink, type) and issubclass(input_sink, FileSink):
return lambda x: input_sink()
elif isinstance(input_sink, FileSink):
Expand All @@ -588,8 +583,7 @@ def _get_sink_fn(input_sink):
return lambda x: TextSink()

@staticmethod
def _get_destination_fn(destination):
# type: (...) -> Callable[[Any], str]
def _get_destination_fn(destination) -> Callable[[Any], str]:
if isinstance(destination, ValueProvider):
return lambda elm: destination.get()
elif callable(destination):
Expand Down Expand Up @@ -757,12 +751,8 @@ def _check_orphaned_files(self, writer_key):


class _WriteShardedRecordsFn(beam.DoFn):

def __init__(self,
base_path,
sink_fn, # type: Callable[[Any], FileSink]
shards # type: int
):
def __init__(
self, base_path, sink_fn: Callable[[Any], FileSink], shards: int):
self.base_path = base_path
self.sink_fn = sink_fn
self.shards = shards
Expand Down Expand Up @@ -805,17 +795,13 @@ def process(


class _AppendShardedDestination(beam.DoFn):
def __init__(
self,
destination, # type: Callable[[Any], str]
shards # type: int
):
def __init__(self, destination: Callable[[Any], str], shards: int):
self.destination_fn = destination
self.shards = shards

# We start the shards for a single destination at an arbitrary point.
self._shard_counter = collections.defaultdict(
lambda: random.randrange(self.shards)) # type: DefaultDict[str, int]
self._shard_counter: DefaultDict[str, int] = collections.defaultdict(
lambda: random.randrange(self.shards))

def _next_shard_for_destination(self, destination):
self._shard_counter[destination] = ((self._shard_counter[destination] + 1) %
Expand All @@ -835,8 +821,9 @@ class _WriteUnshardedRecordsFn(beam.DoFn):
SPILLED_RECORDS = 'spilled_records'
WRITTEN_FILES = 'written_files'

_writers_and_sinks = None # type: Dict[Tuple[str, BoundedWindow], Tuple[BinaryIO, FileSink]]
_file_names = None # type: Dict[Tuple[str, BoundedWindow], str]
_writers_and_sinks: Dict[Tuple[str, BoundedWindow], Tuple[BinaryIO,
FileSink]] = None
_file_names: Dict[Tuple[str, BoundedWindow], str] = None

def __init__(
self,
Expand Down
10 changes: 6 additions & 4 deletions sdks/python/apache_beam/io/gcp/bigquery_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def _insert_load_job(

def _start_job(
self,
request, # type: bigquery.BigqueryJobsInsertRequest
request: 'bigquery.BigqueryJobsInsertRequest',
stream=None,
):
"""Inserts a BigQuery job.
Expand Down Expand Up @@ -1786,9 +1786,11 @@ def generate_bq_job_name(job_name, step_id, job_type, random=None):


def check_schema_equal(
left, right, *, ignore_descriptions=False, ignore_field_order=False):
# type: (Union[bigquery.TableSchema, bigquery.TableFieldSchema], Union[bigquery.TableSchema, bigquery.TableFieldSchema], bool, bool) -> bool

left: Union['bigquery.TableSchema', 'bigquery.TableFieldSchema'],
right: Union['bigquery.TableSchema', 'bigquery.TableFieldSchema'],
*,
ignore_descriptions: bool = False,
ignore_field_order: bool = False) -> bool:
"""Check whether schemas are equivalent.
This comparison function differs from using == to compare TableSchema
Expand Down
6 changes: 4 additions & 2 deletions sdks/python/apache_beam/io/gcp/gcsio.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,10 @@ def create_storage_client(pipeline_options, use_credentials=True):

class GcsIO(object):
"""Google Cloud Storage I/O client."""
def __init__(self, storage_client=None, pipeline_options=None):
# type: (Optional[storage.Client], Optional[Union[dict, PipelineOptions]]) -> None
def __init__(
self,
storage_client: Optional[storage.Client] = None,
pipeline_options: Optional[Union[dict, PipelineOptions]] = None) -> None:
if pipeline_options is None:
pipeline_options = PipelineOptions()
elif isinstance(pipeline_options, dict):
Expand Down
Loading

0 comments on commit edf7abf

Please sign in to comment.