Skip to content

Commit

Permalink
Release v0.5.0
Browse files Browse the repository at this point in the history
GitOrigin-RevId: c03e185b372a183344a6ba1832b5aeb3e41a771c
  • Loading branch information
Privacy Sandbox Team authored and Amandoj committed Oct 8, 2024
1 parent 9126178 commit 5348395
Show file tree
Hide file tree
Showing 89 changed files with 2,902 additions and 455 deletions.
14 changes: 7 additions & 7 deletions BUILDING.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ Images can be built using the provided script

The expected output for images built within the provided [Dockerfile](Dockerfile):
```
bazel-bin/shuffler/services/aggregator/aggregator_image/index.json: "digest": "sha256:47d794ec4653a135d434e5e5095145c8f91442d0f3fde972fdc367b8aaf13703"
bazel-bin/shuffler/services/collector/collector_image/index.json: "digest": "sha256:08ba57f255f6c6b19cd986aa2b643cfb571fdeb508824215e661798053f39180"
bazel-bin/shuffler/services/modelupdater/model_updater_image/index.json: "digest": "sha256:9f6fb2566681f57aa7ddb30f06ca0086eccfa871e1bae7fa8952611b7c7ddc41"
bazel-bin/shuffler/services/taskassignment/task_assignment_image/index.json: "digest": "sha256:a626d7b680cf5314b663a879a22ebd58799cc0b7b902bda400f158e9153f2843"
bazel-bin/shuffler/services/taskbuilder/task_builder_image/index.json: "digest": "sha256:cee0f537b9332a35c9044aec9497a31cc5c9ad3dfdd5230435d7b76b2b43ffcd"
bazel-bin/shuffler/services/taskmanagement/task_management_image/index.json: "digest": "sha256:516a9311d1de4466401975cde870b6f2acf4a49e05c4d63f988a7ea2c2610803"
bazel-bin/shuffler/services/taskscheduler/task_scheduler_image/index.json: "digest": "sha256:cacfb2b70d7553147f3ec4dff2c7d275281e354c87155e66cabf822d2e81d34b"
bazel-bin/shuffler/services/aggregator/aggregator_image/index.json: "digest": "sha256:a8a66a800c23865604e9ab816e6c3fc24c29bbed34187b5e21c93819f3594533"
bazel-bin/shuffler/services/collector/collector_image/index.json: "digest": "sha256:7be6b26c12d678614e6e92fc225985198afcfae6293abf84ddac7c6db217de25"
bazel-bin/shuffler/services/modelupdater/model_updater_image/index.json: "digest": "sha256:9bc980a096865f4dffcfd44e956441a1be72319f214c1bb0198f5441e9e619db"
bazel-bin/shuffler/services/taskassignment/task_assignment_image/index.json: "digest": "sha256:9169c5bc235022972f9cd1ce370716397314f1dcdd67b605426f99f9a0f46c54"
bazel-bin/shuffler/services/taskbuilder/task_builder_image/index.json: "digest": "sha256:87dcaf4adeb728d67ba73c29e2997aac318335f172f9484986dbb45154ddd398"
bazel-bin/shuffler/services/taskmanagement/task_management_image/index.json: "digest": "sha256:b00553e6648ded8fde0d773be566333dbb3c5eaef06074e54f1c7adea1e82962"
bazel-bin/shuffler/services/taskscheduler/task_scheduler_image/index.json: "digest": "sha256:8f9669d23f8a76f6b22303016accd36809559218562289a2c4c0dfbadf36fb1a"
```

## Publishing
Expand Down
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# Changelog

## [v0.5.0]

### Changes

- Added support for Cloud CDN in GCP terraform.
- Added several GCP alarm policies in GCP terraform.
- Additional charts added to Looker project dashboard.
- Upgrade from Apache http4 to http5 client for KAVS and notification clients.
- Update GCP Confidential Space images and add support for in-memory tree parallelism for V1 tensorflow aggregation.

## [v0.4.0]

