From 10689f2f8c0da41bd29c3857c6016d12e7dfe150 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 17 Oct 2024 17:35:35 -0700 Subject: [PATCH] Modernize python type hints for apache_beam. This was done with com2ann plus some manaual edits. --- sdks/python/apache_beam/dataframe/doctests.py | 2 +- .../apache_beam/dataframe/expressions.py | 33 ++++--- .../apache_beam/dataframe/transforms.py | 27 +++--- sdks/python/apache_beam/io/avroio_test.py | 2 +- sdks/python/apache_beam/io/fileio.py | 37 +++----- .../apache_beam/io/gcp/bigquery_tools.py | 10 ++- sdks/python/apache_beam/io/gcp/gcsio.py | 6 +- .../apache_beam/metrics/monitoring_infos.py | 54 ++++++----- .../apache_beam/options/pipeline_options.py | 26 ++---- sdks/python/apache_beam/pipeline_test.py | 8 +- .../runners/interactive/interactive_beam.py | 18 ++-- .../apache_beam/runners/interactive/utils.py | 50 +++++------ .../apache_beam/runners/portability/stager.py | 90 +++++++++---------- .../python/apache_beam/transforms/external.py | 10 +-- .../apache_beam/transforms/ptransform_test.py | 2 +- .../typehints/trivial_inference.py | 2 +- 16 files changed, 166 insertions(+), 211 deletions(-) diff --git a/sdks/python/apache_beam/dataframe/doctests.py b/sdks/python/apache_beam/dataframe/doctests.py index 33faa6b58599..84f5420e6e78 100644 --- a/sdks/python/apache_beam/dataframe/doctests.py +++ b/sdks/python/apache_beam/dataframe/doctests.py @@ -146,7 +146,7 @@ class _InMemoryResultRecorder(object): """ # Class-level value to survive pickling. - _ALL_RESULTS = {} # type: Dict[str, List[Any]] + _ALL_RESULTS: Dict[str, List[Any]] = {} def __init__(self): self._id = id(self) diff --git a/sdks/python/apache_beam/dataframe/expressions.py b/sdks/python/apache_beam/dataframe/expressions.py index 91d237c7de96..f0704dd37c81 100644 --- a/sdks/python/apache_beam/dataframe/expressions.py +++ b/sdks/python/apache_beam/dataframe/expressions.py @@ -36,12 +36,12 @@ class Session(object): def __init__(self, bindings=None): self._bindings = dict(bindings or {}) - def evaluate(self, expr): # type: (Expression) -> Any + def evaluate(self, expr: 'Expression') -> Any: if expr not in self._bindings: self._bindings[expr] = expr.evaluate_at(self) return self._bindings[expr] - def lookup(self, expr): # type: (Expression) -> Any + def lookup(self, expr: 'Expression') -> Any: return self._bindings[expr] @@ -251,9 +251,9 @@ def preserves_partition_by(self) -> partitionings.Partitioning: class PlaceholderExpression(Expression): """An expression whose value must be explicitly bound in the session.""" def __init__( - self, # type: PlaceholderExpression - proxy, # type: T - reference=None, # type: Any + self, + proxy: T, + reference: Any = None, ): """Initialize a placeholder expression. @@ -282,11 +282,7 @@ def preserves_partition_by(self): class ConstantExpression(Expression): """An expression whose value is known at pipeline construction time.""" - def __init__( - self, # type: ConstantExpression - value, # type: T - proxy=None # type: Optional[T] - ): + def __init__(self, value: T, proxy: Optional[T] = None): """Initialize a constant expression. Args: @@ -319,14 +315,15 @@ def preserves_partition_by(self): class ComputedExpression(Expression): """An expression whose value must be computed at pipeline execution time.""" def __init__( - self, # type: ComputedExpression - name, # type: str - func, # type: Callable[...,T] - args, # type: Iterable[Expression] - proxy=None, # type: Optional[T] - _id=None, # type: Optional[str] - requires_partition_by=partitionings.Index(), # type: partitionings.Partitioning - preserves_partition_by=partitionings.Singleton(), # type: partitionings.Partitioning + self, + name: str, + func: Callable[..., T], + args: Iterable[Expression], + proxy: Optional[T] = None, + _id: Optional[str] = None, + requires_partition_by: partitionings.Partitioning = partitionings.Index(), + preserves_partition_by: partitionings.Partitioning = partitionings. + Singleton(), ): """Initialize a computed expression. diff --git a/sdks/python/apache_beam/dataframe/transforms.py b/sdks/python/apache_beam/dataframe/transforms.py index d0b5be4eb2a9..dee3f61827be 100644 --- a/sdks/python/apache_beam/dataframe/transforms.py +++ b/sdks/python/apache_beam/dataframe/transforms.py @@ -16,7 +16,6 @@ import collections import logging -from typing import TYPE_CHECKING from typing import Any from typing import Dict from typing import List @@ -30,18 +29,16 @@ import apache_beam as beam from apache_beam import transforms from apache_beam.dataframe import expressions +from apache_beam.dataframe import frame_base from apache_beam.dataframe import frames # pylint: disable=unused-import from apache_beam.dataframe import partitionings +from apache_beam.pvalue import PCollection from apache_beam.utils import windowed_value __all__ = [ 'DataframeTransform', ] -if TYPE_CHECKING: - # pylint: disable=ungrouped-imports - from apache_beam.pvalue import PCollection - T = TypeVar('T') TARGET_PARTITION_SIZE = 1 << 23 # 8M @@ -108,15 +105,15 @@ def expand(self, input_pcolls): from apache_beam.dataframe import convert # Convert inputs to a flat dict. - input_dict = _flatten(input_pcolls) # type: Dict[Any, PCollection] + input_dict: Dict[Any, PCollection] = _flatten(input_pcolls) proxies = _flatten(self._proxy) if self._proxy is not None else { tag: None for tag in input_dict } - input_frames = { + input_frames: Dict[Any, frame_base.DeferredFrame] = { k: convert.to_dataframe(pc, proxies[k]) for k, pc in input_dict.items() - } # type: Dict[Any, DeferredFrame] # noqa: F821 + } # noqa: F821 # Apply the function. frames_input = _substitute(input_pcolls, input_frames) @@ -152,9 +149,9 @@ def expand(self, inputs): def _apply_deferred_ops( self, - inputs, # type: Dict[expressions.Expression, PCollection] - outputs, # type: Dict[Any, expressions.Expression] - ): # -> Dict[Any, PCollection] + inputs: Dict[expressions.Expression, PCollection], + outputs: Dict[Any, expressions.Expression], + ): # -> Dict[Any, PCollection] """Construct a Beam graph that evaluates a set of expressions on a set of input PCollections. @@ -585,11 +582,9 @@ def _concat(parts): def _flatten( - valueish, # type: Union[T, List[T], Tuple[T], Dict[Any, T]] - root=(), # type: Tuple[Any, ...] - ): - # type: (...) -> Mapping[Tuple[Any, ...], T] - + valueish: Union[T, List[T], Tuple[T], Dict[Any, T]], + root: Tuple[Any, ...] = (), +) -> Mapping[Tuple[Any, ...], T]: """Given a nested structure of dicts, tuples, and lists, return a flat dictionary where the values are the leafs and the keys are the "paths" to these leaves. diff --git a/sdks/python/apache_beam/io/avroio_test.py b/sdks/python/apache_beam/io/avroio_test.py index 77b20117e702..2d25010da486 100644 --- a/sdks/python/apache_beam/io/avroio_test.py +++ b/sdks/python/apache_beam/io/avroio_test.py @@ -82,7 +82,7 @@ class AvroBase(object): - _temp_files = [] # type: List[str] + _temp_files: List[str] = [] def __init__(self, methodName='runTest'): super().__init__(methodName) diff --git a/sdks/python/apache_beam/io/fileio.py b/sdks/python/apache_beam/io/fileio.py index d9b2a2040675..111206a18a28 100644 --- a/sdks/python/apache_beam/io/fileio.py +++ b/sdks/python/apache_beam/io/fileio.py @@ -94,7 +94,6 @@ import uuid from collections import namedtuple from functools import partial -from typing import TYPE_CHECKING from typing import Any from typing import BinaryIO # pylint: disable=unused-import from typing import Callable @@ -115,15 +114,13 @@ from apache_beam.options.value_provider import ValueProvider from apache_beam.transforms.periodicsequence import PeriodicImpulse from apache_beam.transforms.userstate import CombiningValueStateSpec +from apache_beam.transforms.window import BoundedWindow from apache_beam.transforms.window import FixedWindows from apache_beam.transforms.window import GlobalWindow from apache_beam.transforms.window import IntervalWindow from apache_beam.utils.timestamp import MAX_TIMESTAMP from apache_beam.utils.timestamp import Timestamp -if TYPE_CHECKING: - from apache_beam.transforms.window import BoundedWindow - __all__ = [ 'EmptyMatchTreatment', 'MatchFiles', @@ -382,8 +379,7 @@ def create_metadata( mime_type="application/octet-stream", compression_type=CompressionTypes.AUTO) - def open(self, fh): - # type: (BinaryIO) -> None + def open(self, fh: BinaryIO) -> None: raise NotImplementedError def write(self, record): @@ -575,8 +571,7 @@ class signature or an instance of FileSink to this parameter. If none is self._max_num_writers_per_bundle = max_writers_per_bundle @staticmethod - def _get_sink_fn(input_sink): - # type: (...) -> Callable[[Any], FileSink] + def _get_sink_fn(input_sink) -> Callable[[Any], FileSink]: if isinstance(input_sink, type) and issubclass(input_sink, FileSink): return lambda x: input_sink() elif isinstance(input_sink, FileSink): @@ -588,8 +583,7 @@ def _get_sink_fn(input_sink): return lambda x: TextSink() @staticmethod - def _get_destination_fn(destination): - # type: (...) -> Callable[[Any], str] + def _get_destination_fn(destination) -> Callable[[Any], str]: if isinstance(destination, ValueProvider): return lambda elm: destination.get() elif callable(destination): @@ -757,12 +751,8 @@ def _check_orphaned_files(self, writer_key): class _WriteShardedRecordsFn(beam.DoFn): - - def __init__(self, - base_path, - sink_fn, # type: Callable[[Any], FileSink] - shards # type: int - ): + def __init__( + self, base_path, sink_fn: Callable[[Any], FileSink], shards: int): self.base_path = base_path self.sink_fn = sink_fn self.shards = shards @@ -805,17 +795,13 @@ def process( class _AppendShardedDestination(beam.DoFn): - def __init__( - self, - destination, # type: Callable[[Any], str] - shards # type: int - ): + def __init__(self, destination: Callable[[Any], str], shards: int): self.destination_fn = destination self.shards = shards # We start the shards for a single destination at an arbitrary point. - self._shard_counter = collections.defaultdict( - lambda: random.randrange(self.shards)) # type: DefaultDict[str, int] + self._shard_counter: DefaultDict[str, int] = collections.defaultdict( + lambda: random.randrange(self.shards)) def _next_shard_for_destination(self, destination): self._shard_counter[destination] = ((self._shard_counter[destination] + 1) % @@ -835,8 +821,9 @@ class _WriteUnshardedRecordsFn(beam.DoFn): SPILLED_RECORDS = 'spilled_records' WRITTEN_FILES = 'written_files' - _writers_and_sinks = None # type: Dict[Tuple[str, BoundedWindow], Tuple[BinaryIO, FileSink]] - _file_names = None # type: Dict[Tuple[str, BoundedWindow], str] + _writers_and_sinks: Dict[Tuple[str, BoundedWindow], Tuple[BinaryIO, + FileSink]] = None + _file_names: Dict[Tuple[str, BoundedWindow], str] = None def __init__( self, diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_tools.py index c7128e7899ec..afd580710219 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools.py @@ -560,7 +560,7 @@ def _insert_load_job( def _start_job( self, - request, # type: bigquery.BigqueryJobsInsertRequest + request: 'bigquery.BigqueryJobsInsertRequest', stream=None, ): """Inserts a BigQuery job. @@ -1786,9 +1786,11 @@ def generate_bq_job_name(job_name, step_id, job_type, random=None): def check_schema_equal( - left, right, *, ignore_descriptions=False, ignore_field_order=False): - # type: (Union[bigquery.TableSchema, bigquery.TableFieldSchema], Union[bigquery.TableSchema, bigquery.TableFieldSchema], bool, bool) -> bool - + left: Union['bigquery.TableSchema', 'bigquery.TableFieldSchema'], + right: Union['bigquery.TableSchema', 'bigquery.TableFieldSchema'], + *, + ignore_descriptions: bool = False, + ignore_field_order: bool = False) -> bool: """Check whether schemas are equivalent. This comparison function differs from using == to compare TableSchema diff --git a/sdks/python/apache_beam/io/gcp/gcsio.py b/sdks/python/apache_beam/io/gcp/gcsio.py index 22a33fa13c63..8056de51f43f 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio.py +++ b/sdks/python/apache_beam/io/gcp/gcsio.py @@ -137,8 +137,10 @@ def create_storage_client(pipeline_options, use_credentials=True): class GcsIO(object): """Google Cloud Storage I/O client.""" - def __init__(self, storage_client=None, pipeline_options=None): - # type: (Optional[storage.Client], Optional[Union[dict, PipelineOptions]]) -> None + def __init__( + self, + storage_client: Optional[storage.Client] = None, + pipeline_options: Optional[Union[dict, PipelineOptions]] = None) -> None: if pipeline_options is None: pipeline_options = PipelineOptions() elif isinstance(pipeline_options, dict): diff --git a/sdks/python/apache_beam/metrics/monitoring_infos.py b/sdks/python/apache_beam/metrics/monitoring_infos.py index 09cb350b3826..5227a4c9872b 100644 --- a/sdks/python/apache_beam/metrics/monitoring_infos.py +++ b/sdks/python/apache_beam/metrics/monitoring_infos.py @@ -182,9 +182,8 @@ def create_labels(ptransform=None, namespace=None, name=None, pcollection=None): return labels -def int64_user_counter(namespace, name, metric, ptransform=None): - # type: (...) -> metrics_pb2.MonitoringInfo - +def int64_user_counter( + namespace, name, metric, ptransform=None) -> metrics_pb2.MonitoringInfo: """Return the counter monitoring info for the specifed URN, metric and labels. Args: @@ -199,9 +198,12 @@ def int64_user_counter(namespace, name, metric, ptransform=None): USER_COUNTER_URN, SUM_INT64_TYPE, metric, labels) -def int64_counter(urn, metric, ptransform=None, pcollection=None, labels=None): - # type: (...) -> metrics_pb2.MonitoringInfo - +def int64_counter( + urn, + metric, + ptransform=None, + pcollection=None, + labels=None) -> metrics_pb2.MonitoringInfo: """Return the counter monitoring info for the specifed URN, metric and labels. Args: @@ -217,9 +219,8 @@ def int64_counter(urn, metric, ptransform=None, pcollection=None, labels=None): return create_monitoring_info(urn, SUM_INT64_TYPE, metric, labels) -def int64_user_distribution(namespace, name, metric, ptransform=None): - # type: (...) -> metrics_pb2.MonitoringInfo - +def int64_user_distribution( + namespace, name, metric, ptransform=None) -> metrics_pb2.MonitoringInfo: """Return the distribution monitoring info for the URN, metric and labels. Args: @@ -234,9 +235,11 @@ def int64_user_distribution(namespace, name, metric, ptransform=None): USER_DISTRIBUTION_URN, DISTRIBUTION_INT64_TYPE, payload, labels) -def int64_distribution(urn, metric, ptransform=None, pcollection=None): - # type: (...) -> metrics_pb2.MonitoringInfo - +def int64_distribution( + urn, + metric, + ptransform=None, + pcollection=None) -> metrics_pb2.MonitoringInfo: """Return a distribution monitoring info for the URN, metric and labels. Args: @@ -251,9 +254,8 @@ def int64_distribution(urn, metric, ptransform=None, pcollection=None): return create_monitoring_info(urn, DISTRIBUTION_INT64_TYPE, payload, labels) -def int64_user_gauge(namespace, name, metric, ptransform=None): - # type: (...) -> metrics_pb2.MonitoringInfo - +def int64_user_gauge( + namespace, name, metric, ptransform=None) -> metrics_pb2.MonitoringInfo: """Return the gauge monitoring info for the URN, metric and labels. Args: @@ -276,9 +278,7 @@ def int64_user_gauge(namespace, name, metric, ptransform=None): USER_GAUGE_URN, LATEST_INT64_TYPE, payload, labels) -def int64_gauge(urn, metric, ptransform=None): - # type: (...) -> metrics_pb2.MonitoringInfo - +def int64_gauge(urn, metric, ptransform=None) -> metrics_pb2.MonitoringInfo: """Return the gauge monitoring info for the URN, metric and labels. Args: @@ -320,9 +320,8 @@ def user_set_string(namespace, name, metric, ptransform=None): USER_STRING_SET_URN, STRING_SET_TYPE, metric, labels) -def create_monitoring_info(urn, type_urn, payload, labels=None): - # type: (...) -> metrics_pb2.MonitoringInfo - +def create_monitoring_info( + urn, type_urn, payload, labels=None) -> metrics_pb2.MonitoringInfo: """Return the gauge monitoring info for the URN, type, metric and labels. Args: @@ -366,9 +365,9 @@ def is_user_monitoring_info(monitoring_info_proto): return monitoring_info_proto.urn in USER_METRIC_URNS -def extract_metric_result_map_value(monitoring_info_proto): - # type: (...) -> Union[None, int, DistributionResult, GaugeResult, set] - +def extract_metric_result_map_value( + monitoring_info_proto +) -> Union[None, int, DistributionResult, GaugeResult, set]: """Returns the relevant GaugeResult, DistributionResult or int value for counter metric, set for StringSet metric. @@ -408,14 +407,13 @@ def get_step_name(monitoring_info_proto): return monitoring_info_proto.labels.get(PTRANSFORM_LABEL) -def to_key(monitoring_info_proto): - # type: (metrics_pb2.MonitoringInfo) -> FrozenSet[Hashable] - +def to_key( + monitoring_info_proto: metrics_pb2.MonitoringInfo) -> FrozenSet[Hashable]: """Returns a key based on the URN and labels. This is useful in maps to prevent reporting the same MonitoringInfo twice. """ - key_items = list(monitoring_info_proto.labels.items()) # type: List[Hashable] + key_items: List[Hashable] = list(monitoring_info_proto.labels.items()) key_items.append(monitoring_info_proto.urn) return frozenset(key_items) diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 837dc0f5439f..2b78832ae493 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -267,8 +267,7 @@ def __getstate__(self): return self.__dict__ @classmethod - def _add_argparse_args(cls, parser): - # type: (_BeamArgumentParser) -> None + def _add_argparse_args(cls, parser: _BeamArgumentParser) -> None: # Override this in subclasses to provide options. pass @@ -317,11 +316,8 @@ def from_dictionary(cls, options): def get_all_options( self, drop_default=False, - add_extra_args_fn=None, # type: Optional[Callable[[_BeamArgumentParser], None]] - retain_unknown_options=False - ): - # type: (...) -> Dict[str, Any] - + add_extra_args_fn: Optional[Callable[[_BeamArgumentParser], None]] = None, + retain_unknown_options=False) -> Dict[str, Any]: """Returns a dictionary of all defined arguments. Returns a dictionary of all defined arguments (arguments that are defined in @@ -446,9 +442,7 @@ def from_urn(key): def display_data(self): return self.get_all_options(drop_default=True, retain_unknown_options=True) - def view_as(self, cls): - # type: (Type[PipelineOptionsT]) -> PipelineOptionsT - + def view_as(self, cls: Type[PipelineOptionsT]) -> PipelineOptionsT: """Returns a view of current object as provided PipelineOption subclass. Example Usage:: @@ -487,13 +481,11 @@ def view_as(self, cls): view._all_options = self._all_options return view - def _visible_option_list(self): - # type: () -> List[str] + def _visible_option_list(self) -> List[str]: return sorted( option for option in dir(self._visible_options) if option[0] != '_') - def __dir__(self): - # type: () -> List[str] + def __dir__(self) -> List[str]: return sorted( dir(type(self)) + list(self.__dict__) + self._visible_option_list()) @@ -643,9 +635,9 @@ def additional_option_ptransform_fn(): # Optional type checks that aren't enabled by default. -additional_type_checks = { +additional_type_checks: Dict[str, Callable[[], None]] = { 'ptransform_fn': additional_option_ptransform_fn, -} # type: Dict[str, Callable[[], None]] +} def enable_all_additional_type_checks(): @@ -1836,7 +1828,7 @@ class OptionsContext(object): Can also be used as a decorator. """ - overrides = [] # type: List[Dict[str, Any]] + overrides: List[Dict[str, Any]] = [] def __init__(self, **options): self.options = options diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py index 61aac350280f..1c11f953c58d 100644 --- a/sdks/python/apache_beam/pipeline_test.py +++ b/sdks/python/apache_beam/pipeline_test.py @@ -1053,7 +1053,7 @@ def expand(self, p): self.p = p return p | beam.Create([None]) - def display_data(self): # type: () -> dict + def display_data(self) -> dict: parent_dd = super().display_data() parent_dd['p_dd_string'] = DisplayDataItem( 'p_dd_string_value', label='p_dd_string_label') @@ -1067,7 +1067,7 @@ def expand(self, p): self.p = p return p | beam.Create([None]) - def display_data(self): # type: () -> dict + def display_data(self) -> dict: parent_dd = super().display_data() parent_dd['dd_string'] = DisplayDataItem( 'dd_string_value', label='dd_string_label') @@ -1183,7 +1183,7 @@ class UseMaxValueHint(ResourceHint): @classmethod def get_merged_value( - cls, outer_value, inner_value): # type: (bytes, bytes) -> bytes + cls, outer_value: bytes, inner_value: bytes) -> bytes: return ResourceHint._use_max(outer_value, inner_value) ResourceHint.register_resource_hint('foo_hint', FooHint) @@ -1312,7 +1312,7 @@ class UseMaxValueHint(ResourceHint): @classmethod def get_merged_value( - cls, outer_value, inner_value): # type: (bytes, bytes) -> bytes + cls, outer_value: bytes, inner_value: bytes) -> bytes: return ResourceHint._use_max(outer_value, inner_value) ResourceHint.register_resource_hint('foo_hint', FooHint) diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam.py b/sdks/python/apache_beam/runners/interactive/interactive_beam.py index 60453b5066c3..e3dc8b8968ad 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_beam.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_beam.py @@ -273,9 +273,9 @@ class Recordings(): from all defined unbounded sources for that PCollection's pipeline. The following methods allow for introspection into that background recording job. """ - def describe(self, pipeline=None): - # type: (Optional[beam.Pipeline]) -> dict[str, Any] # noqa: F821 - + def describe( + self, + pipeline: Optional[beam.Pipeline] = None) -> Dict[str, Any]: # noqa: F821 """Returns a description of all the recordings for the given pipeline. If no pipeline is given then this returns a dictionary of descriptions for @@ -292,9 +292,7 @@ def describe(self, pipeline=None): return description[pipeline] return description - def clear(self, pipeline): - # type: (beam.Pipeline) -> bool - + def clear(self, pipeline: beam.Pipeline) -> bool: """Clears all recordings of the given pipeline. Returns True if cleared.""" description = self.describe(pipeline) @@ -308,18 +306,14 @@ def clear(self, pipeline): ie.current_env().cleanup(pipeline) return True - def stop(self, pipeline): - # type: (beam.Pipeline) -> None - + def stop(self, pipeline: beam.Pipeline) -> None: """Stops the background source recording of the given pipeline.""" recording_manager = ie.current_env().get_recording_manager( pipeline, create_if_absent=True) recording_manager.cancel() - def record(self, pipeline): - # type: (beam.Pipeline) -> bool - + def record(self, pipeline: beam.Pipeline) -> bool: """Starts a background source recording job for the given pipeline. Returns True if the recording job was started. """ diff --git a/sdks/python/apache_beam/runners/interactive/utils.py b/sdks/python/apache_beam/runners/interactive/utils.py index b7d56ce90acb..828f23a467c2 100644 --- a/sdks/python/apache_beam/runners/interactive/utils.py +++ b/sdks/python/apache_beam/runners/interactive/utils.py @@ -24,12 +24,17 @@ import json import logging from typing import Any +from typing import Callable from typing import Dict +from typing import Iterator +from typing import List from typing import Tuple +from typing import Union import pandas as pd import apache_beam as beam +from apache_beam.coders import Coder from apache_beam.dataframe.convert import to_pcollection from apache_beam.dataframe.frame_base import DeferredBase from apache_beam.options.pipeline_options import PipelineOptions @@ -55,14 +60,13 @@ def to_element_list( - reader, # type: Generator[Union[beam_runner_api_pb2.TestStreamPayload.Event, WindowedValueHolder]] # noqa: F821 - coder, # type: Coder # noqa: F821 - include_window_info, # type: bool - n=None, # type: int - include_time_events=False, # type: bool -): - # type: (...) -> List[WindowedValue] # noqa: F821 - + reader: Iterator[Union[beam_runner_api_pb2.TestStreamPayload.Event, + WindowedValueHolder]], + coder: Coder, + include_window_info: bool, + n: int = None, + include_time_events: bool = False, +) -> List[WindowedValue]: """Returns an iterator that properly decodes the elements from the reader. """ @@ -102,9 +106,10 @@ def elements(): count += 1 -def elements_to_df(elements, include_window_info=False, element_type=None): - # type: (List[WindowedValue], bool, Any) -> DataFrame # noqa: F821 - +def elements_to_df( + elements: List[WindowedValue], + include_window_info: bool = False, + element_type: Any = None) -> 'DataFrame': # noqa: F821 """Parses the given elements into a Dataframe. If the elements are a list of WindowedValues, then it will break out the @@ -143,9 +148,7 @@ def elements_to_df(elements, include_window_info=False, element_type=None): return final_df -def register_ipython_log_handler(): - # type: () -> None - +def register_ipython_log_handler() -> None: """Adds the IPython handler to a dummy parent logger (named 'apache_beam.runners.interactive') of all interactive modules' loggers so that if is_in_notebook, logging displays the logs as HTML in frontends. @@ -200,9 +203,7 @@ def emit(self, record): pass # NOOP when dependencies are not available. -def obfuscate(*inputs): - # type: (*Any) -> str - +def obfuscate(*inputs: Any) -> str: """Obfuscates any inputs into a hexadecimal string.""" str_inputs = [str(input) for input in inputs] merged_inputs = '_'.join(str_inputs) @@ -223,8 +224,7 @@ class ProgressIndicator(object): spinner_removal_template = """ $("#{id}").remove();""" - def __init__(self, enter_text, exit_text): - # type: (str, str) -> None + def __init__(self, enter_text: str, exit_text: str) -> None: self._id = 'progress_indicator_{}'.format(obfuscate(id(self))) self._enter_text = enter_text @@ -267,9 +267,7 @@ def __exit__(self, exc_type, exc_value, traceback): 'or notebook environment: %s' % e) -def progress_indicated(func): - # type: (Callable[..., Any]) -> Callable[..., Any] # noqa: F821 - +def progress_indicated(func: Callable[..., Any]) -> Callable[..., Any]: """A decorator using a unique progress indicator as a context manager to execute the given function within.""" @functools.wraps(func) @@ -280,9 +278,7 @@ def run_within_progress_indicator(*args, **kwargs): return run_within_progress_indicator -def as_json(func): - # type: (Callable[..., Any]) -> Callable[..., str] # noqa: F821 - +def as_json(func: Callable[..., Any]) -> Callable[..., str]: """A decorator convert python objects returned by callables to json string. @@ -440,9 +436,7 @@ def create_var_in_main(name: str, return name, value -def assert_bucket_exists(bucket_name): - # type: (str) -> None - +def assert_bucket_exists(bucket_name: str) -> None: """Asserts whether the specified GCS bucket with the name bucket_name exists. diff --git a/sdks/python/apache_beam/runners/portability/stager.py b/sdks/python/apache_beam/runners/portability/stager.py index 98c0e3176f75..c7142bfddcaf 100644 --- a/sdks/python/apache_beam/runners/portability/stager.py +++ b/sdks/python/apache_beam/runners/portability/stager.py @@ -107,9 +107,9 @@ class Stager(object): """ _DEFAULT_CHUNK_SIZE = 2 << 20 - def stage_artifact(self, local_path_to_artifact, artifact_name, sha256): - # type: (str, str, str) -> None - + def stage_artifact( + self, local_path_to_artifact: str, artifact_name: str, + sha256: str) -> None: """ Stages the artifact to Stager._staging_location and adds artifact_name to the manifest of artifacts that have been staged.""" raise NotImplementedError @@ -159,14 +159,16 @@ def extract_staging_tuple_iter( raise RuntimeError("unknown artifact type: %s" % artifact.type_urn) @staticmethod - def create_job_resources(options, # type: PipelineOptions - temp_dir, # type: str - build_setup_args=None, # type: Optional[List[str]] - pypi_requirements=None, # type: Optional[List[str]] - populate_requirements_cache=None, # type: Optional[Callable[[str, str, bool], None]] - skip_prestaged_dependencies=False, # type: Optional[bool] - log_submission_env_dependencies=True, # type: Optional[bool] - ): + def create_job_resources( + options: PipelineOptions, + temp_dir: str, + build_setup_args: Optional[List[str]] = None, + pypi_requirements: Optional[List[str]] = None, + populate_requirements_cache: Optional[Callable[[str, str, bool], + None]] = None, + skip_prestaged_dependencies: Optional[bool] = False, + log_submission_env_dependencies: Optional[bool] = True, + ): """For internal use only; no backwards-compatibility guarantees. Creates (if needed) a list of job resources. @@ -198,7 +200,7 @@ def create_job_resources(options, # type: PipelineOptions while trying to create the resources (e.g., build a setup package). """ - resources = [] # type: List[beam_runner_api_pb2.ArtifactInformation] + resources: List[beam_runner_api_pb2.ArtifactInformation] = [] setup_options = options.view_as(SetupOptions) use_beam_default_container = options.view_as( @@ -381,10 +383,10 @@ def create_job_resources(options, # type: PipelineOptions return resources - def stage_job_resources(self, - resources, # type: List[Tuple[str, str, str]] - staging_location=None # type: Optional[str] - ): + def stage_job_resources( + self, + resources: List[Tuple[str, str, str]], + staging_location: Optional[str] = None): """For internal use only; no backwards-compatibility guarantees. Stages job resources to staging_location. @@ -416,13 +418,13 @@ def stage_job_resources(self, def create_and_stage_job_resources( self, - options, # type: PipelineOptions - build_setup_args=None, # type: Optional[List[str]] - temp_dir=None, # type: Optional[str] - pypi_requirements=None, # type: Optional[List[str]] - populate_requirements_cache=None, # type: Optional[Callable[[str, str, bool], None]] - staging_location=None # type: Optional[str] - ): + options: PipelineOptions, + build_setup_args: Optional[List[str]] = None, + temp_dir: Optional[str] = None, + pypi_requirements: Optional[List[str]] = None, + populate_requirements_cache: Optional[Callable[[str, str, bool], + None]] = None, + staging_location: Optional[str] = None): """For internal use only; no backwards-compatibility guarantees. Creates (if needed) and stages job resources to staging_location. @@ -523,9 +525,8 @@ def _is_remote_path(path): return path.find('://') != -1 @staticmethod - def _create_jar_packages(jar_packages, temp_dir): - # type: (...) -> List[beam_runner_api_pb2.ArtifactInformation] - + def _create_jar_packages( + jar_packages, temp_dir) -> List[beam_runner_api_pb2.ArtifactInformation]: """Creates a list of local jar packages for Java SDK Harness. :param jar_packages: Ordered list of local paths to jar packages to be @@ -538,9 +539,9 @@ def _create_jar_packages(jar_packages, temp_dir): RuntimeError: If files specified are not found or do not have expected name patterns. """ - resources = [] # type: List[beam_runner_api_pb2.ArtifactInformation] + resources: List[beam_runner_api_pb2.ArtifactInformation] = [] staging_temp_dir = tempfile.mkdtemp(dir=temp_dir) - local_packages = [] # type: List[str] + local_packages: List[str] = [] for package in jar_packages: if not os.path.basename(package).endswith('.jar'): raise RuntimeError( @@ -574,9 +575,9 @@ def _create_jar_packages(jar_packages, temp_dir): return resources @staticmethod - def _create_extra_packages(extra_packages, temp_dir): - # type: (...) -> List[beam_runner_api_pb2.ArtifactInformation] - + def _create_extra_packages( + extra_packages, + temp_dir) -> List[beam_runner_api_pb2.ArtifactInformation]: """Creates a list of local extra packages. Args: @@ -595,9 +596,9 @@ def _create_extra_packages(extra_packages, temp_dir): RuntimeError: If files specified are not found or do not have expected name patterns. """ - resources = [] # type: List[beam_runner_api_pb2.ArtifactInformation] + resources: List[beam_runner_api_pb2.ArtifactInformation] = [] staging_temp_dir = tempfile.mkdtemp(dir=temp_dir) - local_packages = [] # type: List[str] + local_packages: List[str] = [] for package in extra_packages: if not (os.path.basename(package).endswith('.tar') or os.path.basename(package).endswith('.tar.gz') or @@ -665,9 +666,7 @@ def _get_python_executable(): @staticmethod def _remove_dependency_from_requirements( - requirements_file, # type: str - dependency_to_remove, # type: str - temp_directory_path): + requirements_file: str, dependency_to_remove: str, temp_directory_path): """Function to remove dependencies from a given requirements file.""" # read all the dependency names with open(requirements_file, 'r') as f: @@ -776,11 +775,10 @@ def _populate_requirements_cache( processes.check_output(cmd_args, stderr=processes.STDOUT) @staticmethod - def _build_setup_package(setup_file, # type: str - temp_dir, # type: str - build_setup_args=None # type: Optional[List[str]] - ): - # type: (...) -> str + def _build_setup_package( + setup_file: str, + temp_dir: str, + build_setup_args: Optional[List[str]] = None) -> str: saved_current_directory = os.getcwd() try: os.chdir(os.path.dirname(setup_file)) @@ -819,9 +817,7 @@ def _build_setup_package(setup_file, # type: str os.chdir(saved_current_directory) @staticmethod - def _desired_sdk_filename_in_staging_location(sdk_location): - # type: (...) -> str - + def _desired_sdk_filename_in_staging_location(sdk_location) -> str: """Returns the name that SDK file should have in the staging location. Args: sdk_location: Full path to SDK file. @@ -836,9 +832,9 @@ def _desired_sdk_filename_in_staging_location(sdk_location): return names.STAGED_SDK_SOURCES_FILENAME @staticmethod - def _create_beam_sdk(sdk_remote_location, temp_dir): - # type: (...) -> List[beam_runner_api_pb2.ArtifactInformation] - + def _create_beam_sdk( + sdk_remote_location, + temp_dir) -> List[beam_runner_api_pb2.ArtifactInformation]: """Creates a Beam SDK file with the appropriate version. Args: diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index 8a04e7efb195..2aec8ca5b914 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -653,8 +653,8 @@ def __init__(self, urn, payload, expansion_service=None): payload.payload() if isinstance(payload, PayloadBuilder) else payload) self._expansion_service = expansion_service self._external_namespace = self._fresh_namespace() - self._inputs = {} # type: Dict[str, pvalue.PCollection] - self._outputs = {} # type: Dict[str, pvalue.PCollection] + self._inputs: Dict[str, pvalue.PCollection] = {} + self._outputs: Dict[str, pvalue.PCollection] = {} def with_output_types(self, *args, **kwargs): return WithTypeHints.with_output_types(self, *args, **kwargs) @@ -691,13 +691,11 @@ def outer_namespace(cls, namespace): cls._external_namespace.value = prev @classmethod - def _fresh_namespace(cls): - # type: () -> str + def _fresh_namespace(cls) -> str: ExternalTransform._namespace_counter += 1 return '%s_%d' % (cls.get_local_namespace(), cls._namespace_counter) - def expand(self, pvalueish): - # type: (pvalue.PCollection) -> pvalue.PCollection + def expand(self, pvalueish: pvalue.PCollection) -> pvalue.PCollection: if isinstance(pvalueish, pvalue.PBegin): self._inputs = {} elif isinstance(pvalueish, (list, tuple)): diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index d760ef74fb14..6f5eeb50f7d4 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -1099,7 +1099,7 @@ def SamplePTransform(pcoll): class PTransformLabelsTest(unittest.TestCase): class CustomTransform(beam.PTransform): - pardo = None # type: Optional[beam.PTransform] + pardo: Optional[beam.PTransform] = None def expand(self, pcoll): self.pardo = '*Do*' >> beam.FlatMap(lambda x: [x + 1]) diff --git a/sdks/python/apache_beam/typehints/trivial_inference.py b/sdks/python/apache_beam/typehints/trivial_inference.py index 8b6f43abaa83..fe9007ed63ca 100644 --- a/sdks/python/apache_beam/typehints/trivial_inference.py +++ b/sdks/python/apache_beam/typehints/trivial_inference.py @@ -492,7 +492,7 @@ def infer_return_type_func(f, input_types, debug=False, depth=0): # stack[-has_kwargs]: Map of keyword args. # stack[-1 - has_kwargs]: Iterable of positional args. # stack[-2 - has_kwargs]: Function to call. - has_kwargs = arg & 1 # type: int + has_kwargs: int = arg & 1 pop_count = has_kwargs + 2 if has_kwargs: # TODO(BEAM-24755): Unimplemented. Requires same functionality as a