diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index 6472b717f8ef..be834bf03bbd 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -225,86 +225,89 @@ private void flushInternal(Map requests) { } private void issueSingleRequest(long id, PendingRequest pendingRequest) { - if (prepareForSend(id, pendingRequest)) { - StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); - requestBuilder - .addCommitChunkBuilder() - .setComputationId(pendingRequest.computationId()) - .setRequestId(id) - .setShardingKey(pendingRequest.shardingKey()) - .setSerializedWorkItemCommit(pendingRequest.serializedCommit()); - StreamingCommitWorkRequest chunk = requestBuilder.build(); - try { - send(chunk); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. - } - } else { + if (!prepareForSend(id, pendingRequest)) { pendingRequest.abort(); + return; + } + + StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); + requestBuilder + .addCommitChunkBuilder() + .setComputationId(pendingRequest.computationId()) + .setRequestId(id) + .setShardingKey(pendingRequest.shardingKey()) + .setSerializedWorkItemCommit(pendingRequest.serializedCommit()); + StreamingCommitWorkRequest chunk = requestBuilder.build(); + try { + send(chunk); + } catch (IllegalStateException e) { + // Stream was broken, request will be retried when stream is reopened. } } private void issueBatchedRequest(Map requests) { - if (prepareForSend(requests)) { - StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); - String lastComputation = null; - for (Map.Entry entry : requests.entrySet()) { - PendingRequest request = entry.getValue(); - StreamingCommitRequestChunk.Builder chunkBuilder = requestBuilder.addCommitChunkBuilder(); - if (lastComputation == null || !lastComputation.equals(request.computationId())) { - chunkBuilder.setComputationId(request.computationId()); - lastComputation = request.computationId(); - } - chunkBuilder - .setRequestId(entry.getKey()) - .setShardingKey(request.shardingKey()) - .setSerializedWorkItemCommit(request.serializedCommit()); - } - StreamingCommitWorkRequest request = requestBuilder.build(); - try { - send(request); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. - } - } else { + if (!prepareForSend(requests)) { requests.forEach((ignored, pendingRequest) -> pendingRequest.abort()); + return; + } + + StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); + String lastComputation = null; + for (Map.Entry entry : requests.entrySet()) { + PendingRequest request = entry.getValue(); + StreamingCommitRequestChunk.Builder chunkBuilder = requestBuilder.addCommitChunkBuilder(); + if (lastComputation == null || !lastComputation.equals(request.computationId())) { + chunkBuilder.setComputationId(request.computationId()); + lastComputation = request.computationId(); + } + chunkBuilder + .setRequestId(entry.getKey()) + .setShardingKey(request.shardingKey()) + .setSerializedWorkItemCommit(request.serializedCommit()); + } + StreamingCommitWorkRequest request = requestBuilder.build(); + try { + send(request); + } catch (IllegalStateException e) { + // Stream was broken, request will be retried when stream is reopened. } } private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) { - if (prepareForSend(id, pendingRequest)) { - checkNotNull(pendingRequest.computationId(), "Cannot commit WorkItem w/o a computationId."); - ByteString serializedCommit = pendingRequest.serializedCommit(); - synchronized (this) { - for (int i = 0; - i < serializedCommit.size(); - i += AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { - int end = i + AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE; - ByteString chunk = serializedCommit.substring(i, Math.min(end, serializedCommit.size())); - - StreamingCommitRequestChunk.Builder chunkBuilder = - StreamingCommitRequestChunk.newBuilder() - .setRequestId(id) - .setSerializedWorkItemCommit(chunk) - .setComputationId(pendingRequest.computationId()) - .setShardingKey(pendingRequest.shardingKey()); - int remaining = serializedCommit.size() - end; - if (remaining > 0) { - chunkBuilder.setRemainingBytesForWorkItem(remaining); - } - - StreamingCommitWorkRequest requestChunk = - StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build(); - try { - send(requestChunk); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. - break; - } + if (!prepareForSend(id, pendingRequest)) { + pendingRequest.abort(); + return; + } + + checkNotNull(pendingRequest.computationId(), "Cannot commit WorkItem w/o a computationId."); + ByteString serializedCommit = pendingRequest.serializedCommit(); + synchronized (this) { + for (int i = 0; + i < serializedCommit.size(); + i += AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { + int end = i + AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE; + ByteString chunk = serializedCommit.substring(i, Math.min(end, serializedCommit.size())); + + StreamingCommitRequestChunk.Builder chunkBuilder = + StreamingCommitRequestChunk.newBuilder() + .setRequestId(id) + .setSerializedWorkItemCommit(chunk) + .setComputationId(pendingRequest.computationId()) + .setShardingKey(pendingRequest.shardingKey()); + int remaining = serializedCommit.size() - end; + if (remaining > 0) { + chunkBuilder.setRemainingBytesForWorkItem(remaining); + } + + StreamingCommitWorkRequest requestChunk = + StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build(); + try { + send(requestChunk); + } catch (IllegalStateException e) { + // Stream was broken, request will be retried when stream is reopened. + break; } } - } else { - pendingRequest.abort(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index 4aff1c83fc62..cda246065ab9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -69,7 +69,7 @@ final class GrpcGetDataStream private static final StreamingGetDataRequest HEALTH_CHECK_REQUEST = StreamingGetDataRequest.newBuilder().build(); - /** @implNote insertion and removal is guarded by {@link #shutdownLock} */ + /** @implNote {@link QueuedBatch} objects in the queue are is guarded by {@link #shutdownLock} */ private final Deque batches; private final Map pending; @@ -349,21 +349,36 @@ private ResponseT issueRequest(QueuedRequest request, ParseFn