From 43b237e50407d1c749835aa32e4a5a62ba11931a Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:47:09 -0700 Subject: [PATCH 01/29] Modernize python type hints for apache_beam --- sdks/python/apache_beam/pvalue.py | 125 +++++++++++++----------------- 1 file changed, 55 insertions(+), 70 deletions(-) diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index 0858d628a55c..5aff1d35aa24 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -80,14 +80,14 @@ class PValue(object): (2) Has a transform that can compute the value if executed. (3) Has a value which is meaningful if the transform was executed. """ - - def __init__(self, - pipeline, # type: Pipeline - tag=None, # type: Optional[str] - element_type=None, # type: Optional[Union[type,typehints.TypeConstraint]] - windowing=None, # type: Optional[Windowing] - is_bounded=True, - ): + def __init__( + self, + pipeline: Pipeline, + tag: Optional[str] = None, + element_type: Optional[Union[type, typehints.TypeConstraint]] = None, + windowing: Optional[Windowing] = None, + is_bounded=True, + ): """Initializes a PValue with all arguments hidden behind keyword arguments. Args: @@ -101,7 +101,7 @@ def __init__(self, # The AppliedPTransform instance for the application of the PTransform # generating this PValue. The field gets initialized when a transform # gets applied. - self.producer = None # type: Optional[AppliedPTransform] + self.producer: Optional[AppliedPTransform] = None self.is_bounded = is_bounded if windowing: self._windowing = windowing @@ -152,8 +152,7 @@ def __hash__(self): return hash((self.tag, self.producer)) @property - def windowing(self): - # type: () -> Windowing + def windowing(self) -> Windowing: if not hasattr(self, '_windowing'): assert self.producer is not None and self.producer.transform is not None self._windowing = self.producer.transform.get_windowing( @@ -167,9 +166,7 @@ def __reduce_ex__(self, unused_version): return _InvalidUnpickledPCollection, () @staticmethod - def from_(pcoll, is_bounded=None): - # type: (PValue, Optional[bool]) -> PCollection - + def from_(pcoll: PValue, is_bounded: Optional[bool] = None) -> PCollection: """Create a PCollection, using another PCollection as a starting point. Transfers relevant attributes. @@ -178,8 +175,8 @@ def from_(pcoll, is_bounded=None): is_bounded = pcoll.is_bounded return PCollection(pcoll.pipeline, is_bounded=is_bounded) - def to_runner_api(self, context): - # type: (PipelineContext) -> beam_runner_api_pb2.PCollection + def to_runner_api( + self, context: PipelineContext) -> beam_runner_api_pb2.PCollection: return beam_runner_api_pb2.PCollection( unique_name=self._unique_name(), coder_id=context.coder_id_from_element_type( @@ -189,8 +186,7 @@ def to_runner_api(self, context): windowing_strategy_id=context.windowing_strategies.get_id( self.windowing)) - def _unique_name(self): - # type: () -> str + def _unique_name(self) -> str: if self.producer: return '%d%s.%s' % ( len(self.producer.full_label), self.producer.full_label, self.tag) @@ -198,8 +194,9 @@ def _unique_name(self): return 'PCollection%s' % id(self) @staticmethod - def from_runner_api(proto, context): - # type: (beam_runner_api_pb2.PCollection, PipelineContext) -> PCollection + def from_runner_api( + proto: beam_runner_api_pb2.PCollection, + context: PipelineContext) -> PCollection: # Producer and tag will be filled in later, the key point is that the same # object is returned for the same pcollection id. # We pass None for the PCollection's Pipeline to avoid a cycle during @@ -236,14 +233,14 @@ class PDone(PValue): class DoOutputsTuple(object): """An object grouping the multiple outputs of a ParDo or FlatMap transform.""" - - def __init__(self, - pipeline, # type: Pipeline - transform, # type: ParDo - tags, # type: Sequence[str] - main_tag, # type: Optional[str] - allow_unknown_tags=None, # type: Optional[bool] - ): + def __init__( + self, + pipeline: Pipeline, + transform: ParDo, + tags: Sequence[str], + main_tag: Optional[str], + allow_unknown_tags: Optional[bool] = None, + ): self._pipeline = pipeline self._tags = tags self._main_tag = main_tag @@ -253,9 +250,9 @@ def __init__(self, # The ApplyPTransform instance for the application of the multi FlatMap # generating this value. The field gets initialized when a transform # gets applied. - self.producer = None # type: Optional[AppliedPTransform] + self.producer: Optional[AppliedPTransform] = None # Dictionary of PCollections already associated with tags. - self._pcolls = {} # type: Dict[Optional[str], PCollection] + self._pcolls: Dict[Optional[str], PCollection] = {} def __str__(self): return '<%s>' % self._str_internal() @@ -267,25 +264,21 @@ def _str_internal(self): return '%s main_tag=%s tags=%s transform=%s' % ( self.__class__.__name__, self._main_tag, self._tags, self._transform) - def __iter__(self): - # type: () -> Iterator[PCollection] - + def __iter__(self) -> Iterator[PCollection]: """Iterates over tags returning for each call a (tag, pcollection) pair.""" if self._main_tag is not None: yield self[self._main_tag] for tag in self._tags: yield self[tag] - def __getattr__(self, tag): - # type: (str) -> PCollection + def __getattr__(self, tag: str) -> PCollection: # Special methods which may be accessed before the object is # fully constructed (e.g. in unpickling). if tag[:2] == tag[-2:] == '__': return object.__getattr__(self, tag) # type: ignore return self[tag] - def __getitem__(self, tag): - # type: (Union[int, str, None]) -> PCollection + def __getitem__(self, tag: Union[int, str, None]) -> PCollection: # Accept int tags so that we can look at Partition tags with the # same ints that we used in the partition function. # TODO(gildea): Consider requiring string-based tags everywhere. @@ -337,8 +330,7 @@ class TaggedOutput(object): if it wants to emit on the main output and TaggedOutput objects if it wants to emit a value on a specific tagged output. """ - def __init__(self, tag, value): - # type: (str, Any) -> None + def __init__(self, tag: str, value: Any) -> None: if not isinstance(tag, str): raise TypeError( 'Attempting to create a TaggedOutput with non-string tag %s' % @@ -357,8 +349,7 @@ class AsSideInput(object): options, and should not be instantiated directly. (See instead AsSingleton, AsIter, etc.) """ - def __init__(self, pcoll): - # type: (PCollection) -> None + def __init__(self, pcoll: PCollection) -> None: from apache_beam.transforms import sideinputs self.pvalue = pcoll self._window_mapping_fn = sideinputs.default_window_mapping_fn( @@ -389,8 +380,7 @@ def _windowed_coder(self): # TODO(robertwb): Get rid of _from_runtime_iterable and _view_options # in favor of _side_input_data(). - def _side_input_data(self): - # type: () -> SideInputData + def _side_input_data(self) -> SideInputData: view_options = self._view_options() from_runtime_iterable = type(self)._from_runtime_iterable return SideInputData( @@ -398,15 +388,14 @@ def _side_input_data(self): self._window_mapping_fn, lambda iterable: from_runtime_iterable(iterable, view_options)) - def to_runner_api(self, context): - # type: (PipelineContext) -> beam_runner_api_pb2.SideInput + def to_runner_api( + self, context: PipelineContext) -> beam_runner_api_pb2.SideInput: return self._side_input_data().to_runner_api(context) @staticmethod - def from_runner_api(proto, # type: beam_runner_api_pb2.SideInput - context # type: PipelineContext - ): - # type: (...) -> _UnpickledSideInput + def from_runner_api( + proto: beam_runner_api_pb2.SideInput, + context: PipelineContext) -> _UnpickledSideInput: return _UnpickledSideInput(SideInputData.from_runner_api(proto, context)) @staticmethod @@ -418,8 +407,7 @@ def requires_keyed_input(self): class _UnpickledSideInput(AsSideInput): - def __init__(self, side_input_data): - # type: (SideInputData) -> None + def __init__(self, side_input_data: SideInputData) -> None: self._data = side_input_data self._window_mapping_fn = side_input_data.window_mapping_fn @@ -450,17 +438,17 @@ def _side_input_data(self): class SideInputData(object): """All of the data about a side input except for the bound PCollection.""" - def __init__(self, - access_pattern, # type: str - window_mapping_fn, # type: sideinputs.WindowMappingFn - view_fn - ): + def __init__( + self, + access_pattern: str, + window_mapping_fn: sideinputs.WindowMappingFn, + view_fn): self.access_pattern = access_pattern self.window_mapping_fn = window_mapping_fn self.view_fn = view_fn - def to_runner_api(self, context): - # type: (PipelineContext) -> beam_runner_api_pb2.SideInput + def to_runner_api( + self, context: PipelineContext) -> beam_runner_api_pb2.SideInput: return beam_runner_api_pb2.SideInput( access_pattern=beam_runner_api_pb2.FunctionSpec( urn=self.access_pattern), @@ -472,8 +460,9 @@ def to_runner_api(self, context): payload=pickler.dumps(self.window_mapping_fn))) @staticmethod - def from_runner_api(proto, unused_context): - # type: (beam_runner_api_pb2.SideInput, PipelineContext) -> SideInputData + def from_runner_api( + proto: beam_runner_api_pb2.SideInput, + unused_context: PipelineContext) -> SideInputData: assert proto.view_fn.urn == python_urns.PICKLED_VIEWFN assert ( proto.window_mapping_fn.urn == python_urns.PICKLED_WINDOW_MAPPING_FN) @@ -501,8 +490,8 @@ class AsSingleton(AsSideInput): """ _NO_DEFAULT = object() - def __init__(self, pcoll, default_value=_NO_DEFAULT): - # type: (PCollection, Any) -> None + def __init__( + self, pcoll: PCollection, default_value: Any = _NO_DEFAULT) -> None: super().__init__(pcoll) self.default_value = default_value @@ -552,8 +541,7 @@ def __repr__(self): def _from_runtime_iterable(it, options): return it - def _side_input_data(self): - # type: () -> SideInputData + def _side_input_data(self) -> SideInputData: return SideInputData( common_urns.side_inputs.ITERABLE.urn, self._window_mapping_fn, @@ -582,8 +570,7 @@ class AsList(AsSideInput): def _from_runtime_iterable(it, options): return list(it) - def _side_input_data(self): - # type: () -> SideInputData + def _side_input_data(self) -> SideInputData: return SideInputData( common_urns.side_inputs.ITERABLE.urn, self._window_mapping_fn, list) @@ -607,8 +594,7 @@ class AsDict(AsSideInput): def _from_runtime_iterable(it, options): return dict(it) - def _side_input_data(self): - # type: () -> SideInputData + def _side_input_data(self) -> SideInputData: return SideInputData( common_urns.side_inputs.ITERABLE.urn, self._window_mapping_fn, dict) @@ -631,8 +617,7 @@ def _from_runtime_iterable(it, options): result[k].append(v) return result - def _side_input_data(self): - # type: () -> SideInputData + def _side_input_data(self) -> SideInputData: return SideInputData( common_urns.side_inputs.MULTIMAP.urn, self._window_mapping_fn, From cd495e9cf8ab67bf5e32a7f66996b45bbd5b7ae5 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:47:12 -0700 Subject: [PATCH 02/29] Modernize python type hints for apache_beam/coders --- .../apache_beam/coders/observable_test.py | 2 +- sdks/python/apache_beam/coders/row_coder.py | 3 +- sdks/python/apache_beam/coders/slow_stream.py | 32 +++++++------------ .../coders/standard_coders_test.py | 2 +- sdks/python/apache_beam/coders/typecoders.py | 20 ++++++------ 5 files changed, 24 insertions(+), 35 deletions(-) diff --git a/sdks/python/apache_beam/coders/observable_test.py b/sdks/python/apache_beam/coders/observable_test.py index 46f5186ba533..df4e7ef09408 100644 --- a/sdks/python/apache_beam/coders/observable_test.py +++ b/sdks/python/apache_beam/coders/observable_test.py @@ -29,7 +29,7 @@ class ObservableMixinTest(unittest.TestCase): observed_count = 0 observed_sum = 0 - observed_keys = [] # type: List[Optional[str]] + observed_keys: List[Optional[str]] = [] def observer(self, value, key=None): self.observed_count += 1 diff --git a/sdks/python/apache_beam/coders/row_coder.py b/sdks/python/apache_beam/coders/row_coder.py index 7765ccebc26f..0d0392e94214 100644 --- a/sdks/python/apache_beam/coders/row_coder.py +++ b/sdks/python/apache_beam/coders/row_coder.py @@ -117,8 +117,7 @@ def from_type_hint(cls, type_hint, registry): return cls(schema) @staticmethod - def from_payload(payload): - # type: (bytes) -> RowCoder + def from_payload(payload: bytes) -> RowCoder: return RowCoder(proto_utils.parse_Bytes(payload, schema_pb2.Schema)) def __reduce__(self): diff --git a/sdks/python/apache_beam/coders/slow_stream.py b/sdks/python/apache_beam/coders/slow_stream.py index 71a5b45d7691..b08ad8e9a37f 100644 --- a/sdks/python/apache_beam/coders/slow_stream.py +++ b/sdks/python/apache_beam/coders/slow_stream.py @@ -30,11 +30,10 @@ class OutputStream(object): A pure Python implementation of stream.OutputStream.""" def __init__(self): - self.data = [] # type: List[bytes] + self.data: List[bytes] = [] self.byte_count = 0 - def write(self, b, nested=False): - # type: (bytes, bool) -> None + def write(self, b: bytes, nested: bool = False) -> None: assert isinstance(b, bytes) if nested: self.write_var_int64(len(b)) @@ -45,8 +44,7 @@ def write_byte(self, val): self.data.append(chr(val).encode('latin-1')) self.byte_count += 1 - def write_var_int64(self, v): - # type: (int) -> None + def write_var_int64(self, v: int) -> None: if v < 0: v += 1 << 64 if v <= 0: @@ -78,16 +76,13 @@ def write_bigendian_double(self, v): def write_bigendian_float(self, v): self.write(struct.pack('>f', v)) - def get(self): - # type: () -> bytes + def get(self) -> bytes: return b''.join(self.data) - def size(self): - # type: () -> int + def size(self) -> int: return self.byte_count - def _clear(self): - # type: () -> None + def _clear(self) -> None: self.data = [] self.byte_count = 0 @@ -101,8 +96,7 @@ def __init__(self): super().__init__() self.count = 0 - def write(self, byte_array, nested=False): - # type: (bytes, bool) -> None + def write(self, byte_array: bytes, nested: bool = False) -> None: blen = len(byte_array) if nested: self.write_var_int64(blen) @@ -125,25 +119,21 @@ class InputStream(object): """For internal use only; no backwards-compatibility guarantees. A pure Python implementation of stream.InputStream.""" - def __init__(self, data): - # type: (bytes) -> None + def __init__(self, data: bytes) -> None: self.data = data self.pos = 0 def size(self): return len(self.data) - self.pos - def read(self, size): - # type: (int) -> bytes + def read(self, size: int) -> bytes: self.pos += size return self.data[self.pos - size:self.pos] - def read_all(self, nested): - # type: (bool) -> bytes + def read_all(self, nested: bool) -> bytes: return self.read(self.read_var_int64() if nested else self.size()) - def read_byte(self): - # type: () -> int + def read_byte(self) -> int: self.pos += 1 return self.data[self.pos - 1] diff --git a/sdks/python/apache_beam/coders/standard_coders_test.py b/sdks/python/apache_beam/coders/standard_coders_test.py index b2cbe6e339f7..47df0116f2c6 100644 --- a/sdks/python/apache_beam/coders/standard_coders_test.py +++ b/sdks/python/apache_beam/coders/standard_coders_test.py @@ -300,7 +300,7 @@ def json_value_parser(self, coder_spec): # Used when --fix is passed. fix = False - to_fix = {} # type: Dict[Tuple[int, bytes], bytes] + to_fix: Dict[Tuple[int, bytes], bytes] = {} @classmethod def tearDownClass(cls): diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py index e32e4823c48d..1667cb7a916a 100644 --- a/sdks/python/apache_beam/coders/typecoders.py +++ b/sdks/python/apache_beam/coders/typecoders.py @@ -80,8 +80,8 @@ def MakeXyzs(v): class CoderRegistry(object): """A coder registry for typehint/coder associations.""" def __init__(self, fallback_coder=None): - self._coders = {} # type: Dict[Any, Type[coders.Coder]] - self.custom_types = [] # type: List[Any] + self._coders: Dict[Any, Type[coders.Coder]] = {} + self.custom_types: List[Any] = [] self.register_standard_coders(fallback_coder) def register_standard_coders(self, fallback_coder): @@ -104,12 +104,14 @@ def register_standard_coders(self, fallback_coder): def register_fallback_coder(self, fallback_coder): self._fallback_coder = FirstOf([fallback_coder, self._fallback_coder]) - def _register_coder_internal(self, typehint_type, typehint_coder_class): - # type: (Any, Type[coders.Coder]) -> None + def _register_coder_internal( + self, typehint_type: Any, + typehint_coder_class: Type[coders.Coder]) -> None: self._coders[typehint_type] = typehint_coder_class - def register_coder(self, typehint_type, typehint_coder_class): - # type: (Any, Type[coders.Coder]) -> None + def register_coder( + self, typehint_type: Any, + typehint_coder_class: Type[coders.Coder]) -> None: if not isinstance(typehint_coder_class, type): raise TypeError( 'Coder registration requires a coder class object. ' @@ -122,8 +124,7 @@ def register_coder(self, typehint_type, typehint_coder_class): typehint_type = getattr(typehint_type, '__name__', str(typehint_type)) self._register_coder_internal(typehint_type, typehint_coder_class) - def get_coder(self, typehint): - # type: (Any) -> coders.Coder + def get_coder(self, typehint: Any) -> coders.Coder: if typehint and typehint.__module__ == '__main__': # See https://github.com/apache/beam/issues/21541 # TODO(robertwb): Remove once all runners are portable. @@ -187,8 +188,7 @@ class FirstOf(object): """For internal use only; no backwards-compatibility guarantees. A class used to get the first matching coder from a list of coders.""" - def __init__(self, coders): - # type: (Iterable[Type[coders.Coder]]) -> None + def __init__(self, coders: Iterable[Type[coders.Coder]]) -> None: self._coders = coders def from_type_hint(self, typehint, registry): From 6143cd0e0a80ea3548cf1fef515daa129772d6cb Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:47:17 -0700 Subject: [PATCH 03/29] Modernize python type hints for apache_beam/dataframe --- sdks/python/apache_beam/dataframe/convert.py | 32 ++++--- .../apache_beam/dataframe/frame_base.py | 20 ++--- .../apache_beam/dataframe/partitionings.py | 9 +- sdks/python/apache_beam/dataframe/schemas.py | 9 +- .../apache_beam/dataframe/schemas_test.py | 85 ++++++++++++------- 5 files changed, 84 insertions(+), 71 deletions(-) diff --git a/sdks/python/apache_beam/dataframe/convert.py b/sdks/python/apache_beam/dataframe/convert.py index 96d0c4f8b9f5..817cabc4b076 100644 --- a/sdks/python/apache_beam/dataframe/convert.py +++ b/sdks/python/apache_beam/dataframe/convert.py @@ -42,12 +42,10 @@ # TODO: Or should this be called as_dataframe? def to_dataframe( - pcoll, # type: pvalue.PCollection - proxy=None, # type: Optional[pd.core.generic.NDFrame] - label=None, # type: Optional[str] -): - # type: (...) -> frame_base.DeferredFrame - + pcoll: pvalue.PCollection, + proxy: Optional[pd.core.generic.NDFrame] = None, + label: Optional[str] = None, +) -> frame_base.DeferredFrame: """Converts a PCollection to a deferred dataframe-like object, which can manipulated with pandas methods like `filter` and `groupby`. @@ -93,10 +91,10 @@ def to_dataframe( # Note that the pipeline (indirectly) holds references to the transforms which # keeps both the PCollections and expressions alive. This ensures the # expression's ids are never accidentally re-used. -TO_PCOLLECTION_CACHE = weakref.WeakValueDictionary( -) # type: weakref.WeakValueDictionary[str, pvalue.PCollection] -UNBATCHED_CACHE = weakref.WeakValueDictionary( -) # type: weakref.WeakValueDictionary[str, pvalue.PCollection] +TO_PCOLLECTION_CACHE: weakref.WeakValueDictionary[ + str, pvalue.PCollection] = weakref.WeakValueDictionary() +UNBATCHED_CACHE: weakref.WeakValueDictionary[ + str, pvalue.PCollection] = weakref.WeakValueDictionary() class RowsToDataFrameFn(beam.DoFn): @@ -173,7 +171,7 @@ def infer_output_type(self, input_element_type): def to_pcollection( - *dataframes, # type: Union[frame_base.DeferredFrame, pd.DataFrame, pd.Series] + *dataframes: Union[frame_base.DeferredFrame, pd.DataFrame, pd.Series], label=None, always_return_tuple=False, yield_elements='schemas', @@ -258,12 +256,12 @@ def extract_input(placeholder): df for df in dataframes if df._expr._id not in TO_PCOLLECTION_CACHE ] if len(new_dataframes): - new_results = {p: extract_input(p) - for p in placeholders - } | label >> transforms._DataframeExpressionsTransform({ - ix: df._expr - for (ix, df) in enumerate(new_dataframes) - }) # type: Dict[Any, pvalue.PCollection] + new_results: Dict[Any, pvalue.PCollection] = { + p: extract_input(p) + for p in placeholders + } | label >> transforms._DataframeExpressionsTransform( + {ix: df._expr + for (ix, df) in enumerate(new_dataframes)}) TO_PCOLLECTION_CACHE.update( {new_dataframes[ix]._expr._id: pc diff --git a/sdks/python/apache_beam/dataframe/frame_base.py b/sdks/python/apache_beam/dataframe/frame_base.py index 4e89e473b730..90f34d45dd98 100644 --- a/sdks/python/apache_beam/dataframe/frame_base.py +++ b/sdks/python/apache_beam/dataframe/frame_base.py @@ -38,7 +38,7 @@ class DeferredBase(object): - _pandas_type_map = {} # type: Dict[Union[type, None], type] + _pandas_type_map: Dict[Union[type, None], type] = {} def __init__(self, expr): self._expr = expr @@ -197,8 +197,8 @@ def _proxy_method( inplace=False, base=None, *, - requires_partition_by, # type: partitionings.Partitioning - preserves_partition_by, # type: partitionings.Partitioning + requires_partition_by: partitionings.Partitioning, + preserves_partition_by: partitionings.Partitioning, ): if name is None: name, func = name_and_func(func) @@ -227,14 +227,14 @@ def _elementwise_function( def _proxy_function( - func, # type: Union[Callable, str] - name=None, # type: Optional[str] - restrictions=None, # type: Optional[Dict[str, Union[Any, List[Any]]]] - inplace=False, # type: bool - base=None, # type: Optional[type] + func: Union[Callable, str], + name: Optional[str] = None, + restrictions: Optional[Dict[str, Union[Any, List[Any]]]] = None, + inplace: bool = False, + base: Optional[type] = None, *, - requires_partition_by, # type: partitionings.Partitioning - preserves_partition_by, # type: partitionings.Partitioning + requires_partition_by: partitionings.Partitioning, + preserves_partition_by: partitionings.Partitioning, ): if name is None: diff --git a/sdks/python/apache_beam/dataframe/partitionings.py b/sdks/python/apache_beam/dataframe/partitionings.py index 5513f4bb496e..ca37c504334b 100644 --- a/sdks/python/apache_beam/dataframe/partitionings.py +++ b/sdks/python/apache_beam/dataframe/partitionings.py @@ -32,9 +32,7 @@ class Partitioning(object): def __repr__(self): return self.__class__.__name__ - def is_subpartitioning_of(self, other): - # type: (Partitioning) -> bool - + def is_subpartitioning_of(self, other: Partitioning) -> bool: """Returns whether self is a sub-partition of other. Specifically, returns whether something partitioned by self is necissarily @@ -48,9 +46,8 @@ def __lt__(self, other): def __le__(self, other): return not self.is_subpartitioning_of(other) - def partition_fn(self, df, num_partitions): - # type: (Frame, int) -> Iterable[Tuple[Any, Frame]] - + def partition_fn(self, df: Frame, + num_partitions: int) -> Iterable[Tuple[Any, Frame]]: """A callable that actually performs the partitioning of a Frame df. This will be invoked via a FlatMap in conjunction with a GroupKey to diff --git a/sdks/python/apache_beam/dataframe/schemas.py b/sdks/python/apache_beam/dataframe/schemas.py index 6356945e05f9..e70229f21f77 100644 --- a/sdks/python/apache_beam/dataframe/schemas.py +++ b/sdks/python/apache_beam/dataframe/schemas.py @@ -85,9 +85,7 @@ def expand(self, pcoll): | beam.Map(converter.produce_batch)) -def generate_proxy(element_type): - # type: (type) -> pd.DataFrame - +def generate_proxy(element_type: type) -> pd.DataFrame: """Generate a proxy pandas object for the given PCollection element_type. Currently only supports generating a DataFrame proxy from a schema-aware @@ -106,9 +104,8 @@ def generate_proxy(element_type): return proxy -def element_type_from_dataframe(proxy, include_indexes=False): - # type: (pd.DataFrame, bool) -> type - +def element_type_from_dataframe( + proxy: pd.DataFrame, include_indexes: bool = False) -> type: """Generate an element_type for an element-wise PCollection from a proxy pandas object. Currently only supports converting the element_type for a schema-aware PCollection to a proxy DataFrame. diff --git a/sdks/python/apache_beam/dataframe/schemas_test.py b/sdks/python/apache_beam/dataframe/schemas_test.py index ed0ba6b342af..4c196e29e712 100644 --- a/sdks/python/apache_beam/dataframe/schemas_test.py +++ b/sdks/python/apache_beam/dataframe/schemas_test.py @@ -64,36 +64,57 @@ def check_df_pcoll_equal(actual): # pd.Series([b'abc'], dtype=bytes).dtype != 'S' # pd.Series([b'abc'], dtype=bytes).astype(bytes).dtype == 'S' # (test data, pandas_type, column_name, beam_type) -COLUMNS = [ - ([375, 24, 0, 10, 16], np.int32, 'i32', np.int32), - ([375, 24, 0, 10, 16], np.int64, 'i64', np.int64), - ([375, 24, None, 10, 16], - pd.Int32Dtype(), - 'i32_nullable', - typing.Optional[np.int32]), - ([375, 24, None, 10, 16], - pd.Int64Dtype(), - 'i64_nullable', - typing.Optional[np.int64]), - ([375., 24., None, 10., 16.], - np.float64, - 'f64', - typing.Optional[np.float64]), - ([375., 24., None, 10., 16.], - np.float32, - 'f32', - typing.Optional[np.float32]), - ([True, False, True, True, False], bool, 'bool', bool), - (['Falcon', 'Ostrich', None, 3.14, 0], object, 'any', typing.Any), - ([True, False, True, None, False], - pd.BooleanDtype(), - 'bool_nullable', - typing.Optional[bool]), - (['Falcon', 'Ostrich', None, 'Aardvark', 'Elephant'], - pd.StringDtype(), - 'strdtype', - typing.Optional[str]), -] # type: typing.List[typing.Tuple[typing.List[typing.Any], typing.Any, str, typing.Any]] +COLUMNS: typing.List[typing.Tuple[typing.List[typing.Any], + typing.Any, + str, + typing.Any]] = [ + ([375, 24, 0, 10, 16], + np.int32, + 'i32', + np.int32), + ([375, 24, 0, 10, 16], + np.int64, + 'i64', + np.int64), + ([375, 24, None, 10, 16], + pd.Int32Dtype(), + 'i32_nullable', + typing.Optional[np.int32]), + ([375, 24, None, 10, 16], + pd.Int64Dtype(), + 'i64_nullable', + typing.Optional[np.int64]), + ([375., 24., None, 10., 16.], + np.float64, + 'f64', + typing.Optional[np.float64]), + ([375., 24., None, 10., 16.], + np.float32, + 'f32', + typing.Optional[np.float32]), + ([True, False, True, True, False], + bool, + 'bool', + bool), + (['Falcon', 'Ostrich', None, 3.14, 0], + object, + 'any', + typing.Any), + ([True, False, True, None, False], + pd.BooleanDtype(), + 'bool_nullable', + typing.Optional[bool]), + ([ + 'Falcon', + 'Ostrich', + None, + 'Aardvark', + 'Elephant' + ], + pd.StringDtype(), + 'strdtype', + typing.Optional[str]), + ] NICE_TYPES_DF = pd.DataFrame(columns=[name for _, _, name, _ in COLUMNS]) for arr, dtype, name, _ in COLUMNS: @@ -104,9 +125,9 @@ def check_df_pcoll_equal(actual): SERIES_TESTS = [(pd.Series(arr, dtype=dtype, name=name), arr, beam_type) for (arr, dtype, name, beam_type) in COLUMNS] -_TEST_ARRAYS = [ +_TEST_ARRAYS: typing.List[typing.List[typing.Any]] = [ arr for (arr, _, _, _) in COLUMNS -] # type: typing.List[typing.List[typing.Any]] +] DF_RESULT = list(zip(*_TEST_ARRAYS)) BEAM_SCHEMA = typing.NamedTuple( # type: ignore 'BEAM_SCHEMA', [(name, beam_type) for _, _, name, beam_type in COLUMNS]) From d75916b6342666f3f1e086e91bd088182c7a0d74 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:47:20 -0700 Subject: [PATCH 04/29] Modernize python type hints for apache_beam/examples/cookbook --- .../apache_beam/examples/cookbook/bigtableio_it_test.py | 2 +- .../apache_beam/examples/cookbook/datastore_wordcount.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py b/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py index 98023fbc624c..8fdb4946ed5f 100644 --- a/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py +++ b/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py @@ -53,7 +53,7 @@ if TYPE_CHECKING: import google.cloud.bigtable.instance -EXISTING_INSTANCES = [] # type: List[google.cloud.bigtable.instance.Instance] +EXISTING_INSTANCES: List[google.cloud.bigtable.instance.Instance] = [] LABEL_KEY = 'python-bigtable-beam' label_stamp = datetime.datetime.utcnow().replace(tzinfo=UTC) label_stamp_micros = _microseconds_from_datetime(label_stamp) diff --git a/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py b/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py index 6a4b9e234297..65ea7990a2d8 100644 --- a/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py +++ b/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py @@ -87,9 +87,7 @@ def __init__(self): self.word_counter = Metrics.counter('main', 'total_words') self.word_lengths_dist = Metrics.distribution('main', 'word_len_dist') - def process(self, element): - # type: (Entity) -> Optional[Iterable[Text]] - + def process(self, element: Entity) -> Optional[Iterable[Text]]: """Extract words from the 'content' property of Cloud Datastore entities. The element is a line of text. If the line is blank, note that, too. From c842252b4b5c69a9b6ad078482ae152249d2fbdf Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:47:27 -0700 Subject: [PATCH 05/29] Modernize python type hints for apache_beam/internal --- .../apache_beam/internal/cloudpickle_pickler.py | 4 +--- sdks/python/apache_beam/internal/dill_pickler.py | 7 ++----- sdks/python/apache_beam/internal/module_test.py | 2 +- sdks/python/apache_beam/internal/pickler.py | 3 +-- sdks/python/apache_beam/internal/util.py | 11 +++++------ 5 files changed, 10 insertions(+), 17 deletions(-) diff --git a/sdks/python/apache_beam/internal/cloudpickle_pickler.py b/sdks/python/apache_beam/internal/cloudpickle_pickler.py index 6063faa0b14c..83cdac4b5f33 100644 --- a/sdks/python/apache_beam/internal/cloudpickle_pickler.py +++ b/sdks/python/apache_beam/internal/cloudpickle_pickler.py @@ -46,9 +46,7 @@ RLOCK_TYPE = type(_pickle_lock) -def dumps(o, enable_trace=True, use_zlib=False): - # type: (...) -> bytes - +def dumps(o, enable_trace=True, use_zlib=False) -> bytes: """For internal use only; no backwards-compatibility guarantees.""" with _pickle_lock: with io.BytesIO() as file: diff --git a/sdks/python/apache_beam/internal/dill_pickler.py b/sdks/python/apache_beam/internal/dill_pickler.py index 8a0742642dfb..7f7ac5b214fa 100644 --- a/sdks/python/apache_beam/internal/dill_pickler.py +++ b/sdks/python/apache_beam/internal/dill_pickler.py @@ -309,8 +309,7 @@ def save_module(pickler, obj): # Pickle module dictionaries (commonly found in lambda's globals) # by referencing their module. old_save_module_dict = dill.dill.save_module_dict - known_module_dicts = { - } # type: Dict[int, Tuple[types.ModuleType, Dict[str, Any]]] + known_module_dicts: Dict[int, Tuple[types.ModuleType, Dict[str, Any]]] = {} @dill.dill.register(dict) def new_save_module_dict(pickler, obj): @@ -370,9 +369,7 @@ def new_log_info(msg, *args, **kwargs): logging.getLogger('dill').setLevel(logging.WARN) -def dumps(o, enable_trace=True, use_zlib=False): - # type: (...) -> bytes - +def dumps(o, enable_trace=True, use_zlib=False) -> bytes: """For internal use only; no backwards-compatibility guarantees.""" with _pickle_lock: try: diff --git a/sdks/python/apache_beam/internal/module_test.py b/sdks/python/apache_beam/internal/module_test.py index eaa1629be8e5..55a178b93f82 100644 --- a/sdks/python/apache_beam/internal/module_test.py +++ b/sdks/python/apache_beam/internal/module_test.py @@ -64,7 +64,7 @@ def get(self): class RecursiveClass(object): """A class that contains a reference to itself.""" - SELF_TYPE = None # type: Type[RecursiveClass] + SELF_TYPE: Type[RecursiveClass] = None def __init__(self, datum): self.datum = 'RecursiveClass:%s' % datum diff --git a/sdks/python/apache_beam/internal/pickler.py b/sdks/python/apache_beam/internal/pickler.py index 1685ae928167..79ebd16314bf 100644 --- a/sdks/python/apache_beam/internal/pickler.py +++ b/sdks/python/apache_beam/internal/pickler.py @@ -38,8 +38,7 @@ desired_pickle_lib = dill_pickler -def dumps(o, enable_trace=True, use_zlib=False): - # type: (...) -> bytes +def dumps(o, enable_trace=True, use_zlib=False) -> bytes: return desired_pickle_lib.dumps( o, enable_trace=enable_trace, use_zlib=use_zlib) diff --git a/sdks/python/apache_beam/internal/util.py b/sdks/python/apache_beam/internal/util.py index f0a3ad8288b5..85a6e4c43b83 100644 --- a/sdks/python/apache_beam/internal/util.py +++ b/sdks/python/apache_beam/internal/util.py @@ -66,12 +66,11 @@ def __hash__(self): return hash(type(self)) -def remove_objects_from_args(args, # type: Iterable[Any] - kwargs, # type: Dict[str, Any] - pvalue_class # type: Union[Type[T], Tuple[Type[T], ...]] - ): - # type: (...) -> Tuple[List[Any], Dict[str, Any], List[T]] - +def remove_objects_from_args( + args: Iterable[Any], + kwargs: Dict[str, Any], + pvalue_class: Union[Type[T], Tuple[Type[T], ...]] +) -> Tuple[List[Any], Dict[str, Any], List[T]]: """For internal use only; no backwards-compatibility guarantees. Replaces all objects of a given type in args/kwargs with a placeholder. From 33bde4de9c5778a22a8f0143693e958414b49ac0 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:47:28 -0700 Subject: [PATCH 06/29] Modernize python type hints for apache_beam/internal/metrics --- .../apache_beam/internal/metrics/cells.py | 27 +++----- .../apache_beam/internal/metrics/metric.py | 66 ++++++++++--------- 2 files changed, 45 insertions(+), 48 deletions(-) diff --git a/sdks/python/apache_beam/internal/metrics/cells.py b/sdks/python/apache_beam/internal/metrics/cells.py index 3fcaecf8c677..9a5f8c1f3113 100644 --- a/sdks/python/apache_beam/internal/metrics/cells.py +++ b/sdks/python/apache_beam/internal/metrics/cells.py @@ -55,8 +55,7 @@ def __init__(self, bucket_type): def reset(self): self.data = HistogramAggregator(self._bucket_type).identity_element() - def combine(self, other): - # type: (HistogramCell) -> HistogramCell + def combine(self, other: HistogramCell) -> HistogramCell: result = HistogramCell(self._bucket_type) result.data = self.data.combine(other.data) return result @@ -64,8 +63,7 @@ def combine(self, other): def update(self, value): self.data.histogram.record(value) - def get_cumulative(self): - # type: () -> HistogramData + def get_cumulative(self) -> HistogramData: return self.data.get_cumulative() def to_runner_api_monitoring_info(self, name, transform_id): @@ -92,8 +90,7 @@ def __hash__(self): class HistogramResult(object): - def __init__(self, data): - # type: (HistogramData) -> None + def __init__(self, data: HistogramData) -> None: self.data = data def __eq__(self, other): @@ -142,12 +139,10 @@ def __hash__(self): def __repr__(self): return 'HistogramData({})'.format(self.histogram.get_percentile_info()) - def get_cumulative(self): - # type: () -> HistogramData + def get_cumulative(self) -> HistogramData: return HistogramData(self.histogram) - def combine(self, other): - # type: (Optional[HistogramData]) -> HistogramData + def combine(self, other: Optional[HistogramData]) -> HistogramData: if other is None: return self @@ -161,18 +156,14 @@ class HistogramAggregator(MetricAggregator): Values aggregated should be ``HistogramData`` objects. """ - def __init__(self, bucket_type): - # type: (BucketType) -> None + def __init__(self, bucket_type: BucketType) -> None: self._bucket_type = bucket_type - def identity_element(self): - # type: () -> HistogramData + def identity_element(self) -> HistogramData: return HistogramData(Histogram(self._bucket_type)) - def combine(self, x, y): - # type: (HistogramData, HistogramData) -> HistogramData + def combine(self, x: HistogramData, y: HistogramData) -> HistogramData: return x.combine(y) - def result(self, x): - # type: (HistogramData) -> HistogramResult + def result(self, x: HistogramData) -> HistogramResult: return HistogramResult(x.get_cumulative()) diff --git a/sdks/python/apache_beam/internal/metrics/metric.py b/sdks/python/apache_beam/internal/metrics/metric.py index f892dd2024a1..35a5b4f3bc6a 100644 --- a/sdks/python/apache_beam/internal/metrics/metric.py +++ b/sdks/python/apache_beam/internal/metrics/metric.py @@ -61,9 +61,10 @@ class Metrics(object): @staticmethod - def counter(urn, labels=None, process_wide=False): - # type: (str, Optional[Dict[str, str]], bool) -> UserMetrics.DelegatingCounter - + def counter( + urn: str, + labels: Optional[Dict[str, str]] = None, + process_wide: bool = False) -> UserMetrics.DelegatingCounter: """Obtains or creates a Counter metric. Args: @@ -82,9 +83,11 @@ def counter(urn, labels=None, process_wide=False): process_wide=process_wide) @staticmethod - def histogram(namespace, name, bucket_type, logger=None): - # type: (Union[Type, str], str, BucketType, Optional[MetricLogger]) -> Metrics.DelegatingHistogram - + def histogram( + namespace: Union[Type, str], + name: str, + bucket_type: BucketType, + logger: Optional[MetricLogger] = None) -> Metrics.DelegatingHistogram: """Obtains or creates a Histogram metric. Args: @@ -103,16 +106,18 @@ def histogram(namespace, name, bucket_type, logger=None): class DelegatingHistogram(Histogram): """Metrics Histogram that Delegates functionality to MetricsEnvironment.""" - def __init__(self, metric_name, bucket_type, logger): - # type: (MetricName, BucketType, Optional[MetricLogger]) -> None + def __init__( + self, + metric_name: MetricName, + bucket_type: BucketType, + logger: Optional[MetricLogger]) -> None: super().__init__(metric_name) self.metric_name = metric_name self.cell_type = HistogramCellFactory(bucket_type) self.logger = logger self.updater = MetricUpdater(self.cell_type, self.metric_name) - def update(self, value): - # type: (object) -> None + def update(self, value: object) -> None: self.updater(value) if self.logger: self.logger.update(self.cell_type, self.metric_name, value) @@ -120,27 +125,30 @@ def update(self, value): class MetricLogger(object): """Simple object to locally aggregate and log metrics.""" - def __init__(self): - # type: () -> None - self._metric = {} # type: Dict[MetricName, MetricCell] + def __init__(self) -> None: + self._metric: Dict[MetricName, MetricCell] = {} self._lock = threading.Lock() self._last_logging_millis = int(time.time() * 1000) self.minimum_logging_frequency_msec = 180000 - def update(self, cell_type, metric_name, value): - # type: (Union[Type[MetricCell], MetricCellFactory], MetricName, object) -> None + def update( + self, + cell_type: Union[Type[MetricCell], MetricCellFactory], + metric_name: MetricName, + value: object) -> None: cell = self._get_metric_cell(cell_type, metric_name) cell.update(value) - def _get_metric_cell(self, cell_type, metric_name): - # type: (Union[Type[MetricCell], MetricCellFactory], MetricName) -> MetricCell + def _get_metric_cell( + self, + cell_type: Union[Type[MetricCell], MetricCellFactory], + metric_name: MetricName) -> MetricCell: with self._lock: if metric_name not in self._metric: self._metric[metric_name] = cell_type() return self._metric[metric_name] - def log_metrics(self, reset_after_logging=False): - # type: (bool) -> None + def log_metrics(self, reset_after_logging: bool = False) -> None: if self._lock.acquire(False): try: current_millis = int(time.time() * 1000) @@ -172,14 +180,14 @@ class ServiceCallMetric(object): TODO(ajamato): Add Request latency metric. """ - def __init__(self, request_count_urn, base_labels=None): - # type: (str, Optional[Dict[str, str]]) -> None + def __init__( + self, + request_count_urn: str, + base_labels: Optional[Dict[str, str]] = None) -> None: self.base_labels = base_labels if base_labels else {} self.request_count_urn = request_count_urn - def call(self, status): - # type: (Union[int, str, HttpError]) -> None - + def call(self, status: Union[int, str, HttpError]) -> None: """Record the status of the call into appropriate metrics.""" canonical_status = self.convert_to_canonical_status_string(status) additional_labels = {monitoring_infos.STATUS_LABEL: canonical_status} @@ -191,9 +199,8 @@ def call(self, status): urn=self.request_count_urn, labels=labels, process_wide=True) request_counter.inc() - def convert_to_canonical_status_string(self, status): - # type: (Union[int, str, HttpError]) -> str - + def convert_to_canonical_status_string( + self, status: Union[int, str, HttpError]) -> str: """Converts a status to a canonical GCP status cdoe string.""" http_status_code = None if isinstance(status, int): @@ -222,9 +229,8 @@ def convert_to_canonical_status_string(self, status): return str(http_status_code) @staticmethod - def bigtable_error_code_to_grpc_status_string(grpc_status_code): - # type: (Optional[int]) -> str - + def bigtable_error_code_to_grpc_status_string( + grpc_status_code: Optional[int]) -> str: """ Converts the bigtable error code to a canonical GCP status code string. From d73982af05412feffe53f7a56202e4cedce8671c Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:47:33 -0700 Subject: [PATCH 07/29] Modernize python type hints for apache_beam/io --- sdks/python/apache_beam/io/filebasedsource.py | 24 +++--- sdks/python/apache_beam/io/filesystem.py | 77 ++++++------------ .../python/apache_beam/io/hadoopfilesystem.py | 12 +-- sdks/python/apache_beam/io/iobase.py | 79 +++++++------------ sdks/python/apache_beam/io/jdbc.py | 18 ++--- sdks/python/apache_beam/io/localfilesystem.py | 8 +- .../apache_beam/io/restriction_trackers.py | 9 +-- sdks/python/apache_beam/io/textio.py | 66 ++++++++-------- 8 files changed, 109 insertions(+), 184 deletions(-) diff --git a/sdks/python/apache_beam/io/filebasedsource.py b/sdks/python/apache_beam/io/filebasedsource.py index 240fc65c52b3..91763ced6e69 100644 --- a/sdks/python/apache_beam/io/filebasedsource.py +++ b/sdks/python/apache_beam/io/filebasedsource.py @@ -135,8 +135,7 @@ def display_data(self): } @check_accessible(['_pattern']) - def _get_concat_source(self): - # type: () -> concat_source.ConcatSource + def _get_concat_source(self) -> concat_source.ConcatSource: if self._concat_source is None: pattern = self._pattern.get() @@ -369,9 +368,8 @@ def process(self, element: Union[str, FileMetadata], *args, class _ReadRange(DoFn): def __init__( self, - source_from_file, # type: Union[str, iobase.BoundedSource] - with_filename=False # type: bool - ) -> None: + source_from_file: Union[str, iobase.BoundedSource], + with_filename: bool = False) -> None: self._source_from_file = source_from_file self._with_filename = with_filename @@ -402,14 +400,14 @@ class ReadAllFiles(PTransform): PTransform authors who wishes to implement file-based Read transforms that read a PCollection of files. """ - def __init__(self, - splittable, # type: bool - compression_type, - desired_bundle_size, # type: int - min_bundle_size, # type: int - source_from_file, # type: Callable[[str], iobase.BoundedSource] - with_filename=False # type: bool - ): + def __init__( + self, + splittable: bool, + compression_type, + desired_bundle_size: int, + min_bundle_size: int, + source_from_file: Callable[[str], iobase.BoundedSource], + with_filename: bool = False): """ Args: splittable: If False, files won't be split into sub-ranges. If True, diff --git a/sdks/python/apache_beam/io/filesystem.py b/sdks/python/apache_beam/io/filesystem.py index 142e04bc295e..550079a482c4 100644 --- a/sdks/python/apache_beam/io/filesystem.py +++ b/sdks/python/apache_beam/io/filesystem.py @@ -145,7 +145,7 @@ class CompressedFile(object): def __init__( self, - fileobj, # type: BinaryIO + fileobj: BinaryIO, compression_type=CompressionTypes.GZIP, read_size=DEFAULT_READ_BUFFER_SIZE): if not fileobj: @@ -167,7 +167,7 @@ def __init__( raise ValueError( 'File object must be at position 0 but was %d' % self._file.tell()) self._uncompressed_position = 0 - self._uncompressed_size = None # type: Optional[int] + self._uncompressed_size: Optional[int] = None if self.readable(): self._read_size = read_size @@ -217,19 +217,15 @@ def _initialize_compressor(self): self._compressor = zlib.compressobj( zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, self._gzip_mask) - def readable(self): - # type: () -> bool + def readable(self) -> bool: mode = self._file.mode return 'r' in mode or 'a' in mode - def writeable(self): - # type: () -> bool + def writeable(self) -> bool: mode = self._file.mode return 'w' in mode or 'a' in mode - def write(self, data): - # type: (bytes) -> None - + def write(self, data: bytes) -> None: """Write data to file.""" if not self._compressor: raise ValueError('compressor not initialized') @@ -303,9 +299,7 @@ def read(self, num_bytes: Optional[int] = None) -> bytes: return self._read_from_internal_buffer( lambda: self._read_buffer.read(num_bytes)) - def readline(self): - # type: () -> bytes - + def readline(self) -> bytes: """Equivalent to standard file.readline(). Same return conventions apply.""" if not self._decompressor: raise ValueError('decompressor not initialized') @@ -345,31 +339,24 @@ def flush(self) -> None: self._file.flush() @property - def seekable(self): - # type: () -> bool + def seekable(self) -> bool: return 'r' in self._file.mode - def _clear_read_buffer(self): - # type: () -> None - + def _clear_read_buffer(self) -> None: """Clears the read buffer by removing all the contents and resetting _read_position to 0""" self._read_position = 0 self._read_buffer.seek(0) self._read_buffer.truncate(0) - def _rewind_file(self): - # type: () -> None - + def _rewind_file(self) -> None: """Seeks to the beginning of the input file. Input file's EOF marker is cleared and _uncompressed_position is reset to zero""" self._file.seek(0, os.SEEK_SET) self._read_eof = False self._uncompressed_position = 0 - def _rewind(self): - # type: () -> None - + def _rewind(self) -> None: """Seeks to the beginning of the input file and resets the internal read buffer. The decompressor object is re-initialized to ensure that no data left in it's buffer.""" @@ -379,9 +366,7 @@ def _rewind(self): # Re-initialize decompressor to clear any data buffered prior to rewind self._initialize_decompressor() - def seek(self, offset, whence=os.SEEK_SET): - # type: (int, int) -> None - + def seek(self, offset: int, whence: int = os.SEEK_SET) -> None: """Set the file's current offset. Seeking behavior: @@ -445,9 +430,7 @@ def seek(self, offset, whence=os.SEEK_SET): break bytes_to_skip -= len(data) - def tell(self): - # type: () -> int - + def tell(self) -> int: """Returns current position in uncompressed file.""" return self._uncompressed_position @@ -503,8 +486,7 @@ class MatchResult(object): """Result from the ``FileSystem`` match operation which contains the list of matched ``FileMetadata``. """ - def __init__(self, pattern, metadata_list): - # type: (str, List[FileMetadata]) -> None + def __init__(self, pattern: str, metadata_list: List[FileMetadata]) -> None: self.metadata_list = metadata_list self.pattern = pattern @@ -559,9 +541,7 @@ def scheme(cls): raise NotImplementedError @abc.abstractmethod - def join(self, basepath, *paths): - # type: (str, *str) -> str - + def join(self, basepath: str, *paths: str) -> str: """Join two or more pathname components for the filesystem Args: @@ -573,9 +553,7 @@ def join(self, basepath, *paths): raise NotImplementedError @abc.abstractmethod - def split(self, path): - # type: (str) -> Tuple[str, str] - + def split(self, path: str) -> Tuple[str, str]: """Splits the given path into two parts. Splits the path into a pair (head, tail) such that tail contains the last @@ -648,9 +626,8 @@ def _url_dirname(self, url_or_path): scheme, path = self._split_scheme(url_or_path) return self._combine_scheme(scheme, posixpath.dirname(path)) - def match_files(self, file_metas, pattern): - # type: (List[FileMetadata], str) -> Iterator[FileMetadata] - + def match_files(self, file_metas: List[FileMetadata], + pattern: str) -> Iterator[FileMetadata]: """Filter :class:`FileMetadata` objects by *pattern* Args: @@ -671,9 +648,7 @@ def match_files(self, file_metas, pattern): yield file_metadata @staticmethod - def translate_pattern(pattern): - # type: (str) -> str - + def translate_pattern(pattern: str) -> str: """ Translate a *pattern* to a regular expression. There is no way to quote meta-characters. @@ -809,9 +784,7 @@ def create( self, path, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO - + compression_type=CompressionTypes.AUTO) -> BinaryIO: """Returns a write channel for the given file path. Args: @@ -828,9 +801,7 @@ def open( self, path, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO - + compression_type=CompressionTypes.AUTO) -> BinaryIO: """Returns a read channel for the given file path. Args: @@ -870,9 +841,7 @@ def rename(self, source_file_names, destination_file_names): raise NotImplementedError @abc.abstractmethod - def exists(self, path): - # type: (str) -> bool - + def exists(self, path: str) -> bool: """Check if the provided path exists on the FileSystem. Args: @@ -883,9 +852,7 @@ def exists(self, path): raise NotImplementedError @abc.abstractmethod - def size(self, path): - # type: (str) -> int - + def size(self, path: str) -> int: """Get size in bytes of a file on the FileSystem. Args: diff --git a/sdks/python/apache_beam/io/hadoopfilesystem.py b/sdks/python/apache_beam/io/hadoopfilesystem.py index c47a66c0f105..cf488c228a28 100644 --- a/sdks/python/apache_beam/io/hadoopfilesystem.py +++ b/sdks/python/apache_beam/io/hadoopfilesystem.py @@ -237,9 +237,7 @@ def create( self, url, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO - + compression_type=CompressionTypes.AUTO) -> BinaryIO: """ Returns: A Python File-like object. @@ -261,9 +259,7 @@ def open( self, url, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO - + compression_type=CompressionTypes.AUTO) -> BinaryIO: """ Returns: A Python File-like object. @@ -356,9 +352,7 @@ def rename(self, source_file_names, destination_file_names): if exceptions: raise BeamIOError('Rename operation failed', exceptions) - def exists(self, url): - # type: (str) -> bool - + def exists(self, url: str) -> bool: """Checks existence of url in HDFS. Args: diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py index 96f154dbe4b8..1b416d37f8a4 100644 --- a/sdks/python/apache_beam/io/iobase.py +++ b/sdks/python/apache_beam/io/iobase.py @@ -104,8 +104,7 @@ class SourceBase(HasDisplayData, urns.RunnerApiFn): def default_output_coder(self): raise NotImplementedError - def is_bounded(self): - # type: () -> bool + def is_bounded(self) -> bool: raise NotImplementedError @@ -144,9 +143,7 @@ class BoundedSource(SourceBase): implementations may invoke methods of ``BoundedSource`` objects through multi-threaded and/or reentrant execution modes. """ - def estimate_size(self): - # type: () -> Optional[int] - + def estimate_size(self) -> Optional[int]: """Estimates the size of source in bytes. An estimate of the total size (in bytes) of the data that would be read @@ -159,13 +156,12 @@ def estimate_size(self): """ raise NotImplementedError - def split(self, - desired_bundle_size, # type: int - start_position=None, # type: Optional[Any] - stop_position=None, # type: Optional[Any] - ): - # type: (...) -> Iterator[SourceBundle] - + def split( + self, + desired_bundle_size: int, + start_position: Optional[Any] = None, + stop_position: Optional[Any] = None, + ) -> Iterator[SourceBundle]: """Splits the source into a set of bundles. Bundles should be approximately of size ``desired_bundle_size`` bytes. @@ -182,12 +178,11 @@ def split(self, """ raise NotImplementedError - def get_range_tracker(self, - start_position, # type: Optional[Any] - stop_position, # type: Optional[Any] - ): - # type: (...) -> RangeTracker - + def get_range_tracker( + self, + start_position: Optional[Any], + stop_position: Optional[Any], + ) -> RangeTracker: """Returns a RangeTracker for a given position range. Framework may invoke ``read()`` method with the RangeTracker object returned @@ -879,9 +874,7 @@ class Read(ptransform.PTransform): # Import runners here to prevent circular imports from apache_beam.runners.pipeline_context import PipelineContext - def __init__(self, source): - # type: (SourceBase) -> None - + def __init__(self, source: SourceBase) -> None: """Initializes a Read transform. Args: @@ -921,12 +914,12 @@ def expand(self, pbegin): return pvalue.PCollection( pbegin.pipeline, is_bounded=self.source.is_bounded()) - def get_windowing(self, unused_inputs): - # type: (...) -> core.Windowing + def get_windowing(self, unused_inputs) -> core.Windowing: return core.Windowing(window.GlobalWindows()) - def _infer_output_coder(self, input_type=None, input_coder=None): - # type: (...) -> Optional[coders.Coder] + def _infer_output_coder(self, + input_type=None, + input_coder=None) -> Optional[coders.Coder]: if isinstance(self.source, SourceBase): return self.source.default_output_coder() else: @@ -1129,8 +1122,7 @@ def from_runner_api_parameter( class WriteImpl(ptransform.PTransform): """Implements the writing of custom sinks.""" - def __init__(self, sink): - # type: (Sink) -> None + def __init__(self, sink: Sink) -> None: super().__init__() self.sink = sink @@ -1289,9 +1281,7 @@ def current_restriction(self): """ raise NotImplementedError - def current_progress(self): - # type: () -> RestrictionProgress - + def current_progress(self) -> RestrictionProgress: """Returns a RestrictionProgress object representing the current progress. This API is recommended to be implemented. The runner can do a better job @@ -1416,16 +1406,12 @@ def get_estimator_state(self): """ raise NotImplementedError(type(self)) - def current_watermark(self): - # type: () -> timestamp.Timestamp - + def current_watermark(self) -> timestamp.Timestamp: """Return estimated output_watermark. This function must return monotonically increasing watermarks.""" raise NotImplementedError(type(self)) - def observe_timestamp(self, timestamp): - # type: (timestamp.Timestamp) -> None - + def observe_timestamp(self, timestamp: timestamp.Timestamp) -> None: """Update tracking watermark with latest output timestamp. Args: @@ -1450,8 +1436,7 @@ def __repr__(self): self._fraction, self._completed, self._remaining) @property - def completed_work(self): - # type: () -> float + def completed_work(self) -> float: if self._completed is not None: return self._completed elif self._remaining is not None and self._fraction is not None: @@ -1460,8 +1445,7 @@ def completed_work(self): return self._fraction @property - def remaining_work(self): - # type: () -> float + def remaining_work(self) -> float: if self._remaining is not None: return self._remaining elif self._completed is not None and self._fraction: @@ -1470,28 +1454,24 @@ def remaining_work(self): return 1 - self._fraction @property - def total_work(self): - # type: () -> float + def total_work(self) -> float: return self.completed_work + self.remaining_work @property - def fraction_completed(self): - # type: () -> float + def fraction_completed(self) -> float: if self._fraction is not None: return self._fraction else: return float(self._completed) / self.total_work @property - def fraction_remaining(self): - # type: () -> float + def fraction_remaining(self) -> float: if self._fraction is not None: return 1 - self._fraction else: return float(self._remaining) / self.total_work - def with_completed(self, completed): - # type: (int) -> RestrictionProgress + def with_completed(self, completed: int) -> RestrictionProgress: return RestrictionProgress( fraction=self._fraction, remaining=self._remaining, completed=completed) @@ -1569,8 +1549,7 @@ def __init__(self, restriction): restriction) self.restriction = restriction - def current_progress(self): - # type: () -> RestrictionProgress + def current_progress(self) -> RestrictionProgress: return RestrictionProgress( fraction=self.restriction.range_tracker().fraction_consumed()) diff --git a/sdks/python/apache_beam/io/jdbc.py b/sdks/python/apache_beam/io/jdbc.py index 903b0d1b0fef..3fef1f5fee35 100644 --- a/sdks/python/apache_beam/io/jdbc.py +++ b/sdks/python/apache_beam/io/jdbc.py @@ -373,8 +373,7 @@ def __init__(self, argument=""): pass @classmethod - def representation_type(cls): - # type: () -> type + def representation_type(cls) -> type: return Timestamp @classmethod @@ -385,14 +384,12 @@ def urn(cls): def language_type(cls): return datetime.date - def to_representation_type(self, value): - # type: (datetime.date) -> Timestamp + def to_representation_type(self, value: datetime.date) -> Timestamp: return Timestamp.from_utc_datetime( datetime.datetime.combine( value, datetime.datetime.min.time(), tzinfo=datetime.timezone.utc)) - def to_language_type(self, value): - # type: (Timestamp) -> datetime.date + def to_language_type(self, value: Timestamp) -> datetime.date: return value.to_utc_datetime().date() @@ -420,8 +417,7 @@ def __init__(self, argument=""): pass @classmethod - def representation_type(cls): - # type: () -> type + def representation_type(cls) -> type: return Timestamp @classmethod @@ -432,16 +428,14 @@ def urn(cls): def language_type(cls): return datetime.time - def to_representation_type(self, value): - # type: (datetime.date) -> Timestamp + def to_representation_type(self, value: datetime.date) -> Timestamp: return Timestamp.from_utc_datetime( datetime.datetime.combine( datetime.datetime.utcfromtimestamp(0), value, tzinfo=datetime.timezone.utc)) - def to_language_type(self, value): - # type: (Timestamp) -> datetime.date + def to_language_type(self, value: Timestamp) -> datetime.date: return value.to_utc_datetime().time() diff --git a/sdks/python/apache_beam/io/localfilesystem.py b/sdks/python/apache_beam/io/localfilesystem.py index 3580b79ea56f..e9fe7dd4b1c2 100644 --- a/sdks/python/apache_beam/io/localfilesystem.py +++ b/sdks/python/apache_beam/io/localfilesystem.py @@ -147,9 +147,7 @@ def create( self, path, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO - + compression_type=CompressionTypes.AUTO) -> BinaryIO: """Returns a write channel for the given file path. Args: @@ -166,9 +164,7 @@ def open( self, path, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO - + compression_type=CompressionTypes.AUTO) -> BinaryIO: """Returns a read channel for the given file path. Args: diff --git a/sdks/python/apache_beam/io/restriction_trackers.py b/sdks/python/apache_beam/io/restriction_trackers.py index 06b06fa1ed34..37d902aa5f3f 100644 --- a/sdks/python/apache_beam/io/restriction_trackers.py +++ b/sdks/python/apache_beam/io/restriction_trackers.py @@ -62,8 +62,7 @@ def split(self, desired_num_offsets_per_split, min_num_offsets_per_split=1): yield OffsetRange(current_split_start, current_split_stop) current_split_start = current_split_stop - def split_at(self, split_pos): - # type: (...) -> Tuple[OffsetRange, OffsetRange] + def split_at(self, split_pos) -> Tuple[OffsetRange, OffsetRange]: return OffsetRange(self.start, split_pos), OffsetRange(split_pos, self.stop) def new_tracker(self): @@ -78,8 +77,7 @@ class OffsetRestrictionTracker(RestrictionTracker): Offset range is represented as OffsetRange. """ - def __init__(self, offset_range): - # type: (OffsetRange) -> None + def __init__(self, offset_range: OffsetRange) -> None: assert isinstance(offset_range, OffsetRange), offset_range self._range = offset_range self._current_position = None @@ -100,8 +98,7 @@ def check_done(self): def current_restriction(self): return self._range - def current_progress(self): - # type: () -> RestrictionProgress + def current_progress(self) -> RestrictionProgress: if self._current_position is None: fraction = 0.0 elif self._range.stop == self._range.start: diff --git a/sdks/python/apache_beam/io/textio.py b/sdks/python/apache_beam/io/textio.py index 454fe4d69dea..3de9709d7362 100644 --- a/sdks/python/apache_beam/io/textio.py +++ b/sdks/python/apache_beam/io/textio.py @@ -102,18 +102,19 @@ def reset(self): self.data = b'' self.position = 0 - def __init__(self, - file_pattern, - min_bundle_size, - compression_type, - strip_trailing_newlines, - coder, # type: coders.Coder - buffer_size=DEFAULT_READ_BUFFER_SIZE, - validate=True, - skip_header_lines=0, - header_processor_fns=(None, None), - delimiter=None, - escapechar=None): + def __init__( + self, + file_pattern, + min_bundle_size, + compression_type, + strip_trailing_newlines, + coder: coders.Coder, + buffer_size=DEFAULT_READ_BUFFER_SIZE, + validate=True, + skip_header_lines=0, + header_processor_fns=(None, None), + delimiter=None, + escapechar=None): """Initialize a _TextSource Args: @@ -433,21 +434,21 @@ def output_type_hint(self): class _TextSink(filebasedsink.FileBasedSink): """A sink to a GCS or local text file or files.""" - - def __init__(self, - file_path_prefix, - file_name_suffix='', - append_trailing_newlines=True, - num_shards=0, - shard_name_template=None, - coder=coders.ToBytesCoder(), # type: coders.Coder - compression_type=CompressionTypes.AUTO, - header=None, - footer=None, - *, - max_records_per_shard=None, - max_bytes_per_shard=None, - skip_if_empty=False): + def __init__( + self, + file_path_prefix, + file_name_suffix='', + append_trailing_newlines=True, + num_shards=0, + shard_name_template=None, + coder: coders.Coder = coders.ToBytesCoder(), + compression_type=CompressionTypes.AUTO, + header=None, + footer=None, + *, + max_records_per_shard=None, + max_bytes_per_shard=None, + skip_if_empty=False): """Initialize a _TextSink. Args: @@ -591,7 +592,7 @@ def __init__( compression_type=CompressionTypes.AUTO, strip_trailing_newlines=True, validate=False, - coder=coders.StrUtf8Coder(), # type: coders.Coder + coder: coders.Coder = coders.StrUtf8Coder(), skip_header_lines=0, with_filename=False, delimiter=None, @@ -742,7 +743,7 @@ def __init__( min_bundle_size=0, compression_type=CompressionTypes.AUTO, strip_trailing_newlines=True, - coder=coders.StrUtf8Coder(), # type: coders.Coder + coder: coders.Coder = coders.StrUtf8Coder(), validate=True, skip_header_lines=0, delimiter=None, @@ -808,15 +809,14 @@ class ReadFromTextWithFilename(ReadFromText): class WriteToText(PTransform): """A :class:`~apache_beam.transforms.ptransform.PTransform` for writing to text files.""" - def __init__( self, - file_path_prefix, # type: str + file_path_prefix: str, file_name_suffix='', append_trailing_newlines=True, num_shards=0, - shard_name_template=None, # type: Optional[str] - coder=coders.ToBytesCoder(), # type: coders.Coder + shard_name_template: Optional[str] = None, + coder: coders.Coder = coders.ToBytesCoder(), compression_type=CompressionTypes.AUTO, header=None, footer=None, From 8f6f24dc1ff91ddd9e243522c48815622effd540 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:47:35 -0700 Subject: [PATCH 08/29] Modernize python type hints for apache_beam/io/azure --- sdks/python/apache_beam/io/azure/blobstoragefilesystem.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py b/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py index 8bc3dd68281f..d18440a50947 100644 --- a/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py +++ b/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py @@ -150,8 +150,8 @@ def create( self, path, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO # noqa: F821 + compression_type=CompressionTypes.AUTO) -> BinaryIO: + # noqa: F821 """Returns a write channel for the given file path. @@ -168,8 +168,8 @@ def open( self, path, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO # noqa: F821 + compression_type=CompressionTypes.AUTO) -> BinaryIO: + # noqa: F821 """Returns a read channel for the given file path. From b8029e9aa1e272f1673ce0258cff4322634a9e17 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:47:36 -0700 Subject: [PATCH 09/29] Modernize python type hints for apache_beam/io/flink --- .../apache_beam/io/flink/flink_streaming_impulse_source.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source.py b/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source.py index d50672ed6be2..91c76b5d54bf 100644 --- a/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source.py +++ b/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source.py @@ -35,7 +35,7 @@ class FlinkStreamingImpulseSource(PTransform): URN = "flink:transform:streaming_impulse:v1" - config = {} # type: Dict[str, Any] + config: Dict[str, Any] = {} def expand(self, pbegin): assert isinstance(pbegin, pvalue.PBegin), ( From f49a29a4685cab02ed1a61a3fc18338cda2c9445 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:47:48 -0700 Subject: [PATCH 10/29] Modernize python type hints for apache_beam/io/gcp --- .../apache_beam/io/gcp/bigquery_avro_tools.py | 11 ++-- .../io/gcp/bigquery_schema_tools.py | 5 +- .../io/gcp/datastore/v1new/helper.py | 4 +- .../io/gcp/datastore/v1new/types.py | 18 +++---- .../apache_beam/io/gcp/gcsfilesystem.py | 8 +-- sdks/python/apache_beam/io/gcp/pubsub.py | 53 +++++++------------ 6 files changed, 40 insertions(+), 59 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_avro_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_avro_tools.py index d10a4d8fc2a3..c54ba74a7343 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_avro_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_avro_tools.py @@ -63,8 +63,10 @@ def get_record_schema_from_dict_table_schema( - schema_name, table_schema, namespace="apache_beam.io.gcp.bigquery"): - # type: (Text, Dict[Text, Any], Text) -> Dict[Text, Any] # noqa: F821 + schema_name: Text, + table_schema: Dict[Text, Any], + namespace: Text = "apache_beam.io.gcp.bigquery") -> Dict[Text, Any]: + # noqa: F821 """Convert a table schema into an Avro schema. @@ -90,8 +92,9 @@ def get_record_schema_from_dict_table_schema( } -def table_field_to_avro_field(table_field, namespace): - # type: (Dict[Text, Any], str) -> Dict[Text, Any] # noqa: F821 +def table_field_to_avro_field(table_field: Dict[Text, Any], + namespace: str) -> Dict[Text, Any]: + # noqa: F821 """Convert a BigQuery field to an avro field. diff --git a/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py index 7b8a58e96978..4ba8e2b84bad 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py @@ -53,9 +53,8 @@ } -def generate_user_type_from_bq_schema(the_table_schema, selected_fields=None): - #type: (bigquery.TableSchema) -> type - +def generate_user_type_from_bq_schema( + the_table_schema, selected_fields: bigquery.TableSchema = None) -> type: """Convert a schema of type TableSchema into a pcollection element. Args: the_table_schema: A BQ schema of type TableSchema diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/helper.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/helper.py index a6f8ef594695..417a04c3d2b4 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1new/helper.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/helper.py @@ -70,9 +70,9 @@ def retry_on_rpc_error(exception): def create_entities(count, id_or_name=False): """Creates a list of entities with random keys.""" if id_or_name: - ids_or_names = [ + ids_or_names: List[Union[str, int]] = [ uuid.uuid4().int & ((1 << 63) - 1) for _ in range(count) - ] # type: List[Union[str, int]] + ] else: ids_or_names = [str(uuid.uuid4()) for _ in range(count)] diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py index 137df4235d47..2029886f24a9 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py @@ -153,12 +153,12 @@ def __repr__(self): class Key(object): - def __init__(self, - path_elements, # type: List[Union[Text, int]] - parent=None, # type: Optional[Key] - project=None, # type: Optional[Text] - namespace=None # type: Optional[Text] - ): + def __init__( + self, + path_elements: List[Union[Text, int]], + parent: Optional[Key] = None, + project: Optional[Text] = None, + namespace: Optional[Text] = None): """ Represents a Datastore key. @@ -229,11 +229,7 @@ def __repr__(self): class Entity(object): - def __init__( - self, - key, # type: Key - exclude_from_indexes=() # type: Iterable[str] - ): + def __init__(self, key: Key, exclude_from_indexes: Iterable[str] = ()): """ Represents a Datastore entity. diff --git a/sdks/python/apache_beam/io/gcp/gcsfilesystem.py b/sdks/python/apache_beam/io/gcp/gcsfilesystem.py index 173b21a38f88..47d1997ddc7b 100644 --- a/sdks/python/apache_beam/io/gcp/gcsfilesystem.py +++ b/sdks/python/apache_beam/io/gcp/gcsfilesystem.py @@ -159,9 +159,7 @@ def create( self, path, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO - + compression_type=CompressionTypes.AUTO) -> BinaryIO: """Returns a write channel for the given file path. Args: @@ -177,9 +175,7 @@ def open( self, path, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO - + compression_type=CompressionTypes.AUTO) -> BinaryIO: """Returns a read channel for the given file path. Args: diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py index cec65bc530f3..6267837269e1 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub.py +++ b/sdks/python/apache_beam/io/gcp/pubsub.py @@ -110,9 +110,7 @@ def __repr__(self): return 'PubsubMessage(%s, %s)' % (self.data, self.attributes) @staticmethod - def _from_proto_str(proto_msg): - # type: (bytes) -> PubsubMessage - + def _from_proto_str(proto_msg: bytes) -> PubsubMessage: """Construct from serialized form of ``PubsubMessage``. Args: @@ -185,9 +183,7 @@ def _to_proto_str(self, for_publish=False): return serialized @staticmethod - def _from_message(msg): - # type: (Any) -> PubsubMessage - + def _from_message(msg: Any) -> PubsubMessage: """Construct from ``google.cloud.pubsub_v1.subscriber.message.Message``. https://googleapis.github.io/google-cloud-python/latest/pubsub/subscriber/api/message.html @@ -211,14 +207,11 @@ class ReadFromPubSub(PTransform): def __init__( self, - topic=None, # type: Optional[str] - subscription=None, # type: Optional[str] - id_label=None, # type: Optional[str] - with_attributes=False, # type: bool - timestamp_attribute=None # type: Optional[str] - ): - # type: (...) -> None - + topic: Optional[str] = None, + subscription: Optional[str] = None, + id_label: Optional[str] = None, + with_attributes: bool = False, + timestamp_attribute: Optional[str] = None) -> None: """Initializes ``ReadFromPubSub``. Args: @@ -327,13 +320,10 @@ class WriteToPubSub(PTransform): def __init__( self, - topic, # type: str - with_attributes=False, # type: bool - id_label=None, # type: Optional[str] - timestamp_attribute=None # type: Optional[str] - ): - # type: (...) -> None - + topic: str, + with_attributes: bool = False, + id_label: Optional[str] = None, + timestamp_attribute: Optional[str] = None) -> None: """Initializes ``WriteToPubSub``. Args: @@ -359,8 +349,7 @@ def __init__( self._sink = _PubSubSink(topic, id_label, timestamp_attribute) @staticmethod - def message_to_proto_str(element): - # type: (PubsubMessage) -> bytes + def message_to_proto_str(element: PubsubMessage) -> bytes: if not isinstance(element, PubsubMessage): raise TypeError( 'Unexpected element. Type: %s (expected: PubsubMessage), ' @@ -368,8 +357,7 @@ def message_to_proto_str(element): return element._to_proto_str(for_publish=True) @staticmethod - def bytes_to_proto_str(element): - # type: (bytes) -> bytes + def bytes_to_proto_str(element: bytes) -> bytes: msg = PubsubMessage(element, {}) return msg._to_proto_str(for_publish=True) @@ -438,12 +426,11 @@ class _PubSubSource(iobase.SourceBase): """ def __init__( self, - topic=None, # type: Optional[str] - subscription=None, # type: Optional[str] - id_label=None, # type: Optional[str] - with_attributes=False, # type: bool - timestamp_attribute=None # type: Optional[str] - ): + topic: Optional[str] = None, + subscription: Optional[str] = None, + id_label: Optional[str] = None, + with_attributes: bool = False, + timestamp_attribute: Optional[str] = None): self.coder = coders.BytesCoder() self.full_topic = topic self.full_subscription = subscription @@ -562,8 +549,8 @@ class MultipleReadFromPubSub(PTransform): """ def __init__( self, - pubsub_source_descriptors, # type: List[PubSubSourceDescriptor] - with_attributes=False, # type: bool + pubsub_source_descriptors: List[PubSubSourceDescriptor], + with_attributes: bool = False, ): """Initializes ``PubSubMultipleReader``. From 5d5a09b6b3e1fe2652599255336c19e165a16c37 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:47:49 -0700 Subject: [PATCH 11/29] Modernize python type hints for apache_beam/metrics --- sdks/python/apache_beam/metrics/metric.py | 90 +++++++------------ sdks/python/apache_beam/metrics/metricbase.py | 12 +-- 2 files changed, 40 insertions(+), 62 deletions(-) diff --git a/sdks/python/apache_beam/metrics/metric.py b/sdks/python/apache_beam/metrics/metric.py index 08a359edae90..c107e55fcd89 100644 --- a/sdks/python/apache_beam/metrics/metric.py +++ b/sdks/python/apache_beam/metrics/metric.py @@ -57,8 +57,7 @@ class Metrics(object): """Lets users create/access metric objects during pipeline execution.""" @staticmethod - def get_namespace(namespace): - # type: (Union[Type, str]) -> str + def get_namespace(namespace: Union[Type, str]) -> str: if isinstance(namespace, type): return '{}.{}'.format(namespace.__module__, namespace.__name__) elif isinstance(namespace, str): @@ -67,9 +66,8 @@ def get_namespace(namespace): raise ValueError('Unknown namespace type') @staticmethod - def counter(namespace, name): - # type: (Union[Type, str], str) -> Metrics.DelegatingCounter - + def counter( + namespace: Union[Type, str], name: str) -> Metrics.DelegatingCounter: """Obtains or creates a Counter metric. Args: @@ -83,9 +81,8 @@ def counter(namespace, name): return Metrics.DelegatingCounter(MetricName(namespace, name)) @staticmethod - def distribution(namespace, name): - # type: (Union[Type, str], str) -> Metrics.DelegatingDistribution - + def distribution( + namespace: Union[Type, str], name: str) -> Metrics.DelegatingDistribution: """Obtains or creates a Distribution metric. Distribution metrics are restricted to integer-only distributions. @@ -101,9 +98,7 @@ def distribution(namespace, name): return Metrics.DelegatingDistribution(MetricName(namespace, name)) @staticmethod - def gauge(namespace, name): - # type: (Union[Type, str], str) -> Metrics.DelegatingGauge - + def gauge(namespace: Union[Type, str], name: str) -> Metrics.DelegatingGauge: """Obtains or creates a Gauge metric. Gauge metrics are restricted to integer-only values. @@ -120,8 +115,8 @@ def gauge(namespace, name): class DelegatingCounter(Counter): """Metrics Counter that Delegates functionality to MetricsEnvironment.""" - def __init__(self, metric_name, process_wide=False): - # type: (MetricName, bool) -> None + def __init__( + self, metric_name: MetricName, process_wide: bool = False) -> None: super().__init__(metric_name) self.inc = MetricUpdater( # type: ignore[assignment] cells.CounterCell, @@ -131,15 +126,13 @@ def __init__(self, metric_name, process_wide=False): class DelegatingDistribution(Distribution): """Metrics Distribution Delegates functionality to MetricsEnvironment.""" - def __init__(self, metric_name): - # type: (MetricName) -> None + def __init__(self, metric_name: MetricName) -> None: super().__init__(metric_name) self.update = MetricUpdater(cells.DistributionCell, metric_name) # type: ignore[assignment] class DelegatingGauge(Gauge): """Metrics Gauge that Delegates functionality to MetricsEnvironment.""" - def __init__(self, metric_name): - # type: (MetricName) -> None + def __init__(self, metric_name: MetricName) -> None: super().__init__(metric_name) self.set = MetricUpdater(cells.GaugeCell, metric_name) # type: ignore[assignment] @@ -150,8 +143,7 @@ class MetricResults(object): GAUGES = "gauges" @staticmethod - def _matches_name(filter, metric_key): - # type: (MetricsFilter, MetricKey) -> bool + def _matches_name(filter: MetricsFilter, metric_key: MetricKey) -> bool: if ((filter.namespaces and metric_key.metric.namespace not in filter.namespaces) or (filter.names and metric_key.metric.name not in filter.names)): @@ -160,9 +152,7 @@ def _matches_name(filter, metric_key): return True @staticmethod - def _is_sub_list(needle, haystack): - # type: (List[str], List[str]) -> bool - + def _is_sub_list(needle: List[str], haystack: List[str]) -> bool: """True iff `needle` is a sub-list of `haystack` (i.e. a contiguous slice of `haystack` exactly matches `needle`""" needle_len = len(needle) @@ -174,9 +164,7 @@ def _is_sub_list(needle, haystack): return False @staticmethod - def _matches_sub_path(actual_scope, filter_scope): - # type: (str, str) -> bool - + def _matches_sub_path(actual_scope: str, filter_scope: str) -> bool: """True iff the '/'-delimited pieces of filter_scope exist as a sub-list of the '/'-delimited pieces of actual_scope""" return bool( @@ -184,8 +172,7 @@ def _matches_sub_path(actual_scope, filter_scope): filter_scope.split('/'), actual_scope.split('/'))) @staticmethod - def _matches_scope(filter, metric_key): - # type: (MetricsFilter, MetricKey) -> bool + def _matches_scope(filter: MetricsFilter, metric_key: MetricKey) -> bool: if not filter.steps: return True @@ -196,8 +183,7 @@ def _matches_scope(filter, metric_key): return False @staticmethod - def matches(filter, metric_key): - # type: (Optional[MetricsFilter], MetricKey) -> bool + def matches(filter: Optional[MetricsFilter], metric_key: MetricKey) -> bool: if filter is None: return True @@ -206,9 +192,9 @@ def matches(filter, metric_key): return True return False - def query(self, filter=None): - # type: (Optional[MetricsFilter]) -> Dict[str, List[MetricResults]] - + def query( + self, + filter: Optional[MetricsFilter] = None) -> Dict[str, List[MetricResults]]: """Queries the runner for existing user metrics that match the filter. It should return a dictionary, with lists of each kind of metric, and @@ -236,63 +222,53 @@ class MetricsFilter(object): Note: This class only supports user defined metrics. """ - def __init__(self): - # type: () -> None - self._names = set() # type: Set[str] - self._namespaces = set() # type: Set[str] - self._steps = set() # type: Set[str] + def __init__(self) -> None: + self._names: Set[str] = set() + self._namespaces: Set[str] = set() + self._steps: Set[str] = set() @property - def steps(self): - # type: () -> FrozenSet[str] + def steps(self) -> FrozenSet[str]: return frozenset(self._steps) @property - def names(self): - # type: () -> FrozenSet[str] + def names(self) -> FrozenSet[str]: return frozenset(self._names) @property - def namespaces(self): - # type: () -> FrozenSet[str] + def namespaces(self) -> FrozenSet[str]: return frozenset(self._namespaces) - def with_metric(self, metric): - # type: (Metric) -> MetricsFilter + def with_metric(self, metric: Metric) -> MetricsFilter: name = metric.metric_name.name or '' namespace = metric.metric_name.namespace or '' return self.with_name(name).with_namespace(namespace) - def with_name(self, name): - # type: (str) -> MetricsFilter + def with_name(self, name: str) -> MetricsFilter: return self.with_names([name]) - def with_names(self, names): - # type: (Iterable[str]) -> MetricsFilter + def with_names(self, names: Iterable[str]) -> MetricsFilter: if isinstance(names, str): raise ValueError('Names must be a collection, not a string') self._names.update(names) return self - def with_namespace(self, namespace): - # type: (Union[Type, str]) -> MetricsFilter + def with_namespace(self, namespace: Union[Type, str]) -> MetricsFilter: return self.with_namespaces([namespace]) - def with_namespaces(self, namespaces): - # type: (Iterable[Union[Type, str]]) -> MetricsFilter + def with_namespaces( + self, namespaces: Iterable[Union[Type, str]]) -> MetricsFilter: if isinstance(namespaces, str): raise ValueError('Namespaces must be an iterable, not a string') self._namespaces.update([Metrics.get_namespace(ns) for ns in namespaces]) return self - def with_step(self, step): - # type: (str) -> MetricsFilter + def with_step(self, step: str) -> MetricsFilter: return self.with_steps([step]) - def with_steps(self, steps): - # type: (Iterable[str]) -> MetricsFilter + def with_steps(self, steps: Iterable[str]) -> MetricsFilter: if isinstance(steps, str): raise ValueError('Steps must be an iterable, not a string') diff --git a/sdks/python/apache_beam/metrics/metricbase.py b/sdks/python/apache_beam/metrics/metricbase.py index 12e7881792f9..53da01f3955c 100644 --- a/sdks/python/apache_beam/metrics/metricbase.py +++ b/sdks/python/apache_beam/metrics/metricbase.py @@ -49,9 +49,12 @@ class MetricName(object): allows grouping related metrics together and also prevents collisions between multiple metrics of the same name. """ - def __init__(self, namespace, name, urn=None, labels=None): - # type: (Optional[str], Optional[str], Optional[str], Optional[Dict[str, str]]) -> None - + def __init__( + self, + namespace: Optional[str], + name: Optional[str], + urn: Optional[str] = None, + labels: Optional[Dict[str, str]] = None) -> None: """Initializes ``MetricName``. Note: namespace and name should be set for user metrics, @@ -103,8 +106,7 @@ def fast_name(self): class Metric(object): """Base interface of a metric object.""" - def __init__(self, metric_name): - # type: (MetricName) -> None + def __init__(self, metric_name: MetricName) -> None: self.metric_name = metric_name From fbafe8d780b9bfcf6beb226e3f721bf369c316f7 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:47:51 -0700 Subject: [PATCH 12/29] Modernize python type hints for apache_beam/ml/gcp --- .../apache_beam/ml/gcp/naturallanguageml.py | 36 +++++++++---------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/sdks/python/apache_beam/ml/gcp/naturallanguageml.py b/sdks/python/apache_beam/ml/gcp/naturallanguageml.py index 4f63aef68232..ceeae522890c 100644 --- a/sdks/python/apache_beam/ml/gcp/naturallanguageml.py +++ b/sdks/python/apache_beam/ml/gcp/naturallanguageml.py @@ -52,15 +52,13 @@ class Document(object): from_gcs (bool): Whether the content should be interpret as a Google Cloud Storage URI. The default value is :data:`False`. """ - def __init__( self, - content, # type: str - type='PLAIN_TEXT', # type: Union[str, language_v1.Document.Type] - language_hint=None, # type: Optional[str] - encoding='UTF8', # type: Optional[str] - from_gcs=False # type: bool - ): + content: str, + type: Union[str, language_v1.Document.Type] = 'PLAIN_TEXT', + language_hint: Optional[str] = None, + encoding: Optional[str] = 'UTF8', + from_gcs: bool = False): self.content = content self.type = type self.encoding = encoding @@ -68,8 +66,7 @@ def __init__( self.from_gcs = from_gcs @staticmethod - def to_dict(document): - # type: (Document) -> Mapping[str, Optional[str]] + def to_dict(document: Document) -> Mapping[str, Optional[str]]: if document.from_gcs: dict_repr = {'gcs_content_uri': document.content} else: @@ -82,11 +79,11 @@ def to_dict(document): @beam.ptransform_fn def AnnotateText( - pcoll, # type: beam.pvalue.PCollection - features, # type: Union[Mapping[str, bool], language_v1.AnnotateTextRequest.Features] - timeout=None, # type: Optional[float] - metadata=None # type: Optional[Sequence[Tuple[str, str]]] -): + pcoll: beam.pvalue.PCollection, + features: Union[Mapping[str, bool], + language_v1.AnnotateTextRequest.Features], + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None): """A :class:`~apache_beam.transforms.ptransform.PTransform` for annotating text using the Google Cloud Natural Language API: https://cloud.google.com/natural-language/docs. @@ -113,10 +110,10 @@ def AnnotateText( class _AnnotateTextFn(beam.DoFn): def __init__( self, - features, # type: Union[Mapping[str, bool], language_v1.AnnotateTextRequest.Features] - timeout, # type: Optional[float] - metadata=None # type: Optional[Sequence[Tuple[str, str]]] - ): + features: Union[Mapping[str, bool], + language_v1.AnnotateTextRequest.Features], + timeout: Optional[float], + metadata: Optional[Sequence[Tuple[str, str]]] = None): self.features = features self.timeout = timeout self.metadata = metadata @@ -127,8 +124,7 @@ def setup(self): self.client = self._get_api_client() @staticmethod - def _get_api_client(): - # type: () -> language.LanguageServiceClient + def _get_api_client() -> language.LanguageServiceClient: return language.LanguageServiceClient() def process(self, element): From 0eab29802a3954e91acd3dbbd1d8821cf333b0d2 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:47:54 -0700 Subject: [PATCH 13/29] Modernize python type hints for apache_beam/options --- sdks/python/apache_beam/options/value_provider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/options/value_provider.py b/sdks/python/apache_beam/options/value_provider.py index 5a5d36370f39..fa1649beed26 100644 --- a/sdks/python/apache_beam/options/value_provider.py +++ b/sdks/python/apache_beam/options/value_provider.py @@ -95,7 +95,7 @@ class RuntimeValueProvider(ValueProvider): at graph construction time. """ runtime_options = None - experiments = set() # type: Set[str] + experiments: Set[str] = set() def __init__(self, option_name, value_type, default_value): self.option_name = option_name From 842b8ecab01b9e7e21afe82f67468332028f661b Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:47:57 -0700 Subject: [PATCH 14/29] Modernize python type hints for apache_beam/runners --- .../apache_beam/runners/pipeline_context.py | 129 ++++++++---------- sdks/python/apache_beam/runners/runner.py | 42 ++---- sdks/python/apache_beam/runners/sdf_utils.py | 26 ++-- 3 files changed, 85 insertions(+), 112 deletions(-) diff --git a/sdks/python/apache_beam/runners/pipeline_context.py b/sdks/python/apache_beam/runners/pipeline_context.py index 102b8b60d69a..44961241dc15 100644 --- a/sdks/python/apache_beam/runners/pipeline_context.py +++ b/sdks/python/apache_beam/runners/pipeline_context.py @@ -59,13 +59,11 @@ class PortableObject(Protocol): - def to_runner_api(self, __context): - # type: (PipelineContext) -> Any + def to_runner_api(self, __context: PipelineContext) -> Any: pass @classmethod - def from_runner_api(cls, __proto, __context): - # type: (Any, PipelineContext) -> Any + def from_runner_api(cls, __proto: Any, __context: PipelineContext) -> Any: pass @@ -75,27 +73,24 @@ class _PipelineContextMap(Generic[PortableObjectT]): Under the hood it encodes and decodes these objects into runner API representations. """ - def __init__(self, - context, # type: PipelineContext - obj_type, # type: Type[PortableObjectT] - namespace, # type: str - proto_map=None # type: Optional[Mapping[str, message.Message]] - ): - # type: (...) -> None + def __init__( + self, + context: PipelineContext, + obj_type: Type[PortableObjectT], + namespace: str, + proto_map: Optional[Mapping[str, message.Message]] = None) -> None: self._pipeline_context = context self._obj_type = obj_type self._namespace = namespace - self._obj_to_id = {} # type: Dict[Any, str] - self._id_to_obj = {} # type: Dict[str, Any] + self._obj_to_id: Dict[Any, str] = {} + self._id_to_obj: Dict[str, Any] = {} self._id_to_proto = dict(proto_map) if proto_map else {} - def populate_map(self, proto_map): - # type: (Mapping[str, message.Message]) -> None + def populate_map(self, proto_map: Mapping[str, message.Message]) -> None: for id, proto in self._id_to_proto.items(): proto_map[id].CopyFrom(proto) - def get_id(self, obj, label=None): - # type: (PortableObjectT, Optional[str]) -> str + def get_id(self, obj: PortableObjectT, label: Optional[str] = None) -> str: if obj not in self._obj_to_id: id = self._pipeline_context.component_id_map.get_or_assign( obj, self._obj_type, label) @@ -104,19 +99,23 @@ def get_id(self, obj, label=None): self._id_to_proto[id] = obj.to_runner_api(self._pipeline_context) return self._obj_to_id[obj] - def get_proto(self, obj, label=None): - # type: (PortableObjectT, Optional[str]) -> message.Message + def get_proto( + self, + obj: PortableObjectT, + label: Optional[str] = None) -> message.Message: return self._id_to_proto[self.get_id(obj, label)] - def get_by_id(self, id): - # type: (str) -> PortableObjectT + def get_by_id(self, id: str) -> PortableObjectT: if id not in self._id_to_obj: self._id_to_obj[id] = self._obj_type.from_runner_api( self._id_to_proto[id], self._pipeline_context) return self._id_to_obj[id] - def get_by_proto(self, maybe_new_proto, label=None, deduplicate=False): - # type: (message.Message, Optional[str], bool) -> str + def get_by_proto( + self, + maybe_new_proto: message.Message, + label: Optional[str] = None, + deduplicate: bool = False) -> str: # TODO: this method may not be safe for arbitrary protos due to # xlang concerns, hence limiting usage to the only current use-case it has. # See: https://github.com/apache/beam/pull/14390#discussion_r616062377 @@ -136,16 +135,17 @@ def get_by_proto(self, maybe_new_proto, label=None, deduplicate=False): obj=obj, obj_type=self._obj_type, label=label), maybe_new_proto) - def get_id_to_proto_map(self): - # type: () -> Dict[str, message.Message] + def get_id_to_proto_map(self) -> Dict[str, message.Message]: return self._id_to_proto - def get_proto_from_id(self, id): - # type: (str) -> message.Message + def get_proto_from_id(self, id: str) -> message.Message: return self.get_id_to_proto_map()[id] - def put_proto(self, id, proto, ignore_duplicates=False): - # type: (str, message.Message, bool) -> str + def put_proto( + self, + id: str, + proto: message.Message, + ignore_duplicates: bool = False) -> str: if not ignore_duplicates and id in self._id_to_proto: raise ValueError("Id '%s' is already taken." % id) elif (ignore_duplicates and id in self._id_to_proto and @@ -158,12 +158,10 @@ def put_proto(self, id, proto, ignore_duplicates=False): self._id_to_proto[id] = proto return id - def __getitem__(self, id): - # type: (str) -> Any + def __getitem__(self, id: str) -> Any: return self.get_by_id(id) - def __contains__(self, id): - # type: (str) -> bool + def __contains__(self, id: str) -> bool: return id in self._id_to_proto @@ -172,18 +170,18 @@ class PipelineContext(object): Used for accessing and constructing the referenced objects of a Pipeline. """ - - def __init__(self, - proto=None, # type: Optional[Union[beam_runner_api_pb2.Components, beam_fn_api_pb2.ProcessBundleDescriptor]] - component_id_map=None, # type: Optional[pipeline.ComponentIdMap] - default_environment=None, # type: Optional[environments.Environment] - use_fake_coders=False, # type: bool - iterable_state_read=None, # type: Optional[IterableStateReader] - iterable_state_write=None, # type: Optional[IterableStateWriter] - namespace='ref', # type: str - requirements=(), # type: Iterable[str] - ): - # type: (...) -> None + def __init__( + self, + proto: Optional[Union[beam_runner_api_pb2.Components, + beam_fn_api_pb2.ProcessBundleDescriptor]] = None, + component_id_map: Optional[pipeline.ComponentIdMap] = None, + default_environment: Optional[environments.Environment] = None, + use_fake_coders: bool = False, + iterable_state_read: Optional[IterableStateReader] = None, + iterable_state_write: Optional[IterableStateWriter] = None, + namespace: str = 'ref', + requirements: Iterable[str] = (), + ) -> None: if isinstance(proto, beam_fn_api_pb2.ProcessBundleDescriptor): proto = beam_runner_api_pb2.Components( coders=dict(proto.coders.items()), @@ -224,22 +222,19 @@ def __init__(self, if default_environment is None: default_environment = environments.DefaultEnvironment() - self._default_environment_id = self.environments.get_id( - default_environment, label='default_environment') # type: str + self._default_environment_id: str = self.environments.get_id( + default_environment, label='default_environment') self.use_fake_coders = use_fake_coders - self.deterministic_coder_map = { - } # type: Mapping[coders.Coder, coders.Coder] + self.deterministic_coder_map: Mapping[coders.Coder, coders.Coder] = {} self.iterable_state_read = iterable_state_read self.iterable_state_write = iterable_state_write self._requirements = set(requirements) - def add_requirement(self, requirement): - # type: (str) -> None + def add_requirement(self, requirement: str) -> None: self._requirements.add(requirement) - def requirements(self): - # type: () -> FrozenSet[str] + def requirements(self) -> FrozenSet[str]: return frozenset(self._requirements) # If fake coders are requested, return a pickled version of the element type @@ -248,8 +243,9 @@ def requirements(self): # TODO(https://github.com/apache/beam/issues/18490): Remove once this is no # longer needed. def coder_id_from_element_type( - self, element_type, requires_deterministic_key_coder=None): - # type: (Any, Optional[str]) -> str + self, + element_type: Any, + requires_deterministic_key_coder: Optional[str] = None) -> str: if self.use_fake_coders: return pickler.dumps(element_type).decode('ascii') else: @@ -262,14 +258,12 @@ def coder_id_from_element_type( ]) return self.coders.get_id(coder) - def deterministic_coder(self, coder, msg): - # type: (coders.Coder, str) -> coders.Coder + def deterministic_coder(self, coder: coders.Coder, msg: str) -> coders.Coder: if coder not in self.deterministic_coder_map: self.deterministic_coder_map[coder] = coder.as_deterministic_coder(msg) # type: ignore return self.deterministic_coder_map[coder] - def element_type_from_coder_id(self, coder_id): - # type: (str) -> Any + def element_type_from_coder_id(self, coder_id: str) -> Any: if self.use_fake_coders or coder_id not in self.coders: return pickler.loads(coder_id) else: @@ -277,12 +271,10 @@ def element_type_from_coder_id(self, coder_id): self.coders[coder_id].to_type_hint()) @staticmethod - def from_runner_api(proto): - # type: (beam_runner_api_pb2.Components) -> PipelineContext + def from_runner_api(proto: beam_runner_api_pb2.Components) -> PipelineContext: return PipelineContext(proto) - def to_runner_api(self): - # type: () -> beam_runner_api_pb2.Components + def to_runner_api(self) -> beam_runner_api_pb2.Components: context_proto = beam_runner_api_pb2.Components() self.transforms.populate_map(context_proto.transforms) @@ -293,20 +285,19 @@ def to_runner_api(self): return context_proto - def default_environment_id(self): - # type: () -> str + def default_environment_id(self) -> str: return self._default_environment_id def get_environment_id_for_resource_hints( - self, hints): # type: (Dict[str, bytes]) -> str + self, hints: Dict[str, bytes]) -> str: """Returns an environment id that has necessary resource hints.""" if not hints: return self.default_environment_id() def get_or_create_environment_with_resource_hints( - template_env_id, - resource_hints, - ): # type: (str, Dict[str, bytes]) -> str + template_env_id: str, + resource_hints: Dict[str, bytes], + ) -> str: """Creates an environment that has necessary hints and returns its id.""" template_env = self.environments.get_proto_from_id(template_env_id) cloned_env = beam_runner_api_pb2.Environment() diff --git a/sdks/python/apache_beam/runners/runner.py b/sdks/python/apache_beam/runners/runner.py index d037e0d42c0b..4ba49378c8a5 100644 --- a/sdks/python/apache_beam/runners/runner.py +++ b/sdks/python/apache_beam/runners/runner.py @@ -55,9 +55,7 @@ _LOGGER = logging.getLogger(__name__) -def create_runner(runner_name): - # type: (str) -> PipelineRunner - +def create_runner(runner_name: str) -> PipelineRunner: """For internal use only; no backwards-compatibility guarantees. Creates a runner instance from a runner class name. @@ -113,13 +111,10 @@ class PipelineRunner(object): provide a new implementation for clear_pvalue(), which is used to wipe out materialized values in order to reduce footprint. """ - - def run(self, - transform, # type: PTransform - options=None # type: Optional[PipelineOptions] - ): - # type: (...) -> PipelineResult - + def run( + self, + transform: PTransform, + options: Optional[PipelineOptions] = None) -> PipelineResult: """Run the given transform or callable with this runner. Blocks until the pipeline is complete. See also `PipelineRunner.run_async`. @@ -128,12 +123,10 @@ def run(self, result.wait_until_finish() return result - def run_async(self, - transform, # type: PTransform - options=None # type: Optional[PipelineOptions] - ): - # type: (...) -> PipelineResult - + def run_async( + self, + transform: PTransform, + options: Optional[PipelineOptions] = None) -> PipelineResult: """Run the given transform or callable with this runner. May return immediately, executing the pipeline in the background. @@ -171,12 +164,7 @@ def default_environment( options.view_as(PortableOptions)) def run_pipeline( - self, - pipeline, # type: Pipeline - options # type: PipelineOptions - ): - # type: (...) -> PipelineResult - + self, pipeline: Pipeline, options: PipelineOptions) -> PipelineResult: """Execute the entire pipeline or the sub-DAG reachable from a node. """ pipeline.visit( @@ -194,11 +182,11 @@ def run_pipeline( default_environment=self.default_environment(options)), options) - def apply(self, - transform, # type: PTransform - input, # type: Optional[pvalue.PValue] - options # type: PipelineOptions - ): + def apply( + self, + transform: PTransform, + input: Optional[pvalue.PValue], + options: PipelineOptions): # TODO(robertwb): Remove indirection once internal references are fixed. return self.apply_PTransform(transform, input, options) diff --git a/sdks/python/apache_beam/runners/sdf_utils.py b/sdks/python/apache_beam/runners/sdf_utils.py index bbb6b2de6e85..d2d8a4a3c584 100644 --- a/sdks/python/apache_beam/runners/sdf_utils.py +++ b/sdks/python/apache_beam/runners/sdf_utils.py @@ -55,8 +55,7 @@ class ThreadsafeRestrictionTracker(object): This wrapper guarantees synchronization of modifying restrictions across multi-thread. """ - def __init__(self, restriction_tracker): - # type: (RestrictionTracker) -> None + def __init__(self, restriction_tracker: RestrictionTracker) -> None: from apache_beam.io.iobase import RestrictionTracker if not isinstance(restriction_tracker, RestrictionTracker): raise ValueError( @@ -67,7 +66,7 @@ def __init__(self, restriction_tracker): self._timestamp = None self._lock = threading.RLock() self._deferred_residual = None - self._deferred_timestamp = None # type: Optional[Union[Timestamp, Duration]] + self._deferred_timestamp: Optional[Union[Timestamp, Duration]] = None def current_restriction(self): with self._lock: @@ -110,8 +109,7 @@ def check_done(self): with self._lock: return self._restriction_tracker.check_done() - def current_progress(self): - # type: () -> RestrictionProgress + def current_progress(self) -> RestrictionProgress: with self._lock: return self._restriction_tracker.current_progress() @@ -119,9 +117,7 @@ def try_split(self, fraction_of_remainder): with self._lock: return self._restriction_tracker.try_split(fraction_of_remainder) - def deferred_status(self): - # type: () -> Optional[Tuple[Any, Duration]] - + def deferred_status(self) -> Optional[Tuple[Any, Duration]]: """Returns deferred work which is produced by ``defer_remainder()``. When there is a self-checkpoint performed, the system needs to fulfill the @@ -159,8 +155,9 @@ class RestrictionTrackerView(object): time, the RestrictionTrackerView will be fed into the ``DoFn.process`` as a restriction_tracker. """ - def __init__(self, threadsafe_restriction_tracker): - # type: (ThreadsafeRestrictionTracker) -> None + def __init__( + self, + threadsafe_restriction_tracker: ThreadsafeRestrictionTracker) -> None: if not isinstance(threadsafe_restriction_tracker, ThreadsafeRestrictionTracker): raise ValueError( @@ -185,8 +182,7 @@ class ThreadsafeWatermarkEstimator(object): """A threadsafe wrapper which wraps a WatermarkEstimator with locking mechanism to guarantee multi-thread safety. """ - def __init__(self, watermark_estimator): - # type: (WatermarkEstimator) -> None + def __init__(self, watermark_estimator: WatermarkEstimator) -> None: from apache_beam.io.iobase import WatermarkEstimator if not isinstance(watermark_estimator, WatermarkEstimator): raise ValueError('Initializing Threadsafe requires a WatermarkEstimator') @@ -207,13 +203,11 @@ def get_estimator_state(self): with self._lock: return self._watermark_estimator.get_estimator_state() - def current_watermark(self): - # type: () -> Timestamp + def current_watermark(self) -> Timestamp: with self._lock: return self._watermark_estimator.current_watermark() - def observe_timestamp(self, timestamp): - # type: (Timestamp) -> None + def observe_timestamp(self, timestamp: Timestamp) -> None: if not isinstance(timestamp, Timestamp): raise ValueError( 'Input of observe_timestamp should be a Timestamp ' From 0763d7e50f3e83b77d523112c0559234bf22533c Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:48:00 -0700 Subject: [PATCH 15/29] Modernize python type hints for apache_beam/runners/direct --- .../runners/direct/bundle_factory.py | 34 ++--- .../consumer_tracking_pipeline_visitor.py | 10 +- .../runners/direct/direct_runner.py | 5 +- .../runners/direct/evaluation_context.py | 117 +++++++----------- .../apache_beam/runners/direct/executor.py | 101 ++++++--------- .../runners/direct/sdf_direct_runner.py | 12 +- .../runners/direct/test_stream_impl.py | 4 +- .../runners/direct/transform_evaluator.py | 60 ++++----- .../runners/direct/watermark_manager.py | 80 ++++++------ 9 files changed, 183 insertions(+), 240 deletions(-) diff --git a/sdks/python/apache_beam/runners/direct/bundle_factory.py b/sdks/python/apache_beam/runners/direct/bundle_factory.py index e4beefe992c1..8553fdb50656 100644 --- a/sdks/python/apache_beam/runners/direct/bundle_factory.py +++ b/sdks/python/apache_beam/runners/direct/bundle_factory.py @@ -40,16 +40,17 @@ class BundleFactory(object): in case consecutive ones share the same timestamp and windows. DirectRunnerOptions.direct_runner_use_stacked_bundle controls this option. """ - def __init__(self, stacked): - # type: (bool) -> None + def __init__(self, stacked: bool) -> None: self._stacked = stacked - def create_bundle(self, output_pcollection): - # type: (Union[pvalue.PBegin, pvalue.PCollection]) -> _Bundle + def create_bundle( + self, output_pcollection: Union[pvalue.PBegin, + pvalue.PCollection]) -> _Bundle: return _Bundle(output_pcollection, self._stacked) - def create_empty_committed_bundle(self, output_pcollection): - # type: (Union[pvalue.PBegin, pvalue.PCollection]) -> _Bundle + def create_empty_committed_bundle( + self, output_pcollection: Union[pvalue.PBegin, + pvalue.PCollection]) -> _Bundle: bundle = self.create_bundle(output_pcollection) bundle.commit(None) return bundle @@ -110,27 +111,27 @@ def pane_info(self): def add_value(self, value): self._appended_values.append(value) - def windowed_values(self): - # type: () -> Iterator[WindowedValue] + def windowed_values(self) -> Iterator[WindowedValue]: # yield first windowed_value as is, then iterate through # _appended_values to yield WindowedValue on the fly. yield self._initial_windowed_value for v in self._appended_values: yield self._initial_windowed_value.with_value(v) - def __init__(self, pcollection, stacked=True): - # type: (Union[pvalue.PBegin, pvalue.PCollection], bool) -> None + def __init__( + self, + pcollection: Union[pvalue.PBegin, pvalue.PCollection], + stacked: bool = True) -> None: assert isinstance(pcollection, (pvalue.PBegin, pvalue.PCollection)) self._pcollection = pcollection - self._elements = [ - ] # type: List[Union[WindowedValue, _Bundle._StackedWindowedValues]] + self._elements: List[Union[WindowedValue, + _Bundle._StackedWindowedValues]] = [] self._stacked = stacked self._committed = False self._tag = None # optional tag information for this bundle - def get_elements_iterable(self, make_copy=False): - # type: (bool) -> Iterable[WindowedValue] - + def get_elements_iterable(self, + make_copy: bool = False) -> Iterable[WindowedValue]: """Returns iterable elements. Args: @@ -203,8 +204,7 @@ def add(self, element): def output(self, element): self.add(element) - def receive(self, element): - # type: (WindowedValue) -> None + def receive(self, element: WindowedValue) -> None: self.add(element) def commit(self, synchronized_processing_time): diff --git a/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py b/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py index 2a6fc3ee6093..60b0e5beeae2 100644 --- a/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py +++ b/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py @@ -40,10 +40,9 @@ class ConsumerTrackingPipelineVisitor(PipelineVisitor): transform has produced and committed output. """ def __init__(self): - self.value_to_consumers = { - } # type: Dict[pvalue.PValue, Set[AppliedPTransform]] - self.root_transforms = set() # type: Set[AppliedPTransform] - self.step_names = {} # type: Dict[AppliedPTransform, str] + self.value_to_consumers: Dict[pvalue.PValue, Set[AppliedPTransform]] = {} + self.root_transforms: Set[AppliedPTransform] = set() + self.step_names: Dict[AppliedPTransform, str] = {} self._num_transforms = 0 self._views = set() @@ -57,8 +56,7 @@ def views(self): """ return list(self._views) - def visit_transform(self, applied_ptransform): - # type: (AppliedPTransform) -> None + def visit_transform(self, applied_ptransform: AppliedPTransform) -> None: inputs = list(applied_ptransform.inputs) if inputs: for input_value in inputs: diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index a470ba80d8ee..1cd20550edf3 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -390,8 +390,9 @@ def __init__(self, source): self._source = source def _infer_output_coder( - self, unused_input_type=None, unused_input_coder=None): - # type: (...) -> typing.Optional[coders.Coder] + self, + unused_input_type=None, + unused_input_coder=None) -> typing.Optional[coders.Coder]: return coders.BytesCoder() def get_windowing(self, unused_inputs): diff --git a/sdks/python/apache_beam/runners/direct/evaluation_context.py b/sdks/python/apache_beam/runners/direct/evaluation_context.py index fbe59b072ae4..d42f9d43fe71 100644 --- a/sdks/python/apache_beam/runners/direct/evaluation_context.py +++ b/sdks/python/apache_beam/runners/direct/evaluation_context.py @@ -53,10 +53,7 @@ class _ExecutionContext(object): It holds the watermarks for that transform, as well as keyed states. """ - def __init__( - self, - watermarks, # type: _TransformWatermarks - keyed_states): + def __init__(self, watermarks: _TransformWatermarks, keyed_states): self.watermarks = watermarks self.keyed_states = keyed_states @@ -91,13 +88,12 @@ class _SideInputsContainer(object): It provides methods for blocking until a side-input is available and writing to a side input. """ - def __init__(self, side_inputs): - # type: (Iterable[pvalue.AsSideInput]) -> None + def __init__(self, side_inputs: Iterable[pvalue.AsSideInput]) -> None: self._lock = threading.Lock() - self._views = {} # type: Dict[pvalue.AsSideInput, _SideInputView] - self._transform_to_side_inputs = collections.defaultdict( - list - ) # type: DefaultDict[Optional[AppliedPTransform], List[pvalue.AsSideInput]] + self._views: Dict[pvalue.AsSideInput, _SideInputView] = {} + self._transform_to_side_inputs: DefaultDict[ + Optional[AppliedPTransform], + List[pvalue.AsSideInput]] = collections.defaultdict(list) # this appears unused: self._side_input_to_blocked_tasks = collections.defaultdict(list) # type: ignore @@ -111,13 +107,8 @@ def __repr__(self): for elm in self._views.values()) if self._views else '[]') return '_SideInputsContainer(_views=%s)' % views_string - def get_value_or_block_until_ready(self, - side_input, - task, # type: TransformExecutor - block_until # type: Timestamp - ): - # type: (...) -> Any - + def get_value_or_block_until_ready( + self, side_input, task: TransformExecutor, block_until: Timestamp) -> Any: """Returns the value of a view whose task is unblocked or blocks its task. It gets the value of a view whose watermark has been updated and @@ -147,9 +138,7 @@ def add_values(self, side_input, values): view.elements.extend(values) def update_watermarks_for_transform_and_unblock_tasks( - self, ptransform, watermark): - # type: (...) -> List[Tuple[TransformExecutor, Timestamp]] - + self, ptransform, watermark) -> List[Tuple[TransformExecutor, Timestamp]]: """Updates _SideInputsContainer after a watermark update and unbloks tasks. It traverses the list of side inputs per PTransform and calls @@ -170,9 +159,7 @@ def update_watermarks_for_transform_and_unblock_tasks( return unblocked_tasks def _update_watermarks_for_side_input_and_unblock_tasks( - self, side_input, watermark): - # type: (...) -> List[Tuple[TransformExecutor, Timestamp]] - + self, side_input, watermark) -> List[Tuple[TransformExecutor, Timestamp]]: """Helps update _SideInputsContainer after a watermark update. For each view of the side input, it updates the value of the watermark @@ -238,24 +225,24 @@ class EvaluationContext(object): appropriately. This includes updating the per-(step,key) state, updating global watermarks, and executing any callbacks that can be executed. """ - - def __init__(self, - pipeline_options, - bundle_factory, # type: BundleFactory - root_transforms, - value_to_consumers, - step_names, - views, # type: Iterable[pvalue.AsSideInput] - clock - ): + def __init__( + self, + pipeline_options, + bundle_factory: BundleFactory, + root_transforms, + value_to_consumers, + step_names, + views: Iterable[pvalue.AsSideInput], + clock): self.pipeline_options = pipeline_options self._bundle_factory = bundle_factory self._root_transforms = root_transforms self._value_to_consumers = value_to_consumers self._step_names = step_names self.views = views - self._pcollection_to_views = collections.defaultdict( - list) # type: DefaultDict[pvalue.PValue, List[pvalue.AsSideInput]] + self._pcollection_to_views: DefaultDict[ + pvalue.PValue, + List[pvalue.AsSideInput]] = collections.defaultdict(list) for view in views: self._pcollection_to_views[view.pvalue].append(view) self._transform_keyed_states = self._initialize_keyed_states( @@ -266,8 +253,8 @@ def __init__(self, root_transforms, value_to_consumers, self._transform_keyed_states) - self._pending_unblocked_tasks = [ - ] # type: List[Tuple[TransformExecutor, Timestamp]] + self._pending_unblocked_tasks: List[Tuple[TransformExecutor, + Timestamp]] = [] self._counter_factory = counters.CounterFactory() self._metrics = DirectMetrics() @@ -291,15 +278,14 @@ def metrics(self): # TODO. Should this be made a @property? return self._metrics - def is_root_transform(self, applied_ptransform): - # type: (AppliedPTransform) -> bool + def is_root_transform(self, applied_ptransform: AppliedPTransform) -> bool: return applied_ptransform in self._root_transforms - def handle_result(self, - completed_bundle, # type: _Bundle - completed_timers, - result # type: TransformResult - ): + def handle_result( + self, + completed_bundle: _Bundle, + completed_timers, + result: TransformResult): """Handle the provided result produced after evaluating the input bundle. Handle the provided TransformResult, produced after evaluating @@ -352,10 +338,8 @@ def handle_result(self, existing_keyed_state[k] = v return committed_bundles - def _update_side_inputs_container(self, - committed_bundles, # type: Iterable[_Bundle] - result # type: TransformResult - ): + def _update_side_inputs_container( + self, committed_bundles: Iterable[_Bundle], result: TransformResult): """Update the side inputs container if we are outputting into a side input. Look at the result, and if it's outputing into a PCollection that we have @@ -381,12 +365,11 @@ def schedule_pending_unblocked_tasks(self, executor_service): executor_service.submit(task) self._pending_unblocked_tasks = [] - def _commit_bundles(self, - uncommitted_bundles, # type: Iterable[_Bundle] - unprocessed_bundles # type: Iterable[_Bundle] - ): - # type: (...) -> Tuple[Tuple[_Bundle, ...], Tuple[_Bundle, ...]] - + def _commit_bundles( + self, + uncommitted_bundles: Iterable[_Bundle], + unprocessed_bundles: Iterable[_Bundle] + ) -> Tuple[Tuple[_Bundle, ...], Tuple[_Bundle, ...]]: """Commits bundles and returns a immutable set of committed bundles.""" for in_progress_bundle in uncommitted_bundles: producing_applied_ptransform = in_progress_bundle.pcollection.producer @@ -398,32 +381,29 @@ def _commit_bundles(self, unprocessed_bundle.commit(None) return tuple(uncommitted_bundles), tuple(unprocessed_bundles) - def get_execution_context(self, applied_ptransform): - # type: (AppliedPTransform) -> _ExecutionContext + def get_execution_context( + self, applied_ptransform: AppliedPTransform) -> _ExecutionContext: return _ExecutionContext( self._watermark_manager.get_watermarks(applied_ptransform), self._transform_keyed_states[applied_ptransform]) - def create_bundle(self, output_pcollection): - # type: (Union[pvalue.PBegin, pvalue.PCollection]) -> _Bundle - + def create_bundle( + self, output_pcollection: Union[pvalue.PBegin, + pvalue.PCollection]) -> _Bundle: """Create an uncommitted bundle for the specified PCollection.""" return self._bundle_factory.create_bundle(output_pcollection) - def create_empty_committed_bundle(self, output_pcollection): - # type: (pvalue.PCollection) -> _Bundle - + def create_empty_committed_bundle( + self, output_pcollection: pvalue.PCollection) -> _Bundle: """Create empty bundle useful for triggering evaluation.""" return self._bundle_factory.create_empty_committed_bundle( output_pcollection) - def extract_all_timers(self): - # type: () -> Tuple[List[Tuple[AppliedPTransform, List[TimerFiring]]], bool] + def extract_all_timers( + self) -> Tuple[List[Tuple[AppliedPTransform, List[TimerFiring]]], bool]: return self._watermark_manager.extract_all_timers() - def is_done(self, transform=None): - # type: (Optional[AppliedPTransform]) -> bool - + def is_done(self, transform: Optional[AppliedPTransform] = None) -> bool: """Checks completion of a step or the pipeline. Args: @@ -441,8 +421,7 @@ def is_done(self, transform=None): return False return True - def _is_transform_done(self, transform): - # type: (AppliedPTransform) -> bool + def _is_transform_done(self, transform: AppliedPTransform) -> bool: tw = self._watermark_manager.get_watermarks(transform) return tw.output_watermark == WatermarkManager.WATERMARK_POS_INF diff --git a/sdks/python/apache_beam/runners/direct/executor.py b/sdks/python/apache_beam/runners/direct/executor.py index 0ab3033d68b5..8c389ffccf5d 100644 --- a/sdks/python/apache_beam/runners/direct/executor.py +++ b/sdks/python/apache_beam/runners/direct/executor.py @@ -64,9 +64,7 @@ class _ExecutorServiceWorker(threading.Thread): TIMEOUT = 5 def __init__( - self, - queue, # type: queue.Queue[_ExecutorService.CallableTask] - index): + self, queue: queue.Queue[_ExecutorService.CallableTask], index): super().__init__() self.queue = queue self._index = index @@ -86,8 +84,7 @@ def _update_name(self, task=None): self.name = 'Thread: %d, %s (%s)' % ( self._index, name, 'executing' if task else 'idle') - def _get_task_or_none(self): - # type: () -> Optional[_ExecutorService.CallableTask] + def _get_task_or_none(self) -> Optional[_ExecutorService.CallableTask]: try: # Do not block indefinitely, otherwise we may not act for a requested # shutdown. @@ -114,16 +111,14 @@ def shutdown(self): self.shutdown_requested = True def __init__(self, num_workers): - self.queue = queue.Queue( - ) # type: queue.Queue[_ExecutorService.CallableTask] + self.queue: queue.Queue[_ExecutorService.CallableTask] = queue.Queue() self.workers = [ _ExecutorService._ExecutorServiceWorker(self.queue, i) for i in range(num_workers) ] self.shutdown_requested = False - def submit(self, task): - # type: (_ExecutorService.CallableTask) -> None + def submit(self, task: _ExecutorService.CallableTask) -> None: assert isinstance(task, _ExecutorService.CallableTask) if not self.shutdown_requested: self.queue.put(task) @@ -150,11 +145,7 @@ def shutdown(self): class _TransformEvaluationState(object): - def __init__( - self, - executor_service, - scheduled # type: Set[TransformExecutor] - ): + def __init__(self, executor_service, scheduled: Set[TransformExecutor]): self.executor_service = executor_service self.scheduled = scheduled @@ -219,21 +210,18 @@ class _TransformExecutorServices(object): Controls the concurrency as appropriate for the applied transform the executor exists for. """ - def __init__(self, executor_service): - # type: (_ExecutorService) -> None + def __init__(self, executor_service: _ExecutorService) -> None: self._executor_service = executor_service - self._scheduled = set() # type: Set[TransformExecutor] + self._scheduled: Set[TransformExecutor] = set() self._parallel = _ParallelEvaluationState( self._executor_service, self._scheduled) - self._serial_cache = WeakValueDictionary( - ) # type: WeakValueDictionary[Any, _SerialEvaluationState] + self._serial_cache: WeakValueDictionary[ + Any, _SerialEvaluationState] = WeakValueDictionary() - def parallel(self): - # type: () -> _ParallelEvaluationState + def parallel(self) -> _ParallelEvaluationState: return self._parallel - def serial(self, step): - # type: (Any) -> _SerialEvaluationState + def serial(self, step: Any) -> _SerialEvaluationState: cached = self._serial_cache.get(step) if not cached: cached = _SerialEvaluationState(self._executor_service, self._scheduled) @@ -241,8 +229,7 @@ def serial(self, step): return cached @property - def executors(self): - # type: () -> FrozenSet[TransformExecutor] + def executors(self) -> FrozenSet[TransformExecutor]: return frozenset(self._scheduled) @@ -253,12 +240,11 @@ class _CompletionCallback(object): that are triggered due to the arrival of elements from an upstream transform, or for a source transform. """ - - def __init__(self, - evaluation_context, # type: EvaluationContext - all_updates, - timer_firings=None - ): + def __init__( + self, + evaluation_context: EvaluationContext, + all_updates, + timer_firings=None): self._evaluation_context = evaluation_context self._all_updates = all_updates self._timer_firings = timer_firings or [] @@ -295,15 +281,15 @@ class TransformExecutor(_ExecutorService.CallableTask): _MAX_RETRY_PER_BUNDLE = 4 - def __init__(self, - transform_evaluator_registry, # type: TransformEvaluatorRegistry - evaluation_context, # type: EvaluationContext - input_bundle, # type: _Bundle - fired_timers, - applied_ptransform, - completion_callback, - transform_evaluation_state # type: _TransformEvaluationState - ): + def __init__( + self, + transform_evaluator_registry: TransformEvaluatorRegistry, + evaluation_context: EvaluationContext, + input_bundle: _Bundle, + fired_timers, + applied_ptransform, + completion_callback, + transform_evaluation_state: _TransformEvaluationState): self._transform_evaluator_registry = transform_evaluator_registry self._evaluation_context = evaluation_context self._input_bundle = input_bundle @@ -319,7 +305,7 @@ def __init__(self, self._applied_ptransform = applied_ptransform self._completion_callback = completion_callback self._transform_evaluation_state = transform_evaluation_state - self._side_input_values = {} # type: Dict[pvalue.AsSideInput, Any] + self._side_input_values: Dict[pvalue.AsSideInput, Any] = {} self.blocked = False self._call_count = 0 self._retry_count = 0 @@ -444,8 +430,7 @@ def __init__( self, value_to_consumers, transform_evaluator_registry, - evaluation_context # type: EvaluationContext - ): + evaluation_context: EvaluationContext): self.executor_service = _ExecutorService( _ExecutorServiceParallelExecutor.NUM_WORKERS) self.transform_executor_services = _TransformExecutorServices( @@ -487,8 +472,7 @@ def request_shutdown(self): self.executor_service.await_completion() self.evaluation_context.shutdown() - def schedule_consumers(self, committed_bundle): - # type: (_Bundle) -> None + def schedule_consumers(self, committed_bundle: _Bundle) -> None: if committed_bundle.pcollection in self.value_to_consumers: consumers = self.value_to_consumers[committed_bundle.pcollection] for applied_ptransform in consumers: @@ -500,20 +484,20 @@ def schedule_consumers(self, committed_bundle): def schedule_unprocessed_bundle(self, applied_ptransform, unprocessed_bundle): self.node_to_pending_bundles[applied_ptransform].append(unprocessed_bundle) - def schedule_consumption(self, - consumer_applied_ptransform, - committed_bundle, # type: _Bundle - fired_timers, - on_complete - ): + def schedule_consumption( + self, + consumer_applied_ptransform, + committed_bundle: _Bundle, + fired_timers, + on_complete): """Schedules evaluation of the given bundle with the transform.""" assert consumer_applied_ptransform assert committed_bundle assert on_complete if self.transform_evaluator_registry.should_execute_serially( consumer_applied_ptransform): - transform_executor_service = self.transform_executor_services.serial( - consumer_applied_ptransform) # type: _TransformEvaluationState + transform_executor_service: _TransformEvaluationState = self.transform_executor_services.serial( + consumer_applied_ptransform) else: transform_executor_service = self.transform_executor_services.parallel() @@ -587,8 +571,7 @@ def __init__(self, exception=None): class _MonitorTask(_ExecutorService.CallableTask): """MonitorTask continuously runs to ensure that pipeline makes progress.""" - def __init__(self, executor): - # type: (_ExecutorServiceParallelExecutor) -> None + def __init__(self, executor: _ExecutorServiceParallelExecutor) -> None: self._executor = executor @property @@ -624,9 +607,7 @@ def call(self, state_sampler): if not self._should_shutdown(): self._executor.executor_service.submit(self) - def _should_shutdown(self): - # type: () -> bool - + def _should_shutdown(self) -> bool: """Checks whether the pipeline is completed and should be shut down. If there is anything in the queue of tasks to do or @@ -690,9 +671,7 @@ def _fire_timers(self): timer_completion_callback) return bool(transform_fired_timers) - def _is_executing(self): - # type: () -> bool - + def _is_executing(self) -> bool: """Checks whether the job is still executing. Returns: diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py index 528b2d1f576b..119383856ba2 100644 --- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py @@ -541,8 +541,10 @@ def __init__(self): self.output_iter = None def handle_process_outputs( - self, windowed_input_element, output_iter, watermark_estimator=None): - # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None + self, + windowed_input_element: WindowedValue, + output_iter: Iterable[Any], + watermark_estimator: Optional[WatermarkEstimator] = None) -> None: self.output_iter = output_iter def reset(self): @@ -551,6 +553,8 @@ def reset(self): class _NoneShallPassOutputHandler(OutputHandler): def handle_process_outputs( - self, windowed_input_element, output_iter, watermark_estimator=None): - # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None + self, + windowed_input_element: WindowedValue, + output_iter: Iterable[Any], + watermark_estimator: Optional[WatermarkEstimator] = None) -> None: raise RuntimeError() diff --git a/sdks/python/apache_beam/runners/direct/test_stream_impl.py b/sdks/python/apache_beam/runners/direct/test_stream_impl.py index 0842a51d5666..1cda97bc56eb 100644 --- a/sdks/python/apache_beam/runners/direct/test_stream_impl.py +++ b/sdks/python/apache_beam/runners/direct/test_stream_impl.py @@ -309,8 +309,8 @@ def is_alive(): return not (shutdown_requested or evaluation_context.shutdown_requested) # The shared queue that allows the producer and consumer to communicate. - channel = Queue( - ) # type: Queue[Union[test_stream.Event, _EndOfStream]] # noqa: F821 + channel: Queue[Union[test_stream.Event, + _EndOfStream]] = Queue() # noqa: F821 event_stream = Thread( target=_TestStream._stream_events_from_rpc, args=(endpoint, output_tags, coder, channel, is_alive)) diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py index 37004c7258a7..a7db67f7f098 100644 --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py @@ -89,14 +89,13 @@ class TransformEvaluatorRegistry(object): Creates instances of TransformEvaluator for the application of a transform. """ - _test_evaluators_overrides = { - } # type: Dict[Type[core.PTransform], Type[_TransformEvaluator]] + _test_evaluators_overrides: Dict[Type[core.PTransform], + Type[_TransformEvaluator]] = {} - def __init__(self, evaluation_context): - # type: (EvaluationContext) -> None + def __init__(self, evaluation_context: EvaluationContext) -> None: assert evaluation_context self._evaluation_context = evaluation_context - self._evaluators = { + self._evaluators: Dict[Type[core.PTransform], Type[_TransformEvaluator]] = { io.Read: _BoundedReadEvaluator, _DirectReadFromPubSub: _PubSubReadEvaluator, core.Flatten: _FlattenEvaluator, @@ -109,7 +108,7 @@ def __init__(self, evaluation_context): ProcessElements: _ProcessElementsEvaluator, _WatermarkController: _WatermarkControllerEvaluator, PairWithTiming: _PairWithTimingEvaluator, - } # type: Dict[Type[core.PTransform], Type[_TransformEvaluator]] + } self._evaluators.update(self._test_evaluators_overrides) self._root_bundle_providers = { core.PTransform: DefaultRootBundleProvider, @@ -231,13 +230,12 @@ def get_root_bundles(self): class _TransformEvaluator(object): """An evaluator of a specific application of a transform.""" - - def __init__(self, - evaluation_context, # type: EvaluationContext - applied_ptransform, # type: AppliedPTransform - input_committed_bundle, - side_inputs - ): + def __init__( + self, + evaluation_context: EvaluationContext, + applied_ptransform: AppliedPTransform, + input_committed_bundle, + side_inputs): self._evaluation_context = evaluation_context self._applied_ptransform = applied_ptransform self._input_committed_bundle = input_committed_bundle @@ -321,9 +319,7 @@ def process_element(self, element): """Processes a new element as part of the current bundle.""" raise NotImplementedError('%s do not process elements.' % type(self)) - def finish_bundle(self): - # type: () -> TransformResult - + def finish_bundle(self) -> TransformResult: """Finishes the bundle and produces output.""" pass @@ -592,7 +588,7 @@ class _PubSubReadEvaluator(_TransformEvaluator): # A mapping of transform to _PubSubSubscriptionWrapper. # TODO(https://github.com/apache/beam/issues/19751): Prevents garbage # collection of pipeline instances. - _subscription_cache = {} # type: Dict[AppliedPTransform, str] + _subscription_cache: Dict[AppliedPTransform, str] = {} def __init__( self, @@ -607,7 +603,7 @@ def __init__( input_committed_bundle, side_inputs) - self.source = self._applied_ptransform.transform._source # type: _PubSubSource + self.source: _PubSubSource = self._applied_ptransform.transform._source if self.source.id_label: raise NotImplementedError( 'DirectRunner: id_label is not supported for PubSub reads') @@ -655,8 +651,8 @@ def start_bundle(self): def process_element(self, element): pass - def _read_from_pubsub(self, timestamp_attribute): - # type: (...) -> List[Tuple[Timestamp, PubsubMessage]] + def _read_from_pubsub( + self, timestamp_attribute) -> List[Tuple[Timestamp, PubsubMessage]]: from apache_beam.io.gcp.pubsub import PubsubMessage from google.cloud import pubsub @@ -699,8 +695,7 @@ def _get_element(message): return results - def finish_bundle(self): - # type: () -> TransformResult + def finish_bundle(self) -> TransformResult: data = self._read_from_pubsub(self.source.timestamp_attribute) if data: output_pcollection = list(self._outputs)[0] @@ -777,8 +772,7 @@ def __init__(self, evaluation_context): class NullReceiver(common.Receiver): """Ignores undeclared outputs, default execution mode.""" - def receive(self, element): - # type: (WindowedValue) -> None + def receive(self, element: WindowedValue) -> None: pass class _InMemoryReceiver(common.Receiver): @@ -787,8 +781,7 @@ def __init__(self, target, tag): self._target = target self._tag = tag - def receive(self, element): - # type: (WindowedValue) -> None + def receive(self, element: WindowedValue) -> None: self._target[self._tag].append(element) def __missing__(self, key): @@ -799,14 +792,13 @@ def __missing__(self, key): class _ParDoEvaluator(_TransformEvaluator): """TransformEvaluator for ParDo transform.""" - - def __init__(self, - evaluation_context, # type: EvaluationContext - applied_ptransform, # type: AppliedPTransform - input_committed_bundle, - side_inputs, - perform_dofn_pickle_test=True - ): + def __init__( + self, + evaluation_context: EvaluationContext, + applied_ptransform: AppliedPTransform, + input_committed_bundle, + side_inputs, + perform_dofn_pickle_test=True): super().__init__( evaluation_context, applied_ptransform, diff --git a/sdks/python/apache_beam/runners/direct/watermark_manager.py b/sdks/python/apache_beam/runners/direct/watermark_manager.py index 8f97de508ff5..077f9f05e183 100644 --- a/sdks/python/apache_beam/runners/direct/watermark_manager.py +++ b/sdks/python/apache_beam/runners/direct/watermark_manager.py @@ -55,8 +55,8 @@ def __init__( self._value_to_consumers = value_to_consumers self._transform_keyed_states = transform_keyed_states # AppliedPTransform -> TransformWatermarks - self._transform_to_watermarks = { - } # type: Dict[AppliedPTransform, _TransformWatermarks] + self._transform_to_watermarks: Dict[AppliedPTransform, + _TransformWatermarks] = {} for root_transform in root_transforms: self._transform_to_watermarks[root_transform] = _TransformWatermarks( @@ -71,8 +71,8 @@ def __init__( for consumer in consumers: self._update_input_transform_watermarks(consumer) - def _update_input_transform_watermarks(self, applied_ptransform): - # type: (AppliedPTransform) -> None + def _update_input_transform_watermarks( + self, applied_ptransform: AppliedPTransform) -> None: assert isinstance(applied_ptransform, pipeline.AppliedPTransform) input_transform_watermarks = [] for input_pvalue in applied_ptransform.inputs: @@ -84,9 +84,8 @@ def _update_input_transform_watermarks(self, applied_ptransform): applied_ptransform].update_input_transform_watermarks( input_transform_watermarks) - def get_watermarks(self, applied_ptransform): - # type: (AppliedPTransform) -> _TransformWatermarks - + def get_watermarks( + self, applied_ptransform: AppliedPTransform) -> _TransformWatermarks: """Gets the input and output watermarks for an AppliedPTransform. If the applied_ptransform has not processed any elements, return a @@ -107,15 +106,15 @@ def get_watermarks(self, applied_ptransform): return self._transform_to_watermarks[applied_ptransform] - def update_watermarks(self, - completed_committed_bundle, # type: _Bundle - applied_ptransform, # type: AppliedPTransform - completed_timers, - outputs, - unprocessed_bundles, - keyed_earliest_holds, - side_inputs_container - ): + def update_watermarks( + self, + completed_committed_bundle: _Bundle, + applied_ptransform: AppliedPTransform, + completed_timers, + outputs, + unprocessed_bundles, + keyed_earliest_holds, + side_inputs_container): assert isinstance(applied_ptransform, pipeline.AppliedPTransform) self._update_pending( completed_committed_bundle, @@ -127,13 +126,13 @@ def update_watermarks(self, tw.hold(keyed_earliest_holds) return self._refresh_watermarks(applied_ptransform, side_inputs_container) - def _update_pending(self, - input_committed_bundle, - applied_ptransform, # type: AppliedPTransform - completed_timers, - output_committed_bundles, # type: Iterable[_Bundle] - unprocessed_bundles # type: Iterable[_Bundle] - ): + def _update_pending( + self, + input_committed_bundle, + applied_ptransform: AppliedPTransform, + completed_timers, + output_committed_bundles: Iterable[_Bundle], + unprocessed_bundles: Iterable[_Bundle]): """Updated list of pending bundles for the given AppliedPTransform.""" # Update pending elements. Filter out empty bundles. They do not impact @@ -179,12 +178,11 @@ def _refresh_watermarks(self, applied_ptransform, side_inputs_container): applied_ptransform, tw)) return unblocked_tasks - def extract_all_timers(self): - # type: () -> Tuple[List[Tuple[AppliedPTransform, List[TimerFiring]]], bool] - + def extract_all_timers( + self) -> Tuple[List[Tuple[AppliedPTransform, List[TimerFiring]]], bool]: """Extracts fired timers for all transforms and reports if there are any timers set.""" - all_timers = [] # type: List[Tuple[AppliedPTransform, List[TimerFiring]]] + all_timers: List[Tuple[AppliedPTransform, List[TimerFiring]]] = [] has_realtime_timer = False for applied_ptransform, tw in self._transform_to_watermarks.items(): fired_timers, had_realtime_timer = tw.extract_transform_timers() @@ -203,19 +201,19 @@ class _TransformWatermarks(object): def __init__(self, clock, keyed_states, transform): self._clock = clock self._keyed_states = keyed_states - self._input_transform_watermarks = [] # type: List[_TransformWatermarks] + self._input_transform_watermarks: List[_TransformWatermarks] = [] self._input_watermark = WatermarkManager.WATERMARK_NEG_INF self._output_watermark = WatermarkManager.WATERMARK_NEG_INF self._keyed_earliest_holds = {} # Scheduled bundles targeted for this transform. - self._pending = set() # type: Set[_Bundle] + self._pending: Set[_Bundle] = set() self._fired_timers = set() self._lock = threading.Lock() self._label = str(transform) - def update_input_transform_watermarks(self, input_transform_watermarks): - # type: (List[_TransformWatermarks]) -> None + def update_input_transform_watermarks( + self, input_transform_watermarks: List[_TransformWatermarks]) -> None: with self._lock: self._input_transform_watermarks = input_transform_watermarks @@ -225,14 +223,12 @@ def update_timers(self, completed_timers): self._fired_timers.remove(timer_firing) @property - def input_watermark(self): - # type: () -> Timestamp + def input_watermark(self) -> Timestamp: with self._lock: return self._input_watermark @property - def output_watermark(self): - # type: () -> Timestamp + def output_watermark(self) -> Timestamp: with self._lock: return self._output_watermark @@ -244,22 +240,18 @@ def hold(self, keyed_earliest_holds): hold_value == WatermarkManager.WATERMARK_POS_INF): del self._keyed_earliest_holds[key] - def add_pending(self, pending): - # type: (_Bundle) -> None + def add_pending(self, pending: _Bundle) -> None: with self._lock: self._pending.add(pending) - def remove_pending(self, completed): - # type: (_Bundle) -> None + def remove_pending(self, completed: _Bundle) -> None: with self._lock: # Ignore repeated removes. This will happen if a transform has a repeated # input. if completed in self._pending: self._pending.remove(completed) - def refresh(self): - # type: () -> bool - + def refresh(self) -> bool: """Refresh the watermark for a given transform. This method looks at the watermark coming from all input PTransforms, and @@ -308,9 +300,7 @@ def refresh(self): def synchronized_processing_output_time(self): return self._clock.time() - def extract_transform_timers(self): - # type: () -> Tuple[List[TimerFiring], bool] - + def extract_transform_timers(self) -> Tuple[List[TimerFiring], bool]: """Extracts fired timers and reports of any timers set per transform.""" with self._lock: fired_timers = [] From 3cf0c5512a3389c166b007f32cec886fe7383c6a Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:48:08 -0700 Subject: [PATCH 16/29] Modernize python type hints for apache_beam/runners/interactive --- .../interactive/background_caching_job.py | 4 +- .../runners/interactive/cache_manager.py | 4 +- .../runners/interactive/cache_manager_test.py | 2 +- .../interactive/display/pipeline_graph.py | 17 +-- .../display/pipeline_graph_renderer.py | 30 ++-- .../interactive/options/capture_control.py | 9 +- .../interactive/options/capture_limiters.py | 15 +- .../runners/interactive/recording_manager.py | 140 ++++++++---------- .../testing/integration/notebook_executor.py | 3 +- .../testing/integration/screen_diff.py | 7 +- 10 files changed, 95 insertions(+), 136 deletions(-) diff --git a/sdks/python/apache_beam/runners/interactive/background_caching_job.py b/sdks/python/apache_beam/runners/interactive/background_caching_job.py index 3802cfa60095..71f7f77ded4e 100644 --- a/sdks/python/apache_beam/runners/interactive/background_caching_job.py +++ b/sdks/python/apache_beam/runners/interactive/background_caching_job.py @@ -193,9 +193,7 @@ def is_background_caching_job_needed(user_pipeline): cache_changed)) -def is_cache_complete(pipeline_id): - # type: (str) -> bool - +def is_cache_complete(pipeline_id: str) -> bool: """Returns True if the backgrond cache for the given pipeline is done. """ user_pipeline = ie.current_env().pipeline_id_to_pipeline(pipeline_id) diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager.py b/sdks/python/apache_beam/runners/interactive/cache_manager.py index b04eb92132a5..ce543796a6bd 100644 --- a/sdks/python/apache_beam/runners/interactive/cache_manager.py +++ b/sdks/python/apache_beam/runners/interactive/cache_manager.py @@ -145,9 +145,7 @@ def cleanup(self): """Cleans up all the PCollection caches.""" raise NotImplementedError - def size(self, *labels): - # type: (*str) -> int - + def size(self, *labels: str) -> int: """Returns the size of the PCollection on disk in bytes.""" raise NotImplementedError diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager_test.py b/sdks/python/apache_beam/runners/interactive/cache_manager_test.py index 8dd525978284..a5d38682716c 100644 --- a/sdks/python/apache_beam/runners/interactive/cache_manager_test.py +++ b/sdks/python/apache_beam/runners/interactive/cache_manager_test.py @@ -37,7 +37,7 @@ class FileBasedCacheManagerTest(object): tested with InteractiveRunner as a part of integration tests instead. """ - cache_format = None # type: str + cache_format: str = None def setUp(self): self.cache_manager = cache.FileBasedCacheManager( diff --git a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py index 92cb108bc46f..1f1e315fea09 100644 --- a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py +++ b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py @@ -48,7 +48,7 @@ class PipelineGraph(object): """Creates a DOT representing the pipeline. Thread-safe. Runner agnostic.""" def __init__( self, - pipeline, # type: Union[beam_runner_api_pb2.Pipeline, beam.Pipeline] + pipeline: Union[beam_runner_api_pb2.Pipeline, beam.Pipeline], default_vertex_attrs={'shape': 'box'}, default_edge_attrs=None, render_option=None): @@ -71,7 +71,7 @@ def __init__( rendered. See display.pipeline_graph_renderer for available options. """ self._lock = threading.Lock() - self._graph = None # type: pydot.Dot + self._graph: pydot.Dot = None self._pipeline_instrument = None if isinstance(pipeline, beam.Pipeline): self._pipeline_instrument = inst.PipelineInstrument( @@ -90,10 +90,9 @@ def __init__( (beam_runner_api_pb2.Pipeline, beam.Pipeline, type(pipeline))) # A dict from PCollection ID to a list of its consuming Transform IDs - self._consumers = collections.defaultdict( - list) # type: DefaultDict[str, List[str]] + self._consumers: DefaultDict[str, List[str]] = collections.defaultdict(list) # A dict from PCollection ID to its producing Transform ID - self._producers = {} # type: Dict[str, str] + self._producers: Dict[str, str] = {} for transform_id, transform_proto in self._top_level_transforms(): for pcoll_id in transform_proto.inputs.values(): @@ -113,8 +112,7 @@ def __init__( self._renderer = pipeline_graph_renderer.get_renderer(render_option) - def get_dot(self): - # type: () -> str + def get_dot(self) -> str: return self._get_graph().to_string() def display_graph(self): @@ -130,9 +128,8 @@ def display_graph(self): 'environment is in a notebook. Cannot display the ' 'pipeline graph.') - def _top_level_transforms(self): - # type: () -> Iterator[Tuple[str, beam_runner_api_pb2.PTransform]] - + def _top_level_transforms( + self) -> Iterator[Tuple[str, beam_runner_api_pb2.PTransform]]: """Yields all top level PTransforms (subtransforms of the root PTransform). Yields: (str, PTransform proto) ID, proto pair of top level PTransforms. diff --git a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py index 9e23fc1deeda..a09205e23401 100644 --- a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py +++ b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py @@ -40,17 +40,13 @@ class PipelineGraphRenderer(BeamPlugin, metaclass=abc.ABCMeta): """ @classmethod @abc.abstractmethod - def option(cls): - # type: () -> str - + def option(cls) -> str: """The corresponding rendering option for the renderer. """ raise NotImplementedError @abc.abstractmethod - def render_pipeline_graph(self, pipeline_graph): - # type: (PipelineGraph) -> str - + def render_pipeline_graph(self, pipeline_graph: PipelineGraph) -> str: """Renders the pipeline graph in HTML-compatible format. Args: @@ -66,12 +62,10 @@ class MuteRenderer(PipelineGraphRenderer): """Use this renderer to mute the pipeline display. """ @classmethod - def option(cls): - # type: () -> str + def option(cls) -> str: return 'mute' - def render_pipeline_graph(self, pipeline_graph): - # type: (PipelineGraph) -> str + def render_pipeline_graph(self, pipeline_graph: PipelineGraph) -> str: return '' @@ -79,12 +73,10 @@ class TextRenderer(PipelineGraphRenderer): """This renderer simply returns the dot representation in text format. """ @classmethod - def option(cls): - # type: () -> str + def option(cls) -> str: return 'text' - def render_pipeline_graph(self, pipeline_graph): - # type: (PipelineGraph) -> str + def render_pipeline_graph(self, pipeline_graph: PipelineGraph) -> str: return pipeline_graph.get_dot() @@ -96,18 +88,14 @@ class PydotRenderer(PipelineGraphRenderer): 2. The python module pydot: https://pypi.org/project/pydot/ """ @classmethod - def option(cls): - # type: () -> str + def option(cls) -> str: return 'graph' - def render_pipeline_graph(self, pipeline_graph): - # type: (PipelineGraph) -> str + def render_pipeline_graph(self, pipeline_graph: PipelineGraph) -> str: return pipeline_graph._get_graph().create_svg().decode("utf-8") # pylint: disable=protected-access -def get_renderer(option=None): - # type: (Optional[str]) -> Type[PipelineGraphRenderer] - +def get_renderer(option: Optional[str] = None) -> Type[PipelineGraphRenderer]: """Get an instance of PipelineGraphRenderer given rendering option. Args: diff --git a/sdks/python/apache_beam/runners/interactive/options/capture_control.py b/sdks/python/apache_beam/runners/interactive/options/capture_control.py index 86422cd8219d..c14deeeb3956 100644 --- a/sdks/python/apache_beam/runners/interactive/options/capture_control.py +++ b/sdks/python/apache_beam/runners/interactive/options/capture_control.py @@ -45,8 +45,8 @@ def __init__(self): self._capture_size_limit = 1e9 self._test_limiters = None - def limiters(self): - # type: () -> List[capture_limiters.Limiter] # noqa: F821 + def limiters(self) -> List[capture_limiters.Limiter]: + # noqa: F821 if self._test_limiters: return self._test_limiters return [ @@ -54,8 +54,9 @@ def limiters(self): capture_limiters.DurationLimiter(self._capture_duration) ] - def set_limiters_for_test(self, limiters): - # type: (List[capture_limiters.Limiter]) -> None # noqa: F821 + def set_limiters_for_test( + self, limiters: List[capture_limiters.Limiter]) -> None: + # noqa: F821 self._test_limiters = limiters diff --git a/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py b/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py index 9634685e6fb5..3b2fb9f326ea 100644 --- a/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py +++ b/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py @@ -32,9 +32,7 @@ class Limiter: """Limits an aspect of the caching layer.""" - def is_triggered(self): - # type: () -> bool - + def is_triggered(self) -> bool: """Returns True if the limiter has triggered, and caching should stop.""" raise NotImplementedError @@ -43,8 +41,8 @@ class ElementLimiter(Limiter): """A `Limiter` that limits reading from cache based on some property of an element. """ - def update(self, e): - # type: (Any) -> None # noqa: F821 + def update(self, e: Any) -> None: + # noqa: F821 """Update the internal state based on some property of an element. @@ -55,10 +53,7 @@ def update(self, e): class SizeLimiter(Limiter): """Limits the cache size to a specified byte limit.""" - def __init__( - self, - size_limit # type: int - ): + def __init__(self, size_limit: int): self._size_limit = size_limit def is_triggered(self): @@ -75,7 +70,7 @@ class DurationLimiter(Limiter): """Limits the duration of the capture.""" def __init__( self, - duration_limit # type: datetime.timedelta # noqa: F821 + duration_limit: datetime.timedelta # noqa: F821 ): self._duration_limit = duration_limit self._timer = threading.Timer(duration_limit.total_seconds(), self._trigger) diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager.py b/sdks/python/apache_beam/runners/interactive/recording_manager.py index bee215717b4d..a2470e693314 100644 --- a/sdks/python/apache_beam/runners/interactive/recording_manager.py +++ b/sdks/python/apache_beam/runners/interactive/recording_manager.py @@ -40,12 +40,11 @@ class ElementStream: """A stream of elements from a given PCollection.""" def __init__( self, - pcoll, # type: beam.pvalue.PCollection - var, # type: str - cache_key, # type: str - max_n, # type: int - max_duration_secs # type: float - ): + pcoll: beam.pvalue.PCollection, + var: str, + cache_key: str, + max_n: int, + max_duration_secs: float): self._pcoll = pcoll self._cache_key = cache_key self._pipeline = ie.current_env().user_pipeline(pcoll.pipeline) @@ -58,46 +57,38 @@ def __init__( self._done = False @property - def var(self): - # type: () -> str - + def var(self) -> str: """Returns the variable named that defined this PCollection.""" return self._var @property - def pcoll(self): - # type: () -> beam.pvalue.PCollection - + def pcoll(self) -> beam.pvalue.PCollection: """Returns the PCollection that supplies this stream with data.""" return self._pcoll @property - def cache_key(self): - # type: () -> str - + def cache_key(self) -> str: """Returns the cache key for this stream.""" return self._cache_key - def display_id(self, suffix): - # type: (str) -> str - + def display_id(self, suffix: str) -> str: """Returns a unique id able to be displayed in a web browser.""" return utils.obfuscate(self._cache_key, suffix) - def is_computed(self): - # type: () -> boolean # noqa: F821 + def is_computed(self) -> boolean: + # noqa: F821 """Returns True if no more elements will be recorded.""" return self._pcoll in ie.current_env().computed_pcollections - def is_done(self): - # type: () -> boolean # noqa: F821 + def is_done(self) -> boolean: + # noqa: F821 """Returns True if no more new elements will be yielded.""" return self._done - def read(self, tail=True): - # type: (boolean) -> Any # noqa: F821 + def read(self, tail: boolean = True) -> Any: + # noqa: F821 """Reads the elements currently recorded.""" @@ -154,11 +145,11 @@ class Recording: """A group of PCollections from a given pipeline run.""" def __init__( self, - user_pipeline, # type: beam.Pipeline - pcolls, # type: List[beam.pvalue.PCollection] # noqa: F821 - result, # type: beam.runner.PipelineResult - max_n, # type: int - max_duration_secs, # type: float + user_pipeline: beam.Pipeline, + pcolls: List[beam.pvalue.PCollection], # noqa: F821 + result: beam.runner.PipelineResult, + max_n: int, + max_duration_secs: float, ): self._user_pipeline = user_pipeline self._result = result @@ -188,9 +179,7 @@ def __init__( self._mark_computed.daemon = True self._mark_computed.start() - def _mark_all_computed(self): - # type: () -> None - + def _mark_all_computed(self) -> None: """Marks all the PCollections upon a successful pipeline run.""" if not self._result: return @@ -216,40 +205,30 @@ def _mark_all_computed(self): if self._result.state is PipelineState.DONE and self._set_computed: ie.current_env().mark_pcollection_computed(self._pcolls) - def is_computed(self): - # type: () -> boolean # noqa: F821 + def is_computed(self) -> boolean: + # noqa: F821 """Returns True if all PCollections are computed.""" return all(s.is_computed() for s in self._streams.values()) - def stream(self, pcoll): - # type: (beam.pvalue.PCollection) -> ElementStream - + def stream(self, pcoll: beam.pvalue.PCollection) -> ElementStream: """Returns an ElementStream for a given PCollection.""" return self._streams[pcoll] - def computed(self): - # type: () -> None - + def computed(self) -> None: """Returns all computed ElementStreams.""" return {p: s for p, s in self._streams.items() if s.is_computed()} - def uncomputed(self): - # type: () -> None - + def uncomputed(self) -> None: """Returns all uncomputed ElementStreams.""" return {p: s for p, s in self._streams.items() if not s.is_computed()} - def cancel(self): - # type: () -> None - + def cancel(self) -> None: """Cancels the recording.""" with self._result_lock: self._result.cancel() - def wait_until_finish(self): - # type: () -> None - + def wait_until_finish(self) -> None: """Waits until the pipeline is done and returns the final state. This also marks any PCollections as computed right away if the pipeline is @@ -261,9 +240,7 @@ def wait_until_finish(self): self._mark_computed.join() return self._result.state - def describe(self): - # type: () -> dict[str, int] - + def describe(self) -> dict[str, int]: """Returns a dictionary describing the cache and recording.""" cache_manager = ie.current_env().get_cache_manager(self._user_pipeline) @@ -274,17 +251,21 @@ def describe(self): class RecordingManager: """Manages recordings of PCollections for a given pipeline.""" - def __init__(self, user_pipeline, pipeline_var=None, test_limiters=None): - # type: (beam.Pipeline, str, list[Limiter]) -> None # noqa: F821 - - self.user_pipeline = user_pipeline # type: beam.Pipeline - self.pipeline_var = pipeline_var if pipeline_var else '' # type: str - self._recordings = set() # type: set[Recording] - self._start_time_sec = 0 # type: float + def __init__( + self, + user_pipeline: beam.Pipeline, + pipeline_var: str = None, + test_limiters: list[Limiter] = None) -> None: + # noqa: F821 + + self.user_pipeline: beam.Pipeline = user_pipeline + self.pipeline_var: str = pipeline_var if pipeline_var else '' + self._recordings: set[Recording] = set() + self._start_time_sec: float = 0 self._test_limiters = test_limiters if test_limiters else [] - def _watch(self, pcolls): - # type: (List[beam.pvalue.PCollection]) -> None # noqa: F821 + def _watch(self, pcolls: List[beam.pvalue.PCollection]) -> None: + # noqa: F821 """Watch any pcollections not being watched. @@ -314,9 +295,7 @@ def _watch(self, pcolls): ie.current_env().watch( {'anonymous_pcollection_{}'.format(id(pcoll)): pcoll}) - def _clear(self): - # type: () -> None - + def _clear(self) -> None: """Clears the recording of all non-source PCollections.""" cache_manager = ie.current_env().get_cache_manager(self.user_pipeline) @@ -338,17 +317,13 @@ def _clear_pcolls(self, cache_manager, pcolls): for pc in pcolls: cache_manager.clear('full', pc) - def clear(self): - # type: () -> None - + def clear(self) -> None: """Clears all cached PCollections for this RecordingManager.""" cache_manager = ie.current_env().get_cache_manager(self.user_pipeline) if cache_manager: cache_manager.cleanup() - def cancel(self): - # type: (None) -> None - + def cancel(self: None) -> None: """Cancels the current background recording job.""" bcj.attempt_to_cancel_background_caching_job(self.user_pipeline) @@ -361,9 +336,7 @@ def cancel(self): # evict the BCJ after they complete. ie.current_env().evict_background_caching_job(self.user_pipeline) - def describe(self): - # type: () -> dict[str, int] - + def describe(self) -> dict[str, int]: """Returns a dictionary describing the cache and recording.""" cache_manager = ie.current_env().get_cache_manager(self.user_pipeline) @@ -384,9 +357,7 @@ def describe(self): 'pipeline_var': self.pipeline_var } - def record_pipeline(self): - # type: () -> bool - + def record_pipeline(self) -> bool: """Starts a background caching job for this RecordingManager's pipeline.""" runner = self.user_pipeline.runner @@ -412,8 +383,12 @@ def record_pipeline(self): return True return False - def record(self, pcolls, max_n, max_duration): - # type: (List[beam.pvalue.PCollection], int, Union[int,str]) -> Recording # noqa: F821 + def record( + self, + pcolls: List[beam.pvalue.PCollection], + max_n: int, + max_duration: Union[int, str]) -> Recording: + # noqa: F821 """Records the given PCollections.""" @@ -464,8 +439,13 @@ def record(self, pcolls, max_n, max_duration): return recording - def read(self, pcoll_name, pcoll, max_n, max_duration_secs): - # type: (str, beam.pvalue.PValue, int, float) -> Union[None, ElementStream] # noqa: F821 + def read( + self, + pcoll_name: str, + pcoll: beam.pvalue.PValue, + max_n: int, + max_duration_secs: float) -> Union[None, ElementStream]: + # noqa: F821 """Reads an ElementStream of a computed PCollection. diff --git a/sdks/python/apache_beam/runners/interactive/testing/integration/notebook_executor.py b/sdks/python/apache_beam/runners/interactive/testing/integration/notebook_executor.py index 6a80639ee285..808ede64d60d 100644 --- a/sdks/python/apache_beam/runners/interactive/testing/integration/notebook_executor.py +++ b/sdks/python/apache_beam/runners/interactive/testing/integration/notebook_executor.py @@ -40,8 +40,7 @@ class NotebookExecutor(object): """Executor that reads notebooks, executes it and gathers outputs into static HTML pages that can be served.""" - def __init__(self, path): - # type: (str) -> None + def __init__(self, path: str) -> None: assert _interactive_integration_ready, ( '[interactive_test] dependency is not installed.') diff --git a/sdks/python/apache_beam/runners/interactive/testing/integration/screen_diff.py b/sdks/python/apache_beam/runners/interactive/testing/integration/screen_diff.py index a1c9971b0882..743d5614f9a2 100644 --- a/sdks/python/apache_beam/runners/interactive/testing/integration/screen_diff.py +++ b/sdks/python/apache_beam/runners/interactive/testing/integration/screen_diff.py @@ -52,8 +52,11 @@ class ScreenDiffIntegrationTestEnvironment(object): """A test environment to conduct screen diff integration tests for notebooks. """ - def __init__(self, test_notebook_path, golden_dir, cleanup=True): - # type: (str, str, bool) -> None + def __init__( + self, + test_notebook_path: str, + golden_dir: str, + cleanup: bool = True) -> None: assert _interactive_integration_ready, ( '[interactive_test] dependency is not installed.') From b41698210dfe3ef075f31eaf69ba0dda73fa9a36 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:48:09 -0700 Subject: [PATCH 17/29] Modernize python type hints for apache_beam/runners/job --- sdks/python/apache_beam/runners/job/utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/runners/job/utils.py b/sdks/python/apache_beam/runners/job/utils.py index 205d87941a5a..1e15064ffd70 100644 --- a/sdks/python/apache_beam/runners/job/utils.py +++ b/sdks/python/apache_beam/runners/job/utils.py @@ -27,8 +27,7 @@ from google.protobuf import struct_pb2 -def dict_to_struct(dict_obj): - # type: (dict) -> struct_pb2.Struct +def dict_to_struct(dict_obj: dict) -> struct_pb2.Struct: try: return json_format.ParseDict(dict_obj, struct_pb2.Struct()) except json_format.ParseError: @@ -36,6 +35,5 @@ def dict_to_struct(dict_obj): raise -def struct_to_dict(struct_obj): - # type: (struct_pb2.Struct) -> dict +def struct_to_dict(struct_obj: struct_pb2.Struct) -> dict: return json.loads(json_format.MessageToJson(struct_obj)) From 8b540eb5075a9b5fb9ee75078fbf71c7d48591f2 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:48:15 -0700 Subject: [PATCH 18/29] Modernize python type hints for apache_beam/runners/portability --- .../portability/abstract_job_service.py | 139 ++++---- .../runners/portability/artifact_service.py | 18 +- .../portability/fn_api_runner/execution.py | 300 ++++++++--------- .../portability/fn_api_runner/fn_runner.py | 315 ++++++++---------- .../fn_api_runner/fn_runner_test.py | 2 +- .../fn_api_runner/watermark_manager.py | 11 +- .../runners/portability/job_server.py | 6 +- .../runners/portability/local_job_service.py | 36 +- .../runners/portability/portable_runner.py | 69 ++-- .../runners/portability/stager_test.py | 2 +- 10 files changed, 407 insertions(+), 491 deletions(-) diff --git a/sdks/python/apache_beam/runners/portability/abstract_job_service.py b/sdks/python/apache_beam/runners/portability/abstract_job_service.py index 1aa841df4c31..09c388f3b6a3 100644 --- a/sdks/python/apache_beam/runners/portability/abstract_job_service.py +++ b/sdks/python/apache_beam/runners/portability/abstract_job_service.py @@ -74,25 +74,22 @@ class AbstractJobServiceServicer(beam_job_api_pb2_grpc.JobServiceServicer): Servicer for the Beam Job API. """ def __init__(self): - self._jobs = {} # type: Dict[str, AbstractBeamJob] + self._jobs: Dict[str, AbstractBeamJob] = {} def create_beam_job(self, preparation_id, # stype: str - job_name, # type: str - pipeline, # type: beam_runner_api_pb2.Pipeline - options # type: struct_pb2.Struct - ): - # type: (...) -> AbstractBeamJob - + job_name: str, + pipeline: beam_runner_api_pb2.Pipeline, + options: struct_pb2.Struct + ) -> AbstractBeamJob: """Returns an instance of AbstractBeamJob specific to this servicer.""" raise NotImplementedError(type(self)) - def Prepare(self, - request, # type: beam_job_api_pb2.PrepareJobRequest - context=None, - timeout=None - ): - # type: (...) -> beam_job_api_pb2.PrepareJobResponse + def Prepare( + self, + request: beam_job_api_pb2.PrepareJobRequest, + context=None, + timeout=None) -> beam_job_api_pb2.PrepareJobResponse: _LOGGER.debug('Got Prepare request.') preparation_id = '%s-%s' % (request.job_name, uuid.uuid4()) self._jobs[preparation_id] = self.create_beam_job( @@ -108,56 +105,52 @@ def Prepare(self, artifact_staging_endpoint(), staging_session_token=preparation_id) - def Run(self, - request, # type: beam_job_api_pb2.RunJobRequest - context=None, - timeout=None - ): - # type: (...) -> beam_job_api_pb2.RunJobResponse + def Run( + self, + request: beam_job_api_pb2.RunJobRequest, + context=None, + timeout=None) -> beam_job_api_pb2.RunJobResponse: # For now, just use the preparation id as the job id. job_id = request.preparation_id _LOGGER.info("Running job '%s'", job_id) self._jobs[job_id].run() return beam_job_api_pb2.RunJobResponse(job_id=job_id) - def GetJobs(self, - request, # type: beam_job_api_pb2.GetJobsRequest - context=None, - timeout=None - ): - # type: (...) -> beam_job_api_pb2.GetJobsResponse + def GetJobs( + self, + request: beam_job_api_pb2.GetJobsRequest, + context=None, + timeout=None) -> beam_job_api_pb2.GetJobsResponse: return beam_job_api_pb2.GetJobsResponse( job_info=[job.to_runner_api() for job in self._jobs.values()]) def GetState( self, - request, # type: beam_job_api_pb2.GetJobStateRequest - context=None): - # type: (...) -> beam_job_api_pb2.JobStateEvent + request: beam_job_api_pb2.GetJobStateRequest, + context=None) -> beam_job_api_pb2.JobStateEvent: return make_state_event(*self._jobs[request.job_id].get_state()) - def GetPipeline(self, - request, # type: beam_job_api_pb2.GetJobPipelineRequest - context=None, - timeout=None - ): - # type: (...) -> beam_job_api_pb2.GetJobPipelineResponse + def GetPipeline( + self, + request: beam_job_api_pb2.GetJobPipelineRequest, + context=None, + timeout=None) -> beam_job_api_pb2.GetJobPipelineResponse: return beam_job_api_pb2.GetJobPipelineResponse( pipeline=self._jobs[request.job_id].get_pipeline()) - def Cancel(self, - request, # type: beam_job_api_pb2.CancelJobRequest - context=None, - timeout=None - ): - # type: (...) -> beam_job_api_pb2.CancelJobResponse + def Cancel( + self, + request: beam_job_api_pb2.CancelJobRequest, + context=None, + timeout=None) -> beam_job_api_pb2.CancelJobResponse: self._jobs[request.job_id].cancel() return beam_job_api_pb2.CancelJobResponse( state=self._jobs[request.job_id].get_state()[0]) - def GetStateStream(self, request, context=None, timeout=None): - # type: (...) -> Iterator[beam_job_api_pb2.JobStateEvent] - + def GetStateStream(self, + request, + context=None, + timeout=None) -> Iterator[beam_job_api_pb2.JobStateEvent]: """Yields state transitions since the stream started. """ if request.job_id not in self._jobs: @@ -167,9 +160,11 @@ def GetStateStream(self, request, context=None, timeout=None): for state, timestamp in job.get_state_stream(): yield make_state_event(state, timestamp) - def GetMessageStream(self, request, context=None, timeout=None): - # type: (...) -> Iterator[beam_job_api_pb2.JobMessagesResponse] - + def GetMessageStream( + self, + request, + context=None, + timeout=None) -> Iterator[beam_job_api_pb2.JobMessagesResponse]: """Yields messages since the stream started. """ if request.job_id not in self._jobs: @@ -184,50 +179,48 @@ def GetMessageStream(self, request, context=None, timeout=None): resp = beam_job_api_pb2.JobMessagesResponse(message_response=msg) yield resp - def DescribePipelineOptions(self, request, context=None, timeout=None): - # type: (...) -> beam_job_api_pb2.DescribePipelineOptionsResponse + def DescribePipelineOptions( + self, + request, + context=None, + timeout=None) -> beam_job_api_pb2.DescribePipelineOptionsResponse: return beam_job_api_pb2.DescribePipelineOptionsResponse() class AbstractBeamJob(object): """Abstract baseclass for managing a single Beam job.""" - - def __init__(self, - job_id, # type: str - job_name, # type: str - pipeline, # type: beam_runner_api_pb2.Pipeline - options # type: struct_pb2.Struct - ): + def __init__( + self, + job_id: str, + job_name: str, + pipeline: beam_runner_api_pb2.Pipeline, + options: struct_pb2.Struct): self._job_id = job_id self._job_name = job_name self._pipeline_proto = pipeline self._pipeline_options = options self._state_history = [(beam_job_api_pb2.JobState.STOPPED, Timestamp.now())] - def prepare(self): - # type: () -> None - + def prepare(self) -> None: """Called immediately after this class is instantiated""" raise NotImplementedError(self) - def run(self): - # type: () -> None + def run(self) -> None: raise NotImplementedError(self) - def cancel(self): - # type: () -> Optional[beam_job_api_pb2.JobState.Enum] + def cancel(self) -> Optional[beam_job_api_pb2.JobState.Enum]: raise NotImplementedError(self) - def artifact_staging_endpoint(self): - # type: () -> Optional[endpoints_pb2.ApiServiceDescriptor] + def artifact_staging_endpoint( + self) -> Optional[endpoints_pb2.ApiServiceDescriptor]: raise NotImplementedError(self) - def get_state_stream(self): - # type: () -> Iterator[StateEvent] + def get_state_stream(self) -> Iterator[StateEvent]: raise NotImplementedError(self) - def get_message_stream(self): - # type: () -> Iterator[Union[StateEvent, Optional[beam_job_api_pb2.JobMessage]]] + def get_message_stream( + self + ) -> Iterator[Union[StateEvent, Optional[beam_job_api_pb2.JobMessage]]]: raise NotImplementedError(self) @property @@ -259,8 +252,7 @@ def with_state_history(self, state_stream): """Utility to prepend recorded state history to an active state stream""" return itertools.chain(self._state_history[:], state_stream) - def get_pipeline(self): - # type: () -> beam_runner_api_pb2.Pipeline + def get_pipeline(self) -> beam_runner_api_pb2.Pipeline: return self._pipeline_proto @staticmethod @@ -268,8 +260,7 @@ def is_terminal_state(state): from apache_beam.runners.portability import portable_runner return state in portable_runner.TERMINAL_STATES - def to_runner_api(self): - # type: () -> beam_job_api_pb2.JobInfo + def to_runner_api(self) -> beam_job_api_pb2.JobInfo: return beam_job_api_pb2.JobInfo( job_id=self._job_id, job_name=self._job_name, @@ -285,9 +276,7 @@ def __init__(self, jar_path, root): def close(self): self._zipfile_handle.close() - def file_writer(self, path): - # type: (str) -> Tuple[BinaryIO, str] - + def file_writer(self, path: str) -> Tuple[BinaryIO, str]: """Given a relative path, returns an open handle that can be written to and an reference that can later be used to read this file.""" full_path = '%s/%s' % (self._root, path) diff --git a/sdks/python/apache_beam/runners/portability/artifact_service.py b/sdks/python/apache_beam/runners/portability/artifact_service.py index 6dec4031ee07..b9395caeafaf 100644 --- a/sdks/python/apache_beam/runners/portability/artifact_service.py +++ b/sdks/python/apache_beam/runners/portability/artifact_service.py @@ -57,7 +57,7 @@ class ArtifactRetrievalService( def __init__( self, - file_reader, # type: Callable[[str], BinaryIO] + file_reader: Callable[[str], BinaryIO], chunk_size=None, ): self._file_reader = file_reader @@ -97,18 +97,20 @@ class ArtifactStagingService( beam_artifact_api_pb2_grpc.ArtifactStagingServiceServicer): def __init__( self, - file_writer, # type: Callable[[str, Optional[str]], Tuple[BinaryIO, str]] - ): + file_writer: Callable[[str, Optional[str]], Tuple[BinaryIO, str]], + ): self._lock = threading.Lock() - self._jobs_to_stage = { - } # type: Dict[str, Tuple[Dict[Any, List[beam_runner_api_pb2.ArtifactInformation]], threading.Event]] + self._jobs_to_stage: Dict[ + str, + Tuple[Dict[Any, List[beam_runner_api_pb2.ArtifactInformation]], + threading.Event]] = {} self._file_writer = file_writer def register_job( self, - staging_token, # type: str - dependency_sets # type: MutableMapping[Any, List[beam_runner_api_pb2.ArtifactInformation]] - ): + staging_token: str, + dependency_sets: MutableMapping[ + Any, List[beam_runner_api_pb2.ArtifactInformation]]): if staging_token in self._jobs_to_stage: raise ValueError('Already staging %s' % staging_token) with self._lock: diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py index 885c96146456..3c16cb7cf99d 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py @@ -95,12 +95,10 @@ class Buffer(Protocol): - def __iter__(self): - # type: () -> Iterator[bytes] + def __iter__(self) -> Iterator[bytes]: pass - def append(self, item): - # type: (bytes) -> None + def append(self, item: bytes) -> None: pass def extend(self, other: 'Buffer') -> None: @@ -111,31 +109,26 @@ class PartitionableBuffer(Buffer, Protocol): def copy(self) -> 'PartitionableBuffer': pass - def partition(self, n): - # type: (int) -> List[List[bytes]] + def partition(self, n: int) -> List[List[bytes]]: pass @property - def cleared(self): - # type: () -> bool + def cleared(self) -> bool: pass - def clear(self): - # type: () -> None + def clear(self) -> None: pass - def reset(self): - # type: () -> None + def reset(self) -> None: pass class ListBuffer: """Used to support parititioning of a list.""" - def __init__(self, coder_impl): - # type: (Optional[CoderImpl]) -> None + def __init__(self, coder_impl: Optional[CoderImpl]) -> None: self._coder_impl = coder_impl or CoderImpl() - self._inputs = [] # type: List[bytes] - self._grouped_output = None # type: Optional[List[List[bytes]]] + self._inputs: List[bytes] = [] + self._grouped_output: Optional[List[List[bytes]]] = None self.cleared = False def copy(self) -> 'ListBuffer': @@ -151,16 +144,14 @@ def extend(self, extra: 'Buffer') -> None: assert isinstance(extra, ListBuffer) self._inputs.extend(extra._inputs) - def append(self, element): - # type: (bytes) -> None + def append(self, element: bytes) -> None: if self.cleared: raise RuntimeError('Trying to append to a cleared ListBuffer.') if self._grouped_output: raise RuntimeError('ListBuffer append after read.') self._inputs.append(element) - def partition(self, n): - # type: (int) -> List[List[bytes]] + def partition(self, n: int) -> List[List[bytes]]: if self.cleared: raise RuntimeError('Trying to partition a cleared ListBuffer.') if len(self._inputs) >= n or len(self._inputs) == 0: @@ -181,21 +172,17 @@ def partition(self, n): for output_stream in output_stream_list] return self._grouped_output - def __iter__(self): - # type: () -> Iterator[bytes] + def __iter__(self) -> Iterator[bytes]: if self.cleared: raise RuntimeError('Trying to iterate through a cleared ListBuffer.') return iter(self._inputs) - def clear(self): - # type: () -> None + def clear(self) -> None: self.cleared = True self._inputs = [] self._grouped_output = None - def reset(self): - # type: () -> None - + def reset(self) -> None: """Resets a cleared buffer for reuse.""" if not self.cleared: raise RuntimeError('Trying to reset a non-cleared ListBuffer.') @@ -204,19 +191,17 @@ def reset(self): class GroupingBuffer(object): """Used to accumulate groupded (shuffled) results.""" - def __init__(self, - pre_grouped_coder, # type: coders.Coder - post_grouped_coder, # type: coders.Coder - windowing # type: core.Windowing - ): - # type: (...) -> None + def __init__( + self, + pre_grouped_coder: coders.Coder, + post_grouped_coder: coders.Coder, + windowing: core.Windowing) -> None: self._key_coder = pre_grouped_coder.key_coder() self._pre_grouped_coder = pre_grouped_coder self._post_grouped_coder = post_grouped_coder - self._table = collections.defaultdict( - list) # type: DefaultDict[bytes, List[Any]] + self._table: DefaultDict[bytes, List[Any]] = collections.defaultdict(list) self._windowing = windowing - self._grouped_output = None # type: Optional[List[List[bytes]]] + self._grouped_output: Optional[List[List[bytes]]] = None def copy(self) -> 'GroupingBuffer': # This is a silly temporary optimization. This class must be removed once @@ -224,8 +209,7 @@ def copy(self) -> 'GroupingBuffer': # data grouping instead of GroupingBuffer). return self - def append(self, elements_data): - # type: (bytes) -> None + def append(self, elements_data: bytes) -> None: if self._grouped_output: raise RuntimeError('Grouping table append after read.') input_stream = create_InputStream(elements_data) @@ -241,8 +225,7 @@ def append(self, elements_data): value if is_trivial_windowing else windowed_key_value. with_value(value)) - def extend(self, input_buffer): - # type: (Buffer) -> None + def extend(self, input_buffer: Buffer) -> None: if isinstance(input_buffer, ListBuffer): # TODO(pabloem): GroupingBuffer will be removed once shuffling is done # via state. Remove this workaround along with that. @@ -252,9 +235,7 @@ def extend(self, input_buffer): for key, values in input_buffer._table.items(): self._table[key].extend(values) - def partition(self, n): - # type: (int) -> List[List[bytes]] - + def partition(self, n: int) -> List[List[bytes]]: """ It is used to partition _GroupingBuffer to N parts. Once it is partitioned, it would not be re-partitioned with diff N. Re-partition is not supported now. @@ -292,9 +273,7 @@ def partition(self, n): self._table.clear() return self._grouped_output - def __iter__(self): - # type: () -> Iterator[bytes] - + def __iter__(self) -> Iterator[bytes]: """ Since partition() returns a list of lists, add this __iter__ to return a list to simplify code when we need to iterate through ALL elements of _GroupingBuffer. @@ -305,12 +284,10 @@ def __iter__(self): # PartionableBuffer protocol cleared = False - def clear(self): - # type: () -> None + def clear(self) -> None: pass - def reset(self): - # type: () -> None + def reset(self) -> None: pass @@ -318,15 +295,13 @@ class WindowGroupingBuffer(object): """Used to partition windowed side inputs.""" def __init__( self, - access_pattern, # type: beam_runner_api_pb2.FunctionSpec - coder # type: WindowedValueCoder - ): - # type: (...) -> None + access_pattern: beam_runner_api_pb2.FunctionSpec, + coder: WindowedValueCoder) -> None: # Here's where we would use a different type of partitioning # (e.g. also by key) for a different access pattern. if access_pattern.urn == common_urns.side_inputs.ITERABLE.urn: self._kv_extractor = lambda value: ('', value) - self._key_coder = coders.SingletonCoder('') # type: coders.Coder + self._key_coder: coders.Coder = coders.SingletonCoder('') self._value_coder = coder.wrapped_value_coder elif access_pattern.urn == common_urns.side_inputs.MULTIMAP.urn: self._kv_extractor = lambda value: value @@ -336,23 +311,22 @@ def __init__( raise ValueError("Unknown access pattern: '%s'" % access_pattern.urn) self._windowed_value_coder = coder self._window_coder = coder.window_coder - self._values_by_window = collections.defaultdict( - list) # type: DefaultDict[Tuple[str, BoundedWindow], List[Any]] + self._values_by_window: DefaultDict[Tuple[str, BoundedWindow], + List[Any]] = collections.defaultdict( + list) - def append(self, elements_data): - # type: (bytes) -> None + def append(self, elements_data: bytes) -> None: input_stream = create_InputStream(elements_data) while input_stream.size() > 0: - windowed_val_coder_impl = self._windowed_value_coder.get_impl( - ) # type: WindowedValueCoderImpl + windowed_val_coder_impl: WindowedValueCoderImpl = self._windowed_value_coder.get_impl( + ) windowed_value = windowed_val_coder_impl.decode_from_stream( input_stream, True) key, value = self._kv_extractor(windowed_value.value) for window in windowed_value.windows: self._values_by_window[key, window].append(value) - def encoded_items(self): - # type: () -> Iterator[Tuple[bytes, bytes, bytes, int]] + def encoded_items(self) -> Iterator[Tuple[bytes, bytes, bytes, int]]: value_coder_impl = self._value_coder.get_impl() key_coder_impl = self._key_coder.get_impl() for (key, window), values in self._values_by_window.items(): @@ -368,22 +342,21 @@ class GenericNonMergingWindowFn(window.NonMergingWindowFn): URN = 'internal-generic-non-merging' - def __init__(self, coder): - # type: (coders.Coder) -> None + def __init__(self, coder: coders.Coder) -> None: self._coder = coder - def assign(self, assign_context): - # type: (window.WindowFn.AssignContext) -> Iterable[BoundedWindow] + def assign( + self, + assign_context: window.WindowFn.AssignContext) -> Iterable[BoundedWindow]: raise NotImplementedError() - def get_window_coder(self): - # type: () -> coders.Coder + def get_window_coder(self) -> coders.Coder: return self._coder @staticmethod @window.urns.RunnerApiFn.register_urn(URN, bytes) - def from_runner_api_parameter(window_coder_id, context): - # type: (bytes, Any) -> GenericNonMergingWindowFn + def from_runner_api_parameter( + window_coder_id: bytes, context: Any) -> GenericNonMergingWindowFn: return GenericNonMergingWindowFn( context.coders[window_coder_id.decode('utf-8')]) @@ -478,11 +451,13 @@ class GenericMergingWindowFn(window.WindowFn): TO_SDK_TRANSFORM = 'read' FROM_SDK_TRANSFORM = 'write' - _HANDLES = {} # type: Dict[str, GenericMergingWindowFn] + _HANDLES: Dict[str, GenericMergingWindowFn] = {} - def __init__(self, execution_context, windowing_strategy_proto): - # type: (FnApiRunnerExecutionContext, beam_runner_api_pb2.WindowingStrategy) -> None - self._worker_handler = None # type: Optional[worker_handlers.WorkerHandler] + def __init__( + self, + execution_context: FnApiRunnerExecutionContext, + windowing_strategy_proto: beam_runner_api_pb2.WindowingStrategy) -> None: + self._worker_handler: Optional[worker_handlers.WorkerHandler] = None self._handle_id = handle_id = uuid.uuid4().hex self._HANDLES[handle_id] = self # ExecutionContexts are expensive, we don't want to keep them in the @@ -494,32 +469,30 @@ def __init__(self, execution_context, windowing_strategy_proto): self._counter = 0 # Lazily created in make_process_bundle_descriptor() self._process_bundle_descriptor = None - self._bundle_processor_id = '' # type: str - self.windowed_input_coder_impl = None # type: Optional[CoderImpl] - self.windowed_output_coder_impl = None # type: Optional[CoderImpl] + self._bundle_processor_id: str = '' + self.windowed_input_coder_impl: Optional[CoderImpl] = None + self.windowed_output_coder_impl: Optional[CoderImpl] = None - def _execution_context_ref(self): - # type: () -> FnApiRunnerExecutionContext + def _execution_context_ref(self) -> FnApiRunnerExecutionContext: result = self._execution_context_ref_obj() assert result is not None return result - def payload(self): - # type: () -> bytes + def payload(self) -> bytes: return self._handle_id.encode('utf-8') @staticmethod @window.urns.RunnerApiFn.register_urn(URN, bytes) - def from_runner_api_parameter(handle_id, unused_context): - # type: (bytes, Any) -> GenericMergingWindowFn + def from_runner_api_parameter( + handle_id: bytes, unused_context: Any) -> GenericMergingWindowFn: return GenericMergingWindowFn._HANDLES[handle_id.decode('utf-8')] - def assign(self, assign_context): - # type: (window.WindowFn.AssignContext) -> Iterable[window.BoundedWindow] + def assign( + self, assign_context: window.WindowFn.AssignContext + ) -> Iterable[window.BoundedWindow]: raise NotImplementedError() - def merge(self, merge_context): - # type: (window.WindowFn.MergeContext) -> None + def merge(self, merge_context: window.WindowFn.MergeContext) -> None: worker_handler = self.worker_handle() assert self.windowed_input_coder_impl is not None @@ -554,13 +527,11 @@ def merge(self, merge_context): raise RuntimeError(result.error) # The result was "returned" via the merge callbacks on merge_context above. - def get_window_coder(self): - # type: () -> coders.Coder + def get_window_coder(self) -> coders.Coder: return self._execution_context_ref().pipeline_context.coders[ self._windowing_strategy_proto.window_coder_id] - def worker_handle(self): - # type: () -> worker_handlers.WorkerHandler + def worker_handle(self) -> worker_handlers.WorkerHandler: if self._worker_handler is None: worker_handler_manager = self._execution_context_ref( ).worker_handler_manager @@ -574,14 +545,14 @@ def worker_handle(self): return self._worker_handler def make_process_bundle_descriptor( - self, data_api_service_descriptor, state_api_service_descriptor): - # type: (Optional[endpoints_pb2.ApiServiceDescriptor], Optional[endpoints_pb2.ApiServiceDescriptor]) -> beam_fn_api_pb2.ProcessBundleDescriptor - + self, + data_api_service_descriptor: Optional[endpoints_pb2.ApiServiceDescriptor], + state_api_service_descriptor: Optional[endpoints_pb2.ApiServiceDescriptor] + ) -> beam_fn_api_pb2.ProcessBundleDescriptor: """Creates a ProcessBundleDescriptor for invoking the WindowFn's merge operation. """ - def make_channel_payload(coder_id): - # type: (str) -> bytes + def make_channel_payload(coder_id: str) -> bytes: data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) if data_api_service_descriptor: data_spec.api_service_descriptor.url = (data_api_service_descriptor.url) @@ -593,8 +564,7 @@ def make_channel_payload(coder_id): window.GlobalWindows()).to_runner_api(pipeline_context) coders = dict(pipeline_context.coders.get_id_to_proto_map()) - def make_coder(urn, *components): - # type: (str, str) -> str + def make_coder(urn: str, *components: str) -> str: coder_proto = beam_runner_api_pb2.Coder( spec=beam_runner_api_pb2.FunctionSpec(urn=urn), component_coder_ids=components) @@ -681,8 +651,7 @@ def make_coder(urn, *components): state_api_service_descriptor=state_api_service_descriptor, timer_api_service_descriptor=data_api_service_descriptor) - def uid(self, name=''): - # type: (str) -> str + def uid(self, name: str = '') -> str: self._counter += 1 return '%s_%s_%s' % (self._handle_id, name, self._counter) @@ -693,16 +662,18 @@ class FnApiRunnerExecutionContext(object): PCollection IDs to list that functions as buffer for the ``beam.PCollection``. """ - def __init__(self, - stages, # type: List[translations.Stage] - worker_handler_manager, # type: worker_handlers.WorkerHandlerManager - pipeline_components, # type: beam_runner_api_pb2.Components - safe_coders: translations.SafeCoderMapping, - data_channel_coders: Dict[str, str], - num_workers: int, - uses_teststream: bool = False, - split_managers = () # type: Sequence[Tuple[str, Callable[[int], Iterable[float]]]] - ) -> None: + def __init__( + self, + stages: List[translations.Stage], + worker_handler_manager: worker_handlers.WorkerHandlerManager, + pipeline_components: beam_runner_api_pb2.Components, + safe_coders: translations.SafeCoderMapping, + data_channel_coders: Dict[str, str], + num_workers: int, + uses_teststream: bool = False, + split_managers: Sequence[Tuple[str, Callable[[int], + Iterable[float]]]] = () + ) -> None: """ :param worker_handler_manager: This class manages the set of worker handlers, and the communication with state / control APIs. @@ -714,8 +685,8 @@ def __init__(self, self.stages = {s.name: s for s in stages} self.side_input_descriptors_by_stage = ( self._build_data_side_inputs_map(stages)) - self.pcoll_buffers = {} # type: MutableMapping[bytes, PartitionableBuffer] - self.timer_buffers = {} # type: MutableMapping[bytes, ListBuffer] + self.pcoll_buffers: MutableMapping[bytes, PartitionableBuffer] = {} + self.timer_buffers: MutableMapping[bytes, ListBuffer] = {} self.worker_handler_manager = worker_handler_manager self.pipeline_components = pipeline_components self.safe_coders = safe_coders @@ -806,7 +777,7 @@ def setup(self) -> None: def _enqueue_stage_initial_inputs(self, stage: Stage) -> None: """Sets up IMPULSE inputs for a stage, and the data GRPC API endpoint.""" - data_input = {} # type: MutableMapping[str, PartitionableBuffer] + data_input: MutableMapping[str, PartitionableBuffer] = {} ready_to_schedule = True for transform in stage.transforms: if (transform.spec.urn in {bundle_processor.DATA_INPUT_URN, @@ -854,23 +825,23 @@ def _enqueue_stage_initial_inputs(self, stage: Stage) -> None: ((stage.name, MAX_TIMESTAMP), DataInput(data_input, {}))) @staticmethod - def _build_data_side_inputs_map(stages): - # type: (Iterable[translations.Stage]) -> MutableMapping[str, DataSideInput] - + def _build_data_side_inputs_map( + stages: Iterable[translations.Stage] + ) -> MutableMapping[str, DataSideInput]: """Builds an index mapping stages to side input descriptors. A side input descriptor is a map of side input IDs to side input access patterns for all of the outputs of a stage that will be consumed as a side input. """ - transform_consumers = collections.defaultdict( - list) # type: DefaultDict[str, List[beam_runner_api_pb2.PTransform]] - stage_consumers = collections.defaultdict( - list) # type: DefaultDict[str, List[translations.Stage]] - - def get_all_side_inputs(): - # type: () -> Set[str] - all_side_inputs = set() # type: Set[str] + transform_consumers: DefaultDict[ + str, + List[beam_runner_api_pb2.PTransform]] = collections.defaultdict(list) + stage_consumers: DefaultDict[ + str, List[translations.Stage]] = collections.defaultdict(list) + + def get_all_side_inputs() -> Set[str]: + all_side_inputs: Set[str] = set() for stage in stages: for transform in stage.transforms: for input in transform.inputs.values(): @@ -881,7 +852,7 @@ def get_all_side_inputs(): return all_side_inputs all_side_inputs = frozenset(get_all_side_inputs()) - data_side_inputs_by_producing_stage = {} # type: Dict[str, DataSideInput] + data_side_inputs_by_producing_stage: Dict[str, DataSideInput] = {} producing_stages_by_pcoll = {} @@ -912,8 +883,7 @@ def get_all_side_inputs(): return data_side_inputs_by_producing_stage - def _make_safe_windowing_strategy(self, id): - # type: (str) -> str + def _make_safe_windowing_strategy(self, id: str) -> str: windowing_strategy_proto = self.pipeline_components.windowing_strategies[id] if windowing_strategy_proto.window_fn.urn in SAFE_WINDOW_FNS: return id @@ -940,18 +910,16 @@ def _make_safe_windowing_strategy(self, id): return safe_id @property - def state_servicer(self): - # type: () -> worker_handlers.StateServicer + def state_servicer(self) -> worker_handlers.StateServicer: # TODO(BEAM-9625): Ensure FnApiRunnerExecutionContext owns StateServicer return self.worker_handler_manager.state_servicer - def next_uid(self): - # type: () -> str + def next_uid(self) -> str: self._last_uid += 1 return str(self._last_uid) - def _iterable_state_write(self, values, element_coder_impl): - # type: (Iterable, CoderImpl) -> bytes + def _iterable_state_write( + self, values: Iterable, element_coder_impl: CoderImpl) -> bytes: token = unique_name(None, 'iter').encode('ascii') out = create_OutputStream() for element in values: @@ -964,9 +932,8 @@ def _iterable_state_write(self, values, element_coder_impl): def commit_side_inputs_to_state( self, - data_side_input, # type: DataSideInput - ): - # type: (...) -> None + data_side_input: DataSideInput, + ) -> None: for (consuming_transform_id, tag), (buffer_id, func_spec) in data_side_input.items(): _, pcoll_id = split_buffer_id(buffer_id) @@ -1024,14 +991,13 @@ def commit_side_inputs_to_state( class BundleContextManager(object): - - def __init__(self, - execution_context, # type: FnApiRunnerExecutionContext - stage, # type: translations.Stage - num_workers, # type: int - split_managers, # type: Sequence[Tuple[str, Callable[[int], Iterable[float]]]] - ): - # type: (...) -> None + def __init__( + self, + execution_context: FnApiRunnerExecutionContext, + stage: translations.Stage, + num_workers: int, + split_managers: Sequence[Tuple[str, Callable[[int], Iterable[float]]]], + ) -> None: self.execution_context = execution_context self.stage = stage self.bundle_uid = self.execution_context.next_uid() @@ -1039,12 +1005,13 @@ def __init__(self, self.split_managers = split_managers # Properties that are lazily initialized - self._process_bundle_descriptor = None # type: Optional[beam_fn_api_pb2.ProcessBundleDescriptor] - self._worker_handlers = None # type: Optional[List[worker_handlers.WorkerHandler]] + self._process_bundle_descriptor: Optional[ + beam_fn_api_pb2.ProcessBundleDescriptor] = None + self._worker_handlers: Optional[List[worker_handlers.WorkerHandler]] = None # a mapping of {(transform_id, timer_family_id): timer_coder_id}. The map # is built after self._process_bundle_descriptor is initialized. # This field can be used to tell whether current bundle has timers. - self._timer_coder_ids = None # type: Optional[Dict[Tuple[str, str], str]] + self._timer_coder_ids: Optional[Dict[Tuple[str, str], str]] = None # A mapping from transform_name to Buffer ID self.stage_data_outputs: DataOutput = {} @@ -1066,36 +1033,35 @@ def _compute_expected_outputs(self) -> None: create_buffer_id(timer_family_id, 'timers'), time_domain) @property - def worker_handlers(self): - # type: () -> List[worker_handlers.WorkerHandler] + def worker_handlers(self) -> List[worker_handlers.WorkerHandler]: if self._worker_handlers is None: self._worker_handlers = ( self.execution_context.worker_handler_manager.get_worker_handlers( self.stage.environment, self.num_workers)) return self._worker_handlers - def data_api_service_descriptor(self): - # type: () -> Optional[endpoints_pb2.ApiServiceDescriptor] + def data_api_service_descriptor( + self) -> Optional[endpoints_pb2.ApiServiceDescriptor]: # All worker_handlers share the same grpc server, so we can read grpc server # info from any worker_handler and read from the first worker_handler. return self.worker_handlers[0].data_api_service_descriptor() - def state_api_service_descriptor(self): - # type: () -> Optional[endpoints_pb2.ApiServiceDescriptor] + def state_api_service_descriptor( + self) -> Optional[endpoints_pb2.ApiServiceDescriptor]: # All worker_handlers share the same grpc server, so we can read grpc server # info from any worker_handler and read from the first worker_handler. return self.worker_handlers[0].state_api_service_descriptor() @property - def process_bundle_descriptor(self): - # type: () -> beam_fn_api_pb2.ProcessBundleDescriptor + def process_bundle_descriptor( + self) -> beam_fn_api_pb2.ProcessBundleDescriptor: if self._process_bundle_descriptor is None: self._process_bundle_descriptor = self._build_process_bundle_descriptor() self._timer_coder_ids = self._build_timer_coders_id_map() return self._process_bundle_descriptor - def _build_process_bundle_descriptor(self): - # type: () -> beam_fn_api_pb2.ProcessBundleDescriptor + def _build_process_bundle_descriptor( + self) -> beam_fn_api_pb2.ProcessBundleDescriptor: # Cannot be invoked until *after* _extract_endpoints is called. # Always populate the timer_api_service_descriptor. return beam_fn_api_pb2.ProcessBundleDescriptor( @@ -1115,16 +1081,14 @@ def _build_process_bundle_descriptor(self): state_api_service_descriptor=self.state_api_service_descriptor(), timer_api_service_descriptor=self.data_api_service_descriptor()) - def get_input_coder_impl(self, transform_id): - # type: (str) -> CoderImpl + def get_input_coder_impl(self, transform_id: str) -> CoderImpl: coder_id = beam_fn_api_pb2.RemoteGrpcPort.FromString( self.process_bundle_descriptor.transforms[transform_id].spec.payload ).coder_id assert coder_id return self.get_coder_impl(coder_id) - def _build_timer_coders_id_map(self): - # type: () -> Dict[Tuple[str, str], str] + def _build_timer_coders_id_map(self) -> Dict[Tuple[str, str], str]: assert self._process_bundle_descriptor is not None timer_coder_ids = {} for transform_id, transform_proto in (self._process_bundle_descriptor @@ -1137,23 +1101,21 @@ def _build_timer_coders_id_map(self): timer_family_spec.timer_family_coder_id) return timer_coder_ids - def get_coder_impl(self, coder_id): - # type: (str) -> CoderImpl + def get_coder_impl(self, coder_id: str) -> CoderImpl: if coder_id in self.execution_context.safe_coders: return self.execution_context.pipeline_context.coders[ self.execution_context.safe_coders[coder_id]].get_impl() else: return self.execution_context.pipeline_context.coders[coder_id].get_impl() - def get_timer_coder_impl(self, transform_id, timer_family_id): - # type: (str, str) -> CoderImpl + def get_timer_coder_impl( + self, transform_id: str, timer_family_id: str) -> CoderImpl: assert self._timer_coder_ids is not None return self.get_coder_impl( self._timer_coder_ids[(transform_id, timer_family_id)]) - def get_buffer(self, buffer_id, transform_id): - # type: (bytes, str) -> PartitionableBuffer - + def get_buffer( + self, buffer_id: bytes, transform_id: str) -> PartitionableBuffer: """Returns the buffer for a given (operation_type, PCollection ID). For grouping-typed operations, we produce a ``GroupingBuffer``. For others, we produce a ``ListBuffer``. diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py index 07569fe328d8..d15e04e5f238 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py @@ -102,15 +102,12 @@ class FnApiRunner(runner.PipelineRunner): def __init__( self, - default_environment=None, # type: Optional[environments.Environment] - bundle_repeat=0, # type: int - use_state_iterables=False, # type: bool - provision_info=None, # type: Optional[ExtendedProvisionInfo] - progress_request_frequency=None, # type: Optional[float] - is_drain=False # type: bool - ): - # type: (...) -> None - + default_environment: Optional[environments.Environment] = None, + bundle_repeat: int = 0, + use_state_iterables: bool = False, + provision_info: Optional[ExtendedProvisionInfo] = None, + progress_request_frequency: Optional[float] = None, + is_drain: bool = False) -> None: """Creates a new Fn API Runner. Args: @@ -138,19 +135,16 @@ def __init__( retrieval_token='unused-retrieval-token')) @staticmethod - def supported_requirements(): - # type: () -> Tuple[str, ...] + def supported_requirements() -> Tuple[str, ...]: return ( common_urns.requirements.REQUIRES_STATEFUL_PROCESSING.urn, common_urns.requirements.REQUIRES_BUNDLE_FINALIZATION.urn, common_urns.requirements.REQUIRES_SPLITTABLE_DOFN.urn, ) - def run_pipeline(self, - pipeline, # type: Pipeline - options # type: pipeline_options.PipelineOptions - ): - # type: (...) -> RunnerResult + def run_pipeline( + self, pipeline: Pipeline, + options: pipeline_options.PipelineOptions) -> RunnerResult: RuntimeValueProvider.set_runtime_options({}) # Setup "beam_fn_api" experiment options if lacked. @@ -206,8 +200,10 @@ def run_pipeline(self, options) return self._latest_run_result - def run_via_runner_api(self, pipeline_proto, options): - # type: (beam_runner_api_pb2.Pipeline, pipeline_options.PipelineOptions) -> RunnerResult + def run_via_runner_api( + self, + pipeline_proto: beam_runner_api_pb2.Pipeline, + options: pipeline_options.PipelineOptions) -> RunnerResult: validate_pipeline_graph(pipeline_proto) self._validate_requirements(pipeline_proto) self._check_requirements(pipeline_proto) @@ -282,8 +278,7 @@ def resolve_any_environments(self, pipeline_proto): return pipeline_proto @contextlib.contextmanager - def maybe_profile(self): - # type: () -> Iterator[None] + def maybe_profile(self) -> Iterator[None]: if self._profiler_factory: try: profile_id = 'direct-' + subprocess.check_output([ @@ -291,8 +286,8 @@ def maybe_profile(self): ]).decode(errors='ignore').strip() except subprocess.CalledProcessError: profile_id = 'direct-unknown' - profiler = self._profiler_factory( - profile_id, time_prefix='') # type: Optional[Profile] + profiler: Optional[Profile] = self._profiler_factory( + profile_id, time_prefix='') else: profiler = None @@ -328,14 +323,12 @@ def maybe_profile(self): # Empty context. yield - def _validate_requirements(self, pipeline_proto): - # type: (beam_runner_api_pb2.Pipeline) -> None - + def _validate_requirements( + self, pipeline_proto: beam_runner_api_pb2.Pipeline) -> None: """As a test runner, validate requirements were set correctly.""" expected_requirements = set() - def add_requirements(transform_id): - # type: (str) -> None + def add_requirements(transform_id: str) -> None: transform = pipeline_proto.components.transforms[transform_id] if transform.spec.urn in translations.PAR_DO_URNS: payload = proto_utils.parse_Bytes( @@ -366,9 +359,8 @@ def add_requirements(transform_id): 'Missing requirement declaration: %s' % (expected_requirements - set(pipeline_proto.requirements))) - def _check_requirements(self, pipeline_proto): - # type: (beam_runner_api_pb2.Pipeline) -> None - + def _check_requirements( + self, pipeline_proto: beam_runner_api_pb2.Pipeline) -> None: """Check that this runner can satisfy all pipeline requirements.""" supported_requirements = set(self.supported_requirements()) for requirement in pipeline_proto.requirements: @@ -388,10 +380,8 @@ def _check_requirements(self, pipeline_proto): raise NotImplementedError(timer.time_domain) def create_stages( - self, - pipeline_proto # type: beam_runner_api_pb2.Pipeline - ): - # type: (...) -> Tuple[translations.TransformContext, List[translations.Stage]] + self, pipeline_proto: beam_runner_api_pb2.Pipeline + ) -> Tuple[translations.TransformContext, List[translations.Stage]]: return translations.create_and_optimize_stages( copy.deepcopy(pipeline_proto), phases=[ @@ -417,12 +407,10 @@ def create_stages( use_state_iterables=self._use_state_iterables, is_drain=self._is_drain) - def run_stages(self, - stage_context, # type: translations.TransformContext - stages # type: List[translations.Stage] - ): - # type: (...) -> RunnerResult - + def run_stages( + self, + stage_context: translations.TransformContext, + stages: List[translations.Stage]) -> RunnerResult: """Run a list of topologically-sorted stages in batch mode. Args: @@ -593,11 +581,12 @@ def _schedule_ready_bundles( def _run_bundle_multiple_times_for_testing( self, - runner_execution_context, # type: execution.FnApiRunnerExecutionContext - bundle_manager, # type: BundleManager - data_input, # type: MutableMapping[str, execution.PartitionableBuffer] - data_output, # type: DataOutput - fired_timers, # type: Mapping[translations.TimerFamilyId, execution.PartitionableBuffer] + runner_execution_context: execution.FnApiRunnerExecutionContext, + bundle_manager: BundleManager, + data_input: MutableMapping[str, execution.PartitionableBuffer], + data_output: DataOutput, + fired_timers: Mapping[translations.TimerFamilyId, + execution.PartitionableBuffer], expected_output_timers: OutputTimers, ) -> None: """ @@ -679,12 +668,10 @@ def _collect_written_timers( def _add_sdk_delayed_applications_to_deferred_inputs( self, - bundle_context_manager, # type: execution.BundleContextManager - bundle_result, # type: beam_fn_api_pb2.InstructionResponse - deferred_inputs # type: MutableMapping[str, execution.PartitionableBuffer] - ): - # type: (...) -> Set[str] - + bundle_context_manager: execution.BundleContextManager, + bundle_result: beam_fn_api_pb2.InstructionResponse, + deferred_inputs: MutableMapping[str, execution.PartitionableBuffer] + ) -> Set[str]: """Returns a set of PCollection IDs of PColls having delayed applications. This transform inspects the bundle_context_manager, and bundle_result @@ -711,13 +698,11 @@ def _add_sdk_delayed_applications_to_deferred_inputs( def _add_residuals_and_channel_splits_to_deferred_inputs( self, - splits, # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse] - bundle_context_manager, # type: execution.BundleContextManager - last_sent, # type: MutableMapping[str, execution.PartitionableBuffer] - deferred_inputs # type: MutableMapping[str, execution.PartitionableBuffer] - ): - # type: (...) -> Tuple[Set[str], Set[str]] - + splits: List[beam_fn_api_pb2.ProcessBundleSplitResponse], + bundle_context_manager: execution.BundleContextManager, + last_sent: MutableMapping[str, execution.PartitionableBuffer], + deferred_inputs: MutableMapping[str, execution.PartitionableBuffer] + ) -> Tuple[Set[str], Set[str]]: """Returns a two sets representing PCollections with watermark holds. The first set represents PCollections with delayed root applications. @@ -726,7 +711,7 @@ def _add_residuals_and_channel_splits_to_deferred_inputs( pcolls_with_delayed_apps = set() transforms_with_channel_splits = set() - prev_stops = {} # type: Dict[str, int] + prev_stops: Dict[str, int] = {} for split in splits: for delayed_application in split.residual_roots: producer_name = bundle_context_manager.input_for( @@ -783,11 +768,11 @@ def _add_residuals_and_channel_splits_to_deferred_inputs( channel_split.transform_id] = channel_split.last_primary_element return pcolls_with_delayed_apps, transforms_with_channel_splits - def _execute_bundle(self, - runner_execution_context, # type: execution.FnApiRunnerExecutionContext - bundle_context_manager, # type: execution.BundleContextManager - bundle_input: DataInput - ) -> beam_fn_api_pb2.InstructionResponse: + def _execute_bundle( + self, + runner_execution_context: execution.FnApiRunnerExecutionContext, + bundle_context_manager: execution.BundleContextManager, + bundle_input: DataInput) -> beam_fn_api_pb2.InstructionResponse: """Execute a bundle end-to-end. Args: @@ -943,7 +928,8 @@ def _get_bundle_manager( cache_token_generator = FnApiRunner.get_cache_token_generator(static=False) if bundle_context_manager.num_workers == 1: # Avoid thread/processor pools for increased performance and debugability. - bundle_manager_type = BundleManager # type: Union[Type[BundleManager], Type[ParallelBundleManager]] + bundle_manager_type: Union[Type[BundleManager], + Type[ParallelBundleManager]] = BundleManager elif bundle_context_manager.stage.is_stateful(): # State is keyed, and a single key cannot be processed concurrently. # Alternatively, we could arrange to partition work by key. @@ -958,12 +944,13 @@ def _get_bundle_manager( @staticmethod def _build_watermark_updates( - runner_execution_context, # type: execution.FnApiRunnerExecutionContext - stage_inputs, # type: Iterable[str] - expected_timers, # type: Iterable[translations.TimerFamilyId] - pcolls_with_da, # type: Set[str] - transforms_w_splits, # type: Set[str] - watermarks_by_transform_and_timer_family # type: Dict[translations.TimerFamilyId, timestamp.Timestamp] + runner_execution_context: execution.FnApiRunnerExecutionContext, + stage_inputs: Iterable[str], + expected_timers: Iterable[translations.TimerFamilyId], + pcolls_with_da: Set[str], + transforms_w_splits: Set[str], + watermarks_by_transform_and_timer_family: Dict[translations.TimerFamilyId, + timestamp.Timestamp] ) -> Dict[Union[str, translations.TimerFamilyId], timestamp.Timestamp]: """Builds a dictionary of PCollection (or TimerFamilyId) to timestamp. @@ -979,8 +966,8 @@ def _build_watermark_updates( watermarks_by_transform_and_timer_family: represent the set of watermark holds to be added for each timer family. """ - updates = { - } # type: Dict[Union[str, translations.TimerFamilyId], timestamp.Timestamp] + updates: Dict[Union[str, translations.TimerFamilyId], + timestamp.Timestamp] = {} def get_pcoll_id(transform_id): buffer_id = runner_execution_context.input_transform_to_buffer_id[ @@ -1024,12 +1011,12 @@ def get_pcoll_id(transform_id): def _run_bundle( self, - runner_execution_context, # type: execution.FnApiRunnerExecutionContext - bundle_context_manager, # type: execution.BundleContextManager + runner_execution_context: execution.FnApiRunnerExecutionContext, + bundle_context_manager: execution.BundleContextManager, bundle_input: DataInput, data_output: DataOutput, expected_timer_output: OutputTimers, - bundle_manager # type: BundleManager + bundle_manager: BundleManager ) -> Tuple[beam_fn_api_pb2.InstructionResponse, Dict[str, execution.PartitionableBuffer], OutputTimerData, @@ -1052,7 +1039,7 @@ def _run_bundle( # - timers # - SDK-initiated deferred applications of root elements # - Runner-initiated deferred applications of root elements - deferred_inputs = {} # type: Dict[str, execution.PartitionableBuffer] + deferred_inputs: Dict[str, execution.PartitionableBuffer] = {} watermarks_by_transform_and_timer_family, newly_set_timers = ( self._collect_written_timers(bundle_context_manager)) @@ -1085,48 +1072,42 @@ def _run_bundle( return result, deferred_inputs, newly_set_timers, watermark_updates @staticmethod - def get_cache_token_generator(static=True): - # type: (bool) -> Iterator[beam_fn_api_pb2.ProcessBundleRequest.CacheToken] - + def get_cache_token_generator( + static: bool = True + ) -> Iterator[beam_fn_api_pb2.ProcessBundleRequest.CacheToken]: """A generator for cache tokens. :arg static If True, generator always returns the same cache token If False, generator returns a new cache token each time :return A generator which returns a cache token on next(generator) """ - def generate_token(identifier): - # type: (int) -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken + def generate_token( + identifier: int) -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken: return beam_fn_api_pb2.ProcessBundleRequest.CacheToken( user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.UserState( ), token="cache_token_{}".format(identifier).encode("utf-8")) class StaticGenerator(object): - def __init__(self): - # type: () -> None + def __init__(self) -> None: self._token = generate_token(1) - def __iter__(self): - # type: () -> StaticGenerator + def __iter__(self) -> StaticGenerator: # pylint: disable=non-iterator-returned return self - def __next__(self): - # type: () -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken + def __next__(self) -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken: return self._token class DynamicGenerator(object): - def __init__(self): - # type: () -> None + def __init__(self) -> None: self._counter = 0 self._lock = threading.Lock() - def __iter__(self): - # type: () -> DynamicGenerator + def __iter__(self) -> DynamicGenerator: # pylint: disable=non-iterator-returned return self - def __next__(self): - # type: () -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken + def __next__(self) -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken: with self._lock: self._counter += 1 return generate_token(self._counter) @@ -1138,19 +1119,18 @@ def __next__(self): class ExtendedProvisionInfo(object): - def __init__(self, - provision_info=None, # type: Optional[beam_provision_api_pb2.ProvisionInfo] - artifact_staging_dir=None, # type: Optional[str] - job_name='', # type: str - ): - # type: (...) -> None + def __init__( + self, + provision_info: Optional[beam_provision_api_pb2.ProvisionInfo] = None, + artifact_staging_dir: Optional[str] = None, + job_name: str = '', + ) -> None: self.provision_info = ( provision_info or beam_provision_api_pb2.ProvisionInfo()) self.artifact_staging_dir = artifact_staging_dir self.job_name = job_name - def for_environment(self, env): - # type: (...) -> ExtendedProvisionInfo + def for_environment(self, env) -> ExtendedProvisionInfo: if env.dependencies: provision_info_with_deps = copy.deepcopy(self.provision_info) provision_info_with_deps.dependencies.extend(env.dependencies) @@ -1218,31 +1198,27 @@ class BundleManager(object): _uid_counter = 0 _lock = threading.Lock() - def __init__(self, - bundle_context_manager, # type: execution.BundleContextManager - progress_frequency=None, # type: Optional[float] - cache_token_generator=FnApiRunner.get_cache_token_generator(), - split_managers=() - ): - # type: (...) -> None - + def __init__( + self, + bundle_context_manager: execution.BundleContextManager, + progress_frequency: Optional[float] = None, + cache_token_generator=FnApiRunner.get_cache_token_generator(), + split_managers=() + ) -> None: """Set up a bundle manager. Args: progress_frequency """ - self.bundle_context_manager = bundle_context_manager # type: execution.BundleContextManager + self.bundle_context_manager: execution.BundleContextManager = bundle_context_manager self._progress_frequency = progress_frequency - self._worker_handler = None # type: Optional[WorkerHandler] + self._worker_handler: Optional[WorkerHandler] = None self._cache_token_generator = cache_token_generator self.split_managers = split_managers - def _send_input_to_worker(self, - process_bundle_id, # type: str - read_transform_id, # type: str - byte_streams - ): - # type: (...) -> None + def _send_input_to_worker( + self, process_bundle_id: str, read_transform_id: str, + byte_streams) -> None: assert self._worker_handler is not None data_out = self._worker_handler.data_conn.output_stream( process_bundle_id, read_transform_id) @@ -1251,8 +1227,7 @@ def _send_input_to_worker(self, data_out.close() def _send_timers_to_worker( - self, process_bundle_id, transform_id, timer_family_id, timers): - # type: (...) -> None + self, process_bundle_id, transform_id, timer_family_id, timers) -> None: assert self._worker_handler is not None timer_out = self._worker_handler.data_conn.output_timer_stream( process_bundle_id, transform_id, timer_family_id) @@ -1273,13 +1248,12 @@ def _select_split_manager(self) -> Optional[Callable[[int], Iterable[float]]]: return None - def _generate_splits_for_testing(self, - split_manager, - inputs, # type: Mapping[str, execution.PartitionableBuffer] - process_bundle_id - ): - # type: (...) -> List[beam_fn_api_pb2.ProcessBundleSplitResponse] - split_results = [] # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse] + def _generate_splits_for_testing( + self, + split_manager, + inputs: Mapping[str, execution.PartitionableBuffer], + process_bundle_id) -> List[beam_fn_api_pb2.ProcessBundleSplitResponse]: + split_results: List[beam_fn_api_pb2.ProcessBundleSplitResponse] = [] read_transform_id, buffer_data = only_element(inputs.items()) byte_stream = b''.join(buffer_data or []) num_elements = len( @@ -1317,8 +1291,8 @@ def _generate_splits_for_testing(self, estimated_input_elements=num_elements) })) logging.info("Requesting split %s", split_request) - split_response = self._worker_handler.control_conn.push( - split_request).get() # type: beam_fn_api_pb2.InstructionResponse + split_response: beam_fn_api_pb2.InstructionResponse = self._worker_handler.control_conn.push( + split_request).get() for t in (0.05, 0.1, 0.2): if ('Unknown process bundle' in split_response.error or split_response.process_bundle_split == @@ -1343,13 +1317,15 @@ def _generate_splits_for_testing(self, break return split_results - def process_bundle(self, - inputs, # type: Mapping[str, execution.PartitionableBuffer] - expected_outputs, # type: DataOutput - fired_timers, # type: Mapping[translations.TimerFamilyId, execution.PartitionableBuffer] - expected_output_timers: OutputTimers, - dry_run=False, # type: bool - ) -> BundleProcessResult: + def process_bundle( + self, + inputs: Mapping[str, execution.PartitionableBuffer], + expected_outputs: DataOutput, + fired_timers: Mapping[translations.TimerFamilyId, + execution.PartitionableBuffer], + expected_output_timers: OutputTimers, + dry_run: bool = False, + ) -> BundleProcessResult: # Unique id for the instruction processing this bundle. with BundleManager._lock: BundleManager._uid_counter += 1 @@ -1383,7 +1359,7 @@ def process_bundle(self, cache_tokens=[next(self._cache_token_generator)])) result_future = self._worker_handler.control_conn.push(process_bundle_req) - split_results = [] # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse] + split_results: List[beam_fn_api_pb2.ProcessBundleSplitResponse] = [] with ProgressRequester(self._worker_handler, process_bundle_id, self._progress_frequency): @@ -1392,8 +1368,9 @@ def process_bundle(self, split_results = self._generate_splits_for_testing( split_manager, inputs, process_bundle_id) - expect_reads = list( - expected_outputs.keys()) # type: List[Union[str, Tuple[str, str]]] + expect_reads: List[Union[str, + Tuple[str, + str]]] = list(expected_outputs.keys()) expect_reads.extend(list(expected_output_timers.keys())) # Gather all output data. @@ -1417,7 +1394,7 @@ def process_bundle(self, expected_outputs[output.transform_id], output.transform_id).append(output.data) - result = result_future.get() # type: beam_fn_api_pb2.InstructionResponse + result: beam_fn_api_pb2.InstructionResponse = result_future.get() if result.error: raise RuntimeError(result.error) @@ -1435,30 +1412,30 @@ def process_bundle(self, class ParallelBundleManager(BundleManager): - def __init__( self, - bundle_context_manager, # type: execution.BundleContextManager - progress_frequency=None, # type: Optional[float] + bundle_context_manager: execution.BundleContextManager, + progress_frequency: Optional[float] = None, cache_token_generator=None, - **kwargs): - # type: (...) -> None + **kwargs) -> None: super().__init__( bundle_context_manager, progress_frequency, cache_token_generator=cache_token_generator) self._num_workers = bundle_context_manager.num_workers - def process_bundle(self, - inputs, # type: Mapping[str, execution.PartitionableBuffer] - expected_outputs, # type: DataOutput - fired_timers, # type: Mapping[translations.TimerFamilyId, execution.PartitionableBuffer] - expected_output_timers, # type: OutputTimers - dry_run=False, # type: bool - ): - # type: (...) -> BundleProcessResult - part_inputs = [{} for _ in range(self._num_workers) - ] # type: List[Dict[str, List[bytes]]] + def process_bundle( + self, + inputs: Mapping[str, execution.PartitionableBuffer], + expected_outputs: DataOutput, + fired_timers: Mapping[translations.TimerFamilyId, + execution.PartitionableBuffer], + expected_output_timers: OutputTimers, + dry_run: bool = False, + ) -> BundleProcessResult: + part_inputs: List[Dict[str, + List[bytes]]] = [{} + for _ in range(self._num_workers)] # Timers are only executed on the first worker # TODO(BEAM-9741): Split timers to multiple workers timer_inputs = [ @@ -1468,12 +1445,10 @@ def process_bundle(self, for ix, part in enumerate(input.partition(self._num_workers)): part_inputs[ix][name] = part - merged_result = None # type: Optional[beam_fn_api_pb2.InstructionResponse] - split_result_list = [ - ] # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse] + merged_result: Optional[beam_fn_api_pb2.InstructionResponse] = None + split_result_list: List[beam_fn_api_pb2.ProcessBundleSplitResponse] = [] - def execute(part_map_input_timers): - # type: (...) -> BundleProcessResult + def execute(part_map_input_timers) -> BundleProcessResult: part_map, input_timers = part_map_input_timers bundle_manager = BundleManager( self.bundle_context_manager, @@ -1509,20 +1484,19 @@ class ProgressRequester(threading.Thread): A callback can be passed to call with progress updates. """ - - def __init__(self, - worker_handler, # type: WorkerHandler - instruction_id, - frequency, - callback=None - ): - # type: (...) -> None + def __init__( + self, + worker_handler: WorkerHandler, + instruction_id, + frequency, + callback=None) -> None: super().__init__() self._worker_handler = worker_handler self._instruction_id = instruction_id self._frequency = frequency self._done = False - self._latest_progress = None # type: Optional[beam_fn_api_pb2.ProcessBundleProgressResponse] + self._latest_progress: Optional[ + beam_fn_api_pb2.ProcessBundleProgressResponse] = None self._callback = callback self.daemon = True @@ -1593,8 +1567,7 @@ def query(self, filter=None): self.GAUGES: gauges } - def monitoring_infos(self): - # type: () -> List[metrics_pb2.MonitoringInfo] + def monitoring_infos(self) -> List[metrics_pb2.MonitoringInfo]: return [ item for sublist in self._monitoring_infos.values() for item in sublist ] diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py index 4a35da8dd274..97b10b83e051 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py @@ -2224,7 +2224,7 @@ def __reduce__(self): return _unpickle_element_counter, (name, ) -_pickled_element_counters = {} # type: Dict[str, ElementCounter] +_pickled_element_counters: Dict[str, ElementCounter] = {} def _unpickle_element_counter(name): diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/watermark_manager.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/watermark_manager.py index 6f926a6284e2..106eca108297 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/watermark_manager.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/watermark_manager.py @@ -102,8 +102,7 @@ def input_watermark(self): w = min(w, min(i._produced_watermark for i in self.side_inputs)) return w - def __init__(self, stages): - # type: (List[translations.Stage]) -> None + def __init__(self, stages: List[translations.Stage]) -> None: self._pcollections_by_name: Dict[Union[str, translations.TimerFamilyId], WatermarkManager.PCollectionNode] = {} self._stages_by_name: Dict[str, WatermarkManager.StageNode] = {} @@ -189,12 +188,12 @@ def _verify(self, stages: List[translations.Stage]): 'Stage %s has no main inputs. ' 'At least one main input is necessary.' % s.name) - def get_stage_node(self, name): - # type: (str) -> StageNode # noqa: F821 + def get_stage_node(self, name: str) -> StageNode: + # noqa: F821 return self._stages_by_name[name] - def get_pcoll_node(self, name): - # type: (str) -> PCollectionNode # noqa: F821 + def get_pcoll_node(self, name: str) -> PCollectionNode: + # noqa: F821 return self._pcollections_by_name[name] def set_pcoll_watermark(self, name, watermark): diff --git a/sdks/python/apache_beam/runners/portability/job_server.py b/sdks/python/apache_beam/runners/portability/job_server.py index eda8755e18ab..030a3b67df33 100644 --- a/sdks/python/apache_beam/runners/portability/job_server.py +++ b/sdks/python/apache_beam/runners/portability/job_server.py @@ -48,8 +48,7 @@ def __init__(self, endpoint, timeout=None): self._endpoint = endpoint self._timeout = timeout - def start(self): - # type: () -> beam_job_api_pb2_grpc.JobServiceStub + def start(self) -> beam_job_api_pb2_grpc.JobServiceStub: channel = grpc.insecure_channel(self._endpoint) grpc.channel_ready_future(channel).result(timeout=self._timeout) return beam_job_api_pb2_grpc.JobServiceStub(channel) @@ -59,8 +58,7 @@ def stop(self): class EmbeddedJobServer(JobServer): - def start(self): - # type: () -> local_job_service.LocalJobServicer + def start(self) -> local_job_service.LocalJobServicer: return local_job_service.LocalJobServicer() def stop(self): diff --git a/sdks/python/apache_beam/runners/portability/local_job_service.py b/sdks/python/apache_beam/runners/portability/local_job_service.py index 6966e66d2c64..4b6d4718f4dd 100644 --- a/sdks/python/apache_beam/runners/portability/local_job_service.py +++ b/sdks/python/apache_beam/runners/portability/local_job_service.py @@ -87,16 +87,16 @@ def __init__(self, staging_dir=None, beam_job_type=None): self._staging_dir = staging_dir or tempfile.mkdtemp() self._artifact_service = artifact_service.ArtifactStagingService( artifact_service.BeamFilesystemHandler(self._staging_dir).file_writer) - self._artifact_staging_endpoint = None # type: Optional[endpoints_pb2.ApiServiceDescriptor] + self._artifact_staging_endpoint: Optional[ + endpoints_pb2.ApiServiceDescriptor] = None self._beam_job_type = beam_job_type or BeamJob def create_beam_job(self, preparation_id, # stype: str - job_name, # type: str - pipeline, # type: beam_runner_api_pb2.Pipeline - options # type: struct_pb2.Struct - ): - # type: (...) -> BeamJob + job_name: str, + pipeline: beam_runner_api_pb2.Pipeline, + options: struct_pb2.Struct + ) -> BeamJob: self._artifact_service.register_job( staging_token=preparation_id, dependency_sets=_extract_dependency_sets( @@ -181,7 +181,7 @@ class SubprocessSdkWorker(object): """ def __init__( self, - worker_command_line, # type: bytes + worker_command_line: bytes, control_address, provision_info, worker_id=None): @@ -238,20 +238,20 @@ class BeamJob(abstract_job_service.AbstractBeamJob): The current state of the pipeline is available as self.state. """ - - def __init__(self, - job_id, # type: str - pipeline, - options, - provision_info, # type: fn_runner.ExtendedProvisionInfo - artifact_staging_endpoint, # type: Optional[endpoints_pb2.ApiServiceDescriptor] - artifact_service, # type: artifact_service.ArtifactStagingService - ): + def __init__( + self, + job_id: str, + pipeline, + options, + provision_info: fn_runner.ExtendedProvisionInfo, + artifact_staging_endpoint: Optional[endpoints_pb2.ApiServiceDescriptor], + artifact_service: artifact_service.ArtifactStagingService, + ): super().__init__(job_id, provision_info.job_name, pipeline, options) self._provision_info = provision_info self._artifact_staging_endpoint = artifact_staging_endpoint self._artifact_service = artifact_service - self._state_queues = [] # type: List[queue.Queue] + self._state_queues: List[queue.Queue] = [] self._log_queues = JobLogQueues() self.daemon = True self.result = None @@ -378,7 +378,7 @@ def Logging(self, log_bundles, context=None): class JobLogQueues(object): def __init__(self): - self._queues = [] # type: List[queue.Queue] + self._queues: List[queue.Queue] = [] self._cache = [] self._cache_size = 10 self._lock = threading.Lock() diff --git a/sdks/python/apache_beam/runners/portability/portable_runner.py b/sdks/python/apache_beam/runners/portability/portable_runner.py index ab5ee9fff6f9..fd19c2ba2388 100644 --- a/sdks/python/apache_beam/runners/portability/portable_runner.py +++ b/sdks/python/apache_beam/runners/portability/portable_runner.py @@ -97,9 +97,11 @@ def __init__(self, job_service, options, retain_unknown_options=False): self.artifact_endpoint = options.view_as(PortableOptions).artifact_endpoint self._retain_unknown_options = retain_unknown_options - def submit(self, proto_pipeline): - # type: (beam_runner_api_pb2.Pipeline) -> Tuple[str, Iterator[beam_job_api_pb2.JobStateEvent], Iterator[beam_job_api_pb2.JobMessagesResponse]] - + def submit( + self, proto_pipeline: beam_runner_api_pb2.Pipeline + ) -> Tuple[str, + Iterator[beam_job_api_pb2.JobStateEvent], + Iterator[beam_job_api_pb2.JobMessagesResponse]]: """ Submit and run the pipeline defined by `proto_pipeline`. """ @@ -113,9 +115,7 @@ def submit(self, proto_pipeline): prepare_response.staging_session_token) return self.run(prepare_response.preparation_id) - def get_pipeline_options(self): - # type: () -> struct_pb2.Struct - + def get_pipeline_options(self) -> struct_pb2.Struct: """ Get `self.options` as a protobuf Struct """ @@ -189,9 +189,9 @@ def convert_pipeline_option_value(v): } return job_utils.dict_to_struct(p_options) - def prepare(self, proto_pipeline): - # type: (beam_runner_api_pb2.Pipeline) -> beam_job_api_pb2.PrepareJobResponse - + def prepare( + self, proto_pipeline: beam_runner_api_pb2.Pipeline + ) -> beam_job_api_pb2.PrepareJobResponse: """Prepare the job on the job service""" return self.job_service.Prepare( beam_job_api_pb2.PrepareJobRequest( @@ -200,13 +200,11 @@ def prepare(self, proto_pipeline): pipeline_options=self.get_pipeline_options()), timeout=self.timeout) - def stage(self, - proto_pipeline, # type: beam_runner_api_pb2.Pipeline - artifact_staging_endpoint, - staging_session_token - ): - # type: (...) -> None - + def stage( + self, + proto_pipeline: beam_runner_api_pb2.Pipeline, + artifact_staging_endpoint, + staging_session_token) -> None: """Stage artifacts""" if artifact_staging_endpoint: artifact_service.offer_artifacts( @@ -216,9 +214,11 @@ def stage(self, artifact_service.BeamFilesystemHandler(None).file_reader), staging_session_token) - def run(self, preparation_id): - # type: (str) -> Tuple[str, Iterator[beam_job_api_pb2.JobStateEvent], Iterator[beam_job_api_pb2.JobMessagesResponse]] - + def run( + self, preparation_id: str + ) -> Tuple[str, + Iterator[beam_job_api_pb2.JobStateEvent], + Iterator[beam_job_api_pb2.JobMessagesResponse]]: """Run the job""" try: state_stream = self.job_service.GetStateStream( @@ -260,11 +260,10 @@ class PortableRunner(runner.PipelineRunner): running and managing the job lies with the job service used. """ def __init__(self): - self._dockerized_job_server = None # type: Optional[job_server.JobServer] + self._dockerized_job_server: Optional[job_server.JobServer] = None @staticmethod - def _create_environment(options): - # type: (PipelineOptions) -> environments.Environment + def _create_environment(options: PipelineOptions) -> environments.Environment: return environments.Environment.from_options( options.view_as(PortableOptions)) @@ -274,20 +273,17 @@ def default_job_server(self, options): 'Alternatively, you may specify which portable runner you intend to ' 'use, such as --runner=FlinkRunner or --runner=SparkRunner.') - def create_job_service_handle(self, job_service, options): - # type: (...) -> JobServiceHandle + def create_job_service_handle(self, job_service, options) -> JobServiceHandle: return JobServiceHandle(job_service, options) - def create_job_service(self, options): - # type: (PipelineOptions) -> JobServiceHandle - + def create_job_service(self, options: PipelineOptions) -> JobServiceHandle: """ Start the job service and return a `JobServiceHandle` """ job_endpoint = options.view_as(PortableOptions).job_endpoint if job_endpoint: if job_endpoint == 'embed': - server = job_server.EmbeddedJobServer() # type: job_server.JobServer + server: job_server.JobServer = job_server.EmbeddedJobServer() else: job_server_timeout = options.view_as(PortableOptions).job_server_timeout server = job_server.ExternalJobServer(job_endpoint, job_server_timeout) @@ -296,8 +292,9 @@ def create_job_service(self, options): return self.create_job_service_handle(server.start(), options) @staticmethod - def get_proto_pipeline(pipeline, options): - # type: (Pipeline, PipelineOptions) -> beam_runner_api_pb2.Pipeline + def get_proto_pipeline( + pipeline: Pipeline, + options: PipelineOptions) -> beam_runner_api_pb2.Pipeline: proto_pipeline = pipeline.to_runner_api( default_environment=environments.Environment.from_options( options.view_as(PortableOptions))) @@ -473,8 +470,7 @@ def __init__( self._metrics = None self._runtime_exception = None - def cancel(self): - # type: () -> None + def cancel(self) -> None: try: self._job_service.Cancel( beam_job_api_pb2.CancelJobRequest(job_id=self._job_id)) @@ -513,8 +509,7 @@ def metrics(self): self._metrics = PortableMetrics(job_metrics_response) return self._metrics - def _last_error_message(self): - # type: () -> str + def _last_error_message(self) -> str: # Filter only messages with the "message_response" and error messages. messages = [ m.message_response for m in self._messages @@ -535,8 +530,7 @@ def wait_until_finish(self, duration=None): the execution. If None or zero, will wait until the pipeline finishes. :return: The result of the pipeline, i.e. PipelineResult. """ - def read_messages(): - # type: () -> None + def read_messages() -> None: previous_state = -1 for message in self._message_stream: if message.HasField('message_response'): @@ -595,8 +589,7 @@ def _observe_state(self, message_thread): finally: self._cleanup() - def _cleanup(self, on_exit=False): - # type: (bool) -> None + def _cleanup(self, on_exit: bool = False) -> None: if on_exit and self._cleanup_callbacks: _LOGGER.info( 'Running cleanup on exit. If your pipeline should continue running, ' diff --git a/sdks/python/apache_beam/runners/portability/stager_test.py b/sdks/python/apache_beam/runners/portability/stager_test.py index 25fd62b16533..5535989a5786 100644 --- a/sdks/python/apache_beam/runners/portability/stager_test.py +++ b/sdks/python/apache_beam/runners/portability/stager_test.py @@ -75,7 +75,7 @@ def create_temp_file(self, path, contents): def is_remote_path(self, path): return path.startswith('/tmp/remote/') - remote_copied_files = [] # type: List[str] + remote_copied_files: List[str] = [] def file_copy(self, from_path, to_path): if self.is_remote_path(from_path): From 8fdbe88dc6e1e9dd71b8d4dada1a5c871567d5f7 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:48:18 -0700 Subject: [PATCH 19/29] Modernize python type hints for apache_beam/runners/worker --- .../apache_beam/runners/worker/log_handler.py | 30 +++----- .../apache_beam/runners/worker/logger.py | 29 +++---- .../apache_beam/runners/worker/statecache.py | 76 ++++++------------- .../runners/worker/statesampler.py | 50 +++++------- .../runners/worker/statesampler_slow.py | 51 +++++-------- .../runners/worker/worker_id_interceptor.py | 3 +- .../runners/worker/worker_pool_main.py | 37 ++++----- .../runners/worker/worker_status.py | 4 +- 8 files changed, 107 insertions(+), 173 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/log_handler.py b/sdks/python/apache_beam/runners/worker/log_handler.py index 6a2f612fbee1..b7cf9db757d3 100644 --- a/sdks/python/apache_beam/runners/worker/log_handler.py +++ b/sdks/python/apache_beam/runners/worker/log_handler.py @@ -81,15 +81,15 @@ class FnApiLogRecordHandler(logging.Handler): # dropped. If the average log size is 1KB this may use up to 10MB of memory. _QUEUE_SIZE = 10000 - def __init__(self, log_service_descriptor): - # type: (endpoints_pb2.ApiServiceDescriptor) -> None + def __init__( + self, log_service_descriptor: endpoints_pb2.ApiServiceDescriptor) -> None: super().__init__() self._alive = True self._dropped_logs = 0 - self._log_entry_queue = queue.Queue( - maxsize=self._QUEUE_SIZE - ) # type: queue.Queue[Union[beam_fn_api_pb2.LogEntry, Sentinel]] + self._log_entry_queue: queue.Queue[Union[beam_fn_api_pb2.LogEntry, + Sentinel]] = queue.Queue( + maxsize=self._QUEUE_SIZE) ch = GRPCChannelFactory.insecure_channel(log_service_descriptor.url) # Make sure the channel is ready to avoid [BEAM-4649] @@ -101,16 +101,15 @@ def __init__(self, log_service_descriptor): self._reader.daemon = True self._reader.start() - def connect(self): - # type: () -> Iterable + def connect(self) -> Iterable: if hasattr(self, '_logging_stub'): del self._logging_stub # type: ignore[has-type] self._logging_stub = beam_fn_api_pb2_grpc.BeamFnLoggingStub( self._log_channel) return self._logging_stub.Logging(self._write_log_entries()) - def map_log_level(self, level): - # type: (int) -> beam_fn_api_pb2.LogEntry.Severity.Enum.ValueType + def map_log_level( + self, level: int) -> beam_fn_api_pb2.LogEntry.Severity.Enum.ValueType: try: return LOG_LEVEL_TO_LOGENTRY_MAP[level] except KeyError: @@ -119,8 +118,7 @@ def map_log_level(self, level): beam_level in LOG_LEVEL_TO_LOGENTRY_MAP.items() if python_level <= level) - def emit(self, record): - # type: (logging.LogRecord) -> None + def emit(self, record: logging.LogRecord) -> None: log_entry = beam_fn_api_pb2.LogEntry() log_entry.severity = self.map_log_level(record.levelno) try: @@ -154,9 +152,7 @@ def emit(self, record): except queue.Full: self._dropped_logs += 1 - def close(self): - # type: () -> None - + def close(self) -> None: """Flush out all existing log entries and unregister this handler.""" try: self._alive = False @@ -175,8 +171,7 @@ def close(self): # prematurely. logging.error("Error closing the logging channel.", exc_info=True) - def _write_log_entries(self): - # type: () -> Iterator[beam_fn_api_pb2.LogEntry.List] + def _write_log_entries(self) -> Iterator[beam_fn_api_pb2.LogEntry.List]: done = False while not done: log_entries = [self._log_entry_queue.get()] @@ -194,8 +189,7 @@ def _write_log_entries(self): yield beam_fn_api_pb2.LogEntry.List( log_entries=cast(List[beam_fn_api_pb2.LogEntry], log_entries)) - def _read_log_control_messages(self): - # type: () -> None + def _read_log_control_messages(self) -> None: # Only reconnect when we are alive. # We can drop some logs in the unlikely event of logging connection # dropped(not closed) during termination when we still have logs to be sent. diff --git a/sdks/python/apache_beam/runners/worker/logger.py b/sdks/python/apache_beam/runners/worker/logger.py index e1c84bc6ded2..1efebeb3c78c 100644 --- a/sdks/python/apache_beam/runners/worker/logger.py +++ b/sdks/python/apache_beam/runners/worker/logger.py @@ -39,15 +39,13 @@ # context information that changes while work items get executed: # work_item_id, step_name, stage_name. class _PerThreadWorkerData(threading.local): - def __init__(self): - # type: () -> None + def __init__(self) -> None: super().__init__() # in the list, as going up and down all the way to zero incurs several # reallocations. - self.stack = [] # type: List[Dict[str, Any]] + self.stack: List[Dict[str, Any]] = [] - def get_data(self): - # type: () -> Dict[str, Any] + def get_data(self) -> Dict[str, Any]: all_data = {} for datum in self.stack: all_data.update(datum) @@ -58,9 +56,7 @@ def get_data(self): @contextlib.contextmanager -def PerThreadLoggingContext(**kwargs): - # type: (**Any) -> Iterator[None] - +def PerThreadLoggingContext(**kwargs: Any) -> Iterator[None]: """A context manager to add per thread attributes.""" stack = per_thread_worker_data.stack stack.append(kwargs) @@ -72,15 +68,12 @@ def PerThreadLoggingContext(**kwargs): class JsonLogFormatter(logging.Formatter): """A JSON formatter class as expected by the logging standard module.""" - def __init__(self, job_id, worker_id): - # type: (str, str) -> None + def __init__(self, job_id: str, worker_id: str) -> None: super().__init__() self.job_id = job_id self.worker_id = worker_id - def format(self, record): - # type: (logging.LogRecord) -> str - + def format(self, record: logging.LogRecord) -> str: """Returns a JSON string based on a LogRecord instance. Args: @@ -115,7 +108,7 @@ def format(self, record): Python thread object. Nevertheless having this value can allow to filter log statement from only one specific thread. """ - output = {} # type: Dict[str, Any] + output: Dict[str, Any] = {} output['timestamp'] = { 'seconds': int(record.created), 'nanos': int(record.msecs * 1000000) } @@ -170,9 +163,11 @@ def format(self, record): return json.dumps(output) -def initialize(job_id, worker_id, log_path, log_level=logging.INFO): - # type: (str, str, str, int) -> None - +def initialize( + job_id: str, + worker_id: str, + log_path: str, + log_level: int = logging.INFO) -> None: """Initialize root logger so that we log JSON to a file and text to stdout.""" file_handler = logging.FileHandler(log_path) diff --git a/sdks/python/apache_beam/runners/worker/statecache.py b/sdks/python/apache_beam/runners/worker/statecache.py index dde4243057dd..d4e61cc9297f 100644 --- a/sdks/python/apache_beam/runners/worker/statecache.py +++ b/sdks/python/apache_beam/runners/worker/statecache.py @@ -58,39 +58,31 @@ class WeightedValue(object): :arg weight The associated weight of the value. If unspecified, the objects size will be used. """ - def __init__(self, value, weight): - # type: (Any, int) -> None + def __init__(self, value: Any, weight: int) -> None: self._value = value if weight <= 0: raise ValueError( 'Expected weight to be > 0 for %s but received %d' % (value, weight)) self._weight = weight - def weight(self): - # type: () -> int + def weight(self) -> int: return self._weight - def value(self): - # type: () -> Any + def value(self) -> Any: return self._value class CacheAware(object): """Allows cache users to override what objects are measured.""" - def __init__(self): - # type: () -> None + def __init__(self) -> None: pass - def get_referents_for_cache(self): - # type: () -> List[Any] - + def get_referents_for_cache(self) -> List[Any]: """Returns the list of objects accounted during cache measurement.""" raise NotImplementedError() -def _safe_isinstance(obj, type): - # type: (Any, Union[type, Tuple[type, ...]]) -> bool - +def _safe_isinstance(obj: Any, type: Union[type, Tuple[type, ...]]) -> bool: """ Return whether an object is an instance of a class or of a subclass thereof. See `isinstance()` for more information. @@ -106,9 +98,7 @@ def _safe_isinstance(obj, type): return False -def _size_func(obj): - # type: (Any) -> int - +def _size_func(obj: Any) -> int: """ Returns the size of the object or a default size if an error occurred during sizing. @@ -136,9 +126,7 @@ def _size_func(obj): _size_func.last_log_time = 0 # type: ignore -def _get_referents_func(*objs): - # type: (List[Any]) -> List[Any] - +def _get_referents_func(*objs: List[Any]) -> List[Any]: """Returns the list of objects accounted during cache measurement. Users can inherit CacheAware to override which referents should be @@ -154,9 +142,7 @@ def _get_referents_func(*objs): return rval -def _filter_func(o): - # type: (Any) -> bool - +def _filter_func(o: Any) -> bool: """ Filter out specific types from being measured. @@ -171,9 +157,7 @@ def _filter_func(o): return not _safe_isinstance(o, _TYPES_TO_NOT_MEASURE) -def get_deep_size(*objs): - # type: (Any) -> int - +def get_deep_size(*objs: Any) -> int: """Calculates the deep size of all the arguments in bytes.""" return objsize.get_deep_size( *objs, @@ -184,13 +168,11 @@ def get_deep_size(*objs): class _LoadingValue(WeightedValue): """Allows concurrent users of the cache to wait for a value to be loaded.""" - def __init__(self): - # type: () -> None + def __init__(self) -> None: super().__init__(None, 1) self._wait_event = threading.Event() - def load(self, key, loading_fn): - # type: (Any, Callable[[Any], Any]) -> None + def load(self, key: Any, loading_fn: Callable[[Any], Any]) -> None: try: self._value = loading_fn(key) except Exception as err: @@ -198,8 +180,7 @@ def load(self, key, loading_fn): finally: self._wait_event.set() - def value(self): - # type: () -> Any + def value(self) -> Any: self._wait_event.wait() err = getattr(self, "_error", None) if err: @@ -229,13 +210,12 @@ class StateCache(object): :arg max_weight The maximum weight of entries to store in the cache in bytes. """ - def __init__(self, max_weight): - # type: (int) -> None + def __init__(self, max_weight: int) -> None: _LOGGER.info('Creating state cache with size %s', max_weight) self._max_weight = max_weight self._current_weight = 0 - self._cache = collections.OrderedDict( - ) # type: collections.OrderedDict[Any, WeightedValue] + self._cache: collections.OrderedDict[ + Any, WeightedValue] = collections.OrderedDict() self._hit_count = 0 self._miss_count = 0 self._evict_count = 0 @@ -243,8 +223,7 @@ def __init__(self, max_weight): self._load_count = 0 self._lock = threading.RLock() - def peek(self, key): - # type: (Any) -> Any + def peek(self, key: Any) -> Any: assert self.is_cache_enabled() with self._lock: value = self._cache.get(key, None) @@ -256,8 +235,7 @@ def peek(self, key): self._hit_count += 1 return value.value() - def get(self, key, loading_fn): - # type: (Any, Callable[[Any], Any]) -> Any + def get(self, key: Any, loading_fn: Callable[[Any], Any]) -> Any: assert self.is_cache_enabled() and callable(loading_fn) self._lock.acquire() @@ -333,8 +311,7 @@ def get(self, key, loading_fn): return value.value() - def put(self, key, value): - # type: (Any, Any) -> None + def put(self, key: Any, value: Any) -> None: assert self.is_cache_enabled() if not _safe_isinstance(value, WeightedValue): weight = get_deep_size(value) @@ -356,22 +333,19 @@ def put(self, key, value): self._current_weight -= weighted_value.weight() self._evict_count += 1 - def invalidate(self, key): - # type: (Any) -> None + def invalidate(self, key: Any) -> None: assert self.is_cache_enabled() with self._lock: weighted_value = self._cache.pop(key, None) if weighted_value is not None: self._current_weight -= weighted_value.weight() - def invalidate_all(self): - # type: () -> None + def invalidate_all(self) -> None: with self._lock: self._cache.clear() self._current_weight = 0 - def describe_stats(self): - # type: () -> str + def describe_stats(self) -> str: with self._lock: request_count = self._hit_count + self._miss_count if request_count > 0: @@ -390,11 +364,9 @@ def describe_stats(self): self._load_count, self._evict_count) - def is_cache_enabled(self): - # type: () -> bool + def is_cache_enabled(self) -> bool: return self._max_weight > 0 - def size(self): - # type: () -> int + def size(self) -> int: with self._lock: return len(self._cache) diff --git a/sdks/python/apache_beam/runners/worker/statesampler.py b/sdks/python/apache_beam/runners/worker/statesampler.py index 4dc7e97c140d..ece31c05517a 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler.py +++ b/sdks/python/apache_beam/runners/worker/statesampler.py @@ -89,46 +89,39 @@ def for_test(): class StateSampler(statesampler_impl.StateSampler): - - def __init__(self, - prefix, # type: str - counter_factory, - sampling_period_ms=DEFAULT_SAMPLING_PERIOD_MS): + def __init__( + self, + prefix: str, + counter_factory, + sampling_period_ms=DEFAULT_SAMPLING_PERIOD_MS): self._prefix = prefix self._counter_factory = counter_factory - self._states_by_name = { - } # type: Dict[CounterName, statesampler_impl.ScopedState] + self._states_by_name: Dict[CounterName, statesampler_impl.ScopedState] = {} self.sampling_period_ms = sampling_period_ms - self.tracked_thread = None # type: Optional[threading.Thread] + self.tracked_thread: Optional[threading.Thread] = None self.finished = False self.started = False super().__init__(sampling_period_ms) @property - def stage_name(self): - # type: () -> str + def stage_name(self) -> str: return self._prefix - def stop(self): - # type: () -> None + def stop(self) -> None: set_current_tracker(None) super().stop() - def stop_if_still_running(self): - # type: () -> None + def stop_if_still_running(self) -> None: if self.started and not self.finished: self.stop() - def start(self): - # type: () -> None + def start(self) -> None: self.tracked_thread = threading.current_thread() set_current_tracker(self) super().start() self.started = True - def get_info(self): - # type: () -> StateSamplerInfo - + def get_info(self) -> StateSamplerInfo: """Returns StateSamplerInfo with transition statistics.""" return StateSamplerInfo( self.current_state().name, @@ -136,14 +129,13 @@ def get_info(self): self.time_since_transition, self.tracked_thread) - def scoped_state(self, - name_context, # type: Union[str, common.NameContext] - state_name, # type: str - io_target=None, - metrics_container=None # type: Optional[MetricsContainer] - ): - # type: (...) -> statesampler_impl.ScopedState - + def scoped_state( + self, + name_context: Union[str, common.NameContext], + state_name: str, + io_target=None, + metrics_container: Optional[MetricsContainer] = None + ) -> statesampler_impl.ScopedState: """Returns a ScopedState object associated to a Step and a State. Args: @@ -173,9 +165,7 @@ def scoped_state(self, counter_name, name_context, output_counter, metrics_container) return self._states_by_name[counter_name] - def commit_counters(self): - # type: () -> None - + def commit_counters(self) -> None: """Updates output counters with latest state statistics.""" for state in self._states_by_name.values(): state_msecs = int(1e-6 * state.nsecs) diff --git a/sdks/python/apache_beam/runners/worker/statesampler_slow.py b/sdks/python/apache_beam/runners/worker/statesampler_slow.py index 4279b4f8d7f3..33bce92d4391 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_slow.py +++ b/sdks/python/apache_beam/runners/worker/statesampler_slow.py @@ -31,21 +31,19 @@ def __init__(self, sampling_period_ms): self.state_transition_count = 0 self.time_since_transition = 0 - def current_state(self): - # type: () -> ScopedState - + def current_state(self) -> ScopedState: """Returns the current execution state. This operation is not thread safe, and should only be called from the execution thread.""" return self._state_stack[-1] - def _scoped_state(self, - counter_name, # type: counters.CounterName - name_context, # type: common.NameContext - output_counter, - metrics_container=None): - # type: (...) -> ScopedState + def _scoped_state( + self, + counter_name: counters.CounterName, + name_context: common.NameContext, + output_counter, + metrics_container=None) -> ScopedState: assert isinstance(name_context, common.NameContext) return ScopedState( self, counter_name, name_context, output_counter, metrics_container) @@ -55,38 +53,33 @@ def update_metric(self, typed_metric_name, value): if metrics_container is not None: metrics_container.get_metric_cell(typed_metric_name).update(value) - def _enter_state(self, state): - # type: (ScopedState) -> None + def _enter_state(self, state: ScopedState) -> None: self.state_transition_count += 1 self._state_stack.append(state) - def _exit_state(self): - # type: () -> None + def _exit_state(self) -> None: self.state_transition_count += 1 self._state_stack.pop() - def start(self): - # type: () -> None + def start(self) -> None: # Sampling not yet supported. Only state tracking at the moment. pass - def stop(self): - # type: () -> None + def stop(self) -> None: pass - def reset(self): - # type: () -> None + def reset(self) -> None: pass class ScopedState(object): - - def __init__(self, - sampler, # type: StateSampler - name, # type: counters.CounterName - step_name_context, # type: Optional[common.NameContext] - counter=None, - metrics_container=None): + def __init__( + self, + sampler: StateSampler, + name: counters.CounterName, + step_name_context: Optional[common.NameContext], + counter=None, + metrics_container=None): self.state_sampler = sampler self.name = name self.name_context = step_name_context @@ -94,12 +87,10 @@ def __init__(self, self.nsecs = 0 self.metrics_container = metrics_container - def sampled_seconds(self): - # type: () -> float + def sampled_seconds(self) -> float: return 1e-9 * self.nsecs - def sampled_msecs_int(self): - # type: () -> int + def sampled_msecs_int(self) -> int: return int(1e-6 * self.nsecs) def __repr__(self): diff --git a/sdks/python/apache_beam/runners/worker/worker_id_interceptor.py b/sdks/python/apache_beam/runners/worker/worker_id_interceptor.py index b913f2c63b63..1db2b5f4a151 100644 --- a/sdks/python/apache_beam/runners/worker/worker_id_interceptor.py +++ b/sdks/python/apache_beam/runners/worker/worker_id_interceptor.py @@ -40,8 +40,7 @@ class WorkerIdInterceptor(grpc.UnaryUnaryClientInterceptor, # Unique worker Id for this worker. _worker_id = os.environ.get('WORKER_ID') - def __init__(self, worker_id=None): - # type: (Optional[str]) -> None + def __init__(self, worker_id: Optional[str] = None) -> None: if worker_id: self._worker_id = worker_id diff --git a/sdks/python/apache_beam/runners/worker/worker_pool_main.py b/sdks/python/apache_beam/runners/worker/worker_pool_main.py index 7e81b1fa6d72..307261c2d3c3 100644 --- a/sdks/python/apache_beam/runners/worker/worker_pool_main.py +++ b/sdks/python/apache_beam/runners/worker/worker_pool_main.py @@ -73,18 +73,17 @@ def _kill(): class BeamFnExternalWorkerPoolServicer( beam_fn_api_pb2_grpc.BeamFnExternalWorkerPoolServicer): - - def __init__(self, - use_process=False, - container_executable=None, # type: Optional[str] - state_cache_size=0, - data_buffer_time_limit_ms=0 - ): + def __init__( + self, + use_process=False, + container_executable: Optional[str] = None, + state_cache_size=0, + data_buffer_time_limit_ms=0): self._use_process = use_process self._container_executable = container_executable self._state_cache_size = state_cache_size self._data_buffer_time_limit_ms = data_buffer_time_limit_ms - self._worker_processes = {} # type: Dict[str, subprocess.Popen] + self._worker_processes: Dict[str, subprocess.Popen] = {} @classmethod def start( @@ -93,9 +92,7 @@ def start( port=0, state_cache_size=0, data_buffer_time_limit_ms=-1, - container_executable=None # type: Optional[str] - ): - # type: (...) -> Tuple[str, grpc.Server] + container_executable: Optional[str] = None) -> Tuple[str, grpc.Server]: options = [("grpc.http2.max_pings_without_data", 0), ("grpc.http2.max_ping_strikes", 0)] worker_server = grpc.server( @@ -121,11 +118,10 @@ def kill_worker_processes(): return worker_address, worker_server - def StartWorker(self, - start_worker_request, # type: beam_fn_api_pb2.StartWorkerRequest - unused_context - ): - # type: (...) -> beam_fn_api_pb2.StartWorkerResponse + def StartWorker( + self, + start_worker_request: beam_fn_api_pb2.StartWorkerRequest, + unused_context) -> beam_fn_api_pb2.StartWorkerResponse: try: if self._use_process: command = [ @@ -182,11 +178,10 @@ def StartWorker(self, except Exception: return beam_fn_api_pb2.StartWorkerResponse(error=traceback.format_exc()) - def StopWorker(self, - stop_worker_request, # type: beam_fn_api_pb2.StopWorkerRequest - unused_context - ): - # type: (...) -> beam_fn_api_pb2.StopWorkerResponse + def StopWorker( + self, + stop_worker_request: beam_fn_api_pb2.StopWorkerRequest, + unused_context) -> beam_fn_api_pb2.StopWorkerResponse: # applicable for process mode to ensure process cleanup # thread based workers terminate automatically worker_process = self._worker_processes.pop( diff --git a/sdks/python/apache_beam/runners/worker/worker_status.py b/sdks/python/apache_beam/runners/worker/worker_status.py index a7f4890344a8..2271b4495d79 100644 --- a/sdks/python/apache_beam/runners/worker/worker_status.py +++ b/sdks/python/apache_beam/runners/worker/worker_status.py @@ -96,9 +96,7 @@ def heap_dump(): return banner + heap + ending -def _state_cache_stats(state_cache): - # type: (StateCache) -> str - +def _state_cache_stats(state_cache: StateCache) -> str: """Gather state cache statistics.""" cache_stats = ['=' * 10 + ' CACHE STATS ' + '=' * 10] if not state_cache.is_cache_enabled(): From 77d818951094adce95954764037625a4454c42cd Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:48:22 -0700 Subject: [PATCH 20/29] Modernize python type hints for apache_beam/testing/benchmarks --- .../testing/benchmarks/nexmark/monitor.py | 3 +-- .../benchmarks/nexmark/nexmark_launcher.py | 3 +-- .../testing/benchmarks/nexmark/nexmark_perf.py | 4 +--- .../testing/benchmarks/nexmark/nexmark_util.py | 18 +++++++----------- 4 files changed, 10 insertions(+), 18 deletions(-) diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/monitor.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/monitor.py index 9d363bfeec61..064fbb11da5d 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/monitor.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/monitor.py @@ -32,8 +32,7 @@ class Monitor(object): name_prefix: a prefix for this Monitor's metrics' names, intended to be unique in per-monitor basis in pipeline """ - def __init__(self, namespace, name_prefix): - # type: (str, str) -> None + def __init__(self, namespace: str, name_prefix: str) -> None: self.namespace = namespace self.name_prefix = name_prefix self.doFn = MonitorDoFn(namespace, name_prefix) diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_launcher.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_launcher.py index bdf6f476212d..ec686543c3f9 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_launcher.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_launcher.py @@ -381,8 +381,7 @@ def monitor(self, job, event_monitor, result_monitor): return perf @staticmethod - def log_performance(perf): - # type: (NexmarkPerf) -> None + def log_performance(perf: NexmarkPerf) -> None: logging.info( 'input event count: %d, output event count: %d' % (perf.event_count, perf.result_count)) diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_perf.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_perf.py index e691b312e201..a783e58aacf7 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_perf.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_perf.py @@ -36,9 +36,7 @@ def __init__( # number of result produced self.result_count = result_count if result_count else -1 - def has_progress(self, previous_perf): - # type: (NexmarkPerf) -> bool - + def has_progress(self, previous_perf: NexmarkPerf) -> bool: """ Args: previous_perf: a NexmarkPerf object to be compared to self diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_util.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_util.py index 570fcb1e1ec0..ef53156d8be0 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_util.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_util.py @@ -219,15 +219,13 @@ def unnest_to_json(cand): return cand -def millis_to_timestamp(millis): - # type: (int) -> Timestamp +def millis_to_timestamp(millis: int) -> Timestamp: micro_second = millis * 1000 return Timestamp(micros=micro_second) -def get_counter_metric(result, namespace, name): - # type: (PipelineResult, str, str) -> int - +def get_counter_metric( + result: PipelineResult, namespace: str, name: str) -> int: """ get specific counter metric from pipeline result @@ -249,9 +247,8 @@ def get_counter_metric(result, namespace, name): return counters[0].result if len(counters) > 0 else -1 -def get_start_time_metric(result, namespace, name): - # type: (PipelineResult, str, str) -> int - +def get_start_time_metric( + result: PipelineResult, namespace: str, name: str) -> int: """ get the start time out of all times recorded by the specified distribution metric @@ -271,9 +268,8 @@ def get_start_time_metric(result, namespace, name): return min(min_list) if len(min_list) > 0 else -1 -def get_end_time_metric(result, namespace, name): - # type: (PipelineResult, str, str) -> int - +def get_end_time_metric( + result: PipelineResult, namespace: str, name: str) -> int: """ get the end time out of all times recorded by the specified distribution metric From acfd72c7c066d8bee257c096fc5d9c30cdec06f3 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:48:23 -0700 Subject: [PATCH 21/29] Modernize python type hints for apache_beam/testing/load_tests --- .../load_tests/load_test_metrics_utils.py | 32 ++++++++----------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py b/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py index d1da4667dcb8..e71c4a471923 100644 --- a/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py +++ b/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py @@ -199,7 +199,7 @@ def __init__( bq_table=None, bq_dataset=None, publish_to_bq=False, - influxdb_options=None, # type: Optional[InfluxDBMetricsPublisherOptions] + influxdb_options: Optional[InfluxDBMetricsPublisherOptions] = None, namespace=None, filters=None): """Initializes :class:`MetricsReader` . @@ -524,37 +524,31 @@ def save(self, results): class InfluxDBMetricsPublisherOptions(object): def __init__( self, - measurement, # type: str - db_name, # type: str - hostname, # type: str - user=None, # type: Optional[str] - password=None # type: Optional[str] - ): + measurement: str, + db_name: str, + hostname: str, + user: Optional[str] = None, + password: Optional[str] = None): self.measurement = measurement self.db_name = db_name self.hostname = hostname self.user = user self.password = password - def validate(self): - # type: () -> bool + def validate(self) -> bool: return bool(self.measurement) and bool(self.db_name) - def http_auth_enabled(self): - # type: () -> bool + def http_auth_enabled(self) -> bool: return self.user is not None and self.password is not None class InfluxDBMetricsPublisher(MetricsPublisher): """Publishes collected metrics to InfluxDB database.""" - def __init__( - self, - options # type: InfluxDBMetricsPublisherOptions - ): + def __init__(self, options: InfluxDBMetricsPublisherOptions): self.options = options - def publish(self, results): - # type: (List[Mapping[str, Union[float, str, int]]]) -> None + def publish( + self, results: List[Mapping[str, Union[float, str, int]]]) -> None: url = '{}/write'.format(self.options.hostname) payload = self._build_payload(results) query_str = {'db': self.options.db_name, 'precision': 's'} @@ -575,8 +569,8 @@ def publish(self, results): 'with an error message: %s' % (response.status_code, content['error'])) - def _build_payload(self, results): - # type: (List[Mapping[str, Union[float, str, int]]]) -> str + def _build_payload( + self, results: List[Mapping[str, Union[float, str, int]]]) -> str: def build_kv(mapping, key): return '{}={}'.format(key, mapping[key]) From 4402f2d44e9f11c18abff046986e92c529173d27 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:48:28 -0700 Subject: [PATCH 22/29] Modernize python type hints for apache_beam/transforms --- .../combinefn_lifecycle_pipeline.py | 2 +- .../apache_beam/transforms/combiners.py | 2 +- .../apache_beam/transforms/external_java.py | 4 +- .../apache_beam/transforms/resources.py | 32 ++-- .../apache_beam/transforms/sideinputs.py | 28 ++- sdks/python/apache_beam/transforms/trigger.py | 7 +- .../apache_beam/transforms/userstate.py | 143 +++++++--------- .../apache_beam/transforms/userstate_test.py | 2 +- sdks/python/apache_beam/transforms/window.py | 162 +++++++----------- 9 files changed, 153 insertions(+), 229 deletions(-) diff --git a/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py b/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py index 51f66b3c1bb0..33cc3db34811 100644 --- a/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py +++ b/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py @@ -35,7 +35,7 @@ @with_input_types(int) @with_output_types(int) class CallSequenceEnforcingCombineFn(beam.CombineFn): - instances = set() # type: Set[CallSequenceEnforcingCombineFn] + instances: Set[CallSequenceEnforcingCombineFn] = set() def __init__(self): super().__init__() diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py index 53d5190bf625..8b05e8da1df5 100644 --- a/sdks/python/apache_beam/transforms/combiners.py +++ b/sdks/python/apache_beam/transforms/combiners.py @@ -380,7 +380,7 @@ def push(hp, e): return False if self._compare or self._key: - heapc = [] # type: List[cy_combiners.ComparableValue] + heapc: List[cy_combiners.ComparableValue] = [] for bundle in bundles: if not heapc: heapc = [ diff --git a/sdks/python/apache_beam/transforms/external_java.py b/sdks/python/apache_beam/transforms/external_java.py index 534b2622c8a0..85eeff977609 100644 --- a/sdks/python/apache_beam/transforms/external_java.py +++ b/sdks/python/apache_beam/transforms/external_java.py @@ -46,8 +46,8 @@ class JavaExternalTransformTest(object): # This will be overwritten if set via a flag. - expansion_service_jar = None # type: str - expansion_service_port = None # type: int + expansion_service_jar: str = None + expansion_service_port: int = None class _RunWithExpansion(object): def __init__(self): diff --git a/sdks/python/apache_beam/transforms/resources.py b/sdks/python/apache_beam/transforms/resources.py index 7c4160df8edd..bc15271aadd0 100644 --- a/sdks/python/apache_beam/transforms/resources.py +++ b/sdks/python/apache_beam/transforms/resources.py @@ -52,13 +52,13 @@ class ResourceHint: """A superclass to define resource hints.""" # A unique URN, one per Resource Hint class. - urn = None # type: Optional[str] + urn: Optional[str] = None - _urn_to_known_hints = {} # type: Dict[str, type] - _name_to_known_hints = {} # type: Dict[str, type] + _urn_to_known_hints: Dict[str, type] = {} + _name_to_known_hints: Dict[str, type] = {} @classmethod - def parse(cls, value): # type: (str) -> Dict[str, bytes] + def parse(cls, value: str) -> Dict[str, bytes]: """Describes how to parse the hint. Override to specify a custom parsing logic.""" assert cls.urn is not None @@ -66,8 +66,7 @@ def parse(cls, value): # type: (str) -> Dict[str, bytes] return {cls.urn: ResourceHint._parse_str(value)} @classmethod - def get_merged_value( - cls, outer_value, inner_value): # type: (bytes, bytes) -> bytes + def get_merged_value(cls, outer_value: bytes, inner_value: bytes) -> bytes: """Reconciles values of a hint when the hint specified on a transform is also defined in an outer context, for example on a composite transform, or specified in the transform's execution environment. @@ -89,8 +88,7 @@ def is_registered(name): return name in ResourceHint._name_to_known_hints @staticmethod - def register_resource_hint( - hint_name, hint_class): # type: (str, type) -> None + def register_resource_hint(hint_name: str, hint_class: type) -> None: assert issubclass(hint_class, ResourceHint) assert hint_class.urn is not None ResourceHint._name_to_known_hints[hint_name] = hint_class @@ -164,12 +162,11 @@ class MinRamHint(ResourceHint): urn = resource_hints.MIN_RAM_BYTES.urn @classmethod - def parse(cls, value): # type: (str) -> Dict[str, bytes] + def parse(cls, value: str) -> Dict[str, bytes]: return {cls.urn: ResourceHint._parse_storage_size_str(value)} @classmethod - def get_merged_value( - cls, outer_value, inner_value): # type: (bytes, bytes) -> bytes + def get_merged_value(cls, outer_value: bytes, inner_value: bytes) -> bytes: return ResourceHint._use_max(outer_value, inner_value) @@ -183,8 +180,7 @@ class CpuCountHint(ResourceHint): urn = resource_hints.CPU_COUNT.urn @classmethod - def get_merged_value( - cls, outer_value, inner_value): # type: (bytes, bytes) -> bytes + def get_merged_value(cls, outer_value: bytes, inner_value: bytes) -> bytes: return ResourceHint._use_max(outer_value, inner_value) @@ -193,7 +189,7 @@ def get_merged_value( ResourceHint.register_resource_hint('cpuCount', CpuCountHint) -def parse_resource_hints(hints): # type: (Dict[Any, Any]) -> Dict[str, bytes] +def parse_resource_hints(hints: Dict[Any, Any]) -> Dict[str, bytes]: parsed_hints = {} for hint, value in hints.items(): try: @@ -208,8 +204,8 @@ def parse_resource_hints(hints): # type: (Dict[Any, Any]) -> Dict[str, bytes] return parsed_hints -def resource_hints_from_options(options): - # type: (Optional[PipelineOptions]) -> Dict[str, bytes] +def resource_hints_from_options( + options: Optional[PipelineOptions]) -> Dict[str, bytes]: if options is None: return {} hints = {} @@ -225,8 +221,8 @@ def resource_hints_from_options(options): def merge_resource_hints( - outer_hints, inner_hints -): # type: (Mapping[str, bytes], Mapping[str, bytes]) -> Dict[str, bytes] + outer_hints: Mapping[str, bytes], + inner_hints: Mapping[str, bytes]) -> Dict[str, bytes]: merged_hints = dict(inner_hints) for urn, outer_value in outer_hints.items(): if urn in inner_hints: diff --git a/sdks/python/apache_beam/transforms/sideinputs.py b/sdks/python/apache_beam/transforms/sideinputs.py index 5c92eafe5422..4951594e63b0 100644 --- a/sdks/python/apache_beam/transforms/sideinputs.py +++ b/sdks/python/apache_beam/transforms/sideinputs.py @@ -45,21 +45,20 @@ # Top-level function so we can identify it later. -def _global_window_mapping_fn(w, global_window=window.GlobalWindow()): - # type: (...) -> window.GlobalWindow +def _global_window_mapping_fn( + w, global_window=window.GlobalWindow()) -> window.GlobalWindow: return global_window -def default_window_mapping_fn(target_window_fn): - # type: (window.WindowFn) -> WindowMappingFn +def default_window_mapping_fn( + target_window_fn: window.WindowFn) -> WindowMappingFn: if target_window_fn == window.GlobalWindows(): return _global_window_mapping_fn if isinstance(target_window_fn, window.Sessions): raise RuntimeError("Sessions is not allowed in side inputs") - def map_via_end(source_window): - # type: (window.BoundedWindow) -> window.BoundedWindow + def map_via_end(source_window: window.BoundedWindow) -> window.BoundedWindow: return list( target_window_fn.assign( window.WindowFn.AssignContext(source_window.max_timestamp())))[-1] @@ -67,8 +66,7 @@ def map_via_end(source_window): return map_via_end -def get_sideinput_index(tag): - # type: (str) -> int +def get_sideinput_index(tag: str) -> int: match = re.match(SIDE_INPUT_REGEX, tag, re.DOTALL) if match: return int(match.group(1)) @@ -78,28 +76,22 @@ def get_sideinput_index(tag): class SideInputMap(object): """Represents a mapping of windows to side input values.""" - def __init__( - self, - view_class, # type: pvalue.AsSideInput - view_options, - iterable): + def __init__(self, view_class: pvalue.AsSideInput, view_options, iterable): self._window_mapping_fn = view_options.get( 'window_mapping_fn', _global_window_mapping_fn) self._view_class = view_class self._view_options = view_options self._iterable = iterable - self._cache = {} # type: Dict[window.BoundedWindow, Any] + self._cache: Dict[window.BoundedWindow, Any] = {} - def __getitem__(self, window): - # type: (window.BoundedWindow) -> Any + def __getitem__(self, window: window.BoundedWindow) -> Any: if window not in self._cache: target_window = self._window_mapping_fn(window) self._cache[window] = self._view_class._from_runtime_iterable( _FilteringIterable(self._iterable, target_window), self._view_options) return self._cache[window] - def is_globally_windowed(self): - # type: () -> bool + def is_globally_windowed(self) -> bool: return self._window_mapping_fn == _global_window_mapping_fn diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py index 6483859adcf8..63895704727f 100644 --- a/sdks/python/apache_beam/transforms/trigger.py +++ b/sdks/python/apache_beam/transforms/trigger.py @@ -181,8 +181,7 @@ class DataLossReason(Flag): # to `reason & flag == flag` -def _IncludesMayFinish(reason): - # type: (DataLossReason) -> bool +def _IncludesMayFinish(reason: DataLossReason) -> bool: return reason & DataLossReason.MAY_FINISH == DataLossReason.MAY_FINISH @@ -267,9 +266,7 @@ def reset(self, window, context): """Clear any state and timers used by this TriggerFn.""" pass - def may_lose_data(self, unused_windowing): - # type: (core.Windowing) -> DataLossReason - + def may_lose_data(self, unused_windowing: core.Windowing) -> DataLossReason: """Returns whether or not this trigger could cause data loss. A trigger can cause data loss in the following scenarios: diff --git a/sdks/python/apache_beam/transforms/userstate.py b/sdks/python/apache_beam/transforms/userstate.py index c266d0685472..45f73ab69ad0 100644 --- a/sdks/python/apache_beam/transforms/userstate.py +++ b/sdks/python/apache_beam/transforms/userstate.py @@ -50,8 +50,7 @@ class StateSpec(object): """Specification for a user DoFn state cell.""" - def __init__(self, name, coder): - # type: (str, Coder) -> None + def __init__(self, name: str, coder: Coder) -> None: if not isinstance(name, str): raise TypeError("name is not a string") if not isinstance(coder, Coder): @@ -59,19 +58,18 @@ def __init__(self, name, coder): self.name = name self.coder = coder - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: return '%s(%s)' % (self.__class__.__name__, self.name) - def to_runner_api(self, context): - # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec + def to_runner_api( + self, context: PipelineContext) -> beam_runner_api_pb2.StateSpec: raise NotImplementedError class ReadModifyWriteStateSpec(StateSpec): """Specification for a user DoFn value state cell.""" - def to_runner_api(self, context): - # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec + def to_runner_api( + self, context: PipelineContext) -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( read_modify_write_spec=beam_runner_api_pb2.ReadModifyWriteStateSpec( coder_id=context.coders.get_id(self.coder)), @@ -81,8 +79,8 @@ def to_runner_api(self, context): class BagStateSpec(StateSpec): """Specification for a user DoFn bag state cell.""" - def to_runner_api(self, context): - # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec + def to_runner_api( + self, context: PipelineContext) -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( bag_spec=beam_runner_api_pb2.BagStateSpec( element_coder_id=context.coders.get_id(self.coder)), @@ -92,8 +90,8 @@ def to_runner_api(self, context): class SetStateSpec(StateSpec): """Specification for a user DoFn Set State cell""" - def to_runner_api(self, context): - # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec + def to_runner_api( + self, context: PipelineContext) -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( set_spec=beam_runner_api_pb2.SetStateSpec( element_coder_id=context.coders.get_id(self.coder)), @@ -103,9 +101,11 @@ def to_runner_api(self, context): class CombiningValueStateSpec(StateSpec): """Specification for a user DoFn combining value state cell.""" - def __init__(self, name, coder=None, combine_fn=None): - # type: (str, Optional[Coder], Any) -> None - + def __init__( + self, + name: str, + coder: Optional[Coder] = None, + combine_fn: Any = None) -> None: """Initialize the specification for CombiningValue state. CombiningValueStateSpec(name, combine_fn) -> Coder-inferred combining value @@ -140,8 +140,8 @@ def __init__(self, name, coder=None, combine_fn=None): super().__init__(name, coder) - def to_runner_api(self, context): - # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec + def to_runner_api( + self, context: PipelineContext) -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( combining_spec=beam_runner_api_pb2.CombiningStateSpec( combine_fn=self.combine_fn.to_runner_api(context), @@ -169,29 +169,26 @@ class TimerSpec(object): """Specification for a user stateful DoFn timer.""" prefix = "ts-" - def __init__(self, name, time_domain): - # type: (str, str) -> None + def __init__(self, name: str, time_domain: str) -> None: self.name = self.prefix + name if time_domain not in (TimeDomain.WATERMARK, TimeDomain.REAL_TIME): raise ValueError('Unsupported TimeDomain: %r.' % (time_domain, )) self.time_domain = time_domain - self._attached_callback = None # type: Optional[Callable] + self._attached_callback: Optional[Callable] = None - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: return '%s(%s)' % (self.__class__.__name__, self.name) - def to_runner_api(self, context, key_coder, window_coder): - # type: (PipelineContext, Coder, Coder) -> beam_runner_api_pb2.TimerFamilySpec + def to_runner_api( + self, context: PipelineContext, key_coder: Coder, + window_coder: Coder) -> beam_runner_api_pb2.TimerFamilySpec: return beam_runner_api_pb2.TimerFamilySpec( time_domain=TimeDomain.to_runner_api(self.time_domain), timer_family_coder_id=context.coders.get_id( coders._TimerCoder(key_coder, window_coder))) -def on_timer(timer_spec): - # type: (TimerSpec) -> Callable[[CallableT], CallableT] - +def on_timer(timer_spec: TimerSpec) -> Callable[[CallableT], CallableT]: """Decorator for timer firing DoFn method. This decorator allows a user to specify an on_timer processing method @@ -208,8 +205,7 @@ def my_timer_expiry_callback(self): if not isinstance(timer_spec, TimerSpec): raise ValueError('@on_timer decorator expected TimerSpec.') - def _inner(method): - # type: (CallableT) -> CallableT + def _inner(method: CallableT) -> CallableT: if not callable(method): raise ValueError('@on_timer decorator expected callable.') if timer_spec._attached_callback: @@ -221,9 +217,7 @@ def _inner(method): return _inner -def get_dofn_specs(dofn): - # type: (DoFn) -> Tuple[Set[StateSpec], Set[TimerSpec]] - +def get_dofn_specs(dofn: DoFn) -> Tuple[Set[StateSpec], Set[TimerSpec]]: """Gets the state and timer specs for a DoFn, if any. Args: @@ -262,9 +256,7 @@ def get_dofn_specs(dofn): return all_state_specs, all_timer_specs -def is_stateful_dofn(dofn): - # type: (DoFn) -> bool - +def is_stateful_dofn(dofn: DoFn) -> bool: """Determines whether a given DoFn is a stateful DoFn.""" # A Stateful DoFn is a DoFn that uses user state or timers. @@ -272,9 +264,7 @@ def is_stateful_dofn(dofn): return bool(all_state_specs or all_timer_specs) -def validate_stateful_dofn(dofn): - # type: (DoFn) -> None - +def validate_stateful_dofn(dofn: DoFn) -> None: """Validates the proper specification of a stateful DoFn.""" # Get state and timer specs. @@ -306,12 +296,10 @@ def validate_stateful_dofn(dofn): class BaseTimer(object): - def clear(self, dynamic_timer_tag=''): - # type: (str) -> None + def clear(self, dynamic_timer_tag: str = '') -> None: raise NotImplementedError - def set(self, timestamp, dynamic_timer_tag=''): - # type: (Timestamp, str) -> None + def set(self, timestamp: Timestamp, dynamic_timer_tag: str = '') -> None: raise NotImplementedError @@ -321,66 +309,54 @@ def set(self, timestamp, dynamic_timer_tag=''): class RuntimeTimer(BaseTimer): """Timer interface object passed to user code.""" def __init__(self) -> None: - self._timer_recordings = {} # type: Dict[str, _TimerTuple] + self._timer_recordings: Dict[str, _TimerTuple] = {} self._cleared = False - self._new_timestamp = None # type: Optional[Timestamp] + self._new_timestamp: Optional[Timestamp] = None - def clear(self, dynamic_timer_tag=''): - # type: (str) -> None + def clear(self, dynamic_timer_tag: str = '') -> None: self._timer_recordings[dynamic_timer_tag] = _TimerTuple( cleared=True, timestamp=None) - def set(self, timestamp, dynamic_timer_tag=''): - # type: (Timestamp, str) -> None + def set(self, timestamp: Timestamp, dynamic_timer_tag: str = '') -> None: self._timer_recordings[dynamic_timer_tag] = _TimerTuple( cleared=False, timestamp=timestamp) class RuntimeState(object): """State interface object passed to user code.""" - def prefetch(self): - # type: () -> None + def prefetch(self) -> None: # The default implementation here does nothing. pass - def finalize(self): - # type: () -> None + def finalize(self) -> None: pass class ReadModifyWriteRuntimeState(RuntimeState): - def read(self): - # type: () -> Any + def read(self) -> Any: raise NotImplementedError(type(self)) - def write(self, value): - # type: (Any) -> None + def write(self, value: Any) -> None: raise NotImplementedError(type(self)) - def clear(self): - # type: () -> None + def clear(self) -> None: raise NotImplementedError(type(self)) - def commit(self): - # type: () -> None + def commit(self) -> None: raise NotImplementedError(type(self)) class AccumulatingRuntimeState(RuntimeState): - def read(self): - # type: () -> Iterable[Any] + def read(self) -> Iterable[Any]: raise NotImplementedError(type(self)) - def add(self, value): - # type: (Any) -> None + def add(self, value: Any) -> None: raise NotImplementedError(type(self)) - def clear(self): - # type: () -> None + def clear(self) -> None: raise NotImplementedError(type(self)) - def commit(self): - # type: () -> None + def commit(self) -> None: raise NotImplementedError(type(self)) @@ -398,24 +374,23 @@ class CombiningValueRuntimeState(AccumulatingRuntimeState): class UserStateContext(object): """Wrapper allowing user state and timers to be accessed by a DoFnInvoker.""" - def get_timer(self, - timer_spec, # type: TimerSpec - key, # type: Any - window, # type: windowed_value.BoundedWindow - timestamp, # type: Timestamp - pane, # type: windowed_value.PaneInfo - ): - # type: (...) -> BaseTimer + def get_timer( + self, + timer_spec: TimerSpec, + key: Any, + window: windowed_value.BoundedWindow, + timestamp: Timestamp, + pane: windowed_value.PaneInfo, + ) -> BaseTimer: raise NotImplementedError(type(self)) - def get_state(self, - state_spec, # type: StateSpec - key, # type: Any - window, # type: windowed_value.BoundedWindow - ): - # type: (...) -> RuntimeState + def get_state( + self, + state_spec: StateSpec, + key: Any, + window: windowed_value.BoundedWindow, + ) -> RuntimeState: raise NotImplementedError(type(self)) - def commit(self): - # type: () -> None + def commit(self) -> None: raise NotImplementedError(type(self)) diff --git a/sdks/python/apache_beam/transforms/userstate_test.py b/sdks/python/apache_beam/transforms/userstate_test.py index e17894ccb949..9caa31386569 100644 --- a/sdks/python/apache_beam/transforms/userstate_test.py +++ b/sdks/python/apache_beam/transforms/userstate_test.py @@ -437,7 +437,7 @@ def __repr__(self): class StatefulDoFnOnDirectRunnerTest(unittest.TestCase): # pylint: disable=expression-not-assigned - all_records = None # type: List[Any] + all_records: List[Any] = None def setUp(self): # Use state on the TestCase class, since other references would be pickled diff --git a/sdks/python/apache_beam/transforms/window.py b/sdks/python/apache_beam/transforms/window.py index c76b30fb8ff7..d13894cff8a1 100644 --- a/sdks/python/apache_beam/transforms/window.py +++ b/sdks/python/apache_beam/transforms/window.py @@ -104,8 +104,9 @@ class TimestampCombiner(object): OUTPUT_AT_EARLIEST_TRANSFORMED = 'OUTPUT_AT_EARLIEST_TRANSFORMED' @staticmethod - def get_impl(timestamp_combiner, window_fn): - # type: (beam_runner_api_pb2.OutputTime.Enum, WindowFn) -> timeutil.TimestampCombinerImpl + def get_impl( + timestamp_combiner: beam_runner_api_pb2.OutputTime.Enum, + window_fn: WindowFn) -> timeutil.TimestampCombinerImpl: if timestamp_combiner == TimestampCombiner.OUTPUT_AT_EOW: return timeutil.OutputAtEndOfWindowImpl() elif timestamp_combiner == TimestampCombiner.OUTPUT_AT_EARLIEST: @@ -124,18 +125,16 @@ class AssignContext(object): """Context passed to WindowFn.assign().""" def __init__( self, - timestamp, # type: TimestampTypes - element=None, # type: Optional[Any] - window=None # type: Optional[BoundedWindow] - ): - # type: (...) -> None + timestamp: TimestampTypes, + element: Optional[Any] = None, + window: Optional[BoundedWindow] = None) -> None: self.timestamp = Timestamp.of(timestamp) self.element = element self.window = window @abc.abstractmethod - def assign(self, assign_context): - # type: (AssignContext) -> Iterable[BoundedWindow] # noqa: F821 + def assign(self, assign_context: AssignContext) -> Iterable[BoundedWindow]: + # noqa: F821 """Associates windows to an element. @@ -149,35 +148,30 @@ def assign(self, assign_context): class MergeContext(object): """Context passed to WindowFn.merge() to perform merging, if any.""" - def __init__(self, windows): - # type: (Iterable[BoundedWindow]) -> None + def __init__(self, windows: Iterable[BoundedWindow]) -> None: self.windows = list(windows) - def merge(self, to_be_merged, merge_result): - # type: (Iterable[BoundedWindow], BoundedWindow) -> None + def merge( + self, + to_be_merged: Iterable[BoundedWindow], + merge_result: BoundedWindow) -> None: raise NotImplementedError @abc.abstractmethod - def merge(self, merge_context): - # type: (WindowFn.MergeContext) -> None - + def merge(self, merge_context: WindowFn.MergeContext) -> None: """Returns a window that is the result of merging a set of windows.""" raise NotImplementedError - def is_merging(self): - # type: () -> bool - + def is_merging(self) -> bool: """Returns whether this WindowFn merges windows.""" return True @abc.abstractmethod - def get_window_coder(self): - # type: () -> coders.Coder + def get_window_coder(self) -> coders.Coder: raise NotImplementedError - def get_transformed_output_time(self, window, input_timestamp): # pylint: disable=unused-argument - # type: (BoundedWindow, Timestamp) -> Timestamp - + def get_transformed_output_time( + self, window: BoundedWindow, input_timestamp: Timestamp) -> Timestamp: # pylint: disable=unused-argument """Given input time and output window, returns output time for window. If TimestampCombiner.OUTPUT_AT_EARLIEST_TRANSFORMED is used in the @@ -205,22 +199,18 @@ class BoundedWindow(object): Attributes: end: End of window. """ - def __init__(self, end): - # type: (TimestampTypes) -> None + def __init__(self, end: TimestampTypes) -> None: self._end = Timestamp.of(end) @property - def start(self): - # type: () -> Timestamp + def start(self) -> Timestamp: raise NotImplementedError @property - def end(self): - # type: () -> Timestamp + def end(self) -> Timestamp: return self._end - def max_timestamp(self): - # type: () -> Timestamp + def max_timestamp(self) -> Timestamp: return self.end.predecessor() def __eq__(self, other): @@ -270,12 +260,10 @@ def __lt__(self, other): return self.end < other.end return hash(self) < hash(other) - def intersects(self, other): - # type: (IntervalWindow) -> bool + def intersects(self, other: IntervalWindow) -> bool: return other.start < self.end or self.start < other.end - def union(self, other): - # type: (IntervalWindow) -> IntervalWindow + def union(self, other: IntervalWindow) -> IntervalWindow: return IntervalWindow( min(self.start, other.start), max(self.end, other.end)) @@ -291,8 +279,7 @@ class TimestampedValue(Generic[V]): value: The underlying value. timestamp: Timestamp associated with the value as seconds since Unix epoch. """ - def __init__(self, value, timestamp): - # type: (V, TimestampTypes) -> None + def __init__(self, value: V, timestamp: TimestampTypes) -> None: self.value = value self.timestamp = Timestamp.of(timestamp) @@ -314,15 +301,14 @@ def __lt__(self, other): class GlobalWindow(BoundedWindow): """The default window into which all data is placed (via GlobalWindows).""" - _instance = None # type: GlobalWindow + _instance: GlobalWindow = None def __new__(cls): if cls._instance is None: cls._instance = super(GlobalWindow, cls).__new__(cls) return cls._instance - def __init__(self): - # type: () -> None + def __init__(self) -> None: super().__init__(GlobalWindow._getTimestampFromProto()) def __repr__(self): @@ -336,25 +322,21 @@ def __eq__(self, other): return self is other or type(self) is type(other) @property - def start(self): - # type: () -> Timestamp + def start(self) -> Timestamp: return MIN_TIMESTAMP @staticmethod - def _getTimestampFromProto(): - # type: () -> Timestamp + def _getTimestampFromProto() -> Timestamp: ts_millis = int( common_urns.constants.GLOBAL_WINDOW_MAX_TIMESTAMP_MILLIS.constant) return Timestamp(micros=ts_millis * 1000) class NonMergingWindowFn(WindowFn): - def is_merging(self): - # type: () -> bool + def is_merging(self) -> bool: return False - def merge(self, merge_context): - # type: (WindowFn.MergeContext) -> None + def merge(self, merge_context: WindowFn.MergeContext) -> None: pass # No merging. @@ -363,34 +345,31 @@ class GlobalWindows(NonMergingWindowFn): @classmethod def windowed_batch( cls, - batch, # type: Any - timestamp=MIN_TIMESTAMP, # type: Timestamp - pane_info=windowed_value.PANE_INFO_UNKNOWN # type: windowed_value.PaneInfo - ): - # type: (...) -> windowed_value.WindowedBatch + batch: Any, + timestamp: Timestamp = MIN_TIMESTAMP, + pane_info: windowed_value.PaneInfo = windowed_value.PANE_INFO_UNKNOWN + ) -> windowed_value.WindowedBatch: return windowed_value.HomogeneousWindowedBatch.of( batch, timestamp, (GlobalWindow(), ), pane_info) @classmethod def windowed_value( cls, - value, # type: Any - timestamp=MIN_TIMESTAMP, # type: Timestamp - pane_info=windowed_value.PANE_INFO_UNKNOWN # type: windowed_value.PaneInfo - ): - # type: (...) -> WindowedValue + value: Any, + timestamp: Timestamp = MIN_TIMESTAMP, + pane_info: windowed_value.PaneInfo = windowed_value.PANE_INFO_UNKNOWN + ) -> WindowedValue: return WindowedValue(value, timestamp, (GlobalWindow(), ), pane_info) @classmethod def windowed_value_at_end_of_window(cls, value): return cls.windowed_value(value, GlobalWindow().max_timestamp()) - def assign(self, assign_context): - # type: (WindowFn.AssignContext) -> List[GlobalWindow] + def assign(self, + assign_context: WindowFn.AssignContext) -> List[GlobalWindow]: return [GlobalWindow()] - def get_window_coder(self): - # type: () -> coders.GlobalWindowCoder + def get_window_coder(self) -> coders.GlobalWindowCoder: return coders.GlobalWindowCoder() def __hash__(self): @@ -405,8 +384,8 @@ def to_runner_api_parameter(self, context): @staticmethod @urns.RunnerApiFn.register_urn(common_urns.global_windows.urn, None) - def from_runner_api_parameter(unused_fn_parameter, unused_context): - # type: (...) -> GlobalWindows + def from_runner_api_parameter( + unused_fn_parameter, unused_context) -> GlobalWindows: return GlobalWindows() @@ -424,11 +403,7 @@ class FixedWindows(NonMergingWindowFn): value in range [0, size). If it is not it will be normalized to this range. """ - def __init__( - self, - size, # type: DurationTypes - offset=0 # type: TimestampTypes - ): + def __init__(self, size: DurationTypes, offset: TimestampTypes = 0): """Initialize a ``FixedWindows`` function for a given size and offset. Args: @@ -443,14 +418,12 @@ def __init__( self.size = Duration.of(size) self.offset = Timestamp.of(offset) % self.size - def assign(self, context): - # type: (WindowFn.AssignContext) -> List[IntervalWindow] + def assign(self, context: WindowFn.AssignContext) -> List[IntervalWindow]: timestamp = context.timestamp start = timestamp - (timestamp - self.offset) % self.size return [IntervalWindow(start, start + self.size)] - def get_window_coder(self): - # type: () -> coders.IntervalWindowCoder + def get_window_coder(self) -> coders.IntervalWindowCoder: return coders.IntervalWindowCoder() def __eq__(self, other): @@ -473,8 +446,7 @@ def to_runner_api_parameter(self, context): @urns.RunnerApiFn.register_urn( common_urns.fixed_windows.urn, standard_window_fns_pb2.FixedWindowsPayload) - def from_runner_api_parameter(fn_parameter, unused_context): - # type: (...) -> FixedWindows + def from_runner_api_parameter(fn_parameter, unused_context) -> FixedWindows: return FixedWindows( size=Duration(micros=fn_parameter.size.ToMicroseconds()), offset=Timestamp(micros=fn_parameter.offset.ToMicroseconds())) @@ -494,20 +466,19 @@ class SlidingWindows(NonMergingWindowFn): t=N * period + offset where t=0 is the epoch. The offset must be a value in range [0, period). If it is not it will be normalized to this range. """ - - def __init__(self, - size, # type: DurationTypes - period, # type: DurationTypes - offset=0, # type: TimestampTypes - ): + def __init__( + self, + size: DurationTypes, + period: DurationTypes, + offset: TimestampTypes = 0, + ): if size <= 0: raise ValueError('The size parameter must be strictly positive.') self.size = Duration.of(size) self.period = Duration.of(period) self.offset = Timestamp.of(offset) % period - def assign(self, context): - # type: (WindowFn.AssignContext) -> List[IntervalWindow] + def assign(self, context: WindowFn.AssignContext) -> List[IntervalWindow]: timestamp = context.timestamp start = timestamp - ((timestamp - self.offset) % self.period) return [ @@ -520,8 +491,7 @@ def assign(self, context): -self.period.micros) ] - def get_window_coder(self): - # type: () -> coders.IntervalWindowCoder + def get_window_coder(self) -> coders.IntervalWindowCoder: return coders.IntervalWindowCoder() def __eq__(self, other): @@ -548,8 +518,7 @@ def to_runner_api_parameter(self, context): @urns.RunnerApiFn.register_urn( common_urns.sliding_windows.urn, standard_window_fns_pb2.SlidingWindowsPayload) - def from_runner_api_parameter(fn_parameter, unused_context): - # type: (...) -> SlidingWindows + def from_runner_api_parameter(fn_parameter, unused_context) -> SlidingWindows: return SlidingWindows( size=Duration(micros=fn_parameter.size.ToMicroseconds()), offset=Timestamp(micros=fn_parameter.offset.ToMicroseconds()), @@ -565,24 +534,20 @@ class Sessions(WindowFn): Attributes: gap_size: Size of the gap between windows as floating-point seconds. """ - def __init__(self, gap_size): - # type: (DurationTypes) -> None + def __init__(self, gap_size: DurationTypes) -> None: if gap_size <= 0: raise ValueError('The size parameter must be strictly positive.') self.gap_size = Duration.of(gap_size) - def assign(self, context): - # type: (WindowFn.AssignContext) -> List[IntervalWindow] + def assign(self, context: WindowFn.AssignContext) -> List[IntervalWindow]: timestamp = context.timestamp return [IntervalWindow(timestamp, timestamp + self.gap_size)] - def get_window_coder(self): - # type: () -> coders.IntervalWindowCoder + def get_window_coder(self) -> coders.IntervalWindowCoder: return coders.IntervalWindowCoder() - def merge(self, merge_context): - # type: (WindowFn.MergeContext) -> None - to_merge = [] # type: List[BoundedWindow] + def merge(self, merge_context: WindowFn.MergeContext) -> None: + to_merge: List[BoundedWindow] = [] end = MIN_TIMESTAMP for w in sorted(merge_context.windows, key=lambda w: w.start): if to_merge: @@ -620,7 +585,6 @@ def to_runner_api_parameter(self, context): @urns.RunnerApiFn.register_urn( common_urns.session_windows.urn, standard_window_fns_pb2.SessionWindowsPayload) - def from_runner_api_parameter(fn_parameter, unused_context): - # type: (...) -> Sessions + def from_runner_api_parameter(fn_parameter, unused_context) -> Sessions: return Sessions( gap_size=Duration(micros=fn_parameter.gap_size.ToMicroseconds())) From 79d4ffd20badf02498abdef39cd169b1d9fcc23f Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:48:34 -0700 Subject: [PATCH 23/29] Modernize python type hints for apache_beam/typehints --- .../apache_beam/typehints/decorators.py | 79 ++++++++----------- .../typehints/native_type_compatibility.py | 2 +- .../python/apache_beam/typehints/typehints.py | 2 +- 3 files changed, 35 insertions(+), 48 deletions(-) diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py index c24f2ed8f43c..ee0cb76d45d4 100644 --- a/sdks/python/apache_beam/typehints/decorators.py +++ b/sdks/python/apache_beam/typehints/decorators.py @@ -202,8 +202,11 @@ class IOTypeHints(NamedTuple): origin: List[str] @classmethod - def _make_origin(cls, bases, tb=True, msg=()): - # type: (List[IOTypeHints], bool, Iterable[str]) -> List[str] + def _make_origin( + cls, + bases: List[IOTypeHints], + tb: bool = True, + msg: Iterable[str] = ()) -> List[str]: if msg: res = list(msg) else: @@ -229,16 +232,12 @@ def _make_origin(cls, bases, tb=True, msg=()): return res @classmethod - def empty(cls): - # type: () -> IOTypeHints - + def empty(cls) -> IOTypeHints: """Construct a base IOTypeHints object with no hints.""" return IOTypeHints(None, None, []) @classmethod - def from_callable(cls, fn): - # type: (Callable) -> Optional[IOTypeHints] - + def from_callable(cls, fn: Callable) -> Optional[IOTypeHints]: """Construct an IOTypeHints object from a callable's signature. Supports Python 3 annotations. For partial annotations, sets unknown types @@ -292,23 +291,19 @@ def from_callable(cls, fn): output_types=(tuple(output_args), {}), origin=cls._make_origin([], tb=False, msg=msg)) - def with_input_types(self, *args, **kwargs): - # type: (...) -> IOTypeHints + def with_input_types(self, *args, **kwargs) -> IOTypeHints: return self._replace( input_types=(args, kwargs), origin=self._make_origin([self])) - def with_output_types(self, *args, **kwargs): - # type: (...) -> IOTypeHints + def with_output_types(self, *args, **kwargs) -> IOTypeHints: return self._replace( output_types=(args, kwargs), origin=self._make_origin([self])) - def with_input_types_from(self, other): - # type: (IOTypeHints) -> IOTypeHints + def with_input_types_from(self, other: IOTypeHints) -> IOTypeHints: return self._replace( input_types=other.input_types, origin=self._make_origin([self])) - def with_output_types_from(self, other): - # type: (IOTypeHints) -> IOTypeHints + def with_output_types_from(self, other: IOTypeHints) -> IOTypeHints: return self._replace( output_types=other.output_types, origin=self._make_origin([self])) @@ -355,14 +350,13 @@ def strip_pcoll(self): def strip_pcoll_helper( self, - my_type, # type: any - has_my_type, # type: Callable[[], bool] - my_key, # type: str - special_containers, # type: List[Union[PBegin, PDone, PCollection]] # noqa: F821 - error_str, # type: str - source_str # type: str - ): - # type: (...) -> IOTypeHints + my_type: any, + has_my_type: Callable[[], bool], + my_key: str, + special_containers: List[Union[PBegin, PDone, PCollection]], # noqa: F821 + error_str: str, + source_str: str + ) -> IOTypeHints: from apache_beam.pvalue import PCollection if not has_my_type() or not my_type or len(my_type[0]) != 1: @@ -396,9 +390,7 @@ def strip_pcoll_helper( origin=self._make_origin([self], tb=False, msg=[source_str]), **kwarg_dict) - def strip_iterable(self): - # type: () -> IOTypeHints - + def strip_iterable(self) -> IOTypeHints: """Removes outer Iterable (or equivalent) from output type. Only affects instances with simple output types, otherwise is a no-op. @@ -437,8 +429,7 @@ def strip_iterable(self): output_types=((yielded_type, ), {}), origin=self._make_origin([self], tb=False, msg=['strip_iterable()'])) - def with_defaults(self, hints): - # type: (Optional[IOTypeHints]) -> IOTypeHints + def with_defaults(self, hints: Optional[IOTypeHints]) -> IOTypeHints: if not hints: return self if not self: @@ -501,8 +492,7 @@ class WithTypeHints(object): def __init__(self, *unused_args, **unused_kwargs): self._type_hints = IOTypeHints.empty() - def _get_or_create_type_hints(self): - # type: () -> IOTypeHints + def _get_or_create_type_hints(self) -> IOTypeHints: # __init__ may have not been called try: # Only return an instance bound to self (see BEAM-8629). @@ -524,23 +514,24 @@ def get_type_hints(self): self.default_type_hints()).with_defaults( get_type_hints(self.__class__))) - def _set_type_hints(self, type_hints): - # type: (IOTypeHints) -> None + def _set_type_hints(self, type_hints: IOTypeHints) -> None: self._type_hints = type_hints def default_type_hints(self): return None - def with_input_types(self, *arg_hints, **kwarg_hints): - # type: (WithTypeHintsT, *Any, **Any) -> WithTypeHintsT + def with_input_types( + self: WithTypeHintsT, *arg_hints: Any, + **kwarg_hints: Any) -> WithTypeHintsT: arg_hints = native_type_compatibility.convert_to_beam_types(arg_hints) kwarg_hints = native_type_compatibility.convert_to_beam_types(kwarg_hints) self._type_hints = self._get_or_create_type_hints().with_input_types( *arg_hints, **kwarg_hints) return self - def with_output_types(self, *arg_hints, **kwarg_hints): - # type: (WithTypeHintsT, *Any, **Any) -> WithTypeHintsT + def with_output_types( + self: WithTypeHintsT, *arg_hints: Any, + **kwarg_hints: Any) -> WithTypeHintsT: arg_hints = native_type_compatibility.convert_to_beam_types(arg_hints) kwarg_hints = native_type_compatibility.convert_to_beam_types(kwarg_hints) self._type_hints = self._get_or_create_type_hints().with_output_types( @@ -681,9 +672,7 @@ def getcallargs_forhints(func, *type_args, **type_kwargs): return dict(bound_args) -def get_type_hints(fn): - # type: (Any) -> IOTypeHints - +def get_type_hints(fn: Any) -> IOTypeHints: """Gets the type hint associated with an arbitrary object fn. Always returns a valid IOTypeHints object, creating one if necessary. @@ -704,9 +693,8 @@ def get_type_hints(fn): # pylint: enable=protected-access -def with_input_types(*positional_hints, **keyword_hints): - # type: (*Any, **Any) -> Callable[[T], T] - +def with_input_types(*positional_hints: Any, + **keyword_hints: Any) -> Callable[[T], T]: """A decorator that type-checks defined type-hints with passed func arguments. All type-hinted arguments can be specified using positional arguments, @@ -790,9 +778,8 @@ def annotate_input_types(f): return annotate_input_types -def with_output_types(*return_type_hint, **kwargs): - # type: (*Any, **Any) -> Callable[[T], T] - +def with_output_types(*return_type_hint: Any, + **kwargs: Any) -> Callable[[T], T]: """A decorator that type-checks defined type-hints for return values(s). This decorator will type-check the return value(s) of the decorated function. diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index cd517cd6ac70..621adc44507e 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -184,7 +184,7 @@ def is_forward_ref(typ): # Mapping from typing.TypeVar/typehints.TypeVariable ids to an object of the # other type. Bidirectional mapping preserves typing.TypeVar instances. -_type_var_cache = {} # type: typing.Dict[int, typehints.TypeVariable] +_type_var_cache: typing.Dict[int, typehints.TypeVariable] = {} def convert_builtin_to_typing(typ): diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 70eb78b6ffc6..b368f0abdf3d 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -1257,7 +1257,7 @@ def __getitem__(self, type_params): # There is a circular dependency between defining this mapping # and using it in normalize(). Initialize it here and populate # it below. -_KNOWN_PRIMITIVE_TYPES = {} # type: typing.Dict[type, typing.Any] +_KNOWN_PRIMITIVE_TYPES: typing.Dict[type, typing.Any] = {} def normalize(x, none_as_type=False): From f2ffa5ec49e25146068426b8555a234363ceac07 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 15:48:36 -0700 Subject: [PATCH 24/29] Modernize python type hints for apache_beam/utils --- sdks/python/apache_beam/utils/profiler.py | 24 ++- sdks/python/apache_beam/utils/proto_utils.py | 34 ++-- .../apache_beam/utils/python_callable.py | 6 +- sdks/python/apache_beam/utils/sharded_key.py | 5 +- sdks/python/apache_beam/utils/shared.py | 28 +-- sdks/python/apache_beam/utils/timestamp.py | 160 ++++++------------ sdks/python/apache_beam/utils/urns.py | 57 +++---- 7 files changed, 114 insertions(+), 200 deletions(-) diff --git a/sdks/python/apache_beam/utils/profiler.py b/sdks/python/apache_beam/utils/profiler.py index d10703c17289..7463f50eb55b 100644 --- a/sdks/python/apache_beam/utils/profiler.py +++ b/sdks/python/apache_beam/utils/profiler.py @@ -45,18 +45,18 @@ class Profile(object): SORTBY = 'cumulative' - profile_output = None # type: str - stats = None # type: pstats.Stats + profile_output: str = None + stats: pstats.Stats = None def __init__( self, - profile_id, # type: str - profile_location=None, # type: Optional[str] - log_results=False, # type: bool - file_copy_fn=None, # type: Optional[Callable[[str, str], None]] - time_prefix='%Y-%m-%d_%H_%M_%S-', # type: str - enable_cpu_profiling=False, # type: bool - enable_memory_profiling=False, # type: bool + profile_id: str, + profile_location: Optional[str] = None, + log_results: bool = False, + file_copy_fn: Optional[Callable[[str, str], None]] = None, + time_prefix: str = '%Y-%m-%d_%H_%M_%S-', + enable_cpu_profiling: bool = False, + enable_memory_profiling: bool = False, ): """Creates a Profile object. @@ -139,8 +139,7 @@ def default_file_copy_fn(src, dest): filesystems.FileSystems.rename([dest + '.tmp'], [dest]) @staticmethod - def factory_from_options(options): - # type: (...) -> Optional[Callable[..., Profile]] + def factory_from_options(options) -> Optional[Callable[..., Profile]]: if options.profile_cpu or options.profile_memory: def create_profiler(profile_id, **kwargs): @@ -156,8 +155,7 @@ def create_profiler(profile_id, **kwargs): return None def _upload_profile_data( - self, profile_location, dir, data, write_binary=True): - # type: (...) -> str + self, profile_location, dir, data, write_binary=True) -> str: dump_location = os.path.join( profile_location, dir, diff --git a/sdks/python/apache_beam/utils/proto_utils.py b/sdks/python/apache_beam/utils/proto_utils.py index 3a5e020df167..cc637dead477 100644 --- a/sdks/python/apache_beam/utils/proto_utils.py +++ b/sdks/python/apache_beam/utils/proto_utils.py @@ -38,14 +38,12 @@ @overload -def pack_Any(msg): - # type: (message.Message) -> any_pb2.Any +def pack_Any(msg: message.Message) -> any_pb2.Any: pass @overload -def pack_Any(msg): - # type: (None) -> None +def pack_Any(msg: None) -> None: pass @@ -63,14 +61,12 @@ def pack_Any(msg): @overload -def unpack_Any(any_msg, msg_class): - # type: (any_pb2.Any, Type[MessageT]) -> MessageT +def unpack_Any(any_msg: any_pb2.Any, msg_class: Type[MessageT]) -> MessageT: pass @overload -def unpack_Any(any_msg, msg_class): - # type: (any_pb2.Any, None) -> None +def unpack_Any(any_msg: any_pb2.Any, msg_class: None) -> None: pass @@ -87,14 +83,13 @@ def unpack_Any(any_msg, msg_class): @overload -def parse_Bytes(serialized_bytes, msg_class): - # type: (bytes, Type[MessageT]) -> MessageT +def parse_Bytes(serialized_bytes: bytes, msg_class: Type[MessageT]) -> MessageT: pass @overload -def parse_Bytes(serialized_bytes, msg_class): - # type: (bytes, Union[Type[bytes], None]) -> bytes +def parse_Bytes( + serialized_bytes: bytes, msg_class: Union[Type[bytes], None]) -> bytes: pass @@ -109,9 +104,7 @@ def parse_Bytes(serialized_bytes, msg_class): return msg -def pack_Struct(**kwargs): - # type: (...) -> struct_pb2.Struct - +def pack_Struct(**kwargs) -> struct_pb2.Struct: """Returns a struct containing the values indicated by kwargs. """ msg = struct_pb2.Struct() @@ -120,16 +113,13 @@ def pack_Struct(**kwargs): return msg -def from_micros(cls, micros): - # type: (Type[TimeMessageT], int) -> TimeMessageT +def from_micros(cls: Type[TimeMessageT], micros: int) -> TimeMessageT: result = cls() result.FromMicroseconds(micros) return result -def to_Timestamp(time): - # type: (Union[int, float]) -> timestamp_pb2.Timestamp - +def to_Timestamp(time: Union[int, float]) -> timestamp_pb2.Timestamp: """Convert a float returned by time.time() to a Timestamp. """ seconds = int(time) @@ -137,9 +127,7 @@ def to_Timestamp(time): return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) -def from_Timestamp(timestamp): - # type: (timestamp_pb2.Timestamp) -> float - +def from_Timestamp(timestamp: timestamp_pb2.Timestamp) -> float: """Convert a Timestamp to a float expressed as seconds since the epoch. """ return timestamp.seconds + float(timestamp.nanos) / 10**9 diff --git a/sdks/python/apache_beam/utils/python_callable.py b/sdks/python/apache_beam/utils/python_callable.py index 70aa7cb39e5c..f6f507300ea8 100644 --- a/sdks/python/apache_beam/utils/python_callable.py +++ b/sdks/python/apache_beam/utils/python_callable.py @@ -43,8 +43,7 @@ class PythonCallableWithSource(object): is a valid chunk of source code. """ - def __init__(self, source): - # type: (str) -> None + def __init__(self, source: str) -> None: self._source = source self._callable = self.load_from_source(source) @@ -120,8 +119,7 @@ def default_label(self): def _argspec_fn(self): return self._callable - def get_source(self): - # type: () -> str + def get_source(self) -> str: return self._source def __call__(self, *args, **kwargs): diff --git a/sdks/python/apache_beam/utils/sharded_key.py b/sdks/python/apache_beam/utils/sharded_key.py index 9a03ab36bfd2..f6492779ef34 100644 --- a/sdks/python/apache_beam/utils/sharded_key.py +++ b/sdks/python/apache_beam/utils/sharded_key.py @@ -30,9 +30,8 @@ class ShardedKey(object): def __init__( self, key, - shard_id, # type: bytes - ): - # type: (...) -> None + shard_id: bytes, + ) -> None: assert shard_id is not None self._key = key self._shard_id = shard_id diff --git a/sdks/python/apache_beam/utils/shared.py b/sdks/python/apache_beam/utils/shared.py index d7eed350b0c1..bb04d1a19fb0 100644 --- a/sdks/python/apache_beam/utils/shared.py +++ b/sdks/python/apache_beam/utils/shared.py @@ -109,13 +109,7 @@ def __init__(self): self._ref = None self._tag = None - def acquire( - self, - constructor_fn, # type: Callable[[], Any] - tag=None # type: Any - ): - # type: (...) -> Any - + def acquire(self, constructor_fn: Callable[[], Any], tag: Any = None) -> Any: """Acquire a reference to the object this shared control block manages. Args: @@ -209,18 +203,14 @@ def __init__(self): # to keep it alive self._keepalive = (None, None) - def make_key(self): - # type: (...) -> Text + def make_key(self) -> Text: return str(uuid.uuid1()) def acquire( self, - key, # type: Text - constructor_fn, # type: Callable[[], Any] - tag=None # type: Any - ): - # type: (...) -> Any - + key: Text, + constructor_fn: Callable[[], Any], + tag: Any = None) -> Any: """Acquire a reference to a Shared object. Args: @@ -280,13 +270,7 @@ class Shared(object): def __init__(self): self._key = _shared_map.make_key() - def acquire( - self, - constructor_fn, # type: Callable[[], Any] - tag=None # type: Any - ): - # type: (...) -> Any - + def acquire(self, constructor_fn: Callable[[], Any], tag: Any = None) -> Any: """Acquire a reference to the object associated with this Shared handle. Args: diff --git a/sdks/python/apache_beam/utils/timestamp.py b/sdks/python/apache_beam/utils/timestamp.py index c54b5bf44e5c..dbc768308e60 100644 --- a/sdks/python/apache_beam/utils/timestamp.py +++ b/sdks/python/apache_beam/utils/timestamp.py @@ -52,8 +52,10 @@ class Timestamp(object): especially after arithmetic operations (for example, 10000000 % 0.1 evaluates to 0.0999999994448885). """ - def __init__(self, seconds=0, micros=0): - # type: (Union[int, float], Union[int, float]) -> None + def __init__( + self, + seconds: Union[int, float] = 0, + micros: Union[int, float] = 0) -> None: if not isinstance(seconds, (int, float)): raise TypeError( 'Cannot interpret %s %s as seconds.' % (seconds, type(seconds))) @@ -63,9 +65,7 @@ def __init__(self, seconds=0, micros=0): self.micros = int(seconds * 1000000) + int(micros) @staticmethod - def of(seconds): - # type: (TimestampTypes) -> Timestamp - + def of(seconds: TimestampTypes) -> Timestamp: """Return the Timestamp for the given number of seconds. If the input is already a Timestamp, the input itself will be returned. @@ -88,19 +88,15 @@ def of(seconds): 'Cannot interpret %s %s as Timestamp.' % (seconds, type(seconds))) @staticmethod - def now(): - # type: () -> Timestamp + def now() -> Timestamp: return Timestamp(seconds=time.time()) @staticmethod - def _epoch_datetime_utc(): - # type: () -> datetime.datetime + def _epoch_datetime_utc() -> datetime.datetime: return datetime.datetime.fromtimestamp(0, pytz.utc) @classmethod - def from_utc_datetime(cls, dt): - # type: (datetime.datetime) -> Timestamp - + def from_utc_datetime(cls, dt: datetime.datetime) -> Timestamp: """Create a ``Timestamp`` instance from a ``datetime.datetime`` object. Args: @@ -117,9 +113,7 @@ def from_utc_datetime(cls, dt): return Timestamp(duration.total_seconds()) @classmethod - def from_rfc3339(cls, rfc3339): - # type: (str) -> Timestamp - + def from_rfc3339(cls, rfc3339: str) -> Timestamp: """Create a ``Timestamp`` instance from an RFC 3339 compliant string. .. note:: @@ -140,20 +134,15 @@ def seconds(self) -> int: """Returns the timestamp in seconds.""" return self.micros // 1000000 - def predecessor(self): - # type: () -> Timestamp - + def predecessor(self) -> Timestamp: """Returns the largest timestamp smaller than self.""" return Timestamp(micros=self.micros - 1) - def successor(self): - # type: () -> Timestamp - + def successor(self) -> Timestamp: """Returns the smallest timestamp larger than self.""" return Timestamp(micros=self.micros + 1) - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: micros = self.micros sign = '' if micros < 0: @@ -165,9 +154,7 @@ def __repr__(self): return 'Timestamp(%s%d.%06d)' % (sign, int_part, frac_part) return 'Timestamp(%s%d)' % (sign, int_part) - def to_utc_datetime(self, has_tz=False): - # type: (bool) -> datetime.datetime - + def to_utc_datetime(self, has_tz: bool = False) -> datetime.datetime: """Returns a ``datetime.datetime`` object of UTC for this Timestamp. Note that this method returns a ``datetime.datetime`` object without a @@ -189,23 +176,18 @@ def to_utc_datetime(self, has_tz=False): epoch = epoch.replace(tzinfo=None) return epoch + datetime.timedelta(microseconds=self.micros) - def to_rfc3339(self): - # type: () -> str + def to_rfc3339(self) -> str: # Append 'Z' for UTC timezone. return self.to_utc_datetime().isoformat() + 'Z' - def to_proto(self): - # type: () -> timestamp_pb2.Timestamp - + def to_proto(self) -> timestamp_pb2.Timestamp: """Returns the `google.protobuf.timestamp_pb2` representation.""" secs = self.micros // 1000000 nanos = (self.micros % 1000000) * 1000 return timestamp_pb2.Timestamp(seconds=secs, nanos=nanos) @staticmethod - def from_proto(timestamp_proto): - # type: (timestamp_pb2.Timestamp) -> Timestamp - + def from_proto(timestamp_proto: timestamp_pb2.Timestamp) -> Timestamp: """Creates a Timestamp from a `google.protobuf.timestamp_pb2`. Note that the google has a sub-second resolution of nanoseconds whereas this @@ -227,18 +209,15 @@ class has a resolution of microsends. This class will truncate the return Timestamp( seconds=timestamp_proto.seconds, micros=timestamp_proto.nanos // 1000) - def __float__(self): - # type: () -> float + def __float__(self) -> float: # Note that the returned value may have lost precision. return self.micros / 1000000 - def __int__(self): - # type: () -> int + def __int__(self) -> int: # Note that the returned value may have lost precision. return self.micros // 1000000 - def __eq__(self, other): - # type: (object) -> bool + def __eq__(self, other: object) -> bool: # Allow comparisons between Duration and Timestamp values. if isinstance(other, (Duration, Timestamp)): return self.micros == other.micros @@ -248,57 +227,48 @@ def __eq__(self, other): # Support equality with other types return NotImplemented - def __lt__(self, other): - # type: (TimestampDurationTypes) -> bool + def __lt__(self, other: TimestampDurationTypes) -> bool: # Allow comparisons between Duration and Timestamp values. if not isinstance(other, Duration): other = Timestamp.of(other) return self.micros < other.micros - def __gt__(self, other): - # type: (TimestampDurationTypes) -> bool + def __gt__(self, other: TimestampDurationTypes) -> bool: return not (self < other or self == other) - def __le__(self, other): - # type: (TimestampDurationTypes) -> bool + def __le__(self, other: TimestampDurationTypes) -> bool: return self < other or self == other - def __ge__(self, other): - # type: (TimestampDurationTypes) -> bool + def __ge__(self, other: TimestampDurationTypes) -> bool: return not self < other - def __hash__(self): - # type: () -> int + def __hash__(self) -> int: return hash(self.micros) - def __add__(self, other): - # type: (DurationTypes) -> Timestamp + def __add__(self, other: DurationTypes) -> Timestamp: other = Duration.of(other) return Timestamp(micros=self.micros + other.micros) - def __radd__(self, other): - # type: (DurationTypes) -> Timestamp + def __radd__(self, other: DurationTypes) -> Timestamp: return self + other @overload - def __sub__(self, other): - # type: (DurationTypes) -> Timestamp + def __sub__(self, other: DurationTypes) -> Timestamp: pass @overload - def __sub__(self, other): - # type: (Timestamp) -> Duration + def __sub__(self, other: Timestamp) -> Duration: pass - def __sub__(self, other): - # type: (Union[DurationTypes, Timestamp]) -> Union[Timestamp, Duration] + def __sub__( + self, other: Union[DurationTypes, + Timestamp]) -> Union[Timestamp, Duration]: if isinstance(other, Timestamp): return Duration(micros=self.micros - other.micros) other = Duration.of(other) return Timestamp(micros=self.micros - other.micros) - def __mod__(self, other): - # type: (DurationTypes) -> Duration + def __mod__(self, other: DurationTypes) -> Duration: other = Duration.of(other) return Duration(micros=self.micros % other.micros) @@ -319,14 +289,14 @@ class Duration(object): especially after arithmetic operations (for example, 10000000 % 0.1 evaluates to 0.0999999994448885). """ - def __init__(self, seconds=0, micros=0): - # type: (Union[int, float], Union[int, float]) -> None + def __init__( + self, + seconds: Union[int, float] = 0, + micros: Union[int, float] = 0) -> None: self.micros = int(seconds * 1000000) + int(micros) @staticmethod - def of(seconds): - # type: (DurationTypes) -> Duration - + def of(seconds: DurationTypes) -> Duration: """Return the Duration for the given number of seconds since Unix epoch. If the input is already a Duration, the input itself will be returned. @@ -344,18 +314,14 @@ def of(seconds): return seconds return Duration(seconds) - def to_proto(self): - # type: () -> duration_pb2.Duration - + def to_proto(self) -> duration_pb2.Duration: """Returns the `google.protobuf.duration_pb2` representation.""" secs = self.micros // 1000000 nanos = (self.micros % 1000000) * 1000 return duration_pb2.Duration(seconds=secs, nanos=nanos) @staticmethod - def from_proto(duration_proto): - # type: (duration_pb2.Duration) -> Duration - + def from_proto(duration_proto: duration_pb2.Duration) -> Duration: """Creates a Duration from a `google.protobuf.duration_pb2`. Note that the google has a sub-second resolution of nanoseconds whereas this @@ -377,8 +343,7 @@ class has a resolution of microsends. This class will truncate the return Duration( seconds=duration_proto.seconds, micros=duration_proto.nanos // 1000) - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: micros = self.micros sign = '' if micros < 0: @@ -390,13 +355,11 @@ def __repr__(self): return 'Duration(%s%d.%06d)' % (sign, int_part, frac_part) return 'Duration(%s%d)' % (sign, int_part) - def __float__(self): - # type: () -> float + def __float__(self) -> float: # Note that the returned value may have lost precision. return self.micros / 1000000 - def __eq__(self, other): - # type: (object) -> bool + def __eq__(self, other: object) -> bool: # Allow comparisons between Duration and Timestamp values. if isinstance(other, (Duration, Timestamp)): return self.micros == other.micros @@ -406,65 +369,52 @@ def __eq__(self, other): # Support equality with other types return NotImplemented - def __lt__(self, other): - # type: (TimestampDurationTypes) -> bool + def __lt__(self, other: TimestampDurationTypes) -> bool: # Allow comparisons between Duration and Timestamp values. if not isinstance(other, Timestamp): other = Duration.of(other) return self.micros < other.micros - def __gt__(self, other): - # type: (TimestampDurationTypes) -> bool + def __gt__(self, other: TimestampDurationTypes) -> bool: return not (self < other or self == other) - def __le__(self, other): - # type: (TimestampDurationTypes) -> bool + def __le__(self, other: TimestampDurationTypes) -> bool: return self < other or self == other - def __ge__(self, other): - # type: (TimestampDurationTypes) -> bool + def __ge__(self, other: TimestampDurationTypes) -> bool: return not self < other - def __hash__(self): - # type: () -> int + def __hash__(self) -> int: return hash(self.micros) - def __neg__(self): - # type: () -> Duration + def __neg__(self) -> Duration: return Duration(micros=-self.micros) - def __add__(self, other): - # type: (DurationTypes) -> Duration + def __add__(self, other: DurationTypes) -> Duration: if isinstance(other, Timestamp): # defer to Timestamp.__add__ return NotImplemented other = Duration.of(other) return Duration(micros=self.micros + other.micros) - def __radd__(self, other): - # type: (DurationTypes) -> Duration + def __radd__(self, other: DurationTypes) -> Duration: return self + other - def __sub__(self, other): - # type: (DurationTypes) -> Duration + def __sub__(self, other: DurationTypes) -> Duration: other = Duration.of(other) return Duration(micros=self.micros - other.micros) - def __rsub__(self, other): - # type: (DurationTypes) -> Duration + def __rsub__(self, other: DurationTypes) -> Duration: return -(self - other) - def __mul__(self, other): - # type: (DurationTypes) -> Duration + def __mul__(self, other: DurationTypes) -> Duration: other = Duration.of(other) return Duration(micros=self.micros * other.micros // 1000000) - def __rmul__(self, other): - # type: (DurationTypes) -> Duration + def __rmul__(self, other: DurationTypes) -> Duration: return self * other - def __mod__(self, other): - # type: (DurationTypes) -> Duration + def __mod__(self, other: DurationTypes) -> Duration: other = Duration.of(other) return Duration(micros=self.micros % other.micros) diff --git a/sdks/python/apache_beam/utils/urns.py b/sdks/python/apache_beam/utils/urns.py index 3f2cb43e9753..9d2f393fd7a3 100644 --- a/sdks/python/apache_beam/utils/urns.py +++ b/sdks/python/apache_beam/utils/urns.py @@ -65,7 +65,7 @@ class RunnerApiFn(object): # classes + abc metaclass # __metaclass__ = abc.ABCMeta - _known_urns = {} # type: Dict[str, Tuple[Optional[type], ConstructorFn]] + _known_urns: Dict[str, Tuple[Optional[type], ConstructorFn]] = {} # @abc.abstractmethod is disabled here to avoid an error with mypy. mypy # performs abc.abtractmethod/property checks even if a class does @@ -74,9 +74,8 @@ class RunnerApiFn(object): # mypy incorrectly infers that this method has not been overridden with a # concrete implementation. # @abc.abstractmethod - def to_runner_api_parameter(self, unused_context): - # type: (PipelineContext) -> Tuple[str, Any] - + def to_runner_api_parameter( + self, unused_context: PipelineContext) -> Tuple[str, Any]: """Returns the urn and payload for this Fn. The returned urn(s) should be registered with `register_urn`. @@ -87,40 +86,38 @@ def to_runner_api_parameter(self, unused_context): @overload def register_urn( cls, - urn, # type: str - parameter_type, # type: Type[T] - ): - # type: (...) -> Callable[[Callable[[T, PipelineContext], Any]], Callable[[T, PipelineContext], Any]] + urn: str, + parameter_type: Type[T], + ) -> Callable[[Callable[[T, PipelineContext], Any]], + Callable[[T, PipelineContext], Any]]: pass @classmethod @overload def register_urn( cls, - urn, # type: str - parameter_type, # type: None - ): - # type: (...) -> Callable[[Callable[[bytes, PipelineContext], Any]], Callable[[bytes, PipelineContext], Any]] + urn: str, + parameter_type: None, + ) -> Callable[[Callable[[bytes, PipelineContext], Any]], + Callable[[bytes, PipelineContext], Any]]: pass @classmethod @overload - def register_urn(cls, - urn, # type: str - parameter_type, # type: Type[T] - fn # type: Callable[[T, PipelineContext], Any] - ): - # type: (...) -> None + def register_urn( + cls, + urn: str, + parameter_type: Type[T], + fn: Callable[[T, PipelineContext], Any]) -> None: pass @classmethod @overload - def register_urn(cls, - urn, # type: str - parameter_type, # type: None - fn # type: Callable[[bytes, PipelineContext], Any] - ): - # type: (...) -> None + def register_urn( + cls, + urn: str, + parameter_type: None, + fn: Callable[[bytes, PipelineContext], Any]) -> None: pass @classmethod @@ -161,9 +158,8 @@ def register_pickle_urn(cls, pickle_urn): lambda proto, unused_context: pickler.loads(proto.value)) - def to_runner_api(self, context): - # type: (PipelineContext) -> beam_runner_api_pb2.FunctionSpec - + def to_runner_api( + self, context: PipelineContext) -> beam_runner_api_pb2.FunctionSpec: """Returns an FunctionSpec encoding this Fn. Prefer overriding self.to_runner_api_parameter. @@ -176,9 +172,10 @@ def to_runner_api(self, context): typed_param, message.Message) else typed_param) @classmethod - def from_runner_api(cls, fn_proto, context): - # type: (Type[RunnerApiFnT], beam_runner_api_pb2.FunctionSpec, PipelineContext) -> RunnerApiFnT - + def from_runner_api( + cls: Type[RunnerApiFnT], + fn_proto: beam_runner_api_pb2.FunctionSpec, + context: PipelineContext) -> RunnerApiFnT: """Converts from an FunctionSpec to a Fn object. Prefer registering a urn with its parameter type and constructor. From abdb1b742a9e9fdd0dac91a705ae4339f1b6aa29 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Jul 2024 16:30:40 -0700 Subject: [PATCH 25/29] Fix circular references, mypy complaints. --- sdks/python/apache_beam/coders/row_coder.py | 2 +- sdks/python/apache_beam/dataframe/convert.py | 6 +-- .../apache_beam/dataframe/partitionings.py | 2 +- .../apache_beam/internal/metrics/cells.py | 12 ++--- .../apache_beam/internal/metrics/metric.py | 20 ++++---- .../apache_beam/internal/module_test.py | 3 +- .../io/azure/blobstoragefilesystem.py | 8 +--- .../apache_beam/io/gcp/bigquery_avro_tools.py | 25 +++++----- .../io/gcp/bigquery_schema_tools.py | 2 +- .../io/gcp/datastore/v1new/types.py | 9 ++-- sdks/python/apache_beam/io/gcp/pubsub.py | 4 +- sdks/python/apache_beam/io/iobase.py | 6 +-- .../apache_beam/io/restriction_trackers.py | 2 +- sdks/python/apache_beam/metrics/metric.py | 32 +++++++------ .../apache_beam/ml/gcp/naturallanguageml.py | 2 +- sdks/python/apache_beam/pvalue.py | 32 ++++++------- .../runners/direct/bundle_factory.py | 4 +- .../consumer_tracking_pipeline_visitor.py | 5 +- .../apache_beam/runners/direct/executor.py | 12 ++--- .../runners/direct/watermark_manager.py | 16 +++---- .../display/pipeline_graph_renderer.py | 8 ++-- .../apache_beam/runners/pipeline_context.py | 19 ++++---- .../portability/abstract_job_service.py | 14 ++---- .../portability/fn_api_runner/execution.py | 33 +++++++------ .../portability/fn_api_runner/fn_runner.py | 34 +++++++------- .../runners/portability/job_server.py | 2 +- .../runners/portability/local_job_service.py | 7 +-- .../runners/portability/portable_runner.py | 7 +-- sdks/python/apache_beam/runners/runner.py | 16 +++---- sdks/python/apache_beam/runners/sdf_utils.py | 6 +-- .../apache_beam/runners/worker/log_handler.py | 5 +- .../runners/worker/statesampler.py | 4 +- .../runners/worker/statesampler_slow.py | 10 ++-- .../benchmarks/nexmark/nexmark_perf.py | 2 +- .../load_tests/load_test_metrics_utils.py | 2 +- .../combinefn_lifecycle_pipeline.py | 2 +- .../apache_beam/transforms/external_java.py | 5 +- .../apache_beam/transforms/resources.py | 7 +-- .../apache_beam/transforms/sideinputs.py | 2 +- .../apache_beam/transforms/userstate.py | 28 +++++------ .../apache_beam/transforms/userstate_test.py | 2 +- sdks/python/apache_beam/transforms/window.py | 32 +++++++------ .../apache_beam/typehints/decorators.py | 23 +++++----- sdks/python/apache_beam/utils/profiler.py | 6 +-- sdks/python/apache_beam/utils/timestamp.py | 46 +++++++++---------- sdks/python/apache_beam/utils/urns.py | 21 ++++----- 46 files changed, 261 insertions(+), 286 deletions(-) diff --git a/sdks/python/apache_beam/coders/row_coder.py b/sdks/python/apache_beam/coders/row_coder.py index 0d0392e94214..e93abbc887fb 100644 --- a/sdks/python/apache_beam/coders/row_coder.py +++ b/sdks/python/apache_beam/coders/row_coder.py @@ -117,7 +117,7 @@ def from_type_hint(cls, type_hint, registry): return cls(schema) @staticmethod - def from_payload(payload: bytes) -> RowCoder: + def from_payload(payload: bytes) -> 'RowCoder': return RowCoder(proto_utils.parse_Bytes(payload, schema_pb2.Schema)) def __reduce__(self): diff --git a/sdks/python/apache_beam/dataframe/convert.py b/sdks/python/apache_beam/dataframe/convert.py index 817cabc4b076..0ccd4489767b 100644 --- a/sdks/python/apache_beam/dataframe/convert.py +++ b/sdks/python/apache_beam/dataframe/convert.py @@ -17,10 +17,10 @@ import inspect import warnings import weakref -from typing import TYPE_CHECKING from typing import Any from typing import Dict from typing import Iterable +from typing import Optional from typing import Tuple from typing import Union @@ -35,10 +35,6 @@ from apache_beam.dataframe.schemas import generate_proxy from apache_beam.typehints.pandas_type_compatibility import dtype_to_fieldtype -if TYPE_CHECKING: - # pylint: disable=ungrouped-imports - from typing import Optional - # TODO: Or should this be called as_dataframe? def to_dataframe( diff --git a/sdks/python/apache_beam/dataframe/partitionings.py b/sdks/python/apache_beam/dataframe/partitionings.py index ca37c504334b..0ff09e111480 100644 --- a/sdks/python/apache_beam/dataframe/partitionings.py +++ b/sdks/python/apache_beam/dataframe/partitionings.py @@ -32,7 +32,7 @@ class Partitioning(object): def __repr__(self): return self.__class__.__name__ - def is_subpartitioning_of(self, other: Partitioning) -> bool: + def is_subpartitioning_of(self, other: 'Partitioning') -> bool: """Returns whether self is a sub-partition of other. Specifically, returns whether something partitioned by self is necissarily diff --git a/sdks/python/apache_beam/internal/metrics/cells.py b/sdks/python/apache_beam/internal/metrics/cells.py index 9a5f8c1f3113..c7b546258a70 100644 --- a/sdks/python/apache_beam/internal/metrics/cells.py +++ b/sdks/python/apache_beam/internal/metrics/cells.py @@ -55,7 +55,7 @@ def __init__(self, bucket_type): def reset(self): self.data = HistogramAggregator(self._bucket_type).identity_element() - def combine(self, other: HistogramCell) -> HistogramCell: + def combine(self, other: 'HistogramCell') -> 'HistogramCell': result = HistogramCell(self._bucket_type) result.data = self.data.combine(other.data) return result @@ -63,7 +63,7 @@ def combine(self, other: HistogramCell) -> HistogramCell: def update(self, value): self.data.histogram.record(value) - def get_cumulative(self) -> HistogramData: + def get_cumulative(self) -> 'HistogramData': return self.data.get_cumulative() def to_runner_api_monitoring_info(self, name, transform_id): @@ -90,7 +90,7 @@ def __hash__(self): class HistogramResult(object): - def __init__(self, data: HistogramData) -> None: + def __init__(self, data: 'HistogramData') -> None: self.data = data def __eq__(self, other): @@ -139,10 +139,10 @@ def __hash__(self): def __repr__(self): return 'HistogramData({})'.format(self.histogram.get_percentile_info()) - def get_cumulative(self) -> HistogramData: + def get_cumulative(self) -> 'HistogramData': return HistogramData(self.histogram) - def combine(self, other: Optional[HistogramData]) -> HistogramData: + def combine(self, other: Optional['HistogramData']) -> 'HistogramData': if other is None: return self @@ -156,7 +156,7 @@ class HistogramAggregator(MetricAggregator): Values aggregated should be ``HistogramData`` objects. """ - def __init__(self, bucket_type: BucketType) -> None: + def __init__(self, bucket_type: 'BucketType') -> None: self._bucket_type = bucket_type def identity_element(self) -> HistogramData: diff --git a/sdks/python/apache_beam/internal/metrics/metric.py b/sdks/python/apache_beam/internal/metrics/metric.py index 35a5b4f3bc6a..8acf800ff8c6 100644 --- a/sdks/python/apache_beam/internal/metrics/metric.py +++ b/sdks/python/apache_beam/internal/metrics/metric.py @@ -86,8 +86,8 @@ def counter( def histogram( namespace: Union[Type, str], name: str, - bucket_type: BucketType, - logger: Optional[MetricLogger] = None) -> Metrics.DelegatingHistogram: + bucket_type: 'BucketType', + logger: Optional['MetricLogger'] = None) -> 'Metrics.DelegatingHistogram': """Obtains or creates a Histogram metric. Args: @@ -109,8 +109,8 @@ class DelegatingHistogram(Histogram): def __init__( self, metric_name: MetricName, - bucket_type: BucketType, - logger: Optional[MetricLogger]) -> None: + bucket_type: 'BucketType', + logger: Optional['MetricLogger']) -> None: super().__init__(metric_name) self.metric_name = metric_name self.cell_type = HistogramCellFactory(bucket_type) @@ -126,14 +126,14 @@ def update(self, value: object) -> None: class MetricLogger(object): """Simple object to locally aggregate and log metrics.""" def __init__(self) -> None: - self._metric: Dict[MetricName, MetricCell] = {} + self._metric: Dict[MetricName, 'MetricCell'] = {} self._lock = threading.Lock() self._last_logging_millis = int(time.time() * 1000) self.minimum_logging_frequency_msec = 180000 def update( self, - cell_type: Union[Type[MetricCell], MetricCellFactory], + cell_type: Union[Type['MetricCell'], 'MetricCellFactory'], metric_name: MetricName, value: object) -> None: cell = self._get_metric_cell(cell_type, metric_name) @@ -141,8 +141,8 @@ def update( def _get_metric_cell( self, - cell_type: Union[Type[MetricCell], MetricCellFactory], - metric_name: MetricName) -> MetricCell: + cell_type: Union[Type['MetricCell'], 'MetricCellFactory'], + metric_name: MetricName) -> 'MetricCell': with self._lock: if metric_name not in self._metric: self._metric[metric_name] = cell_type() @@ -187,7 +187,7 @@ def __init__( self.base_labels = base_labels if base_labels else {} self.request_count_urn = request_count_urn - def call(self, status: Union[int, str, HttpError]) -> None: + def call(self, status: Union[int, str, 'HttpError']) -> None: """Record the status of the call into appropriate metrics.""" canonical_status = self.convert_to_canonical_status_string(status) additional_labels = {monitoring_infos.STATUS_LABEL: canonical_status} @@ -200,7 +200,7 @@ def call(self, status: Union[int, str, HttpError]) -> None: request_counter.inc() def convert_to_canonical_status_string( - self, status: Union[int, str, HttpError]) -> str: + self, status: Union[int, str, 'HttpError']) -> str: """Converts a status to a canonical GCP status cdoe string.""" http_status_code = None if isinstance(status, int): diff --git a/sdks/python/apache_beam/internal/module_test.py b/sdks/python/apache_beam/internal/module_test.py index 55a178b93f82..89e7d7eaa821 100644 --- a/sdks/python/apache_beam/internal/module_test.py +++ b/sdks/python/apache_beam/internal/module_test.py @@ -21,7 +21,6 @@ import re import sys -from typing import Type class TopClass(object): @@ -64,7 +63,7 @@ def get(self): class RecursiveClass(object): """A class that contains a reference to itself.""" - SELF_TYPE: Type[RecursiveClass] = None + SELF_TYPE = None def __init__(self, datum): self.datum = 'RecursiveClass:%s' % datum diff --git a/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py b/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py index d18440a50947..c446c17247d7 100644 --- a/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py +++ b/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py @@ -150,9 +150,7 @@ def create( self, path, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO) -> BinaryIO: - # noqa: F821 - + compression_type=CompressionTypes.AUTO): """Returns a write channel for the given file path. Args: @@ -168,9 +166,7 @@ def open( self, path, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO) -> BinaryIO: - # noqa: F821 - + compression_type=CompressionTypes.AUTO): """Returns a read channel for the given file path. Args: diff --git a/sdks/python/apache_beam/io/gcp/bigquery_avro_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_avro_tools.py index c54ba74a7343..b6c177fc7418 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_avro_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_avro_tools.py @@ -23,6 +23,9 @@ NOTHING IN THIS FILE HAS BACKWARDS COMPATIBILITY GUARANTEES. """ +from typing import Any +from typing import Dict + # BigQuery types as listed in # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types # with aliases (RECORD, BOOLEAN, FLOAT, INTEGER) as defined in @@ -63,20 +66,20 @@ def get_record_schema_from_dict_table_schema( - schema_name: Text, - table_schema: Dict[Text, Any], - namespace: Text = "apache_beam.io.gcp.bigquery") -> Dict[Text, Any]: + schema_name: str, + table_schema: Dict[str, Any], + namespace: str = "apache_beam.io.gcp.bigquery") -> Dict[str, Any]: # noqa: F821 """Convert a table schema into an Avro schema. Args: - schema_name (Text): The name of the record. - table_schema (Dict[Text, Any]): A BigQuery table schema in dict form. - namespace (Text): The namespace of the Avro schema. + schema_name (str): The name of the record. + table_schema (Dict[str, Any]): A BigQuery table schema in dict form. + namespace (str): The namespace of the Avro schema. Returns: - Dict[Text, Any]: The schema as an Avro RecordSchema. + Dict[str, Any]: The schema as an Avro RecordSchema. """ avro_fields = [ table_field_to_avro_field(field, ".".join((namespace, schema_name))) @@ -92,17 +95,17 @@ def get_record_schema_from_dict_table_schema( } -def table_field_to_avro_field(table_field: Dict[Text, Any], - namespace: str) -> Dict[Text, Any]: +def table_field_to_avro_field(table_field: Dict[str, Any], + namespace: str) -> Dict[str, Any]: # noqa: F821 """Convert a BigQuery field to an avro field. Args: - table_field (Dict[Text, Any]): A BigQuery field in dict form. + table_field (Dict[str, Any]): A BigQuery field in dict form. Returns: - Dict[Text, Any]: An equivalent Avro field in dict form. + Dict[str, Any]: An equivalent Avro field in dict form. """ assert "type" in table_field, \ "Unable to get type for table field {}".format(table_field) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py index 4ba8e2b84bad..beb373a7dea3 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py @@ -54,7 +54,7 @@ def generate_user_type_from_bq_schema( - the_table_schema, selected_fields: bigquery.TableSchema = None) -> type: + the_table_schema, selected_fields: 'bigquery.TableSchema' = None) -> type: """Convert a schema of type TableSchema into a pcollection element. Args: the_table_schema: A BQ schema of type TableSchema diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py index 2029886f24a9..f7ce69099ca0 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py @@ -25,7 +25,6 @@ from typing import Iterable from typing import List from typing import Optional -from typing import Text from typing import Union from google.cloud.datastore import entity @@ -155,10 +154,10 @@ def __repr__(self): class Key(object): def __init__( self, - path_elements: List[Union[Text, int]], - parent: Optional[Key] = None, - project: Optional[Text] = None, - namespace: Optional[Text] = None): + path_elements: List[Union[str, int]], + parent: Optional['Key'] = None, + project: Optional[str] = None, + namespace: Optional[str] = None): """ Represents a Datastore key. diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py index 6267837269e1..1c6cf31a48ce 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub.py +++ b/sdks/python/apache_beam/io/gcp/pubsub.py @@ -110,7 +110,7 @@ def __repr__(self): return 'PubsubMessage(%s, %s)' % (self.data, self.attributes) @staticmethod - def _from_proto_str(proto_msg: bytes) -> PubsubMessage: + def _from_proto_str(proto_msg: bytes) -> 'PubsubMessage': """Construct from serialized form of ``PubsubMessage``. Args: @@ -183,7 +183,7 @@ def _to_proto_str(self, for_publish=False): return serialized @staticmethod - def _from_message(msg: Any) -> PubsubMessage: + def _from_message(msg: Any) -> 'PubsubMessage': """Construct from ``google.cloud.pubsub_v1.subscriber.message.Message``. https://googleapis.github.io/google-cloud-python/latest/pubsub/subscriber/api/message.html diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py index 1b416d37f8a4..53215275e050 100644 --- a/sdks/python/apache_beam/io/iobase.py +++ b/sdks/python/apache_beam/io/iobase.py @@ -182,7 +182,7 @@ def get_range_tracker( self, start_position: Optional[Any], stop_position: Optional[Any], - ) -> RangeTracker: + ) -> 'RangeTracker': """Returns a RangeTracker for a given position range. Framework may invoke ``read()`` method with the RangeTracker object returned @@ -1281,7 +1281,7 @@ def current_restriction(self): """ raise NotImplementedError - def current_progress(self) -> RestrictionProgress: + def current_progress(self) -> 'RestrictionProgress': """Returns a RestrictionProgress object representing the current progress. This API is recommended to be implemented. The runner can do a better job @@ -1471,7 +1471,7 @@ def fraction_remaining(self) -> float: else: return float(self._remaining) / self.total_work - def with_completed(self, completed: int) -> RestrictionProgress: + def with_completed(self, completed: int) -> 'RestrictionProgress': return RestrictionProgress( fraction=self._fraction, remaining=self._remaining, completed=completed) diff --git a/sdks/python/apache_beam/io/restriction_trackers.py b/sdks/python/apache_beam/io/restriction_trackers.py index 37d902aa5f3f..4b819e87a8d6 100644 --- a/sdks/python/apache_beam/io/restriction_trackers.py +++ b/sdks/python/apache_beam/io/restriction_trackers.py @@ -62,7 +62,7 @@ def split(self, desired_num_offsets_per_split, min_num_offsets_per_split=1): yield OffsetRange(current_split_start, current_split_stop) current_split_start = current_split_stop - def split_at(self, split_pos) -> Tuple[OffsetRange, OffsetRange]: + def split_at(self, split_pos) -> Tuple['OffsetRange', 'OffsetRange']: return OffsetRange(self.start, split_pos), OffsetRange(split_pos, self.stop) def new_tracker(self): diff --git a/sdks/python/apache_beam/metrics/metric.py b/sdks/python/apache_beam/metrics/metric.py index c107e55fcd89..3722af6dc17a 100644 --- a/sdks/python/apache_beam/metrics/metric.py +++ b/sdks/python/apache_beam/metrics/metric.py @@ -67,7 +67,7 @@ def get_namespace(namespace: Union[Type, str]) -> str: @staticmethod def counter( - namespace: Union[Type, str], name: str) -> Metrics.DelegatingCounter: + namespace: Union[Type, str], name: str) -> 'Metrics.DelegatingCounter': """Obtains or creates a Counter metric. Args: @@ -82,7 +82,8 @@ def counter( @staticmethod def distribution( - namespace: Union[Type, str], name: str) -> Metrics.DelegatingDistribution: + namespace: Union[Type, str], + name: str) -> 'Metrics.DelegatingDistribution': """Obtains or creates a Distribution metric. Distribution metrics are restricted to integer-only distributions. @@ -98,7 +99,8 @@ def distribution( return Metrics.DelegatingDistribution(MetricName(namespace, name)) @staticmethod - def gauge(namespace: Union[Type, str], name: str) -> Metrics.DelegatingGauge: + def gauge( + namespace: Union[Type, str], name: str) -> 'Metrics.DelegatingGauge': """Obtains or creates a Gauge metric. Gauge metrics are restricted to integer-only values. @@ -143,7 +145,7 @@ class MetricResults(object): GAUGES = "gauges" @staticmethod - def _matches_name(filter: MetricsFilter, metric_key: MetricKey) -> bool: + def _matches_name(filter: 'MetricsFilter', metric_key: 'MetricKey') -> bool: if ((filter.namespaces and metric_key.metric.namespace not in filter.namespaces) or (filter.names and metric_key.metric.name not in filter.names)): @@ -172,7 +174,7 @@ def _matches_sub_path(actual_scope: str, filter_scope: str) -> bool: filter_scope.split('/'), actual_scope.split('/'))) @staticmethod - def _matches_scope(filter: MetricsFilter, metric_key: MetricKey) -> bool: + def _matches_scope(filter: 'MetricsFilter', metric_key: 'MetricKey') -> bool: if not filter.steps: return True @@ -183,7 +185,8 @@ def _matches_scope(filter: MetricsFilter, metric_key: MetricKey) -> bool: return False @staticmethod - def matches(filter: Optional[MetricsFilter], metric_key: MetricKey) -> bool: + def matches( + filter: Optional['MetricsFilter'], metric_key: 'MetricKey') -> bool: if filter is None: return True @@ -194,7 +197,8 @@ def matches(filter: Optional[MetricsFilter], metric_key: MetricKey) -> bool: def query( self, - filter: Optional[MetricsFilter] = None) -> Dict[str, List[MetricResults]]: + filter: Optional['MetricsFilter'] = None + ) -> Dict[str, List['MetricResults']]: """Queries the runner for existing user metrics that match the filter. It should return a dictionary, with lists of each kind of metric, and @@ -239,36 +243,36 @@ def names(self) -> FrozenSet[str]: def namespaces(self) -> FrozenSet[str]: return frozenset(self._namespaces) - def with_metric(self, metric: Metric) -> MetricsFilter: + def with_metric(self, metric: 'Metric') -> 'MetricsFilter': name = metric.metric_name.name or '' namespace = metric.metric_name.namespace or '' return self.with_name(name).with_namespace(namespace) - def with_name(self, name: str) -> MetricsFilter: + def with_name(self, name: str) -> 'MetricsFilter': return self.with_names([name]) - def with_names(self, names: Iterable[str]) -> MetricsFilter: + def with_names(self, names: Iterable[str]) -> 'MetricsFilter': if isinstance(names, str): raise ValueError('Names must be a collection, not a string') self._names.update(names) return self - def with_namespace(self, namespace: Union[Type, str]) -> MetricsFilter: + def with_namespace(self, namespace: Union[Type, str]) -> 'MetricsFilter': return self.with_namespaces([namespace]) def with_namespaces( - self, namespaces: Iterable[Union[Type, str]]) -> MetricsFilter: + self, namespaces: Iterable[Union[Type, str]]) -> 'MetricsFilter': if isinstance(namespaces, str): raise ValueError('Namespaces must be an iterable, not a string') self._namespaces.update([Metrics.get_namespace(ns) for ns in namespaces]) return self - def with_step(self, step: str) -> MetricsFilter: + def with_step(self, step: str) -> 'MetricsFilter': return self.with_steps([step]) - def with_steps(self, steps: Iterable[str]) -> MetricsFilter: + def with_steps(self, steps: Iterable[str]) -> 'MetricsFilter': if isinstance(steps, str): raise ValueError('Steps must be an iterable, not a string') diff --git a/sdks/python/apache_beam/ml/gcp/naturallanguageml.py b/sdks/python/apache_beam/ml/gcp/naturallanguageml.py index ceeae522890c..f46b8d61639b 100644 --- a/sdks/python/apache_beam/ml/gcp/naturallanguageml.py +++ b/sdks/python/apache_beam/ml/gcp/naturallanguageml.py @@ -66,7 +66,7 @@ def __init__( self.from_gcs = from_gcs @staticmethod - def to_dict(document: Document) -> Mapping[str, Optional[str]]: + def to_dict(document: 'Document') -> Mapping[str, Optional[str]]: if document.from_gcs: dict_repr = {'gcs_content_uri': document.content} else: diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index 5aff1d35aa24..5a400570cf18 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -82,10 +82,10 @@ class PValue(object): """ def __init__( self, - pipeline: Pipeline, + pipeline: 'Pipeline', tag: Optional[str] = None, - element_type: Optional[Union[type, typehints.TypeConstraint]] = None, - windowing: Optional[Windowing] = None, + element_type: Optional[Union[type, 'typehints.TypeConstraint']] = None, + windowing: Optional['Windowing'] = None, is_bounded=True, ): """Initializes a PValue with all arguments hidden behind keyword arguments. @@ -152,7 +152,7 @@ def __hash__(self): return hash((self.tag, self.producer)) @property - def windowing(self) -> Windowing: + def windowing(self) -> 'Windowing': if not hasattr(self, '_windowing'): assert self.producer is not None and self.producer.transform is not None self._windowing = self.producer.transform.get_windowing( @@ -166,7 +166,7 @@ def __reduce_ex__(self, unused_version): return _InvalidUnpickledPCollection, () @staticmethod - def from_(pcoll: PValue, is_bounded: Optional[bool] = None) -> PCollection: + def from_(pcoll: PValue, is_bounded: Optional[bool] = None) -> 'PCollection': """Create a PCollection, using another PCollection as a starting point. Transfers relevant attributes. @@ -176,7 +176,7 @@ def from_(pcoll: PValue, is_bounded: Optional[bool] = None) -> PCollection: return PCollection(pcoll.pipeline, is_bounded=is_bounded) def to_runner_api( - self, context: PipelineContext) -> beam_runner_api_pb2.PCollection: + self, context: 'PipelineContext') -> beam_runner_api_pb2.PCollection: return beam_runner_api_pb2.PCollection( unique_name=self._unique_name(), coder_id=context.coder_id_from_element_type( @@ -196,7 +196,7 @@ def _unique_name(self) -> str: @staticmethod def from_runner_api( proto: beam_runner_api_pb2.PCollection, - context: PipelineContext) -> PCollection: + context: 'PipelineContext') -> 'PCollection': # Producer and tag will be filled in later, the key point is that the same # object is returned for the same pcollection id. # We pass None for the PCollection's Pipeline to avoid a cycle during @@ -235,8 +235,8 @@ class DoOutputsTuple(object): """An object grouping the multiple outputs of a ParDo or FlatMap transform.""" def __init__( self, - pipeline: Pipeline, - transform: ParDo, + pipeline: 'Pipeline', + transform: 'ParDo', tags: Sequence[str], main_tag: Optional[str], allow_unknown_tags: Optional[bool] = None, @@ -380,7 +380,7 @@ def _windowed_coder(self): # TODO(robertwb): Get rid of _from_runtime_iterable and _view_options # in favor of _side_input_data(). - def _side_input_data(self) -> SideInputData: + def _side_input_data(self) -> 'SideInputData': view_options = self._view_options() from_runtime_iterable = type(self)._from_runtime_iterable return SideInputData( @@ -389,13 +389,13 @@ def _side_input_data(self) -> SideInputData: lambda iterable: from_runtime_iterable(iterable, view_options)) def to_runner_api( - self, context: PipelineContext) -> beam_runner_api_pb2.SideInput: + self, context: 'PipelineContext') -> beam_runner_api_pb2.SideInput: return self._side_input_data().to_runner_api(context) @staticmethod def from_runner_api( proto: beam_runner_api_pb2.SideInput, - context: PipelineContext) -> _UnpickledSideInput: + context: 'PipelineContext') -> '_UnpickledSideInput': return _UnpickledSideInput(SideInputData.from_runner_api(proto, context)) @staticmethod @@ -407,7 +407,7 @@ def requires_keyed_input(self): class _UnpickledSideInput(AsSideInput): - def __init__(self, side_input_data: SideInputData) -> None: + def __init__(self, side_input_data: 'SideInputData') -> None: self._data = side_input_data self._window_mapping_fn = side_input_data.window_mapping_fn @@ -441,14 +441,14 @@ class SideInputData(object): def __init__( self, access_pattern: str, - window_mapping_fn: sideinputs.WindowMappingFn, + window_mapping_fn: 'sideinputs.WindowMappingFn', view_fn): self.access_pattern = access_pattern self.window_mapping_fn = window_mapping_fn self.view_fn = view_fn def to_runner_api( - self, context: PipelineContext) -> beam_runner_api_pb2.SideInput: + self, context: 'PipelineContext') -> beam_runner_api_pb2.SideInput: return beam_runner_api_pb2.SideInput( access_pattern=beam_runner_api_pb2.FunctionSpec( urn=self.access_pattern), @@ -462,7 +462,7 @@ def to_runner_api( @staticmethod def from_runner_api( proto: beam_runner_api_pb2.SideInput, - unused_context: PipelineContext) -> SideInputData: + unused_context: 'PipelineContext') -> 'SideInputData': assert proto.view_fn.urn == python_urns.PICKLED_VIEWFN assert ( proto.window_mapping_fn.urn == python_urns.PICKLED_WINDOW_MAPPING_FN) diff --git a/sdks/python/apache_beam/runners/direct/bundle_factory.py b/sdks/python/apache_beam/runners/direct/bundle_factory.py index 8553fdb50656..95d8c06111a2 100644 --- a/sdks/python/apache_beam/runners/direct/bundle_factory.py +++ b/sdks/python/apache_beam/runners/direct/bundle_factory.py @@ -45,12 +45,12 @@ def __init__(self, stacked: bool) -> None: def create_bundle( self, output_pcollection: Union[pvalue.PBegin, - pvalue.PCollection]) -> _Bundle: + pvalue.PCollection]) -> '_Bundle': return _Bundle(output_pcollection, self._stacked) def create_empty_committed_bundle( self, output_pcollection: Union[pvalue.PBegin, - pvalue.PCollection]) -> _Bundle: + pvalue.PCollection]) -> '_Bundle': bundle = self.create_bundle(output_pcollection) bundle.commit(None) return bundle diff --git a/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py b/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py index 60b0e5beeae2..91085274f32a 100644 --- a/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py +++ b/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py @@ -19,16 +19,13 @@ # pytype: skip-file -from typing import TYPE_CHECKING from typing import Dict from typing import Set from apache_beam import pvalue +from apache_beam.pipeline import AppliedPTransform from apache_beam.pipeline import PipelineVisitor -if TYPE_CHECKING: - from apache_beam.pipeline import AppliedPTransform - class ConsumerTrackingPipelineVisitor(PipelineVisitor): """For internal use only; no backwards-compatibility guarantees. diff --git a/sdks/python/apache_beam/runners/direct/executor.py b/sdks/python/apache_beam/runners/direct/executor.py index 8c389ffccf5d..59e282b91bd9 100644 --- a/sdks/python/apache_beam/runners/direct/executor.py +++ b/sdks/python/apache_beam/runners/direct/executor.py @@ -64,7 +64,7 @@ class _ExecutorServiceWorker(threading.Thread): TIMEOUT = 5 def __init__( - self, queue: queue.Queue[_ExecutorService.CallableTask], index): + self, queue: queue.Queue['_ExecutorService.CallableTask'], index): super().__init__() self.queue = queue self._index = index @@ -84,12 +84,12 @@ def _update_name(self, task=None): self.name = 'Thread: %d, %s (%s)' % ( self._index, name, 'executing' if task else 'idle') - def _get_task_or_none(self) -> Optional[_ExecutorService.CallableTask]: + def _get_task_or_none(self) -> Optional['_ExecutorService.CallableTask']: try: # Do not block indefinitely, otherwise we may not act for a requested # shutdown. return self.queue.get( - timeout=_ExecutorService._ExecutorServiceWorker.TIMEOUT) + timeout='_ExecutorService._ExecutorServiceWorker.TIMEOUT') except queue.Empty: return None @@ -118,7 +118,7 @@ def __init__(self, num_workers): ] self.shutdown_requested = False - def submit(self, task: _ExecutorService.CallableTask) -> None: + def submit(self, task: '_ExecutorService.CallableTask') -> None: assert isinstance(task, _ExecutorService.CallableTask) if not self.shutdown_requested: self.queue.put(task) @@ -496,8 +496,8 @@ def schedule_consumption( assert on_complete if self.transform_evaluator_registry.should_execute_serially( consumer_applied_ptransform): - transform_executor_service: _TransformEvaluationState = self.transform_executor_services.serial( - consumer_applied_ptransform) + transform_executor_service: _TransformEvaluationState = ( + self.transform_executor_services.serial(consumer_applied_ptransform)) else: transform_executor_service = self.transform_executor_services.parallel() diff --git a/sdks/python/apache_beam/runners/direct/watermark_manager.py b/sdks/python/apache_beam/runners/direct/watermark_manager.py index 077f9f05e183..b4b16c0a5d21 100644 --- a/sdks/python/apache_beam/runners/direct/watermark_manager.py +++ b/sdks/python/apache_beam/runners/direct/watermark_manager.py @@ -28,16 +28,16 @@ from typing import Tuple from apache_beam import pipeline +from apache_beam.pipeline import AppliedPTransform from apache_beam import pvalue from apache_beam.runners.direct.util import TimerFiring from apache_beam.utils.timestamp import MAX_TIMESTAMP from apache_beam.utils.timestamp import MIN_TIMESTAMP from apache_beam.utils.timestamp import TIME_GRANULARITY +from apache_beam.utils.timestamp import Timestamp if TYPE_CHECKING: - from apache_beam.pipeline import AppliedPTransform from apache_beam.runners.direct.bundle_factory import _Bundle - from apache_beam.utils.timestamp import Timestamp class WatermarkManager(object): @@ -108,7 +108,7 @@ def get_watermarks( def update_watermarks( self, - completed_committed_bundle: _Bundle, + completed_committed_bundle: '_Bundle', applied_ptransform: AppliedPTransform, completed_timers, outputs, @@ -131,8 +131,8 @@ def _update_pending( input_committed_bundle, applied_ptransform: AppliedPTransform, completed_timers, - output_committed_bundles: Iterable[_Bundle], - unprocessed_bundles: Iterable[_Bundle]): + output_committed_bundles: Iterable['_Bundle'], + unprocessed_bundles: Iterable['_Bundle']): """Updated list of pending bundles for the given AppliedPTransform.""" # Update pending elements. Filter out empty bundles. They do not impact @@ -206,7 +206,7 @@ def __init__(self, clock, keyed_states, transform): self._output_watermark = WatermarkManager.WATERMARK_NEG_INF self._keyed_earliest_holds = {} # Scheduled bundles targeted for this transform. - self._pending: Set[_Bundle] = set() + self._pending: Set['_Bundle'] = set() self._fired_timers = set() self._lock = threading.Lock() @@ -240,11 +240,11 @@ def hold(self, keyed_earliest_holds): hold_value == WatermarkManager.WATERMARK_POS_INF): del self._keyed_earliest_holds[key] - def add_pending(self, pending: _Bundle) -> None: + def add_pending(self, pending: '_Bundle') -> None: with self._lock: self._pending.add(pending) - def remove_pending(self, completed: _Bundle) -> None: + def remove_pending(self, completed: '_Bundle') -> None: with self._lock: # Ignore repeated removes. This will happen if a transform has a repeated # input. diff --git a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py index a09205e23401..ad46f5d65ea3 100644 --- a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py +++ b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py @@ -46,7 +46,7 @@ def option(cls) -> str: raise NotImplementedError @abc.abstractmethod - def render_pipeline_graph(self, pipeline_graph: PipelineGraph) -> str: + def render_pipeline_graph(self, pipeline_graph: 'PipelineGraph') -> str: """Renders the pipeline graph in HTML-compatible format. Args: @@ -65,7 +65,7 @@ class MuteRenderer(PipelineGraphRenderer): def option(cls) -> str: return 'mute' - def render_pipeline_graph(self, pipeline_graph: PipelineGraph) -> str: + def render_pipeline_graph(self, pipeline_graph: 'PipelineGraph') -> str: return '' @@ -76,7 +76,7 @@ class TextRenderer(PipelineGraphRenderer): def option(cls) -> str: return 'text' - def render_pipeline_graph(self, pipeline_graph: PipelineGraph) -> str: + def render_pipeline_graph(self, pipeline_graph: 'PipelineGraph') -> str: return pipeline_graph.get_dot() @@ -91,7 +91,7 @@ class PydotRenderer(PipelineGraphRenderer): def option(cls) -> str: return 'graph' - def render_pipeline_graph(self, pipeline_graph: PipelineGraph) -> str: + def render_pipeline_graph(self, pipeline_graph: 'PipelineGraph') -> str: return pipeline_graph._get_graph().create_svg().decode("utf-8") # pylint: disable=protected-access diff --git a/sdks/python/apache_beam/runners/pipeline_context.py b/sdks/python/apache_beam/runners/pipeline_context.py index 44961241dc15..6af40e10332d 100644 --- a/sdks/python/apache_beam/runners/pipeline_context.py +++ b/sdks/python/apache_beam/runners/pipeline_context.py @@ -23,7 +23,6 @@ # pytype: skip-file # mypy: disallow-untyped-defs -from typing import TYPE_CHECKING from typing import Any from typing import Dict from typing import FrozenSet @@ -35,9 +34,12 @@ from typing import TypeVar from typing import Union +from google.protobuf import message from typing_extensions import Protocol from apache_beam import coders +from apache_beam.coders.coder_impl import IterableStateReader +from apache_beam.coders.coder_impl import IterableStateWriter from apache_beam import pipeline from apache_beam import pvalue from apache_beam.internal import pickler @@ -49,21 +51,15 @@ from apache_beam.transforms.resources import merge_resource_hints from apache_beam.typehints import native_type_compatibility -if TYPE_CHECKING: - from google.protobuf import message # pylint: disable=ungrouped-imports - from apache_beam.coders.coder_impl import IterableStateReader - from apache_beam.coders.coder_impl import IterableStateWriter - from apache_beam.transforms import ptransform - PortableObjectT = TypeVar('PortableObjectT', bound='PortableObject') class PortableObject(Protocol): - def to_runner_api(self, __context: PipelineContext) -> Any: + def to_runner_api(self, __context: 'PipelineContext') -> Any: pass @classmethod - def from_runner_api(cls, __proto: Any, __context: PipelineContext) -> Any: + def from_runner_api(cls, __proto: Any, __context: 'PipelineContext') -> Any: pass @@ -75,7 +71,7 @@ class _PipelineContextMap(Generic[PortableObjectT]): """ def __init__( self, - context: PipelineContext, + context: 'PipelineContext', obj_type: Type[PortableObjectT], namespace: str, proto_map: Optional[Mapping[str, message.Message]] = None) -> None: @@ -271,7 +267,8 @@ def element_type_from_coder_id(self, coder_id: str) -> Any: self.coders[coder_id].to_type_hint()) @staticmethod - def from_runner_api(proto: beam_runner_api_pb2.Components) -> PipelineContext: + def from_runner_api( + proto: beam_runner_api_pb2.Components) -> 'PipelineContext': return PipelineContext(proto) def to_runner_api(self) -> beam_runner_api_pb2.Components: diff --git a/sdks/python/apache_beam/runners/portability/abstract_job_service.py b/sdks/python/apache_beam/runners/portability/abstract_job_service.py index 09c388f3b6a3..87162d5feda5 100644 --- a/sdks/python/apache_beam/runners/portability/abstract_job_service.py +++ b/sdks/python/apache_beam/runners/portability/abstract_job_service.py @@ -25,7 +25,7 @@ import uuid import zipfile from concurrent import futures -from typing import TYPE_CHECKING +from typing import BinaryIO from typing import Dict from typing import Iterator from typing import Optional @@ -34,21 +34,17 @@ import grpc from google.protobuf import json_format +from google.protobuf import struct_pb2 from google.protobuf import timestamp_pb2 from apache_beam.portability.api import beam_artifact_api_pb2_grpc from apache_beam.portability.api import beam_job_api_pb2 from apache_beam.portability.api import beam_job_api_pb2_grpc +from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.portability.api import endpoints_pb2 from apache_beam.runners.portability import artifact_service from apache_beam.utils.timestamp import Timestamp -if TYPE_CHECKING: - # pylint: disable=ungrouped-imports - from typing import BinaryIO - from google.protobuf import struct_pb2 - from apache_beam.portability.api import beam_runner_api_pb2 - _LOGGER = logging.getLogger(__name__) StateEvent = Tuple[int, Union[timestamp_pb2.Timestamp, Timestamp]] @@ -81,7 +77,7 @@ def create_beam_job(self, job_name: str, pipeline: beam_runner_api_pb2.Pipeline, options: struct_pb2.Struct - ) -> AbstractBeamJob: + ) -> 'AbstractBeamJob': """Returns an instance of AbstractBeamJob specific to this servicer.""" raise NotImplementedError(type(self)) @@ -208,7 +204,7 @@ def prepare(self) -> None: def run(self) -> None: raise NotImplementedError(self) - def cancel(self) -> Optional[beam_job_api_pb2.JobState.Enum]: + def cancel(self) -> Optional['beam_job_api_pb2.JobState.Enum']: raise NotImplementedError(self) def artifact_staging_endpoint( diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py index 3c16cb7cf99d..4976ff78aa99 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py @@ -50,6 +50,7 @@ from apache_beam.coders.coder_impl import CoderImpl from apache_beam.coders.coder_impl import create_InputStream from apache_beam.coders.coder_impl import create_OutputStream +from apache_beam.coders.coder_impl import WindowedValueCoderImpl from apache_beam.coders.coders import WindowedValueCoder from apache_beam.portability import common_urns from apache_beam.portability import python_urns @@ -59,6 +60,7 @@ from apache_beam.runners.common import ENCODED_IMPULSE_VALUE from apache_beam.runners.direct.clock import RealClock from apache_beam.runners.direct.clock import TestClock +from apache_beam.portability.api import endpoints_pb2 from apache_beam.runners.portability.fn_api_runner import translations from apache_beam.runners.portability.fn_api_runner.translations import DataInput from apache_beam.runners.portability.fn_api_runner.translations import DataOutput @@ -73,6 +75,7 @@ from apache_beam.transforms import core from apache_beam.transforms import trigger from apache_beam.transforms import window +from apache_beam.transforms.window import BoundedWindow from apache_beam.transforms.window import GlobalWindow from apache_beam.transforms.window import GlobalWindows from apache_beam.utils import proto_utils @@ -81,12 +84,8 @@ from apache_beam.utils.timestamp import Timestamp if TYPE_CHECKING: - from apache_beam.coders.coder_impl import WindowedValueCoderImpl - from apache_beam.portability.api import endpoints_pb2 from apache_beam.runners.portability.fn_api_runner import worker_handlers from apache_beam.runners.portability.fn_api_runner.translations import DataSideInput - from apache_beam.runners.portability.fn_api_runner.translations import TimerFamilyId - from apache_beam.transforms.window import BoundedWindow _LOGGER = logging.getLogger(__name__) @@ -318,8 +317,8 @@ def __init__( def append(self, elements_data: bytes) -> None: input_stream = create_InputStream(elements_data) while input_stream.size() > 0: - windowed_val_coder_impl: WindowedValueCoderImpl = self._windowed_value_coder.get_impl( - ) + windowed_val_coder_impl: WindowedValueCoderImpl = ( + self._windowed_value_coder.get_impl()) windowed_value = windowed_val_coder_impl.decode_from_stream( input_stream, True) key, value = self._kv_extractor(windowed_value.value) @@ -356,7 +355,7 @@ def get_window_coder(self) -> coders.Coder: @staticmethod @window.urns.RunnerApiFn.register_urn(URN, bytes) def from_runner_api_parameter( - window_coder_id: bytes, context: Any) -> GenericNonMergingWindowFn: + window_coder_id: bytes, context: Any) -> 'GenericNonMergingWindowFn': return GenericNonMergingWindowFn( context.coders[window_coder_id.decode('utf-8')]) @@ -451,11 +450,11 @@ class GenericMergingWindowFn(window.WindowFn): TO_SDK_TRANSFORM = 'read' FROM_SDK_TRANSFORM = 'write' - _HANDLES: Dict[str, GenericMergingWindowFn] = {} + _HANDLES: Dict[str, 'GenericMergingWindowFn'] = {} def __init__( self, - execution_context: FnApiRunnerExecutionContext, + execution_context: 'FnApiRunnerExecutionContext', windowing_strategy_proto: beam_runner_api_pb2.WindowingStrategy) -> None: self._worker_handler: Optional[worker_handlers.WorkerHandler] = None self._handle_id = handle_id = uuid.uuid4().hex @@ -473,7 +472,7 @@ def __init__( self.windowed_input_coder_impl: Optional[CoderImpl] = None self.windowed_output_coder_impl: Optional[CoderImpl] = None - def _execution_context_ref(self) -> FnApiRunnerExecutionContext: + def _execution_context_ref(self) -> 'FnApiRunnerExecutionContext': result = self._execution_context_ref_obj() assert result is not None return result @@ -484,7 +483,7 @@ def payload(self) -> bytes: @staticmethod @window.urns.RunnerApiFn.register_urn(URN, bytes) def from_runner_api_parameter( - handle_id: bytes, unused_context: Any) -> GenericMergingWindowFn: + handle_id: bytes, unused_context: Any) -> 'GenericMergingWindowFn': return GenericMergingWindowFn._HANDLES[handle_id.decode('utf-8')] def assign( @@ -531,7 +530,7 @@ def get_window_coder(self) -> coders.Coder: return self._execution_context_ref().pipeline_context.coders[ self._windowing_strategy_proto.window_coder_id] - def worker_handle(self) -> worker_handlers.WorkerHandler: + def worker_handle(self) -> 'worker_handlers.WorkerHandler': if self._worker_handler is None: worker_handler_manager = self._execution_context_ref( ).worker_handler_manager @@ -665,7 +664,7 @@ class FnApiRunnerExecutionContext(object): def __init__( self, stages: List[translations.Stage], - worker_handler_manager: worker_handlers.WorkerHandlerManager, + worker_handler_manager: 'worker_handlers.WorkerHandlerManager', pipeline_components: beam_runner_api_pb2.Components, safe_coders: translations.SafeCoderMapping, data_channel_coders: Dict[str, str], @@ -827,7 +826,7 @@ def _enqueue_stage_initial_inputs(self, stage: Stage) -> None: @staticmethod def _build_data_side_inputs_map( stages: Iterable[translations.Stage] - ) -> MutableMapping[str, DataSideInput]: + ) -> MutableMapping[str, 'DataSideInput']: """Builds an index mapping stages to side input descriptors. A side input descriptor is a map of side input IDs to side input access @@ -910,7 +909,7 @@ def _make_safe_windowing_strategy(self, id: str) -> str: return safe_id @property - def state_servicer(self) -> worker_handlers.StateServicer: + def state_servicer(self) -> 'worker_handlers.StateServicer': # TODO(BEAM-9625): Ensure FnApiRunnerExecutionContext owns StateServicer return self.worker_handler_manager.state_servicer @@ -932,7 +931,7 @@ def _iterable_state_write( def commit_side_inputs_to_state( self, - data_side_input: DataSideInput, + data_side_input: 'DataSideInput', ) -> None: for (consuming_transform_id, tag), (buffer_id, func_spec) in data_side_input.items(): @@ -1033,7 +1032,7 @@ def _compute_expected_outputs(self) -> None: create_buffer_id(timer_family_id, 'timers'), time_domain) @property - def worker_handlers(self) -> List[worker_handlers.WorkerHandler]: + def worker_handlers(self) -> List['worker_handlers.WorkerHandler']: if self._worker_handlers is None: self._worker_handlers = ( self.execution_context.worker_handler_manager.get_worker_handlers( diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py index d15e04e5f238..8b313d624a52 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py @@ -31,7 +31,6 @@ import sys import threading import time -from typing import TYPE_CHECKING from typing import Callable from typing import Dict from typing import Iterable @@ -55,11 +54,13 @@ from apache_beam.metrics.monitoring_infos import consolidate as consolidate_monitoring_infos from apache_beam.options import pipeline_options from apache_beam.options.value_provider import RuntimeValueProvider +from apache_beam.pipeline import Pipeline from apache_beam.portability import common_urns from apache_beam.portability import python_urns from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.portability.api import beam_provision_api_pb2 from apache_beam.portability.api import beam_runner_api_pb2 +from apache_beam.portability.api import metrics_pb2 from apache_beam.runners import runner from apache_beam.runners.common import group_by_key_input_visitor from apache_beam.runners.common import merge_common_environments @@ -75,6 +76,7 @@ from apache_beam.runners.portability.fn_api_runner.translations import OutputTimers from apache_beam.runners.portability.fn_api_runner.translations import create_buffer_id from apache_beam.runners.portability.fn_api_runner.translations import only_element +from apache_beam.runners.portability.fn_api_runner.worker_handlers import WorkerHandler from apache_beam.runners.portability.fn_api_runner.worker_handlers import WorkerHandlerManager from apache_beam.runners.worker import bundle_processor from apache_beam.transforms import environments @@ -83,11 +85,6 @@ from apache_beam.utils import timestamp from apache_beam.utils.profiler import Profile -if TYPE_CHECKING: - from apache_beam.pipeline import Pipeline - from apache_beam.portability.api import metrics_pb2 - from apache_beam.runners.portability.fn_api_runner.worker_handlers import WorkerHandler - _LOGGER = logging.getLogger(__name__) _BUNDLE_LOGGER = logging.getLogger(__name__ + ".run_bundle") @@ -105,7 +102,7 @@ def __init__( default_environment: Optional[environments.Environment] = None, bundle_repeat: int = 0, use_state_iterables: bool = False, - provision_info: Optional[ExtendedProvisionInfo] = None, + provision_info: Optional['ExtendedProvisionInfo'] = None, progress_request_frequency: Optional[float] = None, is_drain: bool = False) -> None: """Creates a new Fn API Runner. @@ -144,7 +141,7 @@ def supported_requirements() -> Tuple[str, ...]: def run_pipeline( self, pipeline: Pipeline, - options: pipeline_options.PipelineOptions) -> RunnerResult: + options: pipeline_options.PipelineOptions) -> 'RunnerResult': RuntimeValueProvider.set_runtime_options({}) # Setup "beam_fn_api" experiment options if lacked. @@ -203,7 +200,7 @@ def run_pipeline( def run_via_runner_api( self, pipeline_proto: beam_runner_api_pb2.Pipeline, - options: pipeline_options.PipelineOptions) -> RunnerResult: + options: pipeline_options.PipelineOptions) -> 'RunnerResult': validate_pipeline_graph(pipeline_proto) self._validate_requirements(pipeline_proto) self._check_requirements(pipeline_proto) @@ -410,7 +407,7 @@ def create_stages( def run_stages( self, stage_context: translations.TransformContext, - stages: List[translations.Stage]) -> RunnerResult: + stages: List[translations.Stage]) -> 'RunnerResult': """Run a list of topologically-sorted stages in batch mode. Args: @@ -582,7 +579,7 @@ def _schedule_ready_bundles( def _run_bundle_multiple_times_for_testing( self, runner_execution_context: execution.FnApiRunnerExecutionContext, - bundle_manager: BundleManager, + bundle_manager: 'BundleManager', data_input: MutableMapping[str, execution.PartitionableBuffer], data_output: DataOutput, fired_timers: Mapping[translations.TimerFamilyId, @@ -1016,7 +1013,7 @@ def _run_bundle( bundle_input: DataInput, data_output: DataOutput, expected_timer_output: OutputTimers, - bundle_manager: BundleManager + bundle_manager: 'BundleManager' ) -> Tuple[beam_fn_api_pb2.InstructionResponse, Dict[str, execution.PartitionableBuffer], OutputTimerData, @@ -1091,7 +1088,7 @@ class StaticGenerator(object): def __init__(self) -> None: self._token = generate_token(1) - def __iter__(self) -> StaticGenerator: + def __iter__(self) -> 'StaticGenerator': # pylint: disable=non-iterator-returned return self @@ -1103,7 +1100,7 @@ def __init__(self) -> None: self._counter = 0 self._lock = threading.Lock() - def __iter__(self) -> DynamicGenerator: + def __iter__(self) -> 'DynamicGenerator': # pylint: disable=non-iterator-returned return self @@ -1130,7 +1127,7 @@ def __init__( self.artifact_staging_dir = artifact_staging_dir self.job_name = job_name - def for_environment(self, env) -> ExtendedProvisionInfo: + def for_environment(self, env) -> 'ExtendedProvisionInfo': if env.dependencies: provision_info_with_deps = copy.deepcopy(self.provision_info) provision_info_with_deps.dependencies.extend(env.dependencies) @@ -1210,7 +1207,8 @@ def __init__( Args: progress_frequency """ - self.bundle_context_manager: execution.BundleContextManager = bundle_context_manager + self.bundle_context_manager: execution.BundleContextManager = ( + bundle_context_manager) self._progress_frequency = progress_frequency self._worker_handler: Optional[WorkerHandler] = None self._cache_token_generator = cache_token_generator @@ -1291,8 +1289,8 @@ def _generate_splits_for_testing( estimated_input_elements=num_elements) })) logging.info("Requesting split %s", split_request) - split_response: beam_fn_api_pb2.InstructionResponse = self._worker_handler.control_conn.push( - split_request).get() + split_response: beam_fn_api_pb2.InstructionResponse = ( + self._worker_handler.control_conn.push(split_request).get()) for t in (0.05, 0.1, 0.2): if ('Unknown process bundle' in split_response.error or split_response.process_bundle_split == diff --git a/sdks/python/apache_beam/runners/portability/job_server.py b/sdks/python/apache_beam/runners/portability/job_server.py index 030a3b67df33..e44d8ab0ae93 100644 --- a/sdks/python/apache_beam/runners/portability/job_server.py +++ b/sdks/python/apache_beam/runners/portability/job_server.py @@ -58,7 +58,7 @@ def stop(self): class EmbeddedJobServer(JobServer): - def start(self) -> local_job_service.LocalJobServicer: + def start(self) -> 'local_job_service.LocalJobServicer': return local_job_service.LocalJobServicer() def stop(self): diff --git a/sdks/python/apache_beam/runners/portability/local_job_service.py b/sdks/python/apache_beam/runners/portability/local_job_service.py index 4b6d4718f4dd..869f013d0d26 100644 --- a/sdks/python/apache_beam/runners/portability/local_job_service.py +++ b/sdks/python/apache_beam/runners/portability/local_job_service.py @@ -27,7 +27,6 @@ import threading import time import traceback -from typing import TYPE_CHECKING from typing import Any from typing import List from typing import Mapping @@ -35,6 +34,7 @@ import grpc from google.protobuf import json_format +from google.protobuf import struct_pb2 from google.protobuf import text_format # type: ignore # not in typeshed from apache_beam import pipeline @@ -57,9 +57,6 @@ from apache_beam.transforms import environments from apache_beam.utils import thread_pool_executor -if TYPE_CHECKING: - from google.protobuf import struct_pb2 # pylint: disable=ungrouped-imports - _LOGGER = logging.getLogger(__name__) @@ -96,7 +93,7 @@ def create_beam_job(self, job_name: str, pipeline: beam_runner_api_pb2.Pipeline, options: struct_pb2.Struct - ) -> BeamJob: + ) -> 'BeamJob': self._artifact_service.register_job( staging_token=preparation_id, dependency_sets=_extract_dependency_sets( diff --git a/sdks/python/apache_beam/runners/portability/portable_runner.py b/sdks/python/apache_beam/runners/portability/portable_runner.py index fd19c2ba2388..92f123697a9d 100644 --- a/sdks/python/apache_beam/runners/portability/portable_runner.py +++ b/sdks/python/apache_beam/runners/portability/portable_runner.py @@ -25,7 +25,6 @@ import logging import threading import time -from typing import TYPE_CHECKING from typing import Any from typing import Dict from typing import Iterator @@ -33,6 +32,7 @@ from typing import Tuple import grpc +from google.protobuf import struct_pb2 from apache_beam.metrics import metric from apache_beam.metrics.execution import MetricResult @@ -41,6 +41,7 @@ from apache_beam.options.pipeline_options import PortableOptions from apache_beam.options.pipeline_options import StandardOptions from apache_beam.options.value_provider import ValueProvider +from apache_beam.pipeline import Pipeline from apache_beam.portability import common_urns from apache_beam.portability import python_urns from apache_beam.portability.api import beam_artifact_api_pb2_grpc @@ -56,10 +57,6 @@ from apache_beam.runners.worker import worker_pool_main from apache_beam.transforms import environments -if TYPE_CHECKING: - from google.protobuf import struct_pb2 # pylint: disable=ungrouped-imports - from apache_beam.pipeline import Pipeline - __all__ = ['PortableRunner'] MESSAGE_LOG_LEVELS = { diff --git a/sdks/python/apache_beam/runners/runner.py b/sdks/python/apache_beam/runners/runner.py index 4ba49378c8a5..78022724226a 100644 --- a/sdks/python/apache_beam/runners/runner.py +++ b/sdks/python/apache_beam/runners/runner.py @@ -55,7 +55,7 @@ _LOGGER = logging.getLogger(__name__) -def create_runner(runner_name: str) -> PipelineRunner: +def create_runner(runner_name: str) -> 'PipelineRunner': """For internal use only; no backwards-compatibility guarantees. Creates a runner instance from a runner class name. @@ -113,8 +113,8 @@ class PipelineRunner(object): """ def run( self, - transform: PTransform, - options: Optional[PipelineOptions] = None) -> PipelineResult: + transform: 'PTransform', + options: Optional[PipelineOptions] = None) -> 'PipelineResult': """Run the given transform or callable with this runner. Blocks until the pipeline is complete. See also `PipelineRunner.run_async`. @@ -125,8 +125,8 @@ def run( def run_async( self, - transform: PTransform, - options: Optional[PipelineOptions] = None) -> PipelineResult: + transform: 'PTransform', + options: Optional[PipelineOptions] = None) -> 'PipelineResult': """Run the given transform or callable with this runner. May return immediately, executing the pipeline in the background. @@ -164,7 +164,7 @@ def default_environment( options.view_as(PortableOptions)) def run_pipeline( - self, pipeline: Pipeline, options: PipelineOptions) -> PipelineResult: + self, pipeline: 'Pipeline', options: PipelineOptions) -> 'PipelineResult': """Execute the entire pipeline or the sub-DAG reachable from a node. """ pipeline.visit( @@ -184,8 +184,8 @@ def run_pipeline( def apply( self, - transform: PTransform, - input: Optional[pvalue.PValue], + transform: 'PTransform', + input: Optional['pvalue.PValue'], options: PipelineOptions): # TODO(robertwb): Remove indirection once internal references are fixed. return self.apply_PTransform(transform, input, options) diff --git a/sdks/python/apache_beam/runners/sdf_utils.py b/sdks/python/apache_beam/runners/sdf_utils.py index d2d8a4a3c584..01573656b6ac 100644 --- a/sdks/python/apache_beam/runners/sdf_utils.py +++ b/sdks/python/apache_beam/runners/sdf_utils.py @@ -55,7 +55,7 @@ class ThreadsafeRestrictionTracker(object): This wrapper guarantees synchronization of modifying restrictions across multi-thread. """ - def __init__(self, restriction_tracker: RestrictionTracker) -> None: + def __init__(self, restriction_tracker: 'RestrictionTracker') -> None: from apache_beam.io.iobase import RestrictionTracker if not isinstance(restriction_tracker, RestrictionTracker): raise ValueError( @@ -109,7 +109,7 @@ def check_done(self): with self._lock: return self._restriction_tracker.check_done() - def current_progress(self) -> RestrictionProgress: + def current_progress(self) -> 'RestrictionProgress': with self._lock: return self._restriction_tracker.current_progress() @@ -182,7 +182,7 @@ class ThreadsafeWatermarkEstimator(object): """A threadsafe wrapper which wraps a WatermarkEstimator with locking mechanism to guarantee multi-thread safety. """ - def __init__(self, watermark_estimator: WatermarkEstimator) -> None: + def __init__(self, watermark_estimator: 'WatermarkEstimator') -> None: from apache_beam.io.iobase import WatermarkEstimator if not isinstance(watermark_estimator, WatermarkEstimator): raise ValueError('Initializing Threadsafe requires a WatermarkEstimator') diff --git a/sdks/python/apache_beam/runners/worker/log_handler.py b/sdks/python/apache_beam/runners/worker/log_handler.py index b7cf9db757d3..88cc3c9791d5 100644 --- a/sdks/python/apache_beam/runners/worker/log_handler.py +++ b/sdks/python/apache_beam/runners/worker/log_handler.py @@ -27,7 +27,6 @@ import threading import time import traceback -from typing import TYPE_CHECKING from typing import Iterable from typing import Iterator from typing import List @@ -38,14 +37,12 @@ from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.portability.api import beam_fn_api_pb2_grpc +from apache_beam.portability.api import endpoints_pb2 from apache_beam.runners.worker import statesampler from apache_beam.runners.worker.channel_factory import GRPCChannelFactory from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor from apache_beam.utils.sentinel import Sentinel -if TYPE_CHECKING: - from apache_beam.portability.api import endpoints_pb2 - # Mapping from logging levels to LogEntry levels. LOG_LEVEL_TO_LOGENTRY_MAP = { logging.FATAL: beam_fn_api_pb2.LogEntry.Severity.CRITICAL, diff --git a/sdks/python/apache_beam/runners/worker/statesampler.py b/sdks/python/apache_beam/runners/worker/statesampler.py index ece31c05517a..b9c75f4de93d 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler.py +++ b/sdks/python/apache_beam/runners/worker/statesampler.py @@ -131,10 +131,10 @@ def get_info(self) -> StateSamplerInfo: def scoped_state( self, - name_context: Union[str, common.NameContext], + name_context: Union[str, 'common.NameContext'], state_name: str, io_target=None, - metrics_container: Optional[MetricsContainer] = None + metrics_container: Optional['MetricsContainer'] = None ) -> statesampler_impl.ScopedState: """Returns a ScopedState object associated to a Step and a State. diff --git a/sdks/python/apache_beam/runners/worker/statesampler_slow.py b/sdks/python/apache_beam/runners/worker/statesampler_slow.py index 33bce92d4391..be801284450a 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_slow.py +++ b/sdks/python/apache_beam/runners/worker/statesampler_slow.py @@ -31,7 +31,7 @@ def __init__(self, sampling_period_ms): self.state_transition_count = 0 self.time_since_transition = 0 - def current_state(self) -> ScopedState: + def current_state(self) -> 'ScopedState': """Returns the current execution state. This operation is not thread safe, and should only be called from the @@ -41,9 +41,9 @@ def current_state(self) -> ScopedState: def _scoped_state( self, counter_name: counters.CounterName, - name_context: common.NameContext, + name_context: 'common.NameContext', output_counter, - metrics_container=None) -> ScopedState: + metrics_container=None) -> 'ScopedState': assert isinstance(name_context, common.NameContext) return ScopedState( self, counter_name, name_context, output_counter, metrics_container) @@ -53,7 +53,7 @@ def update_metric(self, typed_metric_name, value): if metrics_container is not None: metrics_container.get_metric_cell(typed_metric_name).update(value) - def _enter_state(self, state: ScopedState) -> None: + def _enter_state(self, state: 'ScopedState') -> None: self.state_transition_count += 1 self._state_stack.append(state) @@ -77,7 +77,7 @@ def __init__( self, sampler: StateSampler, name: counters.CounterName, - step_name_context: Optional[common.NameContext], + step_name_context: Optional['common.NameContext'], counter=None, metrics_container=None): self.state_sampler = sampler diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_perf.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_perf.py index a783e58aacf7..c29825f95f3e 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_perf.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_perf.py @@ -36,7 +36,7 @@ def __init__( # number of result produced self.result_count = result_count if result_count else -1 - def has_progress(self, previous_perf: NexmarkPerf) -> bool: + def has_progress(self, previous_perf: 'NexmarkPerf') -> bool: """ Args: previous_perf: a NexmarkPerf object to be compared to self diff --git a/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py b/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py index e71c4a471923..caadbaca1e1e 100644 --- a/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py +++ b/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py @@ -199,7 +199,7 @@ def __init__( bq_table=None, bq_dataset=None, publish_to_bq=False, - influxdb_options: Optional[InfluxDBMetricsPublisherOptions] = None, + influxdb_options: Optional['InfluxDBMetricsPublisherOptions'] = None, namespace=None, filters=None): """Initializes :class:`MetricsReader` . diff --git a/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py b/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py index 33cc3db34811..3cb5f32c3114 100644 --- a/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py +++ b/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py @@ -35,7 +35,7 @@ @with_input_types(int) @with_output_types(int) class CallSequenceEnforcingCombineFn(beam.CombineFn): - instances: Set[CallSequenceEnforcingCombineFn] = set() + instances: Set['CallSequenceEnforcingCombineFn'] = set() def __init__(self): super().__init__() diff --git a/sdks/python/apache_beam/transforms/external_java.py b/sdks/python/apache_beam/transforms/external_java.py index 85eeff977609..e3984fa8ef20 100644 --- a/sdks/python/apache_beam/transforms/external_java.py +++ b/sdks/python/apache_beam/transforms/external_java.py @@ -21,6 +21,7 @@ import logging import subprocess import sys +from typing import Optional import grpc from mock import patch @@ -46,8 +47,8 @@ class JavaExternalTransformTest(object): # This will be overwritten if set via a flag. - expansion_service_jar: str = None - expansion_service_port: int = None + expansion_service_jar: Optional[str] = None + expansion_service_port: Optional[int] = None class _RunWithExpansion(object): def __init__(self): diff --git a/sdks/python/apache_beam/transforms/resources.py b/sdks/python/apache_beam/transforms/resources.py index bc15271aadd0..04f38d368122 100644 --- a/sdks/python/apache_beam/transforms/resources.py +++ b/sdks/python/apache_beam/transforms/resources.py @@ -26,18 +26,15 @@ """ import re -from typing import TYPE_CHECKING from typing import Any from typing import Dict +from typing import Mapping from typing import Optional +from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import StandardOptions from apache_beam.portability.common_urns import resource_hints -if TYPE_CHECKING: - from typing import Mapping - from apache_beam.options.pipeline_options import PipelineOptions - __all__ = [ 'ResourceHint', 'AcceleratorHint', diff --git a/sdks/python/apache_beam/transforms/sideinputs.py b/sdks/python/apache_beam/transforms/sideinputs.py index 4951594e63b0..0ff2a388b9e1 100644 --- a/sdks/python/apache_beam/transforms/sideinputs.py +++ b/sdks/python/apache_beam/transforms/sideinputs.py @@ -76,7 +76,7 @@ def get_sideinput_index(tag: str) -> int: class SideInputMap(object): """Represents a mapping of windows to side input values.""" - def __init__(self, view_class: pvalue.AsSideInput, view_options, iterable): + def __init__(self, view_class: 'pvalue.AsSideInput', view_options, iterable): self._window_mapping_fn = view_options.get( 'window_mapping_fn', _global_window_mapping_fn) self._view_class = view_class diff --git a/sdks/python/apache_beam/transforms/userstate.py b/sdks/python/apache_beam/transforms/userstate.py index 45f73ab69ad0..ada0b755bd6c 100644 --- a/sdks/python/apache_beam/transforms/userstate.py +++ b/sdks/python/apache_beam/transforms/userstate.py @@ -38,12 +38,12 @@ from apache_beam.portability import common_urns from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.transforms.timeutil import TimeDomain +from apache_beam.utils import windowed_value +from apache_beam.utils.timestamp import Timestamp if TYPE_CHECKING: from apache_beam.runners.pipeline_context import PipelineContext - from apache_beam.transforms.core import CombineFn, DoFn - from apache_beam.utils import windowed_value - from apache_beam.utils.timestamp import Timestamp + from apache_beam.transforms.core import DoFn CallableT = TypeVar('CallableT', bound=Callable) @@ -62,14 +62,14 @@ def __repr__(self) -> str: return '%s(%s)' % (self.__class__.__name__, self.name) def to_runner_api( - self, context: PipelineContext) -> beam_runner_api_pb2.StateSpec: + self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: raise NotImplementedError class ReadModifyWriteStateSpec(StateSpec): """Specification for a user DoFn value state cell.""" def to_runner_api( - self, context: PipelineContext) -> beam_runner_api_pb2.StateSpec: + self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( read_modify_write_spec=beam_runner_api_pb2.ReadModifyWriteStateSpec( coder_id=context.coders.get_id(self.coder)), @@ -80,7 +80,7 @@ def to_runner_api( class BagStateSpec(StateSpec): """Specification for a user DoFn bag state cell.""" def to_runner_api( - self, context: PipelineContext) -> beam_runner_api_pb2.StateSpec: + self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( bag_spec=beam_runner_api_pb2.BagStateSpec( element_coder_id=context.coders.get_id(self.coder)), @@ -91,7 +91,7 @@ def to_runner_api( class SetStateSpec(StateSpec): """Specification for a user DoFn Set State cell""" def to_runner_api( - self, context: PipelineContext) -> beam_runner_api_pb2.StateSpec: + self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( set_spec=beam_runner_api_pb2.SetStateSpec( element_coder_id=context.coders.get_id(self.coder)), @@ -141,7 +141,7 @@ def __init__( super().__init__(name, coder) def to_runner_api( - self, context: PipelineContext) -> beam_runner_api_pb2.StateSpec: + self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( combining_spec=beam_runner_api_pb2.CombiningStateSpec( combine_fn=self.combine_fn.to_runner_api(context), @@ -180,7 +180,7 @@ def __repr__(self) -> str: return '%s(%s)' % (self.__class__.__name__, self.name) def to_runner_api( - self, context: PipelineContext, key_coder: Coder, + self, context: 'PipelineContext', key_coder: Coder, window_coder: Coder) -> beam_runner_api_pb2.TimerFamilySpec: return beam_runner_api_pb2.TimerFamilySpec( time_domain=TimeDomain.to_runner_api(self.time_domain), @@ -217,7 +217,7 @@ def _inner(method: CallableT) -> CallableT: return _inner -def get_dofn_specs(dofn: DoFn) -> Tuple[Set[StateSpec], Set[TimerSpec]]: +def get_dofn_specs(dofn: 'DoFn') -> Tuple[Set[StateSpec], Set[TimerSpec]]: """Gets the state and timer specs for a DoFn, if any. Args: @@ -256,7 +256,7 @@ def get_dofn_specs(dofn: DoFn) -> Tuple[Set[StateSpec], Set[TimerSpec]]: return all_state_specs, all_timer_specs -def is_stateful_dofn(dofn: DoFn) -> bool: +def is_stateful_dofn(dofn: 'DoFn') -> bool: """Determines whether a given DoFn is a stateful DoFn.""" # A Stateful DoFn is a DoFn that uses user state or timers. @@ -264,7 +264,7 @@ def is_stateful_dofn(dofn: DoFn) -> bool: return bool(all_state_specs or all_timer_specs) -def validate_stateful_dofn(dofn: DoFn) -> None: +def validate_stateful_dofn(dofn: 'DoFn') -> None: """Validates the proper specification of a stateful DoFn.""" # Get state and timer specs. @@ -378,7 +378,7 @@ def get_timer( self, timer_spec: TimerSpec, key: Any, - window: windowed_value.BoundedWindow, + window: 'windowed_value.BoundedWindow', timestamp: Timestamp, pane: windowed_value.PaneInfo, ) -> BaseTimer: @@ -388,7 +388,7 @@ def get_state( self, state_spec: StateSpec, key: Any, - window: windowed_value.BoundedWindow, + window: 'windowed_value.BoundedWindow', ) -> RuntimeState: raise NotImplementedError(type(self)) diff --git a/sdks/python/apache_beam/transforms/userstate_test.py b/sdks/python/apache_beam/transforms/userstate_test.py index 9caa31386569..5dd6c61d6add 100644 --- a/sdks/python/apache_beam/transforms/userstate_test.py +++ b/sdks/python/apache_beam/transforms/userstate_test.py @@ -437,7 +437,7 @@ def __repr__(self): class StatefulDoFnOnDirectRunnerTest(unittest.TestCase): # pylint: disable=expression-not-assigned - all_records: List[Any] = None + all_records: List[Any] def setUp(self): # Use state on the TestCase class, since other references would be pickled diff --git a/sdks/python/apache_beam/transforms/window.py b/sdks/python/apache_beam/transforms/window.py index d13894cff8a1..592164a5ef49 100644 --- a/sdks/python/apache_beam/transforms/window.py +++ b/sdks/python/apache_beam/transforms/window.py @@ -106,7 +106,7 @@ class TimestampCombiner(object): @staticmethod def get_impl( timestamp_combiner: beam_runner_api_pb2.OutputTime.Enum, - window_fn: WindowFn) -> timeutil.TimestampCombinerImpl: + window_fn: 'WindowFn') -> timeutil.TimestampCombinerImpl: if timestamp_combiner == TimestampCombiner.OUTPUT_AT_EOW: return timeutil.OutputAtEndOfWindowImpl() elif timestamp_combiner == TimestampCombiner.OUTPUT_AT_EARLIEST: @@ -127,13 +127,14 @@ def __init__( self, timestamp: TimestampTypes, element: Optional[Any] = None, - window: Optional[BoundedWindow] = None) -> None: + window: Optional['BoundedWindow'] = None) -> None: self.timestamp = Timestamp.of(timestamp) self.element = element self.window = window @abc.abstractmethod - def assign(self, assign_context: AssignContext) -> Iterable[BoundedWindow]: + def assign(self, + assign_context: 'AssignContext') -> Iterable['BoundedWindow']: # noqa: F821 """Associates windows to an element. @@ -148,17 +149,17 @@ def assign(self, assign_context: AssignContext) -> Iterable[BoundedWindow]: class MergeContext(object): """Context passed to WindowFn.merge() to perform merging, if any.""" - def __init__(self, windows: Iterable[BoundedWindow]) -> None: + def __init__(self, windows: Iterable['BoundedWindow']) -> None: self.windows = list(windows) def merge( self, - to_be_merged: Iterable[BoundedWindow], - merge_result: BoundedWindow) -> None: + to_be_merged: Iterable['BoundedWindow'], + merge_result: 'BoundedWindow') -> None: raise NotImplementedError @abc.abstractmethod - def merge(self, merge_context: WindowFn.MergeContext) -> None: + def merge(self, merge_context: 'WindowFn.MergeContext') -> None: """Returns a window that is the result of merging a set of windows.""" raise NotImplementedError @@ -171,7 +172,7 @@ def get_window_coder(self) -> coders.Coder: raise NotImplementedError def get_transformed_output_time( - self, window: BoundedWindow, input_timestamp: Timestamp) -> Timestamp: # pylint: disable=unused-argument + self, window: 'BoundedWindow', input_timestamp: Timestamp) -> Timestamp: # pylint: disable=unused-argument """Given input time and output window, returns output time for window. If TimestampCombiner.OUTPUT_AT_EARLIEST_TRANSFORMED is used in the @@ -260,10 +261,10 @@ def __lt__(self, other): return self.end < other.end return hash(self) < hash(other) - def intersects(self, other: IntervalWindow) -> bool: + def intersects(self, other: 'IntervalWindow') -> bool: return other.start < self.end or self.start < other.end - def union(self, other: IntervalWindow) -> IntervalWindow: + def union(self, other: 'IntervalWindow') -> 'IntervalWindow': return IntervalWindow( min(self.start, other.start), max(self.end, other.end)) @@ -301,7 +302,7 @@ def __lt__(self, other): class GlobalWindow(BoundedWindow): """The default window into which all data is placed (via GlobalWindows).""" - _instance: GlobalWindow = None + _instance: Optional['GlobalWindow'] = None def __new__(cls): if cls._instance is None: @@ -385,7 +386,7 @@ def to_runner_api_parameter(self, context): @staticmethod @urns.RunnerApiFn.register_urn(common_urns.global_windows.urn, None) def from_runner_api_parameter( - unused_fn_parameter, unused_context) -> GlobalWindows: + unused_fn_parameter, unused_context) -> 'GlobalWindows': return GlobalWindows() @@ -446,7 +447,7 @@ def to_runner_api_parameter(self, context): @urns.RunnerApiFn.register_urn( common_urns.fixed_windows.urn, standard_window_fns_pb2.FixedWindowsPayload) - def from_runner_api_parameter(fn_parameter, unused_context) -> FixedWindows: + def from_runner_api_parameter(fn_parameter, unused_context) -> 'FixedWindows': return FixedWindows( size=Duration(micros=fn_parameter.size.ToMicroseconds()), offset=Timestamp(micros=fn_parameter.offset.ToMicroseconds())) @@ -518,7 +519,8 @@ def to_runner_api_parameter(self, context): @urns.RunnerApiFn.register_urn( common_urns.sliding_windows.urn, standard_window_fns_pb2.SlidingWindowsPayload) - def from_runner_api_parameter(fn_parameter, unused_context) -> SlidingWindows: + def from_runner_api_parameter( + fn_parameter, unused_context) -> 'SlidingWindows': return SlidingWindows( size=Duration(micros=fn_parameter.size.ToMicroseconds()), offset=Timestamp(micros=fn_parameter.offset.ToMicroseconds()), @@ -585,6 +587,6 @@ def to_runner_api_parameter(self, context): @urns.RunnerApiFn.register_urn( common_urns.session_windows.urn, standard_window_fns_pb2.SessionWindowsPayload) - def from_runner_api_parameter(fn_parameter, unused_context) -> Sessions: + def from_runner_api_parameter(fn_parameter, unused_context) -> 'Sessions': return Sessions( gap_size=Duration(micros=fn_parameter.gap_size.ToMicroseconds())) diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py index ee0cb76d45d4..9c0cc2b8af4e 100644 --- a/sdks/python/apache_beam/typehints/decorators.py +++ b/sdks/python/apache_beam/typehints/decorators.py @@ -204,7 +204,7 @@ class IOTypeHints(NamedTuple): @classmethod def _make_origin( cls, - bases: List[IOTypeHints], + bases: List['IOTypeHints'], tb: bool = True, msg: Iterable[str] = ()) -> List[str]: if msg: @@ -232,12 +232,12 @@ def _make_origin( return res @classmethod - def empty(cls) -> IOTypeHints: + def empty(cls) -> 'IOTypeHints': """Construct a base IOTypeHints object with no hints.""" return IOTypeHints(None, None, []) @classmethod - def from_callable(cls, fn: Callable) -> Optional[IOTypeHints]: + def from_callable(cls, fn: Callable) -> Optional['IOTypeHints']: """Construct an IOTypeHints object from a callable's signature. Supports Python 3 annotations. For partial annotations, sets unknown types @@ -291,19 +291,19 @@ def from_callable(cls, fn: Callable) -> Optional[IOTypeHints]: output_types=(tuple(output_args), {}), origin=cls._make_origin([], tb=False, msg=msg)) - def with_input_types(self, *args, **kwargs) -> IOTypeHints: + def with_input_types(self, *args, **kwargs) -> 'IOTypeHints': return self._replace( input_types=(args, kwargs), origin=self._make_origin([self])) - def with_output_types(self, *args, **kwargs) -> IOTypeHints: + def with_output_types(self, *args, **kwargs) -> 'IOTypeHints': return self._replace( output_types=(args, kwargs), origin=self._make_origin([self])) - def with_input_types_from(self, other: IOTypeHints) -> IOTypeHints: + def with_input_types_from(self, other: 'IOTypeHints') -> 'IOTypeHints': return self._replace( input_types=other.input_types, origin=self._make_origin([self])) - def with_output_types_from(self, other: IOTypeHints) -> IOTypeHints: + def with_output_types_from(self, other: 'IOTypeHints') -> 'IOTypeHints': return self._replace( output_types=other.output_types, origin=self._make_origin([self])) @@ -353,10 +353,11 @@ def strip_pcoll_helper( my_type: any, has_my_type: Callable[[], bool], my_key: str, - special_containers: List[Union[PBegin, PDone, PCollection]], # noqa: F821 + special_containers: List[ + Union['PBegin', 'PDone', 'PCollection']], # noqa: F821 error_str: str, source_str: str - ) -> IOTypeHints: + ) -> 'IOTypeHints': from apache_beam.pvalue import PCollection if not has_my_type() or not my_type or len(my_type[0]) != 1: @@ -390,7 +391,7 @@ def strip_pcoll_helper( origin=self._make_origin([self], tb=False, msg=[source_str]), **kwarg_dict) - def strip_iterable(self) -> IOTypeHints: + def strip_iterable(self) -> 'IOTypeHints': """Removes outer Iterable (or equivalent) from output type. Only affects instances with simple output types, otherwise is a no-op. @@ -429,7 +430,7 @@ def strip_iterable(self) -> IOTypeHints: output_types=((yielded_type, ), {}), origin=self._make_origin([self], tb=False, msg=['strip_iterable()'])) - def with_defaults(self, hints: Optional[IOTypeHints]) -> IOTypeHints: + def with_defaults(self, hints: Optional['IOTypeHints']) -> 'IOTypeHints': if not hints: return self if not self: diff --git a/sdks/python/apache_beam/utils/profiler.py b/sdks/python/apache_beam/utils/profiler.py index 7463f50eb55b..c75fdcc5878d 100644 --- a/sdks/python/apache_beam/utils/profiler.py +++ b/sdks/python/apache_beam/utils/profiler.py @@ -45,8 +45,8 @@ class Profile(object): SORTBY = 'cumulative' - profile_output: str = None - stats: pstats.Stats = None + profile_output: str + stats: pstats.Stats def __init__( self, @@ -139,7 +139,7 @@ def default_file_copy_fn(src, dest): filesystems.FileSystems.rename([dest + '.tmp'], [dest]) @staticmethod - def factory_from_options(options) -> Optional[Callable[..., Profile]]: + def factory_from_options(options) -> Optional[Callable[..., 'Profile']]: if options.profile_cpu or options.profile_memory: def create_profiler(profile_id, **kwargs): diff --git a/sdks/python/apache_beam/utils/timestamp.py b/sdks/python/apache_beam/utils/timestamp.py index dbc768308e60..3f585eecae08 100644 --- a/sdks/python/apache_beam/utils/timestamp.py +++ b/sdks/python/apache_beam/utils/timestamp.py @@ -65,7 +65,7 @@ def __init__( self.micros = int(seconds * 1000000) + int(micros) @staticmethod - def of(seconds: TimestampTypes) -> Timestamp: + def of(seconds: TimestampTypes) -> 'Timestamp': """Return the Timestamp for the given number of seconds. If the input is already a Timestamp, the input itself will be returned. @@ -88,7 +88,7 @@ def of(seconds: TimestampTypes) -> Timestamp: 'Cannot interpret %s %s as Timestamp.' % (seconds, type(seconds))) @staticmethod - def now() -> Timestamp: + def now() -> 'Timestamp': return Timestamp(seconds=time.time()) @staticmethod @@ -96,7 +96,7 @@ def _epoch_datetime_utc() -> datetime.datetime: return datetime.datetime.fromtimestamp(0, pytz.utc) @classmethod - def from_utc_datetime(cls, dt: datetime.datetime) -> Timestamp: + def from_utc_datetime(cls, dt: datetime.datetime) -> 'Timestamp': """Create a ``Timestamp`` instance from a ``datetime.datetime`` object. Args: @@ -113,7 +113,7 @@ def from_utc_datetime(cls, dt: datetime.datetime) -> Timestamp: return Timestamp(duration.total_seconds()) @classmethod - def from_rfc3339(cls, rfc3339: str) -> Timestamp: + def from_rfc3339(cls, rfc3339: str) -> 'Timestamp': """Create a ``Timestamp`` instance from an RFC 3339 compliant string. .. note:: @@ -134,11 +134,11 @@ def seconds(self) -> int: """Returns the timestamp in seconds.""" return self.micros // 1000000 - def predecessor(self) -> Timestamp: + def predecessor(self) -> 'Timestamp': """Returns the largest timestamp smaller than self.""" return Timestamp(micros=self.micros - 1) - def successor(self) -> Timestamp: + def successor(self) -> 'Timestamp': """Returns the smallest timestamp larger than self.""" return Timestamp(micros=self.micros + 1) @@ -187,7 +187,7 @@ def to_proto(self) -> timestamp_pb2.Timestamp: return timestamp_pb2.Timestamp(seconds=secs, nanos=nanos) @staticmethod - def from_proto(timestamp_proto: timestamp_pb2.Timestamp) -> Timestamp: + def from_proto(timestamp_proto: timestamp_pb2.Timestamp) -> 'Timestamp': """Creates a Timestamp from a `google.protobuf.timestamp_pb2`. Note that the google has a sub-second resolution of nanoseconds whereas this @@ -245,30 +245,30 @@ def __ge__(self, other: TimestampDurationTypes) -> bool: def __hash__(self) -> int: return hash(self.micros) - def __add__(self, other: DurationTypes) -> Timestamp: + def __add__(self, other: DurationTypes) -> 'Timestamp': other = Duration.of(other) return Timestamp(micros=self.micros + other.micros) - def __radd__(self, other: DurationTypes) -> Timestamp: + def __radd__(self, other: DurationTypes) -> 'Timestamp': return self + other @overload - def __sub__(self, other: DurationTypes) -> Timestamp: + def __sub__(self, other: DurationTypes) -> 'Timestamp': pass @overload - def __sub__(self, other: Timestamp) -> Duration: + def __sub__(self, other: 'Timestamp') -> 'Duration': pass def __sub__( self, other: Union[DurationTypes, - Timestamp]) -> Union[Timestamp, Duration]: + 'Timestamp']) -> Union['Timestamp', 'Duration']: if isinstance(other, Timestamp): return Duration(micros=self.micros - other.micros) other = Duration.of(other) return Timestamp(micros=self.micros - other.micros) - def __mod__(self, other: DurationTypes) -> Duration: + def __mod__(self, other: DurationTypes) -> 'Duration': other = Duration.of(other) return Duration(micros=self.micros % other.micros) @@ -296,7 +296,7 @@ def __init__( self.micros = int(seconds * 1000000) + int(micros) @staticmethod - def of(seconds: DurationTypes) -> Duration: + def of(seconds: DurationTypes) -> 'Duration': """Return the Duration for the given number of seconds since Unix epoch. If the input is already a Duration, the input itself will be returned. @@ -321,7 +321,7 @@ def to_proto(self) -> duration_pb2.Duration: return duration_pb2.Duration(seconds=secs, nanos=nanos) @staticmethod - def from_proto(duration_proto: duration_pb2.Duration) -> Duration: + def from_proto(duration_proto: duration_pb2.Duration) -> 'Duration': """Creates a Duration from a `google.protobuf.duration_pb2`. Note that the google has a sub-second resolution of nanoseconds whereas this @@ -387,34 +387,34 @@ def __ge__(self, other: TimestampDurationTypes) -> bool: def __hash__(self) -> int: return hash(self.micros) - def __neg__(self) -> Duration: + def __neg__(self) -> 'Duration': return Duration(micros=-self.micros) - def __add__(self, other: DurationTypes) -> Duration: + def __add__(self, other: DurationTypes) -> 'Duration': if isinstance(other, Timestamp): # defer to Timestamp.__add__ return NotImplemented other = Duration.of(other) return Duration(micros=self.micros + other.micros) - def __radd__(self, other: DurationTypes) -> Duration: + def __radd__(self, other: DurationTypes) -> 'Duration': return self + other - def __sub__(self, other: DurationTypes) -> Duration: + def __sub__(self, other: DurationTypes) -> 'Duration': other = Duration.of(other) return Duration(micros=self.micros - other.micros) - def __rsub__(self, other: DurationTypes) -> Duration: + def __rsub__(self, other: DurationTypes) -> 'Duration': return -(self - other) - def __mul__(self, other: DurationTypes) -> Duration: + def __mul__(self, other: DurationTypes) -> 'Duration': other = Duration.of(other) return Duration(micros=self.micros * other.micros // 1000000) - def __rmul__(self, other: DurationTypes) -> Duration: + def __rmul__(self, other: DurationTypes) -> 'Duration': return self * other - def __mod__(self, other: DurationTypes) -> Duration: + def __mod__(self, other: DurationTypes) -> 'Duration': other = Duration.of(other) return Duration(micros=self.micros % other.micros) diff --git a/sdks/python/apache_beam/utils/urns.py b/sdks/python/apache_beam/utils/urns.py index 9d2f393fd7a3..2647a0200bde 100644 --- a/sdks/python/apache_beam/utils/urns.py +++ b/sdks/python/apache_beam/utils/urns.py @@ -38,10 +38,10 @@ from google.protobuf import wrappers_pb2 from apache_beam.internal import pickler +from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.utils import proto_utils if TYPE_CHECKING: - from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.runners.pipeline_context import PipelineContext T = TypeVar('T') @@ -75,7 +75,7 @@ class RunnerApiFn(object): # concrete implementation. # @abc.abstractmethod def to_runner_api_parameter( - self, unused_context: PipelineContext) -> Tuple[str, Any]: + self, unused_context: 'PipelineContext') -> Tuple[str, Any]: """Returns the urn and payload for this Fn. The returned urn(s) should be registered with `register_urn`. @@ -88,8 +88,8 @@ def register_urn( cls, urn: str, parameter_type: Type[T], - ) -> Callable[[Callable[[T, PipelineContext], Any]], - Callable[[T, PipelineContext], Any]]: + ) -> Callable[[Callable[[T, 'PipelineContext'], Any]], + Callable[[T, 'PipelineContext'], Any]]: pass @classmethod @@ -98,8 +98,8 @@ def register_urn( cls, urn: str, parameter_type: None, - ) -> Callable[[Callable[[bytes, PipelineContext], Any]], - Callable[[bytes, PipelineContext], Any]]: + ) -> Callable[[Callable[[bytes, 'PipelineContext'], Any]], + Callable[[bytes, 'PipelineContext'], Any]]: pass @classmethod @@ -108,7 +108,7 @@ def register_urn( cls, urn: str, parameter_type: Type[T], - fn: Callable[[T, PipelineContext], Any]) -> None: + fn: Callable[[T, 'PipelineContext'], Any]) -> None: pass @classmethod @@ -117,7 +117,7 @@ def register_urn( cls, urn: str, parameter_type: None, - fn: Callable[[bytes, PipelineContext], Any]) -> None: + fn: Callable[[bytes, 'PipelineContext'], Any]) -> None: pass @classmethod @@ -159,12 +159,11 @@ def register_pickle_urn(cls, pickle_urn): unused_context: pickler.loads(proto.value)) def to_runner_api( - self, context: PipelineContext) -> beam_runner_api_pb2.FunctionSpec: + self, context: 'PipelineContext') -> beam_runner_api_pb2.FunctionSpec: """Returns an FunctionSpec encoding this Fn. Prefer overriding self.to_runner_api_parameter. """ - from apache_beam.portability.api import beam_runner_api_pb2 urn, typed_param = self.to_runner_api_parameter(context) return beam_runner_api_pb2.FunctionSpec( urn=urn, @@ -175,7 +174,7 @@ def to_runner_api( def from_runner_api( cls: Type[RunnerApiFnT], fn_proto: beam_runner_api_pb2.FunctionSpec, - context: PipelineContext) -> RunnerApiFnT: + context: 'PipelineContext') -> RunnerApiFnT: """Converts from an FunctionSpec to a Fn object. Prefer registering a urn with its parameter type and constructor. From d4de077a2a96d2fed23ba15bf5e36a3ccb8150b1 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Wed, 3 Jul 2024 08:49:11 -0700 Subject: [PATCH 26/29] Fix bad type declarations. --- sdks/python/apache_beam/dataframe/convert.py | 8 ++--- .../examples/cookbook/bigtableio_it_test.py | 2 +- .../apache_beam/internal/module_test.py | 4 ++- .../runners/direct/evaluation_context.py | 30 +++++++++---------- .../apache_beam/runners/direct/executor.py | 24 +++++++-------- .../runners/direct/sdf_direct_runner.py | 5 +--- .../runners/direct/test_stream_impl.py | 4 +-- .../runners/direct/transform_evaluator.py | 12 ++++---- .../runners/direct/watermark_manager.py | 4 +-- .../interactive/options/capture_control.py | 5 ++-- .../interactive/options/capture_limiters.py | 2 ++ .../runners/interactive/recording_manager.py | 24 +++++++-------- 12 files changed, 63 insertions(+), 61 deletions(-) diff --git a/sdks/python/apache_beam/dataframe/convert.py b/sdks/python/apache_beam/dataframe/convert.py index 0ccd4489767b..e44cc429eac1 100644 --- a/sdks/python/apache_beam/dataframe/convert.py +++ b/sdks/python/apache_beam/dataframe/convert.py @@ -87,10 +87,10 @@ def to_dataframe( # Note that the pipeline (indirectly) holds references to the transforms which # keeps both the PCollections and expressions alive. This ensures the # expression's ids are never accidentally re-used. -TO_PCOLLECTION_CACHE: weakref.WeakValueDictionary[ - str, pvalue.PCollection] = weakref.WeakValueDictionary() -UNBATCHED_CACHE: weakref.WeakValueDictionary[ - str, pvalue.PCollection] = weakref.WeakValueDictionary() +TO_PCOLLECTION_CACHE: 'weakref.WeakValueDictionary[str, pvalue.PCollection]' = ( + weakref.WeakValueDictionary()) +UNBATCHED_CACHE: 'weakref.WeakValueDictionary[str, pvalue.PCollection]' = ( + weakref.WeakValueDictionary()) class RowsToDataFrameFn(beam.DoFn): diff --git a/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py b/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py index 8fdb4946ed5f..6b5573aa4569 100644 --- a/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py +++ b/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py @@ -53,7 +53,7 @@ if TYPE_CHECKING: import google.cloud.bigtable.instance -EXISTING_INSTANCES: List[google.cloud.bigtable.instance.Instance] = [] +EXISTING_INSTANCES: List['google.cloud.bigtable.instance.Instance'] = [] LABEL_KEY = 'python-bigtable-beam' label_stamp = datetime.datetime.utcnow().replace(tzinfo=UTC) label_stamp_micros = _microseconds_from_datetime(label_stamp) diff --git a/sdks/python/apache_beam/internal/module_test.py b/sdks/python/apache_beam/internal/module_test.py index 89e7d7eaa821..1bb5a9d424b9 100644 --- a/sdks/python/apache_beam/internal/module_test.py +++ b/sdks/python/apache_beam/internal/module_test.py @@ -22,6 +22,8 @@ import re import sys +from typing import Any + class TopClass(object): class NestedClass(object): @@ -63,7 +65,7 @@ def get(self): class RecursiveClass(object): """A class that contains a reference to itself.""" - SELF_TYPE = None + SELF_TYPE: Any = None def __init__(self, datum): self.datum = 'RecursiveClass:%s' % datum diff --git a/sdks/python/apache_beam/runners/direct/evaluation_context.py b/sdks/python/apache_beam/runners/direct/evaluation_context.py index d42f9d43fe71..c34735499abc 100644 --- a/sdks/python/apache_beam/runners/direct/evaluation_context.py +++ b/sdks/python/apache_beam/runners/direct/evaluation_context.py @@ -31,21 +31,21 @@ from typing import Tuple from typing import Union +from apache_beam import pvalue +from apache_beam.pipeline import AppliedPTransform from apache_beam.runners.direct.direct_metrics import DirectMetrics from apache_beam.runners.direct.executor import TransformExecutor from apache_beam.runners.direct.watermark_manager import WatermarkManager from apache_beam.transforms import sideinputs from apache_beam.transforms.trigger import InMemoryUnmergedState from apache_beam.utils import counters +from apache_beam.utils.timestamp import Timestamp if TYPE_CHECKING: - from apache_beam import pvalue - from apache_beam.pipeline import AppliedPTransform from apache_beam.runners.direct.bundle_factory import BundleFactory, _Bundle from apache_beam.runners.direct.util import TimerFiring from apache_beam.runners.direct.util import TransformResult from apache_beam.runners.direct.watermark_manager import _TransformWatermarks - from apache_beam.utils.timestamp import Timestamp class _ExecutionContext(object): @@ -53,7 +53,7 @@ class _ExecutionContext(object): It holds the watermarks for that transform, as well as keyed states. """ - def __init__(self, watermarks: _TransformWatermarks, keyed_states): + def __init__(self, watermarks: '_TransformWatermarks', keyed_states): self.watermarks = watermarks self.keyed_states = keyed_states @@ -88,7 +88,7 @@ class _SideInputsContainer(object): It provides methods for blocking until a side-input is available and writing to a side input. """ - def __init__(self, side_inputs: Iterable[pvalue.AsSideInput]) -> None: + def __init__(self, side_inputs: Iterable['pvalue.AsSideInput']) -> None: self._lock = threading.Lock() self._views: Dict[pvalue.AsSideInput, _SideInputView] = {} self._transform_to_side_inputs: DefaultDict[ @@ -228,7 +228,7 @@ class EvaluationContext(object): def __init__( self, pipeline_options, - bundle_factory: BundleFactory, + bundle_factory: 'BundleFactory', root_transforms, value_to_consumers, step_names, @@ -283,9 +283,9 @@ def is_root_transform(self, applied_ptransform: AppliedPTransform) -> bool: def handle_result( self, - completed_bundle: _Bundle, + completed_bundle: '_Bundle', completed_timers, - result: TransformResult): + result: 'TransformResult'): """Handle the provided result produced after evaluating the input bundle. Handle the provided TransformResult, produced after evaluating @@ -339,7 +339,7 @@ def handle_result( return committed_bundles def _update_side_inputs_container( - self, committed_bundles: Iterable[_Bundle], result: TransformResult): + self, committed_bundles: Iterable['_Bundle'], result: 'TransformResult'): """Update the side inputs container if we are outputting into a side input. Look at the result, and if it's outputing into a PCollection that we have @@ -367,9 +367,9 @@ def schedule_pending_unblocked_tasks(self, executor_service): def _commit_bundles( self, - uncommitted_bundles: Iterable[_Bundle], - unprocessed_bundles: Iterable[_Bundle] - ) -> Tuple[Tuple[_Bundle, ...], Tuple[_Bundle, ...]]: + uncommitted_bundles: Iterable['_Bundle'], + unprocessed_bundles: Iterable['_Bundle'] + ) -> Tuple[Tuple['_Bundle', ...], Tuple['_Bundle', ...]]: """Commits bundles and returns a immutable set of committed bundles.""" for in_progress_bundle in uncommitted_bundles: producing_applied_ptransform = in_progress_bundle.pcollection.producer @@ -389,18 +389,18 @@ def get_execution_context( def create_bundle( self, output_pcollection: Union[pvalue.PBegin, - pvalue.PCollection]) -> _Bundle: + pvalue.PCollection]) -> '_Bundle': """Create an uncommitted bundle for the specified PCollection.""" return self._bundle_factory.create_bundle(output_pcollection) def create_empty_committed_bundle( - self, output_pcollection: pvalue.PCollection) -> _Bundle: + self, output_pcollection: pvalue.PCollection) -> '_Bundle': """Create empty bundle useful for triggering evaluation.""" return self._bundle_factory.create_empty_committed_bundle( output_pcollection) def extract_all_timers( - self) -> Tuple[List[Tuple[AppliedPTransform, List[TimerFiring]]], bool]: + self) -> Tuple[List[Tuple[AppliedPTransform, List['TimerFiring']]], bool]: return self._watermark_manager.extract_all_timers() def is_done(self, transform: Optional[AppliedPTransform] = None) -> bool: diff --git a/sdks/python/apache_beam/runners/direct/executor.py b/sdks/python/apache_beam/runners/direct/executor.py index 59e282b91bd9..e8be9d64f993 100644 --- a/sdks/python/apache_beam/runners/direct/executor.py +++ b/sdks/python/apache_beam/runners/direct/executor.py @@ -64,7 +64,7 @@ class _ExecutorServiceWorker(threading.Thread): TIMEOUT = 5 def __init__( - self, queue: queue.Queue['_ExecutorService.CallableTask'], index): + self, queue: 'queue.Queue[_ExecutorService.CallableTask]', index): super().__init__() self.queue = queue self._index = index @@ -89,7 +89,7 @@ def _get_task_or_none(self) -> Optional['_ExecutorService.CallableTask']: # Do not block indefinitely, otherwise we may not act for a requested # shutdown. return self.queue.get( - timeout='_ExecutorService._ExecutorServiceWorker.TIMEOUT') + timeout=_ExecutorService._ExecutorServiceWorker.TIMEOUT) except queue.Empty: return None @@ -145,7 +145,7 @@ def shutdown(self): class _TransformEvaluationState(object): - def __init__(self, executor_service, scheduled: Set[TransformExecutor]): + def __init__(self, executor_service, scheduled: Set['TransformExecutor']): self.executor_service = executor_service self.scheduled = scheduled @@ -229,7 +229,7 @@ def serial(self, step: Any) -> _SerialEvaluationState: return cached @property - def executors(self) -> FrozenSet[TransformExecutor]: + def executors(self) -> FrozenSet['TransformExecutor']: return frozenset(self._scheduled) @@ -242,7 +242,7 @@ class _CompletionCallback(object): """ def __init__( self, - evaluation_context: EvaluationContext, + evaluation_context: 'EvaluationContext', all_updates, timer_firings=None): self._evaluation_context = evaluation_context @@ -283,9 +283,9 @@ class TransformExecutor(_ExecutorService.CallableTask): def __init__( self, - transform_evaluator_registry: TransformEvaluatorRegistry, - evaluation_context: EvaluationContext, - input_bundle: _Bundle, + transform_evaluator_registry: 'TransformEvaluatorRegistry', + evaluation_context: 'EvaluationContext', + input_bundle: '_Bundle', fired_timers, applied_ptransform, completion_callback, @@ -430,7 +430,7 @@ def __init__( self, value_to_consumers, transform_evaluator_registry, - evaluation_context: EvaluationContext): + evaluation_context: 'EvaluationContext'): self.executor_service = _ExecutorService( _ExecutorServiceParallelExecutor.NUM_WORKERS) self.transform_executor_services = _TransformExecutorServices( @@ -472,7 +472,7 @@ def request_shutdown(self): self.executor_service.await_completion() self.evaluation_context.shutdown() - def schedule_consumers(self, committed_bundle: _Bundle) -> None: + def schedule_consumers(self, committed_bundle: '_Bundle') -> None: if committed_bundle.pcollection in self.value_to_consumers: consumers = self.value_to_consumers[committed_bundle.pcollection] for applied_ptransform in consumers: @@ -487,7 +487,7 @@ def schedule_unprocessed_bundle(self, applied_ptransform, unprocessed_bundle): def schedule_consumption( self, consumer_applied_ptransform, - committed_bundle: _Bundle, + committed_bundle: '_Bundle', fired_timers, on_complete): """Schedules evaluation of the given bundle with the transform.""" @@ -571,7 +571,7 @@ def __init__(self, exception=None): class _MonitorTask(_ExecutorService.CallableTask): """MonitorTask continuously runs to ensure that pipeline makes progress.""" - def __init__(self, executor: _ExecutorServiceParallelExecutor) -> None: + def __init__(self, executor: '_ExecutorServiceParallelExecutor') -> None: self._executor = executor @property diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py index 119383856ba2..e0a58db0ef3e 100644 --- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py @@ -23,7 +23,6 @@ import uuid from threading import Lock from threading import Timer -from typing import TYPE_CHECKING from typing import Any from typing import Iterable from typing import Optional @@ -32,6 +31,7 @@ from apache_beam import TimeDomain from apache_beam import pvalue from apache_beam.coders import typecoders +from apache_beam.io.iobase import WatermarkEstimator from apache_beam.pipeline import AppliedPTransform from apache_beam.pipeline import PTransformOverride from apache_beam.runners.common import DoFnContext @@ -47,9 +47,6 @@ from apache_beam.transforms.trigger import _ReadModifyWriteStateTag from apache_beam.utils.windowed_value import WindowedValue -if TYPE_CHECKING: - from apache_beam.iobase import WatermarkEstimator - class SplittableParDoOverride(PTransformOverride): """A transform override for ParDo transformss of SplittableDoFns. diff --git a/sdks/python/apache_beam/runners/direct/test_stream_impl.py b/sdks/python/apache_beam/runners/direct/test_stream_impl.py index 1cda97bc56eb..dce161bb0b99 100644 --- a/sdks/python/apache_beam/runners/direct/test_stream_impl.py +++ b/sdks/python/apache_beam/runners/direct/test_stream_impl.py @@ -309,8 +309,8 @@ def is_alive(): return not (shutdown_requested or evaluation_context.shutdown_requested) # The shared queue that allows the producer and consumer to communicate. - channel: Queue[Union[test_stream.Event, - _EndOfStream]] = Queue() # noqa: F821 + channel: 'Queue[Union[test_stream.Event, _EndOfStream]]' = ( + Queue()) # noqa: F821 event_stream = Thread( target=_TestStream._stream_events_from_rpc, args=(endpoint, output_tags, coder, channel, is_alive)) diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py index a7db67f7f098..b0278ba5356c 100644 --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py @@ -36,6 +36,7 @@ from apache_beam import io from apache_beam import pvalue from apache_beam.internal import pickler +from apache_beam.pipeline import AppliedPTransform from apache_beam.runners import common from apache_beam.runners.common import DoFnRunner from apache_beam.runners.common import DoFnState @@ -77,7 +78,6 @@ if TYPE_CHECKING: from apache_beam.io.gcp.pubsub import _PubSubSource from apache_beam.io.gcp.pubsub import PubsubMessage - from apache_beam.pipeline import AppliedPTransform from apache_beam.runners.direct.evaluation_context import EvaluationContext _LOGGER = logging.getLogger(__name__) @@ -90,9 +90,9 @@ class TransformEvaluatorRegistry(object): """ _test_evaluators_overrides: Dict[Type[core.PTransform], - Type[_TransformEvaluator]] = {} + Type['_TransformEvaluator']] = {} - def __init__(self, evaluation_context: EvaluationContext) -> None: + def __init__(self, evaluation_context: 'EvaluationContext') -> None: assert evaluation_context self._evaluation_context = evaluation_context self._evaluators: Dict[Type[core.PTransform], Type[_TransformEvaluator]] = { @@ -232,7 +232,7 @@ class _TransformEvaluator(object): """An evaluator of a specific application of a transform.""" def __init__( self, - evaluation_context: EvaluationContext, + evaluation_context: 'EvaluationContext', applied_ptransform: AppliedPTransform, input_committed_bundle, side_inputs): @@ -652,7 +652,7 @@ def process_element(self, element): pass def _read_from_pubsub( - self, timestamp_attribute) -> List[Tuple[Timestamp, PubsubMessage]]: + self, timestamp_attribute) -> List[Tuple[Timestamp, 'PubsubMessage']]: from apache_beam.io.gcp.pubsub import PubsubMessage from google.cloud import pubsub @@ -794,7 +794,7 @@ class _ParDoEvaluator(_TransformEvaluator): """TransformEvaluator for ParDo transform.""" def __init__( self, - evaluation_context: EvaluationContext, + evaluation_context: 'EvaluationContext', applied_ptransform: AppliedPTransform, input_committed_bundle, side_inputs, diff --git a/sdks/python/apache_beam/runners/direct/watermark_manager.py b/sdks/python/apache_beam/runners/direct/watermark_manager.py index b4b16c0a5d21..854d6475de04 100644 --- a/sdks/python/apache_beam/runners/direct/watermark_manager.py +++ b/sdks/python/apache_beam/runners/direct/watermark_manager.py @@ -85,7 +85,7 @@ def _update_input_transform_watermarks( input_transform_watermarks) def get_watermarks( - self, applied_ptransform: AppliedPTransform) -> _TransformWatermarks: + self, applied_ptransform: AppliedPTransform) -> '_TransformWatermarks': """Gets the input and output watermarks for an AppliedPTransform. If the applied_ptransform has not processed any elements, return a @@ -213,7 +213,7 @@ def __init__(self, clock, keyed_states, transform): self._label = str(transform) def update_input_transform_watermarks( - self, input_transform_watermarks: List[_TransformWatermarks]) -> None: + self, input_transform_watermarks: List['_TransformWatermarks']) -> None: with self._lock: self._input_transform_watermarks = input_transform_watermarks diff --git a/sdks/python/apache_beam/runners/interactive/options/capture_control.py b/sdks/python/apache_beam/runners/interactive/options/capture_control.py index c14deeeb3956..826b596bbc6d 100644 --- a/sdks/python/apache_beam/runners/interactive/options/capture_control.py +++ b/sdks/python/apache_beam/runners/interactive/options/capture_control.py @@ -25,6 +25,7 @@ import logging from datetime import timedelta +from typing import List from apache_beam.io.gcp.pubsub import ReadFromPubSub from apache_beam.runners.interactive import interactive_environment as ie @@ -45,7 +46,7 @@ def __init__(self): self._capture_size_limit = 1e9 self._test_limiters = None - def limiters(self) -> List[capture_limiters.Limiter]: + def limiters(self) -> List['capture_limiters.Limiter']: # noqa: F821 if self._test_limiters: return self._test_limiters @@ -55,7 +56,7 @@ def limiters(self) -> List[capture_limiters.Limiter]: ] def set_limiters_for_test( - self, limiters: List[capture_limiters.Limiter]) -> None: + self, limiters: List['capture_limiters.Limiter']) -> None: # noqa: F821 self._test_limiters = limiters diff --git a/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py b/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py index 3b2fb9f326ea..497772f94c36 100644 --- a/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py +++ b/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py @@ -20,7 +20,9 @@ For internal use only; no backwards-compatibility guarantees. """ +import datetime import threading +from typing import Any import pandas as pd diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager.py b/sdks/python/apache_beam/runners/interactive/recording_manager.py index a2470e693314..23c63884aa7c 100644 --- a/sdks/python/apache_beam/runners/interactive/recording_manager.py +++ b/sdks/python/apache_beam/runners/interactive/recording_manager.py @@ -19,6 +19,10 @@ import threading import time import warnings +from typing import Any +from typing import Dict +from typing import List +from typing import Union import pandas as pd @@ -75,21 +79,19 @@ def display_id(self, suffix: str) -> str: """Returns a unique id able to be displayed in a web browser.""" return utils.obfuscate(self._cache_key, suffix) - def is_computed(self) -> boolean: + def is_computed(self) -> bool: # noqa: F821 """Returns True if no more elements will be recorded.""" return self._pcoll in ie.current_env().computed_pcollections - def is_done(self) -> boolean: + def is_done(self) -> bool: # noqa: F821 """Returns True if no more new elements will be yielded.""" return self._done - def read(self, tail: boolean = True) -> Any: - # noqa: F821 - + def read(self, tail: bool = True) -> Any: """Reads the elements currently recorded.""" # Get the cache manager and wait until the file exists. @@ -147,7 +149,7 @@ def __init__( self, user_pipeline: beam.Pipeline, pcolls: List[beam.pvalue.PCollection], # noqa: F821 - result: beam.runner.PipelineResult, + result: 'beam.runner.PipelineResult', max_n: int, max_duration_secs: float, ): @@ -205,9 +207,7 @@ def _mark_all_computed(self) -> None: if self._result.state is PipelineState.DONE and self._set_computed: ie.current_env().mark_pcollection_computed(self._pcolls) - def is_computed(self) -> boolean: - # noqa: F821 - + def is_computed(self) -> bool: """Returns True if all PCollections are computed.""" return all(s.is_computed() for s in self._streams.values()) @@ -240,7 +240,7 @@ def wait_until_finish(self) -> None: self._mark_computed.join() return self._result.state - def describe(self) -> dict[str, int]: + def describe(self) -> Dict[str, int]: """Returns a dictionary describing the cache and recording.""" cache_manager = ie.current_env().get_cache_manager(self._user_pipeline) @@ -255,7 +255,7 @@ def __init__( self, user_pipeline: beam.Pipeline, pipeline_var: str = None, - test_limiters: list[Limiter] = None) -> None: + test_limiters: List['Limiter'] = None) -> None: # noqa: F821 self.user_pipeline: beam.Pipeline = user_pipeline @@ -336,7 +336,7 @@ def cancel(self: None) -> None: # evict the BCJ after they complete. ie.current_env().evict_background_caching_job(self.user_pipeline) - def describe(self) -> dict[str, int]: + def describe(self) -> Dict[str, int]: """Returns a dictionary describing the cache and recording.""" cache_manager = ie.current_env().get_cache_manager(self.user_pipeline) From 14c52d66ec653c87551e84bfc3cb8882b878757c Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Wed, 3 Jul 2024 16:37:46 -0700 Subject: [PATCH 27/29] Fix bad typing in PubSub tests. --- sdks/python/apache_beam/io/gcp/pubsub.py | 7 ++++--- sdks/python/apache_beam/io/gcp/pubsub_test.py | 10 +++++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py index 1c6cf31a48ce..32e7fbe5ed58 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub.py +++ b/sdks/python/apache_beam/io/gcp/pubsub.py @@ -37,6 +37,7 @@ from typing import NamedTuple from typing import Optional from typing import Tuple +from typing import Union from apache_beam import coders from apache_beam.io import iobase @@ -357,15 +358,15 @@ def message_to_proto_str(element: PubsubMessage) -> bytes: return element._to_proto_str(for_publish=True) @staticmethod - def bytes_to_proto_str(element: bytes) -> bytes: + def bytes_to_proto_str(element: Union[bytes, str]) -> bytes: msg = PubsubMessage(element, {}) return msg._to_proto_str(for_publish=True) def expand(self, pcoll): if self.with_attributes: - pcoll = pcoll | 'ToProtobuf' >> Map(self.message_to_proto_str) + pcoll = pcoll | 'ToProtobufX' >> Map(self.message_to_proto_str) else: - pcoll = pcoll | 'ToProtobuf' >> Map(self.bytes_to_proto_str) + pcoll = pcoll | 'ToProtobufY' >> Map(self.bytes_to_proto_str) pcoll.element_type = bytes return pcoll | Write(self._sink) diff --git a/sdks/python/apache_beam/io/gcp/pubsub_test.py b/sdks/python/apache_beam/io/gcp/pubsub_test.py index 7b4a4d5c93b9..f704338626ee 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub_test.py +++ b/sdks/python/apache_beam/io/gcp/pubsub_test.py @@ -391,6 +391,7 @@ def test_expand(self): pcoll = ( p | ReadFromPubSub('projects/fakeprj/topics/baz') + | beam.Map(lambda x: PubsubMessage(x)) | WriteToPubSub( 'projects/fakeprj/topics/a_topic', with_attributes=True) | beam.Map(lambda x: x)) @@ -875,7 +876,7 @@ def test_write_messages_with_attributes_error(self, mock_pubsub): options = PipelineOptions([]) options.view_as(StandardOptions).streaming = True - with self.assertRaisesRegex(AttributeError, r'str.*has no attribute.*data'): + with self.assertRaisesRegex(Exception, r'Type hint violation'): with TestPipeline(options=options) as p: _ = ( p @@ -897,7 +898,9 @@ def test_write_messages_unsupported_features(self, mock_pubsub): p | Create(payloads) | WriteToPubSub( - 'projects/fakeprj/topics/a_topic', id_label='a_label')) + 'projects/fakeprj/topics/a_topic', + id_label='a_label', + with_attributes=True)) options = PipelineOptions([]) options.view_as(StandardOptions).streaming = True @@ -909,7 +912,8 @@ def test_write_messages_unsupported_features(self, mock_pubsub): | Create(payloads) | WriteToPubSub( 'projects/fakeprj/topics/a_topic', - timestamp_attribute='timestamp')) + timestamp_attribute='timestamp', + with_attributes=True)) def test_runner_api_transformation(self, unused_mock_pubsub): sink = _PubSubSink( From 64e6194b9486e14e43fbd02624594238ae9de277 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Wed, 3 Jul 2024 16:45:36 -0700 Subject: [PATCH 28/29] Preserve existing linter comments. --- sdks/python/apache_beam/internal/module_test.py | 1 - sdks/python/apache_beam/runners/direct/test_stream_impl.py | 5 +++-- .../apache_beam/runners/interactive/recording_manager.py | 5 +---- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/sdks/python/apache_beam/internal/module_test.py b/sdks/python/apache_beam/internal/module_test.py index 1bb5a9d424b9..ff0ad0c564e6 100644 --- a/sdks/python/apache_beam/internal/module_test.py +++ b/sdks/python/apache_beam/internal/module_test.py @@ -21,7 +21,6 @@ import re import sys - from typing import Any diff --git a/sdks/python/apache_beam/runners/direct/test_stream_impl.py b/sdks/python/apache_beam/runners/direct/test_stream_impl.py index dce161bb0b99..c720418b05ed 100644 --- a/sdks/python/apache_beam/runners/direct/test_stream_impl.py +++ b/sdks/python/apache_beam/runners/direct/test_stream_impl.py @@ -30,6 +30,7 @@ from queue import Empty as EmptyException from queue import Queue from threading import Thread +from typing import Union import grpc @@ -309,8 +310,8 @@ def is_alive(): return not (shutdown_requested or evaluation_context.shutdown_requested) # The shared queue that allows the producer and consumer to communicate. - channel: 'Queue[Union[test_stream.Event, _EndOfStream]]' = ( - Queue()) # noqa: F821 + channel: 'Queue[Union[test_stream.Event, _EndOfStream]]' = ( # noqa: F821 + Queue()) event_stream = Thread( target=_TestStream._stream_events_from_rpc, args=(endpoint, output_tags, coder, channel, is_alive)) diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager.py b/sdks/python/apache_beam/runners/interactive/recording_manager.py index 23c63884aa7c..2e113240c09c 100644 --- a/sdks/python/apache_beam/runners/interactive/recording_manager.py +++ b/sdks/python/apache_beam/runners/interactive/recording_manager.py @@ -255,8 +255,7 @@ def __init__( self, user_pipeline: beam.Pipeline, pipeline_var: str = None, - test_limiters: List['Limiter'] = None) -> None: - # noqa: F821 + test_limiters: List['Limiter'] = None) -> None: # noqa: F821 self.user_pipeline: beam.Pipeline = user_pipeline self.pipeline_var: str = pipeline_var if pipeline_var else '' @@ -265,8 +264,6 @@ def __init__( self._test_limiters = test_limiters if test_limiters else [] def _watch(self, pcolls: List[beam.pvalue.PCollection]) -> None: - # noqa: F821 - """Watch any pcollections not being watched. This allows for the underlying caching layer to identify the PCollection as From a0ba8dea7d829f6f4ea869238fa8f281b371d3cc Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Mon, 8 Jul 2024 16:16:39 -0700 Subject: [PATCH 29/29] isort --- sdks/python/apache_beam/runners/direct/watermark_manager.py | 2 +- sdks/python/apache_beam/runners/pipeline_context.py | 4 ++-- .../runners/portability/fn_api_runner/execution.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/runners/direct/watermark_manager.py b/sdks/python/apache_beam/runners/direct/watermark_manager.py index 854d6475de04..666ade6cf82d 100644 --- a/sdks/python/apache_beam/runners/direct/watermark_manager.py +++ b/sdks/python/apache_beam/runners/direct/watermark_manager.py @@ -28,8 +28,8 @@ from typing import Tuple from apache_beam import pipeline -from apache_beam.pipeline import AppliedPTransform from apache_beam import pvalue +from apache_beam.pipeline import AppliedPTransform from apache_beam.runners.direct.util import TimerFiring from apache_beam.utils.timestamp import MAX_TIMESTAMP from apache_beam.utils.timestamp import MIN_TIMESTAMP diff --git a/sdks/python/apache_beam/runners/pipeline_context.py b/sdks/python/apache_beam/runners/pipeline_context.py index 6af40e10332d..0a03c96bc19b 100644 --- a/sdks/python/apache_beam/runners/pipeline_context.py +++ b/sdks/python/apache_beam/runners/pipeline_context.py @@ -38,10 +38,10 @@ from typing_extensions import Protocol from apache_beam import coders -from apache_beam.coders.coder_impl import IterableStateReader -from apache_beam.coders.coder_impl import IterableStateWriter from apache_beam import pipeline from apache_beam import pvalue +from apache_beam.coders.coder_impl import IterableStateReader +from apache_beam.coders.coder_impl import IterableStateWriter from apache_beam.internal import pickler from apache_beam.pipeline import ComponentIdMap from apache_beam.portability.api import beam_fn_api_pb2 diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py index 4976ff78aa99..e69e37495f64 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py @@ -48,19 +48,19 @@ from apache_beam import coders from apache_beam.coders.coder_impl import CoderImpl +from apache_beam.coders.coder_impl import WindowedValueCoderImpl from apache_beam.coders.coder_impl import create_InputStream from apache_beam.coders.coder_impl import create_OutputStream -from apache_beam.coders.coder_impl import WindowedValueCoderImpl from apache_beam.coders.coders import WindowedValueCoder from apache_beam.portability import common_urns from apache_beam.portability import python_urns from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.portability.api import beam_runner_api_pb2 +from apache_beam.portability.api import endpoints_pb2 from apache_beam.runners import pipeline_context from apache_beam.runners.common import ENCODED_IMPULSE_VALUE from apache_beam.runners.direct.clock import RealClock from apache_beam.runners.direct.clock import TestClock -from apache_beam.portability.api import endpoints_pb2 from apache_beam.runners.portability.fn_api_runner import translations from apache_beam.runners.portability.fn_api_runner.translations import DataInput from apache_beam.runners.portability.fn_api_runner.translations import DataOutput