Skip to content

Commit

Permalink
Forward worker_id
Browse files Browse the repository at this point in the history
  • Loading branch information
damondouglas committed Dec 26, 2024
1 parent 3636a3c commit 129dd8e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
10 changes: 6 additions & 4 deletions sdks/python/apache_beam/runners/worker/sdk_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __init__(
self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
credentials, self._worker_id, data_buffer_time_limit_ms)
self._state_handler_factory = GrpcStateHandlerFactory(
self._state_cache, credentials)
self._state_cache, self._worker_id, credentials)
self._profiler_factory = profiler_factory
self.data_sampler = data_sampler
self.runner_capabilities = runner_capabilities
Expand All @@ -228,6 +228,7 @@ def default_factory(id):
status_address,
self._bundle_processor_cache,
self._state_cache,
self._worker_id,
enable_heap_dump) # type: Optional[FnApiWorkerStatusHandler]
except Exception:
traceback_string = traceback.format_exc()
Expand Down Expand Up @@ -893,13 +894,14 @@ class GrpcStateHandlerFactory(StateHandlerFactory):
Caches the created channels by ``state descriptor url``.
"""
def __init__(self, state_cache, credentials=None):
# type: (StateCache, Optional[grpc.ChannelCredentials]) -> None
def __init__(self, state_cache, worker_id, credentials=None):
# type: (StateCache, Optional[str], Optional[grpc.ChannelCredentials]) -> None
self._state_handler_cache = {} # type: Dict[str, CachingStateHandler]
self._lock = threading.Lock()
self._throwing_state_handler = ThrowingStateHandler()
self._credentials = credentials
self._state_cache = state_cache
self._worker_id = worker_id

def create_state_handler(self, api_service_descriptor):
# type: (endpoints_pb2.ApiServiceDescriptor) -> CachingStateHandler
Expand All @@ -926,7 +928,7 @@ def create_state_handler(self, api_service_descriptor):
_LOGGER.info('State channel established.')
# Add workerId to the grpc channel
grpc_channel = grpc.intercept_channel(
grpc_channel, WorkerIdInterceptor())
grpc_channel, WorkerIdInterceptor(self._worker_id))
self._state_handler_cache[url] = GlobalCachingStateHandler(
self._state_cache,
GrpcStateHandler(
Expand Down
3 changes: 2 additions & 1 deletion sdks/python/apache_beam/runners/worker/worker_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def __init__(
bundle_process_cache=None,
state_cache=None,
enable_heap_dump=False,
worker_id=None,
log_lull_timeout_ns=DEFAULT_LOG_LULL_TIMEOUT_NS):
"""Initialize FnApiWorkerStatusHandler.
Expand All @@ -164,7 +165,7 @@ def __init__(
self._state_cache = state_cache
ch = GRPCChannelFactory.insecure_channel(status_address)
grpc.channel_ready_future(ch).result(timeout=60)
self._status_channel = grpc.intercept_channel(ch, WorkerIdInterceptor())
self._status_channel = grpc.intercept_channel(ch, WorkerIdInterceptor(worker_id))
self._status_stub = beam_fn_api_pb2_grpc.BeamFnWorkerStatusStub(
self._status_channel)
self._responses = queue.Queue()
Expand Down

0 comments on commit 129dd8e

Please sign in to comment.