### Changes
Expand Down
2 changes: 2 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ maven_install(
"com.google.code.findbugs:jsr305:3.0.2",
"io.github.resilience4j:resilience4j-core:1.7.1",
"io.github.resilience4j:resilience4j-retry:1.7.1",
# Apache 4.x used for coordinator dependency.
"org.apache.httpcomponents:httpcore:4.4.14",
"org.apache.httpcomponents:httpclient:4.5.13",
"org.apache.httpcomponents.client5:httpclient5:5.3.1",
"org.apache.httpcomponents.core5:httpcore5:5.1.4",
"org.apache.httpcomponents.core5:httpcore5-h2:5.1.4", # Explicit transitive dependency to avoid https://issues.apache.org/jira/browse/HTTPCLIENT-2222
"com.fasterxml.jackson.datatype:jackson-datatype-guava:2.15.2",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import com.google.ondevicepersonalization.federatedcompute.shuffler.common.CompressionUtils;
import com.google.ondevicepersonalization.federatedcompute.shuffler.common.Constants;
import com.google.ondevicepersonalization.federatedcompute.shuffler.common.Exceptions;
import com.google.ondevicepersonalization.federatedcompute.shuffler.common.NonRetryableException;
import com.google.ondevicepersonalization.federatedcompute.shuffler.common.crypto.Payload;
import com.google.ondevicepersonalization.federatedcompute.shuffler.common.crypto.PublicKeyEncryptionService;
import com.google.ondevicepersonalization.federatedcompute.shuffler.common.dao.BlobDao;
Expand Down Expand Up @@ -150,16 +151,34 @@ private void processMessageImpl(AggregatorMessage message) {
gradient ->
getGradientFullPath(
message.getGradientBucket(), message.getGradientPrefix(), gradient))
.map(blobDao::downloadAndDecompressIfNeeded)
.map(
(gradient) ->
blobDao
.downloadAndDecompressIfNeeded(gradient)
.orElseThrow(
() ->
new NonRetryableException(
String.format(
"Downloaded gradient for bucket %s and object %s is null or"
+ " does not exist",
gradient.getHost(), gradient.getResourceObject()))))
.map((payload) -> Payload.parseAndDecryptPayload(payload, decryptionKeyService))
.collect(Collectors.toList());

byte[] plan =
blobDao.downloadAndDecompressIfNeeded(
BlobDescription.builder()
.host(message.getServerPlanBucket())
.resourceObject(message.getServerPlanObject())
.build());
blobDao
.downloadAndDecompressIfNeeded(
BlobDescription.builder()
.host(message.getServerPlanBucket())
.resourceObject(message.getServerPlanObject())
.build())
.orElseThrow(
() ->
new NonRetryableException(
String.format(
"Downloaded plan for bucket %s and object %s is null or"
+ " does not exist",
message.getServerPlanBucket(), message.getServerPlanObject())));

byte[] aggregatedResult =
aggregate(encryptedGradients, plan, message.isAccumulateIntermediateUpdates());
Expand All @@ -177,23 +196,28 @@ private void processMessageImpl(AggregatorMessage message) {
throw new RuntimeException("Failed to compressAndUpload aggregated result", e);
}

try {
AggregatorNotification notification =
AggregatorNotification.builder()
.messages(
List.of(
AggregatorNotification.Message.builder()
.attributes(
AggregatorNotification.Attributes.builder()
.requestId(message.getRequestId())
.status(AggregatorNotification.Status.OK)
.build())
.build()))
.build();
httpMessageSender.sendMessage(notification, message.getNotificationEndpoint());
} catch (Exception e) {
logger.atError().setCause(e).log("Failed to send message to provided notification endpoint.");
throw new RuntimeException("Failed to send message to provided notification endpoint.", e);
if (!Strings.isNullOrEmpty(message.getNotificationEndpoint())) {
try {
AggregatorNotification notification =
AggregatorNotification.builder()
.messages(
List.of(
AggregatorNotification.Message.builder()
.attributes(
AggregatorNotification.Attributes.builder()
.requestId(message.getRequestId())
.status(AggregatorNotification.Status.OK)
.build())
.build()))
.build();
httpMessageSender.sendMessage(notification, message.getNotificationEndpoint());
} catch (Exception e) {
logger
.atError()
.setCause(e)
.log("Failed to send message to provided notification endpoint.");
throw new RuntimeException("Failed to send message to provided notification endpoint.", e);
}
}
}

Expand All @@ -213,22 +237,38 @@ private byte[] aggregate(
} catch (InvalidProtocolBufferException e) {
throw new IllegalStateException("Failed to decode plan");
}
if (plan.getPhase(0).hasServerPhaseV2()) {
// Take the sqrt to enforce two layers of in-memory tree aggregation.
int partitionSize = (int) Math.ceil(Math.sqrt(encryptedGradients.size()));
// Take the sqrt to enforce two layers of in-memory tree aggregation.
int partitionSize = (int) Math.ceil(Math.sqrt(encryptedGradients.size()));

// Layer 1
List<List<byte[]>> partitionedGradients = Lists.partition(encryptedGradients, partitionSize);
encryptedGradients =
partitionedGradients.parallelStream()
.map(gradients -> aggregateV2(gradients, planBytes, accumulateIntermediateUpdates))
.collect(Collectors.toList());
// Layer 1
List<List<byte[]>> partitionedGradients = Lists.partition(encryptedGradients, partitionSize);
if (plan.getPhase(0).hasServerPhaseV2()) {
if (partitionedGradients.size() > 1) {
boolean finalAccumulateIntermediateUpdates = accumulateIntermediateUpdates;
encryptedGradients =
partitionedGradients.parallelStream()
.map(
gradients ->
aggregateV2(gradients, planBytes, finalAccumulateIntermediateUpdates))
.collect(Collectors.toList());
accumulateIntermediateUpdates = true;
}

// Layer 2
return aggregateV2(encryptedGradients, planBytes, true);
return aggregateV2(encryptedGradients, planBytes, accumulateIntermediateUpdates);
} else {
// TODO(b/295060730): Support parallel V1 aggregation. This is currently blocked by limited
// tmpfs volume size on supported TEEs.
if (partitionedGradients.size() > 1) {
boolean finalAccumulateIntermediateUpdates = accumulateIntermediateUpdates;
encryptedGradients =
partitionedGradients.parallelStream()
.map(
gradients ->
aggregateV1(gradients, planBytes, finalAccumulateIntermediateUpdates))
.collect(Collectors.toList());
accumulateIntermediateUpdates = true;
}

// Layer 2
return aggregateV1(encryptedGradients, planBytes, accumulateIntermediateUpdates);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ java_library(
"//java/src/main/java/com/google/ondevicepersonalization/federatedcompute/shuffler/aggregator/core/message:aggregator_notification",
"//java/src/main/java/com/google/ondevicepersonalization/federatedcompute/shuffler/common:compression_utils",
"//java/src/main/java/com/google/ondevicepersonalization/federatedcompute/shuffler/common:exceptions",
"//java/src/main/java/com/google/ondevicepersonalization/federatedcompute/shuffler/common:non_retryable_exception",
"//java/src/main/java/com/google/ondevicepersonalization/federatedcompute/shuffler/common:constants",
"//java/src/main/java/com/google/ondevicepersonalization/federatedcompute/shuffler/common/crypto:payload",
"//java/src/main/java/com/google/ondevicepersonalization/federatedcompute/shuffler/common/dao:blob_dao",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import java.util.Set;
import java.util.TreeSet;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.stream.Collectors;
import org.slf4j.Logger;
Expand Down Expand Up @@ -77,6 +78,7 @@ public class CollectorCoreImpl implements CollectorCore {
private final int localComputeTimeoutMinutes;
private final int uploadTimeoutMinutes;
private final int batchSize;
private final Optional<Long> aggregationBatchFailureThreshold;

public CollectorCoreImpl(
TaskDao taskDao,
Expand All @@ -93,7 +95,8 @@ public CollectorCoreImpl(
LockRegistry lockRegistry,
int localComputeTimeoutMinutes,
int uploadTimeoutMinutes,
int collectorBatchSize) {
int collectorBatchSize,
Optional<Long> aggregationBatchFailureThreshold) {
this.taskDao = taskDao;
this.blobDao = blobDao;
this.blobManager = blobManager;
Expand All @@ -109,6 +112,7 @@ public CollectorCoreImpl(
this.localComputeTimeoutMinutes = localComputeTimeoutMinutes;
this.uploadTimeoutMinutes = uploadTimeoutMinutes;
this.batchSize = collectorBatchSize;
this.aggregationBatchFailureThreshold = aggregationBatchFailureThreshold;
}

private static String trimSlash(String folderName) {
Expand Down Expand Up @@ -214,28 +218,62 @@ public void processAggregatorNotifications(AggregatorNotification.Attributes not
}
}

// If iteration is AGGREGATING and total batches sent/completed < reportGoal update status.
IterationEntity iteration = taskDao.getIterationById(iterationId).get();
if (iteration.getStatus() == Status.AGGREGATING) {
if (aggregationBatchDao.querySumOfAggregationBatchesOfStatus(
iteration,
iteration.getAggregationLevel() - 1,
List.of(
AggregationBatchEntity.Status.PUBLISH_COMPLETED,
AggregationBatchEntity.Status.UPLOAD_COMPLETED))
< iteration.getReportGoal()) {
if (!taskDao.updateIterationStatus(
iteration,
iteration.toBuilder().status(Status.COLLECTING).aggregationLevel(0).build())) {
logger.error(
"Failed to update iteration {} from {} to {}",
iteration.getId().toString(),
iteration.getStatus(),
Status.COLLECTING);
// Throw exception to nack the message and retry
throw new IllegalStateException("Failed to update iteration.");
// Obtain lock before updating the iteration
String partition = iterationId.toString();
Lock lock = lockRegistry.obtain(LOCK_PREFIX + partition);
if (lock.tryLock(30, TimeUnit.SECONDS)) {
try {
// If iteration failure count is above threshold, fail the iteration.
IterationEntity iteration = taskDao.getIterationById(iterationId).get();
if (aggregationBatchFailureThreshold.isPresent()) {
long totalFailed =
aggregationBatchDao.querySumOfAggregationBatchesOfStatus(
iteration,
/* AggregationLevel */ 0,
List.of(AggregationBatchEntity.Status.FAILED));
if (totalFailed > aggregationBatchFailureThreshold.get() * batchSize) {
if (!taskDao.updateIterationStatus(
iteration, iteration.toBuilder().status(Status.AGGREGATING_FAILED).build())) {
logger.error(
"Failed to update iteration {} from {} to {}",
iteration.getId().toString(),
iteration.getStatus(),
Status.AGGREGATING_FAILED);
// Throw exception to nack the message and retry
throw new IllegalStateException("Failed to update iteration.");
}
return;
}
}

// If iteration is AGGREGATING and total batches sent/completed < reportGoal update status.
if (iteration.getStatus() == Status.AGGREGATING) {
if (aggregationBatchDao.querySumOfAggregationBatchesOfStatus(
iteration,
iteration.getAggregationLevel() - 1,
List.of(
AggregationBatchEntity.Status.PUBLISH_COMPLETED,
AggregationBatchEntity.Status.UPLOAD_COMPLETED))
< iteration.getReportGoal()) {
if (!taskDao.updateIterationStatus(
iteration,
iteration.toBuilder().status(Status.COLLECTING).aggregationLevel(0).build())) {
logger.error(
"Failed to update iteration {} from {} to {}",
iteration.getId().toString(),
iteration.getStatus(),
Status.COLLECTING);
// Throw exception to nack the message and retry
throw new IllegalStateException("Failed to update iteration.");
}
}
}
} finally {
lock.unlock();
}
} else {
logger.error("Failed to obtain lock during processAggregatorNotifications");
throw new IllegalStateException("Failed to obtain lock during processAggregatorNotifications");
}
} catch (Exception e) {
logger.error("Failed to process message", e);
Expand Down Expand Up @@ -286,6 +324,8 @@ private void processIteration(IterationEntity iteration) {
if (lock.tryLock()) {
try {
MDC.put(Constants.ITERATION_ID, iteration.getId().toString());
// Retrieve the latest status after locking.
iteration = taskDao.getIterationById(iteration.getId()).get();
if (Status.COLLECTING == iteration.getStatus()) {
processCollectingIterationImp(iteration, partition);
} else if (Status.AGGREGATING == iteration.getStatus()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ java_library(
"@coordinator-services-and-shared-libraries//:crypto_client",
"@federatedcompute//fcp/java_src/main/java/com/google/fcp/aggregation:aggregation_exception",
"@federatedcompute//fcp/java_src/main/java/com/google/fcp/tensorflow:tensorflow_exception",
":non_retryable_exception",
"@maven//:com_google_guava_guava",
],
)
Expand All @@ -101,6 +102,13 @@ java_library(
],
)

java_library(
name = "non_retryable_exception",
srcs = [
"NonRetryableException.java",
],
)

java_library(
name = "proto_parser",
srcs = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ public class Exceptions {
public static boolean isRetryableException(Exception e) {
if (isTensorflowException(e)
|| isAggregationException(e)
|| isNonRetryableKeyFetchException(e)) {
|| isNonRetryableKeyFetchException(e)
|| isNonRetryableException(e)) {
return false;
}
return true;
Expand All @@ -45,6 +46,11 @@ public static boolean isAggregationException(Exception e) {
|| (Throwables.getRootCause(e) instanceof AggregationException);
}

public static boolean isNonRetryableException(Exception e) {
return (e instanceof NonRetryableException)
|| (Throwables.getRootCause(e) instanceof NonRetryableException);
}

public static boolean isNonRetryableKeyFetchException(Exception e) {
if (e instanceof KeyFetchException) {
// https://github.com/privacysandbox/coordinator-services-and-shared-libraries/blob/main/java/com/google/scp/operator/cpio/cryptoclient/MultiPartyDecryptionKeyServiceImpl.java#L238
Expand Down
Loading

0 comments on commit 5348395

Please sign in to comment.