diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 0f1700f52486..89c137fe4366 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -45,6 +45,7 @@ from typing import Iterator from typing import List from typing import Mapping +from typing import MutableMapping from typing import Optional from typing import Set from typing import Tuple @@ -130,18 +131,16 @@ class RunnerIOOperation(operations.Operation): """Common baseclass for runner harness IO operations.""" - - def __init__(self, - name_context, # type: common.NameContext - step_name, # type: Any - consumers, # type: Mapping[Any, Iterable[operations.Operation]] - counter_factory, # type: counters.CounterFactory - state_sampler, # type: statesampler.StateSampler - windowed_coder, # type: coders.Coder - transform_id, # type: str - data_channel # type: data_plane.DataChannel - ): - # type: (...) -> None + def __init__( + self, + name_context: common.NameContext, + step_name: Any, + consumers: Mapping[Any, Iterable[operations.Operation]], + counter_factory: counters.CounterFactory, + state_sampler: statesampler.StateSampler, + windowed_coder: coders.Coder, + transform_id: str, + data_channel: data_plane.DataChannel) -> None: super().__init__(name_context, None, counter_factory, state_sampler) self.windowed_coder = windowed_coder self.windowed_coder_impl = windowed_coder.get_impl() @@ -157,36 +156,32 @@ def __init__(self, class DataOutputOperation(RunnerIOOperation): """A sink-like operation that gathers outputs to be sent back to the runner. """ - def set_output_stream(self, output_stream): - # type: (data_plane.ClosableOutputStream) -> None + def set_output_stream( + self, output_stream: data_plane.ClosableOutputStream) -> None: self.output_stream = output_stream - def process(self, windowed_value): - # type: (windowed_value.WindowedValue) -> None + def process(self, windowed_value: windowed_value.WindowedValue) -> None: self.windowed_coder_impl.encode_to_stream( windowed_value, self.output_stream, True) self.output_stream.maybe_flush() - def finish(self): - # type: () -> None + def finish(self) -> None: super().finish() self.output_stream.close() class DataInputOperation(RunnerIOOperation): """A source-like operation that gathers input from the runner.""" - - def __init__(self, - operation_name, # type: common.NameContext - step_name, - consumers, # type: Mapping[Any, List[operations.Operation]] - counter_factory, # type: counters.CounterFactory - state_sampler, # type: statesampler.StateSampler - windowed_coder, # type: coders.Coder - transform_id, - data_channel # type: data_plane.GrpcClientDataChannel - ): - # type: (...) -> None + def __init__( + self, + operation_name: common.NameContext, + step_name, + consumers: Mapping[Any, List[operations.Operation]], + counter_factory: counters.CounterFactory, + state_sampler: statesampler.StateSampler, + windowed_coder: coders.Coder, + transform_id, + data_channel: data_plane.GrpcClientDataChannel) -> None: super().__init__( operation_name, step_name, @@ -217,18 +212,15 @@ def setup(self, data_sampler=None): producer_batch_converter=self.get_output_batch_converter()) ] - def start(self): - # type: () -> None + def start(self) -> None: super().start() with self.splitting_lock: self.started = True - def process(self, windowed_value): - # type: (windowed_value.WindowedValue) -> None + def process(self, windowed_value: windowed_value.WindowedValue) -> None: self.output(windowed_value) - def process_encoded(self, encoded_windowed_values): - # type: (bytes) -> None + def process_encoded(self, encoded_windowed_values: bytes) -> None: input_stream = coder_impl.create_InputStream(encoded_windowed_values) while input_stream.size() > 0: with self.splitting_lock: @@ -244,8 +236,9 @@ def process_encoded(self, encoded_windowed_values): str(self.windowed_coder)) from exn self.output(decoded_value) - def monitoring_infos(self, transform_id, tag_to_pcollection_id): - # type: (str, Dict[str, str]) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo] + def monitoring_infos( + self, transform_id: str, tag_to_pcollection_id: Dict[str, str] + ) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo]: all_monitoring_infos = super().monitoring_infos( transform_id, tag_to_pcollection_id) read_progress_info = monitoring_infos.int64_counter( @@ -259,8 +252,13 @@ def monitoring_infos(self, transform_id, tag_to_pcollection_id): # TODO(https://github.com/apache/beam/issues/19737): typing not compatible # with super type def try_split( # type: ignore[override] - self, fraction_of_remainder, total_buffer_size, allowed_split_points): - # type: (...) -> Optional[Tuple[int, Iterable[operations.SdfSplitResultsPrimary], Iterable[operations.SdfSplitResultsResidual], int]] + self, fraction_of_remainder, total_buffer_size, allowed_split_points + ) -> Optional[ + Tuple[ + int, + Iterable[operations.SdfSplitResultsPrimary], + Iterable[operations.SdfSplitResultsResidual], + int]]: with self.splitting_lock: if not self.started: return None @@ -314,9 +312,10 @@ def is_valid_split_point(index): # try splitting at the current element. if (keep_of_element_remainder < 1 and is_valid_split_point(index) and is_valid_split_point(index + 1)): - split = try_split( - keep_of_element_remainder - ) # type: Optional[Tuple[Iterable[operations.SdfSplitResultsPrimary], Iterable[operations.SdfSplitResultsResidual]]] + split: Optional[Tuple[ + Iterable[operations.SdfSplitResultsPrimary], + Iterable[operations.SdfSplitResultsResidual]]] = try_split( + keep_of_element_remainder) if split: element_primaries, element_residuals = split return index - 1, element_primaries, element_residuals, index + 1 @@ -343,15 +342,13 @@ def is_valid_split_point(index): else: return None - def finish(self): - # type: () -> None + def finish(self) -> None: super().finish() with self.splitting_lock: self.index += 1 self.started = False - def reset(self): - # type: () -> None + def reset(self) -> None: with self.splitting_lock: self.index = -1 self.stop = float('inf') @@ -359,12 +356,12 @@ def reset(self): class _StateBackedIterable(object): - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - state_key, # type: beam_fn_api_pb2.StateKey - coder_or_impl, # type: Union[coders.Coder, coder_impl.CoderImpl] - ): - # type: (...) -> None + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + state_key: beam_fn_api_pb2.StateKey, + coder_or_impl: Union[coders.Coder, coder_impl.CoderImpl], + ) -> None: self._state_handler = state_handler self._state_key = state_key if isinstance(coder_or_impl, coders.Coder): @@ -372,8 +369,7 @@ def __init__(self, else: self._coder_impl = coder_or_impl - def __iter__(self): - # type: () -> Iterator[Any] + def __iter__(self) -> Iterator[Any]: return iter( self._state_handler.blocking_get(self._state_key, self._coder_impl)) @@ -391,15 +387,15 @@ class StateBackedSideInputMap(object): _BULK_READ_FULLY = "fully" _BULK_READ_PARTIALLY = "partially" - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - transform_id, # type: str - tag, # type: Optional[str] - side_input_data, # type: pvalue.SideInputData - coder, # type: WindowedValueCoder - use_bulk_read = False, # type: bool - ): - # type: (...) -> None + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + transform_id: str, + tag: Optional[str], + side_input_data: pvalue.SideInputData, + coder: WindowedValueCoder, + use_bulk_read: bool = False, + ) -> None: self._state_handler = state_handler self._transform_id = transform_id self._tag = tag @@ -407,7 +403,7 @@ def __init__(self, self._element_coder = coder.wrapped_value_coder self._target_window_coder = coder.window_coder # TODO(robertwb): Limit the cache size. - self._cache = {} # type: Dict[BoundedWindow, Any] + self._cache: Dict[BoundedWindow, Any] = {} self._use_bulk_read = use_bulk_read def __getitem__(self, window): @@ -503,14 +499,12 @@ def __reduce__(self): self._cache[target_window] = self._side_input_data.view_fn(raw_view) return self._cache[target_window] - def is_globally_windowed(self): - # type: () -> bool + def is_globally_windowed(self) -> bool: return ( self._side_input_data.window_mapping_fn == sideinputs._global_window_mapping_fn) - def reset(self): - # type: () -> None + def reset(self) -> None: # TODO(BEAM-5428): Cross-bundle caching respecting cache tokens. self._cache = {} @@ -519,26 +513,28 @@ class ReadModifyWriteRuntimeState(userstate.ReadModifyWriteRuntimeState): def __init__(self, underlying_bag_state): self._underlying_bag_state = underlying_bag_state - def read(self): # type: () -> Any + def read(self) -> Any: values = list(self._underlying_bag_state.read()) if not values: return None return values[0] - def write(self, value): # type: (Any) -> None + def write(self, value: Any) -> None: self.clear() self._underlying_bag_state.add(value) - def clear(self): # type: () -> None + def clear(self) -> None: self._underlying_bag_state.clear() - def commit(self): # type: () -> None + def commit(self) -> None: self._underlying_bag_state.commit() class CombiningValueRuntimeState(userstate.CombiningValueRuntimeState): - def __init__(self, underlying_bag_state, combinefn): - # type: (userstate.AccumulatingRuntimeState, core.CombineFn) -> None + def __init__( + self, + underlying_bag_state: userstate.AccumulatingRuntimeState, + combinefn: core.CombineFn) -> None: self._combinefn = combinefn self._combinefn.setup() self._underlying_bag_state = underlying_bag_state @@ -552,12 +548,10 @@ def _read_accumulator(self, rewrite=True): self._underlying_bag_state.add(merged_accumulator) return merged_accumulator - def read(self): - # type: () -> Iterable[Any] + def read(self) -> Iterable[Any]: return self._combinefn.extract_output(self._read_accumulator()) - def add(self, value): - # type: (Any) -> None + def add(self, value: Any) -> None: # Prefer blind writes, but don't let them grow unboundedly. # This should be tuned to be much lower, but for now exercise # both paths well. @@ -569,8 +563,7 @@ def add(self, value): self._underlying_bag_state.add( self._combinefn.add_input(accumulator, value)) - def clear(self): - # type: () -> None + def clear(self) -> None: self._underlying_bag_state.clear() def commit(self): @@ -587,13 +580,11 @@ class _ConcatIterable(object): Unlike itertools.chain, this allows reiteration. """ - def __init__(self, first, second): - # type: (Iterable[Any], Iterable[Any]) -> None + def __init__(self, first: Iterable[Any], second: Iterable[Any]) -> None: self.first = first self.second = second - def __iter__(self): - # type: () -> Iterator[Any] + def __iter__(self) -> Iterator[Any]: for elem in self.first: yield elem for elem in self.second: @@ -604,38 +595,32 @@ def __iter__(self): class SynchronousBagRuntimeState(userstate.BagRuntimeState): - - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - state_key, # type: beam_fn_api_pb2.StateKey - value_coder # type: coders.Coder - ): - # type: (...) -> None + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + state_key: beam_fn_api_pb2.StateKey, + value_coder: coders.Coder) -> None: self._state_handler = state_handler self._state_key = state_key self._value_coder = value_coder self._cleared = False - self._added_elements = [] # type: List[Any] + self._added_elements: List[Any] = [] - def read(self): - # type: () -> Iterable[Any] + def read(self) -> Iterable[Any]: return _ConcatIterable([] if self._cleared else cast( 'Iterable[Any]', _StateBackedIterable( self._state_handler, self._state_key, self._value_coder)), self._added_elements) - def add(self, value): - # type: (Any) -> None + def add(self, value: Any) -> None: self._added_elements.append(value) - def clear(self): - # type: () -> None + def clear(self) -> None: self._cleared = True self._added_elements = [] - def commit(self): - # type: () -> None + def commit(self) -> None: to_await = None if self._cleared: to_await = self._state_handler.clear(self._state_key) @@ -648,18 +633,16 @@ def commit(self): class SynchronousSetRuntimeState(userstate.SetRuntimeState): - - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - state_key, # type: beam_fn_api_pb2.StateKey - value_coder # type: coders.Coder - ): - # type: (...) -> None + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + state_key: beam_fn_api_pb2.StateKey, + value_coder: coders.Coder) -> None: self._state_handler = state_handler self._state_key = state_key self._value_coder = value_coder self._cleared = False - self._added_elements = set() # type: Set[Any] + self._added_elements: Set[Any] = set() def _compact_data(self, rewrite=True): accumulator = set( @@ -679,12 +662,10 @@ def _compact_data(self, rewrite=True): return accumulator - def read(self): - # type: () -> Set[Any] + def read(self) -> Set[Any]: return self._compact_data(rewrite=False) - def add(self, value): - # type: (Any) -> None + def add(self, value: Any) -> None: if self._cleared: # This is a good time explicitly clear. self._state_handler.clear(self._state_key) @@ -694,13 +675,11 @@ def add(self, value): if random.random() > 0.5: self._compact_data() - def clear(self): - # type: () -> None + def clear(self) -> None: self._cleared = True self._added_elements = set() - def commit(self): - # type: () -> None + def commit(self) -> None: to_await = None if self._cleared: to_await = self._state_handler.clear(self._state_key) @@ -887,16 +866,16 @@ def commit(self) -> None: class OutputTimer(userstate.BaseTimer): - def __init__(self, - key, - window, # type: BoundedWindow - timestamp, # type: timestamp.Timestamp - paneinfo, # type: windowed_value.PaneInfo - time_domain, # type: str - timer_family_id, # type: str - timer_coder_impl, # type: coder_impl.TimerCoderImpl - output_stream # type: data_plane.ClosableOutputStream - ): + def __init__( + self, + key, + window: BoundedWindow, + timestamp: timestamp.Timestamp, + paneinfo: windowed_value.PaneInfo, + time_domain: str, + timer_family_id: str, + timer_coder_impl: coder_impl.TimerCoderImpl, + output_stream: data_plane.ClosableOutputStream): self._key = key self._window = window self._input_timestamp = timestamp @@ -942,15 +921,13 @@ def __init__(self, timer_coder_impl, output_stream=None): class FnApiUserStateContext(userstate.UserStateContext): """Interface for state and timers from SDK to Fn API servicer of state..""" - - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - transform_id, # type: str - key_coder, # type: coders.Coder - window_coder, # type: coders.Coder - ): - # type: (...) -> None - + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + transform_id: str, + key_coder: coders.Coder, + window_coder: coders.Coder, + ) -> None: """Initialize a ``FnApiUserStateContext``. Args: @@ -964,11 +941,10 @@ def __init__(self, self._key_coder = key_coder self._window_coder = window_coder # A mapping of {timer_family_id: TimerInfo} - self._timers_info = {} # type: Dict[str, TimerInfo] - self._all_states = {} # type: Dict[tuple, FnApiUserRuntimeStateTypes] + self._timers_info: Dict[str, TimerInfo] = {} + self._all_states: Dict[tuple, FnApiUserRuntimeStateTypes] = {} - def add_timer_info(self, timer_family_id, timer_info): - # type: (str, TimerInfo) -> None + def add_timer_info(self, timer_family_id: str, timer_info: TimerInfo) -> None: self._timers_info[timer_family_id] = timer_info def get_timer( @@ -987,19 +963,15 @@ def get_timer( timer_coder_impl, output_stream) - def get_state(self, *args): - # type: (*Any) -> FnApiUserRuntimeStateTypes + def get_state(self, *args: Any) -> FnApiUserRuntimeStateTypes: state_handle = self._all_states.get(args) if state_handle is None: state_handle = self._all_states[args] = self._create_state(*args) return state_handle - def _create_state(self, - state_spec, # type: userstate.StateSpec - key, - window # type: BoundedWindow - ): - # type: (...) -> FnApiUserRuntimeStateTypes + def _create_state( + self, state_spec: userstate.StateSpec, key, + window: BoundedWindow) -> FnApiUserRuntimeStateTypes: if isinstance(state_spec, (userstate.BagStateSpec, userstate.CombiningValueStateSpec, @@ -1046,13 +1018,11 @@ def _create_state(self, else: raise NotImplementedError(state_spec) - def commit(self): - # type: () -> None + def commit(self) -> None: for state in self._all_states.values(): state.commit() - def reset(self): - # type: () -> None + def reset(self) -> None: for state in self._all_states.values(): state.finalize() self._all_states = {} @@ -1071,14 +1041,12 @@ def wrapper(*args): return wrapper -def only_element(iterable): - # type: (Iterable[T]) -> T +def only_element(iterable: Iterable[T]) -> T: element, = iterable return element -def _environments_compatible(submission, runtime): - # type: (str, str) -> bool +def _environments_compatible(submission: str, runtime: str) -> bool: if submission == runtime: return True if 'rc' in submission and runtime in submission: @@ -1088,8 +1056,8 @@ def _environments_compatible(submission, runtime): return False -def _verify_descriptor_created_in_a_compatible_env(process_bundle_descriptor): - # type: (beam_fn_api_pb2.ProcessBundleDescriptor) -> None +def _verify_descriptor_created_in_a_compatible_env( + process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor) -> None: runtime_sdk = environments.sdk_base_version_capability() for t in process_bundle_descriptor.transforms.values(): @@ -1111,16 +1079,14 @@ def _verify_descriptor_created_in_a_compatible_env(process_bundle_descriptor): class BundleProcessor(object): """ A class for processing bundles of elements. """ - - def __init__(self, - runner_capabilities, # type: FrozenSet[str] - process_bundle_descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor - state_handler, # type: sdk_worker.CachingStateHandler - data_channel_factory, # type: data_plane.DataChannelFactory - data_sampler=None, # type: Optional[data_sampler.DataSampler] - ): - # type: (...) -> None - + def __init__( + self, + runner_capabilities: FrozenSet[str], + process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor, + state_handler: sdk_worker.CachingStateHandler, + data_channel_factory: data_plane.DataChannelFactory, + data_sampler: Optional[data_sampler.DataSampler] = None, + ) -> None: """Initialize a bundle processor. Args: @@ -1136,7 +1102,7 @@ def __init__(self, self.state_handler = state_handler self.data_channel_factory = data_channel_factory self.data_sampler = data_sampler - self.current_instruction_id = None # type: Optional[str] + self.current_instruction_id: Optional[str] = None # Represents whether the SDK is consuming received data. self.consuming_received_data = False @@ -1155,7 +1121,7 @@ def __init__(self, # {(transform_id, timer_family_id): TimerInfo} # The mapping is empty when there is no timer_family_specs in the # ProcessBundleDescriptor. - self.timers_info = {} # type: Dict[Tuple[str, str], TimerInfo] + self.timers_info: Dict[Tuple[str, str], TimerInfo] = {} # TODO(robertwb): Figure out the correct prefix to use for output counters # from StateSampler. @@ -1170,10 +1136,8 @@ def __init__(self, self.splitting_lock = threading.Lock() def create_execution_tree( - self, - descriptor # type: beam_fn_api_pb2.ProcessBundleDescriptor - ): - # type: (...) -> collections.OrderedDict[str, operations.DoOperation] + self, descriptor: beam_fn_api_pb2.ProcessBundleDescriptor + ) -> collections.OrderedDict[str, operations.DoOperation]: transform_factory = BeamTransformFactory( self.runner_capabilities, descriptor, @@ -1192,16 +1156,14 @@ def is_side_input(transform_proto, tag): transform_proto.spec.payload, beam_runner_api_pb2.ParDoPayload).side_inputs - pcoll_consumers = collections.defaultdict( - list) # type: DefaultDict[str, List[str]] + pcoll_consumers: DefaultDict[str, List[str]] = collections.defaultdict(list) for transform_id, transform_proto in descriptor.transforms.items(): for tag, pcoll_id in transform_proto.inputs.items(): if not is_side_input(transform_proto, tag): pcoll_consumers[pcoll_id].append(transform_id) @memoize - def get_operation(transform_id): - # type: (str) -> operations.Operation + def get_operation(transform_id: str) -> operations.Operation: transform_consumers = { tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]] for tag, @@ -1218,8 +1180,7 @@ def get_operation(transform_id): # Operations must be started (hence returned) in order. @memoize - def topological_height(transform_id): - # type: (str) -> int + def topological_height(transform_id: str) -> int: return 1 + max([0] + [ topological_height(consumer) for pcoll in descriptor.transforms[transform_id].outputs.values() @@ -1232,18 +1193,18 @@ def topological_height(transform_id): get_operation(transform_id))) for transform_id in sorted( descriptor.transforms, key=topological_height, reverse=True)]) - def reset(self): - # type: () -> None + def reset(self) -> None: self.counter_factory.reset() self.state_sampler.reset() # Side input caches. for op in self.ops.values(): op.reset() - def process_bundle(self, instruction_id): - # type: (str) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool] + def process_bundle( + self, instruction_id: str + ) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool]: - expected_input_ops = [] # type: List[DataInputOperation] + expected_input_ops: List[DataInputOperation] = [] for op in self.ops.values(): if isinstance(op, DataOutputOperation): @@ -1269,9 +1230,10 @@ def process_bundle(self, instruction_id): # both data input and timer input. The data input is identied by # transform_id. The data input is identified by # (transform_id, timer_family_id). - data_channels = collections.defaultdict( - list - ) # type: DefaultDict[data_plane.DataChannel, List[Union[str, Tuple[str, str]]]] + data_channels: DefaultDict[data_plane.DataChannel, + List[Union[str, Tuple[ + str, + str]]]] = collections.defaultdict(list) # Add expected data inputs for each data channel. input_op_by_transform_id = {} @@ -1337,18 +1299,17 @@ def process_bundle(self, instruction_id): self.current_instruction_id = None self.state_sampler.stop_if_still_running() - def finalize_bundle(self): - # type: () -> beam_fn_api_pb2.FinalizeBundleResponse + def finalize_bundle(self) -> beam_fn_api_pb2.FinalizeBundleResponse: for op in self.ops.values(): op.finalize_bundle() return beam_fn_api_pb2.FinalizeBundleResponse() - def requires_finalization(self): - # type: () -> bool + def requires_finalization(self) -> bool: return any(op.needs_finalization() for op in self.ops.values()) - def try_split(self, bundle_split_request): - # type: (beam_fn_api_pb2.ProcessBundleSplitRequest) -> beam_fn_api_pb2.ProcessBundleSplitResponse + def try_split( + self, bundle_split_request: beam_fn_api_pb2.ProcessBundleSplitRequest + ) -> beam_fn_api_pb2.ProcessBundleSplitResponse: split_response = beam_fn_api_pb2.ProcessBundleSplitResponse() with self.splitting_lock: if bundle_split_request.instruction_id != self.current_instruction_id: @@ -1386,20 +1347,18 @@ def try_split(self, bundle_split_request): return split_response - def delayed_bundle_application(self, - op, # type: operations.DoOperation - deferred_remainder # type: SplitResultResidual - ): - # type: (...) -> beam_fn_api_pb2.DelayedBundleApplication + def delayed_bundle_application( + self, op: operations.DoOperation, deferred_remainder: SplitResultResidual + ) -> beam_fn_api_pb2.DelayedBundleApplication: assert op.input_info is not None # TODO(SDF): For non-root nodes, need main_input_coder + residual_coder. (element_and_restriction, current_watermark, deferred_timestamp) = ( deferred_remainder) if deferred_timestamp: assert isinstance(deferred_timestamp, timestamp.Duration) - proto_deferred_watermark = proto_utils.from_micros( - duration_pb2.Duration, - deferred_timestamp.micros) # type: Optional[duration_pb2.Duration] + proto_deferred_watermark: Optional[ + duration_pb2.Duration] = proto_utils.from_micros( + duration_pb2.Duration, deferred_timestamp.micros) else: proto_deferred_watermark = None return beam_fn_api_pb2.DelayedBundleApplication( @@ -1407,29 +1366,26 @@ def delayed_bundle_application(self, application=self.construct_bundle_application( op.input_info, current_watermark, element_and_restriction)) - def bundle_application(self, - op, # type: operations.DoOperation - primary # type: SplitResultPrimary - ): - # type: (...) -> beam_fn_api_pb2.BundleApplication + def bundle_application( + self, op: operations.DoOperation, + primary: SplitResultPrimary) -> beam_fn_api_pb2.BundleApplication: assert op.input_info is not None return self.construct_bundle_application( op.input_info, None, primary.primary_value) - def construct_bundle_application(self, - op_input_info, # type: operations.OpInputInfo - output_watermark, # type: Optional[timestamp.Timestamp] - element - ): - # type: (...) -> beam_fn_api_pb2.BundleApplication + def construct_bundle_application( + self, + op_input_info: operations.OpInputInfo, + output_watermark: Optional[timestamp.Timestamp], + element) -> beam_fn_api_pb2.BundleApplication: transform_id, main_input_tag, main_input_coder, outputs = op_input_info if output_watermark: proto_output_watermark = proto_utils.from_micros( timestamp_pb2.Timestamp, output_watermark.micros) - output_watermarks = { + output_watermarks: Optional[Dict[str, timestamp_pb2.Timestamp]] = { output: proto_output_watermark for output in outputs - } # type: Optional[Dict[str, timestamp_pb2.Timestamp]] + } else: output_watermarks = None return beam_fn_api_pb2.BundleApplication( @@ -1438,9 +1394,7 @@ def construct_bundle_application(self, output_watermarks=output_watermarks, element=main_input_coder.get_impl().encode_nested(element)) - def monitoring_infos(self): - # type: () -> List[metrics_pb2.MonitoringInfo] - + def monitoring_infos(self) -> List[metrics_pb2.MonitoringInfo]: """Returns the list of MonitoringInfos collected processing this bundle.""" # Construct a new dict first to remove duplicates. all_monitoring_infos_dict = {} @@ -1452,8 +1406,7 @@ def monitoring_infos(self): return list(all_monitoring_infos_dict.values()) - def shutdown(self): - # type: () -> None + def shutdown(self) -> None: for op in self.ops.values(): op.teardown() @@ -1474,15 +1427,16 @@ class ExecutionContext: class BeamTransformFactory(object): """Factory for turning transform_protos into executable operations.""" - def __init__(self, - runner_capabilities, # type: FrozenSet[str] - descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor - data_channel_factory, # type: data_plane.DataChannelFactory - counter_factory, # type: counters.CounterFactory - state_sampler, # type: statesampler.StateSampler - state_handler, # type: sdk_worker.CachingStateHandler - data_sampler, # type: Optional[data_sampler.DataSampler] - ): + def __init__( + self, + runner_capabilities: FrozenSet[str], + descriptor: beam_fn_api_pb2.ProcessBundleDescriptor, + data_channel_factory: data_plane.DataChannelFactory, + counter_factory: counters.CounterFactory, + state_sampler: statesampler.StateSampler, + state_handler: sdk_worker.CachingStateHandler, + data_sampler: Optional[data_sampler.DataSampler], + ): self.runner_capabilities = runner_capabilities self.descriptor = descriptor self.data_channel_factory = data_channel_factory @@ -1499,27 +1453,41 @@ def __init__(self, element_coder_impl)) self.data_sampler = data_sampler - _known_urns = { - } # type: Dict[str, Tuple[ConstructorFn, Union[Type[message.Message], Type[bytes], None]]] + _known_urns: Dict[str, + Tuple[ConstructorFn, + Union[Type[message.Message], Type[bytes], + None]]] = {} @classmethod def register_urn( - cls, - urn, # type: str - parameter_type # type: Optional[Type[T]] - ): - # type: (...) -> Callable[[Callable[[BeamTransformFactory, str, beam_runner_api_pb2.PTransform, T, Dict[str, List[operations.Operation]]], operations.Operation]], Callable[[BeamTransformFactory, str, beam_runner_api_pb2.PTransform, T, Dict[str, List[operations.Operation]]], operations.Operation]] + cls, urn: str, parameter_type: Optional[Type[T]] + ) -> Callable[[ + Callable[[ + BeamTransformFactory, + str, + beam_runner_api_pb2.PTransform, + T, + Dict[str, List[operations.Operation]] + ], + operations.Operation] + ], + Callable[[ + BeamTransformFactory, + str, + beam_runner_api_pb2.PTransform, + T, + Dict[str, List[operations.Operation]] + ], + operations.Operation]]: def wrapper(func): cls._known_urns[urn] = func, parameter_type return func return wrapper - def create_operation(self, - transform_id, # type: str - consumers # type: Dict[str, List[operations.Operation]] - ): - # type: (...) -> operations.Operation + def create_operation( + self, transform_id: str, + consumers: Dict[str, List[operations.Operation]]) -> operations.Operation: transform_proto = self.descriptor.transforms[transform_id] if not transform_proto.unique_name: _LOGGER.debug("No unique name set for transform %s" % transform_id) @@ -1529,8 +1497,7 @@ def create_operation(self, transform_proto.spec.payload, parameter_type) return creator(self, transform_id, transform_proto, payload, consumers) - def extract_timers_info(self): - # type: () -> Dict[Tuple[str, str], TimerInfo] + def extract_timers_info(self) -> Dict[Tuple[str, str], TimerInfo]: timers_info = {} for transform_id, transform_proto in self.descriptor.transforms.items(): if transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn: @@ -1545,8 +1512,7 @@ def extract_timers_info(self): timer_coder_impl=timer_coder_impl) return timers_info - def get_coder(self, coder_id): - # type: (str) -> coders.Coder + def get_coder(self, coder_id: str) -> coders.Coder: if coder_id not in self.descriptor.coders: raise KeyError("No such coder: %s" % coder_id) coder_proto = self.descriptor.coders[coder_id] @@ -1557,8 +1523,7 @@ def get_coder(self, coder_id): return operation_specs.get_coder_from_spec( json.loads(coder_proto.spec.payload.decode('utf-8'))) - def get_windowed_coder(self, pcoll_id): - # type: (str) -> WindowedValueCoder + def get_windowed_coder(self, pcoll_id: str) -> WindowedValueCoder: coder = self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id) # TODO(robertwb): Remove this condition once all runners are consistent. if not isinstance(coder, WindowedValueCoder): @@ -1569,32 +1534,34 @@ def get_windowed_coder(self, pcoll_id): else: return coder - def get_output_coders(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> Dict[str, coders.Coder] + def get_output_coders( + self, transform_proto: beam_runner_api_pb2.PTransform + ) -> Dict[str, coders.Coder]: return { tag: self.get_windowed_coder(pcoll_id) for tag, pcoll_id in transform_proto.outputs.items() } - def get_only_output_coder(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> coders.Coder + def get_only_output_coder( + self, transform_proto: beam_runner_api_pb2.PTransform) -> coders.Coder: return only_element(self.get_output_coders(transform_proto).values()) - def get_input_coders(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> Dict[str, coders.WindowedValueCoder] + def get_input_coders( + self, transform_proto: beam_runner_api_pb2.PTransform + ) -> Dict[str, coders.WindowedValueCoder]: return { tag: self.get_windowed_coder(pcoll_id) for tag, pcoll_id in transform_proto.inputs.items() } - def get_only_input_coder(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> coders.Coder + def get_only_input_coder( + self, transform_proto: beam_runner_api_pb2.PTransform) -> coders.Coder: return only_element(list(self.get_input_coders(transform_proto).values())) - def get_input_windowing(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> Windowing + def get_input_windowing( + self, transform_proto: beam_runner_api_pb2.PTransform) -> Windowing: pcoll_id = only_element(transform_proto.inputs.values()) windowing_strategy_id = self.descriptor.pcollections[ pcoll_id].windowing_strategy_id @@ -1603,12 +1570,10 @@ def get_input_windowing(self, transform_proto): # TODO(robertwb): Update all operations to take these in the constructor. @staticmethod def augment_oldstyle_op( - op, # type: OperationT - step_name, # type: str - consumers, # type: Mapping[str, Iterable[operations.Operation]] - tag_list=None # type: Optional[List[str]] - ): - # type: (...) -> OperationT + op: OperationT, + step_name: str, + consumers: Mapping[str, Iterable[operations.Operation]], + tag_list: Optional[List[str]] = None) -> OperationT: op.step_name = step_name for tag, op_consumers in consumers.items(): for consumer in op_consumers: @@ -1619,13 +1584,11 @@ def augment_oldstyle_op( @BeamTransformFactory.register_urn( DATA_INPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) def create_source_runner( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - grpc_port, # type: beam_fn_api_pb2.RemoteGrpcPort - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> DataInputOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + grpc_port: beam_fn_api_pb2.RemoteGrpcPort, + consumers: Dict[str, List[operations.Operation]]) -> DataInputOperation: output_coder = factory.get_coder(grpc_port.coder_id) return DataInputOperation( @@ -1642,13 +1605,11 @@ def create_source_runner( @BeamTransformFactory.register_urn( DATA_OUTPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) def create_sink_runner( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - grpc_port, # type: beam_fn_api_pb2.RemoteGrpcPort - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> DataOutputOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + grpc_port: beam_fn_api_pb2.RemoteGrpcPort, + consumers: Dict[str, List[operations.Operation]]) -> DataOutputOperation: output_coder = factory.get_coder(grpc_port.coder_id) return DataOutputOperation( common.NameContext(transform_proto.unique_name, transform_id), @@ -1663,13 +1624,12 @@ def create_sink_runner( @BeamTransformFactory.register_urn(OLD_DATAFLOW_RUNNER_HARNESS_READ_URN, None) def create_source_java( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, parameter, - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.ReadOperation + consumers: Dict[str, + List[operations.Operation]]) -> operations.ReadOperation: # The Dataflow runner harness strips the base64 encoding. source = pickler.loads(base64.b64encode(parameter)) spec = operation_specs.WorkerRead( @@ -1688,13 +1648,12 @@ def create_source_java( @BeamTransformFactory.register_urn( common_urns.deprecated_primitives.READ.urn, beam_runner_api_pb2.ReadPayload) def create_deprecated_read( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.ReadPayload - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.ReadOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.ReadPayload, + consumers: Dict[str, + List[operations.Operation]]) -> operations.ReadOperation: source = iobase.BoundedSource.from_runner_api( parameter.source, factory.context) spec = operation_specs.WorkerRead( @@ -1713,13 +1672,12 @@ def create_deprecated_read( @BeamTransformFactory.register_urn( python_urns.IMPULSE_READ_TRANSFORM, beam_runner_api_pb2.ReadPayload) def create_read_from_impulse_python( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.ReadPayload - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.ImpulseReadOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.ReadPayload, + consumers: Dict[str, List[operations.Operation]] +) -> operations.ImpulseReadOperation: return operations.ImpulseReadOperation( common.NameContext(transform_proto.unique_name, transform_id), factory.counter_factory, @@ -1731,12 +1689,11 @@ def create_read_from_impulse_python( @BeamTransformFactory.register_urn(OLD_DATAFLOW_RUNNER_HARNESS_PARDO_URN, None) def create_dofn_javasdk( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, serialized_fn, - consumers # type: Dict[str, List[operations.Operation]] -): + consumers: Dict[str, List[operations.Operation]]): return _create_pardo_operation( factory, transform_id, transform_proto, consumers, serialized_fn) @@ -1820,12 +1777,11 @@ def process(self, element_restriction, *args, **kwargs): common_urns.sdf_components.PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn, beam_runner_api_pb2.ParDoPayload) def create_process_sized_elements_and_restrictions( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.ParDoPayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.ParDoPayload, + consumers: Dict[str, List[operations.Operation]]): return _create_pardo_operation( factory, transform_id, @@ -1867,13 +1823,11 @@ def _create_sdf_operation( @BeamTransformFactory.register_urn( common_urns.primitives.PAR_DO.urn, beam_runner_api_pb2.ParDoPayload) def create_par_do( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.ParDoPayload - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.DoOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.ParDoPayload, + consumers: Dict[str, List[operations.Operation]]) -> operations.DoOperation: return _create_pardo_operation( factory, transform_id, @@ -1885,14 +1839,13 @@ def create_par_do( def _create_pardo_operation( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, consumers, serialized_fn, - pardo_proto=None, # type: Optional[beam_runner_api_pb2.ParDoPayload] - operation_cls=operations.DoOperation -): + pardo_proto: Optional[beam_runner_api_pb2.ParDoPayload] = None, + operation_cls=operations.DoOperation): if pardo_proto and pardo_proto.side_inputs: input_tags_to_coders = factory.get_input_coders(transform_proto) @@ -1924,9 +1877,8 @@ def _create_pardo_operation( if not dofn_data[-1]: # Windowing not set. if pardo_proto: - other_input_tags = set.union( - set(pardo_proto.side_inputs), - set(pardo_proto.timer_family_specs)) # type: Container[str] + other_input_tags: Container[str] = set.union( + set(pardo_proto.side_inputs), set(pardo_proto.timer_family_specs)) else: other_input_tags = () pcoll_id, = [pcoll for tag, pcoll in transform_proto.inputs.items() @@ -1950,12 +1902,12 @@ def _create_pardo_operation( main_input_coder = found_input_coder if pardo_proto.timer_family_specs or pardo_proto.state_specs: - user_state_context = FnApiUserStateContext( - factory.state_handler, - transform_id, - main_input_coder.key_coder(), - main_input_coder.window_coder - ) # type: Optional[FnApiUserStateContext] + user_state_context: Optional[ + FnApiUserStateContext] = FnApiUserStateContext( + factory.state_handler, + transform_id, + main_input_coder.key_coder(), + main_input_coder.window_coder) else: user_state_context = None else: @@ -1989,12 +1941,13 @@ def _create_pardo_operation( return result -def _create_simple_pardo_operation(factory, # type: BeamTransformFactory - transform_id, - transform_proto, - consumers, - dofn, # type: beam.DoFn - ): +def _create_simple_pardo_operation( + factory: BeamTransformFactory, + transform_id, + transform_proto, + consumers, + dofn: beam.DoFn, +): serialized_fn = pickler.dumps((dofn, (), {}, [], None)) return _create_pardo_operation( factory, transform_id, transform_proto, consumers, serialized_fn) @@ -2004,12 +1957,11 @@ def _create_simple_pardo_operation(factory, # type: BeamTransformFactory common_urns.primitives.ASSIGN_WINDOWS.urn, beam_runner_api_pb2.WindowingStrategy) def create_assign_windows( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.WindowingStrategy - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.WindowingStrategy, + consumers: Dict[str, List[operations.Operation]]): class WindowIntoDoFn(beam.DoFn): def __init__(self, windowing): self.windowing = windowing @@ -2036,13 +1988,12 @@ def process( @BeamTransformFactory.register_urn(IDENTITY_DOFN_URN, None) def create_identity_dofn( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, parameter, - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.FlattenOperation + consumers: Dict[str, List[operations.Operation]] +) -> operations.FlattenOperation: return factory.augment_oldstyle_op( operations.FlattenOperation( common.NameContext(transform_proto.unique_name, transform_id), @@ -2058,13 +2009,12 @@ def create_identity_dofn( common_urns.combine_components.COMBINE_PER_KEY_PRECOMBINE.urn, beam_runner_api_pb2.CombinePayload) def create_combine_per_key_precombine( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.PGBKCVOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, + List[operations.Operation]]) -> operations.PGBKCVOperation: serialized_combine_fn = pickler.dumps(( beam.CombineFn.from_runner_api(payload.combine_fn, factory.context), [], {})) @@ -2085,12 +2035,11 @@ def create_combine_per_key_precombine( common_urns.combine_components.COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn, beam_runner_api_pb2.CombinePayload) def create_combbine_per_key_merge_accumulators( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, List[operations.Operation]]): return _create_combine_phase_operation( factory, transform_id, transform_proto, payload, consumers, 'merge') @@ -2099,12 +2048,11 @@ def create_combbine_per_key_merge_accumulators( common_urns.combine_components.COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn, beam_runner_api_pb2.CombinePayload) def create_combine_per_key_extract_outputs( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, List[operations.Operation]]): return _create_combine_phase_operation( factory, transform_id, transform_proto, payload, consumers, 'extract') @@ -2113,12 +2061,11 @@ def create_combine_per_key_extract_outputs( common_urns.combine_components.COMBINE_PER_KEY_CONVERT_TO_ACCUMULATORS.urn, beam_runner_api_pb2.CombinePayload) def create_combine_per_key_convert_to_accumulators( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, List[operations.Operation]]): return _create_combine_phase_operation( factory, transform_id, transform_proto, payload, consumers, 'convert') @@ -2127,19 +2074,18 @@ def create_combine_per_key_convert_to_accumulators( common_urns.combine_components.COMBINE_GROUPED_VALUES.urn, beam_runner_api_pb2.CombinePayload) def create_combine_grouped_values( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, List[operations.Operation]]): return _create_combine_phase_operation( factory, transform_id, transform_proto, payload, consumers, 'all') def _create_combine_phase_operation( - factory, transform_id, transform_proto, payload, consumers, phase): - # type: (...) -> operations.CombineOperation + factory, transform_id, transform_proto, payload, consumers, + phase) -> operations.CombineOperation: serialized_combine_fn = pickler.dumps(( beam.CombineFn.from_runner_api(payload.combine_fn, factory.context), [], {})) @@ -2158,13 +2104,12 @@ def _create_combine_phase_operation( @BeamTransformFactory.register_urn(common_urns.primitives.FLATTEN.urn, None) def create_flatten( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, payload, - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.FlattenOperation + consumers: Dict[str, List[operations.Operation]] +) -> operations.FlattenOperation: return factory.augment_oldstyle_op( operations.FlattenOperation( common.NameContext(transform_proto.unique_name, transform_id), @@ -2179,12 +2124,11 @@ def create_flatten( @BeamTransformFactory.register_urn( common_urns.primitives.MAP_WINDOWS.urn, beam_runner_api_pb2.FunctionSpec) def create_map_windows( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - mapping_fn_spec, # type: beam_runner_api_pb2.FunctionSpec - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + mapping_fn_spec: beam_runner_api_pb2.FunctionSpec, + consumers: Dict[str, List[operations.Operation]]): assert mapping_fn_spec.urn == python_urns.PICKLED_WINDOW_MAPPING_FN window_mapping_fn = pickler.loads(mapping_fn_spec.payload) @@ -2200,12 +2144,11 @@ def process(self, element): @BeamTransformFactory.register_urn( common_urns.primitives.MERGE_WINDOWS.urn, beam_runner_api_pb2.FunctionSpec) def create_merge_windows( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - mapping_fn_spec, # type: beam_runner_api_pb2.FunctionSpec - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + mapping_fn_spec: beam_runner_api_pb2.FunctionSpec, + consumers: Dict[str, List[operations.Operation]]): assert mapping_fn_spec.urn == python_urns.PICKLED_WINDOWFN window_fn = pickler.loads(mapping_fn_spec.payload) @@ -2213,24 +2156,25 @@ class MergeWindows(beam.DoFn): def process(self, element): nonce, windows = element - original_windows = set(windows) # type: Set[window.BoundedWindow] - merged_windows = collections.defaultdict( - set - ) # type: MutableMapping[window.BoundedWindow, Set[window.BoundedWindow]] # noqa: F821 + original_windows: Set[window.BoundedWindow] = set(windows) + merged_windows: MutableMapping[ + window.BoundedWindow, + Set[window.BoundedWindow]] = collections.defaultdict( + set) # noqa: F821 class RecordingMergeContext(window.WindowFn.MergeContext): def merge( self, - to_be_merged, # type: Iterable[window.BoundedWindow] - merge_result, # type: window.BoundedWindow - ): + to_be_merged: Iterable[window.BoundedWindow], + merge_result: window.BoundedWindow, + ): originals = merged_windows[merge_result] - for window in to_be_merged: - if window in original_windows: - originals.add(window) - original_windows.remove(window) + for w in to_be_merged: + if w in original_windows: + originals.add(w) + original_windows.remove(w) else: - originals.update(merged_windows.pop(window)) + originals.update(merged_windows.pop(w)) window_fn.merge(RecordingMergeContext(windows)) yield nonce, (original_windows, merged_windows.items()) @@ -2241,12 +2185,11 @@ def merge( @BeamTransformFactory.register_urn(common_urns.primitives.TO_STRING.urn, None) def create_to_string_fn( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - mapping_fn_spec, # type: beam_runner_api_pb2.FunctionSpec - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + mapping_fn_spec: beam_runner_api_pb2.FunctionSpec, + consumers: Dict[str, List[operations.Operation]]): class ToString(beam.DoFn): def process(self, element): key, value = element