Skip to content

Commit

Permalink
Fix missing worker_id from interceptor (#33453)
Browse files Browse the repository at this point in the history
* Fix missing worker_id from interceptor

* Add worker_id attribute

* reorder and default parameters for GrpcStateHandlerFactory
  • Loading branch information
damondouglas authored Dec 30, 2024
1 parent 669076a commit 18ec331
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
11 changes: 7 additions & 4 deletions sdks/python/apache_beam/runners/worker/sdk_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,9 @@ def __init__(
self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
credentials, self._worker_id, data_buffer_time_limit_ms)
self._state_handler_factory = GrpcStateHandlerFactory(
self._state_cache, credentials)
state_cache=self._state_cache,
credentials=credentials,
worker_id=self._worker_id)
self._profiler_factory = profiler_factory
self.data_sampler = data_sampler
self.runner_capabilities = runner_capabilities
Expand Down Expand Up @@ -893,13 +895,14 @@ class GrpcStateHandlerFactory(StateHandlerFactory):
Caches the created channels by ``state descriptor url``.
"""
def __init__(self, state_cache, credentials=None):
# type: (StateCache, Optional[grpc.ChannelCredentials]) -> None
def __init__(self, state_cache, credentials=None, worker_id=None):
# type: (StateCache, Optional[grpc.ChannelCredentials], Optional[str]) -> None
self._state_handler_cache = {} # type: Dict[str, CachingStateHandler]
self._lock = threading.Lock()
self._throwing_state_handler = ThrowingStateHandler()
self._credentials = credentials
self._state_cache = state_cache
self._worker_id = worker_id

def create_state_handler(self, api_service_descriptor):
# type: (endpoints_pb2.ApiServiceDescriptor) -> CachingStateHandler
Expand All @@ -926,7 +929,7 @@ def create_state_handler(self, api_service_descriptor):
_LOGGER.info('State channel established.')
# Add workerId to the grpc channel
grpc_channel = grpc.intercept_channel(
grpc_channel, WorkerIdInterceptor())
grpc_channel, WorkerIdInterceptor(self._worker_id))
self._state_handler_cache[url] = GlobalCachingStateHandler(
self._state_cache,
GrpcStateHandler(
Expand Down
4 changes: 3 additions & 1 deletion sdks/python/apache_beam/runners/worker/worker_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def __init__(
bundle_process_cache=None,
state_cache=None,
enable_heap_dump=False,
worker_id=None,
log_lull_timeout_ns=DEFAULT_LOG_LULL_TIMEOUT_NS):
"""Initialize FnApiWorkerStatusHandler.
Expand All @@ -164,7 +165,8 @@ def __init__(
self._state_cache = state_cache
ch = GRPCChannelFactory.insecure_channel(status_address)
grpc.channel_ready_future(ch).result(timeout=60)
self._status_channel = grpc.intercept_channel(ch, WorkerIdInterceptor())
self._status_channel = grpc.intercept_channel(
ch, WorkerIdInterceptor(worker_id))
self._status_stub = beam_fn_api_pb2_grpc.BeamFnWorkerStatusStub(
self._status_channel)
self._responses = queue.Queue()
Expand Down

0 comments on commit 18ec331

Please sign in to comment.