From 7e46017995e2ce0b9a6eaf0258af587c85c9ee47 Mon Sep 17 00:00:00 2001 From: Tristan Vuong <85768771+tristanvuong2021@users.noreply.github.com> Date: Fri, 22 Nov 2024 09:07:32 -0800 Subject: [PATCH 1/3] test: Add report creation to correctness test (#1845) test: Add report creation to correctness test --- .github/workflows/build-test.yml | 2 + .github/workflows/run-k8s-tests.yml | 2 + docs/gke/correctness-test.md | 3 +- src/main/k8s/local/testing/BUILD.bazel | 30 ++- .../testing/config_files_kustomization.yaml | 2 + ...empty_encryption_key_pair_config.textproto | 2 + .../testing/mc_config_kustomization.yaml | 18 ++ .../kingdom/service/api/v2alpha/BUILD.bazel | 1 + .../loadtest/reporting/BUILD.bazel | 26 ++ .../reporting/ReportingUserSimulator.kt | 250 ++++++++++++++++++ .../loadtest/resourcesetup/BUILD.bazel | 2 + .../loadtest/resourcesetup/ResourceSetup.kt | 50 ++++ .../k8s/testing/correctness_test_config.proto | 9 + .../k8s/AbstractCorrectnessTest.kt | 18 ++ .../measurement/integration/k8s/BUILD.bazel | 5 + .../k8s/EmptyClusterCorrectnessTest.kt | 84 +++++- .../k8s/SyntheticGeneratorCorrectnessTest.kt | 38 ++- .../correctness_test_config.tmpl.textproto | 2 + .../EmptyClusterPanelMatchCorrectnessTest.kt | 2 +- 19 files changed, 534 insertions(+), 12 deletions(-) create mode 100644 src/main/k8s/local/testing/empty_encryption_key_pair_config.textproto create mode 100644 src/main/k8s/local/testing/mc_config_kustomization.yaml create mode 100644 src/main/kotlin/org/wfanet/measurement/loadtest/reporting/BUILD.bazel create mode 100644 src/main/kotlin/org/wfanet/measurement/loadtest/reporting/ReportingUserSimulator.kt diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index ea039a28bc9..93528038444 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -67,6 +67,8 @@ jobs: build --define worker2_id=worker2 build --define worker2_public_api_target=worker2.example.com:8443 build --define mc_name=measurementConsumers/foo + build --define mc_api_key=foo + build --define mc_cert_name=measurementConsumers/foo/certificates/bar build --define edp1_name=dataProviders/foo1 build --define edp1_cert_name=dataProviders/foo1/certificates/bar1 build --define edp2_name=dataProviders/foo2 diff --git a/.github/workflows/run-k8s-tests.yml b/.github/workflows/run-k8s-tests.yml index 0f6ccea3c86..5906348104b 100644 --- a/.github/workflows/run-k8s-tests.yml +++ b/.github/workflows/run-k8s-tests.yml @@ -52,6 +52,7 @@ jobs: MC_NAME: ${{ vars.MC_NAME }} MC_API_KEY: ${{ secrets.MC_API_KEY }} GCLOUD_PROJECT: ${{ vars.GCLOUD_PROJECT }} + REPORTING_PUBLIC_API_TARGET: ${{ vars.REPORTING_PUBLIC_API_TARGET }} run: | cat << EOF > ~/.bazelrc common --config=ci @@ -59,6 +60,7 @@ jobs: build --define mc_name=$MC_NAME build --define mc_api_key=$MC_API_KEY build --define google_cloud_project=$GCLOUD_PROJECT + build --define reporting_public_api_target=$REPORTING_PUBLIC_API_TARGET test --test_output=streamed test --test_timeout=3600 EOF diff --git a/docs/gke/correctness-test.md b/docs/gke/correctness-test.md index 5b738737809..055def71ac0 100644 --- a/docs/gke/correctness-test.md +++ b/docs/gke/correctness-test.md @@ -80,7 +80,8 @@ bazel test //src/test/kotlin/org/wfanet/measurement/integration/k8s:SyntheticGen --test_output=streamed \ --define=kingdom_public_api_target=v2alpha.kingdom.dev.halo-cmm.org:8443 \ --define=mc_name=measurementConsumers/Rcn7fKd25C8 \ ---define=mc_api_key=W9q4zad246g +--define=mc_api_key=W9q4zad246g \ +--define=reporting_public_api_target=v2alpha.reporting.dev.halo-cmm.org:8443 ``` The time the test takes depends on the size of the data set. With the default diff --git a/src/main/k8s/local/testing/BUILD.bazel b/src/main/k8s/local/testing/BUILD.bazel index 51b188a2f9d..c92878a7261 100644 --- a/src/main/k8s/local/testing/BUILD.bazel +++ b/src/main/k8s/local/testing/BUILD.bazel @@ -63,8 +63,25 @@ kustomization_dir( "config_files_kustomization.yaml", "//src/main/k8s/testing/data:synthetic_generation_specs_small", "//src/main/k8s/testing/secretfiles:known_event_group_metadata_type_set", + "//src/main/k8s/testing/secretfiles:metric_spec_config.textproto", ], renames = {"config_files_kustomization.yaml": "kustomization.yaml"}, + tags = ["manual"], +) + +kustomization_dir( + name = "config_files_for_panel_match", + srcs = [ + "config_files_kustomization.yaml", + "empty_encryption_key_pair_config.textproto", + "//src/main/k8s/testing/data:synthetic_generation_specs_small", + "//src/main/k8s/testing/secretfiles:known_event_group_metadata_type_set", + "//src/main/k8s/testing/secretfiles:metric_spec_config.textproto", + ], + renames = { + "config_files_kustomization.yaml": "kustomization.yaml", + "empty_encryption_key_pair_config.textproto": "encryption_key_pair_config.textproto", + }, ) kustomization_dir( @@ -75,6 +92,15 @@ kustomization_dir( }, ) +kustomization_dir( + name = "mc_config", + srcs = ["mc_config_kustomization.yaml"], + renames = { + "mc_config_kustomization.yaml": "kustomization.yaml", + }, + tags = ["manual"], +) + kustomization_dir( name = "cmms", srcs = [ @@ -83,12 +109,14 @@ kustomization_dir( "//src/main/k8s/local:emulators", "//src/main/k8s/local:kingdom", "//src/main/k8s/local:postgres_database", + "//src/main/k8s/local:reporting_v2", ], generate_kustomization = True, tags = ["manual"], deps = [ ":config_files", ":db_creds", + ":mc_config", "//src/main/k8s/testing/secretfiles:kustomization", ], ) @@ -103,7 +131,7 @@ kustomization_dir( generate_kustomization = True, tags = ["manual"], deps = [ - ":config_files", + ":config_files_for_panel_match", "//src/main/k8s/testing/secretfiles:kustomization", ], ) diff --git a/src/main/k8s/local/testing/config_files_kustomization.yaml b/src/main/k8s/local/testing/config_files_kustomization.yaml index d882167e307..71c188ba7fb 100644 --- a/src/main/k8s/local/testing/config_files_kustomization.yaml +++ b/src/main/k8s/local/testing/config_files_kustomization.yaml @@ -16,6 +16,8 @@ configMapGenerator: - name: config-files files: - authority_key_identifier_to_principal_map.textproto + - encryption_key_pair_config.textproto + - metric_spec_config.textproto - synthetic_population_spec_small.textproto - synthetic_event_group_spec_small_1.textproto - synthetic_event_group_spec_small_2.textproto diff --git a/src/main/k8s/local/testing/empty_encryption_key_pair_config.textproto b/src/main/k8s/local/testing/empty_encryption_key_pair_config.textproto new file mode 100644 index 00000000000..51c20dc96c6 --- /dev/null +++ b/src/main/k8s/local/testing/empty_encryption_key_pair_config.textproto @@ -0,0 +1,2 @@ +# proto-file: wfa/measurement/config/reporting/encryption_key_pair_config.proto +# proto-message: EncryptionKeyPairConfig diff --git a/src/main/k8s/local/testing/mc_config_kustomization.yaml b/src/main/k8s/local/testing/mc_config_kustomization.yaml new file mode 100644 index 00000000000..8e329fcff4d --- /dev/null +++ b/src/main/k8s/local/testing/mc_config_kustomization.yaml @@ -0,0 +1,18 @@ +# Copyright 2024 The Cross-Media Measurement Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +secretGenerator: +- name: mc-config + files: + - measurement_consumer_config.textproto diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/BUILD.bazel index 2865c224525..15694208ce0 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/BUILD.bazel @@ -8,6 +8,7 @@ package(default_visibility = [ "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/testing:__pkg__", "//src/main/kotlin/org/wfanet/measurement/loadtest/panelmatch:__pkg__", "//src/main/kotlin/org/wfanet/measurement/loadtest/panelmatchresourcesetup:__pkg__", + "//src/main/kotlin/org/wfanet/measurement/loadtest/reporting:__pkg__", "//src/main/kotlin/org/wfanet/measurement/loadtest/resourcesetup:__pkg__", "//src/test/kotlin/org/wfanet/measurement/integration/common:__pkg__", ]) diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/reporting/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/loadtest/reporting/BUILD.bazel new file mode 100644 index 00000000000..962beda9818 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/reporting/BUILD.bazel @@ -0,0 +1,26 @@ +load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") + +package( + default_testonly = True, + default_visibility = [ + "//src/main/kotlin/org/wfanet/measurement/integration:__subpackages__", + "//src/main/kotlin/org/wfanet/measurement/loadtest:__subpackages__", + "//src/test/kotlin/org/wfanet/measurement/integration:__subpackages__", + "//src/test/kotlin/org/wfanet/measurement/loadtest:__subpackages__", + ], +) + +kt_jvm_library( + name = "simulator", + srcs = ["ReportingUserSimulator.kt"], + deps = [ + "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:data_providers_service", + "//src/main/kotlin/org/wfanet/measurement/loadtest/config:test_identifiers", + "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha:event_groups_service", + "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha:metric_calculation_specs_service", + "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha:reporting_sets_service", + "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha:reports_service", + "//src/main/proto/wfa/measurement/api/v2alpha/event_templates/testing:test_event_kt_jvm_proto", + "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/reporting/ReportingUserSimulator.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/reporting/ReportingUserSimulator.kt new file mode 100644 index 00000000000..99a80a5dfff --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/reporting/ReportingUserSimulator.kt @@ -0,0 +1,250 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.loadtest.reporting + +import com.google.common.truth.Truth.assertThat +import com.google.type.DayOfWeek +import com.google.type.date +import com.google.type.dateTime +import com.google.type.timeZone +import io.grpc.StatusException +import java.time.Duration +import java.util.logging.Logger +import kotlinx.coroutines.time.delay +import org.wfanet.measurement.api.v2alpha.DataProvider +import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt +import org.wfanet.measurement.api.v2alpha.getDataProviderRequest +import org.wfanet.measurement.common.ExponentialBackoff +import org.wfanet.measurement.common.coerceAtMost +import org.wfanet.measurement.loadtest.config.TestIdentifiers +import org.wfanet.measurement.reporting.v2alpha.EventGroup +import org.wfanet.measurement.reporting.v2alpha.EventGroupsGrpcKt +import org.wfanet.measurement.reporting.v2alpha.ListEventGroupsResponse +import org.wfanet.measurement.reporting.v2alpha.MetricCalculationSpec +import org.wfanet.measurement.reporting.v2alpha.MetricCalculationSpecKt +import org.wfanet.measurement.reporting.v2alpha.MetricCalculationSpecsGrpcKt +import org.wfanet.measurement.reporting.v2alpha.MetricSpecKt +import org.wfanet.measurement.reporting.v2alpha.Report +import org.wfanet.measurement.reporting.v2alpha.ReportKt +import org.wfanet.measurement.reporting.v2alpha.ReportingSet +import org.wfanet.measurement.reporting.v2alpha.ReportingSetKt +import org.wfanet.measurement.reporting.v2alpha.ReportingSetsGrpcKt +import org.wfanet.measurement.reporting.v2alpha.ReportsGrpcKt +import org.wfanet.measurement.reporting.v2alpha.createMetricCalculationSpecRequest +import org.wfanet.measurement.reporting.v2alpha.createReportRequest +import org.wfanet.measurement.reporting.v2alpha.createReportingSetRequest +import org.wfanet.measurement.reporting.v2alpha.getReportRequest +import org.wfanet.measurement.reporting.v2alpha.listEventGroupsRequest +import org.wfanet.measurement.reporting.v2alpha.metricCalculationSpec +import org.wfanet.measurement.reporting.v2alpha.metricSpec +import org.wfanet.measurement.reporting.v2alpha.report +import org.wfanet.measurement.reporting.v2alpha.reportingSet + +/** Simulator for Reporting operations on the Reporting public API. */ +class ReportingUserSimulator( + private val measurementConsumerName: String, + private val dataProvidersClient: DataProvidersGrpcKt.DataProvidersCoroutineStub, + private val eventGroupsClient: EventGroupsGrpcKt.EventGroupsCoroutineStub, + private val reportingSetsClient: ReportingSetsGrpcKt.ReportingSetsCoroutineStub, + private val metricCalculationSpecsClient: + MetricCalculationSpecsGrpcKt.MetricCalculationSpecsCoroutineStub, + private val reportsClient: ReportsGrpcKt.ReportsCoroutineStub, + private val initialResultPollingDelay: Duration = Duration.ofSeconds(1), + private val maximumResultPollingDelay: Duration = Duration.ofMinutes(1), +) { + suspend fun testCreateReport(runId: String) { + logger.info("Creating report...") + + val eventGroup = + listEventGroups() + .filter { + it.eventGroupReferenceId.startsWith( + TestIdentifiers.SIMULATOR_EVENT_GROUP_REFERENCE_ID_PREFIX + ) + } + .firstOrNull { + getDataProvider(it.cmmsDataProvider).capabilities.honestMajorityShareShuffleSupported + } ?: listEventGroups().first() + val createdPrimitiveReportingSet = createPrimitiveReportingSet(eventGroup) + val createdMetricCalculationSpec = createMetricCalculationSpec() + + val report = report { + reportingMetricEntries += + ReportKt.reportingMetricEntry { + key = createdPrimitiveReportingSet.name + value = + ReportKt.reportingMetricCalculationSpec { + metricCalculationSpecs += createdMetricCalculationSpec.name + } + } + reportingInterval = + ReportKt.reportingInterval { + reportStart = dateTime { + year = 2024 + month = 1 + day = 3 + timeZone = timeZone { id = "America/Los_Angeles" } + } + reportEnd = date { + year = 2024 + month = 1 + day = 4 + } + } + } + + val createdReport = + try { + reportsClient.createReport( + createReportRequest { + parent = measurementConsumerName + this.report = report + reportId = "a-$runId" + } + ) + } catch (e: StatusException) { + throw Exception("Error creating Report", e) + } + + val completedReport = pollForCompletedReport(createdReport.name) + + assertThat(completedReport.state).isEqualTo(Report.State.SUCCEEDED) + logger.info("Report creation succeeded") + } + + private suspend fun listEventGroups(): List { + try { + return buildList { + var response: ListEventGroupsResponse = ListEventGroupsResponse.getDefaultInstance() + do { + response = + eventGroupsClient.listEventGroups( + listEventGroupsRequest { + parent = measurementConsumerName + pageToken = response.nextPageToken + } + ) + addAll(response.eventGroupsList) + } while (response.nextPageToken.isNotEmpty()) + } + } catch (e: StatusException) { + throw Exception("Error listing EventGroups", e) + } + } + + private suspend fun getDataProvider(dataProviderName: String): DataProvider { + try { + return dataProvidersClient.getDataProvider(getDataProviderRequest { name = dataProviderName }) + } catch (e: StatusException) { + throw Exception("Error getting DataProvider $dataProviderName", e) + } + } + + private suspend fun createPrimitiveReportingSet(eventGroup: EventGroup): ReportingSet { + val primitiveReportingSet = reportingSet { + primitive = ReportingSetKt.primitive { cmmsEventGroups += eventGroup.cmmsEventGroup } + } + + try { + return reportingSetsClient.createReportingSet( + createReportingSetRequest { + parent = measurementConsumerName + reportingSet = primitiveReportingSet + reportingSetId = "a-123" + } + ) + } catch (e: StatusException) { + throw Exception("Error creating ReportingSet", e) + } + } + + private suspend fun createMetricCalculationSpec(): MetricCalculationSpec { + try { + return metricCalculationSpecsClient.createMetricCalculationSpec( + createMetricCalculationSpecRequest { + parent = measurementConsumerName + metricCalculationSpecId = "a-123" + metricCalculationSpec = metricCalculationSpec { + displayName = "union reach" + metricSpecs += metricSpec { + reach = + MetricSpecKt.reachParams { + singleDataProviderParams = + MetricSpecKt.samplingAndPrivacyParams { + privacyParams = MetricSpecKt.differentialPrivacyParams {} + } + multipleDataProviderParams = + MetricSpecKt.samplingAndPrivacyParams { + privacyParams = MetricSpecKt.differentialPrivacyParams {} + } + } + } + metricFrequencySpec = + MetricCalculationSpecKt.metricFrequencySpec { + weekly = + MetricCalculationSpecKt.MetricFrequencySpecKt.weekly { + dayOfWeek = DayOfWeek.WEDNESDAY + } + } + trailingWindow = + MetricCalculationSpecKt.trailingWindow { + count = 1 + increment = MetricCalculationSpec.TrailingWindow.Increment.WEEK + } + } + } + ) + } catch (e: StatusException) { + throw Exception("Error creating MetricCalculationSpec", e) + } + } + + private suspend fun pollForCompletedReport(reportName: String): Report { + val backoff = + ExponentialBackoff(initialDelay = initialResultPollingDelay, randomnessFactor = 0.0) + var attempt = 1 + while (true) { + val retrievedReport = + try { + reportsClient.getReport(getReportRequest { name = reportName }) + } catch (e: StatusException) { + throw Exception("Error getting Report", e) + } + + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. + when (retrievedReport.state) { + Report.State.SUCCEEDED, + Report.State.FAILED -> return retrievedReport + Report.State.RUNNING, + Report.State.UNRECOGNIZED, + Report.State.STATE_UNSPECIFIED -> { + val resultPollingDelay = + backoff.durationForAttempt(attempt).coerceAtMost(maximumResultPollingDelay) + logger.info { + "Report not completed yet. Waiting for ${resultPollingDelay.seconds} seconds." + } + delay(resultPollingDelay) + attempt++ + } + } + } + } + + companion object { + private val logger: Logger = Logger.getLogger(this::class.java.name) + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/resourcesetup/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/loadtest/resourcesetup/BUILD.bazel index 33554b80de4..185e2f3da8c 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/resourcesetup/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/resourcesetup/BUILD.bazel @@ -37,6 +37,8 @@ kt_jvm_library( "//src/main/proto/wfa/measurement/api/v2alpha:measurement_consumer_kt_jvm_proto", "//src/main/proto/wfa/measurement/api/v2alpha:measurement_consumers_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/config:authority_key_to_principal_map_kt_jvm_proto", + "//src/main/proto/wfa/measurement/config/reporting:encryption_key_pair_config_kt_jvm_proto", + "//src/main/proto/wfa/measurement/config/reporting:measurement_consumer_config_kt_jvm_proto", "//src/main/proto/wfa/measurement/internal/kingdom:account_kt_jvm_proto", "//src/main/proto/wfa/measurement/internal/kingdom:accounts_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/internal/kingdom:certificate_kt_jvm_proto", diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/resourcesetup/ResourceSetup.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/resourcesetup/ResourceSetup.kt index b3571bebf38..6bbc96d00f6 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/resourcesetup/ResourceSetup.kt +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/resourcesetup/ResourceSetup.kt @@ -55,6 +55,10 @@ import org.wfanet.measurement.common.crypto.tink.SelfIssuedIdTokens.generateIdTo import org.wfanet.measurement.common.identity.externalIdToApiId import org.wfanet.measurement.config.AuthorityKeyToPrincipalMapKt import org.wfanet.measurement.config.authorityKeyToPrincipalMap +import org.wfanet.measurement.config.reporting.EncryptionKeyPairConfigKt +import org.wfanet.measurement.config.reporting.encryptionKeyPairConfig +import org.wfanet.measurement.config.reporting.measurementConsumerConfig +import org.wfanet.measurement.config.reporting.measurementConsumerConfigs import org.wfanet.measurement.consent.client.measurementconsumer.signEncryptionPublicKey import org.wfanet.measurement.internal.kingdom.Account as InternalAccount import org.wfanet.measurement.internal.kingdom.AccountsGrpcKt @@ -224,6 +228,47 @@ class ResourceSetup( TextFormat.printer().print(akidMap, writer) } + val measurementConsumerConfig = measurementConsumerConfigs { + for (resource in resources) { + when (resource.resourceCase) { + Resources.Resource.ResourceCase.MEASUREMENT_CONSUMER -> + configs.put( + resource.name, + measurementConsumerConfig { + apiKey = resource.measurementConsumer.apiKey + signingCertificateName = resource.measurementConsumer.certificate + signingPrivateKeyPath = MEASUREMENT_CONSUMER_SIGNING_PRIVATE_KEY_PATH + }, + ) + else -> continue + } + } + } + output.resolve(MEASUREMENT_CONSUMER_CONFIG_FILE).writer().use { writer -> + TextFormat.printer().print(measurementConsumerConfig, writer) + } + + val encryptionKeyPairConfig = encryptionKeyPairConfig { + for (resource in resources) { + when (resource.resourceCase) { + Resources.Resource.ResourceCase.MEASUREMENT_CONSUMER -> + principalKeyPairs += + EncryptionKeyPairConfigKt.principalKeyPairs { + principal = resource.name + keyPairs += + EncryptionKeyPairConfigKt.keyPair { + publicKeyFile = MEASUREMENT_CONSUMER_ENCRYPTION_PUBLIC_KEY_PATH + privateKeyFile = MEASUREMENT_CONSUMER_ENCRYPTION_PRIVATE_KEY_PATH + } + } + else -> continue + } + } + } + output.resolve(ENCRYPTION_KEY_PAIR_CONFIG_FILE).writer().use { writer -> + TextFormat.printer().print(encryptionKeyPairConfig, writer) + } + val configName = bazelConfigName output.resolve(BAZEL_RC_FILE).writer().use { writer -> for (resource in resources) { @@ -446,6 +491,11 @@ class ResourceSetup( const val RESOURCES_OUTPUT_FILE = "resources.textproto" const val AKID_PRINCIPAL_MAP_FILE = "authority_key_identifier_to_principal_map.textproto" const val BAZEL_RC_FILE = "resource-setup.bazelrc" + const val MEASUREMENT_CONSUMER_CONFIG_FILE = "measurement_consumer_config.textproto" + const val ENCRYPTION_KEY_PAIR_CONFIG_FILE = "encryption_key_pair_config.textproto" + const val MEASUREMENT_CONSUMER_SIGNING_PRIVATE_KEY_PATH = "mc_cs_private.der" + const val MEASUREMENT_CONSUMER_ENCRYPTION_PUBLIC_KEY_PATH = "mc_enc_public.tink" + const val MEASUREMENT_CONSUMER_ENCRYPTION_PRIVATE_KEY_PATH = "mc_enc_private.tink" private val logger: Logger = Logger.getLogger(this::class.java.name) } diff --git a/src/main/proto/wfa/measurement/integration/k8s/testing/correctness_test_config.proto b/src/main/proto/wfa/measurement/integration/k8s/testing/correctness_test_config.proto index 4c894b4ca4f..25b14b56827 100644 --- a/src/main/proto/wfa/measurement/integration/k8s/testing/correctness_test_config.proto +++ b/src/main/proto/wfa/measurement/integration/k8s/testing/correctness_test_config.proto @@ -38,4 +38,13 @@ message CorrectnessTestConfig { // Authentication key for the CMMS public API. string api_authentication_key = 4; + + // gRPC target of Reporting public API server. + string reporting_public_api_target = 5; + + // Expected hostname (DNS-ID) in the reporting public API server's TLS + // certificate. + // + // If not specified, standard TLS DNS-ID derivation will be used. + string reporting_public_api_cert_host = 6; } diff --git a/src/test/kotlin/org/wfanet/measurement/integration/k8s/AbstractCorrectnessTest.kt b/src/test/kotlin/org/wfanet/measurement/integration/k8s/AbstractCorrectnessTest.kt index f4a28f4d1cd..64ed27e3e9b 100644 --- a/src/test/kotlin/org/wfanet/measurement/integration/k8s/AbstractCorrectnessTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/integration/k8s/AbstractCorrectnessTest.kt @@ -28,6 +28,7 @@ import org.wfanet.measurement.common.crypto.SigningKeyHandle import org.wfanet.measurement.integration.common.loadEncryptionPrivateKey import org.wfanet.measurement.integration.common.loadSigningKey import org.wfanet.measurement.loadtest.measurementconsumer.MeasurementConsumerSimulator +import org.wfanet.measurement.loadtest.reporting.ReportingUserSimulator /** Test for correctness of the CMMS on Kubernetes. */ abstract class AbstractCorrectnessTest(private val measurementSystem: MeasurementSystem) { @@ -37,6 +38,9 @@ abstract class AbstractCorrectnessTest(private val measurementSystem: Measuremen private val testHarness: MeasurementConsumerSimulator get() = measurementSystem.testHarness + private val reportingTestHarness: ReportingUserSimulator + get() = measurementSystem.reportingTestHarness + @Test(timeout = 1 * 60 * 1000) fun `impression measurement completes with expected result`() = runBlocking { testHarness.testImpression("$runId-impression") @@ -63,9 +67,15 @@ abstract class AbstractCorrectnessTest(private val measurementSystem: Measuremen ) } + @Test(timeout = 1 * 60 * 1000) + fun `report can be created`() = runBlocking { + reportingTestHarness.testCreateReport("$runId-test-report") + } + interface MeasurementSystem { val runId: String val testHarness: MeasurementConsumerSimulator + val reportingTestHarness: ReportingUserSimulator } companion object { @@ -97,6 +107,14 @@ abstract class AbstractCorrectnessTest(private val measurementSystem: Measuremen SigningCerts.fromPemFiles(cert, key, trustedCerts) } + val REPORTING_SIGNING_CERTS: SigningCerts by lazy { + val secretFiles = getRuntimePath(SECRET_FILES_PATH) + val trustedCerts = secretFiles.resolve("reporting_root.pem").toFile() + val cert = secretFiles.resolve("mc_tls.pem").toFile() + val key = secretFiles.resolve("mc_tls.key").toFile() + SigningCerts.fromPemFiles(cert, key, trustedCerts) + } + val MC_ENCRYPTION_PRIVATE_KEY: PrivateKeyHandle by lazy { loadEncryptionPrivateKey(MC_ENCRYPTION_PRIVATE_KEY_NAME) } diff --git a/src/test/kotlin/org/wfanet/measurement/integration/k8s/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/integration/k8s/BUILD.bazel index 3883661dbdb..2f0097c1e6f 100644 --- a/src/test/kotlin/org/wfanet/measurement/integration/k8s/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/integration/k8s/BUILD.bazel @@ -13,11 +13,13 @@ kt_jvm_library( srcs = ["AbstractCorrectnessTest.kt"], data = [ "//src/main/k8s/testing/secretfiles:mc_trusted_certs.pem", + "//src/main/k8s/testing/secretfiles:reporting_root.pem", "//src/main/k8s/testing/secretfiles:secret_files", ], deps = [ "//src/main/kotlin/org/wfanet/measurement/integration/common:configs", "//src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer:simulator", + "//src/main/kotlin/org/wfanet/measurement/loadtest/reporting:simulator", "@wfa_common_jvm//imports/java/com/google/common/truth", "@wfa_common_jvm//imports/java/org/junit", "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", @@ -58,6 +60,7 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/integration/common:synthetic_generation_specs", "//src/main/kotlin/org/wfanet/measurement/loadtest/config:vid_sampling", "//src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer:synthetic_generator_event_query", + "//src/main/kotlin/org/wfanet/measurement/loadtest/reporting:simulator", "//src/main/proto/wfa/measurement/integration/k8s/testing:correctness_test_config_kt_jvm_proto", "@wfa_common_jvm//imports/java/org/junit", "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", @@ -74,6 +77,8 @@ expand_template( "{kingdom_public_api_cert_host}": "localhost", "{mc_name}": TEST_K8S_SETTINGS.mc_name, "{mc_api_key}": TEST_K8S_SETTINGS.mc_api_key, + "{reporting_public_api_target}": "$(reporting_public_api_target)", + "{reporting_public_api_cert_host}": "localhost", }, tags = ["manual"], template = "correctness_test_config.tmpl.textproto", diff --git a/src/test/kotlin/org/wfanet/measurement/integration/k8s/EmptyClusterCorrectnessTest.kt b/src/test/kotlin/org/wfanet/measurement/integration/k8s/EmptyClusterCorrectnessTest.kt index 0ddd08ec41b..95f24dc4307 100644 --- a/src/test/kotlin/org/wfanet/measurement/integration/k8s/EmptyClusterCorrectnessTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/integration/k8s/EmptyClusterCorrectnessTest.kt @@ -69,10 +69,14 @@ import org.wfanet.measurement.internal.kingdom.AccountsGrpcKt import org.wfanet.measurement.loadtest.measurementconsumer.MeasurementConsumerData import org.wfanet.measurement.loadtest.measurementconsumer.MeasurementConsumerSimulator import org.wfanet.measurement.loadtest.measurementconsumer.MetadataSyntheticGeneratorEventQuery +import org.wfanet.measurement.loadtest.reporting.ReportingUserSimulator import org.wfanet.measurement.loadtest.resourcesetup.DuchyCert import org.wfanet.measurement.loadtest.resourcesetup.EntityContent import org.wfanet.measurement.loadtest.resourcesetup.ResourceSetup import org.wfanet.measurement.loadtest.resourcesetup.Resources +import org.wfanet.measurement.reporting.v2alpha.MetricCalculationSpecsGrpcKt +import org.wfanet.measurement.reporting.v2alpha.ReportingSetsGrpcKt +import org.wfanet.measurement.reporting.v2alpha.ReportsGrpcKt /** * Test for correctness of the CMMS on a single "empty" Kubernetes cluster using the `local` @@ -108,6 +112,7 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { val worker1Cert: String, val worker2Cert: String, val measurementConsumer: String, + val measurementConsumerCert: String, val apiKey: String, val dataProviders: Map, ) { @@ -117,6 +122,7 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { var worker1Cert: String? = null var worker2Cert: String? = null var measurementConsumer: String? = null + var measurementConsumerCert: String? = null var apiKey: String? = null val dataProviders = mutableMapOf() @@ -126,6 +132,7 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { Resources.Resource.ResourceCase.MEASUREMENT_CONSUMER -> { measurementConsumer = resource.name apiKey = resource.measurementConsumer.apiKey + measurementConsumerCert = resource.measurementConsumer.certificate } Resources.Resource.ResourceCase.DATA_PROVIDER -> { val displayName = resource.dataProvider.displayName @@ -146,12 +153,13 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { } return ResourceInfo( - requireNotNull(aggregatorCert), - requireNotNull(worker1Cert), - requireNotNull(worker2Cert), - requireNotNull(measurementConsumer), - requireNotNull(apiKey), - dataProviders, + aggregatorCert = requireNotNull(aggregatorCert), + worker1Cert = requireNotNull(worker1Cert), + worker2Cert = requireNotNull(worker2Cert), + measurementConsumer = requireNotNull(measurementConsumer), + measurementConsumerCert = requireNotNull(measurementConsumerCert), + apiKey = requireNotNull(apiKey), + dataProviders = dataProviders, ) } } @@ -174,6 +182,10 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { override val testHarness: MeasurementConsumerSimulator get() = _testHarness + private lateinit var _reportingTestHarness: ReportingUserSimulator + override val reportingTestHarness: ReportingUserSimulator + get() = _reportingTestHarness + override fun apply(base: Statement, description: Description): Statement { return object : Statement() { override fun evaluate() { @@ -182,6 +194,7 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { withTimeout(Duration.ofMinutes(5)) { val measurementConsumerData = populateCluster() _testHarness = createTestHarness(measurementConsumerData) + _reportingTestHarness = createReportingUserSimulator(measurementConsumerData) } } base.evaluate() @@ -212,7 +225,12 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { val resourceSetupOutput = runResourceSetup(duchyCerts, edpEntityContents, measurementConsumerContent) val resourceInfo = ResourceInfo.from(resourceSetupOutput.resources) - loadFullCmms(resourceInfo, resourceSetupOutput.akidPrincipalMap) + loadFullCmms( + resourceInfo, + resourceSetupOutput.akidPrincipalMap, + resourceSetupOutput.measurementConsumerConfig, + resourceSetupOutput.encryptionKeyPairConfig, + ) val encryptionPrivateKey: TinkPrivateKeyHandle = withContext(Dispatchers.IO) { @@ -259,6 +277,35 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { ) } + private suspend fun createReportingUserSimulator( + measurementConsumerData: MeasurementConsumerData + ): ReportingUserSimulator { + val reportingPublicPod: V1Pod = getPod(REPORTING_PUBLIC_DEPLOYMENT_NAME) + + val publicApiForwarder = PortForwarder(reportingPublicPod, SERVER_PORT) + portForwarders.add(publicApiForwarder) + + val publicApiAddress: InetSocketAddress = + withContext(Dispatchers.IO) { publicApiForwarder.start() } + val publicApiChannel: Channel = + buildMutualTlsChannel(publicApiAddress.toTarget(), REPORTING_SIGNING_CERTS) + .also { channels.add(it) } + .withDefaultDeadline(DEFAULT_RPC_DEADLINE) + + return ReportingUserSimulator( + measurementConsumerName = measurementConsumerData.name, + dataProvidersClient = DataProvidersGrpcKt.DataProvidersCoroutineStub(publicApiChannel), + eventGroupsClient = + org.wfanet.measurement.reporting.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub( + publicApiChannel + ), + reportingSetsClient = ReportingSetsGrpcKt.ReportingSetsCoroutineStub(publicApiChannel), + metricCalculationSpecsClient = + MetricCalculationSpecsGrpcKt.MetricCalculationSpecsCoroutineStub(publicApiChannel), + reportsClient = ReportsGrpcKt.ReportsCoroutineStub(publicApiChannel), + ) + } + fun stopPortForwarding() { for (channel in channels) { channel.shutdown() @@ -268,7 +315,12 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { } } - private suspend fun loadFullCmms(resourceInfo: ResourceInfo, akidPrincipalMap: File) { + private suspend fun loadFullCmms( + resourceInfo: ResourceInfo, + akidPrincipalMap: File, + measurementConsumerConfig: File, + encryptionKeyPairConfig: File, + ) { val appliedObjects: List = withContext(Dispatchers.IO) { val outputDir = tempDir.newFolder("cmms") @@ -278,6 +330,13 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { logger.info("Copying $akidPrincipalMap to $CONFIG_FILES_PATH") akidPrincipalMap.copyTo(configFilesDir.resolve(akidPrincipalMap.name)) + logger.info("Copying $encryptionKeyPairConfig to $CONFIG_FILES_PATH") + encryptionKeyPairConfig.copyTo(configFilesDir.resolve(encryptionKeyPairConfig.name)) + + val mcConfigDir = outputDir.toPath().resolve(MC_CONFIG_PATH).toFile() + logger.info("Copying $measurementConsumerConfig to $MC_CONFIG_PATH") + measurementConsumerConfig.copyTo(mcConfigDir.resolve(measurementConsumerConfig.name)) + val configTemplate: File = outputDir.resolve("config.yaml") kustomize( outputDir.toPath().resolve(LOCAL_K8S_TESTING_PATH).resolve("cmms").toFile(), @@ -291,6 +350,8 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { .replace("{worker1_cert_name}", resourceInfo.worker1Cert) .replace("{worker2_cert_name}", resourceInfo.worker2Cert) .replace("{mc_name}", resourceInfo.measurementConsumer) + .replace("{mc_api_key}", resourceInfo.apiKey) + .replace("{mc_cert_name}", resourceInfo.measurementConsumerCert) .let { var config = it for ((displayName, resource) in resourceInfo.dataProviders) { @@ -378,6 +439,8 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { return ResourceSetupOutput( resources, outputDir.resolve(ResourceSetup.AKID_PRINCIPAL_MAP_FILE), + outputDir.resolve(ResourceSetup.MEASUREMENT_CONSUMER_CONFIG_FILE), + outputDir.resolve(ResourceSetup.ENCRYPTION_KEY_PAIR_CONFIG_FILE), ) } @@ -439,6 +502,8 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { data class ResourceSetupOutput( val resources: List, val akidPrincipalMap: File, + val measurementConsumerConfig: File, + val encryptionKeyPairConfig: File, ) } @@ -456,6 +521,8 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { private val DEFAULT_RPC_DEADLINE = Duration.ofSeconds(30) private const val KINGDOM_INTERNAL_DEPLOYMENT_NAME = "gcp-kingdom-data-server-deployment" private const val KINGDOM_PUBLIC_DEPLOYMENT_NAME = "v2alpha-public-api-server-deployment" + private const val REPORTING_PUBLIC_DEPLOYMENT_NAME = + "reporting-v2alpha-public-api-server-deployment" private const val NUM_DATA_PROVIDERS = 6 private val EDP_DISPLAY_NAMES: List = (1..NUM_DATA_PROVIDERS).map { "edp$it" } private val READY_TIMEOUT = Duration.ofMinutes(2L) @@ -463,6 +530,7 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { private val LOCAL_K8S_PATH = Paths.get("src", "main", "k8s", "local") private val LOCAL_K8S_TESTING_PATH = LOCAL_K8S_PATH.resolve("testing") private val CONFIG_FILES_PATH = LOCAL_K8S_TESTING_PATH.resolve("config_files") + private val MC_CONFIG_PATH = LOCAL_K8S_TESTING_PATH.resolve("mc_config") private val IMAGE_PUSHER_PATH = Paths.get("src", "main", "docker", "push_all_local_images.bash") private val tempDir = TemporaryFolder() diff --git a/src/test/kotlin/org/wfanet/measurement/integration/k8s/SyntheticGeneratorCorrectnessTest.kt b/src/test/kotlin/org/wfanet/measurement/integration/k8s/SyntheticGeneratorCorrectnessTest.kt index 86608dd80fa..6108c68c62f 100644 --- a/src/test/kotlin/org/wfanet/measurement/integration/k8s/SyntheticGeneratorCorrectnessTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/integration/k8s/SyntheticGeneratorCorrectnessTest.kt @@ -41,6 +41,9 @@ import org.wfanet.measurement.loadtest.dataprovider.SyntheticGeneratorEventQuery import org.wfanet.measurement.loadtest.measurementconsumer.MeasurementConsumerData import org.wfanet.measurement.loadtest.measurementconsumer.MeasurementConsumerSimulator import org.wfanet.measurement.loadtest.measurementconsumer.MetadataSyntheticGeneratorEventQuery +import org.wfanet.measurement.loadtest.reporting.ReportingUserSimulator +import org.wfanet.measurement.reporting.v2alpha.MetricCalculationSpecsGrpcKt +import org.wfanet.measurement.reporting.v2alpha.ReportsGrpcKt /** * Test for correctness of an existing CMMS on Kubernetes where the EDP simulators use @@ -48,16 +51,21 @@ import org.wfanet.measurement.loadtest.measurementconsumer.MetadataSyntheticGene * The computation composition is using ACDP by assumption. * * This currently assumes that the CMMS instance is using the certificates and keys from this Bazel - * workspace. + * workspace. It also assumes that there is a Reporting system connected to the CMMS. */ class SyntheticGeneratorCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { private class RunningMeasurementSystem : MeasurementSystem, TestRule { override val runId: String by lazy { UUID.randomUUID().toString() } private lateinit var _testHarness: MeasurementConsumerSimulator + private lateinit var _reportingTestHarness: ReportingUserSimulator + override val testHarness: MeasurementConsumerSimulator get() = _testHarness + override val reportingTestHarness: ReportingUserSimulator + get() = _reportingTestHarness + private val channels = mutableListOf() override fun apply(base: Statement, description: Description): Statement { @@ -65,6 +73,7 @@ class SyntheticGeneratorCorrectnessTest : AbstractCorrectnessTest(measurementSys override fun evaluate() { try { _testHarness = createTestHarness() + _reportingTestHarness = createReportingTestHarness() base.evaluate() } finally { shutDownChannels() @@ -110,6 +119,33 @@ class SyntheticGeneratorCorrectnessTest : AbstractCorrectnessTest(measurementSys ) } + private fun createReportingTestHarness(): ReportingUserSimulator { + val publicApiChannel = + buildMutualTlsChannel( + TEST_CONFIG.reportingPublicApiTarget, + REPORTING_SIGNING_CERTS, + TEST_CONFIG.reportingPublicApiCertHost, + ) + .also { channels.add(it) } + .withDefaultDeadline(RPC_DEADLINE_DURATION) + + return ReportingUserSimulator( + measurementConsumerName = TEST_CONFIG.measurementConsumer, + dataProvidersClient = DataProvidersGrpcKt.DataProvidersCoroutineStub(publicApiChannel), + eventGroupsClient = + org.wfanet.measurement.reporting.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub( + publicApiChannel + ), + reportingSetsClient = + org.wfanet.measurement.reporting.v2alpha.ReportingSetsGrpcKt.ReportingSetsCoroutineStub( + publicApiChannel + ), + metricCalculationSpecsClient = + MetricCalculationSpecsGrpcKt.MetricCalculationSpecsCoroutineStub(publicApiChannel), + reportsClient = ReportsGrpcKt.ReportsCoroutineStub(publicApiChannel), + ) + } + private fun shutDownChannels() { for (channel in channels) { channel.shutdown() diff --git a/src/test/kotlin/org/wfanet/measurement/integration/k8s/correctness_test_config.tmpl.textproto b/src/test/kotlin/org/wfanet/measurement/integration/k8s/correctness_test_config.tmpl.textproto index a6b54b93d85..1242cd723ec 100644 --- a/src/test/kotlin/org/wfanet/measurement/integration/k8s/correctness_test_config.tmpl.textproto +++ b/src/test/kotlin/org/wfanet/measurement/integration/k8s/correctness_test_config.tmpl.textproto @@ -4,3 +4,5 @@ kingdom_public_api_target: "{kingdom_public_api_target}" kingdom_public_api_cert_host: "{kingdom_public_api_cert_host}" measurement_consumer: "{mc_name}" api_authentication_key: "{mc_api_key}" +reporting_public_api_target: "{reporting_public_api_target}" +reporting_public_api_cert_host: "{reporting_public_api_cert_host}" diff --git a/src/test/kotlin/org/wfanet/panelmatch/integration/k8s/EmptyClusterPanelMatchCorrectnessTest.kt b/src/test/kotlin/org/wfanet/panelmatch/integration/k8s/EmptyClusterPanelMatchCorrectnessTest.kt index 8c126916597..07de2db818b 100644 --- a/src/test/kotlin/org/wfanet/panelmatch/integration/k8s/EmptyClusterPanelMatchCorrectnessTest.kt +++ b/src/test/kotlin/org/wfanet/panelmatch/integration/k8s/EmptyClusterPanelMatchCorrectnessTest.kt @@ -494,7 +494,7 @@ class EmptyClusterPanelMatchCorrectnessTest : AbstractPanelMatchCorrectnessTest( private val LOCAL_K8S_PATH = Paths.get("src", "main", "k8s", "local") private val LOCAL_K8S_TESTING_PATH = LOCAL_K8S_PATH.resolve("testing") - private val CONFIG_FILES_PATH = LOCAL_K8S_TESTING_PATH.resolve("config_files") + private val CONFIG_FILES_PATH = LOCAL_K8S_TESTING_PATH.resolve("config_files_for_panel_match") private val LOCAL_K8S_PANELMATCH_PATH = Paths.get("src", "main", "k8s", "panelmatch", "local") private val PANELMATCH_CONFIG_FILES_PATH = LOCAL_K8S_PANELMATCH_PATH.resolve("config_files") From 48561a1f8a36ab3060f11e2c36f6a7657a2d2015 Mon Sep 17 00:00:00 2001 From: Sanjay Vasandani Date: Fri, 22 Nov 2024 10:15:57 -0800 Subject: [PATCH 2/3] refactor: Extract listResources utility function for handling pagination (#1923) This also fixes an issue where EventGroupsServiceTest was waiting for a real 30s deadline to pass. Also fixes the simulator issue of PR https://github.com/world-federation-of-advertisers/cross-media-measurement/pull/1927: The MC and EDP simulators were only reading the first page of results from ListEventGroups. This means that in environments with more EventGroups, not all test EventGroups would be read. This is exacerbated by https://github.com/world-federation-of-advertisers/cross-media-measurement/pull/1916, as the default page size is now smaller. --- .../measurement/api/v2alpha/BUILD.bazel | 2 +- .../measurement/common/api/grpc/BUILD.bazel | 12 +- .../common/api/grpc/ListResources.kt | 92 +++++++++++ .../integration/common/BUILD.bazel | 1 + ...sMeasurementSystemProberIntegrationTest.kt | 52 +++--- .../measurement/kingdom/batch/BUILD.bazel | 1 + .../kingdom/batch/MeasurementSystemProber.kt | 151 ++++++++---------- .../loadtest/dataprovider/BUILD.bazel | 1 + .../loadtest/dataprovider/EdpSimulator.kt | 43 ++--- .../loadtest/measurementconsumer/BUILD.bazel | 1 + .../MeasurementConsumerSimulator.kt | 50 ++++-- .../reporting/service/api/v1alpha/BUILD.bazel | 2 +- .../reporting/service/api/v2alpha/BUILD.bazel | 3 +- .../service/api/v2alpha/EventGroupsService.kt | 125 +++++++++------ .../api/v2alpha/EventGroupsServiceTest.kt | 39 ++++- 15 files changed, 377 insertions(+), 198 deletions(-) create mode 100644 src/main/kotlin/org/wfanet/measurement/common/api/grpc/ListResources.kt diff --git a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel index 45bc9a2576f..e53b6c87c1c 100644 --- a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel @@ -52,7 +52,7 @@ kt_jvm_library( deps = [ "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:context_keys", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:measurement_principal", - "//src/main/kotlin/org/wfanet/measurement/common/api/grpc", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:akid_principal_server_interceptor", "//src/main/kotlin/org/wfanet/measurement/common/grpc:context", "//src/main/kotlin/org/wfanet/measurement/common/identity", "//src/main/proto/wfa/measurement/api/v2alpha:duchy_kt_jvm_proto", diff --git a/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel index 3a897c0ad81..4ffd63ed7ab 100644 --- a/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel @@ -3,7 +3,7 @@ load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") package(default_visibility = ["//visibility:public"]) kt_jvm_library( - name = "grpc", + name = "akid_principal_server_interceptor", srcs = ["AkidPrincipalServerInterceptor.kt"], deps = [ "//src/main/kotlin/org/wfanet/measurement/common/api:principal", @@ -14,3 +14,13 @@ kt_jvm_library( "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", ], ) + +kt_jvm_library( + name = "list_resources", + srcs = ["ListResources.kt"], + deps = [ + "@wfa_common_jvm//imports/java/com/google/protobuf", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + "@wfa_rules_kotlin_jvm//imports/io/gprc/kotlin:stub", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/common/api/grpc/ListResources.kt b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/ListResources.kt new file mode 100644 index 00000000000..ebd768b2412 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/ListResources.kt @@ -0,0 +1,92 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.common.api.grpc + +import com.google.protobuf.Message +import io.grpc.kotlin.AbstractCoroutineStub +import kotlin.coroutines.coroutineContext +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.ensureActive +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.flattenConcat +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.map + +/** A [List] of resources from a paginated List method. */ +data class ResourceList( + val resources: List, + /** + * A token that can be sent on subsequent requests to retrieve the next page. If empty, there are + * no subsequent pages. + */ + val nextPageToken: String, +) : List by resources + +/** + * Lists resources from a paginated List method on this stub. + * + * @param pageToken page token for initial request + * @param list function which calls the appropriate List method on the stub + */ +fun > S.listResources( + pageToken: String = "", + list: suspend S.(pageToken: String) -> ResourceList, +): Flow> = + listResources(Int.MAX_VALUE, pageToken) { nextPageToken, _ -> list(nextPageToken) } + +/** + * Lists resources from a paginated List method on this stub. + * + * @param limit maximum number of resources to emit + * @param pageToken page token for initial request + * @param list function which calls the appropriate List method on the stub, returning no more than + * the specified remaining number of resources + */ +fun > S.listResources( + limit: Int, + pageToken: String = "", + list: suspend S.(pageToken: String, remaining: Int) -> ResourceList, +): Flow> { + require(limit > 0) { "limit must be positive" } + return flow { + var remaining: Int = limit + var nextPageToken = pageToken + + while (true) { + coroutineContext.ensureActive() + + val resourceList: ResourceList = list(nextPageToken, remaining) + require(resourceList.size <= remaining) { + "List call must ensure that limit is not exceeded. " + + "Returned ${resourceList.size} items when only $remaining were remaining" + } + emit(resourceList) + + remaining -= resourceList.size + nextPageToken = resourceList.nextPageToken + if (nextPageToken.isEmpty() || remaining == 0) { + break + } + } + } +} + +/** @see [flattenConcat] */ +@ExperimentalCoroutinesApi // Overloads experimental `flattenConcat` function. +fun Flow>.flattenConcat(): Flow = + map { it.asFlow() }.flattenConcat() diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel index f47074b4ae3..8764c226d65 100644 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel @@ -319,6 +319,7 @@ kt_jvm_library( ], deps = [ ":in_process_cmms_components", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/kotlin/org/wfanet/measurement/kingdom/batch:measurement_system_prober", "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/service:data_services", "@wfa_common_jvm//imports/java/com/google/common/truth", diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt index 0a015ca80d3..d9a14f29152 100644 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt @@ -22,6 +22,9 @@ import java.io.File import java.nio.file.Paths import java.time.Clock import java.time.Duration +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking import org.junit.After import org.junit.Before @@ -30,13 +33,15 @@ import org.junit.Rule import org.junit.Test import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCoroutineStub import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub -import org.wfanet.measurement.api.v2alpha.ListMeasurementsResponse import org.wfanet.measurement.api.v2alpha.Measurement import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt.RequisitionsCoroutineStub import org.wfanet.measurement.api.v2alpha.listMeasurementsRequest import org.wfanet.measurement.api.withAuthenticationKey +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.flattenConcat +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.getRuntimePath import org.wfanet.measurement.common.identity.withPrincipalName import org.wfanet.measurement.common.testing.ProviderRule @@ -129,33 +134,32 @@ abstract class InProcessMeasurementSystemProberIntegrationTest( assertThat(measurements.size).isEqualTo(1) } + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. private suspend fun listMeasurements(): List { - var nextPageToken = "" val measurementConsumerData = inProcessCmmsComponents.getMeasurementConsumerData() - do { - val response: ListMeasurementsResponse = - try { - publicMeasurementsClient - .withAuthenticationKey(measurementConsumerData.apiAuthenticationKey) - .listMeasurements( - listMeasurementsRequest { - parent = measurementConsumerData.name - pageToken = nextPageToken - } - ) - } catch (e: StatusException) { - throw Exception( - "Unable to list measurements for measurement consumer ${measurementConsumerData.name}", - e, - ) + val measurementLists: Flow> = + publicMeasurementsClient + .withAuthenticationKey(measurementConsumerData.apiAuthenticationKey) + .listResources { pageToken -> + val response = + try { + listMeasurements( + listMeasurementsRequest { + parent = measurementConsumerData.name + this.pageToken = pageToken + } + ) + } catch (e: StatusException) { + throw Exception( + "Unable to list measurements for measurement consumer ${measurementConsumerData.name}", + e, + ) + } + ResourceList(response.measurementsList, response.nextPageToken) } - if (response.measurementsList.isNotEmpty()) { - return response.measurementsList - } - nextPageToken = response.nextPageToken - } while (nextPageToken.isNotEmpty()) - return emptyList() + + return measurementLists.flattenConcat().toList() } companion object { diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel index 12d1d96644d..2d455bdfd32 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel @@ -47,6 +47,7 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/api:api_key_constants", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/proto/wfa/measurement/api/v2alpha:data_provider_kt_jvm_proto", "//src/main/proto/wfa/measurement/api/v2alpha:data_providers_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/api/v2alpha:event_group_kt_jvm_proto", diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt index b64a3a55842..e7c95bdd42c 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt @@ -27,8 +27,12 @@ import java.security.SecureRandom import java.time.Clock import java.time.Duration import java.util.logging.Logger +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.single +import kotlinx.coroutines.flow.singleOrNull import org.wfanet.measurement.api.v2alpha.CanonicalRequisitionKey -import org.wfanet.measurement.api.v2alpha.DataProvider import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt import org.wfanet.measurement.api.v2alpha.EventGroup import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt @@ -60,6 +64,9 @@ import org.wfanet.measurement.api.v2alpha.requisitionSpec import org.wfanet.measurement.api.v2alpha.unpack import org.wfanet.measurement.api.withAuthenticationKey import org.wfanet.measurement.common.Instrumentation +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.flattenConcat +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.crypto.Hashing import org.wfanet.measurement.common.crypto.SigningKeyHandle import org.wfanet.measurement.common.crypto.readCertificate @@ -198,43 +205,31 @@ class MeasurementSystemProber( private suspend fun buildDataProviderNameToEventGroup(): Map { val dataProviderNameToEventGroup = mutableMapOf() for (dataProviderName in dataProviderNames) { - val getDataProviderRequest = getDataProviderRequest { name = dataProviderName } - val dataProvider: DataProvider = - try { - dataProvidersStub - .withAuthenticationKey(apiAuthenticationKey) - .getDataProvider(getDataProviderRequest) - } catch (e: StatusException) { - throw Exception("Unable to get DataProvider with name $dataProviderName", e) - } - - // TODO(@roaminggypsy): Implement QA event group logic using simulatorEventGroupName - val listEventGroupsRequest = listEventGroupsRequest { - parent = measurementConsumerName - filter = ListEventGroupsRequestKt.filter { dataProviders += dataProviderName } - } - - val eventGroups: List = - try { - eventGroupsStub - .withAuthenticationKey(apiAuthenticationKey) - .listEventGroups(listEventGroupsRequest) - .eventGroupsList - .toList() - } catch (e: StatusException) { - throw Exception( - "Unable to get event groups associated with measurement consumer $measurementConsumerName and data provider $dataProviderName", - e, - ) - } - - if (eventGroups.size != 1) { - throw IllegalStateException( - "here should be exactly 1:1 mapping between a data provider and an event group, but data provider $dataProvider is related to ${eventGroups.size} event groups" - ) - } + val eventGroup: EventGroup = + eventGroupsStub + .withAuthenticationKey(apiAuthenticationKey) + .listResources(1) { pageToken, remaining -> + val request = listEventGroupsRequest { + parent = measurementConsumerName + filter = ListEventGroupsRequestKt.filter { dataProviders += dataProviderName } + this.pageToken = pageToken + pageSize = remaining + } + val response = + try { + listEventGroups(request) + } catch (e: StatusException) { + throw Exception( + "Unable to get event groups associated with measurement consumer $measurementConsumerName and data provider $dataProviderName", + e, + ) + } + ResourceList(response.eventGroupsList, response.nextPageToken) + } + .map { it.single() } + .single() - dataProviderNameToEventGroup[dataProviderName] = eventGroups[0] + dataProviderNameToEventGroup[dataProviderName] = eventGroup } return dataProviderNameToEventGroup } @@ -253,55 +248,47 @@ class MeasurementSystemProber( return clock.instant() >= nextMeasurementEarliestInstant } + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. private suspend fun getLastUpdatedMeasurement(): Measurement? { - var nextPageToken = "" - do { - val response: ListMeasurementsResponse = - try { - measurementsStub - .withAuthenticationKey(apiAuthenticationKey) - .listMeasurements( + val measurements: Flow> = + measurementsStub.withAuthenticationKey(apiAuthenticationKey).listResources(1) { + pageToken, + remaining -> + val response: ListMeasurementsResponse = + try { + listMeasurements( listMeasurementsRequest { parent = measurementConsumerName - this.pageSize = 1 - pageToken = nextPageToken + this.pageToken = pageToken + this.pageSize = remaining } ) - } catch (e: StatusException) { - throw Exception( - "Unable to list measurements for measurement consumer $measurementConsumerName", - e, - ) - } - if (response.measurementsList.isNotEmpty()) { - return response.measurementsList.single() + } catch (e: StatusException) { + throw Exception( + "Unable to list measurements for measurement consumer $measurementConsumerName", + e, + ) + } + ResourceList(response.measurementsList, response.nextPageToken) } - nextPageToken = response.nextPageToken - } while (nextPageToken.isNotEmpty()) - return null + + return measurements.flattenConcat().singleOrNull() } - private suspend fun getRequisitionsForMeasurement(measurementName: String): List { - var nextPageToken = "" - val requisitions = mutableListOf() - do { - val response: ListRequisitionsResponse = - try { - requisitionsStub - .withAuthenticationKey(apiAuthenticationKey) - .listRequisitions( - listRequisitionsRequest { - parent = measurementName - pageToken = nextPageToken - } - ) - } catch (e: StatusException) { - throw Exception("Unable to list requisitions for measurement $measurementName", e) - } - requisitions.addAll(response.requisitionsList) - nextPageToken = response.nextPageToken - } while (nextPageToken.isNotEmpty()) - return requisitions + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. + private fun getRequisitionsForMeasurement(measurementName: String): Flow { + return requisitionsStub + .withAuthenticationKey(apiAuthenticationKey) + .listResources { pageToken -> + val response: ListRequisitionsResponse = + try { + listRequisitions(listRequisitionsRequest { this.pageToken = pageToken }) + } catch (e: StatusException) { + throw Exception("Unable to list requisitions for measurement $measurementName", e) + } + ResourceList(response.requisitionsList, response.nextPageToken) + } + .flattenConcat() } private suspend fun getDataProviderEntry( @@ -360,10 +347,12 @@ class MeasurementSystemProber( private suspend fun updateLastTerminalRequisitionGauge(lastUpdatedMeasurement: Measurement) { val requisitions = getRequisitionsForMeasurement(lastUpdatedMeasurement.name) - for (requisition in requisitions) { + requisitions.collect { requisition -> if (requisition.state == Requisition.State.FULFILLED) { - val requisitionKey = CanonicalRequisitionKey.fromName(requisition.name) - require(requisitionKey != null) { "CanonicalRequisitionKey cannot be null" } + val requisitionKey = + requireNotNull(CanonicalRequisitionKey.fromName(requisition.name)) { + "Requisition name ${requisition.name} is invalid" + } val dataProviderName: String = requisitionKey.dataProviderId val attributes = Attributes.of(DATA_PROVIDER_ATTRIBUTE_KEY, dataProviderName) lastTerminalRequisitionTimeGauge.set( diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/BUILD.bazel index 64479522bc0..99d3e013189 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/BUILD.bazel @@ -123,6 +123,7 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", "//src/main/kotlin/org/wfanet/measurement/common:health", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/kotlin/org/wfanet/measurement/dataprovider:requisition_fulfiller", "//src/main/kotlin/org/wfanet/measurement/eventdataprovider/noiser", "//src/main/kotlin/org/wfanet/measurement/eventdataprovider/privacybudgetmanagement:privacy_budget_manager", diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt index 7a29bd5f9d3..4737bcf56bd 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt @@ -36,9 +36,11 @@ import kotlin.math.roundToInt import kotlin.random.Random import kotlin.random.asJavaRandom import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.asFlow import kotlinx.coroutines.flow.emitAll +import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map import kotlinx.coroutines.withContext @@ -109,6 +111,9 @@ import org.wfanet.measurement.api.v2alpha.updateEventGroupRequest import org.wfanet.measurement.common.Health import org.wfanet.measurement.common.ProtoReflection import org.wfanet.measurement.common.SettableHealth +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.flattenConcat +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.asBufferedFlow import org.wfanet.measurement.common.crypto.authorityKeyIdentifier import org.wfanet.measurement.common.crypto.readCertificate @@ -339,27 +344,29 @@ class EdpSimulator( * Returns the first [EventGroup] for this `DataProvider` and [MeasurementConsumer] with * [eventGroupReferenceId], or `null` if not found. */ + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. private suspend fun getEventGroupByReferenceId(eventGroupReferenceId: String): EventGroup? { - val response = - try { - eventGroupsStub.listEventGroups( - listEventGroupsRequest { - parent = edpData.name - filter = - ListEventGroupsRequestKt.filter { measurementConsumers += measurementConsumerName } - pageSize = Int.MAX_VALUE + return eventGroupsStub + .listResources { pageToken -> + val response = + try { + listEventGroups( + listEventGroupsRequest { + parent = edpData.name + filter = + ListEventGroupsRequestKt.filter { + measurementConsumers += measurementConsumerName + } + this.pageToken = pageToken + } + ) + } catch (e: StatusException) { + throw Exception("Error listing EventGroups", e) } - ) - } catch (e: StatusException) { - throw Exception("Error listing EventGroups", e) + ResourceList(response.eventGroupsList, response.nextPageToken) } - - // TODO(@SanjayVas): Support filtering by reference ID so we don't need to handle multiple pages - // of EventGroups. - check(response.nextPageToken.isEmpty()) { - "Too many EventGroups for ${edpData.name} and $measurementConsumerName" - } - return response.eventGroupsList.find { it.eventGroupReferenceId == eventGroupReferenceId } + .flattenConcat() + .firstOrNull { it.eventGroupReferenceId == eventGroupReferenceId } } private suspend fun ensureMetadataDescriptor( diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/BUILD.bazel index a6b9b428e95..1451a300dae 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/BUILD.bazel @@ -28,6 +28,7 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha/testing", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/kotlin/org/wfanet/measurement/common/identity", "//src/main/kotlin/org/wfanet/measurement/integration/common:configs", "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:api_key_authentication_server_interceptor", diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/MeasurementConsumerSimulator.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/MeasurementConsumerSimulator.kt index bbdf684ba45..f008d4ef94a 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/MeasurementConsumerSimulator.kt +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/MeasurementConsumerSimulator.kt @@ -31,6 +31,10 @@ import kotlin.math.log2 import kotlin.math.max import kotlin.math.sqrt import kotlin.random.Random +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.filter +import kotlinx.coroutines.flow.toList import kotlinx.coroutines.time.delay import org.projectnessie.cel.Program import org.wfanet.measurement.api.v2alpha.Certificate @@ -92,6 +96,9 @@ import org.wfanet.measurement.api.v2alpha.unpack import org.wfanet.measurement.api.withAuthenticationKey import org.wfanet.measurement.common.ExponentialBackoff import org.wfanet.measurement.common.OpenEndTimeRange +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.flattenConcat +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.coerceAtMost import org.wfanet.measurement.common.crypto.Hashing import org.wfanet.measurement.common.crypto.PrivateKeyHandle @@ -812,11 +819,13 @@ class MeasurementConsumerSimulator( maxDataProviders: Int = 20, ): MeasurementInfo { val eventGroups: List = - listEventGroups(measurementConsumer.name).filter { - it.eventGroupReferenceId.startsWith( - TestIdentifiers.SIMULATOR_EVENT_GROUP_REFERENCE_ID_PREFIX - ) - } + listEventGroups(measurementConsumer.name) + .filter { + it.eventGroupReferenceId.startsWith( + TestIdentifiers.SIMULATOR_EVENT_GROUP_REFERENCE_ID_PREFIX + ) + } + .toList() check(eventGroups.isNotEmpty()) { "No event groups found for ${measurementConsumer.name}" } val nonceHashes = mutableListOf() val keyToDataProviderMap: Map = @@ -1255,16 +1264,25 @@ class MeasurementConsumerSimulator( } } - private suspend fun listEventGroups(measurementConsumer: String): List { - val request = listEventGroupsRequest { parent = measurementConsumer } - try { - return eventGroupsClient - .withAuthenticationKey(measurementConsumerData.apiAuthenticationKey) - .listEventGroups(request) - .eventGroupsList - } catch (e: StatusException) { - throw Exception("Error listing event groups for MC $measurementConsumer", e) - } + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. + private fun listEventGroups(measurementConsumer: String): Flow { + return eventGroupsClient + .withAuthenticationKey(measurementConsumerData.apiAuthenticationKey) + .listResources { pageToken -> + val response = + try { + listEventGroups( + listEventGroupsRequest { + parent = measurementConsumer + this.pageToken = pageToken + } + ) + } catch (e: StatusException) { + throw Exception("Error listing event groups for MC $measurementConsumer", e) + } + ResourceList(response.eventGroupsList, response.nextPageToken) + } + .flattenConcat() } private fun extractDataProviderKey(eventGroupName: String): DataProviderKey { @@ -1283,7 +1301,7 @@ class MeasurementConsumerSimulator( } } - private suspend fun buildRequisitionInfo( + private fun buildRequisitionInfo( dataProvider: DataProvider, eventGroups: List, measurementConsumer: MeasurementConsumer, diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel index 445524ac557..3cdded1d308 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel @@ -132,7 +132,7 @@ kt_jvm_library( "context_keys", ":reporting_principal", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", - "//src/main/kotlin/org/wfanet/measurement/common/api/grpc", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:akid_principal_server_interceptor", "//src/main/kotlin/org/wfanet/measurement/common/identity", "@wfa_common_jvm//imports/java/com/google/protobuf", "@wfa_common_jvm//imports/java/io/grpc:api", diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel index 22d6343b537..fc750cdc976 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel @@ -52,7 +52,7 @@ kt_jvm_library( "context_keys", ":reporting_principal", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", - "//src/main/kotlin/org/wfanet/measurement/common/api/grpc", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:akid_principal_server_interceptor", "//src/main/kotlin/org/wfanet/measurement/common/identity", "@wfa_common_jvm//imports/java/com/google/protobuf", "@wfa_common_jvm//imports/java/io/grpc:api", @@ -120,6 +120,7 @@ kt_jvm_library( "//imports/java/org/projectnessie/cel", "//src/main/kotlin/org/wfanet/measurement/api:api_key_constants", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:cel_env_provider", "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:encryption_key_pair_store", "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha:principal_server_interceptor", diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt index 825e90c2bff..650a883fa68 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt @@ -20,10 +20,13 @@ import com.google.protobuf.DynamicMessage import com.google.protobuf.kotlin.unpack import io.grpc.Context import io.grpc.Deadline +import io.grpc.Deadline.Ticker import io.grpc.Status import io.grpc.StatusException import java.security.GeneralSecurityException import java.util.concurrent.TimeUnit +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.transformWhile import org.projectnessie.cel.common.types.Err import org.projectnessie.cel.common.types.ref.Val import org.wfanet.measurement.api.v2alpha.DataProviderKey @@ -31,9 +34,12 @@ import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey import org.wfanet.measurement.api.v2alpha.EventGroup as CmmsEventGroup import org.wfanet.measurement.api.v2alpha.EventGroupKey as CmmsEventGroupKey import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub as CmmsEventGroupsCoroutineStub +import org.wfanet.measurement.api.v2alpha.ListEventGroupsResponse as CmmsListEventGroupsResponse import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey import org.wfanet.measurement.api.v2alpha.listEventGroupsRequest import org.wfanet.measurement.api.withAuthenticationKey +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.crypto.PrivateKeyHandle import org.wfanet.measurement.common.grpc.grpcRequire import org.wfanet.measurement.common.grpc.grpcRequireNotNull @@ -52,6 +58,7 @@ class EventGroupsService( private val cmmsEventGroupsStub: CmmsEventGroupsCoroutineStub, private val encryptionKeyPairStore: EncryptionKeyPairStore, private val celEnvProvider: CelEnvProvider, + private val ticker: Ticker = Deadline.getSystemTicker(), ) : EventGroupsCoroutineImplBase() { override suspend fun listEventGroups(request: ListEventGroupsRequest): ListEventGroupsResponse { val parentKey = @@ -71,66 +78,79 @@ class EventGroupsService( } } + val deadline: Deadline = + Context.current().deadline + ?: Deadline.after(RPC_DEFAULT_DEADLINE_MILLIS, TimeUnit.MILLISECONDS, ticker) val apiAuthenticationKey: String = principal.config.apiKey grpcRequire(request.pageSize >= 0) { "page_size cannot be negative" } - val pageSize = - when { - request.pageSize < MIN_PAGE_SIZE -> DEFAULT_PAGE_SIZE - request.pageSize > MAX_PAGE_SIZE -> MAX_PAGE_SIZE - else -> request.pageSize - } + val limit = + if (request.pageSize > 0) request.pageSize.coerceAtMost(MAX_PAGE_SIZE) else DEFAULT_PAGE_SIZE + val parent = parentKey.toName() + val eventGroupLists: Flow> = + cmmsEventGroupsStub.withAuthenticationKey(apiAuthenticationKey).listResources( + limit, + request.pageToken, + ) { pageToken, remaining -> + val response: CmmsListEventGroupsResponse = + listEventGroups( + listEventGroupsRequest { + this.parent = parent + this.pageSize = remaining + this.pageToken = pageToken + } + ) - var nextPageToken = request.pageToken - val deadline = Context.current().deadline ?: Deadline.after(30, TimeUnit.SECONDS) - do { - val cmmsListEventGroupResponse = - try { - cmmsEventGroupsStub - .withAuthenticationKey(apiAuthenticationKey) - .listEventGroups( - listEventGroupsRequest { - parent = parentKey.toName() - this.pageSize = pageSize - pageToken = nextPageToken + val eventGroups: List = + response.eventGroupsList.map { + val cmmsMetadata: CmmsEventGroup.Metadata? = + if (it.hasEncryptedMetadata()) { + decryptMetadata(it, principal.resourceKey.toName()) + } else { + null } - ) - } catch (e: StatusException) { - throw when (e.status.code) { - Status.Code.DEADLINE_EXCEEDED -> Status.DEADLINE_EXCEEDED - Status.Code.CANCELLED -> Status.CANCELLED - else -> Status.UNKNOWN - } - .withCause(e) - .asRuntimeException() - } - val cmmsEventGroups = cmmsListEventGroupResponse.eventGroupsList - val eventGroups = - cmmsEventGroups.map { - val cmmsMetadata: CmmsEventGroup.Metadata? = - if (it.hasEncryptedMetadata()) { - decryptMetadata(it, principal.resourceKey.toName()) - } else { - null - } + it.toEventGroup(cmmsMetadata) + } - it.toEventGroup(cmmsMetadata) - } + ResourceList(filterEventGroups(eventGroups, request.filter), response.nextPageToken) + } - val filteredEventGroups = filterEventGroups(eventGroups, request.filter) - if (filteredEventGroups.size > 0) { - return listEventGroupsResponse { - this.eventGroups += filteredEventGroups - this.nextPageToken = cmmsListEventGroupResponse.nextPageToken + var hasResponse = false + return listEventGroupsResponse { + try { + eventGroupLists + .transformWhile { + emit(it) + deadline.timeRemaining(TimeUnit.MILLISECONDS) > RPC_DEADLINE_OVERHEAD_MILLIS + } + .collect { eventGroupList -> + this.eventGroups += eventGroupList + nextPageToken = eventGroupList.nextPageToken + hasResponse = true + } + } catch (e: StatusException) { + when (e.status.code) { + Status.Code.DEADLINE_EXCEEDED, + Status.Code.CANCELLED -> { + if (!hasResponse) { + // Only throw an error if we don't have any response yet. Otherwise, just return what + // we have so far. + throw Status.DEADLINE_EXCEEDED.withDescription( + "Timed out listing EventGroups from backend" + ) + .withCause(e) + .asRuntimeException() + } + } + else -> + throw Status.UNKNOWN.withDescription("Error listing EventGroups from backend") + .withCause(e) + .asRuntimeException() } - } else { - nextPageToken = cmmsListEventGroupResponse.nextPageToken } - } while (deadline.timeRemaining(TimeUnit.SECONDS) > 5) - - return listEventGroupsResponse { this.nextPageToken = nextPageToken } + } } private suspend fun filterEventGroups( @@ -259,8 +279,13 @@ class EventGroupsService( companion object { private const val METADATA_FIELD = "metadata.metadata" - private const val MIN_PAGE_SIZE = 1 private const val DEFAULT_PAGE_SIZE = 50 private const val MAX_PAGE_SIZE = 1000 + + /** Overhead to allow for RPC deadlines in milliseconds. */ + private const val RPC_DEADLINE_OVERHEAD_MILLIS = 100L + + /** Default RPC deadline in milliseconds. */ + private const val RPC_DEFAULT_DEADLINE_MILLIS = 30_000L } } diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt index befbdcedd0b..42df07ae256 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt @@ -19,11 +19,13 @@ package org.wfanet.measurement.reporting.service.api.v2alpha import com.google.common.truth.Truth.assertThat import com.google.common.truth.extensions.proto.ProtoTruth.assertThat import com.google.protobuf.Any +import io.grpc.Deadline import io.grpc.Status import io.grpc.StatusRuntimeException import java.nio.file.Path import java.nio.file.Paths import java.time.Duration +import java.util.concurrent.TimeUnit import kotlin.test.assertFailsWith import kotlinx.coroutines.runBlocking import org.junit.After @@ -110,6 +112,7 @@ class EventGroupsServiceTest { private lateinit var celEnvCacheProvider: CelEnvCacheProvider private lateinit var service: EventGroupsService + private val fakeTicker = SettableSystemTicker() @Before fun initService() { @@ -126,6 +129,7 @@ class EventGroupsServiceTest { EventGroupsCoroutineStub(grpcTestServerRule.channel), ENCRYPTION_KEY_PAIR_STORE, celEnvCacheProvider, + fakeTicker, ) } @@ -152,13 +156,13 @@ class EventGroupsServiceTest { whenever(publicKingdomEventGroupsMock.listEventGroups(any())) .thenReturn( listEventGroupsResponse { - nextPageToken = "1" eventGroups += cmmsEventGroup2 + nextPageToken = "1" } ) .thenReturn( listEventGroupsResponse { - eventGroups += listOf(CMMS_EVENT_GROUP, cmmsEventGroup2) + eventGroups += CMMS_EVENT_GROUP nextPageToken = "2" } ) @@ -170,6 +174,7 @@ class EventGroupsServiceTest { listEventGroupsRequest { parent = MEASUREMENT_CONSUMER_NAME filter = "metadata.metadata.publisher_id > 5" + pageSize = 1 } ) } @@ -200,7 +205,10 @@ class EventGroupsServiceTest { eventGroups += CMMS_EVENT_GROUP } ) - + .then { + // Advance time. + fakeTicker.setNanoTime(fakeTicker.nanoTime() + TimeUnit.SECONDS.toNanos(30)) + } val response = withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { runBlocking { @@ -443,7 +451,12 @@ class EventGroupsServiceTest { val response = withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { runBlocking { - service.listEventGroups(listEventGroupsRequest { parent = MEASUREMENT_CONSUMER_NAME }) + service.listEventGroups( + listEventGroupsRequest { + parent = MEASUREMENT_CONSUMER_NAME + pageSize = 2 + } + ) } } @@ -455,7 +468,7 @@ class EventGroupsServiceTest { .isEqualTo( cmmsListEventGroupsRequest { parent = MEASUREMENT_CONSUMER_NAME - pageSize = DEFAULT_PAGE_SIZE + pageSize = 2 } ) } @@ -813,4 +826,20 @@ class EventGroupsServiceTest { } } } + + /** + * Fake [Deadline.Ticker] implementation that allows time to be specified to override delegation + * to the system ticker. + */ + private class SettableSystemTicker : Deadline.Ticker() { + private var nanoTime: Long? = null + + fun setNanoTime(value: Long) { + nanoTime = value + } + + override fun nanoTime(): Long { + return this.nanoTime ?: Deadline.getSystemTicker().nanoTime() + } + } } From b29ff7b52e9305dabd9e11e1908b1e581bcbfd12 Mon Sep 17 00:00:00 2001 From: Sanjay Vasandani Date: Mon, 25 Nov 2024 10:56:15 -0800 Subject: [PATCH 3/3] refactor: Delete Reporting v1 (#1917) Closes #1882 --- .github/workflows/configure-reporting.yml | 184 - .github/workflows/scan-images.yml | 9 +- .github/workflows/update-cmms.yml | 9 - docs/gke/metrics-deployment.md | 2 +- docs/gke/reporting-server-deployment.md | 394 -- docs/gke/reporting-v2-server-deployment.md | 5 - src/main/docker/images.bzl | 42 +- src/main/k8s/BUILD.bazel | 10 - src/main/k8s/dev/BUILD.bazel | 52 - .../reporting_config_files_kustomization.yaml | 21 - src/main/k8s/dev/reporting_gke.cue | 88 - src/main/k8s/local/BUILD.bazel | 41 - src/main/k8s/local/README.md | 9 +- .../k8s/local/config_files_kustomization.yaml | 1 - src/main/k8s/local/reporting.cue | 77 - src/main/k8s/reporting.cue | 176 - src/main/k8s/testing/secretfiles/BUILD.bazel | 1 - .../measurement_spec_config.textproto | 82 - .../integration/common/reporting/BUILD.bazel | 63 - .../InProcessLifeOfAReportIntegrationTest.kt | 434 -- .../reporting/InProcessReportingServer.kt | 280 - .../common/reporting/identity/BUILD.bazel | 34 - .../MetadataPrincipalServerInterceptor.kt | 81 - .../identity/ReportingPrincipalIdentity.kt | 34 - .../reporting/deploy/common/BUILD.bazel | 34 - .../deploy/common/EncryptionKeyPairMap.kt | 55 - .../deploy/common/InternalApiFlags.kt | 38 - .../deploy/common/KingdomApiFlags.kt | 38 - .../deploy/common/server/BUILD.bazel | 78 - .../common/server/ReportingApiServerFlags.kt | 46 - .../common/server/ReportingDataServer.kt | 46 - .../common/server/V1AlphaPublicApiServer.kt | 220 - .../deploy/common/server/postgres/BUILD.bazel | 20 - .../server/postgres/PostgresServices.kt | 33 - .../reporting/deploy/config/BUILD.bazel | 13 - .../config/MeasurementSpecConfigValidator.kt | 107 - .../deploy/gcloud/postgres/server/BUILD.bazel | 35 - .../GCloudPostgresReportingDataServer.kt | 49 - .../deploy/gcloud/postgres/tools/BUILD.bazel | 21 - .../reporting/deploy/postgres/BUILD.bazel | 22 - .../postgres/PostgresMeasurementsService.kt | 89 - .../postgres/PostgresReportingSetsService.kt | 83 - .../deploy/postgres/PostgresReportsService.kt | 93 - .../deploy/postgres/readers/BUILD.bazel | 26 - .../postgres/readers/MeasurementReader.kt | 82 - .../readers/MeasurementResultsReader.kt | 74 - .../deploy/postgres/readers/ReportReader.kt | 527 -- .../postgres/readers/ReportingSetReader.kt | 202 - .../deploy/postgres/server/BUILD.bazel | 29 - .../server/PostgresReportingDataServer.kt | 47 - .../deploy/postgres/testing/BUILD.bazel | 18 - .../deploy/postgres/testing/Schemata.kt | 32 - .../deploy/postgres/tools/BUILD.bazel | 20 - .../deploy/postgres/writers/BUILD.bazel | 27 - .../postgres/writers/CreateMeasurements.kt | 53 - .../deploy/postgres/writers/CreateReport.kt | 482 -- .../postgres/writers/CreateReportingSet.kt | 92 - .../postgres/writers/SetMeasurementFailure.kt | 97 - .../postgres/writers/SetMeasurementResult.kt | 411 -- .../api/v1alpha/AkidPrincipalLookup.kt | 72 - .../reporting/service/api/v1alpha/BUILD.bazel | 143 - .../service/api/v1alpha/ContextKeys.kt | 22 - .../service/api/v1alpha/EventGroupKey.kt | 55 - .../api/v1alpha/EventGroupParentKey.kt | 49 - .../service/api/v1alpha/EventGroupsService.kt | 237 - .../service/api/v1alpha/IdVariable.kt | 36 - .../api/v1alpha/PrincipalServerInterceptor.kt | 86 - .../service/api/v1alpha/ReportKey.kt | 40 - .../service/api/v1alpha/ReportingPrincipal.kt | 47 - .../service/api/v1alpha/ReportingSetKey.kt | 47 - .../api/v1alpha/ReportingSetsService.kt | 235 - .../service/api/v1alpha/ReportsService.kt | 2260 -------- .../api/v1alpha/SetOperationCompiler.kt | 500 -- .../service/api/v1alpha/tools/BUILD.bazel | 32 - .../service/api/v1alpha/tools/README.md | 138 - .../service/api/v1alpha/tools/Reporting.kt | 497 -- .../measurement/config/reporting/BUILD.bazel | 11 - .../reporting/measurement_spec_config.proto | 108 - .../measurement/reporting/v1alpha/BUILD.bazel | 148 - .../reporting/v1alpha/event_group.proto | 72 - .../v1alpha/event_groups_service.proto | 79 - .../reporting/v1alpha/metric.proto | 129 - .../reporting/v1alpha/page_token.proto | 40 - .../reporting/v1alpha/report.proto | 146 - .../reporting/v1alpha/reporting_set.proto | 53 - .../v1alpha/reporting_sets_service.proto | 97 - .../reporting/v1alpha/reports_service.proto | 114 - .../reporting/v1alpha/time_interval.proto | 61 - src/main/terraform/gcloud/cmms/reporting.tf | 44 - .../deploy/common/postgres/BUILD.bazel | 21 - ...esInProcessLifeOfAReportIntegrationTest.kt | 36 - .../reporting/deploy/common/BUILD.bazel | 21 - .../deploy/common/EncryptionKeyPairMapTest.kt | 144 - .../deploy/common/key_pair_map.textproto | 26 - .../reporting/deploy/config/BUILD.bazel | 12 - .../MeasurementSpecConfigValidatorTest.kt | 295 -- .../reporting/deploy/postgres/BUILD.bazel | 58 - .../PostgresMeasurementsServiceTest.kt | 42 - .../PostgresReportingSetsServiceTest.kt | 38 - .../postgres/PostgresReportsServiceTest.kt | 42 - .../reporting/service/api/BUILD.bazel | 2 +- .../service/api/CelEnvProviderTest.kt | 6 +- .../reporting/service/api/v1alpha/BUILD.bazel | 108 - .../api/v1alpha/EventGroupsServiceTest.kt | 489 -- .../api/v1alpha/ReportingSetsServiceTest.kt | 750 --- .../service/api/v1alpha/ReportsServiceTest.kt | 4656 ----------------- .../api/v1alpha/SetOperationCompilerTest.kt | 197 - .../service/api/v1alpha/tools/BUILD.bazel | 29 - .../api/v1alpha/tools/ReportingTest.kt | 561 -- .../api/v1alpha/tools/metric1.textproto | 15 - .../api/v1alpha/tools/metric2.textproto | 36 - 111 files changed, 15 insertions(+), 18635 deletions(-) delete mode 100644 .github/workflows/configure-reporting.yml delete mode 100644 docs/gke/reporting-server-deployment.md delete mode 100644 src/main/k8s/dev/reporting_config_files_kustomization.yaml delete mode 100644 src/main/k8s/dev/reporting_gke.cue delete mode 100644 src/main/k8s/local/reporting.cue delete mode 100644 src/main/k8s/reporting.cue delete mode 100644 src/main/k8s/testing/secretfiles/measurement_spec_config.textproto delete mode 100644 src/main/kotlin/org/wfanet/measurement/integration/common/reporting/BUILD.bazel delete mode 100644 src/main/kotlin/org/wfanet/measurement/integration/common/reporting/InProcessLifeOfAReportIntegrationTest.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/integration/common/reporting/InProcessReportingServer.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/BUILD.bazel delete mode 100644 src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/MetadataPrincipalServerInterceptor.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/ReportingPrincipalIdentity.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/BUILD.bazel delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/EncryptionKeyPairMap.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/InternalApiFlags.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/KingdomApiFlags.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/BUILD.bazel delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/ReportingApiServerFlags.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/ReportingDataServer.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/V1AlphaPublicApiServer.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres/BUILD.bazel delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres/PostgresServices.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/config/BUILD.bazel delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/config/MeasurementSpecConfigValidator.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/server/BUILD.bazel delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/server/GCloudPostgresReportingDataServer.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/tools/BUILD.bazel delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/BUILD.bazel delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresMeasurementsService.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportingSetsService.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportsService.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/BUILD.bazel delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/MeasurementReader.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/MeasurementResultsReader.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/ReportReader.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/ReportingSetReader.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/server/BUILD.bazel delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/server/PostgresReportingDataServer.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing/BUILD.bazel delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing/Schemata.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/tools/BUILD.bazel delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/BUILD.bazel delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateMeasurements.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateReport.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateReportingSet.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/SetMeasurementFailure.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/SetMeasurementResult.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/AkidPrincipalLookup.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ContextKeys.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupKey.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupParentKey.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsService.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/IdVariable.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/PrincipalServerInterceptor.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportKey.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingPrincipal.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetKey.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetsService.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsService.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/SetOperationCompiler.kt delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/BUILD.bazel delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/README.md delete mode 100644 src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/Reporting.kt delete mode 100644 src/main/proto/wfa/measurement/config/reporting/measurement_spec_config.proto delete mode 100644 src/main/proto/wfa/measurement/reporting/v1alpha/BUILD.bazel delete mode 100644 src/main/proto/wfa/measurement/reporting/v1alpha/event_group.proto delete mode 100644 src/main/proto/wfa/measurement/reporting/v1alpha/event_groups_service.proto delete mode 100644 src/main/proto/wfa/measurement/reporting/v1alpha/metric.proto delete mode 100644 src/main/proto/wfa/measurement/reporting/v1alpha/page_token.proto delete mode 100644 src/main/proto/wfa/measurement/reporting/v1alpha/report.proto delete mode 100644 src/main/proto/wfa/measurement/reporting/v1alpha/reporting_set.proto delete mode 100644 src/main/proto/wfa/measurement/reporting/v1alpha/reporting_sets_service.proto delete mode 100644 src/main/proto/wfa/measurement/reporting/v1alpha/reports_service.proto delete mode 100644 src/main/proto/wfa/measurement/reporting/v1alpha/time_interval.proto delete mode 100644 src/main/terraform/gcloud/cmms/reporting.tf delete mode 100644 src/test/kotlin/org/wfanet/measurement/integration/deploy/common/postgres/BUILD.bazel delete mode 100644 src/test/kotlin/org/wfanet/measurement/integration/deploy/common/postgres/PostgresInProcessLifeOfAReportIntegrationTest.kt delete mode 100644 src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/BUILD.bazel delete mode 100644 src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/EncryptionKeyPairMapTest.kt delete mode 100644 src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/key_pair_map.textproto delete mode 100644 src/test/kotlin/org/wfanet/measurement/reporting/deploy/config/BUILD.bazel delete mode 100644 src/test/kotlin/org/wfanet/measurement/reporting/deploy/config/MeasurementSpecConfigValidatorTest.kt delete mode 100644 src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/BUILD.bazel delete mode 100644 src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresMeasurementsServiceTest.kt delete mode 100644 src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportingSetsServiceTest.kt delete mode 100644 src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportsServiceTest.kt delete mode 100644 src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel delete mode 100644 src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsServiceTest.kt delete mode 100644 src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetsServiceTest.kt delete mode 100644 src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsServiceTest.kt delete mode 100644 src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/SetOperationCompilerTest.kt delete mode 100644 src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/BUILD.bazel delete mode 100644 src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/ReportingTest.kt delete mode 100644 src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/metric1.textproto delete mode 100644 src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/metric2.textproto diff --git a/.github/workflows/configure-reporting.yml b/.github/workflows/configure-reporting.yml deleted file mode 100644 index 1dd99809b09..00000000000 --- a/.github/workflows/configure-reporting.yml +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright 2023 The Cross-Media Measurement Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: "Configure Reporting" - -on: - workflow_call: - inputs: - environment: - type: string - required: true - image-tag: - description: "Tag of container images" - type: string - required: true - apply: - description: "Apply the new configuration" - type: boolean - required: true - workflow_dispatch: - inputs: - environment: - description: "GitHub-managed environment" - required: true - type: choice - options: - - dev - - qa - - head - image-tag: - description: "Tag of container images" - type: string - required: true - apply: - description: "Apply the new configuration" - type: boolean - default: false - -permissions: - id-token: write - -env: - KUSTOMIZATION_PATH: "k8s/reporting" - -jobs: - update-reporting: - runs-on: ubuntu-22.04 - environment: ${{ inputs.environment }} - steps: - - uses: actions/checkout@v4 - - # Authenticate to Google Cloud. This will export some environment - # variables, including GCLOUD_PROJECT. - - name: Authenticate to Google Cloud - uses: google-github-actions/auth@v2 - with: - workload_identity_provider: ${{ vars.WORKLOAD_IDENTITY_PROVIDER }} - service_account: ${{ vars.GKE_CONFIG_SERVICE_ACCOUNT }} - - - name: Write auth.bazelrc - env: - BUILDBUDDY_API_KEY: ${{ secrets.BUILDBUDDY_API_KEY }} - run: | - cat << EOF > auth.bazelrc - build --remote_header=x-buildbuddy-api-key=$BUILDBUDDY_API_KEY - EOF - - - name: Write ~/.bazelrc - env: - IMAGE_TAG: ${{ inputs.image-tag }} - POSTGRES_INSTANCE: ${{ vars.POSTGRES_INSTANCE }} - GCLOUD_REGION: ${{ vars.GCLOUD_REGION }} - KINGDOM_PUBLIC_API_TARGET: ${{ vars.KINGDOM_PUBLIC_API_TARGET }} - run: | - cat << EOF > ~/.bazelrc - common --config=ci - build --remote_download_outputs=toplevel # Need build output. - common --config=ghcr - build --define "image_tag=$IMAGE_TAG" - build --define "google_cloud_project=$GCLOUD_PROJECT" - build --define "postgres_instance=$POSTGRES_INSTANCE" - build --define "postgres_region=$GCLOUD_REGION" - build --define "kingdom_public_api_target=$KINGDOM_PUBLIC_API_TARGET" - build --define reporting_public_api_address_name=reporting-v1alpha - EOF - - - - name: Export BAZEL_BIN - run: echo "BAZEL_BIN=$(bazelisk info bazel-bin)" >> $GITHUB_ENV - - - name: Get GKE cluster credentials - uses: google-github-actions/get-gke-credentials@v2 - with: - cluster_name: reporting - location: ${{ vars.GCLOUD_ZONE }} - - - name: Configure metrics - uses: ./.github/actions/configure-metrics - if: ${{ inputs.apply }} - - - name: Generate archives - run: > - bazelisk build - //src/main/k8s/dev:reporting.tar - //src/main/k8s/testing/secretfiles:archive - - - name: Make Kustomization dir - run: mkdir -p "$KUSTOMIZATION_PATH" - - - name: Extract Kustomization archive - run: > - tar -xf "$BAZEL_BIN/src/main/k8s/dev/reporting.tar" - -C "$KUSTOMIZATION_PATH" - - - name: Extract secret files archive - run: > - tar -xf "$BAZEL_BIN/src/main/k8s/testing/secretfiles/archive.tar" - -C "$KUSTOMIZATION_PATH/src/main/k8s/dev/reporting_secrets" - - # Write files from configuration variables. Since it appears that GitHub - # configuration variables use DOS (CRLF) line endings, we convert these to - # Unix (LF) line endings. - - - name: Write AKID to principal map - env: - AKID_TO_PRINCIPAL_MAP: ${{ vars.AKID_TO_PRINCIPAL_MAP }} - run: > - echo "$AKID_TO_PRINCIPAL_MAP" | sed $'s/\r$//' > - "$KUSTOMIZATION_PATH/src/main/k8s/dev/reporting_config_files/authority_key_identifier_to_principal_map.textproto" - - - name: Write encryption key-pair config - env: - ENCRYPTION_KEY_PAIR_CONFIG: ${{ vars.ENCRYPTION_KEY_PAIR_CONFIG }} - run: > - echo "$ENCRYPTION_KEY_PAIR_CONFIG" | sed $'s/\r$//' > - "$KUSTOMIZATION_PATH/src/main/k8s/dev/reporting_config_files/encryption_key_pair_config.textproto" - - - name: Copy measurement spec config - run: > - cp src/main/k8s/testing/secretfiles/measurement_spec_config.textproto - "$KUSTOMIZATION_PATH/src/main/k8s/dev/reporting_config_files/" - - - name: Write measurement consumer config - env: - MEASUREMENT_CONSUMER_CONFIG: ${{ secrets.MEASUREMENT_CONSUMER_CONFIG }} - run: > - echo "$MEASUREMENT_CONSUMER_CONFIG" | sed $'s/\r$//' > - "$KUSTOMIZATION_PATH/src/main/k8s/dev/reporting_secrets/measurement_consumer_config.textproto" - - - name: Copy secret generator - run: > - cp src/main/k8s/testing/secretfiles/reporting_secrets_kustomization.yaml - "$KUSTOMIZATION_PATH/src/main/k8s/dev/reporting_secrets/kustomization.yaml" - - - name: Export KUSTOMIZE_PATH - run: echo "KUSTOMIZE_PATH=$KUSTOMIZATION_PATH/src/main/k8s/dev/reporting" >> $GITHUB_ENV - - # Run kubectl diff, treating the command as succeeded even if the exit - # code is 1 as kubectl uses this code to indicate there's a diff. - - name: kubectl diff - id: kubectl-diff - run: kubectl diff -k "$KUSTOMIZE_PATH" || (( $? == 1 )) - - - name: kubectl apply - if: ${{ inputs.apply }} - run: kubectl apply -k "$KUSTOMIZE_PATH" - - - name: Wait for rollout - if: ${{ inputs.apply }} - run: | - for deployment in $(kubectl get deployments -o name); do - kubectl rollout status "$deployment" --timeout=5m - done diff --git a/.github/workflows/scan-images.yml b/.github/workflows/scan-images.yml index 158513221f1..ee6d925037f 100644 --- a/.github/workflows/scan-images.yml +++ b/.github/workflows/scan-images.yml @@ -66,16 +66,13 @@ jobs: - kingdom/system-api - duchy/herald - duchy/spanner-computations - - reporting/postgres-data-server - kingdom/exchanges-deletion - - reporting/v1alpha-public-api - - reporting/postgres-update-schema - - reporting/v2/v2alpha-public-api - - reporting/v2/postgres-update-schema - panel-exchange/gcloud-example-daemon - panel-exchange/aws-example-daemon - - simulator/synthetic-generator-edp + - reporting/v2/v2alpha-public-api + - reporting/v2/postgres-update-schema - reporting/v2/postgres-internal-server + - simulator/synthetic-generator-edp - duchy/postgres-update-schema - duchy/gcloud-postgres-update-schema - duchy/postgres-internal-server diff --git a/.github/workflows/update-cmms.yml b/.github/workflows/update-cmms.yml index c6611ca358c..a08130e5c49 100644 --- a/.github/workflows/update-cmms.yml +++ b/.github/workflows/update-cmms.yml @@ -124,15 +124,6 @@ jobs: # Update the Reporting system. # # This isn't technically part of the CMMS, but we do it here for simplicity. - update-reporting: - uses: ./.github/workflows/configure-reporting.yml - secrets: inherit - needs: [publish-images, terraform] - with: - image-tag: ${{ needs.publish-images.outputs.image-tag }} - environment: ${{ inputs.environment }} - apply: ${{ inputs.apply }} - update-reporting-v2: uses: ./.github/workflows/configure-reporting-v2.yml secrets: inherit diff --git a/docs/gke/metrics-deployment.md b/docs/gke/metrics-deployment.md index d0a8550439c..497e1a03994 100644 --- a/docs/gke/metrics-deployment.md +++ b/docs/gke/metrics-deployment.md @@ -30,7 +30,7 @@ free to use whichever you prefer. Deploy a Halo component. See the related guides: [Create Kingdom Cluster](kingdom-deployment.md), [Create Duchy Cluster](duchy-deployment.md), or -[Create Reporting Cluster](reporting-server-deployment.md). +[Create Reporting Cluster](reporting-server-v2-deployment.md). ## Google Cloud APIs diff --git a/docs/gke/reporting-server-deployment.md b/docs/gke/reporting-server-deployment.md deleted file mode 100644 index bb9259e2800..00000000000 --- a/docs/gke/reporting-server-deployment.md +++ /dev/null @@ -1,394 +0,0 @@ -# Halo Reporting Server Deployment on GKE - -## Important Note - -This is the deployment guide for the old V1. For the new V2, see -[Reporting V2](reporting-v2-server-deployment.md). - -## Background - -The configuration for the [`dev` environment](../../src/main/k8s/dev) can be -used as the basis for deploying CMMS components using Google Kubernetes Engine -(GKE) on another Google Cloud project. - -Many operations can be done either via the gcloud CLI or the Google Cloud web -console. This guide picks whichever is most convenient for that operation. Feel -free to use whichever you prefer. - -### What are we creating/deploying? - -- 1 Cloud SQL managed PostgreSQL database -- 1 GKE cluster - - 1 Kubernetes secret - - `certs-and-configs` - - 1 Kubernetes configmap - - `config-files` - - 2 Kubernetes services - - `postgres-reporting-data-server` (Cluster IP) - - `v1alpha-public-api-server` (External load balancer) - - 2 Kubernetes deployments - - `postgres-reporting-data-server-deployment` - - `v1alpha-public-api-server-deployment` - - 3 Kubernetes network policies - - `internal-data-server-network-policy` - - `public-api-server-network-policy` - - `default-deny-ingress-and-egress` - -## Before you start - -See [Machine Setup](machine-setup.md). - -### Managed PostgreSQL Quick Start - -If you don't have a managed PostgreSQL instance in your project, you can create -one in the -[Cloud Console](https://console.cloud.google.com/sql/instances/create;engine=PostgreSQL). -For the purposes of this guide, we assume the instance ID is `dev-postgres`. - -Make sure that the instance has the `cloudsql.iam_authentication` flag set to -`On`. Set the machine type and storage based on your expected usage. - -## Create the database - -The Reporting server expects its own database within your PostgreSQL instance. -You can create one with the `gcloud` CLI. For example, a database named -`reporting` in the `dev-postgres` instance. - -```shell -gcloud sql databases create reporting --instance=dev-postgres -``` - -## Build and push the container images - -If you aren't using pre-built release images, you can build the images yourself -from source and push them to a container registry. For example, if you're using -the [Google Container Registry](https://cloud.google.com/container-registry), -you would specify `gcr.io` as your container registry and your Cloud project -name as your image repository prefix. - -Assuming a project named `halo-cmm-dev` and an image tag `build-0001`, run the -following to build and push the images: - -```shell -bazel run -c opt //src/main/docker:push_all_reporting_gke_images \ - --define container_registry=gcr.io \ - --define image_repo_prefix=halo-cmm-dev --define image_tag=build-0001 -``` - -Tip: If you're using [Hybrid Development](../building.md#hybrid-development) for -containerized builds, replace `bazel build` with `tools/bazel-container build` -and `bazel run` with `tools/bazel-container-run`. - -## Create resources for the cluster - -See [GKE Cluster Configuration](cluster-config.md) for background. - -### IAM Service Accounts - -We'll want to -[create a least privilege service account](https://cloud.google.com/kubernetes-engine/docs/how-to/hardening-your-cluster#use_least_privilege_sa) -that our cluster will run under. Follow the steps in the linked guide to do -this. - -We'll additionally want to create a service account that we'll use to allow the -internal API server to access the database. See -[Granting Cloud SQL database access](cluster-config.md#granting-cloud-sql-instance-access) -for how to make sure this service account has the appropriate role. - -### KMS key for secret encryption - -Follow the steps in -[Create a Cloud KMS key](https://cloud.google.com/kubernetes-engine/docs/how-to/encrypting-secrets#creating-key) -to create a KMS key and grant permission to the GKE service agent to use it. - -Let's assume we've created a key named `k8s-secret` in a key ring named -`test-key-ring` in the `us-central1` region under the `halo-cmm-dev` project. -The resource name would be the following: -`projects/halo-cmm-dev/locations/us-central1/keyRings/test-key-ring/cryptoKeys/k8s-secret`. -We'll use this when creating the cluster. - -Tip: For convenience, there is a "Copy resource name" action on the key in the -Cloud console. - -## Create the cluster - -See [GKE Cluster Configuration](cluster-config.md) for tips on cluster creation -parameters, or follow the quick start instructions below. - -After creating the cluster, we can configure `kubectl` to be able to access it - -```shell -gcloud container clusters get-credentials reporting -``` - -### Add Metrics to the cluster - -See [Metrics Deployment](metrics-deployment.md). - -### Quick start - -Supposing you want to create a cluster named `reporting` for the Reporting -server, running under the `gke-cluster` service account in the `halo-cmm-dev` -project, the command would be - -```shell -gcloud container clusters create reporting \ - --enable-network-policy --workload-pool=halo-cmm-dev.svc.id.goog \ - --service-account="gke-cluster@halo-cmm-dev.iam.gserviceaccount.com" \ - --database-encryption-key=projects/halo-cmm-dev/locations/us-central1/keyRings/test-key-ring/cryptoKeys/k8s-secret \ - --num-nodes=3 --enable-autoscaling --min-nodes=2 --max-nodes=4 \ - --machine-type=e2-small -``` - -Adjust the number of nodes and machine type according to your expected usage. -The cluster version should be no older than `1.24.0` in order to support -built-in gRPC health probe. - -## Create the K8s ServiceAccount - -In order to use the IAM service account that we created earlier from our -cluster, we need to create a K8s ServiceAccount and give it access to that IAM -service account. - -For example, to create a K8s ServiceAccount named `internal-reporting-server`, -run - -```shell -kubectl create serviceaccount internal-reporting-server -``` - -Supposing the IAM service account you created in a previous step is named -`reporting-internal` within the `halo-cmm-dev` project. You'll need to allow the -K8s service account to impersonate it - -```shell -gcloud iam service-accounts add-iam-policy-binding \ - reporting-internal@halo-cmm-dev.iam.gserviceaccount.com \ - --role roles/iam.workloadIdentityUser \ - --member "serviceAccount:halo-cmm-dev.svc.id.goog[default/internal-reporting-server]" -``` - -Finally, add an annotation to link the K8s service account to the IAM service -account: - -```shell -kubectl annotate serviceaccount internal-reporting-server \ - iam.gke.io/gcp-service-account=reporting-internal@halo-cmm-dev.iam.gserviceaccount.com -``` - -## Generate the K8s Kustomization - -Populating a cluster is generally done by applying a K8s Kustomization. You can -use the `dev` configuration as a base to get started. The Kustomization is -generated using Bazel rules from files written in [CUE](https://cuelang.org/). - -To generate the `dev` Kustomization, run the following (substituting your own -values): - -```shell -bazel build //src/main/k8s/dev:reporting.tar \ - --define reporting_public_api_address_name=reporting-v1alpha \ - --define google_cloud_project=halo-cmm-dev \ - --define postgres_instance=dev-postgres \ - --define postgres_region=us-central1 \ - --define kingdom_public_api_target=v2alpha.kingdom.dev.halo-cmm.org:8443 \ - --define container_registry=gcr.io \ - --define image_repo_prefix=halo-kingdom-demo --define image_tag=build-0001 -``` - -Extract the generated archive to some directory. - -You can customize this generated object configuration with your own settings -such as the number of replicas per deployment, the memory and CPU requirements -of each container, and the JVM options of each container. - -## Customize the K8s secrets - -We use K8s secrets to hold sensitive information, such as private keys. - -### Certificates and signing keys - -First, prepare all the files we want to include in the Kubernetes secret. The -`dev` configuration assumes the files have the following names: - -1. `all_root_certs.pem` - - This makes up the trusted root CA store. It's the concatenation of the root - CA certificates for all the entities that the Reporting server interacts - with, including: - - * All Measurement Consumers - * Any entity which produces Measurement results (e.g. the Aggregator Duchy - and Data Providers) - * The Kingdom - * The Reporting server itself (for internal traffic) - - Supposing your root certs are all in a single folder and end with - `_root.pem`, you can concatenate them all with a simple shell command: - - ```shell - cat *_root.pem > all_root_certs.pem - ``` - - Note: This assumes that all your root certificate PEM files end in newline. - -1. `reporting_tls.pem` - - The Reporting server's TLS certificate. - -1. `reporting_tls.key` - - The private key for the Reporting server's TLS certificate. - -In addition, you'll need to include the encryption and signing private keys for -the Measurement Consumers that this Reporting server instance needs to act on -behalf of. The encryption keys are assumed to be in Tink's binary keyset format. -The signing private keys are assumed to be DER-encoded unencrypted PKCS #8. - -#### Testing keys - -There are some [testing keys](../../src/main/k8s/testing/secretfiles) within the -repository. These can be used to create the above secret for testing, but **must -not** be used for production environments as doing so would be highly insecure. - -Generate the archive: - -```shell -bazel build //src/main/k8s/testing/secretfiles:archive -``` - -Extract the generated archive to the `src/main/k8s/dev/reporting_secrets/` path -within the Kustomization directory. - -### Measurement Consumer config - -Contents: - -1. `measurement_consumer_config.textproto` - - [`MeasurementConsumerConfig`](../../src/main/proto/wfa/measurement/config/reporting/measurement_consumer_config.proto) - protobuf message in text format. - -### Generator - -Place the above files into the `src/main/k8s/dev/reporting_secrets/` path within -the Kustomization directory. - -Create a `kustomization.yaml` file in that path with the following content, -substituting the names of your own keys: - -```yaml -secretGenerator: -- name: signing - files: - - all_root_certs.pem - - reporting_tls.key - - reporting_tls.pem - - mc_enc_public.tink - - mc_enc_private.tink - - mc_cs_private.der -- name: mc-config - files: - - measurement_consumer_config.textproto -``` - -## Customize the K8s ConfigMap - -Configuration that may frequently change is stored in a K8s configMap. The `dev` -configuration uses one named `config-files`. - -* `authority_key_identifier_to_principal_map.textproto` - * [`AuthorityKeyToPrincipalMap`](../../src/main/proto/wfa/measurement/config/authority_key_to_principal_map.proto) -* `encryption_key_pair_config.textproto` - * [`EncryptionKeyPairConfig`](../../src/main/proto/wfa/measurement/config/reporting/encryption_key_pair_config.proto) -* `measurement_spec_config.textproto` - * [`MeasurementSpecConfig`](../../src/main/proto/wfa/measurement/config/reporting/measurement_spec_config.proto) -* `known_event_group_metadata_type_set.pb` - * Protobuf `FileDescriptorSet` containing known `EventGroup` metadata - types. - -Place these files into the `src/main/k8s/dev/reporting_config_files/` path -within the Kustomization directory. - -## Apply the K8s Kustomization - -Within the Kustomization directory, run - -```shell -kubectl apply -k src/main/k8s/dev/reporting -``` - -Now all components should be successfully deployed to your GKE cluster. You can -verify by running - -```shell -kubectl get deployments -``` - -and - -```shell -kubectl get services -``` - -You should see something like the following: - -``` -NAME READY UP-TO-DATE AVAILABLE AGE -postgres-reporting-data-server-deployment 1/1 1 1 254d -reporting-public-api-v1alpha-server-deployment 1/1 1 1 9m2s -``` - -``` -NAME TYPE CLUSTER-IP EXTERNAL-IP PORT(S) AGE -kubernetes ClusterIP 10.16.32.1 443/TCP 260d -postgres-reporting-data-server ClusterIP 10.16.39.47 8443/TCP 254d -reporting-public-api-v1alpha-server LoadBalancer 10.16.32.255 34.135.79.68 8443:30104/TCP 8m45s -``` - -## Appendix - -### Troubleshooting - -* `notAuthorized` error - - You see an error that looks something like this: - - ``` - { - "code": 403, - "errors": [ - { - "domain": "global", - "message": "The client is not authorized to make this request.", - "reason": "notAuthorized" - } - ], - "message": "The client is not authorized to make this request." - } - ``` - - Make sure that your Cloud SQL instance has the `cloudsql.iam_authentication` - flag set to `On`, and that you've followed all the steps for using Workload - Identity and IAM Authentication. See the - [Workload Identity](cluster-config.md#workload-identity) section in the - Cluster Configuration doc. - - If you believe you have everything configured correctly, try deleting and - recreating the IAM service account for DB access. Apparently there's a - glitch with Cloud SQL that this sometimes resolves. - -### Manual testing via CLI - -The -[`Reporting`](../../src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools) -CLI tool can be used for manual testing as well as examples of how to call the -API. - -### HTTP/REST - -The public API is a set of gRPC services following the -[API Improvement Proposals](https://google.aip.dev/). These can be exposed as an -HTTP REST API using the -[gRPC Gateway](https://github.com/grpc-ecosystem/grpc-gateway). The service -definitions include the appropriate protobuf annotations for this purpose. diff --git a/docs/gke/reporting-v2-server-deployment.md b/docs/gke/reporting-v2-server-deployment.md index 5e5d0056e3a..df8aeeb0c0f 100644 --- a/docs/gke/reporting-v2-server-deployment.md +++ b/docs/gke/reporting-v2-server-deployment.md @@ -1,10 +1,5 @@ # Halo Reporting V2 Server Deployment on GKE -## Important Note - -This is the deployment guide for the new V2. For the old V1, see -[Reporting V1](reporting-server-deployment.md). - ## Background The configuration for the [`dev` environment](../../src/main/k8s/dev) can be diff --git a/src/main/docker/images.bzl b/src/main/docker/images.bzl index 0fa56f1e971..08125217c30 100644 --- a/src/main/docker/images.bzl +++ b/src/main/docker/images.bzl @@ -232,40 +232,6 @@ LOCAL_IMAGES = [ ), ] -REPORTING_COMMON_IMAGES = [ - struct( - name = "reporting_v1alpha_public_api_server_image", - image = "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server:v1alpha_public_api_server_image", - repository = _PREFIX + "/reporting/v1alpha-public-api", - ), -] - -REPORTING_LOCAL_IMAGES = [ - struct( - name = "reporting_data_server_image", - image = "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/server:postgres_reporting_data_server_image", - repository = _PREFIX + "/reporting/local-postgres-internal", - ), - struct( - name = "reporting_postgres_update_schema_image", - image = "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/tools:update_schema_image", - repository = _PREFIX + "/reporting/local-postgres-update-schema", - ), -] - -REPORTING_GKE_IMAGES = [ - struct( - name = "gcloud_reporting_data_server_image", - image = "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/server:gcloud_postgres_reporting_data_server_image", - repository = _PREFIX + "/reporting/postgres-data-server", - ), - struct( - name = "gcloud_reporting_postgres_update_schema_image", - image = "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/tools:update_schema_image", - repository = _PREFIX + "/reporting/postgres-update-schema", - ), -] - REPORTING_V2_COMMON_IMAGES = [ struct( name = "reporting_v2alpha_public_api_server_image", @@ -305,12 +271,12 @@ REPORTING_V2_GKE_IMAGES = [ ), ] -ALL_GKE_IMAGES = COMMON_IMAGES + GKE_IMAGES + REPORTING_COMMON_IMAGES + REPORTING_GKE_IMAGES + REPORTING_V2_COMMON_IMAGES + REPORTING_V2_GKE_IMAGES +ALL_GKE_IMAGES = COMMON_IMAGES + GKE_IMAGES + REPORTING_V2_COMMON_IMAGES + REPORTING_V2_GKE_IMAGES -ALL_LOCAL_IMAGES = COMMON_IMAGES + LOCAL_IMAGES + REPORTING_COMMON_IMAGES + REPORTING_LOCAL_IMAGES + REPORTING_V2_COMMON_IMAGES + REPORTING_V2_LOCAL_IMAGES +ALL_LOCAL_IMAGES = COMMON_IMAGES + LOCAL_IMAGES + REPORTING_V2_COMMON_IMAGES + REPORTING_V2_LOCAL_IMAGES -ALL_IMAGES = COMMON_IMAGES + LOCAL_IMAGES + GKE_IMAGES + REPORTING_COMMON_IMAGES + REPORTING_LOCAL_IMAGES + REPORTING_GKE_IMAGES + REPORTING_V2_COMMON_IMAGES + REPORTING_V2_LOCAL_IMAGES + REPORTING_V2_GKE_IMAGES + EKS_IMAGES +ALL_IMAGES = COMMON_IMAGES + LOCAL_IMAGES + GKE_IMAGES + REPORTING_V2_COMMON_IMAGES + REPORTING_V2_LOCAL_IMAGES + REPORTING_V2_GKE_IMAGES + EKS_IMAGES -ALL_REPORTING_GKE_IMAGES = REPORTING_COMMON_IMAGES + REPORTING_GKE_IMAGES + REPORTING_V2_COMMON_IMAGES + REPORTING_V2_GKE_IMAGES +ALL_REPORTING_GKE_IMAGES = REPORTING_V2_COMMON_IMAGES + REPORTING_V2_GKE_IMAGES ALL_EKS_IMAGES = COMMON_IMAGES + EKS_IMAGES diff --git a/src/main/k8s/BUILD.bazel b/src/main/k8s/BUILD.bazel index e815499a116..1e930ced64a 100644 --- a/src/main/k8s/BUILD.bazel +++ b/src/main/k8s/BUILD.bazel @@ -78,16 +78,6 @@ cue_library( srcs = ["postgres.cue"], ) -cue_library( - name = "reporting", - srcs = ["reporting.cue"], - deps = [ - ":base", - ":config", - ":postgres", - ], -) - cue_library( name = "reporting_v2", srcs = ["reporting_v2.cue"], diff --git a/src/main/k8s/dev/BUILD.bazel b/src/main/k8s/dev/BUILD.bazel index 432063212be..67136a60f35 100644 --- a/src/main/k8s/dev/BUILD.bazel +++ b/src/main/k8s/dev/BUILD.bazel @@ -419,58 +419,6 @@ kustomization_dir( ], ) -cue_dump( - name = "reporting_gke", - srcs = ["reporting_gke.cue"], - cue_tags = { - "secret_name": SIGNING_SECRET_NAME, - "mc_config_secret_name": MC_CONFIG_SECRET_NAME, - "container_registry": IMAGE_REPOSITORY_SETTINGS.container_registry, - "image_repo_prefix": IMAGE_REPOSITORY_SETTINGS.repository_prefix, - "image_tag": IMAGE_REPOSITORY_SETTINGS.image_tag, - "public_api_address_name": REPORTING_K8S_SETTINGS.public_api_address_name, - "google_cloud_project": GCLOUD_SETTINGS.project, - "postgres_instance": GCLOUD_SETTINGS.postgres_instance, - "postgres_region": GCLOUD_SETTINGS.postgres_region, - "kingdom_public_api_target": KINGDOM_K8S_SETTINGS.public_api_target, - }, - tags = ["manual"], - deps = [ - ":base_gke", - ":config_gke", - "//src/main/k8s:reporting", - ], -) - -kustomization_dir( - name = "reporting_config_files", - testonly = True, - srcs = [ - "reporting_config_files_kustomization.yaml", - "//src/main/k8s/testing/secretfiles:known_event_group_metadata_type_set", - ], - renames = {"reporting_config_files_kustomization.yaml": "kustomization.yaml"}, -) - -kustomization_dir( - name = "reporting_secrets", -) - -kustomization_dir( - name = "reporting", - testonly = True, - srcs = [ - "resource_requirements.yaml", - ":reporting_gke", - ], - generate_kustomization = True, - tags = ["manual"], - deps = [ - ":reporting_config_files", - ":reporting_secrets", - ], -) - cue_dump( name = "reporting_v2_gke", srcs = ["reporting_v2_gke.cue"], diff --git a/src/main/k8s/dev/reporting_config_files_kustomization.yaml b/src/main/k8s/dev/reporting_config_files_kustomization.yaml deleted file mode 100644 index a15c366fe23..00000000000 --- a/src/main/k8s/dev/reporting_config_files_kustomization.yaml +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2022 The Cross-Media Measurement Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -configMapGenerator: -- name: config-files - files: - - authority_key_identifier_to_principal_map.textproto - - encryption_key_pair_config.textproto - - measurement_spec_config.textproto - - known_event_group_metadata_type_set.pb diff --git a/src/main/k8s/dev/reporting_gke.cue b/src/main/k8s/dev/reporting_gke.cue deleted file mode 100644 index 7af25490016..00000000000 --- a/src/main/k8s/dev/reporting_gke.cue +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package k8s - -_reportingSecretName: string @tag("secret_name") -_reportingMcConfigSecretName: string @tag("mc_config_secret_name") -_publicApiAddressName: string @tag("public_api_address_name") - -#KingdomApiTarget: #GrpcTarget & { - target: string @tag("kingdom_public_api_target") -} - -// Name of K8s service account for the internal API server. -#InternalServerServiceAccount: "internal-reporting-server" - -#InternalServerResourceRequirements: #ResourceRequirements & { - requests: { - cpu: "100m" - } -} -#PublicServerResourceRequirements: ResourceRequirements=#ResourceRequirements & { - requests: { - cpu: "25m" - memory: "256Mi" - } - limits: { - memory: ResourceRequirements.requests.memory - } -} - -objectSets: [ - defaultNetworkPolicies, - reporting.serviceAccounts, - reporting.configMaps, - reporting.deployments, - reporting.services, - reporting.networkPolicies, -] - -reporting: #Reporting & { - _secretName: _reportingSecretName - _mcConfigSecretName: _reportingMcConfigSecretName - _kingdomApiTarget: #KingdomApiTarget - _internalApiTarget: certificateHost: "localhost" - - _postgresConfig: { - iamUserLocal: "reporting-internal" - database: "reporting" - } - - _verboseGrpcServerLogging: true - - serviceAccounts: { - "\(#InternalServerServiceAccount)": #WorkloadIdentityServiceAccount & { - _iamServiceAccountName: "reporting-internal" - } - } - - configMaps: "java": #JavaConfigMap - - deployments: { - "postgres-reporting-data-server": { - _container: resources: #InternalServerResourceRequirements - spec: template: spec: #ServiceAccountPodSpec & { - serviceAccountName: #InternalServerServiceAccount - } - } - "reporting-public-api-v1alpha-server": { - _container: resources: #PublicServerResourceRequirements - } - } - - services: { - "reporting-public-api-v1alpha-server": _ipAddressName: _publicApiAddressName - } -} diff --git a/src/main/k8s/local/BUILD.bazel b/src/main/k8s/local/BUILD.bazel index 552219510fe..9d9c7cb5612 100644 --- a/src/main/k8s/local/BUILD.bazel +++ b/src/main/k8s/local/BUILD.bazel @@ -200,25 +200,6 @@ cue_dump( ], ) -cue_dump( - name = "reporting", - srcs = ["reporting.cue"], - cue_tags = { - "secret_name": SECRET_NAME, - "db_secret_name": DB_SECRET_NAME, - "mc_config_secret_name": MC_CONFIG_SECRET_NAME, - "container_registry": IMAGE_REPOSITORY_SETTINGS.container_registry, - "image_repo_prefix": IMAGE_REPOSITORY_SETTINGS.repository_prefix, - "image_tag": IMAGE_REPOSITORY_SETTINGS.image_tag, - }, - tags = ["manual"], - deps = [ - ":config_cue", - "//src/main/k8s:postgres", - "//src/main/k8s:reporting", - ], -) - cue_dump( name = "reporting_v2", srcs = ["reporting_v2.cue"], @@ -292,7 +273,6 @@ kustomization_dir( "empty_encryption_key_pair_config.textproto", "//src/main/k8s/testing/data:synthetic_generation_specs_small", "//src/main/k8s/testing/secretfiles:known_event_group_metadata_type_set", - "//src/main/k8s/testing/secretfiles:measurement_spec_config.textproto", "//src/main/k8s/testing/secretfiles:metric_spec_config.textproto", ], renames = { @@ -309,7 +289,6 @@ kustomization_dir( ":encryption_key_pair_config.textproto", "//src/main/k8s/testing/data:synthetic_generation_specs_small", "//src/main/k8s/testing/secretfiles:known_event_group_metadata_type_set", - "//src/main/k8s/testing/secretfiles:measurement_spec_config.textproto", "//src/main/k8s/testing/secretfiles:metric_spec_config.textproto", ], renames = { @@ -384,26 +363,6 @@ kustomization_dir( ], ) -kustomization_dir( - name = "cmms_with_reporting", - srcs = [ - ":duchies", - ":edp_simulators", - ":emulators", - ":kingdom", - ":postgres_database", - ":reporting", - ], - generate_kustomization = True, - tags = ["manual"], - deps = [ - ":config_files", - ":db_creds", - ":mc_config", - "//src/main/k8s/testing/secretfiles:kustomization", - ], -) - kustomization_dir( name = "cmms_with_reporting_v2", srcs = [ diff --git a/src/main/k8s/local/README.md b/src/main/k8s/local/README.md index d1d6a20c609..a0912320e0d 100644 --- a/src/main/k8s/local/README.md +++ b/src/main/k8s/local/README.md @@ -198,8 +198,8 @@ This is an alternate version of the section above. This assumes you've already done the Initial Setup and have the output from the `ResourceSetup` tool. Use the command in the above section to build the tar archive, swapping the -target with `//src/main/k8s/local:cmms_with_reporting.tar`. Extract this archive -to some directory (e.g. `/tmp/cmms`). +target with `//src/main/k8s/local:cmms_with_reporting_v2.tar`. Extract this +archive to some directory (e.g. `/tmp/cmms`). Copy the `authority_key_identifier_to_principal_map.textproto` output from the `ResourceSetup` tool to the `src/main/k8s/local/config_files` path within this @@ -209,12 +209,9 @@ You can then apply the Kustomization from the directory where you extracted the archive: ```shell -kubectl apply -k src/main/k8s/local/cmms_with_reporting/ +kubectl apply -k src/main/k8s/local/cmms_with_reporting_v2/ ``` -To deploy Reporting V2, swap out `cmms_with_reporting` with -`cmms_with_reporting_v2`. - To use the Reporting CLI tool, you need to forward the port again. For example: ```shell diff --git a/src/main/k8s/local/config_files_kustomization.yaml b/src/main/k8s/local/config_files_kustomization.yaml index 56e23573e98..71c188ba7fb 100644 --- a/src/main/k8s/local/config_files_kustomization.yaml +++ b/src/main/k8s/local/config_files_kustomization.yaml @@ -17,7 +17,6 @@ configMapGenerator: files: - authority_key_identifier_to_principal_map.textproto - encryption_key_pair_config.textproto - - measurement_spec_config.textproto - metric_spec_config.textproto - synthetic_population_spec_small.textproto - synthetic_event_group_spec_small_1.textproto diff --git a/src/main/k8s/local/reporting.cue b/src/main/k8s/local/reporting.cue deleted file mode 100644 index 7d7857b5229..00000000000 --- a/src/main/k8s/local/reporting.cue +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package k8s - -_reportingSecretName: string @tag("secret_name") -_reportingDbSecretName: string @tag("db_secret_name") -_reportingMcConfigSecretName: string @tag("mc_config_secret_name") - -objectSets: [ for objectSet in reporting {objectSet}] - -reporting: #Reporting & { - _secretName: _reportingSecretName - _mcConfigSecretName: _reportingMcConfigSecretName - _imageSuffixes: { - "update-reporting-schema": "reporting/local-postgres-update-schema" - "postgres-reporting-data-server": "reporting/local-postgres-internal" - } - - _postgresConfig: { - serviceName: "postgres" - password: "$(POSTGRES_PASSWORD)" - user: "$(POSTGRES_USER)" - } - _kingdomApiTarget: { - serviceName: "v2alpha-public-api-server" - certificateHost: "localhost" - } - _internalApiTarget: { - certificateHost: "localhost" - } - _verboseGrpcServerLogging: true - _verboseGrpcClientLogging: true - - let EnvVars = #EnvVarMap & { - "POSTGRES_USER": { - valueFrom: - secretKeyRef: { - name: _reportingDbSecretName - key: "username" - } - } - "POSTGRES_PASSWORD": { - valueFrom: - secretKeyRef: { - name: _reportingDbSecretName - key: "password" - } - } - } - - deployments: { - "postgres-reporting-data-server": { - _container: _envVars: EnvVars - _updateSchemaContainer: _envVars: EnvVars - } - "reporting-public-api-v1alpha-server": { - spec: template: spec: { - _dependencies: [ - "postgres-reporting-data-server", - "v2alpha-public-api-server", // Kingdom public API server. - ] - } - } - } -} diff --git a/src/main/k8s/reporting.cue b/src/main/k8s/reporting.cue deleted file mode 100644 index 8cae8d4b4a4..00000000000 --- a/src/main/k8s/reporting.cue +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package k8s - -#Reporting: Reporting={ - _verboseGrpcServerLogging: bool | *false - _verboseGrpcClientLogging: bool | *false - - _postgresConfig: #PostgresConfig - - _internalApiTarget: #GrpcTarget & { - serviceName: "postgres-reporting-data-server" - targetOption: "--internal-api-target" - certificateHostOption: "--internal-api-cert-host" - } - _kingdomApiTarget: #GrpcTarget & { - targetOption: "--kingdom-api-target" - certificateHostOption: "--kingdom-api-cert-host" - } - - _imageSuffixes: [_=string]: string - _imageSuffixes: { - "update-reporting-schema": string | *"reporting/postgres-update-schema" - "postgres-reporting-data-server": string | *"reporting/postgres-data-server" - "reporting-public-api-v1alpha-server": string | *"reporting/v1alpha-public-api" - } - _imageConfigs: [_=string]: #ImageConfig - _imageConfigs: { - for name, suffix in _imageSuffixes { - "\(name)": {repoSuffix: suffix} - } - } - _images: { - for name, config in _imageConfigs { - "\(name)": config.image - } - } - _secretName: string - _mcConfigSecretName: string - - _tlsArgs: [ - "--tls-cert-file=/var/run/secrets/files/reporting_tls.pem", - "--tls-key-file=/var/run/secrets/files/reporting_tls.key", - ] - _reportingCertCollectionFileFlag: "--cert-collection-file=/var/run/secrets/files/all_root_certs.pem" - _akidToPrincipalMapFileFlag: "--authority-key-identifier-to-principal-map-file=/etc/\(#AppName)/config-files/authority_key_identifier_to_principal_map.textproto" - _measurementConsumerConfigFileFlag: "--measurement-consumer-config-file=/var/run/secrets/files/config/mc/measurement_consumer_config.textproto" - _signingPrivateKeyStoreDirFlag: "--signing-private-key-store-dir=/var/run/secrets/files" - _encryptionKeyPairDirFlag: "--key-pair-dir=/var/run/secrets/files" - _encryptionKeyPairConfigFileFlag: "--key-pair-config-file=/etc/\(#AppName)/config-files/encryption_key_pair_config.textproto" - _measurementSpecConfigFileFlag: "--measurement-spec-config-file=/etc/\(#AppName)/config-files/measurement_spec_config.textproto" - _knownEventGroupMetadataTypeFlag: "--known-event-group-metadata-type=/etc/\(#AppName)/config-files/known_event_group_metadata_type_set.pb" - _debugVerboseGrpcClientLoggingFlag: "--debug-verbose-grpc-client-logging=\(_verboseGrpcClientLogging)" - _debugVerboseGrpcServerLoggingFlag: "--debug-verbose-grpc-server-logging=\(_verboseGrpcServerLogging)" - - services: [Name=_]: #GrpcService & { - metadata: { - _component: "reporting" - name: Name - } - } - services: { - "postgres-reporting-data-server": {} - "reporting-public-api-v1alpha-server": #ExternalService - } - - deployments: [Name=_]: #ServerDeployment & { - _name: Name - _secretName: Reporting._secretName - _system: "reporting" - _container: { - image: _images[_name] - } - } - deployments: { - "postgres-reporting-data-server": { - _container: args: [ - _reportingCertCollectionFileFlag, - _debugVerboseGrpcServerLoggingFlag, - "--port=8443", - "--health-port=8080", - ] + _postgresConfig.flags + _tlsArgs - - _updateSchemaContainer: Container=#Container & { - image: _images[Container.name] - args: _postgresConfig.flags - imagePullPolicy?: _container.imagePullPolicy - } - - spec: template: spec: _initContainers: { - "update-reporting-schema": _updateSchemaContainer - } - } - - "reporting-public-api-v1alpha-server": { - _container: args: [ - _debugVerboseGrpcClientLoggingFlag, - _debugVerboseGrpcServerLoggingFlag, - _reportingCertCollectionFileFlag, - _akidToPrincipalMapFileFlag, - _measurementConsumerConfigFileFlag, - _signingPrivateKeyStoreDirFlag, - _encryptionKeyPairDirFlag, - _encryptionKeyPairConfigFileFlag, - _measurementSpecConfigFileFlag, - _knownEventGroupMetadataTypeFlag, - "--port=8443", - "--health-port=8080", - "--event-group-metadata-descriptor-cache-duration=1h", - ] + _tlsArgs + _internalApiTarget.args + _kingdomApiTarget.args - - spec: template: spec: { - _mounts: { - "mc-config": { - volume: secret: secretName: Reporting._mcConfigSecretName - volumeMount: mountPath: "/var/run/secrets/files/config/mc/" - } - "config-files": #ConfigMapMount - } - _dependencies: _ | *["postgres-reporting-data-server"] - } - } - } - - networkPolicies: [Name=_]: #NetworkPolicy & { - _name: Name - } - - networkPolicies: { - "internal-reporting-data-server": { - _app_label: "postgres-reporting-data-server-app" - _sourceMatchLabels: [ - "reporting-public-api-v1alpha-server-app", - ] - _egresses: { - // Needs to call out to Postgres server. - any: {} - } - } - "public-reporting-api-server": { - _app_label: "reporting-public-api-v1alpha-server-app" - _destinationMatchLabels: ["postgres-reporting-data-server-app"] - _ingresses: { - gRpc: { - ports: [{ - port: #GrpcPort - }] - } - } - _egresses: { - // Needs to call out to Kingdom. - any: {} - } - } - } - - configMaps: [Name=string]: #ConfigMap & { - metadata: name: Name - } - - serviceAccounts: [Name=string]: #ServiceAccount & { - metadata: name: Name - } -} diff --git a/src/main/k8s/testing/secretfiles/BUILD.bazel b/src/main/k8s/testing/secretfiles/BUILD.bazel index 2ad33a4952b..611511c00bc 100644 --- a/src/main/k8s/testing/secretfiles/BUILD.bazel +++ b/src/main/k8s/testing/secretfiles/BUILD.bazel @@ -19,7 +19,6 @@ package( ) exports_files([ - "measurement_spec_config.textproto", "metric_spec_config.textproto", ]) diff --git a/src/main/k8s/testing/secretfiles/measurement_spec_config.textproto b/src/main/k8s/testing/secretfiles/measurement_spec_config.textproto deleted file mode 100644 index 43bf7b2c38b..00000000000 --- a/src/main/k8s/testing/secretfiles/measurement_spec_config.textproto +++ /dev/null @@ -1,82 +0,0 @@ -# proto-file: wfa/measurement/config/reporting/measurement_spec_config.proto -# proto-message: MeasurementSpecConfig -reach_single_data_provider { - privacy_params { - epsilon: 0.000207 - delta: 1e-15 - } - vid_sampling_interval { - fixed_start { - start: 0f - width: 1f - } - } -} -reach { - privacy_params { - epsilon: 0.0007444 - delta: 1e-15 - } - vid_sampling_interval { - random_start { - width: 256 - num_vid_buckets: 300 - } - } -} -reach_and_frequency_single_data_provider { - reach_privacy_params { - epsilon: 0.000207 - delta: 1e-15 - } - frequency_privacy_params { - epsilon: 0.004728 - delta: 1e-15 - } - vid_sampling_interval { - fixed_start { - start: 0f - width: 1f - } - } -} -reach_and_frequency { - reach_privacy_params { - epsilon: 0.0007444 - delta: 1e-15 - } - frequency_privacy_params { - epsilon: 0.014638 - delta: 1e-15 - } - vid_sampling_interval { - random_start { - width: 256 - num_vid_buckets: 300 - } - } -} -impression { - privacy_params { - epsilon: 0.003592 - delta: 1e-15 - } - vid_sampling_interval { - fixed_start { - start: 0f - width: 1f - } - } -} -duration { - privacy_params { - epsilon: 0.007418 - delta: 1e-15 - } - vid_sampling_interval { - fixed_start { - start: 0f - width: 1f - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/BUILD.bazel deleted file mode 100644 index 5be39730803..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/BUILD.bazel +++ /dev/null @@ -1,63 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -package( - default_testonly = True, - default_visibility = [ - "//src/test/kotlin/org/wfanet/measurement/integration:__subpackages__", - ], -) - -kt_jvm_library( - name = "in_process_life_of_a_report_integration_test", - srcs = [ - "InProcessLifeOfAReportIntegrationTest.kt", - ], - data = [ - "//src/main/k8s/testing/secretfiles:secret_files", - ], - deps = [ - ":in_process_reporting_server", - "//src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity:reporting_principal_identity", - "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/testing:fake_measurements_service", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server:reporting_data_server", - "//src/main/proto/wfa/measurement/config/reporting:encryption_key_pair_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:event_groups_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reporting_sets_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reports_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", - "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/duchy", - ], -) - -kt_jvm_library( - name = "in_process_reporting_server", - srcs = [ - "InProcessReportingServer.kt", - ], - data = [ - "//src/main/k8s/testing/secretfiles:secret_files", - ], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity:metadata_principal_interceptor", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server:reporting_data_server", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:cel_env_provider", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:event_groups_service", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reporting_sets_service", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reports_service", - "//src/main/proto/wfa/measurement/api/v2alpha:certificates_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:data_providers_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:event_group_metadata_descriptors_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:event_groups_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:measurement_consumers_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:measurements_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/config/reporting:encryption_key_pair_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/config/reporting:measurement_spec_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:event_groups_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reporting_sets_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reports_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/InProcessLifeOfAReportIntegrationTest.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/InProcessLifeOfAReportIntegrationTest.kt deleted file mode 100644 index 936f8dfad6f..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/InProcessLifeOfAReportIntegrationTest.kt +++ /dev/null @@ -1,434 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.integration.common.reporting - -import com.google.common.truth.Truth.assertThat -import com.google.common.truth.extensions.proto.ProtoTruth.assertThat -import com.google.protobuf.ByteString -import com.google.protobuf.DescriptorProtos -import com.google.protobuf.duration -import com.google.protobuf.kotlin.toByteString -import com.google.protobuf.timestamp -import java.io.File -import java.nio.file.Paths -import java.time.Clock -import kotlinx.coroutines.launch -import kotlinx.coroutines.runBlocking -import org.junit.Rule -import org.junit.Test -import org.junit.rules.TestRule -import org.mockito.kotlin.any -import org.wfanet.measurement.api.v2alpha.Certificate -import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt -import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt -import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey -import org.wfanet.measurement.api.v2alpha.EventGroup -import org.wfanet.measurement.api.v2alpha.EventGroupKt -import org.wfanet.measurement.api.v2alpha.EventGroupMetadataDescriptor -import org.wfanet.measurement.api.v2alpha.EventGroupMetadataDescriptorsGrpcKt -import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt -import org.wfanet.measurement.api.v2alpha.MeasurementConsumer -import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt -import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt -import org.wfanet.measurement.api.v2alpha.certificate -import org.wfanet.measurement.api.v2alpha.dataProvider -import org.wfanet.measurement.api.v2alpha.eventGroup -import org.wfanet.measurement.api.v2alpha.eventGroupMetadataDescriptor -import org.wfanet.measurement.api.v2alpha.listEventGroupMetadataDescriptorsResponse -import org.wfanet.measurement.api.v2alpha.listEventGroupsResponse -import org.wfanet.measurement.api.v2alpha.measurementConsumer -import org.wfanet.measurement.common.crypto.SigningKeyHandle -import org.wfanet.measurement.common.crypto.readCertificate -import org.wfanet.measurement.common.crypto.readCertificateCollection -import org.wfanet.measurement.common.crypto.subjectKeyIdentifier -import org.wfanet.measurement.common.crypto.testing.loadSigningKey -import org.wfanet.measurement.common.crypto.tink.loadPublicKey -import org.wfanet.measurement.common.getRuntimePath -import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule -import org.wfanet.measurement.common.grpc.testing.mockService -import org.wfanet.measurement.common.identity.RandomIdGenerator -import org.wfanet.measurement.common.pack -import org.wfanet.measurement.common.readByteString -import org.wfanet.measurement.common.testing.chainRulesSequentially -import org.wfanet.measurement.config.reporting.EncryptionKeyPairConfig -import org.wfanet.measurement.config.reporting.EncryptionKeyPairConfigKt.keyPair -import org.wfanet.measurement.config.reporting.EncryptionKeyPairConfigKt.principalKeyPairs -import org.wfanet.measurement.config.reporting.encryptionKeyPairConfig -import org.wfanet.measurement.config.reporting.measurementConsumerConfig -import org.wfanet.measurement.consent.client.common.encryptMessage -import org.wfanet.measurement.consent.client.common.toEncryptionPublicKey -import org.wfanet.measurement.consent.client.common.toPublicKeyHandle -import org.wfanet.measurement.consent.client.measurementconsumer.signEncryptionPublicKey -import org.wfanet.measurement.integration.common.reporting.identity.withPrincipalName -import org.wfanet.measurement.kingdom.service.api.v2alpha.testing.FakeMeasurementsService -import org.wfanet.measurement.reporting.deploy.common.server.ReportingDataServer -import org.wfanet.measurement.reporting.v1alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub -import org.wfanet.measurement.reporting.v1alpha.ListEventGroupsResponse -import org.wfanet.measurement.reporting.v1alpha.ListReportingSetsResponse -import org.wfanet.measurement.reporting.v1alpha.ListReportsResponse -import org.wfanet.measurement.reporting.v1alpha.Metric -import org.wfanet.measurement.reporting.v1alpha.MetricKt -import org.wfanet.measurement.reporting.v1alpha.Report -import org.wfanet.measurement.reporting.v1alpha.ReportKt -import org.wfanet.measurement.reporting.v1alpha.ReportingSet -import org.wfanet.measurement.reporting.v1alpha.ReportingSetsGrpcKt.ReportingSetsCoroutineStub -import org.wfanet.measurement.reporting.v1alpha.ReportsGrpcKt.ReportsCoroutineStub -import org.wfanet.measurement.reporting.v1alpha.copy -import org.wfanet.measurement.reporting.v1alpha.createReportRequest -import org.wfanet.measurement.reporting.v1alpha.createReportingSetRequest -import org.wfanet.measurement.reporting.v1alpha.getReportRequest -import org.wfanet.measurement.reporting.v1alpha.listEventGroupsRequest -import org.wfanet.measurement.reporting.v1alpha.listReportingSetsRequest -import org.wfanet.measurement.reporting.v1alpha.listReportsRequest -import org.wfanet.measurement.reporting.v1alpha.metric -import org.wfanet.measurement.reporting.v1alpha.periodicTimeInterval -import org.wfanet.measurement.reporting.v1alpha.report -import org.wfanet.measurement.reporting.v1alpha.reportingSet - -private const val NUM_SET_OPERATIONS = 50 - -/** - * Test that everything is wired up properly. - * - * This is abstract so that different implementations of dependencies can all run the same tests - * easily. - */ -abstract class InProcessLifeOfAReportIntegrationTest { - abstract val reportingServerDataServices: ReportingDataServer.Services - - private val publicKingdomCertificatesMock: CertificatesGrpcKt.CertificatesCoroutineImplBase = - mockService { - onBlocking { getCertificate(any()) }.thenReturn(CERTIFICATE) - } - private val publicKingdomDataProvidersMock: DataProvidersGrpcKt.DataProvidersCoroutineImplBase = - mockService { - onBlocking { getDataProvider(any()) }.thenReturn(DATA_PROVIDER) - } - private val publicKingdomEventGroupsMock: EventGroupsGrpcKt.EventGroupsCoroutineImplBase = - mockService { - onBlocking { listEventGroups(any()) } - .thenReturn(listEventGroupsResponse { eventGroups += EVENT_GROUP }) - } - private val publicKingdomEventGroupMetadataDescriptorsMock: - EventGroupMetadataDescriptorsGrpcKt.EventGroupMetadataDescriptorsCoroutineImplBase = - mockService { - onBlocking { listEventGroupMetadataDescriptors(any()) } - .thenReturn( - listEventGroupMetadataDescriptorsResponse { - eventGroupMetadataDescriptors += EVENT_GROUP_METADATA_DESCRIPTOR - } - ) - } - private val publicKingdomMeasurementConsumersMock: - MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineImplBase = - mockService { - onBlocking { getMeasurementConsumer(any()) }.thenReturn(MEASUREMENT_CONSUMER) - } - private val publicFakeKingdomMeasurementsService: - MeasurementsGrpcKt.MeasurementsCoroutineImplBase = - FakeMeasurementsService( - RandomIdGenerator(Clock.systemUTC()), - EDP_SIGNING_KEY_HANDLE, - DATA_PROVIDER_CERTIFICATE_NAME, - ) - - private val publicKingdomServer = GrpcTestServerRule { - addService(publicKingdomCertificatesMock) - addService(publicKingdomDataProvidersMock) - addService(publicKingdomEventGroupsMock) - addService(publicKingdomEventGroupMetadataDescriptorsMock) - addService(publicKingdomMeasurementConsumersMock) - addService(publicFakeKingdomMeasurementsService) - } - - private val encryptionKeyPairConfig: EncryptionKeyPairConfig = encryptionKeyPairConfig { - principalKeyPairs += principalKeyPairs { - principal = MEASUREMENT_CONSUMER_NAME - keyPairs += keyPair { - publicKeyFile = "mc_enc_public.tink" - privateKeyFile = "mc_enc_private.tink" - } - } - } - - private val measurementConsumerConfig = measurementConsumerConfig { - apiKey = API_KEY - signingCertificateName = MEASUREMENT_CONSUMER_CERTIFICATE_NAME - signingPrivateKeyPath = MC_SIGNING_PRIVATE_KEY_PATH - } - - private val reportingServer: InProcessReportingServer by lazy { - InProcessReportingServer( - reportingServerDataServices, - publicKingdomServer.channel, - encryptionKeyPairConfig, - SECRETS_DIR, - measurementConsumerConfig, - TRUSTED_CERTIFICATES, - verboseGrpcLogging = false, - ) - } - - @get:Rule - val ruleChain: TestRule by lazy { chainRulesSequentially(publicKingdomServer, reportingServer) } - - private val publicEventGroupsClient by lazy { - EventGroupsCoroutineStub(reportingServer.publicApiChannel) - } - private val publicReportingSetsClient by lazy { - ReportingSetsCoroutineStub(reportingServer.publicApiChannel) - } - private val publicReportsClient by lazy { ReportsCoroutineStub(reportingServer.publicApiChannel) } - - @Test - fun `create Report and get the result successfully`() = runBlocking { - createReportingSet("1", MEASUREMENT_CONSUMER_NAME) - createReportingSet("2", MEASUREMENT_CONSUMER_NAME) - createReportingSet("3", MEASUREMENT_CONSUMER_NAME) - - val createdReport = createReport("1234", MEASUREMENT_CONSUMER_NAME) - val reports = listReports(MEASUREMENT_CONSUMER_NAME) - assertThat(reports.reportsList).hasSize(1) - val completedReport = getReport(createdReport.name, createdReport.measurementConsumer) - assertThat(assertThat(completedReport.state).isEqualTo(Report.State.SUCCEEDED)) - val reportResult = computeReportResult(completedReport) - // each measurement has a result of 100.0 and there are two time intervals - assertThat(reportResult).isEqualTo(200.0 * NUM_SET_OPERATIONS) - } - - @Test - fun `create multiple Reports concurrently successfully`() = runBlocking { - createReportingSet("1", MEASUREMENT_CONSUMER_NAME) - createReportingSet("2", MEASUREMENT_CONSUMER_NAME) - createReportingSet("3", MEASUREMENT_CONSUMER_NAME) - launch { createReport("5", MEASUREMENT_CONSUMER_NAME, true) } - for (i in 1..4) { - launch { createReport("$i", MEASUREMENT_CONSUMER_NAME) } - } - } - - private fun computeReportResult(completedReport: Report): Double { - var sum = 0.0 - completedReport.result.scalarTable.columnsList.forEach { column -> - column.setOperationsList.forEach { sum += it } - } - return sum - } - - private suspend fun listEventGroups(measurementConsumerName: String): ListEventGroupsResponse { - return publicEventGroupsClient - .withPrincipalName(measurementConsumerName) - .listEventGroups( - listEventGroupsRequest { parent = "$measurementConsumerName/dataProviders/-" } - ) - } - - private suspend fun createReportingSet( - runId: String, - measurementConsumerName: String, - ): ReportingSet { - val eventGroupsList = listEventGroups(measurementConsumerName).eventGroupsList - return publicReportingSetsClient - .withPrincipalName(measurementConsumerName) - .createReportingSet( - createReportingSetRequest { - parent = measurementConsumerName - reportingSet = reportingSet { - displayName = "reporting-set-$runId" - eventGroups += eventGroupsList.map { it.name } - } - } - ) - } - - private suspend fun listReportingSets( - measurementConsumerName: String - ): ListReportingSetsResponse { - return publicReportingSetsClient - .withPrincipalName(measurementConsumerName) - .listReportingSets(listReportingSetsRequest { parent = measurementConsumerName }) - } - - private suspend fun createReport( - runId: String, - measurementConsumerName: String, - cumulative: Boolean = false, - ): Report { - val eventGroupsList = listEventGroups(measurementConsumerName).eventGroupsList - val reportingSets = listReportingSets(measurementConsumerName).reportingSetsList - assertThat(reportingSets.size).isAtLeast(3) - val createReportRequest = createReportRequest { - parent = measurementConsumerName - report = report { - measurementConsumer = measurementConsumerName - reportIdempotencyKey = runId - eventGroupUniverse = - ReportKt.eventGroupUniverse { - eventGroupsList.forEach { - eventGroupEntries += ReportKt.EventGroupUniverseKt.eventGroupEntry { key = it.name } - } - } - periodicTimeInterval = periodicTimeInterval { - startTime = timestamp { seconds = 100 } - increment = duration { seconds = 5 } - intervalCount = 2 - } - metrics += metric { - this.cumulative = cumulative - impressionCount = MetricKt.impressionCountParams { maximumFrequencyPerUser = 5 } - val setOperation = - MetricKt.namedSetOperation { - uniqueName = "set-operation" - setOperation = - MetricKt.setOperation { - type = Metric.SetOperation.Type.UNION - lhs = - MetricKt.SetOperationKt.operand { - operation = - MetricKt.setOperation { - type = Metric.SetOperation.Type.UNION - lhs = - MetricKt.SetOperationKt.operand { reportingSet = reportingSets[0].name } - rhs = - MetricKt.SetOperationKt.operand { reportingSet = reportingSets[1].name } - } - } - rhs = MetricKt.SetOperationKt.operand { reportingSet = reportingSets[2].name } - } - } - - for (i in 1..NUM_SET_OPERATIONS) { - setOperations += setOperation.copy { uniqueName = "$uniqueName-$i" } - } - } - } - } - - val report = - publicReportsClient - .withPrincipalName(measurementConsumerName) - .createReport(createReportRequest) - - // Verify concurrent operations process the metrics without skipping set operations. - assertThat(report.metricsList) - .ignoringRepeatedFieldOrder() - .containsExactlyElementsIn(createReportRequest.report.metricsList) - return report - } - - private suspend fun getReport(reportName: String, principalName: String): Report { - return publicReportsClient - .withPrincipalName(principalName) - .getReport(getReportRequest { name = reportName }) - } - - private suspend fun listReports(measurementConsumerName: String): ListReportsResponse { - return publicReportsClient - .withPrincipalName(measurementConsumerName) - .listReports(listReportsRequest { parent = measurementConsumerName }) - } - - companion object { - private val SECRETS_DIR: File = - getRuntimePath( - Paths.get("wfa_measurement_system", "src", "main", "k8s", "testing", "secretfiles") - )!! - .toFile() - - private val TRUSTED_CERTIFICATES = - readCertificateCollection(SECRETS_DIR.resolve("all_root_certs.pem")).associateBy { - it.subjectKeyIdentifier!! - } - - private val MC_CERTIFICATE_DER: ByteString = - SECRETS_DIR.resolve("mc_cs_cert.der").readByteString() - private val MC_SIGNING_KEY_HANDLE: SigningKeyHandle = - loadSigningKey( - SECRETS_DIR.resolve("mc_cs_cert.der"), - SECRETS_DIR.resolve("mc_cs_private.der"), - ) - private val MC_ENCRYPTION_PUBLIC_KEY: EncryptionPublicKey = - loadPublicKey(SECRETS_DIR.resolve("mc_enc_public.tink")).toEncryptionPublicKey() - private const val MC_SIGNING_PRIVATE_KEY_PATH = "mc_cs_private.der" - private const val API_KEY = "AAAAAAAAAHs" - const val MEASUREMENT_CONSUMER_NAME = "measurementConsumers/AAAAAAAAAHs" - private const val MEASUREMENT_CONSUMER_CERTIFICATE_NAME = - "$MEASUREMENT_CONSUMER_NAME/certificates/AAAAAAAAAHs" - - private const val DATA_PROVIDER_NAME = "dataProviders/AAAAAAAAAHs" - private const val DATA_PROVIDER_CERTIFICATE_NAME = - "$DATA_PROVIDER_NAME/certificates/AAAAAAAAAHs" - private val EDP_CERTIFICATE_DER: ByteString = - readCertificate(SECRETS_DIR.resolve("edp1_cs_cert.der").readByteString()) - .encoded - .toByteString() - private val EDP_SIGNING_KEY_HANDLE: SigningKeyHandle = - loadSigningKey( - SECRETS_DIR.resolve("edp1_cs_cert.der"), - SECRETS_DIR.resolve("edp1_cs_private.der"), - ) - - private val CERTIFICATE: Certificate = certificate { - name = DATA_PROVIDER_CERTIFICATE_NAME - x509Der = EDP_CERTIFICATE_DER - } - - private val DATA_PROVIDER = dataProvider { - name = DATA_PROVIDER_NAME - certificate = DATA_PROVIDER_CERTIFICATE_NAME - certificateDer = EDP_CERTIFICATE_DER - publicKey = - signEncryptionPublicKey( - loadPublicKey(SECRETS_DIR.resolve("edp1_enc_public.tink")).toEncryptionPublicKey(), - EDP_SIGNING_KEY_HANDLE, - ) - } - - private val MEASUREMENT_CONSUMER: MeasurementConsumer = measurementConsumer { - name = MEASUREMENT_CONSUMER_NAME - certificate = MEASUREMENT_CONSUMER_CERTIFICATE_NAME - certificateDer = MC_CERTIFICATE_DER - publicKey = signEncryptionPublicKey(MC_ENCRYPTION_PUBLIC_KEY, MC_SIGNING_KEY_HANDLE) - } - - private const val EVENT_GROUP_NAME = "$DATA_PROVIDER_NAME/eventGroups/AAAAAAAAAHs" - private val VID_MODEL_LINES = listOf("model1", "model2") - private val EVENT_TEMPLATE_TYPES = listOf("type1", "type2") - private val EVENT_TEMPLATES = - EVENT_TEMPLATE_TYPES.map { type -> EventGroupKt.eventTemplate { this.type = type } } - - private val EVENT_GROUP: EventGroup = eventGroup { - name = EVENT_GROUP_NAME - measurementConsumer = MEASUREMENT_CONSUMER_NAME - eventGroupReferenceId = "aaa" - measurementConsumerPublicKey = MEASUREMENT_CONSUMER.publicKey.message - vidModelLines.addAll(VID_MODEL_LINES) - eventTemplates.addAll(EVENT_TEMPLATES) - encryptedMetadata = - MC_ENCRYPTION_PUBLIC_KEY.toPublicKeyHandle() - .encryptMessage(EventGroup.Metadata.getDefaultInstance().pack()) - } - - private const val EVENT_GROUP_METADATA_DESCRIPTOR_NAME = - "$DATA_PROVIDER_NAME/eventGroupMetadataDescriptors/AAAAAAAAAHs" - private val FILE_DESCRIPTOR_SET = DescriptorProtos.FileDescriptorSet.getDefaultInstance() - - private val EVENT_GROUP_METADATA_DESCRIPTOR: EventGroupMetadataDescriptor = - eventGroupMetadataDescriptor { - name = EVENT_GROUP_METADATA_DESCRIPTOR_NAME - descriptorSet = FILE_DESCRIPTOR_SET - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/InProcessReportingServer.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/InProcessReportingServer.kt deleted file mode 100644 index 483a04abc22..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/InProcessReportingServer.kt +++ /dev/null @@ -1,280 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.integration.common.reporting - -import com.google.protobuf.ByteString -import io.grpc.Channel -import java.io.File -import java.security.SecureRandom -import java.security.cert.X509Certificate -import java.time.Duration -import java.util.logging.Logger -import kotlin.random.asKotlinRandom -import org.junit.rules.TestRule -import org.junit.runner.Description -import org.junit.runners.model.Statement -import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt.CertificatesCoroutineStub as PublicKingdomCertificatesCoroutineStub -import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCoroutineStub as PublicKingdomDataProvidersCoroutineStub -import org.wfanet.measurement.api.v2alpha.EventGroupMetadataDescriptorsGrpcKt.EventGroupMetadataDescriptorsCoroutineStub as PublicKingdomEventGroupMetadataDescriptorsCoroutineStub -import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub as PublicKingdomEventGroupsCoroutineStub -import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub as PublicKingdomMeasurementConsumersCoroutineStub -import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub as PublicKingdomMeasurementsCoroutineStub -import org.wfanet.measurement.api.withAuthenticationKey -import org.wfanet.measurement.common.crypto.tink.loadPrivateKey -import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule -import org.wfanet.measurement.common.grpc.withVerboseLogging -import org.wfanet.measurement.common.readByteString -import org.wfanet.measurement.common.testing.chainRulesSequentially -import org.wfanet.measurement.config.reporting.EncryptionKeyPairConfig -import org.wfanet.measurement.config.reporting.MeasurementConsumerConfig -import org.wfanet.measurement.config.reporting.MeasurementSpecConfigKt -import org.wfanet.measurement.config.reporting.measurementSpecConfig -import org.wfanet.measurement.integration.common.reporting.identity.withMetadataPrincipalIdentities -import org.wfanet.measurement.internal.reporting.MeasurementsGrpcKt.MeasurementsCoroutineStub as InternalMeasurementsCoroutineStub -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt.ReportingSetsCoroutineStub as InternalReportingSetsCoroutineStub -import org.wfanet.measurement.internal.reporting.ReportsGrpcKt.ReportsCoroutineStub as InternalReportsCoroutineStub -import org.wfanet.measurement.reporting.deploy.common.server.ReportingDataServer -import org.wfanet.measurement.reporting.deploy.common.server.ReportingDataServer.Companion.toList -import org.wfanet.measurement.reporting.service.api.CelEnvCacheProvider -import org.wfanet.measurement.reporting.service.api.InMemoryEncryptionKeyPairStore -import org.wfanet.measurement.reporting.service.api.v1alpha.EventGroupsService -import org.wfanet.measurement.reporting.service.api.v1alpha.ReportingSetsService -import org.wfanet.measurement.reporting.service.api.v1alpha.ReportsService -import org.wfanet.measurement.reporting.v1alpha.EventGroup - -/** TestRule that starts and stops all Reporting Server gRPC services. */ -class InProcessReportingServer( - private val reportingServerDataServices: ReportingDataServer.Services, - private val publicKingdomChannel: Channel, - private val encryptionKeyPairConfig: EncryptionKeyPairConfig, - private val signingPrivateKeyDir: File, - private val measurementConsumerConfig: MeasurementConsumerConfig, - trustedCertificates: Map, - private val verboseGrpcLogging: Boolean = true, -) : TestRule { - private val publicKingdomMeasurementConsumersClient by lazy { - PublicKingdomMeasurementConsumersCoroutineStub(publicKingdomChannel) - } - private val publicKingdomMeasurementsClient by lazy { - PublicKingdomMeasurementsCoroutineStub(publicKingdomChannel) - } - private val publicKingdomCertificatesClient by lazy { - PublicKingdomCertificatesCoroutineStub(publicKingdomChannel) - } - private val publicKingdomDataProvidersClient by lazy { - PublicKingdomDataProvidersCoroutineStub(publicKingdomChannel) - } - private val publicKingdomEventGroupsClient by lazy { - PublicKingdomEventGroupsCoroutineStub(publicKingdomChannel) - } - private val publicKingdomEventGroupMetadataDescriptorsClient by lazy { - PublicKingdomEventGroupMetadataDescriptorsCoroutineStub(publicKingdomChannel) - } - - private val internalApiChannel by lazy { internalDataServer.channel } - private val internalMeasurementsClient by lazy { - InternalMeasurementsCoroutineStub(internalApiChannel) - } - private val internalReportingSetsClient by lazy { - InternalReportingSetsCoroutineStub(internalApiChannel) - } - private val internalReportsClient by lazy { InternalReportsCoroutineStub(internalApiChannel) } - - private val internalDataServer = - GrpcTestServerRule(logAllRequests = verboseGrpcLogging) { - logger.info("Building Reporting Server's internal Data services") - reportingServerDataServices.toList().forEach { - addService(it.withVerboseLogging(verboseGrpcLogging)) - } - } - - private val publicApiServer = - GrpcTestServerRule(logAllRequests = verboseGrpcLogging) { - logger.info("Building Reporting Server's public API services") - - val encryptionKeyPairStore = - InMemoryEncryptionKeyPairStore( - encryptionKeyPairConfig.principalKeyPairsList.associateBy( - { it.principal }, - { - it.keyPairsList.map { keyPair -> - Pair( - signingPrivateKeyDir.resolve(keyPair.publicKeyFile).readByteString(), - loadPrivateKey(signingPrivateKeyDir.resolve(keyPair.privateKeyFile)), - ) - } - }, - ) - ) - - val celEnvCacheProvider = - CelEnvCacheProvider( - publicKingdomEventGroupMetadataDescriptorsClient.withAuthenticationKey( - measurementConsumerConfig.apiKey - ), - EventGroup.getDescriptor(), - Duration.ofSeconds(5), - knownMetadataTypes = emptyList(), - ) - - listOf( - EventGroupsService( - publicKingdomEventGroupsClient, - encryptionKeyPairStore, - celEnvCacheProvider, - ) - .withMetadataPrincipalIdentities(measurementConsumerConfig), - ReportingSetsService(internalReportingSetsClient) - .withMetadataPrincipalIdentities(measurementConsumerConfig), - ReportsService( - internalReportsClient, - internalReportingSetsClient, - internalMeasurementsClient, - publicKingdomDataProvidersClient, - publicKingdomMeasurementConsumersClient, - publicKingdomMeasurementsClient, - publicKingdomCertificatesClient, - encryptionKeyPairStore, - SecureRandom().asKotlinRandom(), - signingPrivateKeyDir, - trustedCertificates, - MEASUREMENT_SPEC_CONFIG, - ) - .withMetadataPrincipalIdentities(measurementConsumerConfig), - ) - .forEach { addService(it.withVerboseLogging(verboseGrpcLogging)) } - } - - /** Provides a gRPC channel to the Reporting Server's public API. */ - val publicApiChannel: Channel - get() = publicApiServer.channel - - override fun apply(statement: Statement, description: Description): Statement { - return chainRulesSequentially(internalDataServer, publicApiServer).apply(statement, description) - } - - companion object { - private val logger: Logger = Logger.getLogger(this::class.java.name) - - private val MEASUREMENT_SPEC_CONFIG = measurementSpecConfig { - reachSingleDataProvider = - MeasurementSpecConfigKt.reachSingleDataProvider { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.000207 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - reach = - MeasurementSpecConfigKt.reach { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.0007444 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - randomStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.randomStart { - width = 256 - numVidBuckets = 300 - } - } - } - reachAndFrequencySingleDataProvider = - MeasurementSpecConfigKt.reachAndFrequencySingleDataProvider { - reachPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.000207 - delta = 1e-15 - } - frequencyPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.004728 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - reachAndFrequency = - MeasurementSpecConfigKt.reachAndFrequency { - reachPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.0007444 - delta = 1e-15 - } - frequencyPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.014638 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - randomStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.randomStart { - width = 256 - numVidBuckets = 300 - } - } - } - impression = - MeasurementSpecConfigKt.impression { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.003592 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - duration = - MeasurementSpecConfigKt.duration { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.007418 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/BUILD.bazel deleted file mode 100644 index a1e471db718..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/BUILD.bazel +++ /dev/null @@ -1,34 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -package( - default_testonly = True, - default_visibility = [ - "//src/main/kotlin/org/wfanet/measurement/integration/common/reporting:__subpackages__", - ], -) - -kt_jvm_library( - name = "metadata_principal_interceptor", - srcs = ["MetadataPrincipalServerInterceptor.kt"], - deps = [ - ":reporting_principal_identity", - "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:context_keys", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:context_keys", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:principal_server_interceptor", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reporting_principal", - "//src/main/proto/wfa/measurement/config:duchy_cert_config_kt_jvm_proto", - "@wfa_common_jvm//imports/java/io/grpc:api", - "@wfa_common_jvm//imports/java/io/grpc/stub", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - ], -) - -kt_jvm_library( - name = "reporting_principal_identity", - srcs = ["ReportingPrincipalIdentity.kt"], - deps = [ - "@wfa_common_jvm//imports/java/io/grpc:api", - "@wfa_common_jvm//imports/java/io/grpc/stub", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/MetadataPrincipalServerInterceptor.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/MetadataPrincipalServerInterceptor.kt deleted file mode 100644 index 629ddcbbe96..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/MetadataPrincipalServerInterceptor.kt +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.integration.common.reporting.identity - -import io.grpc.BindableService -import io.grpc.Context -import io.grpc.Contexts -import io.grpc.Metadata -import io.grpc.ServerCall -import io.grpc.ServerCallHandler -import io.grpc.ServerInterceptor -import io.grpc.ServerInterceptors -import io.grpc.ServerServiceDefinition -import io.grpc.Status -import org.wfanet.measurement.config.reporting.MeasurementConsumerConfig -import org.wfanet.measurement.reporting.service.api.v1alpha.ContextKeys -import org.wfanet.measurement.reporting.service.api.v1alpha.ReportingPrincipal -import org.wfanet.measurement.reporting.service.api.v1alpha.principalFromCurrentContext -import org.wfanet.measurement.reporting.service.api.v1alpha.withPrincipal - -/** - * Extracts a name from the gRPC [Metadata] and creates a [ReportingPrincipal] to add to the gRPC - * [Context]. - * - * To install, wrap a service with: - * ``` - * yourService.withMetadataPrincipalIdentities() - * ``` - * - * The principal can be accessed within gRPC services via [principalFromCurrentContext]. - * - * This expects the Metadata to have a key "reporting_principal" associated with a value equal to - * the resource name of the principal. The recommended way to set this is to use [withPrincipalName] - * on a stub. - */ -class MetadataPrincipalServerInterceptor(private val config: MeasurementConsumerConfig) : - ServerInterceptor { - override fun interceptCall( - call: ServerCall, - headers: Metadata, - next: ServerCallHandler, - ): ServerCall.Listener { - if (ContextKeys.PRINCIPAL_CONTEXT_KEY.get() != null) { - return Contexts.interceptCall(Context.current(), call, headers, next) - } - - val principalName = headers[REPORTING_PRINCIPAL_NAME_METADATA_KEY] - if (principalName == null) { - call.close( - Status.UNAUTHENTICATED.withDescription("$REPORTING_PRINCIPAL_NAME_METADATA_KEY not found"), - Metadata(), - ) - return object : ServerCall.Listener() {} - } - val principal = ReportingPrincipal.fromConfigs(principalName, config) - if (principal == null) { - call.close(Status.UNAUTHENTICATED.withDescription("No valid Principal found"), Metadata()) - return object : ServerCall.Listener() {} - } - val context = Context.current().withPrincipal(principal) - return Contexts.interceptCall(context, call, headers, next) - } -} - -/** Installs [MetadataPrincipalServerInterceptor] on the service. */ -fun BindableService.withMetadataPrincipalIdentities( - config: MeasurementConsumerConfig -): ServerServiceDefinition = - ServerInterceptors.interceptForward(this, MetadataPrincipalServerInterceptor(config)) diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/ReportingPrincipalIdentity.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/ReportingPrincipalIdentity.kt deleted file mode 100644 index ab42b0aa409..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/ReportingPrincipalIdentity.kt +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.integration.common.reporting.identity - -import io.grpc.Metadata -import io.grpc.stub.AbstractStub -import io.grpc.stub.MetadataUtils - -private const val KEY_NAME = "reporting_principal" -val REPORTING_PRINCIPAL_NAME_METADATA_KEY: Metadata.Key = - Metadata.Key.of(KEY_NAME, Metadata.ASCII_STRING_MARSHALLER) - -/** - * Sets metadata key "reporting_principal" on all outgoing requests. On the server side, use - * [MetadataPrincipalServerInterceptor]. Note that this should only be used in in-process tests - * where mTLS isn't used. - */ -fun > T.withPrincipalName(name: String): T { - val extraHeaders = Metadata() - extraHeaders.put(REPORTING_PRINCIPAL_NAME_METADATA_KEY, name) - return withInterceptors(MetadataUtils.newAttachHeadersInterceptor(extraHeaders)) -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/BUILD.bazel deleted file mode 100644 index c576b711d5d..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/BUILD.bazel +++ /dev/null @@ -1,34 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -package( - default_visibility = [ - "//src/main/kotlin/org/wfanet/measurement/reporting:__subpackages__", - "//src/test/kotlin/org/wfanet/measurement/reporting:__subpackages__", - ], -) - -kt_jvm_library( - name = "encryption_key_pair_map", - srcs = ["EncryptionKeyPairMap.kt"], - deps = [ - "//src/main/proto/wfa/measurement/config/reporting:encryption_key_pair_config_kt_jvm_proto", - "@wfa_common_jvm//imports/java/picocli", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/tink", - ], -) - -kt_jvm_library( - name = "flags", - srcs = ["InternalApiFlags.kt"], - deps = [ - "@wfa_common_jvm//imports/java/picocli", - ], -) - -kt_jvm_library( - name = "kingdom_flags", - srcs = ["KingdomApiFlags.kt"], - deps = [ - "@wfa_common_jvm//imports/java/picocli", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/EncryptionKeyPairMap.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/EncryptionKeyPairMap.kt deleted file mode 100644 index dcd3135ee9c..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/EncryptionKeyPairMap.kt +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.common - -import com.google.protobuf.ByteString -import java.io.File -import org.wfanet.measurement.common.crypto.PrivateKeyHandle -import org.wfanet.measurement.common.crypto.tink.loadPrivateKey -import org.wfanet.measurement.common.parseTextProto -import org.wfanet.measurement.common.readByteString -import org.wfanet.measurement.config.reporting.encryptionKeyPairConfig -import picocli.CommandLine.Option - -class EncryptionKeyPairMap { - @Option( - names = ["--key-pair-dir"], - description = ["Path to the directory of MeasurementConsumer's encryption keys"], - ) - private lateinit var keyFilesDirectory: File - - @Option( - names = ["--key-pair-config-file"], - description = ["Path to the textproto file of EncryptionKeyPairConfig that contains key pairs"], - required = true, - ) - private lateinit var keyPairConfigFile: File - - private fun loadKeyPairs(): Map>> { - val keyPairConfig = - parseTextProto(keyPairConfigFile, encryptionKeyPairConfig {}).principalKeyPairsList - return keyPairConfig.associate { config -> - val keyPairs = - config.keyPairsList.map { keyPair -> - val publicKey = keyFilesDirectory.resolve(keyPair.publicKeyFile).readByteString() - val privateKey = loadPrivateKey(keyFilesDirectory.resolve(keyPair.privateKeyFile)) - publicKey to privateKey - } - checkNotNull(config.principal) to keyPairs - } - } - - val keyPairs: Map>> by lazy { loadKeyPairs() } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/InternalApiFlags.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/InternalApiFlags.kt deleted file mode 100644 index 122b61bc0a2..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/InternalApiFlags.kt +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.common - -import picocli.CommandLine - -class InternalApiFlags { - @set:CommandLine.Option( - names = ["--internal-api-target"], - description = ["gRPC target (authority) of the Reporting internal API server"], - required = true, - ) - lateinit var target: String - - @CommandLine.Option( - names = ["--internal-api-cert-host"], - description = - [ - "Expected hostname (DNS-ID) in the Reporting internal API server's TLS certificate.", - "This overrides derivation of the TLS DNS-ID from --internal-api-target.", - ], - required = false, - ) - var certHost: String? = null - private set -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/KingdomApiFlags.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/KingdomApiFlags.kt deleted file mode 100644 index 608a72797c7..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/KingdomApiFlags.kt +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.common - -import picocli.CommandLine - -class KingdomApiFlags { - @set:CommandLine.Option( - names = ["--kingdom-api-target"], - description = ["gRPC target (authority) of the Kingdom public API server"], - required = true, - ) - lateinit var target: String - - @CommandLine.Option( - names = ["--kingdom-api-cert-host"], - description = - [ - "Expected hostname (DNS-ID) in the Kingdom public API server's TLS certificate.", - "This overrides derivation of the TLS DNS-ID from --kingdom-api-target.", - ], - required = false, - ) - var certHost: String? = null - private set -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/BUILD.bazel deleted file mode 100644 index a98176a11ef..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/BUILD.bazel +++ /dev/null @@ -1,78 +0,0 @@ -load("@rules_java//java:defs.bzl", "java_binary") -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") -load("//src/main/docker:macros.bzl", "java_image") - -kt_jvm_library( - name = "reporting_data_server", - srcs = ["ReportingDataServer.kt"], - visibility = [ - "//src/main/kotlin/org/wfanet/measurement/integration/common/reporting:__subpackages__", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy:__subpackages__", - ], - deps = [ - "//src/main/proto/wfa/measurement/internal/reporting:measurements_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reporting_sets_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reports_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - ], -) - -kt_jvm_library( - name = "reporting_api_server_flags", - srcs = ["ReportingApiServerFlags.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common:flags", - "@wfa_common_jvm//imports/java/picocli", - ], -) - -kt_jvm_library( - name = "v1alpha_public_api_server", - srcs = ["V1AlphaPublicApiServer.kt"], - runtime_deps = ["@wfa_common_jvm//imports/java/io/grpc/netty"], - deps = [ - ":reporting_api_server_flags", - "//src/main/kotlin/org/wfanet/measurement/common/api:memoizing_principal_lookup", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common:encryption_key_pair_map", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common:kingdom_flags", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/config:measurement_spec_config_validator", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:cel_env_provider", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:encryption_key_pair_store", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:akid_principal_lookup", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:event_groups_service", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reporting_sets_service", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reports_service", - "//src/main/proto/wfa/measurement/api/v2alpha:certificates_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:data_providers_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:event_group_metadata_descriptors_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:event_groups_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:measurement_consumers_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:measurements_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/config/reporting:measurement_consumer_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/config/reporting:measurement_spec_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/internal/reporting:measurements_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reporting_sets_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reports_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//imports/java/io/grpc:api", - "@wfa_common_jvm//imports/java/picocli", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - ], -) - -java_binary( - name = "V1AlphaPublicApiServer", - main_class = "org.wfanet.measurement.reporting.deploy.common.server.V1AlphaPublicApiServerKt", - runtime_deps = [ - ":v1alpha_public_api_server", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/logging", - ], -) - -java_image( - name = "v1alpha_public_api_server_image", - binary = ":V1AlphaPublicApiServer", - main_class = "org.wfanet.measurement.reporting.deploy.common.server.V1AlphaPublicApiServerKt", - visibility = ["//src:docker_image_deployment"], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/ReportingApiServerFlags.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/ReportingApiServerFlags.kt deleted file mode 100644 index fc3960af21a..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/ReportingApiServerFlags.kt +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.common.server - -import java.time.Duration -import kotlin.properties.Delegates -import org.wfanet.measurement.reporting.deploy.common.InternalApiFlags -import picocli.CommandLine - -class ReportingApiServerFlags { - @CommandLine.Mixin - lateinit var internalApiFlags: InternalApiFlags - private set - - @set:CommandLine.Option( - names = ["--debug-verbose-grpc-client-logging"], - description = ["Enables full gRPC request and response logging for outgoing gRPCs"], - defaultValue = "false", - ) - var debugVerboseGrpcClientLogging by Delegates.notNull() - private set - - @CommandLine.Option( - names = ["--event-group-metadata-descriptor-cache-duration"], - description = - [ - "How long the event group metadata descriptors are cached for before refreshing in format 1d1h1m1s1ms1ns" - ], - defaultValue = "1h", - required = false, - ) - lateinit var eventGroupMetadataDescriptorCacheDuration: Duration - private set -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/ReportingDataServer.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/ReportingDataServer.kt deleted file mode 100644 index 9f3b903ff41..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/ReportingDataServer.kt +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.common.server - -import io.grpc.BindableService -import kotlin.reflect.full.declaredMemberProperties -import kotlinx.coroutines.runInterruptible -import org.wfanet.measurement.common.grpc.CommonServer -import org.wfanet.measurement.internal.reporting.MeasurementsGrpcKt -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt -import org.wfanet.measurement.internal.reporting.ReportsGrpcKt -import picocli.CommandLine - -abstract class ReportingDataServer : Runnable { - data class Services( - val measurementsService: MeasurementsGrpcKt.MeasurementsCoroutineImplBase, - val reportingSetsService: ReportingSetsGrpcKt.ReportingSetsCoroutineImplBase, - val reportsService: ReportsGrpcKt.ReportsCoroutineImplBase, - ) - - @CommandLine.Mixin private lateinit var serverFlags: CommonServer.Flags - - protected suspend fun run(services: Services) { - val server = CommonServer.fromFlags(serverFlags, this::class.simpleName!!, services.toList()) - - runInterruptible { server.start().blockUntilShutdown() } - } - - companion object { - fun Services.toList(): List { - return Services::class.declaredMemberProperties.map { it.get(this) as BindableService } - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/V1AlphaPublicApiServer.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/V1AlphaPublicApiServer.kt deleted file mode 100644 index 2b1b64bdb11..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/V1AlphaPublicApiServer.kt +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.common.server - -import com.google.protobuf.ByteString -import com.google.protobuf.DescriptorProtos -import com.google.protobuf.Descriptors -import io.grpc.Channel -import io.grpc.ServerServiceDefinition -import java.io.File -import java.security.SecureRandom -import kotlin.random.asKotlinRandom -import kotlinx.coroutines.Dispatchers -import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt.CertificatesCoroutineStub as KingdomCertificatesCoroutineStub -import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCoroutineStub as KingdomDataProvidersCoroutineStub -import org.wfanet.measurement.api.v2alpha.EventGroupMetadataDescriptorsGrpcKt.EventGroupMetadataDescriptorsCoroutineStub as KingdomEventGroupMetadataDescriptorsCoroutineStub -import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub as KingdomEventGroupsCoroutineStub -import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub as KingdomMeasurementConsumersCoroutineStub -import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub as KingdomMeasurementsCoroutineStub -import org.wfanet.measurement.api.withAuthenticationKey -import org.wfanet.measurement.common.ProtoReflection -import org.wfanet.measurement.common.api.PrincipalLookup -import org.wfanet.measurement.common.api.memoizing -import org.wfanet.measurement.common.commandLineMain -import org.wfanet.measurement.common.crypto.SigningCerts -import org.wfanet.measurement.common.grpc.CommonServer -import org.wfanet.measurement.common.grpc.buildMutualTlsChannel -import org.wfanet.measurement.common.grpc.withVerboseLogging -import org.wfanet.measurement.common.parseTextProto -import org.wfanet.measurement.config.reporting.MeasurementConsumerConfigs -import org.wfanet.measurement.config.reporting.MeasurementSpecConfig -import org.wfanet.measurement.internal.reporting.MeasurementsGrpcKt.MeasurementsCoroutineStub as InternalMeasurementsCoroutineStub -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt.ReportingSetsCoroutineStub as InternalReportingSetsCoroutineStub -import org.wfanet.measurement.internal.reporting.ReportsGrpcKt.ReportsCoroutineStub as InternalReportsCoroutineStub -import org.wfanet.measurement.reporting.deploy.common.EncryptionKeyPairMap -import org.wfanet.measurement.reporting.deploy.common.KingdomApiFlags -import org.wfanet.measurement.reporting.deploy.config.validate -import org.wfanet.measurement.reporting.service.api.CelEnvCacheProvider -import org.wfanet.measurement.reporting.service.api.InMemoryEncryptionKeyPairStore -import org.wfanet.measurement.reporting.service.api.v1alpha.AkidPrincipalLookup -import org.wfanet.measurement.reporting.service.api.v1alpha.EventGroupsService -import org.wfanet.measurement.reporting.service.api.v1alpha.ReportingPrincipal -import org.wfanet.measurement.reporting.service.api.v1alpha.ReportingSetsService -import org.wfanet.measurement.reporting.service.api.v1alpha.ReportsService -import org.wfanet.measurement.reporting.service.api.v1alpha.withPrincipalsFromX509AuthorityKeyIdentifiers -import org.wfanet.measurement.reporting.v1alpha.EventGroup -import picocli.CommandLine - -private const val SERVER_NAME = "V1AlphaPublicApiServer" - -@CommandLine.Command( - name = SERVER_NAME, - description = ["Server daemon for Reporting v1alpha public API services."], - mixinStandardHelpOptions = true, - showDefaultValues = true, -) -private fun run( - @CommandLine.Mixin reportingApiServerFlags: ReportingApiServerFlags, - @CommandLine.Mixin kingdomApiFlags: KingdomApiFlags, - @CommandLine.Mixin commonServerFlags: CommonServer.Flags, - @CommandLine.Mixin v1AlphaFlags: V1AlphaFlags, - @CommandLine.Mixin encryptionKeyPairMap: EncryptionKeyPairMap, -) { - val clientCerts = - SigningCerts.fromPemFiles( - certificateFile = commonServerFlags.tlsFlags.certFile, - privateKeyFile = commonServerFlags.tlsFlags.privateKeyFile, - trustedCertCollectionFile = commonServerFlags.tlsFlags.certCollectionFile, - ) - val channel: Channel = - buildMutualTlsChannel( - reportingApiServerFlags.internalApiFlags.target, - clientCerts, - reportingApiServerFlags.internalApiFlags.certHost, - ) - .withVerboseLogging(reportingApiServerFlags.debugVerboseGrpcClientLogging) - - val kingdomChannel: Channel = - buildMutualTlsChannel( - target = kingdomApiFlags.target, - clientCerts = clientCerts, - hostName = kingdomApiFlags.certHost, - ) - .withVerboseLogging(reportingApiServerFlags.debugVerboseGrpcClientLogging) - - val principalLookup: PrincipalLookup = - AkidPrincipalLookup( - v1AlphaFlags.authorityKeyIdentifierToPrincipalMapFile, - v1AlphaFlags.measurementConsumerConfigFile, - ) - .memoizing() - - val measurementConsumerConfigs = - parseTextProto( - v1AlphaFlags.measurementConsumerConfigFile, - MeasurementConsumerConfigs.getDefaultInstance(), - ) - - val apiKey = measurementConsumerConfigs.configsMap.values.first().apiKey - val celEnvCacheProvider = - CelEnvCacheProvider( - KingdomEventGroupMetadataDescriptorsCoroutineStub(kingdomChannel) - .withAuthenticationKey(apiKey), - EventGroup.getDescriptor(), - reportingApiServerFlags.eventGroupMetadataDescriptorCacheDuration, - v1AlphaFlags.knownEventGroupMetadataTypes, - Dispatchers.Default, - ) - - val measurementSpecConfig = - parseTextProto( - v1AlphaFlags.measurementSpecConfigFile, - MeasurementSpecConfig.getDefaultInstance(), - ) - - try { - measurementSpecConfig.validate() - } catch (e: IllegalStateException) { - throw IllegalArgumentException("Invalid MeasurementSpeConfig", e) - } - - val services: List = - listOf( - EventGroupsService( - KingdomEventGroupsCoroutineStub(kingdomChannel), - InMemoryEncryptionKeyPairStore(encryptionKeyPairMap.keyPairs), - celEnvCacheProvider, - ) - .withPrincipalsFromX509AuthorityKeyIdentifiers(principalLookup), - ReportingSetsService(InternalReportingSetsCoroutineStub(channel)) - .withPrincipalsFromX509AuthorityKeyIdentifiers(principalLookup), - ReportsService( - InternalReportsCoroutineStub(channel), - InternalReportingSetsCoroutineStub(channel), - InternalMeasurementsCoroutineStub(channel), - KingdomDataProvidersCoroutineStub(kingdomChannel), - KingdomMeasurementConsumersCoroutineStub(kingdomChannel), - KingdomMeasurementsCoroutineStub(kingdomChannel), - KingdomCertificatesCoroutineStub(kingdomChannel), - InMemoryEncryptionKeyPairStore(encryptionKeyPairMap.keyPairs), - SecureRandom().asKotlinRandom(), - v1AlphaFlags.signingPrivateKeyStoreDir, - commonServerFlags.tlsFlags.signingCerts.trustedCertificates, - measurementSpecConfig, - ) - .withPrincipalsFromX509AuthorityKeyIdentifiers(principalLookup), - ) - CommonServer.fromFlags(commonServerFlags, SERVER_NAME, services).start().blockUntilShutdown() -} - -fun main(args: Array) = commandLineMain(::run, args) - -/** Flags specific to the V1Alpha API version. */ -private class V1AlphaFlags { - @CommandLine.Option( - names = ["--authority-key-identifier-to-principal-map-file"], - description = ["File path to a AuthorityKeyToPrincipalMap textproto"], - required = true, - ) - lateinit var authorityKeyIdentifierToPrincipalMapFile: File - private set - - @CommandLine.Option( - names = ["--measurement-consumer-config-file"], - description = ["File path to a MeasurementConsumerConfig textproto"], - required = true, - ) - lateinit var measurementConsumerConfigFile: File - private set - - @CommandLine.Option( - names = ["--measurement-spec-config-file"], - description = ["File path to a MeasurementSpecConfig textproto"], - required = true, - ) - lateinit var measurementSpecConfigFile: File - private set - - @CommandLine.Option( - names = ["--signing-private-key-store-dir"], - description = ["File path to the signing private key store directory"], - required = true, - ) - lateinit var signingPrivateKeyStoreDir: File - private set - - @CommandLine.Option( - names = ["--known-event-group-metadata-type"], - description = - [ - "File path to FileDescriptorSet containing known EventGroup metadata types.", - "This is in addition to standard protobuf well-known types.", - "Can be specified multiple times.", - ], - required = false, - defaultValue = "", - ) - private fun setKnownEventGroupMetadataTypes(fileDescriptorSetFiles: List) { - val fileDescriptorSets = - fileDescriptorSetFiles.map { file -> - file.inputStream().use { input -> DescriptorProtos.FileDescriptorSet.parseFrom(input) } - } - knownEventGroupMetadataTypes = ProtoReflection.buildFileDescriptors(fileDescriptorSets) - } - - lateinit var knownEventGroupMetadataTypes: List - private set -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres/BUILD.bazel deleted file mode 100644 index bc5bae4cdce..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres/BUILD.bazel +++ /dev/null @@ -1,20 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -package( - default_visibility = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/server:__subpackages__", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/server:__subpackages__", - "//src/test/kotlin/org/wfanet/measurement/integration/deploy/common/postgres:__pkg__", - ], -) - -kt_jvm_library( - name = "services", - srcs = ["PostgresServices.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server:reporting_data_server", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres:services", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres/PostgresServices.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres/PostgresServices.kt deleted file mode 100644 index 93b52a7c012..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres/PostgresServices.kt +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.common.server.postgres - -import org.wfanet.measurement.common.db.r2dbc.DatabaseClient -import org.wfanet.measurement.common.identity.IdGenerator -import org.wfanet.measurement.reporting.deploy.common.server.ReportingDataServer.Services -import org.wfanet.measurement.reporting.deploy.postgres.PostgresMeasurementsService -import org.wfanet.measurement.reporting.deploy.postgres.PostgresReportingSetsService -import org.wfanet.measurement.reporting.deploy.postgres.PostgresReportsService - -object PostgresServices { - @JvmStatic - fun create(idGenerator: IdGenerator, client: DatabaseClient): Services { - return Services( - PostgresMeasurementsService(idGenerator, client), - PostgresReportingSetsService(idGenerator, client), - PostgresReportsService(idGenerator, client), - ) - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/config/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/config/BUILD.bazel deleted file mode 100644 index 37ae63fa26b..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/config/BUILD.bazel +++ /dev/null @@ -1,13 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -kt_jvm_library( - name = "measurement_spec_config_validator", - srcs = ["MeasurementSpecConfigValidator.kt"], - visibility = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common:__subpackages__", - "//src/test/kotlin/org/wfanet/measurement/reporting/deploy/config:__subpackages__", - ], - deps = [ - "//src/main/proto/wfa/measurement/config/reporting:measurement_spec_config_kt_jvm_proto", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/config/MeasurementSpecConfigValidator.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/config/MeasurementSpecConfigValidator.kt deleted file mode 100644 index 8858d48f807..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/config/MeasurementSpecConfigValidator.kt +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright 2023 The Cross-Media Measurement Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.wfanet.measurement.reporting.deploy.config - -import org.wfanet.measurement.config.reporting.MeasurementSpecConfig - -/** - * Validates a [MeasurementSpecConfig] - * - * @throws [IllegalStateException] if the [MeasurementSpecConfig] is invalid. - */ -fun MeasurementSpecConfig.validate() { - check(this.reachSingleDataProvider.privacyParams.isValid()) { - "reach_single_data_provider privacy_params is invalid." - } - check(this.reachSingleDataProvider.vidSamplingInterval.isValid()) { - "reach_single_data_provider vid_sampling_interval is invalid." - } - - check(this.reach.privacyParams.isValid()) { "reach privacy_params is invalid." } - check(this.reach.vidSamplingInterval.isValid()) { "reach vid_sampling_interval is invalid." } - - check(this.reachAndFrequencySingleDataProvider.reachPrivacyParams.isValid()) { - "reach_and_frequency_single_data_provider reach_privacy_params is invalid." - } - check(this.reachAndFrequencySingleDataProvider.frequencyPrivacyParams.isValid()) { - "reach_and_frequency_single_data_provider frequency_privacy_params is invalid." - } - check(this.reachAndFrequencySingleDataProvider.vidSamplingInterval.isValid()) { - "reach_and_frequency_single_data_provider vid_sampling_interval is invalid." - } - - check(this.reachAndFrequency.reachPrivacyParams.isValid()) { - "reach_and_frequency reach_privacy_params is invalid." - } - check(this.reachAndFrequency.frequencyPrivacyParams.isValid()) { - "reach_and_frequency frequency_privacy_params is invalid." - } - check(this.reachAndFrequency.vidSamplingInterval.isValid()) { - "reach_and_frequency vid_sampling_interval is invalid." - } - - check(this.impression.privacyParams.isValid()) { "impression privacy_params is invalid." } - check(this.impression.vidSamplingInterval.isValid()) { - "impression vid_sampling_interval is invalid." - } - - check(this.duration.privacyParams.isValid()) { "duration privacy_params is invalid." } - check(this.duration.vidSamplingInterval.isValid()) { - "duration vid_sampling_interval is invalid." - } -} - -private fun MeasurementSpecConfig.DifferentialPrivacyParams.isValid(): Boolean { - return (this.epsilon > 0 && this.delta >= 0) -} - -private fun MeasurementSpecConfig.VidSamplingInterval.isValid(): Boolean { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") - return when (this.startCase) { - MeasurementSpecConfig.VidSamplingInterval.StartCase.FIXED_START -> this.fixedStart.isValid() - MeasurementSpecConfig.VidSamplingInterval.StartCase.RANDOM_START -> this.randomStart.isValid() - MeasurementSpecConfig.VidSamplingInterval.StartCase.START_NOT_SET -> true - } -} - -private fun MeasurementSpecConfig.VidSamplingInterval.FixedStart.isValid(): Boolean { - if (this.start < 0.0) { - return false - } - - if (this.width <= 0.0) { - return false - } - - if (this.start + this.width > 1.0) { - return false - } - - return true -} - -private fun MeasurementSpecConfig.VidSamplingInterval.RandomStart.isValid(): Boolean { - if (this.numVidBuckets <= 0) { - return false - } - - if (this.width <= 0 || this.width > this.numVidBuckets) { - return false - } - - return true -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/server/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/server/BUILD.bazel deleted file mode 100644 index a90ab6868a3..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/server/BUILD.bazel +++ /dev/null @@ -1,35 +0,0 @@ -load("@rules_java//java:defs.bzl", "java_binary") -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") -load("//src/main/docker:macros.bzl", "java_image") - -kt_jvm_library( - name = "gcloud_postgres_reporting_data_server", - srcs = ["GCloudPostgresReportingDataServer.kt"], - runtime_deps = ["@wfa_common_jvm//imports/java/com/google/cloud/sql/postgres:r2dbc"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server:reporting_data_server", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres:services", - "@wfa_common_jvm//imports/java/io/r2dbc", - "@wfa_common_jvm//imports/java/picocli", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/postgres:factories", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/postgres:flags", - ], -) - -java_binary( - name = "GCloudPostgresReportingDataServer", - main_class = "org.wfanet.measurement.reporting.deploy.gcloud.postgres.server.GCloudPostgresReportingDataServerKt", - runtime_deps = [ - ":gcloud_postgres_reporting_data_server", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/logging", - ], -) - -java_image( - name = "gcloud_postgres_reporting_data_server_image", - binary = ":GCloudPostgresReportingDataServer", - main_class = "org.wfanet.measurement.reporting.deploy.gcloud.postgres.server.GCloudPostgresReportingDataServerKt", - visibility = ["//src:docker_image_deployment"], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/server/GCloudPostgresReportingDataServer.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/server/GCloudPostgresReportingDataServer.kt deleted file mode 100644 index a59917f794e..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/server/GCloudPostgresReportingDataServer.kt +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.gcloud.postgres.server - -import java.time.Clock -import kotlinx.coroutines.runBlocking -import org.wfanet.measurement.common.commandLineMain -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresDatabaseClient -import org.wfanet.measurement.common.identity.RandomIdGenerator -import org.wfanet.measurement.gcloud.postgres.PostgresConnectionFactories -import org.wfanet.measurement.gcloud.postgres.PostgresFlags as GCloudPostgresFlags -import org.wfanet.measurement.reporting.deploy.common.server.ReportingDataServer -import org.wfanet.measurement.reporting.deploy.common.server.postgres.PostgresServices -import picocli.CommandLine - -/** Implementation of [ReportingDataServer] using Google Cloud Postgres. */ -@CommandLine.Command( - name = "GCloudPostgresReportingDataServer", - description = ["Start the internal Reporting data-layer services in a single blocking server."], - mixinStandardHelpOptions = true, - showDefaultValues = true, -) -class GCloudPostgresReportingDataServer : ReportingDataServer() { - @CommandLine.Mixin private lateinit var gCloudPostgresFlags: GCloudPostgresFlags - - override fun run() = runBlocking { - val clock = Clock.systemUTC() - val idGenerator = RandomIdGenerator(clock) - - val factory = PostgresConnectionFactories.buildConnectionFactory(gCloudPostgresFlags) - val client = PostgresDatabaseClient.fromConnectionFactory(factory) - - run(PostgresServices.create(idGenerator, client)) - } -} - -fun main(args: Array) = commandLineMain(GCloudPostgresReportingDataServer(), args) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/tools/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/tools/BUILD.bazel deleted file mode 100644 index 613680b7c64..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/tools/BUILD.bazel +++ /dev/null @@ -1,21 +0,0 @@ -load("@rules_java//java:defs.bzl", "java_binary") -load("//src/main/docker:macros.bzl", "java_image") - -java_binary( - name = "UpdateSchema", - args = ["--changelog=reporting/postgres/changelog.yaml"], - main_class = "org.wfanet.measurement.gcloud.postgres.tools.UpdateSchema", - resources = ["//src/main/resources/reporting/postgres"], - runtime_deps = [ - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/logging", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/postgres/tools:update_schema", - ], -) - -java_image( - name = "update_schema_image", - args = ["--changelog=reporting/postgres/changelog.yaml"], - binary = ":UpdateSchema", - main_class = "org.wfanet.measurement.gcloud.postgres.tools.UpdateSchema", - visibility = ["//src:docker_image_deployment"], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/BUILD.bazel deleted file mode 100644 index 49aa526e488..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/BUILD.bazel +++ /dev/null @@ -1,22 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -kt_jvm_library( - name = "services", - srcs = glob(["*Service.kt"]), - visibility = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres:__subpackages__", - "//src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres:__pkg__", - ], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers", - "//src/main/proto/wfa/measurement/internal/reporting:measurements_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reporting_sets_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reports_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//imports/java/io/grpc:api", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresMeasurementsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresMeasurementsService.kt deleted file mode 100644 index b1d02bfc0bc..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresMeasurementsService.kt +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres - -import io.grpc.Status -import org.wfanet.measurement.common.db.r2dbc.DatabaseClient -import org.wfanet.measurement.common.db.r2dbc.postgres.SerializableErrors -import org.wfanet.measurement.common.identity.IdGenerator -import org.wfanet.measurement.internal.reporting.BatchCreateMeasurementsRequest -import org.wfanet.measurement.internal.reporting.BatchCreateMeasurementsResponse -import org.wfanet.measurement.internal.reporting.GetMeasurementRequest -import org.wfanet.measurement.internal.reporting.Measurement -import org.wfanet.measurement.internal.reporting.MeasurementsGrpcKt.MeasurementsCoroutineImplBase -import org.wfanet.measurement.internal.reporting.SetMeasurementFailureRequest -import org.wfanet.measurement.internal.reporting.SetMeasurementResultRequest -import org.wfanet.measurement.internal.reporting.batchCreateMeasurementsResponse -import org.wfanet.measurement.reporting.deploy.postgres.readers.MeasurementReader -import org.wfanet.measurement.reporting.deploy.postgres.writers.CreateMeasurements -import org.wfanet.measurement.reporting.deploy.postgres.writers.SetMeasurementFailure -import org.wfanet.measurement.reporting.deploy.postgres.writers.SetMeasurementResult -import org.wfanet.measurement.reporting.service.internal.MeasurementNotFoundException -import org.wfanet.measurement.reporting.service.internal.MeasurementStateInvalidException - -class PostgresMeasurementsService( - private val idGenerator: IdGenerator, - private val client: DatabaseClient, -) : MeasurementsCoroutineImplBase() { - override suspend fun batchCreateMeasurements( - request: BatchCreateMeasurementsRequest - ): BatchCreateMeasurementsResponse { - return batchCreateMeasurementsResponse { - measurements += CreateMeasurements(request.measurementsList).execute(client, idGenerator) - } - } - - override suspend fun getMeasurement(request: GetMeasurementRequest): Measurement { - val measurementResult = - SerializableErrors.retrying { - MeasurementReader() - .readMeasurementByReferenceIds( - client.singleUse(), - request.measurementConsumerReferenceId, - request.measurementReferenceId, - ) - } - ?: throw MeasurementNotFoundException() - .asStatusRuntimeException(Status.Code.NOT_FOUND, "Measurement not found.") - - return measurementResult.measurement - } - - override suspend fun setMeasurementResult(request: SetMeasurementResultRequest): Measurement { - return try { - SetMeasurementResult(request).execute(client, idGenerator) - } catch (e: MeasurementNotFoundException) { - throw e.asStatusRuntimeException(Status.Code.NOT_FOUND, "Measurement not found.") - } catch (e: MeasurementStateInvalidException) { - throw e.asStatusRuntimeException( - Status.Code.FAILED_PRECONDITION, - "Measurement has already been updated.", - ) - } - } - - override suspend fun setMeasurementFailure(request: SetMeasurementFailureRequest): Measurement { - return try { - SetMeasurementFailure(request).execute(client, idGenerator) - } catch (e: MeasurementNotFoundException) { - throw e.asStatusRuntimeException(Status.Code.NOT_FOUND, "Measurement not found.") - } catch (e: MeasurementStateInvalidException) { - throw e.asStatusRuntimeException( - Status.Code.FAILED_PRECONDITION, - "Measurement has already been updated.", - ) - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportingSetsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportingSetsService.kt deleted file mode 100644 index c429fb046e3..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportingSetsService.kt +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres - -import io.grpc.Status -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.map -import org.wfanet.measurement.common.db.r2dbc.DatabaseClient -import org.wfanet.measurement.common.db.r2dbc.postgres.SerializableErrors -import org.wfanet.measurement.common.db.r2dbc.postgres.SerializableErrors.withSerializableErrorRetries -import org.wfanet.measurement.common.identity.ExternalId -import org.wfanet.measurement.common.identity.IdGenerator -import org.wfanet.measurement.internal.reporting.BatchGetReportingSetRequest -import org.wfanet.measurement.internal.reporting.GetReportingSetRequest -import org.wfanet.measurement.internal.reporting.ReportingSet -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt.ReportingSetsCoroutineImplBase -import org.wfanet.measurement.internal.reporting.StreamReportingSetsRequest -import org.wfanet.measurement.reporting.deploy.postgres.readers.ReportingSetReader -import org.wfanet.measurement.reporting.deploy.postgres.writers.CreateReportingSet -import org.wfanet.measurement.reporting.service.internal.ReportingSetAlreadyExistsException -import org.wfanet.measurement.reporting.service.internal.ReportingSetNotFoundException - -class PostgresReportingSetsService( - private val idGenerator: IdGenerator, - private val client: DatabaseClient, -) : ReportingSetsCoroutineImplBase() { - override suspend fun createReportingSet(request: ReportingSet): ReportingSet { - return try { - CreateReportingSet(request).execute(client, idGenerator) - } catch (e: ReportingSetAlreadyExistsException) { - throw e.asStatusRuntimeException( - Status.Code.ALREADY_EXISTS, - "IDs generated for Reporting Set already exist", - ) - } - } - - override suspend fun getReportingSet(request: GetReportingSetRequest): ReportingSet { - return try { - SerializableErrors.retrying { - ReportingSetReader() - .readReportingSetByExternalId( - client.singleUse(), - request.measurementConsumerReferenceId, - ExternalId(request.externalReportingSetId), - ) - .reportingSet - } - } catch (e: ReportingSetNotFoundException) { - throw e.asStatusRuntimeException(Status.Code.NOT_FOUND, "Reporting Set not found") - } - } - - override fun streamReportingSets(request: StreamReportingSetsRequest): Flow { - return ReportingSetReader() - .listReportingSets(client, request.filter, request.limit) - .map { result -> result.reportingSet } - .withSerializableErrorRetries() - } - - override fun batchGetReportingSet(request: BatchGetReportingSetRequest): Flow { - return ReportingSetReader() - .getReportingSetsByExternalIds( - client, - request.measurementConsumerReferenceId, - request.externalReportingSetIdsList, - ) - .map { result -> result.reportingSet } - .withSerializableErrorRetries() - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportsService.kt deleted file mode 100644 index 0ee2a9ae64a..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportsService.kt +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres - -import io.grpc.Status -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.map -import org.wfanet.measurement.common.db.r2dbc.DatabaseClient -import org.wfanet.measurement.common.db.r2dbc.postgres.SerializableErrors -import org.wfanet.measurement.common.db.r2dbc.postgres.SerializableErrors.withSerializableErrorRetries -import org.wfanet.measurement.common.identity.IdGenerator -import org.wfanet.measurement.internal.reporting.CreateReportRequest -import org.wfanet.measurement.internal.reporting.GetReportByIdempotencyKeyRequest -import org.wfanet.measurement.internal.reporting.GetReportRequest -import org.wfanet.measurement.internal.reporting.Report -import org.wfanet.measurement.internal.reporting.ReportsGrpcKt.ReportsCoroutineImplBase -import org.wfanet.measurement.internal.reporting.StreamReportsRequest -import org.wfanet.measurement.reporting.deploy.postgres.readers.ReportReader -import org.wfanet.measurement.reporting.deploy.postgres.writers.CreateReport -import org.wfanet.measurement.reporting.service.internal.MeasurementCalculationTimeIntervalNotFoundException -import org.wfanet.measurement.reporting.service.internal.ReportNotFoundException -import org.wfanet.measurement.reporting.service.internal.ReportingSetNotFoundException - -class PostgresReportsService( - private val idGenerator: IdGenerator, - private val client: DatabaseClient, -) : ReportsCoroutineImplBase() { - override suspend fun createReport(request: CreateReportRequest): Report { - try { - return CreateReport(request).execute(client, idGenerator) - } catch (e: ReportingSetNotFoundException) { - throw e.asStatusRuntimeException(Status.Code.NOT_FOUND, "Reporting Set not found") - } catch (e: MeasurementCalculationTimeIntervalNotFoundException) { - throw e.asStatusRuntimeException( - Status.Code.INVALID_ARGUMENT, - "Measurement Calculation Time Interval not found in Report", - ) - } - } - - override suspend fun getReport(request: GetReportRequest): Report { - try { - return SerializableErrors.retrying { - ReportReader() - .getReportByExternalId( - client.singleUse(), - request.measurementConsumerReferenceId, - request.externalReportId, - ) - .report - } - } catch (e: ReportNotFoundException) { - throw e.asStatusRuntimeException(Status.Code.NOT_FOUND, "Report not found") - } - } - - override suspend fun getReportByIdempotencyKey( - request: GetReportByIdempotencyKeyRequest - ): Report { - try { - return SerializableErrors.retrying { - ReportReader() - .getReportByIdempotencyKey( - client.singleUse(), - request.measurementConsumerReferenceId, - request.reportIdempotencyKey, - ) - .report - } - } catch (e: ReportNotFoundException) { - throw e.asStatusRuntimeException(Status.Code.NOT_FOUND, "Report not found") - } - } - - override fun streamReports(request: StreamReportsRequest): Flow { - return ReportReader() - .listReports(client, request.filter, request.limit) - .map { result -> result.report } - .withSerializableErrorRetries() - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/BUILD.bazel deleted file mode 100644 index 53bcb7330d3..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/BUILD.bazel +++ /dev/null @@ -1,26 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -package(default_visibility = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres:__subpackages__", -]) - -kt_jvm_library( - name = "readers", - srcs = glob(["*.kt"]), - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/service/internal:internal_exception", - "//src/main/proto/wfa/measurement/internal/reporting:measurement_kt_jvm_proto", - "//src/main/proto/wfa/measurement/internal/reporting:measurements_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:metric_kt_jvm_proto", - "//src/main/proto/wfa/measurement/internal/reporting:report_kt_jvm_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reporting_set_kt_jvm_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reporting_sets_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reports_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//imports/java/com/google/gson", - "@wfa_common_jvm//imports/java/io/r2dbc", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/MeasurementReader.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/MeasurementReader.kt deleted file mode 100644 index 21e9218f31f..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/MeasurementReader.kt +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.readers - -import kotlinx.coroutines.flow.firstOrNull -import org.wfanet.measurement.common.db.r2dbc.ReadContext -import org.wfanet.measurement.common.db.r2dbc.ResultRow -import org.wfanet.measurement.common.db.r2dbc.boundStatement -import org.wfanet.measurement.internal.reporting.Measurement -import org.wfanet.measurement.internal.reporting.measurement - -class MeasurementReader { - data class Result(val measurement: Measurement) - - private val baseSql: String = - """ - SELECT - MeasurementConsumerReferenceId, - MeasurementReferenceId, - State, - Failure, - Result - FROM - Measurements - """ - - fun translate(row: ResultRow): Result = Result(buildMeasurement(row)) - - /** - * Reads a Measurement using reference IDs. - * - * @return null when the Measurement is not found. - */ - suspend fun readMeasurementByReferenceIds( - readContext: ReadContext, - measurementConsumerReferenceId: String, - measurementReferenceId: String, - ): Result? { - val builder = - boundStatement( - baseSql + - """ - WHERE MeasurementConsumerReferenceId = $1 - AND MeasurementReferenceId = $2 - """ - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", measurementReferenceId) - } - - return readContext.executeQuery(builder).consume(::translate).firstOrNull() - } - - private fun buildMeasurement(row: ResultRow): Measurement { - return measurement { - measurementConsumerReferenceId = row["MeasurementConsumerReferenceId"] - measurementReferenceId = row["MeasurementReferenceId"] - state = Measurement.State.forNumber(row["State"]) - val failure: Measurement.Failure? = - row.getProtoMessage("Failure", Measurement.Failure.parser()) - if (failure != null) { - this.failure = failure - } - val result: Measurement.Result? = row.getProtoMessage("Result", Measurement.Result.parser()) - if (result != null) { - this.result = result - } - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/MeasurementResultsReader.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/MeasurementResultsReader.kt deleted file mode 100644 index ae63b9c7232..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/MeasurementResultsReader.kt +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.readers - -import kotlinx.coroutines.flow.Flow -import org.wfanet.measurement.common.db.r2dbc.ReadContext -import org.wfanet.measurement.common.db.r2dbc.ResultRow -import org.wfanet.measurement.common.db.r2dbc.boundStatement -import org.wfanet.measurement.common.identity.InternalId -import org.wfanet.measurement.internal.reporting.Measurement - -class MeasurementResultsReader { - data class Result( - val reportId: InternalId, - val measurementReferenceId: String, - val state: Measurement.State, - val result: Measurement.Result?, - ) - - private val sql = - """ - SELECT - ReportId, - MeasurementReferenceId, - Measurements.State, - Result - FROM - ( - SELECT - MeasurementConsumerReferenceId, - ReportId - FROM - ReportMeasurements - WHERE - MeasurementConsumerReferenceId = $1 AND MeasurementReferenceId = $2 - ) AS Reports - JOIN ReportMeasurements USING (MeasurementConsumerReferenceId, ReportId) - JOIN Measurements USING (MeasurementConsumerReferenceId, MeasurementReferenceId) - """ - - fun translate(row: ResultRow): Result = - Result( - reportId = row["ReportId"], - measurementReferenceId = row["MeasurementReferenceId"], - state = Measurement.State.forNumber(row["State"]), - result = row.getProtoMessage("Result", Measurement.Result.parser()), - ) - - suspend fun listMeasurementsForReportsByMeasurementReferenceId( - readContext: ReadContext, - measurementConsumerReferenceId: String, - measurementReferenceId: String, - ): Flow { - val statement = - boundStatement(sql) { - bind("$1", measurementConsumerReferenceId) - bind("$2", measurementReferenceId) - } - - return readContext.executeQuery(statement).consume(::translate) - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/ReportReader.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/ReportReader.kt deleted file mode 100644 index 127a5d04cd6..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/ReportReader.kt +++ /dev/null @@ -1,527 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.readers - -import com.google.gson.JsonObject -import com.google.gson.JsonParser -import com.google.protobuf.duration -import com.google.protobuf.timestamp -import java.time.Instant -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.emitAll -import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.flow.singleOrNull -import org.wfanet.measurement.common.base64MimeDecode -import org.wfanet.measurement.common.db.r2dbc.DatabaseClient -import org.wfanet.measurement.common.db.r2dbc.ReadContext -import org.wfanet.measurement.common.db.r2dbc.ResultRow -import org.wfanet.measurement.common.db.r2dbc.boundStatement -import org.wfanet.measurement.common.identity.InternalId -import org.wfanet.measurement.common.toProtoTime -import org.wfanet.measurement.internal.reporting.Measurement -import org.wfanet.measurement.internal.reporting.Metric -import org.wfanet.measurement.internal.reporting.Metric.SetOperation -import org.wfanet.measurement.internal.reporting.MetricKt -import org.wfanet.measurement.internal.reporting.MetricKt.namedSetOperation -import org.wfanet.measurement.internal.reporting.MetricKt.setOperation -import org.wfanet.measurement.internal.reporting.PeriodicTimeInterval -import org.wfanet.measurement.internal.reporting.Report -import org.wfanet.measurement.internal.reporting.StreamReportsRequest -import org.wfanet.measurement.internal.reporting.TimeIntervals -import org.wfanet.measurement.internal.reporting.measurement -import org.wfanet.measurement.internal.reporting.metric -import org.wfanet.measurement.internal.reporting.periodicTimeInterval -import org.wfanet.measurement.internal.reporting.report -import org.wfanet.measurement.internal.reporting.timeInterval -import org.wfanet.measurement.internal.reporting.timeIntervals -import org.wfanet.measurement.reporting.service.internal.ReportNotFoundException - -class ReportReader { - data class Result( - val measurementConsumerReferenceId: String, - val reportId: InternalId, - val report: Report, - ) - - private val baseSql: String = - """ - SELECT - MeasurementConsumerReferenceId, - ReportId, - ExternalReportId, - State, - ReportDetails, - ReportIdempotencyKey, - CreateTime, - ARRAY( - SELECT - json_build_object( - 'timeIntervalId', TimeIntervalId, - 'startSeconds', StartSeconds, - 'startNanos', StartNanos, - 'endSeconds', EndSeconds, - 'endNanos', EndNanos - ) - FROM TimeIntervals - WHERE TimeIntervals.MeasurementConsumerReferenceId = Reports.MeasurementConsumerReferenceId - AND TimeIntervals.ReportId = Reports.ReportId - ) AS TimeIntervals, - ARRAY( - SELECT - json_build_object( - 'measurementReferenceId', MeasurementReferenceId, - 'state', Measurements.State, - 'failure', encode(Measurements.failure, 'base64'), - 'result', encode(Measurements.result, 'base64') - ) - FROM ( - SELECT - MeasurementConsumerReferenceId, - MeasurementReferenceId - FROM ReportMeasurements - WHERE ReportMeasurements.ReportId = Reports.ReportId - ) AS ReportMeasurements - JOIN Measurements USING(MeasurementConsumerReferenceId, MeasurementReferenceId) - ) AS Measurements, - IntervalCount, - StartSeconds, - StartNanos, - IncrementSeconds, - IncrementNanos, - ( - SELECT array_agg( - json_build_object( - 'metricId', MetricId, - 'metricDetails', MetricDetails, - 'namedSetOperations', NamedSetOperations, - 'setOperations', SetOperations - ) - ) - FROM ( - SELECT - MetricId, - encode(MetricDetails, 'base64') AS MetricDetails, - ( - SELECT json_agg( - json_build_object( - 'displayName', DisplayName, - 'setOperationId', SetOperationId, - 'measurementCalculations', MeasurementCalculations - ) - ) - FROM ( - SELECT - DisplayName, - SetOperationId, - ( - SELECT json_agg( - json_build_object( - 'timeInterval', TimeInterval, - 'weightedMeasurements', WeightedMeasurements - ) - ) - FROM ( - SELECT - ( - SELECT json_build_object( - 'startSeconds', StartSeconds, - 'startNanos', StartNanos, - 'endSeconds', EndSeconds, - 'endNanos', EndNanos - ) - FROM TimeIntervals - WHERE MeasurementCalculations.MeasurementConsumerReferenceId = TimeIntervals.MeasurementConsumerReferenceId - AND MeasurementCalculations.ReportId = TimeIntervals.ReportId - AND MeasurementCalculations.TimeIntervalId = TimeIntervals.TimeIntervalId - ) AS TimeInterval, - ( - SELECT json_agg( - json_build_object( - 'measurementReferenceId', MeasurementReferenceId, - 'coefficient', Coefficient - ) - ) - FROM WeightedMeasurements - Where WeightedMeasurements.MeasurementConsumerReferenceId = MeasurementCalculations.MeasurementConsumerReferenceId - AND WeightedMeasurements.ReportId = MeasurementCalculations.ReportId - AND WeightedMeasurements.MetricId = MeasurementCalculations.MetricId - AND WeightedMeasurements.NamedSetOperationId = MeasurementCalculations.NamedSetOperationId - AND WeightedMeasurements.MeasurementCalculationId = MeasurementCalculations.MeasurementCalculationId - ) AS WeightedMeasurements - FROM MeasurementCalculations - Where MeasurementCalculations.MeasurementConsumerReferenceId = NamedSetOperations.MeasurementConsumerReferenceId - AND MeasurementCalculations.ReportId = NamedSetOperations.ReportId - AND MeasurementCalculations.MetricId = NamedSetOperations.MetricId - AND MeasurementCalculations.NamedSetOperationId = NamedSetOperations.NamedSetOperationId - ) MeasurementCalculations - ) AS MeasurementCalculations - FROM NamedSetOperations - WHERE NamedSetOperations.MeasurementConsumerReferenceId = Metrics.MeasurementConsumerReferenceId - AND NamedSetOperations.ReportId = Metrics.ReportId - AND NamedSetOperations.MetricId = Metrics.MetricId - ) AS NamedSetOperations) AS NamedSetOperations, - ( - SELECT json_agg( - json_build_object( - 'type', Type, - 'setOperationId', SetOperationId, - 'leftHandSetOperationId', LeftHandSetOperationId, - 'rightHandSetOperationId', RightHandSetOperationId, - 'leftHandReportingSetId', - ( - SELECT ExternalReportingSetId - FROM ReportingSets - WHERE SetOperations.MeasurementConsumerReferenceId = ReportingSets.MeasurementConsumerReferenceId - AND ReportingSetId = LeftHandReportingSetId - ), - 'rightHandReportingSetId', - ( - SELECT ExternalReportingSetId - FROM ReportingSets - WHERE SetOperations.MeasurementConsumerReferenceId = ReportingSets.MeasurementConsumerReferenceId - AND ReportingSetId = RightHandReportingSetId - ) - ) - ) - FROM SetOperations - Where SetOperations.MeasurementConsumerReferenceId = Metrics.MeasurementConsumerReferenceId - AND SetOperations.ReportId = Metrics.ReportId - AND SetOperations.MetricId = Metrics.MetricId - ) AS SetOperations - FROM Metrics - WHERE Metrics.MeasurementConsumerReferenceId = Reports.MeasurementConsumerReferenceId - AND Metrics.ReportId = Reports.ReportId - ) AS Metrics - ) AS Metrics - FROM Reports - LEFT JOIN PeriodicTimeIntervals USING(MeasurementConsumerReferenceId, ReportId) - """ - - fun translate(row: ResultRow): Result = - Result(row["MeasurementConsumerReferenceId"], row["ReportId"], buildReport(row)) - - /** - * Gets the report by external report id. - * - * @throws [ReportNotFoundException] - */ - suspend fun getReportByExternalId( - readContext: ReadContext, - measurementConsumerReferenceId: String, - externalReportId: Long, - ): Result { - val statement = - boundStatement( - (baseSql + - """ - WHERE MeasurementConsumerReferenceId = $1 - AND ExternalReportId = $2 - """) - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", externalReportId) - } - - return readContext.executeQuery(statement).consume(::translate).singleOrNull() - ?: throw ReportNotFoundException() - } - - /** - * Gets the report by report id. - * - * @throws [ReportNotFoundException] - */ - suspend fun getReportById( - readContext: ReadContext, - measurementConsumerReferenceId: String, - reportId: Long, - ): Result { - val statement = - boundStatement( - (baseSql + - """ - WHERE MeasurementConsumerReferenceId = $1 - AND ReportId = $2 - """) - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", reportId) - } - - return readContext.executeQuery(statement).consume(::translate).singleOrNull() - ?: throw ReportNotFoundException() - } - - /** - * Gets the report by report idempotency key. - * - * @throws [ReportNotFoundException] - */ - suspend fun getReportByIdempotencyKey( - readContext: ReadContext, - measurementConsumerReferenceId: String, - reportIdempotencyKey: String, - ): Result { - val statement = - boundStatement( - (baseSql + - """ - WHERE MeasurementConsumerReferenceId = $1 - AND ReportIdempotencyKey = $2 - """) - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", reportIdempotencyKey) - } - - return readContext.executeQuery(statement).consume(::translate).singleOrNull() - ?: throw ReportNotFoundException() - } - - fun listReports( - client: DatabaseClient, - filter: StreamReportsRequest.Filter, - limit: Int = 0, - ): Flow { - val statement = - boundStatement( - baseSql + - """ - WHERE MeasurementConsumerReferenceId = $1 - AND ExternalReportId > $2 - ORDER BY ExternalReportId ASC - LIMIT $3 - """ - ) { - bind("$1", filter.measurementConsumerReferenceId) - bind("$2", filter.externalReportIdAfter) - if (limit > 0) { - bind("$3", limit) - } else { - bind("$3", 50) - } - } - - return flow { - val readContext = client.readTransaction() - try { - emitAll(readContext.executeQuery(statement).consume(::translate)) - } finally { - readContext.close() - } - } - } - - private fun buildReport(row: ResultRow): Report { - return report { - measurementConsumerReferenceId = row["MeasurementConsumerReferenceId"] - externalReportId = row["ExternalReportId"] - state = Report.State.forNumber(row["State"]) - details = row.getProtoMessage("ReportDetails", Report.Details.parser()) - reportIdempotencyKey = row["ReportIdempotencyKey"] - createTime = row.get("CreateTime").toProtoTime() - val intervalCount: Int? = row["IntervalCount"] - if (intervalCount != null) { - this.periodicTimeInterval = buildPeriodicTimeInterval(row) - } else { - timeIntervals = buildTimeIntervals(row["TimeIntervals"]) - } - metrics += buildMetrics(measurementConsumerReferenceId, row["Metrics"]) - measurements.putAll(buildMeasurements(measurementConsumerReferenceId, row["Measurements"])) - } - } - - private fun buildPeriodicTimeInterval(row: ResultRow): PeriodicTimeInterval { - return periodicTimeInterval { - startTime = timestamp { - seconds = row["StartSeconds"] - nanos = row["StartNanos"] - } - increment = duration { - seconds = row["IncrementSeconds"] - nanos = row["IncrementNanos"] - } - intervalCount = row["IntervalCount"] - } - } - - private fun buildTimeIntervals(timeIntervalsArr: Array): TimeIntervals { - return timeIntervals { - timeIntervalsArr.forEach { - val timeIntervalObject = JsonParser.parseString(it).asJsonObject - timeIntervals += timeInterval { - startTime = timestamp { - seconds = timeIntervalObject.getAsJsonPrimitive("startSeconds").asLong - nanos = timeIntervalObject.getAsJsonPrimitive("startNanos").asInt - } - endTime = timestamp { - seconds = timeIntervalObject.getAsJsonPrimitive("endSeconds").asLong - nanos = timeIntervalObject.getAsJsonPrimitive("endNanos").asInt - } - } - } - } - } - - private fun buildMeasurements( - measurementConsumerReferenceId: String, - measurementsArr: Array, - ): Map { - return measurementsArr - .map { JsonParser.parseString(it).asJsonObject } - .associateBy( - { it.getAsJsonPrimitive("measurementReferenceId").asString }, - { - measurement { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - measurementReferenceId = it.getAsJsonPrimitive("measurementReferenceId").asString - state = Measurement.State.forNumber(it.getAsJsonPrimitive("state").asInt) - if (!it.get("failure").isJsonNull) { - failure = - Measurement.Failure.parseFrom(it.getAsJsonPrimitive("failure").base64MimeDecode()) - } - if (!it.get("result").isJsonNull) { - result = - Measurement.Result.parseFrom(it.getAsJsonPrimitive("result").base64MimeDecode()) - } - } - }, - ) - } - - private fun buildMetrics( - measurementConsumerReferenceId: String, - metricsArr: Array, - ): Collection { - val metricsList = ArrayList(metricsArr.size) - metricsArr.forEach { - val metricObject = JsonParser.parseString(it).asJsonObject - metricsList.add( - metric { - details = - Metric.Details.parseFrom( - metricObject.getAsJsonPrimitive("metricDetails").base64MimeDecode() - ) - - val setOperationsArr = metricObject.getAsJsonArray("setOperations") - val setOperationsMap = mutableMapOf() - setOperationsArr.forEach { setOperationElement -> - val setOperationObject = setOperationElement.asJsonObject - setOperationsMap[setOperationObject.getAsJsonPrimitive("setOperationId").asLong] = - setOperationObject - } - - val namedSetOperationsArr = metricObject.getAsJsonArray("namedSetOperations") - namedSetOperationsArr.forEach { namedSetOperationElement -> - val namedSetOperationObject = namedSetOperationElement.asJsonObject - namedSetOperations += namedSetOperation { - displayName = namedSetOperationObject.getAsJsonPrimitive("displayName").asString - val setOperationId = - namedSetOperationObject.getAsJsonPrimitive("setOperationId").asLong - setOperation = - buildSetOperation(measurementConsumerReferenceId, setOperationId, setOperationsMap) - val measurementCalculationsArr = - namedSetOperationObject.getAsJsonArray("measurementCalculations") - measurementCalculationsArr.forEach { measurementCalculationElement -> - val measurementCalculationObject = measurementCalculationElement.asJsonObject - measurementCalculations += - MetricKt.measurementCalculation { - val timeIntervalObject = - measurementCalculationObject.getAsJsonObject("timeInterval") - timeInterval = timeInterval { - startTime = timestamp { - seconds = timeIntervalObject.getAsJsonPrimitive("startSeconds").asLong - nanos = timeIntervalObject.getAsJsonPrimitive("startNanos").asInt - } - endTime = timestamp { - seconds = timeIntervalObject.getAsJsonPrimitive("endSeconds").asLong - nanos = timeIntervalObject.getAsJsonPrimitive("endNanos").asInt - } - } - val weightedMeasurementsArr = - measurementCalculationObject.getAsJsonArray("weightedMeasurements") - weightedMeasurementsArr.forEach { weightedMeasurementElement -> - val weightedMeasurementObject = weightedMeasurementElement.asJsonObject - weightedMeasurements += - MetricKt.MeasurementCalculationKt.weightedMeasurement { - measurementReferenceId = - weightedMeasurementObject - .getAsJsonPrimitive("measurementReferenceId") - .asString - coefficient = - weightedMeasurementObject.getAsJsonPrimitive("coefficient").asInt - } - } - } - } - } - } - } - ) - } - return metricsList - } - - private fun buildSetOperation( - measurementConsumerReferenceId: String, - setOperationId: Long, - setOperationMap: Map, - ): SetOperation { - val setOperationObject = setOperationMap[setOperationId] - return setOperation { - type = SetOperation.Type.forNumber(setOperationObject!!.getAsJsonPrimitive("type").asInt) - lhs = - MetricKt.SetOperationKt.operand { - if (setOperationObject.get("leftHandSetOperationId").isJsonNull) { - if (!setOperationObject.get("leftHandReportingSetId").isJsonNull) { - reportingSetId = - MetricKt.SetOperationKt.reportingSetKey { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - externalReportingSetId = - setOperationObject.getAsJsonPrimitive("leftHandReportingSetId").asLong - } - } - } else { - operation = - buildSetOperation( - measurementConsumerReferenceId, - setOperationObject.getAsJsonPrimitive("leftHandSetOperationId").asLong, - setOperationMap, - ) - } - } - rhs = - MetricKt.SetOperationKt.operand { - if (setOperationObject.get("rightHandSetOperationId").isJsonNull) { - if (!setOperationObject.get("rightHandReportingSetId").isJsonNull) { - reportingSetId = - MetricKt.SetOperationKt.reportingSetKey { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - externalReportingSetId = - setOperationObject.getAsJsonPrimitive("rightHandReportingSetId").asLong - } - } - } else { - operation = - buildSetOperation( - measurementConsumerReferenceId, - setOperationObject.getAsJsonPrimitive("rightHandSetOperationId").asLong, - setOperationMap, - ) - } - } - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/ReportingSetReader.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/ReportingSetReader.kt deleted file mode 100644 index 8ded14f23f3..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/ReportingSetReader.kt +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.readers - -import com.google.gson.JsonObject -import com.google.gson.JsonParser -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.emitAll -import kotlinx.coroutines.flow.emptyFlow -import kotlinx.coroutines.flow.firstOrNull -import kotlinx.coroutines.flow.flow -import org.wfanet.measurement.common.db.r2dbc.DatabaseClient -import org.wfanet.measurement.common.db.r2dbc.ReadContext -import org.wfanet.measurement.common.db.r2dbc.ResultRow -import org.wfanet.measurement.common.db.r2dbc.boundStatement -import org.wfanet.measurement.common.identity.ExternalId -import org.wfanet.measurement.common.identity.InternalId -import org.wfanet.measurement.internal.reporting.ReportingSet -import org.wfanet.measurement.internal.reporting.ReportingSet.EventGroupKey -import org.wfanet.measurement.internal.reporting.ReportingSetKt -import org.wfanet.measurement.internal.reporting.StreamReportingSetsRequest -import org.wfanet.measurement.internal.reporting.reportingSet -import org.wfanet.measurement.reporting.service.internal.ReportingInternalException -import org.wfanet.measurement.reporting.service.internal.ReportingSetNotFoundException - -class ReportingSetReader { - data class Result( - val measurementConsumerReferenceId: String, - val reportingSetId: InternalId, - var reportingSet: ReportingSet, - ) - - private val baseSql: String = - """ - SELECT - MeasurementConsumerReferenceId, - ReportingSetId, - ExternalReportingSetId, - Filter, - DisplayName, - ( - SELECT ARRAY( - SELECT - json_build_object( - 'measurementConsumerReferenceId', MeasurementConsumerReferenceId, - 'dataProviderReferenceId', DataProviderReferenceId, - 'eventGroupReferenceId', EventGroupReferenceId - ) - FROM ReportingSetEventGroups - WHERE ReportingSetEventGroups.MeasurementConsumerReferenceId = ReportingSets.MeasurementConsumerReferenceId - AND ReportingSetEventGroups.ReportingSetId = ReportingSets.ReportingSetId - ) - ) AS EventGroups - FROM - ReportingSets - """ - - fun translate(row: ResultRow): Result = - Result(row["MeasurementConsumerReferenceId"], row["ReportingSetId"], buildReportingSet(row)) - - /** - * Reads a Reporting Set using external ID. - * - * Throws a subclass of [ReportingInternalException]. - * - * @throws [ReportingSetNotFoundException] Reporting Set not found. - */ - suspend fun readReportingSetByExternalId( - readContext: ReadContext, - measurementConsumerReferenceId: String, - externalReportingSetId: ExternalId, - ): Result { - val statement = - boundStatement( - baseSql + - """ - WHERE MeasurementConsumerReferenceId = $1 - AND ExternalReportingSetId = $2 - """ - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", externalReportingSetId) - } - - return readContext.executeQuery(statement).consume(::translate).firstOrNull() - ?: throw ReportingSetNotFoundException() - } - - fun listReportingSets( - client: DatabaseClient, - filter: StreamReportingSetsRequest.Filter, - limit: Int = 0, - ): Flow { - val statement = - boundStatement( - baseSql + - """ - WHERE MeasurementConsumerReferenceId = $1 - AND ExternalReportingSetId > $2 - ORDER BY ExternalReportingSetId ASC - LIMIT $3 - """ - ) { - bind("$1", filter.measurementConsumerReferenceId) - bind("$2", filter.externalReportingSetIdAfter) - if (limit > 0) { - bind("$3", limit) - } else { - bind("$3", 50) - } - } - - return flow { - val readContext = client.readTransaction() - try { - emitAll(readContext.executeQuery(statement).consume(::translate)) - } finally { - readContext.close() - } - } - } - - fun getReportingSetsByExternalIds( - client: DatabaseClient, - measurementConsumerReferenceId: String, - externalReportingSetIds: List, - ): Flow { - val sql = - StringBuilder( - baseSql + - """ - WHERE MeasurementConsumerReferenceId = $1 - AND ExternalReportingSetId IN - """ - ) - - if (externalReportingSetIds.isEmpty()) { - return emptyFlow() - } - - var i = 2 - val bindingMap = mutableMapOf() - val inList = - externalReportingSetIds.joinToString(separator = ",", prefix = "(", postfix = ")") { - val index = "$$i" - bindingMap[it] = "$$i" - i++ - index - } - sql.append(inList) - - val statement = - boundStatement(sql.toString()) { - bind("$1", measurementConsumerReferenceId) - - externalReportingSetIds.forEach { bind(bindingMap.getValue(it), it) } - } - - return flow { - val readContext = client.readTransaction() - try { - emitAll(readContext.executeQuery(statement).consume(::translate)) - } finally { - readContext.close() - } - } - } - - private fun buildReportingSet(row: ResultRow): ReportingSet { - return reportingSet { - measurementConsumerReferenceId = row["MeasurementConsumerReferenceId"] - externalReportingSetId = row["ExternalReportingSetId"] - filter = row["Filter"] - displayName = row["DisplayName"] - val eventGroupsArr = row.get>("EventGroups") - eventGroupsArr.forEach { - eventGroupKeys += buildEventGroupKey(JsonParser.parseString(it).asJsonObject) - } - } - } - - private fun buildEventGroupKey(eventGroup: JsonObject): EventGroupKey { - return ReportingSetKt.eventGroupKey { - measurementConsumerReferenceId = - eventGroup.getAsJsonPrimitive("measurementConsumerReferenceId").asString - dataProviderReferenceId = eventGroup.getAsJsonPrimitive("dataProviderReferenceId").asString - eventGroupReferenceId = eventGroup.getAsJsonPrimitive("eventGroupReferenceId").asString - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/server/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/server/BUILD.bazel deleted file mode 100644 index fddb2884cc9..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/server/BUILD.bazel +++ /dev/null @@ -1,29 +0,0 @@ -load("@rules_java//java:defs.bzl", "java_binary") -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") -load("//src/main/docker:macros.bzl", "java_image") - -kt_jvm_library( - name = "postgres_reporting_data_server", - srcs = ["PostgresReportingDataServer.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server:reporting_data_server", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres:services", - "@wfa_common_jvm//imports/java/picocli", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/postgres:flags", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres", - ], -) - -java_binary( - name = "PostgresReportingDataServer", - main_class = "org.wfanet.measurement.reporting.deploy.postgres.server.PostgresReportingDataServerKt", - runtime_deps = [":postgres_reporting_data_server"], -) - -java_image( - name = "postgres_reporting_data_server_image", - binary = ":PostgresReportingDataServer", - main_class = "org.wfanet.measurement.reporting.deploy.postgres.server.PostgresReportingDataServerKt", - visibility = ["//src:docker_image_deployment"], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/server/PostgresReportingDataServer.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/server/PostgresReportingDataServer.kt deleted file mode 100644 index b84b5ad9acd..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/server/PostgresReportingDataServer.kt +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.server - -import java.time.Clock -import kotlinx.coroutines.runBlocking -import org.wfanet.measurement.common.commandLineMain -import org.wfanet.measurement.common.db.postgres.PostgresFlags -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresDatabaseClient -import org.wfanet.measurement.common.identity.RandomIdGenerator -import org.wfanet.measurement.reporting.deploy.common.server.ReportingDataServer -import org.wfanet.measurement.reporting.deploy.common.server.postgres.PostgresServices -import picocli.CommandLine - -/** Implementation of [ReportingDataServer] using Postgres. */ -@CommandLine.Command( - name = "PostgresReportingDataServer", - description = ["Start the internal Reporting data-layer services in a single blocking server."], - mixinStandardHelpOptions = true, - showDefaultValues = true, -) -class PostgresReportingDataServer : ReportingDataServer() { - @CommandLine.Mixin private lateinit var postgresFlags: PostgresFlags - - override fun run() = runBlocking { - val clock = Clock.systemUTC() - val idGenerator = RandomIdGenerator(clock) - - val client = PostgresDatabaseClient.fromFlags(postgresFlags) - - run(PostgresServices.create(idGenerator, client)) - } -} - -fun main(args: Array) = commandLineMain(PostgresReportingDataServer(), args) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing/BUILD.bazel deleted file mode 100644 index 45fba8267c3..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing/BUILD.bazel +++ /dev/null @@ -1,18 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -package( - default_testonly = True, - default_visibility = [ - "//src/test/kotlin/org/wfanet/measurement/integration/deploy/common/postgres:__subpackages__", - "//src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres:__subpackages__", - ], -) - -kt_jvm_library( - name = "testing", - srcs = glob(["*.kt"]), - resources = ["//src/main/resources/reporting/postgres"], - deps = [ - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing/Schemata.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing/Schemata.kt deleted file mode 100644 index bf9233ecdda..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing/Schemata.kt +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.testing - -import java.nio.file.Path -import org.wfanet.measurement.common.getJarResourcePath - -object Schemata { - private const val REPORTING_POSTGRES_RESOURCE_PREFIX = "reporting/postgres" - - private fun getResourcePath(fileName: String): Path { - val resourceName = "$REPORTING_POSTGRES_RESOURCE_PREFIX/$fileName" - val classLoader: ClassLoader = Thread.currentThread().contextClassLoader - return requireNotNull(classLoader.getJarResourcePath(resourceName)) { - "Resource $resourceName not found" - } - } - - val REPORTING_CHANGELOG_PATH: Path = getResourcePath("changelog.yaml") -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/tools/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/tools/BUILD.bazel deleted file mode 100644 index 31c1f110702..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/tools/BUILD.bazel +++ /dev/null @@ -1,20 +0,0 @@ -load("@rules_java//java:defs.bzl", "java_binary") -load("//src/main/docker:macros.bzl", "java_image") - -java_binary( - name = "UpdateSchema", - args = ["--changelog=reporting/postgres/changelog.yaml"], - main_class = "org.wfanet.measurement.common.db.postgres.tools.UpdateSchema", - resources = ["//src/main/resources/reporting/postgres"], - runtime_deps = [ - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/postgres/tools:update_schema", - ], -) - -java_image( - name = "update_schema_image", - args = ["--changelog=reporting/postgres/changelog.yaml"], - binary = ":UpdateSchema", - main_class = "org.wfanet.measurement.common.db.postgres.tools.UpdateSchema", - visibility = ["//src:docker_image_deployment"], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/BUILD.bazel deleted file mode 100644 index 6073e762463..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/BUILD.bazel +++ /dev/null @@ -1,27 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -package(default_visibility = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres:__subpackages__", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres:__subpackages__", -]) - -kt_jvm_library( - name = "writers", - srcs = glob(["*.kt"]), - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/internal:internal_exception", - "//src/main/proto/wfa/measurement/internal/reporting:measurement_kt_jvm_proto", - "//src/main/proto/wfa/measurement/internal/reporting:metric_kt_jvm_proto", - "//src/main/proto/wfa/measurement/internal/reporting:report_kt_jvm_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reporting_set_kt_jvm_proto", - "@wfa_common_jvm//imports/java/com/google/protobuf", - "@wfa_common_jvm//imports/java/io/r2dbc", - "@wfa_common_jvm//imports/java/org/postgresql:r2dbc", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateMeasurements.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateMeasurements.kt deleted file mode 100644 index 65dd19af31b..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateMeasurements.kt +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.writers - -import org.wfanet.measurement.common.db.r2dbc.boundStatement -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter -import org.wfanet.measurement.internal.reporting.Measurement -import org.wfanet.measurement.internal.reporting.copy -import org.wfanet.measurement.reporting.service.internal.MeasurementAlreadyExistsException - -/** - * Inserts Measurements into the database. - * - * Throws the following on [execute]: - * * [MeasurementAlreadyExistsException] Measurement already exists - */ -class CreateMeasurements(private val measurements: Iterable) : - PostgresWriter>() { - override suspend fun TransactionScope.runTransaction(): Iterable { - transactionContext.run { - for (measurement in measurements) { - val builder = - boundStatement( - """ - INSERT INTO Measurements (MeasurementConsumerReferenceId, MeasurementReferenceId, State) - VALUES ($1, $2, $3) - ON CONFLICT DO NOTHING - """ - ) { - bind("$1", measurement.measurementConsumerReferenceId) - bind("$2", measurement.measurementReferenceId) - bind("$3", Measurement.State.PENDING_VALUE) - } - - executeStatement(builder) - } - } - - return measurements.map { it.copy { state = Measurement.State.PENDING } } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateReport.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateReport.kt deleted file mode 100644 index fd358b20f9a..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateReport.kt +++ /dev/null @@ -1,482 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.writers - -import com.google.protobuf.util.Timestamps -import org.wfanet.measurement.common.db.r2dbc.boundStatement -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter -import org.wfanet.measurement.common.identity.ExternalId -import org.wfanet.measurement.internal.reporting.CreateReportRequest -import org.wfanet.measurement.internal.reporting.CreateReportRequest.MeasurementKey -import org.wfanet.measurement.internal.reporting.Measurement -import org.wfanet.measurement.internal.reporting.Metric -import org.wfanet.measurement.internal.reporting.Metric.MeasurementCalculation.WeightedMeasurement -import org.wfanet.measurement.internal.reporting.Metric.NamedSetOperation -import org.wfanet.measurement.internal.reporting.Metric.SetOperation -import org.wfanet.measurement.internal.reporting.PeriodicTimeInterval -import org.wfanet.measurement.internal.reporting.Report -import org.wfanet.measurement.internal.reporting.TimeInterval -import org.wfanet.measurement.internal.reporting.copy -import org.wfanet.measurement.internal.reporting.timeInterval -import org.wfanet.measurement.reporting.deploy.postgres.readers.ReportingSetReader -import org.wfanet.measurement.reporting.service.internal.MeasurementCalculationTimeIntervalNotFoundException -import org.wfanet.measurement.reporting.service.internal.ReportingSetNotFoundException - -/** - * Inserts a Report into the database. - * - * Throws the following on [execute]: - * * [ReportingSetNotFoundException] ReportingSet not found - * * [MeasurementCalculationTimeIntervalNotFoundException] MeasurementCalculation TimeInterval not - * found. - */ -class CreateReport(private val request: CreateReportRequest) : PostgresWriter() { - override suspend fun TransactionScope.runTransaction(): Report { - val report = request.report - val internalReportId = idGenerator.generateInternalId().value - val externalReportId = idGenerator.generateExternalId().value - - val timeIntervals: List - val isPeriodic = report.timeCase == Report.TimeCase.PERIODIC_TIME_INTERVAL - if (isPeriodic) { - timeIntervals = ArrayList(report.periodicTimeInterval.intervalCount) - timeIntervals.add( - timeInterval { - startTime = report.periodicTimeInterval.startTime - endTime = Timestamps.add(startTime, report.periodicTimeInterval.increment) - } - ) - for (i in 1 until report.periodicTimeInterval.intervalCount) { - timeIntervals.add( - timeInterval { - startTime = timeIntervals[i - 1].endTime - endTime = Timestamps.add(startTime, report.periodicTimeInterval.increment) - } - ) - } - } else { - timeIntervals = report.timeIntervals.timeIntervalsList - } - - val statement = - boundStatement( - """ - INSERT INTO Reports (MeasurementConsumerReferenceId, ReportId, ExternalReportId, State, ReportDetails, ReportIdempotencyKey, CreateTime) - VALUES ($1, $2, $3, $4, $5, $6, now() at time zone 'utc') - """ - ) { - bind("$1", report.measurementConsumerReferenceId) - bind("$2", internalReportId) - bind("$3", externalReportId) - bind("$4", Report.State.RUNNING_VALUE) - bind("$5", report.details) - bind("$6", report.reportIdempotencyKey) - } - - transactionContext.run { - executeStatement(statement) - if (isPeriodic) { - insertPeriodicTimeInterval( - report.measurementConsumerReferenceId, - internalReportId, - report.periodicTimeInterval, - ) - } - insertMeasurements(request.measurementsList) - val timeIntervalMap = - insertTimeIntervals(report.measurementConsumerReferenceId, internalReportId, timeIntervals) - - report.metricsList.forEach { - insertMetric(report.measurementConsumerReferenceId, internalReportId, timeIntervalMap, it) - } - insertReportMeasurements(request.measurementsList, internalReportId) - } - - return report.copy { - this.externalReportId = externalReportId - state = Report.State.RUNNING - } - } - - private suspend fun TransactionScope.insertMeasurements( - measurements: Collection - ) { - val sql = - StringBuilder( - """ - INSERT INTO Measurements (MeasurementConsumerReferenceId, MeasurementReferenceId, State) - VALUES ($1, $2, $3) - """ - ) - val numParameters = 3 - var firstParam = 1 - for (i in 1 until measurements.size) { - firstParam += numParameters - sql.append(",($$firstParam, $${firstParam + 1}, $${firstParam + 2})") - } - sql.append("ON CONFLICT DO NOTHING") - - firstParam = 1 - val statement = - boundStatement(sql.toString()) { - measurements.forEach { - bind("$$firstParam", it.measurementConsumerReferenceId) - bind("$${firstParam + 1}", it.measurementReferenceId) - bind("$${firstParam + 2}", Measurement.State.PENDING_VALUE) - firstParam += numParameters - } - } - transactionContext.executeStatement(statement) - } - - private suspend fun TransactionScope.insertReportMeasurements( - measurements: Collection, - reportId: Long, - ) { - val sql = - StringBuilder( - """ - INSERT INTO ReportMeasurements (MeasurementConsumerReferenceId, MeasurementReferenceId, ReportId) - VALUES ($1, $2, $3) - """ - ) - val numParameters = 3 - var firstParam = 1 - for (i in 1 until measurements.size) { - firstParam += numParameters - sql.append(",($$firstParam, $${firstParam + 1}, $${firstParam + 2})") - } - - firstParam = 1 - val statement = - boundStatement(sql.toString()) { - measurements.forEach { - bind("$$firstParam", it.measurementConsumerReferenceId) - bind("$${firstParam + 1}", it.measurementReferenceId) - bind("$${firstParam + 2}", reportId) - firstParam += numParameters - } - } - transactionContext.executeStatement(statement) - } - - private suspend fun TransactionScope.insertPeriodicTimeInterval( - measurementConsumerReferenceId: String, - reportId: Long, - periodicTimeInterval: PeriodicTimeInterval, - ) { - val statement = - boundStatement( - """ - INSERT INTO PeriodicTimeIntervals (MeasurementConsumerReferenceId, ReportId, StartSeconds, StartNanos, IncrementSeconds, IncrementNanos, IntervalCount) - VALUES ($1, $2, $3, $4, $5, $6, $7) - """ - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", reportId) - bind("$3", periodicTimeInterval.startTime.seconds) - bind("$4", periodicTimeInterval.startTime.nanos) - bind("$5", periodicTimeInterval.increment.seconds) - bind("$6", periodicTimeInterval.increment.nanos) - bind("$7", periodicTimeInterval.intervalCount) - } - - transactionContext.executeStatement(statement) - } - - private suspend fun TransactionScope.insertTimeIntervals( - measurementConsumerReferenceId: String, - reportId: Long, - timeIntervals: Collection, - ): Map { - val sql = - StringBuilder( - """ - INSERT INTO TimeIntervals (MeasurementConsumerReferenceId, ReportId, TimeIntervalId, StartSeconds, StartNanos, EndSeconds, EndNanos) - VALUES ($1, $2, $3, $4, $5, $6, $7) - """ - ) - val numParameters = 7 - var firstParam = 1 - for (i in 1 until timeIntervals.size) { - firstParam += numParameters - sql.append( - """,($$firstParam, $${firstParam + 1}, $${firstParam + 2}, $${firstParam + 3}, - $${firstParam + 4}, $${firstParam + 5}, $${firstParam + 6}) - """ - ) - } - - val timeIntervalMap = mutableMapOf() - firstParam = 1 - val statement = - boundStatement(sql.toString()) { - timeIntervals.forEach { - val timeIntervalId = idGenerator.generateInternalId().value - timeIntervalMap[it] = timeIntervalId - bind("$$firstParam", measurementConsumerReferenceId) - bind("$${firstParam + 1}", reportId) - bind("$${firstParam + 2}", timeIntervalId) - bind("$${firstParam + 3}", it.startTime.seconds) - bind("$${firstParam + 4}", it.startTime.nanos) - bind("$${firstParam + 5}", it.endTime.seconds) - bind("$${firstParam + 6}", it.endTime.nanos) - firstParam += numParameters - } - } - transactionContext.executeStatement(statement) - return timeIntervalMap - } - - private suspend fun TransactionScope.insertMetric( - measurementConsumerReferenceId: String, - reportId: Long, - timeIntervalMap: Map, - metric: Metric, - ) { - val metricId = idGenerator.generateInternalId().value - - val statement = - boundStatement( - """ - INSERT INTO Metrics (MeasurementConsumerReferenceId, ReportId, MetricId, MetricDetails) - VALUES ($1, $2, $3, $4) - """ - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", reportId) - bind("$3", metricId) - bind("$4", metric.details) - } - - transactionContext.executeStatement(statement) - - metric.namedSetOperationsList.forEach { - insertNamedSetOperation( - measurementConsumerReferenceId, - reportId, - metricId, - timeIntervalMap, - it, - ) - } - } - - private suspend fun TransactionScope.insertNamedSetOperation( - measurementConsumerReferenceId: String, - reportId: Long, - metricId: Long, - timeIntervalMap: Map, - namedSetOperation: NamedSetOperation, - ) { - val namedSetOperationId = idGenerator.generateInternalId().value - val setOperationId = - insertSetOperation( - measurementConsumerReferenceId, - reportId, - metricId, - namedSetOperation.setOperation, - ) - - val statement = - boundStatement( - """ - INSERT INTO NamedSetOperations(MeasurementConsumerReferenceId, ReportId, MetricId, NamedSetOperationId, DisplayName, SetOperationId) - VALUES ($1, $2, $3, $4, $5, $6) - """ - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", reportId) - bind("$3", metricId) - bind("$4", namedSetOperationId) - bind("$5", namedSetOperation.displayName) - bind("$6", setOperationId) - } - - transactionContext.executeStatement(statement) - - namedSetOperation.measurementCalculationsList.forEach { - insertMeasurementCalculations( - measurementConsumerReferenceId, - reportId, - metricId, - namedSetOperationId, - timeIntervalMap[it.timeInterval] - ?: throw MeasurementCalculationTimeIntervalNotFoundException(), - it, - ) - } - } - - private suspend fun TransactionScope.insertMeasurementCalculations( - measurementConsumerReferenceId: String, - reportId: Long, - metricId: Long, - namedSetOperationId: Long, - timeIntervalId: Long, - measurementCalculation: Metric.MeasurementCalculation, - ) { - val measurementCalculationId = idGenerator.generateInternalId().value - val statement = - boundStatement( - """ - INSERT INTO MeasurementCalculations(MeasurementConsumerReferenceId, ReportId, MetricId, NamedSetOperationId, MeasurementCalculationId, TimeIntervalId) - VALUES ($1, $2, $3, $4, $5, $6) - """ - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", reportId) - bind("$3", metricId) - bind("$4", namedSetOperationId) - bind("$5", measurementCalculationId) - bind("$6", timeIntervalId) - } - - transactionContext.executeStatement(statement) - insertWeightedMeasurements( - measurementConsumerReferenceId, - reportId, - metricId, - namedSetOperationId, - measurementCalculationId, - measurementCalculation.weightedMeasurementsList, - ) - } - - private suspend fun TransactionScope.insertWeightedMeasurements( - measurementConsumerReferenceId: String, - reportId: Long, - metricId: Long, - namedSetOperationId: Long, - measurementCalculationId: Long, - weightedMeasurements: Collection, - ) { - transactionContext.run { - weightedMeasurements.forEach { - val weightedMeasurementId = idGenerator.generateInternalId().value - val statement = - boundStatement( - """ - INSERT INTO WeightedMeasurements(MeasurementConsumerReferenceId, ReportId, MetricId, NamedSetOperationId, MeasurementCalculationId, WeightedMeasurementId, MeasurementReferenceId, Coefficient) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - """ - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", reportId) - bind("$3", metricId) - bind("$4", namedSetOperationId) - bind("$5", measurementCalculationId) - bind("$6", weightedMeasurementId) - bind("$7", it.measurementReferenceId) - bind("$8", it.coefficient) - } - executeStatement(statement) - } - } - } - - private suspend fun TransactionScope.insertSetOperation( - measurementConsumerReferenceId: String, - reportId: Long, - metricId: Long, - setOperation: SetOperation, - ): Long { - val setOperationId = idGenerator.generateInternalId().value - val lhsReportingSetId: Long? - val rhsReportingSetId: Long? - val lhsSetOperationId: Long? - val rhsSetOperationId: Long? - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") - when (setOperation.lhs.operandCase) { - SetOperation.Operand.OperandCase.REPORTINGSETID -> { - lhsSetOperationId = null - val reportingSetResult = - ReportingSetReader() - .readReportingSetByExternalId( - transactionContext, - measurementConsumerReferenceId, - ExternalId(setOperation.lhs.reportingSetId.externalReportingSetId), - ) - lhsReportingSetId = reportingSetResult.reportingSetId.value - } - SetOperation.Operand.OperandCase.OPERATION -> { - lhsSetOperationId = - insertSetOperation( - measurementConsumerReferenceId, - reportId, - metricId, - setOperation.lhs.operation, - ) - lhsReportingSetId = null - } - SetOperation.Operand.OperandCase.OPERAND_NOT_SET -> { - lhsSetOperationId = null - lhsReportingSetId = null - } - } - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") - when (setOperation.rhs.operandCase) { - SetOperation.Operand.OperandCase.REPORTINGSETID -> { - rhsSetOperationId = null - val reportingSetResult = - ReportingSetReader() - .readReportingSetByExternalId( - transactionContext, - measurementConsumerReferenceId, - ExternalId(setOperation.rhs.reportingSetId.externalReportingSetId), - ) - rhsReportingSetId = reportingSetResult.reportingSetId.value - } - SetOperation.Operand.OperandCase.OPERATION -> { - rhsSetOperationId = - insertSetOperation( - measurementConsumerReferenceId, - reportId, - metricId, - setOperation.rhs.operation, - ) - rhsReportingSetId = null - } - SetOperation.Operand.OperandCase.OPERAND_NOT_SET -> { - rhsSetOperationId = null - rhsReportingSetId = null - } - } - - val statement = - boundStatement( - """ - INSERT INTO SetOperations(MeasurementConsumerReferenceId, ReportId, MetricId, SetOperationId, Type, LeftHandSetOperationId, LeftHandReportingSetId, RightHandSetOperationId, RightHandReportingSetId) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) - """ - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", reportId) - bind("$3", metricId) - bind("$4", setOperationId) - bind("$5", setOperation.typeValue) - bind("$6", lhsSetOperationId) - bind("$7", lhsReportingSetId) - bind("$8", rhsSetOperationId) - bind("$9", rhsReportingSetId) - } - - transactionContext.executeStatement(statement) - - return setOperationId - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateReportingSet.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateReportingSet.kt deleted file mode 100644 index 12e734bc410..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateReportingSet.kt +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.writers - -import io.r2dbc.spi.R2dbcDataIntegrityViolationException -import kotlinx.coroutines.coroutineScope -import org.wfanet.measurement.common.db.r2dbc.boundStatement -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter -import org.wfanet.measurement.common.identity.ExternalId -import org.wfanet.measurement.common.identity.InternalId -import org.wfanet.measurement.internal.reporting.ReportingSet -import org.wfanet.measurement.internal.reporting.copy -import org.wfanet.measurement.reporting.service.internal.ReportingSetAlreadyExistsException - -/** - * Inserts a Reporting Set into the database. - * - * Throws the following on [execute]: - * * [ReportingSetAlreadyExistsException] ReportingSet already exists - */ -class CreateReportingSet(private val request: ReportingSet) : PostgresWriter() { - override suspend fun TransactionScope.runTransaction(): ReportingSet { - val internalReportingSetId = idGenerator.generateInternalId() - val externalReportingSetId = idGenerator.generateExternalId() - - insertReportingSet(internalReportingSetId, externalReportingSetId) - coroutineScope { - for (i in 0 until request.eventGroupKeysList.size) { - insertReportingSetEventGroup(request.eventGroupKeysList[i], internalReportingSetId) - } - } - - return request.copy { this.externalReportingSetId = externalReportingSetId.value } - } - - private suspend fun TransactionScope.insertReportingSet( - internalReportingSetId: InternalId, - externalReportingSetId: ExternalId, - ) { - val statement = - boundStatement( - """ - INSERT INTO ReportingSets (MeasurementConsumerReferenceId, ReportingSetId, ExternalReportingSetId, Filter, DisplayName) - VALUES ($1, $2, $3, $4, $5) - """ - ) { - bind("$1", request.measurementConsumerReferenceId) - bind("$2", internalReportingSetId) - bind("$3", externalReportingSetId) - bind("$4", request.filter) - bind("$5", request.displayName) - } - - try { - transactionContext.executeStatement(statement) - } catch (e: R2dbcDataIntegrityViolationException) { - throw ReportingSetAlreadyExistsException() - } - } - - private suspend fun TransactionScope.insertReportingSetEventGroup( - eventGroupKey: ReportingSet.EventGroupKey, - reportingSetId: InternalId, - ) { - val statement = - boundStatement( - """ - INSERT INTO ReportingSetEventGroups (MeasurementConsumerReferenceId, DataProviderReferenceId, EventGroupReferenceId, ReportingSetId) - VALUES ($1, $2, $3, $4) - """ - ) { - bind("$1", eventGroupKey.measurementConsumerReferenceId) - bind("$2", eventGroupKey.dataProviderReferenceId) - bind("$3", eventGroupKey.eventGroupReferenceId) - bind("$4", reportingSetId) - } - - transactionContext.executeStatement(statement) - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/SetMeasurementFailure.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/SetMeasurementFailure.kt deleted file mode 100644 index 830f7d2a6e1..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/SetMeasurementFailure.kt +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.writers - -import org.wfanet.measurement.common.db.r2dbc.boundStatement -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter -import org.wfanet.measurement.internal.reporting.Measurement -import org.wfanet.measurement.internal.reporting.Report -import org.wfanet.measurement.internal.reporting.SetMeasurementFailureRequest -import org.wfanet.measurement.internal.reporting.measurement -import org.wfanet.measurement.reporting.deploy.postgres.readers.MeasurementReader -import org.wfanet.measurement.reporting.service.internal.MeasurementNotFoundException -import org.wfanet.measurement.reporting.service.internal.MeasurementStateInvalidException - -/** - * Update a [Measurement] to be in a failure state along with any dependent [Report]. - * - * Throws the following on [execute]: - * * [MeasurementNotFoundException] Measurement not found. - * * [MeasurementStateInvalidException] Measurement does not have PENDING state. - */ -class SetMeasurementFailure(private val request: SetMeasurementFailureRequest) : - PostgresWriter() { - override suspend fun TransactionScope.runTransaction(): Measurement { - val measurementResult = - MeasurementReader() - .readMeasurementByReferenceIds( - transactionContext, - measurementConsumerReferenceId = request.measurementConsumerReferenceId, - measurementReferenceId = request.measurementReferenceId, - ) ?: throw MeasurementNotFoundException() - - if (measurementResult.measurement.state != Measurement.State.PENDING) { - throw MeasurementStateInvalidException() - } - - val updateMeasurementStatement = - boundStatement( - """ - UPDATE Measurements - SET State = $1, Failure = $2 - WHERE MeasurementConsumerReferenceId = $3 AND MeasurementReferenceId = $4 - """ - ) { - bind("$1", Measurement.State.FAILED_VALUE) - bind("$2", request.failure) - bind("$3", request.measurementConsumerReferenceId) - bind("$4", request.measurementReferenceId) - } - - val updateReportStatement = - boundStatement( - """ - UPDATE Reports - SET State = $1 - FROM ( - SELECT - ReportId - FROM ReportMeasurements - WHERE MeasurementConsumerReferenceId = $2 AND MeasurementReferenceId = $3 - ) AS ReportMeasurements - WHERE Reports.ReportId = ReportMeasurements.ReportId - """ - ) { - bind("$1", Report.State.FAILED_VALUE) - bind("$2", request.measurementConsumerReferenceId) - bind("$3", request.measurementReferenceId) - } - - transactionContext.run { - val numRowsUpdated = executeStatement(updateMeasurementStatement).numRowsUpdated - if (numRowsUpdated == 0L) { - return@run - } - executeStatement(updateReportStatement) - } - - return measurement { - measurementConsumerReferenceId = request.measurementConsumerReferenceId - measurementReferenceId = request.measurementReferenceId - state = Measurement.State.FAILED - failure = request.failure - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/SetMeasurementResult.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/SetMeasurementResult.kt deleted file mode 100644 index 67db2e99ae2..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/SetMeasurementResult.kt +++ /dev/null @@ -1,411 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.writers - -import com.google.protobuf.Duration -import com.google.protobuf.duration -import com.google.protobuf.util.Durations -import com.google.protobuf.util.Timestamps -import java.time.Instant -import java.util.concurrent.TimeUnit -import org.wfanet.measurement.common.db.r2dbc.boundStatement -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter -import org.wfanet.measurement.internal.reporting.Measurement -import org.wfanet.measurement.internal.reporting.Metric -import org.wfanet.measurement.internal.reporting.Metric.MeasurementCalculation -import org.wfanet.measurement.internal.reporting.PeriodicTimeInterval -import org.wfanet.measurement.internal.reporting.Report -import org.wfanet.measurement.internal.reporting.ReportKt -import org.wfanet.measurement.internal.reporting.SetMeasurementResultRequest -import org.wfanet.measurement.internal.reporting.TimeInterval -import org.wfanet.measurement.internal.reporting.copy -import org.wfanet.measurement.internal.reporting.measurement -import org.wfanet.measurement.internal.reporting.timeInterval -import org.wfanet.measurement.reporting.deploy.postgres.readers.MeasurementReader -import org.wfanet.measurement.reporting.deploy.postgres.readers.MeasurementResultsReader -import org.wfanet.measurement.reporting.deploy.postgres.readers.ReportReader -import org.wfanet.measurement.reporting.service.internal.MeasurementNotFoundException -import org.wfanet.measurement.reporting.service.internal.MeasurementStateInvalidException - -private const val NANOS_PER_SECOND = 1_000_000_000 - -/** - * Updates the result for a Measurement and for the corresponding Report too if the report result - * has been computed. - * - * Throws the following on [execute]: - * * [MeasurementNotFoundException] Measurement not found. - * * [MeasurementStateInvalidException] Measurement does not have PENDING state. - */ -class SetMeasurementResult(private val request: SetMeasurementResultRequest) : - PostgresWriter() { - data class MeasurementResult(val result: Measurement.Result, val coefficient: Int) - - override suspend fun TransactionScope.runTransaction(): Measurement { - val measurementResult = - MeasurementReader() - .readMeasurementByReferenceIds( - transactionContext, - measurementConsumerReferenceId = request.measurementConsumerReferenceId, - measurementReferenceId = request.measurementReferenceId, - ) ?: throw MeasurementNotFoundException() - - if (measurementResult.measurement.state != Measurement.State.PENDING) { - throw MeasurementStateInvalidException() - } - - val updateMeasurementStatement = - boundStatement( - """ - UPDATE Measurements - SET State = $1, Result = $2 - WHERE MeasurementConsumerReferenceId = $3 AND MeasurementReferenceId = $4 - """ - ) { - bind("$1", Measurement.State.SUCCEEDED_VALUE) - bind("$2", request.result) - bind("$3", request.measurementConsumerReferenceId) - bind("$4", request.measurementReferenceId) - } - - transactionContext.run { - val numRowsUpdated = executeStatement(updateMeasurementStatement).numRowsUpdated - if (numRowsUpdated == 0L) { - return@run - } - - val measurementResultsMap = mutableMapOf() - val reportsSet = mutableSetOf() - val incompleteReportsSet = mutableSetOf() - MeasurementResultsReader() - .listMeasurementsForReportsByMeasurementReferenceId( - transactionContext, - request.measurementConsumerReferenceId, - request.measurementReferenceId, - ) - .collect { result -> - if (result.state == Measurement.State.SUCCEEDED) { - reportsSet.add(result.reportId.value) - measurementResultsMap[result.measurementReferenceId] = result - } else { - incompleteReportsSet.add(result.reportId.value) - } - } - - reportsSet.forEach { - if (incompleteReportsSet.contains(it)) { - return@forEach - } - - val reportResult = - ReportReader() - .getReportById(transactionContext, request.measurementConsumerReferenceId, it) - - val updatedDetails = - reportResult.report.details.copy { - this.result = constructResult(reportResult.report, measurementResultsMap) - } - - val updateReportStatement = - boundStatement( - """ - UPDATE Reports - SET ReportDetails = $1, State = $2 - WHERE MeasurementConsumerReferenceId = $3 - AND ReportId = $4 - """ - ) { - bind("$1", updatedDetails) - bind("$2", Report.State.SUCCEEDED.number) - bind("$3", request.measurementConsumerReferenceId) - bind("$4", it) - } - executeStatement(updateReportStatement) - } - } - - return measurement { - measurementConsumerReferenceId = request.measurementConsumerReferenceId - measurementReferenceId = request.measurementReferenceId - state = Measurement.State.SUCCEEDED - result = request.result - } - } - - private fun constructResult( - report: Report, - measurementResultsMap: Map, - ): Report.Details.Result { - return ReportKt.DetailsKt.result { - val rowHeaders = getRowHeaders(report) - val scalarTableColumnsList = ArrayList() - // Each metric contains results for several columns - for (metric in report.metricsList) { - when (val metricType = metric.details.metricTypeCase) { - // REACH, IMPRESSION_COUNT, and WATCH_DURATION are aggregated in one table. - Metric.Details.MetricTypeCase.REACH, - Metric.Details.MetricTypeCase.IMPRESSION_COUNT, - Metric.Details.MetricTypeCase.WATCH_DURATION -> { - // One namedSetOperation is one column in the report - for (namedSetOperation in metric.namedSetOperationsList) { - scalarTableColumnsList += - ReportKt.DetailsKt.ResultKt.column { - columnHeader = buildColumnHeader(metricType.name, namedSetOperation.displayName) - setOperations += - namedSetOperation.sortedMeasurementCalculations.map { - calculateScalarResult( - metricType, - it.getMeasurementResults(measurementResultsMap), - ) - } - } - } - } - Metric.Details.MetricTypeCase.FREQUENCY_HISTOGRAM -> { - histogramTables += metric.toHistogramTable(rowHeaders, measurementResultsMap) - } - Metric.Details.MetricTypeCase.METRICTYPE_NOT_SET -> error("Metric Type should be set.") - } - } - if (scalarTableColumnsList.size > 0) { - scalarTable = - ReportKt.DetailsKt.ResultKt.scalarTable { - this.rowHeaders += rowHeaders - columns += scalarTableColumnsList - } - } - } - } - - /** Generate row headers of [Report.Details.Result] from a [Report]. */ - private fun getRowHeaders(report: Report): List { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - return when (report.timeCase) { - Report.TimeCase.TIME_INTERVALS -> { - report.timeIntervals.timeIntervalsList - .sortedWith { a, b -> - val start = Timestamps.compare(a.startTime, b.startTime) - if (start != 0) { - start - } else { - Timestamps.compare(a.endTime, b.endTime) - } - } - .map(TimeInterval::toRowHeader) - } - Report.TimeCase.PERIODIC_TIME_INTERVAL -> { - - report.periodicTimeInterval.toTimeIntervals().map(TimeInterval::toRowHeader) - } - Report.TimeCase.TIME_NOT_SET -> { - error("Time should be set.") - } - } - } - - private fun PeriodicTimeInterval.toTimeIntervals(): List { - val source = this - var startTime = checkNotNull(source.startTime) - return (0 until source.intervalCount).map { - timeInterval { - this.startTime = startTime - this.endTime = Timestamps.add(startTime, source.increment) - startTime = this.endTime - } - } - } - - private val Metric.NamedSetOperation.sortedMeasurementCalculations: List - get() { - return measurementCalculationsList.sortedWith { a, b -> - val start = Timestamps.compare(a.timeInterval.startTime, b.timeInterval.startTime) - if (start != 0) { - start - } else { - Timestamps.compare(a.timeInterval.endTime, b.timeInterval.endTime) - } - } - } - - private fun MeasurementCalculation.getMeasurementResults( - measurementResultsMap: Map - ): List { - return weightedMeasurementsList.map { - MeasurementResult( - measurementResultsMap[it.measurementReferenceId]?.result - ?: throw MeasurementNotFoundException(), - it.coefficient, - ) - } - } - - /** Build a column header given the metric and set operation name. */ - private fun buildColumnHeader(metricTypeName: String, setOperationName: String): String { - return "${metricTypeName}_$setOperationName" - } - - /** Calculate the equation to get the scalar result. */ - private fun calculateScalarResult( - metricType: Metric.Details.MetricTypeCase, - measurementResultsList: List, - ): Double { - return when (metricType) { - Metric.Details.MetricTypeCase.REACH -> calculateReachResult(measurementResultsList) - Metric.Details.MetricTypeCase.IMPRESSION_COUNT -> - calculateImpressionResult(measurementResultsList) - Metric.Details.MetricTypeCase.WATCH_DURATION -> - calculateWatchDurationResult(measurementResultsList) - Metric.Details.MetricTypeCase.FREQUENCY_HISTOGRAM -> - error("$metricType is not a scalar metric type") - Metric.Details.MetricTypeCase.METRICTYPE_NOT_SET -> error("Metric Type should be set.") - } - } - - /** Calculate the reach result by summing up weighted [Measurement]s. */ - private fun calculateReachResult(measurementResultsList: List): Double { - return measurementResultsList - .sumOf { (result, coefficient) -> - if (!result.hasReach()) { - error("Reach measurement is missing.") - } - result.reach.value * coefficient - } - .toDouble() - } - - /** Calculate the impression result by summing up weighted [Measurement]s. */ - private fun calculateImpressionResult( - measurementCoefficientPairsList: List - ): Double { - return measurementCoefficientPairsList - .sumOf { (result, coefficient) -> - if (!result.hasImpression()) { - error("Impression measurement is missing.") - } - result.impression.value * coefficient - } - .toDouble() - } - - /** Calculate the watch duration result by summing up weighted [Measurement]s. */ - private fun calculateWatchDurationResult( - measurementCoefficientPairsList: List - ): Double { - val watchDuration = - measurementCoefficientPairsList - .map { (result, coefficient) -> - if (!result.hasWatchDuration()) { - error("Watch duration measurement is missing.") - } - result.watchDuration.value * coefficient - } - .reduce { sum, element -> sum + element } - - return watchDuration.seconds + (watchDuration.nanos.toDouble() / NANOS_PER_SECOND) - } - - private operator fun Duration.times(coefficient: Int): Duration { - val source = this - return duration { - val weightedTotalNanos: Long = - (TimeUnit.SECONDS.toNanos(source.seconds) + source.nanos) * coefficient - seconds = TimeUnit.NANOSECONDS.toSeconds(weightedTotalNanos) - nanos = (weightedTotalNanos % NANOS_PER_SECOND).toInt() - } - } - - private operator fun Duration.plus(other: Duration): Duration { - val source = this - return Durations.add(source, other) - } - - /** Calculate the frequency histogram result by summing up weighted [Measurement]s. */ - private fun calculateFrequencyHistogram( - measurementCoefficientPairsList: List - ): Map { - val aggregatedFrequencyHistogramMap = - measurementCoefficientPairsList - .map { (result, coefficient) -> - if (!result.hasFrequency() || !result.hasReach()) { - error("Reach-Frequency measurement is missing.") - } - val reach = result.reach.value - result.frequency.relativeFrequencyDistributionMap.mapValues { - it.value * coefficient * reach - } - } - .fold(mutableMapOf().withDefault { 0.0 }) { - aggregatedFrequencyHistogramMap: MutableMap, - weightedFrequencyHistogramMap -> - for ((frequency, count) in weightedFrequencyHistogramMap) { - aggregatedFrequencyHistogramMap[frequency] = - aggregatedFrequencyHistogramMap.getValue(frequency) + count - } - aggregatedFrequencyHistogramMap - } - - return aggregatedFrequencyHistogramMap - } - - /** Convert a [Metric] to a [Report.Details.Result.HistogramTable] of a [Report] */ - private fun Metric.toHistogramTable( - rowHeaders: List, - measurementResultsMap: Map, - ): Report.Details.Result.HistogramTable { - val setOperationFrequencyHistograms: List>> = - namedSetOperationsList.map { - it.sortedMeasurementCalculations.map { calculation -> - calculateFrequencyHistogram(calculation.getMeasurementResults(measurementResultsMap)) - } - } - val largestFrequency = - setOperationFrequencyHistograms.flatten().flatMap { it.keys }.maxOrNull() ?: 1L - - val source = this - return ReportKt.DetailsKt.ResultKt.histogramTable { - for (rowHeader in rowHeaders) { - for (frequency in 1..largestFrequency) { - rows += - ReportKt.DetailsKt.ResultKt.HistogramTableKt.row { - this.rowHeader = rowHeader - this.frequency = frequency.toInt() - } - } - } - for ((namedSetOperation, frequencyHistograms) in - source.namedSetOperationsList.zip(setOperationFrequencyHistograms)) { - columns += - ReportKt.DetailsKt.ResultKt.column { - columnHeader = - buildColumnHeader(source.details.metricTypeCase.name, namedSetOperation.displayName) - for (frequencyHistogram in frequencyHistograms) { - for (frequency in 1..largestFrequency) { - setOperations += frequencyHistogram.getOrDefault(frequency, 0.0) - } - } - } - } - } - } -} - -/** Convert a [TimeInterval] to a row header in String. */ -private fun TimeInterval.toRowHeader(): String { - val source = this - val startTimeInstant = - Instant.ofEpochSecond(source.startTime.seconds, source.startTime.nanos.toLong()) - val endTimeInstant = Instant.ofEpochSecond(source.endTime.seconds, source.endTime.nanos.toLong()) - return "$startTimeInstant-$endTimeInstant" -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/AkidPrincipalLookup.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/AkidPrincipalLookup.kt deleted file mode 100644 index afba11912f2..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/AkidPrincipalLookup.kt +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright 2022 The Cross-Media Measurement Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import com.google.protobuf.ByteString -import java.io.File -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey -import org.wfanet.measurement.common.api.AkidConfigPrincipalLookup -import org.wfanet.measurement.common.api.AkidConfigResourceNameLookup -import org.wfanet.measurement.common.api.PrincipalLookup -import org.wfanet.measurement.common.api.ResourceKey -import org.wfanet.measurement.common.api.toResourceKeyLookup -import org.wfanet.measurement.common.parseTextProto -import org.wfanet.measurement.config.AuthorityKeyToPrincipalMap -import org.wfanet.measurement.config.reporting.MeasurementConsumerConfig -import org.wfanet.measurement.config.reporting.MeasurementConsumerConfigs - -/** [PrincipalLookup] of [ReportingPrincipal] by authority key identifier (AKID). */ -class AkidPrincipalLookup( - akidConfig: AuthorityKeyToPrincipalMap, - measurementConsumerConfigs: MeasurementConsumerConfigs, -) : PrincipalLookup { - - /** - * Constructs [AkidConfigPrincipalLookup] from a file. - * - * @param akidConfig a [File] containing an [AuthorityKeyToPrincipalMap] message in text format - * @param measurementConsumerConfigs a [File] containing a [MeasurementConsumerConfigs] message in - * text format - */ - constructor( - akidConfig: File, - measurementConsumerConfigs: File, - ) : this( - parseTextProto(akidConfig, AuthorityKeyToPrincipalMap.getDefaultInstance()), - parseTextProto(measurementConsumerConfigs, MeasurementConsumerConfigs.getDefaultInstance()), - ) - - private val measurementConsumerConfigs: Map = - measurementConsumerConfigs.configsMap - - private val resourceKeyLookup = - AkidConfigResourceNameLookup(akidConfig).toResourceKeyLookup(MeasurementConsumerKey.FACTORY) - - override suspend fun getPrincipal(lookupKey: ByteString): ReportingPrincipal? { - val resourceKey: ResourceKey = resourceKeyLookup.getResourceKey(lookupKey) ?: return null - return when (resourceKey) { - is MeasurementConsumerKey -> { - val resourceName: String = resourceKey.toName() - val config = - measurementConsumerConfigs[resourceName] - ?: error("Missing MeasurementConsumerConfig for $resourceName") - MeasurementConsumerPrincipal(resourceKey, config) - } - else -> null - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel deleted file mode 100644 index 3cdded1d308..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel +++ /dev/null @@ -1,143 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -package(default_visibility = ["//visibility:public"]) - -kt_jvm_library( - name = "resource_key", - srcs = glob(["*Key.kt"]) + ["IdVariable.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/common/api:resource_key", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - ], -) - -kt_jvm_library( - name = "reporting_sets_service", - srcs = ["ReportingSetsService.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/api:public_api_version", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:principal_server_interceptor", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reporting_principal", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:resource_key", - "//src/main/proto/wfa/measurement/internal/reporting:reporting_sets_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:page_token_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reporting_sets_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", - ], -) - -kt_jvm_library( - name = "event_groups_service", - srcs = ["EventGroupsService.kt"], - deps = [ - ":resource_key", - "//imports/java/org/projectnessie/cel", - "//src/main/kotlin/org/wfanet/measurement/api:api_key_constants", - "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:cel_env_provider", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:encryption_key_pair_store", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:principal_server_interceptor", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reporting_principal", - "//src/main/proto/wfa/measurement/api/v2alpha:event_group_kt_jvm_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:event_groups_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:event_group_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:event_groups_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/tink", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/measurementconsumer", - ], -) - -kt_jvm_library( - name = "reports_service", - srcs = ["ReportsService.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/api:api_key_constants", - "//src/main/kotlin/org/wfanet/measurement/api:public_api_version", - "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:encryption_key_pair_store", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:principal_server_interceptor", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reporting_principal", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:resource_key", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:set_operation_compiler", - "//src/main/proto/wfa/measurement/api/v2alpha:certificate_kt_jvm_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:certificates_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:data_providers_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:measurement_consumers_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:measurement_kt_jvm_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:measurements_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/config/reporting:measurement_spec_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/internal/reporting:measurements_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reporting_sets_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reports_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:page_token_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reports_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//imports/java/com/google/protobuf/util", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto:security_provider", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", - "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/dataprovider", - "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/measurementconsumer", - ], -) - -kt_jvm_library( - name = "akid_principal_lookup", - srcs = ["AkidPrincipalLookup.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", - "//src/main/kotlin/org/wfanet/measurement/common/api:akid_config_lookup", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reporting_principal", - "//src/main/proto/wfa/measurement/config/reporting:measurement_consumer_config_kt_jvm_proto", - ], -) - -kt_jvm_library( - name = "set_operation_compiler", - srcs = ["SetOperationCompiler.kt"], - deps = [ - "//src/main/proto/wfa/measurement/reporting/v1alpha:report_kt_jvm_proto", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", - ], -) - -kt_jvm_library( - name = "context_keys", - srcs = ["ContextKeys.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/common/api:principal", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reporting_principal", - "@wfa_common_jvm//imports/java/io/grpc:api", - "@wfa_common_jvm//imports/java/io/grpc:context", - ], -) - -kt_jvm_library( - name = "reporting_principal", - srcs = ["ReportingPrincipal.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", - "//src/main/kotlin/org/wfanet/measurement/common/api:principal", - "//src/main/proto/wfa/measurement/config/reporting:measurement_consumer_config_kt_jvm_proto", - ], -) - -kt_jvm_library( - name = "principal_server_interceptor", - srcs = ["PrincipalServerInterceptor.kt"], - deps = [ - "context_keys", - ":reporting_principal", - "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", - "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:akid_principal_server_interceptor", - "//src/main/kotlin/org/wfanet/measurement/common/identity", - "@wfa_common_jvm//imports/java/com/google/protobuf", - "@wfa_common_jvm//imports/java/io/grpc:api", - "@wfa_common_jvm//imports/java/io/grpc:context", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ContextKeys.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ContextKeys.kt deleted file mode 100644 index 23cb010cca8..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ContextKeys.kt +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import io.grpc.Context - -object ContextKeys { - /** This is the context key for the authenticated [ReportingPrincipal]. */ - val PRINCIPAL_CONTEXT_KEY: Context.Key = Context.key("principal") -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupKey.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupKey.kt deleted file mode 100644 index 98286678227..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupKey.kt +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import org.wfanet.measurement.common.ResourceNameParser -import org.wfanet.measurement.common.api.ResourceKey - -private val parser = - ResourceNameParser( - "measurementConsumers/{measurement_consumer}" + - "/dataProviders/{data_provider}/eventGroups/{event_group}" - ) - -/** [ResourceKey] of an EventGroup. */ -class EventGroupKey( - val measurementConsumerReferenceId: String, - val dataProviderReferenceId: String, - val eventGroupReferenceId: String, -) : ResourceKey { - override fun toName(): String { - return parser.assembleName( - mapOf( - IdVariable.MEASUREMENT_CONSUMER to measurementConsumerReferenceId, - IdVariable.DATA_PROVIDER to dataProviderReferenceId, - IdVariable.EVENT_GROUP to eventGroupReferenceId, - ) - ) - } - - companion object FACTORY : ResourceKey.Factory { - val defaultValue = EventGroupKey("", "", "") - - override fun fromName(resourceName: String): EventGroupKey? { - return parser.parseIdVars(resourceName)?.let { - EventGroupKey( - it.getValue(IdVariable.MEASUREMENT_CONSUMER), - it.getValue(IdVariable.DATA_PROVIDER), - it.getValue(IdVariable.EVENT_GROUP), - ) - } - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupParentKey.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupParentKey.kt deleted file mode 100644 index 53526048649..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupParentKey.kt +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import org.wfanet.measurement.common.ResourceNameParser -import org.wfanet.measurement.common.api.ResourceKey - -private val parser = - ResourceNameParser("measurementConsumers/{measurement_consumer}/dataProviders/{data_provider}") - -/** [ResourceKey] of a EventGroupParent. */ -data class EventGroupParentKey( - val measurementConsumerId: String, - val dataProviderReferenceId: String, -) : ResourceKey { - override fun toName(): String { - return parser.assembleName( - mapOf( - IdVariable.MEASUREMENT_CONSUMER to measurementConsumerId, - IdVariable.DATA_PROVIDER to dataProviderReferenceId, - ) - ) - } - - companion object FACTORY : ResourceKey.Factory { - val defaultValue = EventGroupParentKey("", "") - - override fun fromName(resourceName: String): EventGroupParentKey? { - return parser.parseIdVars(resourceName)?.let { - EventGroupParentKey( - it.getValue(IdVariable.MEASUREMENT_CONSUMER), - it.getValue(IdVariable.DATA_PROVIDER), - ) - } - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsService.kt deleted file mode 100644 index e180f567cb7..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsService.kt +++ /dev/null @@ -1,237 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import com.google.protobuf.DynamicMessage -import com.google.protobuf.kotlin.unpack -import io.grpc.Status -import io.grpc.StatusException -import java.security.GeneralSecurityException -import org.projectnessie.cel.common.types.Err -import org.projectnessie.cel.common.types.ref.Val -import org.wfanet.measurement.api.v2alpha.DataProviderKey as CmmsDataProviderKey -import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey as CmmsEncryptionPublicKey -import org.wfanet.measurement.api.v2alpha.EventGroup as CmmsEventGroup -import org.wfanet.measurement.api.v2alpha.EventGroupKey as CmmsEventGroupKey -import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub -import org.wfanet.measurement.api.v2alpha.ListEventGroupsRequestKt.filter -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey as CmmsMeasurementConsumerKey -import org.wfanet.measurement.api.v2alpha.listEventGroupsRequest as cmmsListEventGroupsRequest -import org.wfanet.measurement.api.withAuthenticationKey -import org.wfanet.measurement.common.api.ResourceKey -import org.wfanet.measurement.common.crypto.PrivateKeyHandle -import org.wfanet.measurement.common.grpc.failGrpc -import org.wfanet.measurement.consent.client.measurementconsumer.decryptMetadata -import org.wfanet.measurement.reporting.service.api.CelEnvProvider -import org.wfanet.measurement.reporting.service.api.EncryptionKeyPairStore -import org.wfanet.measurement.reporting.v1alpha.EventGroup -import org.wfanet.measurement.reporting.v1alpha.EventGroupKt.eventTemplate -import org.wfanet.measurement.reporting.v1alpha.EventGroupKt.metadata -import org.wfanet.measurement.reporting.v1alpha.EventGroupsGrpcKt.EventGroupsCoroutineImplBase -import org.wfanet.measurement.reporting.v1alpha.ListEventGroupsRequest -import org.wfanet.measurement.reporting.v1alpha.ListEventGroupsResponse -import org.wfanet.measurement.reporting.v1alpha.eventGroup -import org.wfanet.measurement.reporting.v1alpha.listEventGroupsResponse - -private const val METADATA_FIELD = "metadata.metadata" - -private const val MIN_PAGE_SIZE = 1 -private const val DEFAULT_PAGE_SIZE = 50 -private const val MAX_PAGE_SIZE = 1000 - -class EventGroupsService( - private val cmmsEventGroupsStub: EventGroupsCoroutineStub, - private val encryptionKeyPairStore: EncryptionKeyPairStore, - private val celEnvProvider: CelEnvProvider, -) : EventGroupsCoroutineImplBase() { - override suspend fun listEventGroups(request: ListEventGroupsRequest): ListEventGroupsResponse { - val principal: ReportingPrincipal = principalFromCurrentContext - - if (principal !is MeasurementConsumerPrincipal) { - failGrpc(Status.PERMISSION_DENIED) { - "Cannot list event groups with entities other than measurement consumer." - } - } - - val principalName = principal.resourceKey.toName() - val apiAuthenticationKey: String = principal.config.apiKey - val parentKey = - EventGroupParentKey.fromName(request.parent) - ?: failGrpc(Status.INVALID_ARGUMENT) { "parent malformed or unspecified" } - val pageSize = - when { - request.pageSize < MIN_PAGE_SIZE -> DEFAULT_PAGE_SIZE - request.pageSize > MAX_PAGE_SIZE -> MAX_PAGE_SIZE - else -> request.pageSize - } - - val cmmsListEventGroupResponse = - try { - cmmsEventGroupsStub - .withAuthenticationKey(apiAuthenticationKey) - .listEventGroups( - cmmsListEventGroupsRequest { - parent = CmmsMeasurementConsumerKey(parentKey.measurementConsumerId).toName() - this.pageSize = pageSize - pageToken = request.pageToken - filter = filter { - if (parentKey.dataProviderReferenceId != ResourceKey.WILDCARD_ID) { - dataProviders += CmmsDataProviderKey(parentKey.dataProviderReferenceId).toName() - } - } - } - ) - } catch (e: StatusException) { - throw when (e.status.code) { - Status.Code.DEADLINE_EXCEEDED -> Status.DEADLINE_EXCEEDED - Status.Code.CANCELLED -> Status.CANCELLED - else -> Status.UNKNOWN - } - .withCause(e) - .asRuntimeException() - } - val cmmsEventGroups = cmmsListEventGroupResponse.eventGroupsList - - val eventGroups = - cmmsEventGroups.map { - val cmmsMetadata: CmmsEventGroup.Metadata? = - if (it.hasEncryptedMetadata()) { - decryptMetadata(it, principalName) - } else { - null - } - - it.toEventGroup(cmmsMetadata) - } - - return listEventGroupsResponse { - this.eventGroups += filterEventGroups(eventGroups, request.filter) - nextPageToken = cmmsListEventGroupResponse.nextPageToken - } - } - - private suspend fun filterEventGroups( - eventGroups: List, - filter: String, - ): List { - if (filter.isEmpty()) { - return eventGroups - } - - val typeRegistryAndEnv = celEnvProvider.getTypeRegistryAndEnv() - val env = typeRegistryAndEnv.env - val typeRegistry = typeRegistryAndEnv.typeRegistry - - val astAndIssues = env.compile(filter) - if (astAndIssues.hasIssues()) { - throw Status.INVALID_ARGUMENT.withDescription( - "filter is not a valid CEL expression: ${astAndIssues.issues}" - ) - .asRuntimeException() - } - val program = env.program(astAndIssues.ast) - - eventGroups - .filter { it.hasMetadata() } - .distinctBy { it.metadata.metadata.typeUrl } - .forEach { - val typeUrl = it.metadata.metadata.typeUrl - typeRegistry.getDescriptorForTypeUrl(typeUrl) - ?: throw Status.FAILED_PRECONDITION.withDescription( - "${it.metadata.eventGroupMetadataDescriptor} does not contain descriptor for $typeUrl" - ) - .asRuntimeException() - } - - return eventGroups.filter { eventGroup -> - val variables: Map = - mutableMapOf().apply { - for (fieldDescriptor in eventGroup.descriptorForType.fields) { - put(fieldDescriptor.name, eventGroup.getField(fieldDescriptor)) - } - // TODO(projectnessie/cel-java#295): Remove when fixed. - if (eventGroup.hasMetadata()) { - val metadata: com.google.protobuf.Any = eventGroup.metadata.metadata - put( - METADATA_FIELD, - DynamicMessage.parseFrom( - typeRegistry.getDescriptorForTypeUrl(metadata.typeUrl), - metadata.value, - ), - ) - } - } - val result: Val = program.eval(variables).`val` - if (result is Err) { - // For when the field in the filter doesn't exist in the event group. - if (result.toString().contains("undeclared reference to")) { - return@filter false - } - throw result.toRuntimeException() - } - result.booleanValue() - } - } - - private suspend fun decryptMetadata( - cmmsEventGroup: CmmsEventGroup, - principalName: String, - ): CmmsEventGroup.Metadata { - if (!cmmsEventGroup.hasMeasurementConsumerPublicKey()) { - failGrpc(Status.FAILED_PRECONDITION) { - "EventGroup ${cmmsEventGroup.name} has encrypted metadata but no encryption public key" - } - } - val encryptionKey: CmmsEncryptionPublicKey = - cmmsEventGroup.measurementConsumerPublicKey.unpack() - val decryptionKeyHandle: PrivateKeyHandle = - encryptionKeyPairStore.getPrivateKeyHandle(principalName, encryptionKey.data) - ?: failGrpc(Status.FAILED_PRECONDITION) { - "Public key does not have corresponding private key" - } - return try { - decryptMetadata(cmmsEventGroup.encryptedMetadata, decryptionKeyHandle) - } catch (e: GeneralSecurityException) { - throw Status.FAILED_PRECONDITION.withCause(e) - .withDescription("Metadata cannot be decrypted") - .asRuntimeException() - } - } -} - -private fun CmmsEventGroup.toEventGroup(cmmsMetadata: CmmsEventGroup.Metadata?): EventGroup { - val source = this - val cmmsEventGroupKey = requireNotNull(CmmsEventGroupKey.fromName(name)) - val measurementConsumerKey = - requireNotNull(CmmsMeasurementConsumerKey.fromName(measurementConsumer)) - return eventGroup { - name = - EventGroupKey( - measurementConsumerKey.measurementConsumerId, - cmmsEventGroupKey.dataProviderId, - cmmsEventGroupKey.eventGroupId, - ) - .toName() - dataProvider = CmmsDataProviderKey(cmmsEventGroupKey.dataProviderId).toName() - eventGroupReferenceId = source.eventGroupReferenceId - eventTemplates += source.eventTemplatesList.map { eventTemplate { type = it.type } } - if (cmmsMetadata != null) { - metadata = metadata { - eventGroupMetadataDescriptor = cmmsMetadata.eventGroupMetadataDescriptor - metadata = cmmsMetadata.metadata - } - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/IdVariable.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/IdVariable.kt deleted file mode 100644 index ed764300324..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/IdVariable.kt +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import java.util.Locale -import org.wfanet.measurement.common.ResourceNameParser - -internal enum class IdVariable { - DATA_PROVIDER, - EVENT_GROUP, - MEASUREMENT_CONSUMER, - REPORTING_SET, - REPORT, -} - -internal fun ResourceNameParser.assembleName(idMap: Map): String { - return assembleName(idMap.mapKeys { it.key.name.lowercase(Locale.getDefault()) }) -} - -internal fun ResourceNameParser.parseIdVars(resourceName: String): Map? { - return parseIdSegments(resourceName)?.mapKeys { - IdVariable.valueOf(it.key.uppercase(Locale.getDefault())) - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/PrincipalServerInterceptor.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/PrincipalServerInterceptor.kt deleted file mode 100644 index e52fbc49861..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/PrincipalServerInterceptor.kt +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import com.google.protobuf.ByteString -import io.grpc.BindableService -import io.grpc.Context -import io.grpc.ServerInterceptors -import io.grpc.ServerServiceDefinition -import io.grpc.Status -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey -import org.wfanet.measurement.common.api.PrincipalLookup -import org.wfanet.measurement.common.api.grpc.AkidPrincipalServerInterceptor -import org.wfanet.measurement.common.grpc.failGrpc -import org.wfanet.measurement.common.identity.AuthorityKeyServerInterceptor -import org.wfanet.measurement.config.reporting.MeasurementConsumerConfig - -/** - * Returns a [ReportingPrincipal] in the current gRPC context. Requires [PrincipalServerInterceptor] - * to be installed. - * - * Callers can trust that the [ReportingPrincipal] is authenticated (but not necessarily - * authorized). - */ -val principalFromCurrentContext: ReportingPrincipal - get() = - ContextKeys.PRINCIPAL_CONTEXT_KEY.get() - ?: failGrpc(Status.UNAUTHENTICATED) { "No ReportingPrincipal found" } - -/** - * Executes [block] with [principal] installed in a new [Context]. - * - * The caller of [withPrincipal] is responsible for guaranteeing that [block] can act as [principal] - * -- in other words, [principal] is treated as already authenticated. - */ -fun withPrincipal(principal: ReportingPrincipal, block: () -> T): T { - return Context.current().withPrincipal(principal).call(block) -} - -/** Executes [block] with a [MeasurementConsumerPrincipal] installed in a new [Context]. */ -fun withMeasurementConsumerPrincipal( - measurementConsumerName: String, - config: MeasurementConsumerConfig, - block: () -> T, -): T { - return Context.current() - .withPrincipal( - MeasurementConsumerPrincipal( - MeasurementConsumerKey.fromName(measurementConsumerName)!!, - config, - ) - ) - .call(block) -} - -/** Adds [principal] to the receiver and returns the new [Context]. */ -fun Context.withPrincipal(principal: ReportingPrincipal): Context { - return withValue(ContextKeys.PRINCIPAL_CONTEXT_KEY, principal) -} - -/** Convenience helper for [AkidPrincipalServerInterceptor]. */ -fun BindableService.withPrincipalsFromX509AuthorityKeyIdentifiers( - akidPrincipalLookup: PrincipalLookup -): ServerServiceDefinition { - return ServerInterceptors.interceptForward( - this, - AuthorityKeyServerInterceptor(), - AkidPrincipalServerInterceptor( - ContextKeys.PRINCIPAL_CONTEXT_KEY, - AuthorityKeyServerInterceptor.AUTHORITY_KEY_IDENTIFIERS_CONTEXT_KEY, - akidPrincipalLookup, - ), - ) -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportKey.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportKey.kt deleted file mode 100644 index 11037f6824a..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportKey.kt +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import org.wfanet.measurement.common.ResourceNameParser -import org.wfanet.measurement.common.api.ResourceKey - -private val parser = - ResourceNameParser("measurementConsumers/{measurement_consumer}/reports/{report}") - -/** [ResourceKey] of a Report. */ -data class ReportKey(val measurementConsumerId: String, val reportId: String) : ResourceKey { - override fun toName(): String { - return parser.assembleName( - mapOf(IdVariable.MEASUREMENT_CONSUMER to measurementConsumerId, IdVariable.REPORT to reportId) - ) - } - - companion object FACTORY : ResourceKey.Factory { - val defaultValue = ReportKey("", "") - - override fun fromName(resourceName: String): ReportKey? { - return parser.parseIdVars(resourceName)?.let { - ReportKey(it.getValue(IdVariable.MEASUREMENT_CONSUMER), it.getValue(IdVariable.REPORT)) - } - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingPrincipal.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingPrincipal.kt deleted file mode 100644 index 76c43f1d704..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingPrincipal.kt +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerCertificateKey -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey -import org.wfanet.measurement.common.api.Principal -import org.wfanet.measurement.common.api.ResourcePrincipal -import org.wfanet.measurement.config.reporting.MeasurementConsumerConfig - -/** Identifies the sender of an inbound gRPC request. */ -sealed interface ReportingPrincipal : Principal { - val config: MeasurementConsumerConfig - - companion object { - fun fromConfigs(name: String, config: MeasurementConsumerConfig): ReportingPrincipal? { - return when (name.substringBefore('/')) { - MeasurementConsumerKey.COLLECTION_NAME -> { - require( - config.apiKey.isNotBlank() && - MeasurementConsumerCertificateKey.fromName(config.signingCertificateName) != null && - config.signingPrivateKeyPath.isNotBlank() - ) - MeasurementConsumerKey.fromName(name)?.let { MeasurementConsumerPrincipal(it, config) } - } - else -> null - } - } - } -} - -data class MeasurementConsumerPrincipal( - override val resourceKey: MeasurementConsumerKey, - override val config: MeasurementConsumerConfig, -) : ReportingPrincipal, ResourcePrincipal diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetKey.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetKey.kt deleted file mode 100644 index 7869652a1ff..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetKey.kt +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import org.wfanet.measurement.common.ResourceNameParser -import org.wfanet.measurement.common.api.ResourceKey - -private val parser = - ResourceNameParser("measurementConsumers/{measurement_consumer}/reportingSets/{reporting_set}") - -/** [ResourceKey] of a ReportingSet. */ -data class ReportingSetKey(val measurementConsumerId: String, val reportingSetId: String) : - ResourceKey { - override fun toName(): String { - return parser.assembleName( - mapOf( - IdVariable.MEASUREMENT_CONSUMER to measurementConsumerId, - IdVariable.REPORTING_SET to reportingSetId, - ) - ) - } - - companion object FACTORY : ResourceKey.Factory { - val defaultValue = ReportingSetKey("", "") - - override fun fromName(resourceName: String): ReportingSetKey? { - return parser.parseIdVars(resourceName)?.let { - ReportingSetKey( - it.getValue(IdVariable.MEASUREMENT_CONSUMER), - it.getValue(IdVariable.REPORTING_SET), - ) - } - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetsService.kt deleted file mode 100644 index 12ecd237e27..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetsService.kt +++ /dev/null @@ -1,235 +0,0 @@ -/* - * Copyright 2022 The Cross-Media Measurement Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import io.grpc.Status -import kotlin.math.min -import kotlinx.coroutines.flow.toList -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey -import org.wfanet.measurement.common.base64UrlDecode -import org.wfanet.measurement.common.base64UrlEncode -import org.wfanet.measurement.common.grpc.failGrpc -import org.wfanet.measurement.common.grpc.grpcRequire -import org.wfanet.measurement.common.grpc.grpcRequireNotNull -import org.wfanet.measurement.common.identity.externalIdToApiId -import org.wfanet.measurement.internal.reporting.ReportingSet as InternalReportingSet -import org.wfanet.measurement.internal.reporting.ReportingSet.EventGroupKey as InternalEventGroupKey -import org.wfanet.measurement.internal.reporting.ReportingSetKt.eventGroupKey as internalEventGroupKey -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt.ReportingSetsCoroutineStub -import org.wfanet.measurement.internal.reporting.StreamReportingSetsRequest -import org.wfanet.measurement.internal.reporting.StreamReportingSetsRequestKt.filter -import org.wfanet.measurement.internal.reporting.reportingSet as internalReportingSet -import org.wfanet.measurement.internal.reporting.streamReportingSetsRequest -import org.wfanet.measurement.reporting.v1alpha.CreateReportingSetRequest -import org.wfanet.measurement.reporting.v1alpha.ListReportingSetsPageToken -import org.wfanet.measurement.reporting.v1alpha.ListReportingSetsPageTokenKt.previousPageEnd -import org.wfanet.measurement.reporting.v1alpha.ListReportingSetsRequest -import org.wfanet.measurement.reporting.v1alpha.ListReportingSetsResponse -import org.wfanet.measurement.reporting.v1alpha.ReportingSet -import org.wfanet.measurement.reporting.v1alpha.ReportingSetsGrpcKt.ReportingSetsCoroutineImplBase -import org.wfanet.measurement.reporting.v1alpha.copy -import org.wfanet.measurement.reporting.v1alpha.listReportingSetsPageToken -import org.wfanet.measurement.reporting.v1alpha.listReportingSetsResponse -import org.wfanet.measurement.reporting.v1alpha.reportingSet - -private const val MIN_PAGE_SIZE = 1 -private const val DEFAULT_PAGE_SIZE = 50 -private const val MAX_PAGE_SIZE = 1000 - -class ReportingSetsService(private val internalReportingSetsStub: ReportingSetsCoroutineStub) : - ReportingSetsCoroutineImplBase() { - override suspend fun createReportingSet(request: CreateReportingSetRequest): ReportingSet { - val parentKey = - grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { - "Parent is either unspecified or invalid." - } - - when (val principal: ReportingPrincipal = principalFromCurrentContext) { - is MeasurementConsumerPrincipal -> { - if (request.parent != principal.resourceKey.toName()) { - failGrpc(Status.PERMISSION_DENIED) { - "Cannot create a ReportingSet for another MeasurementConsumer." - } - } - } - } - - grpcRequire(request.hasReportingSet()) { "ReportingSet is not specified." } - - grpcRequire(request.reportingSet.eventGroupsList.isNotEmpty()) { - "EventGroups in ReportingSet cannot be empty." - } - - return internalReportingSetsStub - .createReportingSet(request.reportingSet.toInternal(parentKey)) - .toReportingSet() - } - - override suspend fun listReportingSets( - request: ListReportingSetsRequest - ): ListReportingSetsResponse { - val listReportingSetsPageToken = request.toListReportingSetsPageToken() - - // Based on AIP-132#Errors - when (val principal: ReportingPrincipal = principalFromCurrentContext) { - is MeasurementConsumerPrincipal -> { - if (request.parent != principal.resourceKey.toName()) { - failGrpc(Status.PERMISSION_DENIED) { - "Cannot list ReportingSets belonging to other MeasurementConsumers." - } - } - } - } - - val results: List = - internalReportingSetsStub - .streamReportingSets(listReportingSetsPageToken.toStreamReportingSetsRequest()) - .toList() - - if (results.isEmpty()) { - return ListReportingSetsResponse.getDefaultInstance() - } - - return listReportingSetsResponse { - reportingSets += - results - .subList(0, min(results.size, listReportingSetsPageToken.pageSize)) - .map(InternalReportingSet::toReportingSet) - - if (results.size > listReportingSetsPageToken.pageSize) { - val pageToken = - listReportingSetsPageToken.copy { - lastReportingSet = previousPageEnd { - measurementConsumerReferenceId = - results[results.lastIndex - 1].measurementConsumerReferenceId - externalReportingSetId = results[results.lastIndex - 1].externalReportingSetId - } - } - nextPageToken = pageToken.toByteString().base64UrlEncode() - } - } - } -} - -/** - * Converts an internal [ListReportingSetsPageToken] to an internal [StreamReportingSetsRequest]. - */ -private fun ListReportingSetsPageToken.toStreamReportingSetsRequest(): StreamReportingSetsRequest { - val source = this - return streamReportingSetsRequest { - // get 1 more than the actual page size for deciding whether or not to set page token - limit = pageSize + 1 - filter = filter { - measurementConsumerReferenceId = source.measurementConsumerReferenceId - externalReportingSetIdAfter = source.lastReportingSet.externalReportingSetId - } - } -} - -/** Converts a public [ListReportingSetsRequest] to an internal [ListReportingSetsPageToken]. */ -private fun ListReportingSetsRequest.toListReportingSetsPageToken(): ListReportingSetsPageToken { - grpcRequire(pageSize >= 0) { "Page size cannot be less than 0" } - - val source = this - val parentKey: MeasurementConsumerKey = - grpcRequireNotNull(MeasurementConsumerKey.fromName(parent)) { - "Parent is either unspecified or invalid." - } - val measurementConsumerReferenceId = parentKey.measurementConsumerId - - return if (pageToken.isNotBlank()) { - ListReportingSetsPageToken.parseFrom(pageToken.base64UrlDecode()).copy { - grpcRequire(this.measurementConsumerReferenceId == measurementConsumerReferenceId) { - "Arguments must be kept the same when using a page token" - } - - if ( - source.pageSize != 0 && source.pageSize >= MIN_PAGE_SIZE && source.pageSize <= MAX_PAGE_SIZE - ) { - pageSize = source.pageSize - } - } - } else { - listReportingSetsPageToken { - pageSize = - when { - source.pageSize < MIN_PAGE_SIZE -> DEFAULT_PAGE_SIZE - source.pageSize > MAX_PAGE_SIZE -> MAX_PAGE_SIZE - else -> source.pageSize - } - this.measurementConsumerReferenceId = measurementConsumerReferenceId - } - } -} - -/** Converts a public [ReportingSet] to an internal [InternalReportingSet]. */ -private fun ReportingSet.toInternal( - measurementConsumerKey: MeasurementConsumerKey -): InternalReportingSet { - val source = this - - return internalReportingSet { - measurementConsumerReferenceId = measurementConsumerKey.measurementConsumerId - - for (eventGroup: String in source.eventGroupsList) { - eventGroupKeys += - grpcRequireNotNull(EventGroupKey.fromName(eventGroup)) { - "EventGroup is either unspecified or invalid." - } - .toInternal(measurementConsumerKey.measurementConsumerId) - } - filter = source.filter - displayName = source.displayName - } -} - -/** Converts a public [EventGroupKey] to an internal [InternalEventGroupKey] */ -private fun EventGroupKey.toInternal( - measurementConsumerReferenceId: String -): InternalEventGroupKey { - val source = this - return internalEventGroupKey { - dataProviderReferenceId = source.dataProviderReferenceId - eventGroupReferenceId = source.eventGroupReferenceId - this.measurementConsumerReferenceId = measurementConsumerReferenceId - } -} - -/** Converts an internal [InternalReportingSet] to a public [ReportingSet]. */ -private fun InternalReportingSet.toReportingSet(): ReportingSet { - val source = this - return reportingSet { - name = - ReportingSetKey( - measurementConsumerId = source.measurementConsumerReferenceId, - reportingSetId = externalIdToApiId(source.externalReportingSetId), - ) - .toName() - eventGroups.addAll( - eventGroupKeysList.map { - EventGroupKey( - measurementConsumerReferenceId = it.measurementConsumerReferenceId, - dataProviderReferenceId = it.dataProviderReferenceId, - eventGroupReferenceId = it.eventGroupReferenceId, - ) - .toName() - } - ) - filter = source.filter - displayName = source.displayName - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsService.kt deleted file mode 100644 index 6308cdbbc9e..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsService.kt +++ /dev/null @@ -1,2260 +0,0 @@ -/* - * Copyright 2022 The Cross-Media Measurement Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import com.google.protobuf.Any as ProtoAny -import com.google.protobuf.ByteString -import com.google.protobuf.Duration as ProtoDuration -import com.google.protobuf.kotlin.unpack -import com.google.protobuf.util.Durations -import com.google.protobuf.util.Timestamps -import com.google.type.Interval -import com.google.type.interval -import io.grpc.Status -import io.grpc.StatusException -import java.io.File -import java.security.PrivateKey -import java.security.SignatureException -import java.security.cert.CertPathValidatorException -import java.security.cert.X509Certificate -import java.time.Instant -import kotlin.math.min -import kotlin.random.Random -import kotlinx.coroutines.Deferred -import kotlinx.coroutines.async -import kotlinx.coroutines.awaitAll -import kotlinx.coroutines.coroutineScope -import kotlinx.coroutines.flow.toList -import org.wfanet.measurement.api.v2alpha.Certificate -import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt.CertificatesCoroutineStub -import org.wfanet.measurement.api.v2alpha.CreateMeasurementRequest -import org.wfanet.measurement.api.v2alpha.DataProvider -import org.wfanet.measurement.api.v2alpha.DataProviderKey -import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCoroutineStub -import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey -import org.wfanet.measurement.api.v2alpha.EventGroupKey as CmmsEventGroupKey -import org.wfanet.measurement.api.v2alpha.Measurement -import org.wfanet.measurement.api.v2alpha.Measurement.DataProviderEntry -import org.wfanet.measurement.api.v2alpha.MeasurementConsumer -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey -import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub -import org.wfanet.measurement.api.v2alpha.MeasurementKey -import org.wfanet.measurement.api.v2alpha.MeasurementKt.DataProviderEntryKt -import org.wfanet.measurement.api.v2alpha.MeasurementKt.dataProviderEntry -import org.wfanet.measurement.api.v2alpha.MeasurementSpec -import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt -import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub -import org.wfanet.measurement.api.v2alpha.RequisitionSpec.EventGroupEntry -import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt -import org.wfanet.measurement.api.v2alpha.SignedMessage -import org.wfanet.measurement.api.v2alpha.createMeasurementRequest -import org.wfanet.measurement.api.v2alpha.differentialPrivacyParams -import org.wfanet.measurement.api.v2alpha.getCertificateRequest -import org.wfanet.measurement.api.v2alpha.getDataProviderRequest -import org.wfanet.measurement.api.v2alpha.getMeasurementConsumerRequest -import org.wfanet.measurement.api.v2alpha.getMeasurementRequest -import org.wfanet.measurement.api.v2alpha.measurement -import org.wfanet.measurement.api.v2alpha.measurementSpec -import org.wfanet.measurement.api.v2alpha.requisitionSpec -import org.wfanet.measurement.api.v2alpha.unpack -import org.wfanet.measurement.api.withAuthenticationKey -import org.wfanet.measurement.common.base64UrlDecode -import org.wfanet.measurement.common.base64UrlEncode -import org.wfanet.measurement.common.crypto.Hashing -import org.wfanet.measurement.common.crypto.PrivateKeyHandle -import org.wfanet.measurement.common.crypto.SigningKeyHandle -import org.wfanet.measurement.common.crypto.authorityKeyIdentifier -import org.wfanet.measurement.common.crypto.readCertificate -import org.wfanet.measurement.common.crypto.readPrivateKey -import org.wfanet.measurement.common.grpc.failGrpc -import org.wfanet.measurement.common.grpc.grpcRequire -import org.wfanet.measurement.common.grpc.grpcRequireNotNull -import org.wfanet.measurement.common.identity.apiIdToExternalId -import org.wfanet.measurement.common.identity.externalIdToApiId -import org.wfanet.measurement.common.readByteString -import org.wfanet.measurement.config.reporting.MeasurementSpecConfig -import org.wfanet.measurement.consent.client.measurementconsumer.decryptResult -import org.wfanet.measurement.consent.client.measurementconsumer.encryptRequisitionSpec -import org.wfanet.measurement.consent.client.measurementconsumer.signMeasurementSpec -import org.wfanet.measurement.consent.client.measurementconsumer.signRequisitionSpec -import org.wfanet.measurement.consent.client.measurementconsumer.verifyEncryptionPublicKey -import org.wfanet.measurement.consent.client.measurementconsumer.verifyResult -import org.wfanet.measurement.internal.reporting.CreateReportRequest as InternalCreateReportRequest -import org.wfanet.measurement.internal.reporting.CreateReportRequest.MeasurementKey as InternalMeasurementKey -import org.wfanet.measurement.internal.reporting.CreateReportRequestKt as InternalCreateReportRequestKt -import org.wfanet.measurement.internal.reporting.Measurement as InternalMeasurement -import org.wfanet.measurement.internal.reporting.Measurement.Result as InternalMeasurementResult -import org.wfanet.measurement.internal.reporting.MeasurementKt as InternalMeasurementKt -import org.wfanet.measurement.internal.reporting.MeasurementsGrpcKt.MeasurementsCoroutineStub as InternalMeasurementsCoroutineStub -import org.wfanet.measurement.internal.reporting.Metric as InternalMetric -import org.wfanet.measurement.internal.reporting.Metric.Details as InternalMetricDetails -import org.wfanet.measurement.internal.reporting.Metric.Details.MetricTypeCase as InternalMetricTypeCase -import org.wfanet.measurement.internal.reporting.Metric.FrequencyHistogramParams as InternalFrequencyHistogramParams -import org.wfanet.measurement.internal.reporting.Metric.ImpressionCountParams as InternalImpressionCountParams -import org.wfanet.measurement.internal.reporting.Metric.MeasurementCalculation -import org.wfanet.measurement.internal.reporting.Metric.NamedSetOperation as InternalNamedSetOperation -import org.wfanet.measurement.internal.reporting.Metric.SetOperation as InternalSetOperation -import org.wfanet.measurement.internal.reporting.Metric.SetOperation.Operand as InternalOperand -import org.wfanet.measurement.internal.reporting.Metric.WatchDurationParams as InternalWatchDurationParams -import org.wfanet.measurement.internal.reporting.MetricKt as InternalMetricKt -import org.wfanet.measurement.internal.reporting.MetricKt.MeasurementCalculationKt -import org.wfanet.measurement.internal.reporting.MetricKt.SetOperationKt as InternalSetOperationKt -import org.wfanet.measurement.internal.reporting.PeriodicTimeInterval as InternalPeriodicTimeInterval -import org.wfanet.measurement.internal.reporting.Report as InternalReport -import org.wfanet.measurement.internal.reporting.ReportKt as InternalReportKt -import org.wfanet.measurement.internal.reporting.ReportingSet as InternalReportingSet -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt.ReportingSetsCoroutineStub as InternalReportingSetsCoroutineStub -import org.wfanet.measurement.internal.reporting.ReportsGrpcKt.ReportsCoroutineStub as InternalReportsCoroutineStub -import org.wfanet.measurement.internal.reporting.SetMeasurementResultRequest as SetInternalMeasurementResultRequest -import org.wfanet.measurement.internal.reporting.StreamReportsRequest as StreamInternalReportsRequest -import org.wfanet.measurement.internal.reporting.StreamReportsRequestKt.filter -import org.wfanet.measurement.internal.reporting.TimeInterval as InternalTimeInterval -import org.wfanet.measurement.internal.reporting.TimeIntervals as InternalTimeIntervals -import org.wfanet.measurement.internal.reporting.batchCreateMeasurementsRequest -import org.wfanet.measurement.internal.reporting.batchGetReportingSetRequest -import org.wfanet.measurement.internal.reporting.createReportRequest as internalCreateReportRequest -import org.wfanet.measurement.internal.reporting.getReportByIdempotencyKeyRequest -import org.wfanet.measurement.internal.reporting.getReportRequest as getInternalReportRequest -import org.wfanet.measurement.internal.reporting.measurement as internalMeasurement -import org.wfanet.measurement.internal.reporting.metric as internalMetric -import org.wfanet.measurement.internal.reporting.periodicTimeInterval as internalPeriodicTimeInterval -import org.wfanet.measurement.internal.reporting.report as internalReport -import org.wfanet.measurement.internal.reporting.setMeasurementFailureRequest as setInternalMeasurementFailureRequest -import org.wfanet.measurement.internal.reporting.setMeasurementResultRequest as setInternalMeasurementResultRequest -import org.wfanet.measurement.internal.reporting.streamReportsRequest as streamInternalReportsRequest -import org.wfanet.measurement.internal.reporting.timeInterval as internalTimeInterval -import org.wfanet.measurement.internal.reporting.timeIntervals as internalTimeIntervals -import org.wfanet.measurement.reporting.service.api.EncryptionKeyPairStore -import org.wfanet.measurement.reporting.v1alpha.CreateReportRequest -import org.wfanet.measurement.reporting.v1alpha.GetReportRequest -import org.wfanet.measurement.reporting.v1alpha.ListReportsPageToken -import org.wfanet.measurement.reporting.v1alpha.ListReportsPageTokenKt.previousPageEnd -import org.wfanet.measurement.reporting.v1alpha.ListReportsRequest -import org.wfanet.measurement.reporting.v1alpha.ListReportsResponse -import org.wfanet.measurement.reporting.v1alpha.Metric -import org.wfanet.measurement.reporting.v1alpha.Metric.FrequencyHistogramParams -import org.wfanet.measurement.reporting.v1alpha.Metric.ImpressionCountParams -import org.wfanet.measurement.reporting.v1alpha.Metric.MetricTypeCase -import org.wfanet.measurement.reporting.v1alpha.Metric.NamedSetOperation -import org.wfanet.measurement.reporting.v1alpha.Metric.SetOperation -import org.wfanet.measurement.reporting.v1alpha.Metric.SetOperation.Operand -import org.wfanet.measurement.reporting.v1alpha.Metric.WatchDurationParams -import org.wfanet.measurement.reporting.v1alpha.MetricKt.SetOperationKt.operand -import org.wfanet.measurement.reporting.v1alpha.MetricKt.frequencyHistogramParams -import org.wfanet.measurement.reporting.v1alpha.MetricKt.impressionCountParams -import org.wfanet.measurement.reporting.v1alpha.MetricKt.namedSetOperation -import org.wfanet.measurement.reporting.v1alpha.MetricKt.reachParams -import org.wfanet.measurement.reporting.v1alpha.MetricKt.setOperation -import org.wfanet.measurement.reporting.v1alpha.MetricKt.watchDurationParams -import org.wfanet.measurement.reporting.v1alpha.PeriodicTimeInterval -import org.wfanet.measurement.reporting.v1alpha.Report -import org.wfanet.measurement.reporting.v1alpha.Report.Result -import org.wfanet.measurement.reporting.v1alpha.ReportKt.EventGroupUniverseKt -import org.wfanet.measurement.reporting.v1alpha.ReportKt.ResultKt.HistogramTableKt.row -import org.wfanet.measurement.reporting.v1alpha.ReportKt.ResultKt.column -import org.wfanet.measurement.reporting.v1alpha.ReportKt.ResultKt.histogramTable -import org.wfanet.measurement.reporting.v1alpha.ReportKt.ResultKt.scalarTable -import org.wfanet.measurement.reporting.v1alpha.ReportKt.eventGroupUniverse -import org.wfanet.measurement.reporting.v1alpha.ReportKt.result -import org.wfanet.measurement.reporting.v1alpha.ReportsGrpcKt.ReportsCoroutineImplBase -import org.wfanet.measurement.reporting.v1alpha.TimeInterval -import org.wfanet.measurement.reporting.v1alpha.TimeIntervals -import org.wfanet.measurement.reporting.v1alpha.copy -import org.wfanet.measurement.reporting.v1alpha.listReportsPageToken -import org.wfanet.measurement.reporting.v1alpha.listReportsResponse -import org.wfanet.measurement.reporting.v1alpha.metric -import org.wfanet.measurement.reporting.v1alpha.periodicTimeInterval -import org.wfanet.measurement.reporting.v1alpha.report -import org.wfanet.measurement.reporting.v1alpha.timeInterval -import org.wfanet.measurement.reporting.v1alpha.timeIntervals - -private const val MIN_PAGE_SIZE = 1 -private const val DEFAULT_PAGE_SIZE = 50 -private const val MAX_PAGE_SIZE = 1000 - -private val timeIntervalComparator: (TimeInterval, TimeInterval) -> Int = { a, b -> - val start = Timestamps.compare(a.startTime, b.startTime) - if (start != 0) { - start - } else { - Timestamps.compare(a.endTime, b.endTime) - } -} - -class ReportsService( - private val internalReportsStub: InternalReportsCoroutineStub, - private val internalReportingSetsStub: InternalReportingSetsCoroutineStub, - private val internalMeasurementsStub: InternalMeasurementsCoroutineStub, - private val dataProvidersStub: DataProvidersCoroutineStub, - private val measurementConsumersStub: MeasurementConsumersCoroutineStub, - private val measurementsStub: MeasurementsCoroutineStub, - private val certificateStub: CertificatesCoroutineStub, - private val encryptionKeyPairStore: EncryptionKeyPairStore, - private val secureRandom: Random, - private val signingPrivateKeyDir: File, - private val trustedCertificates: Map, - measurementSpecConfig: MeasurementSpecConfig, -) : ReportsCoroutineImplBase() { - private val setOperationCompiler = SetOperationCompiler() - private val measurementSpecComponentFactory = - MeasurementSpecComponentFactory(measurementSpecConfig, secureRandom) - - private data class ReportInfo( - val measurementConsumerReferenceId: String, - val reportIdempotencyKey: String, - val eventGroupFilters: Map, - ) - - private data class SigningConfig( - val signingCertificateName: String, - val signingCertificateDer: ByteString, - val signingPrivateKey: PrivateKey, - ) - - private data class WeightedMeasurementInfo( - val reportingMeasurementId: String, - val weightedMeasurement: WeightedMeasurement, - val timeInterval: TimeInterval, - val reportTimeInterval: TimeInterval, - var kingdomMeasurementId: String? = null, - ) - - private data class SetOperationResult( - val weightedMeasurementInfoList: List, - val internalMetricDetails: InternalMetricDetails, - ) - - private data class DataProviderInfo( - val dataProviderName: String, - val publicKey: SignedMessage, - val certificateName: String, - ) - - override suspend fun createReport(request: CreateReportRequest): Report { - grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { - "Parent is either unspecified or invalid." - } - val principal: ReportingPrincipal = principalFromCurrentContext - - when (principal) { - is MeasurementConsumerPrincipal -> { - if (request.parent != principal.resourceKey.toName()) { - failGrpc(Status.PERMISSION_DENIED) { - "Cannot create a Report for another MeasurementConsumer." - } - } - } - } - - val resourceKey = principal.resourceKey - val apiAuthenticationKey: String = principal.config.apiKey - - grpcRequire(request.hasReport()) { "Report is not specified." } - - // TODO(@riemanli) Put the check here as the reportIdempotencyKey will be moved to the request - // level in the future. - grpcRequire(request.report.reportIdempotencyKey.isNotBlank()) { - "ReportIdempotencyKey is not specified." - } - grpcRequire(request.report.measurementConsumer == request.parent) { - "Cannot create a Report for another MeasurementConsumer." - } - - grpcRequire(request.report.metricsList.isNotEmpty()) { "Metrics in Report cannot be empty." } - request.report.metricsList.forEach { - grpcRequire(it.setOperationsList.isNotEmpty()) { "Metric setOperationsList cannot be empty." } - it.setOperationsList.forEach { namedSetOperation -> - grpcRequire(namedSetOperation.uniqueName.isNotBlank()) { - "NamedSetOperation uniqueName is unspecified." - } - grpcRequire( - !namedSetOperation.setOperation.lhs.operandCase.equals( - Operand.OperandCase.OPERAND_NOT_SET - ) - ) { - "NamedSetOperation SetOperation Operand is unspecified." - } - grpcRequire( - !namedSetOperation.setOperation.type.equals(SetOperation.Type.TYPE_UNSPECIFIED) - ) { - "NamedSetOperation SetOperation Type is unspecified." - } - } - } - checkSetOperationNamesUniqueness(request.report.metricsList) - - val existingInternalReport: InternalReport? = - getInternalReport(resourceKey.measurementConsumerId, request.report.reportIdempotencyKey) - - if (existingInternalReport != null) return existingInternalReport.toReport() - - val reportInfo: ReportInfo = buildReportInfo(request, resourceKey.measurementConsumerId) - - val metrics = - Metrics(reportInfo, internalReportingSetsStub, setOperationCompiler, request.report) - metrics.process() - val namedSetOperationResults: Map = - metrics.getNamedSetOperationResults() - val internalReportingSetMap: Map = - metrics.getInternalReportingSetsMap() - - val measurementConsumer = - try { - measurementConsumersStub - .withAuthenticationKey(apiAuthenticationKey) - .getMeasurementConsumer( - getMeasurementConsumerRequest { - name = MeasurementConsumerKey(resourceKey.measurementConsumerId).toName() - } - ) - } catch (e: StatusException) { - throw Exception( - "Unable to retrieve the measurement consumer " + - "[${MeasurementConsumerKey(resourceKey.measurementConsumerId).toName()}].", - e, - ) - } - - val dataProviderNames = mutableSetOf() - for (internalReportingSet in internalReportingSetMap.values) { - for (eventGroupKey in internalReportingSet.eventGroupKeysList) { - dataProviderNames.add(DataProviderKey(eventGroupKey.dataProviderReferenceId).toName()) - } - } - val dataProviderInfoMap: Map = - buildDataProviderInfoMap(apiAuthenticationKey, dataProviderNames) - - // TODO: Factor this out to a separate class similar to EncryptionKeyPairStore. - val signingPrivateKeyDer: ByteString = - signingPrivateKeyDir.resolve(principal.config.signingPrivateKeyPath).readByteString() - - val signingCertificateDer: ByteString = - getSigningCertificateDer(apiAuthenticationKey, principal.config.signingCertificateName) - - val signingConfig = - SigningConfig( - principal.config.signingCertificateName, - signingCertificateDer, - readPrivateKey( - signingPrivateKeyDer, - readCertificate(signingCertificateDer).publicKey.algorithm, - ), - ) - - createMeasurements( - request, - namedSetOperationResults, - reportInfo, - measurementConsumer, - apiAuthenticationKey, - signingConfig, - internalReportingSetMap, - dataProviderInfoMap, - ) - - val internalCreateReportRequest: InternalCreateReportRequest = - buildInternalCreateReportRequest(request, reportInfo, namedSetOperationResults) - try { - return internalReportsStub.createReport(internalCreateReportRequest).toReport() - } catch (e: StatusException) { - throw Exception("Unable to create a report in the reporting database.", e) - } - } - - /** Gets a signing certificate x509Der in ByteString. */ - private suspend fun getSigningCertificateDer( - apiAuthenticationKey: String, - signingCertificateName: String, - ): ByteString { - // TODO: Replace this with caching certificates or having them stored alongside the private key. - return try { - certificateStub - .withAuthenticationKey(apiAuthenticationKey) - .getCertificate(getCertificateRequest { name = signingCertificateName }) - .x509Der - } catch (e: StatusException) { - throw Exception( - "Unable to retrieve the signing certificate for the measurement consumer " + - "[$signingCertificateName].", - e, - ) - } - } - - /** Builds a [ReportInfo] from a [CreateReportRequest]. */ - private fun buildReportInfo( - request: CreateReportRequest, - measurementConsumerReferenceId: String, - ): ReportInfo { - grpcRequire(request.report.hasEventGroupUniverse()) { "EventGroupUniverse is not specified." } - grpcRequire(request.report.eventGroupUniverse.eventGroupEntriesList.isNotEmpty()) { - "EventGroupUniverse's eventGroupEntriesList cannot be empty." - } - - val eventGroupFilters = - request.report.eventGroupUniverse.eventGroupEntriesList.associate { - grpcRequireNotNull(EventGroupKey.fromName(it.key)) { - "EventGroupEntry key is not specified or invalid." - } - it.key to it.value - } - - return ReportInfo( - measurementConsumerReferenceId, - request.report.reportIdempotencyKey, - eventGroupFilters, - ) - } - - /** Creates CMM public [Measurement]s and [InternalMeasurement]s from [SetOperationResult]s. */ - private suspend fun createMeasurements( - request: CreateReportRequest, - namedSetOperationResults: Map, - reportInfo: ReportInfo, - measurementConsumer: MeasurementConsumer, - apiAuthenticationKey: String, - signingConfig: SigningConfig, - internalReportingSetMap: Map, - dataProviderInfoMap: Map, - ) = coroutineScope { - val deferredMeasurements = mutableListOf>() - for (metric in request.report.metricsList) { - val internalMetricDetails = buildInternalMetricDetails(metric) - - for (namedSetOperation in metric.setOperationsList) { - val setOperationId = - buildSetOperationId( - reportInfo.reportIdempotencyKey, - internalMetricDetails, - namedSetOperation.uniqueName, - ) - - val setOperationResult: SetOperationResult = - namedSetOperationResults[setOperationId] ?: continue - - for (weightedMeasurementInfo in setOperationResult.weightedMeasurementInfoList) { - deferredMeasurements.add( - async { - createMeasurement( - weightedMeasurementInfo, - reportInfo, - setOperationResult.internalMetricDetails, - measurementConsumer, - apiAuthenticationKey, - signingConfig, - internalReportingSetMap, - dataProviderInfoMap, - ) - .also { - weightedMeasurementInfo.kingdomMeasurementId = - checkNotNull(MeasurementKey.fromName(it.name)).measurementId - } - } - ) - } - } - } - - val internalMeasurements = mutableListOf() - for (measurement in deferredMeasurements.awaitAll()) { - val measurementKey = checkNotNull(MeasurementKey.fromName(measurement.name)) - internalMeasurements.add( - internalMeasurement { - this.measurementConsumerReferenceId = measurementKey.measurementConsumerId - this.measurementReferenceId = measurementKey.measurementId - state = InternalMeasurement.State.PENDING - } - ) - } - - try { - internalMeasurementsStub.batchCreateMeasurements( - batchCreateMeasurementsRequest { measurements += internalMeasurements } - ) - } catch (e: StatusException) { - throw Status.UNKNOWN.withDescription( - "Unable to create measurement in the reporting database." - ) - .withCause(e) - .asRuntimeException() - } - } - - /** Creates a kingdom measurement for a [WeightedMeasurement]. */ - private suspend fun createMeasurement( - weightedMeasurementInfo: WeightedMeasurementInfo, - reportInfo: ReportInfo, - internalMetricDetails: InternalMetricDetails, - measurementConsumer: MeasurementConsumer, - apiAuthenticationKey: String, - signingConfig: SigningConfig, - internalReportingSetMap: Map, - dataProviderInfoMap: Map, - ): Measurement { - val eventGroupEntriesByDataProvider = - groupEventGroupEntriesByDataProvider( - weightedMeasurementInfo.weightedMeasurement.reportingSets, - weightedMeasurementInfo.timeInterval.toMeasurementTimeInterval(), - reportInfo.eventGroupFilters, - internalReportingSetMap, - ) - - val createMeasurementRequest: CreateMeasurementRequest = - buildCreateMeasurementRequest( - measurementConsumer, - eventGroupEntriesByDataProvider, - internalMetricDetails, - weightedMeasurementInfo.reportingMeasurementId, - signingConfig, - dataProviderInfoMap, - ) - - try { - return measurementsStub - .withAuthenticationKey(apiAuthenticationKey) - .createMeasurement(createMeasurementRequest) - } catch (e: StatusException) { - throw Exception( - "Unable to create Measurement with request ID ${createMeasurementRequest.requestId}", - e, - ) - } - } - - /** Gets an [InternalReport]. */ - private suspend fun getInternalReport( - measurementConsumerReferenceId: String, - reportIdempotencyKey: String, - ): InternalReport? { - return try { - internalReportsStub.getReportByIdempotencyKey( - getReportByIdempotencyKeyRequest { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - this.reportIdempotencyKey = reportIdempotencyKey - } - ) - } catch (e: StatusException) { - if (e.status.code != Status.Code.NOT_FOUND) { - throw Exception( - "Unable to retrieve a report from the reporting database using the provided " + - "reportIdempotencyKey [$reportIdempotencyKey].", - e, - ) - } - null - } - } - - override suspend fun listReports(request: ListReportsRequest): ListReportsResponse { - val listReportsPageToken = request.toListReportsPageToken() - - // Based on AIP-132#Errors - val principal: ReportingPrincipal = principalFromCurrentContext - when (principal) { - is MeasurementConsumerPrincipal -> { - if (request.parent != principal.resourceKey.toName()) { - failGrpc(Status.PERMISSION_DENIED) { - "Cannot list Reports belonging to other MeasurementConsumers." - } - } - } - } - val principalName = principal.resourceKey.toName() - - val apiAuthenticationKey: String = principal.config.apiKey - - val streamInternalReportsRequest: StreamInternalReportsRequest = - listReportsPageToken.toStreamReportsRequest() - val results: List = - try { - internalReportsStub.streamReports(streamInternalReportsRequest).toList() - } catch (e: StatusException) { - throw Exception("Unable to list reports from the reporting database.", e) - } - - if (results.isEmpty()) { - return ListReportsResponse.getDefaultInstance() - } - - val nextPageToken: ListReportsPageToken? = - if (results.size > listReportsPageToken.pageSize) { - listReportsPageToken.copy { - lastReport = previousPageEnd { - measurementConsumerReferenceId = - results[results.lastIndex - 1].measurementConsumerReferenceId - externalReportId = results[results.lastIndex - 1].externalReportId - } - } - } else null - - return listReportsResponse { - reports += - results - .subList(0, min(results.size, listReportsPageToken.pageSize)) - .map { syncReport(it, apiAuthenticationKey, principalName) } - .map(InternalReport::toReport) - - if (nextPageToken != null) { - this.nextPageToken = nextPageToken.toByteString().base64UrlEncode() - } - } - } - - override suspend fun getReport(request: GetReportRequest): Report { - val reportKey = - grpcRequireNotNull(ReportKey.fromName(request.name)) { - "Report name is either unspecified or invalid" - } - - val principal: ReportingPrincipal = principalFromCurrentContext - when (principal) { - is MeasurementConsumerPrincipal -> { - if (reportKey.measurementConsumerId != principal.resourceKey.measurementConsumerId) { - failGrpc(Status.PERMISSION_DENIED) { - "Cannot get Report belonging to other MeasurementConsumers." - } - } - } - } - val principalName = principal.resourceKey.toName() - - val apiAuthenticationKey: String = principal.config.apiKey - - val internalReport = - try { - internalReportsStub.getReport( - getInternalReportRequest { - measurementConsumerReferenceId = reportKey.measurementConsumerId - externalReportId = apiIdToExternalId(reportKey.reportId) - } - ) - } catch (e: StatusException) { - throw Exception("Unable to get the report from the reporting database.", e) - } - - val syncedInternalReport = syncReport(internalReport, apiAuthenticationKey, principalName) - - return syncedInternalReport.toReport() - } - - /** Syncs the [InternalReport] and all [InternalMeasurement]s used by it. */ - private suspend fun syncReport( - internalReport: InternalReport, - apiAuthenticationKey: String, - principalName: String, - ): InternalReport { - // Report with SUCCEEDED or FAILED state is already synced. - if ( - internalReport.state == InternalReport.State.SUCCEEDED || - internalReport.state == InternalReport.State.FAILED - ) { - return internalReport - } else if ( - internalReport.state == InternalReport.State.STATE_UNSPECIFIED || - internalReport.state == InternalReport.State.UNRECOGNIZED - ) { - error( - "The measurements cannot be synced because the report state was not set correctly as it " + - "should've been." - ) - } - - // Syncs measurements - syncMeasurements( - internalReport.measurementsMap, - internalReport.measurementConsumerReferenceId, - apiAuthenticationKey, - principalName, - ) - - return try { - internalReportsStub.getReport( - getInternalReportRequest { - measurementConsumerReferenceId = internalReport.measurementConsumerReferenceId - externalReportId = internalReport.externalReportId - } - ) - } catch (e: StatusException) { - val reportName = - ReportKey( - internalReport.measurementConsumerReferenceId, - externalIdToApiId(internalReport.externalReportId), - ) - .toName() - throw Exception("Unable to get the report [$reportName] from the reporting database.", e) - } - } - - /** Syncs [InternalMeasurement]s. */ - private suspend fun syncMeasurements( - measurementsMap: Map, - measurementConsumerReferenceId: String, - apiAuthenticationKey: String, - principalName: String, - ) { - for ((measurementReferenceId, internalMeasurement) in measurementsMap) { - // Measurement with SUCCEEDED state is already synced - if (internalMeasurement.state == InternalMeasurement.State.SUCCEEDED) continue - - syncMeasurement( - measurementReferenceId, - measurementConsumerReferenceId, - apiAuthenticationKey, - principalName, - ) - } - } - - /** Syncs [InternalMeasurement] with the CMM [Measurement] given the measurement reference ID. */ - private suspend fun syncMeasurement( - measurementReferenceId: String, - measurementConsumerReferenceId: String, - apiAuthenticationKey: String, - principalName: String, - ) { - val measurementResourceName = - MeasurementKey(measurementConsumerReferenceId, measurementReferenceId).toName() - val measurement = - try { - measurementsStub - .withAuthenticationKey(apiAuthenticationKey) - .getMeasurement(getMeasurementRequest { name = measurementResourceName }) - } catch (e: StatusException) { - throw Exception("Unable to retrieve the measurement [$measurementResourceName].", e) - } - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (measurement.state) { - Measurement.State.SUCCEEDED -> { - // Converts a Measurement to an InternalMeasurement and store it into the database with - // SUCCEEDED state - val measurementSpec: MeasurementSpec = measurement.measurementSpec.unpack() - val encryptionPrivateKeyHandle = - encryptionKeyPairStore.getPrivateKeyHandle( - principalName, - measurementSpec.measurementPublicKey.unpack().data, - ) ?: failGrpc(Status.PERMISSION_DENIED) { "Encryption private key not found" } - - val setInternalMeasurementResultRequest = - buildSetInternalMeasurementResultRequest( - measurementConsumerReferenceId, - measurementReferenceId, - measurement.resultsList, - encryptionPrivateKeyHandle, - apiAuthenticationKey, - ) - - try { - internalMeasurementsStub.setMeasurementResult(setInternalMeasurementResultRequest) - } catch (e: StatusException) { - throw Exception( - "Unable to update the measurement [$measurementResourceName] in the reporting " + - "database.", - e, - ) - } - } - Measurement.State.AWAITING_REQUISITION_FULFILLMENT, - Measurement.State.COMPUTING -> {} // No action needed - Measurement.State.FAILED, - Measurement.State.CANCELLED -> { - val setInternalMeasurementFailureRequest = setInternalMeasurementFailureRequest { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - this.measurementReferenceId = measurementReferenceId - failure = measurement.failure.toInternal() - } - - try { - internalMeasurementsStub.setMeasurementFailure(setInternalMeasurementFailureRequest) - } catch (e: StatusException) { - throw Exception( - "Unable to update the measurement [$measurementResourceName] in the reporting " + - "database.", - e, - ) - } - } - Measurement.State.STATE_UNSPECIFIED -> error("The measurement state should've been set.") - Measurement.State.UNRECOGNIZED -> error("Unrecognized measurement state.") - } - } - - /** Builds a [SetInternalMeasurementResultRequest]. */ - private suspend fun buildSetInternalMeasurementResultRequest( - measurementConsumerReferenceId: String, - measurementReferenceId: String, - resultsList: List, - privateKeyHandle: PrivateKeyHandle, - apiAuthenticationKey: String, - ): SetInternalMeasurementResultRequest { - - return setInternalMeasurementResultRequest { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - this.measurementReferenceId = measurementReferenceId - result = - aggregateResults( - resultsList - .map { decryptMeasurementResultOutput(it, privateKeyHandle, apiAuthenticationKey) } - .map(Measurement.Result::toInternal) - ) - } - } - - /** Decrypts a [Measurement.ResultOutput] to [Measurement.Result] */ - private suspend fun decryptMeasurementResultOutput( - measurementResultOutput: Measurement.ResultOutput, - encryptionPrivateKeyHandle: PrivateKeyHandle, - apiAuthenticationKey: String, - ): Measurement.Result { - // TODO: Cache the certificate - val certificate = - try { - certificateStub - .withAuthenticationKey(apiAuthenticationKey) - .getCertificate(getCertificateRequest { name = measurementResultOutput.certificate }) - } catch (e: StatusException) { - throw Exception( - "Unable to retrieve the certificate [${measurementResultOutput.certificate}].", - e, - ) - } - - val signedResult: SignedMessage = - decryptResult(measurementResultOutput.encryptedResult, encryptionPrivateKeyHandle) - - val x509Certificate: X509Certificate = readCertificate(certificate.x509Der) - val trustedIssuer: X509Certificate = - checkNotNull(trustedCertificates[checkNotNull(x509Certificate.authorityKeyIdentifier)]) { - "${certificate.name} not issued by trusted CA" - } - // TODO: Record verification failure in internal Measurement rather than having the RPC fail. - try { - verifyResult(signedResult, x509Certificate, trustedIssuer) - } catch (e: CertPathValidatorException) { - throw Exception("Certificate path for ${certificate.name} is invalid", e) - } catch (e: SignatureException) { - throw Exception("Measurement result signature is invalid", e) - } - return signedResult.unpack() - } - - /** Builds an [InternalCreateReportRequest] from a public [CreateReportRequest]. */ - private suspend fun buildInternalCreateReportRequest( - request: CreateReportRequest, - reportInfo: ReportInfo, - namedSetOperationResults: Map, - ): InternalCreateReportRequest { - val internalReport: InternalReport = internalReport { - this.measurementConsumerReferenceId = reportInfo.measurementConsumerReferenceId - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (request.report.timeCase) { - Report.TimeCase.TIME_INTERVALS -> { - this.timeIntervals = request.report.timeIntervals.toInternal() - } - Report.TimeCase.PERIODIC_TIME_INTERVAL -> { - this.periodicTimeInterval = request.report.periodicTimeInterval.toInternal() - } - Report.TimeCase.TIME_NOT_SET -> - failGrpc(Status.INVALID_ARGUMENT) { "The time in Report is not specified." } - } - - for (metric in request.report.metricsList) { - this@internalReport.metrics += - buildInternalMetric(metric, reportInfo, namedSetOperationResults) - } - - details = - InternalReportKt.details { this.eventGroupFilters.putAll(reportInfo.eventGroupFilters) } - - this.reportIdempotencyKey = reportInfo.reportIdempotencyKey - } - - return internalCreateReportRequest { - report = internalReport - measurements += - internalReport.metricsList.flatMap { internalMetric -> - buildInternalMeasurementKeys(internalMetric, reportInfo.measurementConsumerReferenceId) - } - } - } - - /** Builds an [InternalMetric] from a public [Metric]. */ - private suspend fun buildInternalMetric( - metric: Metric, - reportInfo: ReportInfo, - namedSetOperationResults: Map, - ): InternalMetric { - return internalMetric { - details = buildInternalMetricDetails(metric) - - metric.setOperationsList.map { setOperation -> - val setOperationId = - buildSetOperationId(reportInfo.reportIdempotencyKey, details, setOperation.uniqueName) - - namedSetOperationResults[setOperationId]?.let { setOperationResult -> - val internalNamedSetOperation = - buildInternalNamedSetOperation(setOperation, reportInfo, setOperationResult) - namedSetOperations += internalNamedSetOperation - } - } - } - } - - /** Builds an [InternalNamedSetOperation] from a public [NamedSetOperation]. */ - private suspend fun buildInternalNamedSetOperation( - namedSetOperation: NamedSetOperation, - reportInfo: ReportInfo, - setOperationResult: SetOperationResult, - ): InternalNamedSetOperation { - return InternalMetricKt.namedSetOperation { - displayName = namedSetOperation.uniqueName - setOperation = - buildInternalSetOperation( - namedSetOperation.setOperation, - reportInfo.measurementConsumerReferenceId, - ) - - this.measurementCalculations += buildMeasurementCalculationList(setOperationResult) - } - } - - /** Builds an [InternalSetOperation] from a public [SetOperation]. */ - private suspend fun buildInternalSetOperation( - setOperation: SetOperation, - measurementConsumerReferenceId: String, - ): InternalSetOperation { - return InternalMetricKt.setOperation { - this.type = setOperation.type.toInternal() - this.lhs = buildInternalOperand(setOperation.lhs, measurementConsumerReferenceId) - this.rhs = buildInternalOperand(setOperation.rhs, measurementConsumerReferenceId) - } - } - - /** Builds an [InternalOperand] from an [Operand]. */ - private suspend fun buildInternalOperand( - operand: Operand, - measurementConsumerReferenceId: String, - ): InternalOperand { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - return when (operand.operandCase) { - Operand.OperandCase.OPERATION -> - InternalSetOperationKt.operand { - operation = buildInternalSetOperation(operand.operation, measurementConsumerReferenceId) - } - Operand.OperandCase.REPORTING_SET -> { - val reportingSetId = - grpcRequireNotNull(ReportingSetKey.fromName(operand.reportingSet)) { - "Invalid reporting set name ${operand.reportingSet}." - } - .reportingSetId - - InternalSetOperationKt.operand { - this.reportingSetId = - InternalSetOperationKt.reportingSetKey { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - externalReportingSetId = apiIdToExternalId(reportingSetId) - } - } - } - Operand.OperandCase.OPERAND_NOT_SET -> InternalSetOperationKt.operand {} - } - } - - /** - * Builds a list of [MeasurementCalculation]s from a list of [WeightedMeasurement]s and a list of - * [InternalTimeInterval]s. - */ - private fun buildMeasurementCalculationList( - setOperationResult: SetOperationResult - ): List { - val measurementCalculations = mutableListOf() - setOperationResult.weightedMeasurementInfoList - .groupBy { it.reportTimeInterval } - .forEach { (reportTimeInterval, weightedMeasurementInfos) -> - measurementCalculations.add( - InternalMetricKt.measurementCalculation { - timeInterval = reportTimeInterval.toInternal() - - for (weightedMeasurementInfo in weightedMeasurementInfos) { - weightedMeasurements += - MeasurementCalculationKt.weightedMeasurement { - measurementReferenceId = - checkNotNull(weightedMeasurementInfo.kingdomMeasurementId) - coefficient = weightedMeasurementInfo.weightedMeasurement.coefficient - } - } - } - ) - } - return measurementCalculations - } - - /** Builds a [CreateMeasurementRequest]. */ - private fun buildCreateMeasurementRequest( - measurementConsumer: MeasurementConsumer, - eventGroupEntriesByDataProvider: Map>, - internalMetricDetails: InternalMetricDetails, - requestId: String, - signingConfig: SigningConfig, - dataProviderInfoMap: Map, - ): CreateMeasurementRequest { - val measurementConsumerCertificate: X509Certificate = - readCertificate(signingConfig.signingCertificateDer) - val measurementConsumerSigningKey = - SigningKeyHandle(measurementConsumerCertificate, signingConfig.signingPrivateKey) - val measurementEncryptionPublicKey: ProtoAny = measurementConsumer.publicKey.message - - return createMeasurementRequest { - parent = measurementConsumer.name - measurement = measurement { - this.measurementConsumerCertificate = signingConfig.signingCertificateName - - dataProviders += - buildDataProviderEntries( - eventGroupEntriesByDataProvider, - measurementEncryptionPublicKey, - measurementConsumerSigningKey, - dataProviderInfoMap, - ) - - val unsignedMeasurementSpec: MeasurementSpec = - buildUnsignedMeasurementSpec( - measurementEncryptionPublicKey, - dataProviders.map { it.value.nonceHash }, - internalMetricDetails, - ) - - measurementSpec = - signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey) - } - this.requestId = requestId - } - } - - /** - * Converts internal event group entries into [EventGroupEntry] messages, grouping them by - * DataProvider. - */ - private fun groupEventGroupEntriesByDataProvider( - reportingSetNames: List, - timeInterval: Interval, - eventGroupFilters: Map, - internalReportingSetMap: Map, - ): Map> { - return reportingSetNames - .flatMap { - val reportingSetKey = - grpcRequireNotNull(ReportingSetKey.fromName(it)) { "Invalid reporting set name $it" } - val internalReportingSet = - internalReportingSetMap.getValue(apiIdToExternalId(reportingSetKey.reportingSetId)) - internalReportingSet.eventGroupKeysList.map { internalEventGroupKey -> - val eventGroupKey = - EventGroupKey( - internalEventGroupKey.measurementConsumerReferenceId, - internalEventGroupKey.dataProviderReferenceId, - internalEventGroupKey.eventGroupReferenceId, - ) - val eventGroupName = eventGroupKey.toName() - val filter = - combineEventGroupFilters(internalReportingSet.filter, eventGroupFilters[eventGroupName]) - - eventGroupKey to - RequisitionSpecKt.eventGroupEntry { - key = - CmmsEventGroupKey( - internalEventGroupKey.dataProviderReferenceId, - internalEventGroupKey.eventGroupReferenceId, - ) - .toName() - value = - RequisitionSpecKt.EventGroupEntryKt.value { - collectionInterval = timeInterval - if (filter != null) { - this.filter = RequisitionSpecKt.eventFilter { expression = filter } - } - } - } - } - } - .groupBy( - { (eventGroupKey, _) -> DataProviderKey(eventGroupKey.dataProviderReferenceId) }, - { (_, eventGroupEntry) -> eventGroupEntry }, - ) - } - - /** Builds a [Map] of [DataProvider] name to [DataProviderInfo]. */ - private suspend fun buildDataProviderInfoMap( - apiAuthenticationKey: String, - dataProviderNames: Collection, - ): Map { - val dataProviderInfoMap = mutableMapOf() - - if (dataProviderNames.isEmpty()) { - return dataProviderInfoMap - } - - val deferredDataProviderInfoList = mutableListOf>() - coroutineScope { - for (dataProviderName in dataProviderNames) { - deferredDataProviderInfoList.add( - async { - val dataProvider: DataProvider = - try { - dataProvidersStub - .withAuthenticationKey(apiAuthenticationKey) - .getDataProvider(getDataProviderRequest { name = dataProviderName }) - } catch (e: StatusException) { - throw when (e.status.code) { - Status.Code.NOT_FOUND -> - Status.FAILED_PRECONDITION.withDescription("$dataProviderName not found") - else -> Status.UNKNOWN.withDescription("Unable to retrieve $dataProviderName") - } - .withCause(e) - .asRuntimeException() - } - - val certificate: Certificate = - try { - certificateStub - .withAuthenticationKey(apiAuthenticationKey) - .getCertificate(getCertificateRequest { name = dataProvider.certificate }) - } catch (e: StatusException) { - throw Exception("Unable to retrieve Certificate ${dataProvider.certificate}", e) - } - if ( - certificate.revocationState != - Certificate.RevocationState.REVOCATION_STATE_UNSPECIFIED - ) { - throw Status.FAILED_PRECONDITION.withDescription( - "${certificate.name} revocation state is ${certificate.revocationState}" - ) - .asRuntimeException() - } - - val x509Certificate: X509Certificate = readCertificate(certificate.x509Der) - val trustedIssuer: X509Certificate = - trustedCertificates[checkNotNull(x509Certificate.authorityKeyIdentifier)] - ?: throw Status.FAILED_PRECONDITION.withDescription( - "${certificate.name} not issued by trusted CA" - ) - .asRuntimeException() - try { - verifyEncryptionPublicKey(dataProvider.publicKey, x509Certificate, trustedIssuer) - } catch (e: CertPathValidatorException) { - throw Status.FAILED_PRECONDITION.withCause(e) - .withDescription("Certificate path for ${certificate.name} is invalid") - .asRuntimeException() - } catch (e: SignatureException) { - throw Status.FAILED_PRECONDITION.withCause(e) - .withDescription("DataProvider public key signature is invalid") - .asRuntimeException() - } - - DataProviderInfo(dataProvider.name, dataProvider.publicKey, certificate.name) - } - ) - } - - for (deferredDataProviderInfo in deferredDataProviderInfoList.awaitAll()) { - dataProviderInfoMap[deferredDataProviderInfo.dataProviderName] = deferredDataProviderInfo - } - } - - return dataProviderInfoMap - } - - /** Builds a [List] of [DataProviderEntry] messages from [eventGroupEntriesByDataProvider]. */ - private fun buildDataProviderEntries( - eventGroupEntriesByDataProvider: Map>, - measurementEncryptionPublicKey: ProtoAny, - measurementConsumerSigningKey: SigningKeyHandle, - dataProviderInfoMap: Map, - ): List { - return eventGroupEntriesByDataProvider.map { (dataProviderKey, eventGroupEntriesList) -> - // TODO(@SanjayVas): Consider caching the public key and certificate. - val dataProviderName: String = dataProviderKey.toName() - val dataProviderInfo = dataProviderInfoMap.getValue(dataProviderName) - - val requisitionSpec = requisitionSpec { - events = RequisitionSpecKt.events { eventGroups += eventGroupEntriesList } - measurementPublicKey = measurementEncryptionPublicKey - // TODO(world-federation-of-advertisers/cross-media-measurement#1301): Stop setting this - // field. - serializedMeasurementPublicKey = measurementEncryptionPublicKey.value - nonce = secureRandom.nextLong() - } - val encryptRequisitionSpec = - encryptRequisitionSpec( - signRequisitionSpec(requisitionSpec, measurementConsumerSigningKey), - dataProviderInfo.publicKey.unpack(), - ) - - dataProviderEntry { - key = dataProviderName - value = - DataProviderEntryKt.value { - dataProviderCertificate = dataProviderInfo.certificateName - dataProviderPublicKey = dataProviderInfo.publicKey.message - this.encryptedRequisitionSpec = encryptRequisitionSpec - nonceHash = Hashing.hashSha256(requisitionSpec.nonce) - } - } - } - } - - /** Builds the unsigned [MeasurementSpec]. */ - private fun buildUnsignedMeasurementSpec( - measurementEncryptionPublicKey: ProtoAny, - nonceHashes: List, - internalMetricDetails: InternalMetricDetails, - ): MeasurementSpec { - val isSingleDataProvider = nonceHashes.size == 1 - return measurementSpec { - measurementPublicKey = measurementEncryptionPublicKey - // TODO(world-federation-of-advertisers/cross-media-measurement#1301): Stop setting this - // field. - serializedMeasurementPublicKey = measurementEncryptionPublicKey.value - this.nonceHashes += nonceHashes - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (internalMetricDetails.metricTypeCase) { - InternalMetricTypeCase.REACH -> { - if (isSingleDataProvider) { - reach = measurementSpecComponentFactory.getReachSingleDataProviderType() - vidSamplingInterval = - measurementSpecComponentFactory.getReachSingleDataProviderVidSamplingInterval() - } else { - reach = measurementSpecComponentFactory.getReachType() - vidSamplingInterval = measurementSpecComponentFactory.getReachVidSamplingInterval() - } - } - InternalMetricTypeCase.FREQUENCY_HISTOGRAM -> { - if (isSingleDataProvider) { - reachAndFrequency = - measurementSpecComponentFactory.getReachAndFrequencySingleDataProviderType( - internalMetricDetails.frequencyHistogram.maximumFrequency - ) - vidSamplingInterval = - measurementSpecComponentFactory - .getReachAndFrequencySingleDataProviderVidSamplingInterval() - } else { - reachAndFrequency = - measurementSpecComponentFactory.getReachAndFrequencyType( - internalMetricDetails.frequencyHistogram.maximumFrequency - ) - vidSamplingInterval = - measurementSpecComponentFactory.getReachAndFrequencyVidSamplingInterval() - } - } - InternalMetricTypeCase.IMPRESSION_COUNT -> { - impression = - measurementSpecComponentFactory.getImpressionType( - internalMetricDetails.impressionCount.maximumFrequencyPerUser - ) - vidSamplingInterval = measurementSpecComponentFactory.getImpressionVidSamplingInterval() - } - InternalMetricTypeCase.WATCH_DURATION -> { - duration = - measurementSpecComponentFactory.getDurationType( - internalMetricDetails.watchDuration.maximumWatchDurationPerUserSeconds - ) - vidSamplingInterval = measurementSpecComponentFactory.getDurationVidSamplingInterval() - } - InternalMetricTypeCase.METRICTYPE_NOT_SET -> - error("Unset metric type should've already raised error.") - } - } - } - - private class Metrics( - private val reportInfo: ReportInfo, - private val internalReportingSetsStub: ReportingSetsGrpcKt.ReportingSetsCoroutineStub, - private val setOperationCompiler: SetOperationCompiler, - private val report: Report, - ) { - - private val reportingSetExternalIds: MutableSet = mutableSetOf() - private lateinit var namedSetOperationResults: Map - private lateinit var internalReportingSetMap: Map - - suspend fun process() { - if (this::namedSetOperationResults.isInitialized) { - return - } - namedSetOperationResults = compileAllSetOperations() - internalReportingSetMap = getReportingSets() - internalReportingSetMap.values.forEach { - it.checkReportingSetEventGroupFilters(reportInfo.eventGroupFilters) - } - } - - suspend fun getNamedSetOperationResults(): Map { - if (!this::namedSetOperationResults.isInitialized) { - process() - } - return namedSetOperationResults - } - - suspend fun getInternalReportingSetsMap(): Map { - if (!this::namedSetOperationResults.isInitialized) { - process() - } - return internalReportingSetMap - } - - /** Compiles all [SetOperation]s and outputs each result with measurement reference ID. */ - private suspend fun compileAllSetOperations(): Map { - val namedSetOperationResults = mutableMapOf() - - var hasCumulativeMetric = false - for (metric in report.metricsList) { - if (metric.cumulative) { - hasCumulativeMetric = true - break - } - } - val timeIntervalsList = report.timeIntervalsList(hasCumulativeMetric) - val sortedTimeIntervalsList = timeIntervalsList.sortedWith(timeIntervalComparator) - val cumulativeTimeIntervalsList = - sortedTimeIntervalsList.map { timeInterval -> - timeInterval.copy { this.startTime = sortedTimeIntervalsList.first().startTime } - } - - for (metric in report.metricsList) { - val metricTimeIntervalsList = - if (metric.cumulative) cumulativeTimeIntervalsList else sortedTimeIntervalsList - val internalMetricDetails: InternalMetricDetails = buildInternalMetricDetails(metric) - - for (namedSetOperation in metric.setOperationsList) { - checkSetOperationReportingSetName(namedSetOperation.setOperation) - - val setOperationId = - buildSetOperationId( - reportInfo.reportIdempotencyKey, - internalMetricDetails, - namedSetOperation.uniqueName, - ) - - val weightedMeasurementInfoList = - compileSetOperation( - namedSetOperation.setOperation, - setOperationId, - metricTimeIntervalsList, - sortedTimeIntervalsList, - ) - namedSetOperationResults[setOperationId] = - SetOperationResult(weightedMeasurementInfoList, internalMetricDetails) - } - } - - return namedSetOperationResults.toMap() - } - - /** Checks if all reporting sets under a [SetOperation] have valid names. */ - private fun checkSetOperationReportingSetName(setOperation: SetOperation) { - checkOperandReportingSetName(setOperation.lhs) - checkOperandReportingSetName(setOperation.rhs) - } - - /** Checks if all reporting sets under a [Operand] have valid names. */ - private fun checkOperandReportingSetName(operand: Operand) { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (operand.operandCase) { - Operand.OperandCase.OPERATION -> checkSetOperationReportingSetName(operand.operation) - Operand.OperandCase.REPORTING_SET -> checkReportingSetName(operand.reportingSet) - Operand.OperandCase.OPERAND_NOT_SET -> {} - } - } - - /** Check if the event groups in the public [ReportingSet] have valid names. */ - private fun checkReportingSetName(reportingSetName: String) { - val reportingSetKey = - grpcRequireNotNull(ReportingSetKey.fromName(reportingSetName)) { - "Invalid reporting set name $reportingSetName." - } - - grpcRequire( - reportingSetKey.measurementConsumerId == reportInfo.measurementConsumerReferenceId - ) { - "No access to the reporting set [$reportingSetName]." - } - - reportingSetExternalIds.add(apiIdToExternalId(reportingSetKey.reportingSetId)) - } - - /** - * Compiles a [SetOperation] and outputs each result with measurement reference ID. - * - * metricTimeIntervalsList and timeIntervalsList are required to be the same size. - */ - private suspend fun compileSetOperation( - setOperation: SetOperation, - setOperationId: String, - metricTimeIntervalsList: List, - reportTimeIntervalsList: List, - ): List { - if (metricTimeIntervalsList.size != reportTimeIntervalsList.size) { - throw IllegalArgumentException() - } - val sortedReportTimeIntervalsList = reportTimeIntervalsList.sortedWith(timeIntervalComparator) - - val weightedMeasurementsList = setOperationCompiler.compileSetOperation(setOperation) - - return metricTimeIntervalsList.sortedWith(timeIntervalComparator).flatMapIndexed { - timeIntervalsIndex, - timeInterval -> - weightedMeasurementsList.mapIndexed { index, weightedMeasurement -> - val measurementReferenceId = - buildMeasurementReferenceId(setOperationId, timeInterval, index) - - WeightedMeasurementInfo( - measurementReferenceId, - weightedMeasurement, - timeInterval = timeInterval, - reportTimeInterval = sortedReportTimeIntervalsList[timeIntervalsIndex], - ) - } - } - } - - private suspend fun getReportingSets(): Map { - val batchGetReportingSetRequest = batchGetReportingSetRequest { - measurementConsumerReferenceId = reportInfo.measurementConsumerReferenceId - reportingSetExternalIds.forEach { externalReportingSetIds += it } - } - - val internalReportingSetsList = - internalReportingSetsStub.batchGetReportingSet(batchGetReportingSetRequest).toList() - - if (internalReportingSetsList.size < reportingSetExternalIds.size) { - val errorMessage = StringBuilder("The following reporting set names were not found:") - internalReportingSetsList.forEach { - reportingSetExternalIds.remove(it.externalReportingSetId) - } - reportingSetExternalIds.forEach { - errorMessage.append( - " ${ReportingSetKey(reportInfo.measurementConsumerReferenceId, externalIdToApiId(it)).toName()}" - ) - } - failGrpc(Status.NOT_FOUND) { errorMessage.toString() } - } - - return internalReportingSetsList.associateBy { it.externalReportingSetId } - } - } - - private class MeasurementSpecComponentFactory( - private val measurementSpecConfig: MeasurementSpecConfig, - private val secureRandom: Random, - ) { - private val DEFAULT_VID_START = 0.0f - private val DEFAULT_VID_WIDTH = 1.0f - - private val reachSingleDataProviderType = - MeasurementSpecKt.reach { - privacyParams = differentialPrivacyParams { - epsilon = measurementSpecConfig.reachSingleDataProvider.privacyParams.epsilon - delta = measurementSpecConfig.reachSingleDataProvider.privacyParams.delta - } - } - - private val reachType = - MeasurementSpecKt.reach { - privacyParams = differentialPrivacyParams { - epsilon = measurementSpecConfig.reach.privacyParams.epsilon - delta = measurementSpecConfig.reach.privacyParams.delta - } - } - - private val reachAndFrequencySingleDataProviderReachPrivacyParams = differentialPrivacyParams { - epsilon = measurementSpecConfig.reachAndFrequencySingleDataProvider.reachPrivacyParams.epsilon - delta = measurementSpecConfig.reachAndFrequencySingleDataProvider.reachPrivacyParams.delta - } - - private val reachAndFrequencySingleDataProviderFrequencyPrivacyParams = - differentialPrivacyParams { - epsilon = - measurementSpecConfig.reachAndFrequencySingleDataProvider.frequencyPrivacyParams.epsilon - delta = - measurementSpecConfig.reachAndFrequencySingleDataProvider.frequencyPrivacyParams.delta - } - - private val reachAndFrequencyReachPrivacyParams = differentialPrivacyParams { - epsilon = measurementSpecConfig.reachAndFrequency.reachPrivacyParams.epsilon - delta = measurementSpecConfig.reachAndFrequency.reachPrivacyParams.delta - } - - private val reachAndFrequencyFrequencyPrivacyParams = differentialPrivacyParams { - epsilon = measurementSpecConfig.reachAndFrequency.frequencyPrivacyParams.epsilon - delta = measurementSpecConfig.reachAndFrequency.frequencyPrivacyParams.delta - } - - private val impressionPrivacyParams = differentialPrivacyParams { - epsilon = measurementSpecConfig.impression.privacyParams.epsilon - delta = measurementSpecConfig.impression.privacyParams.delta - } - - private val durationPrivacyParams = differentialPrivacyParams { - epsilon = measurementSpecConfig.duration.privacyParams.epsilon - delta = measurementSpecConfig.duration.privacyParams.delta - } - - private fun createVidSamplingInterval( - vidSamplingInterval: MeasurementSpecConfig.VidSamplingInterval - ): MeasurementSpec.VidSamplingInterval { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") - return when (vidSamplingInterval.startCase) { - MeasurementSpecConfig.VidSamplingInterval.StartCase.FIXED_START -> - MeasurementSpecKt.vidSamplingInterval { - start = vidSamplingInterval.fixedStart.start - width = vidSamplingInterval.fixedStart.width - } - MeasurementSpecConfig.VidSamplingInterval.StartCase.RANDOM_START -> - MeasurementSpecKt.vidSamplingInterval { - start = - calculateRandomVidStart( - vidSamplingInterval.randomStart.width, - vidSamplingInterval.randomStart.numVidBuckets, - ) - width = - vidSamplingInterval.randomStart.width.toFloat() / - vidSamplingInterval.randomStart.numVidBuckets - } - MeasurementSpecConfig.VidSamplingInterval.StartCase.START_NOT_SET -> - MeasurementSpecKt.vidSamplingInterval { - start = DEFAULT_VID_START - width = DEFAULT_VID_WIDTH - } - } - } - - private fun calculateRandomVidStart(width: Int, numVidBuckets: Int): Float { - val maxStart = numVidBuckets - width - val start = secureRandom.nextInt(maxStart + 1) - return start.toFloat() / numVidBuckets - } - - fun getReachSingleDataProviderVidSamplingInterval(): MeasurementSpec.VidSamplingInterval { - val vidSamplingInterval = measurementSpecConfig.reachSingleDataProvider.vidSamplingInterval - return createVidSamplingInterval(vidSamplingInterval) - } - - fun getReachSingleDataProviderType(): MeasurementSpec.Reach { - return reachSingleDataProviderType - } - - fun getReachVidSamplingInterval(): MeasurementSpec.VidSamplingInterval { - val vidSamplingInterval = measurementSpecConfig.reach.vidSamplingInterval - return createVidSamplingInterval(vidSamplingInterval) - } - - fun getReachType(): MeasurementSpec.Reach { - return reachType - } - - fun getReachAndFrequencySingleDataProviderVidSamplingInterval(): - MeasurementSpec.VidSamplingInterval { - val vidSamplingInterval = - measurementSpecConfig.reachAndFrequencySingleDataProvider.vidSamplingInterval - return createVidSamplingInterval(vidSamplingInterval) - } - - fun getReachAndFrequencySingleDataProviderType( - maximumFrequency: Int - ): MeasurementSpec.ReachAndFrequency { - return MeasurementSpecKt.reachAndFrequency { - reachPrivacyParams = reachAndFrequencySingleDataProviderReachPrivacyParams - frequencyPrivacyParams = reachAndFrequencySingleDataProviderFrequencyPrivacyParams - this.maximumFrequency = maximumFrequency - } - } - - fun getReachAndFrequencyVidSamplingInterval(): MeasurementSpec.VidSamplingInterval { - val vidSamplingInterval = measurementSpecConfig.reachAndFrequency.vidSamplingInterval - return createVidSamplingInterval(vidSamplingInterval) - } - - fun getReachAndFrequencyType(maximumFrequency: Int): MeasurementSpec.ReachAndFrequency { - return MeasurementSpecKt.reachAndFrequency { - reachPrivacyParams = reachAndFrequencyReachPrivacyParams - frequencyPrivacyParams = reachAndFrequencyFrequencyPrivacyParams - this.maximumFrequency = maximumFrequency - } - } - - fun getImpressionVidSamplingInterval(): MeasurementSpec.VidSamplingInterval { - val vidSamplingInterval = measurementSpecConfig.impression.vidSamplingInterval - return createVidSamplingInterval(vidSamplingInterval) - } - - fun getImpressionType(maximumFrequencyPerUser: Int): MeasurementSpec.Impression { - return MeasurementSpecKt.impression { - privacyParams = impressionPrivacyParams - this.maximumFrequencyPerUser = maximumFrequencyPerUser - } - } - - fun getDurationVidSamplingInterval(): MeasurementSpec.VidSamplingInterval { - val vidSamplingInterval = measurementSpecConfig.duration.vidSamplingInterval - return createVidSamplingInterval(vidSamplingInterval) - } - - fun getDurationType(maximumWatchDurationPerUserSeconds: Int): MeasurementSpec.Duration { - return MeasurementSpecKt.duration { - privacyParams = durationPrivacyParams - maximumWatchDurationPerUser = - Durations.fromSeconds(maximumWatchDurationPerUserSeconds.toLong()) - } - } - } -} - -/** Converts the time in [Report] to a list of [TimeInterval]. */ -private fun Report.timeIntervalsList(hasCumulativeMetric: Boolean): List { - val source = this - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - return when (source.timeCase) { - Report.TimeCase.TIME_INTERVALS -> { - if (hasCumulativeMetric) { - failGrpc(Status.INVALID_ARGUMENT) { "Cannot use TimeIntervals with a cumulative Metric." } - } - grpcRequire(source.timeIntervals.timeIntervalsList.isNotEmpty()) { - "TimeIntervals timeIntervalsList is empty." - } - source.timeIntervals.timeIntervalsList.forEach { - grpcRequire(it.startTime.seconds > 0 || it.startTime.nanos > 0) { - "TimeInterval startTime is unspecified." - } - grpcRequire(it.endTime.seconds > 0 || it.endTime.nanos > 0) { - "TimeInterval endTime is unspecified." - } - grpcRequire( - it.endTime.seconds > it.startTime.seconds || it.endTime.nanos > it.startTime.nanos - ) { - "TimeInterval endTime is not later than startTime." - } - } - source.timeIntervals.timeIntervalsList.map { it } - } - Report.TimeCase.PERIODIC_TIME_INTERVAL -> { - grpcRequire( - source.periodicTimeInterval.startTime.seconds > 0 || - source.periodicTimeInterval.startTime.nanos > 0 - ) { - "PeriodicTimeInterval startTime is unspecified." - } - grpcRequire( - source.periodicTimeInterval.increment.seconds > 0 || - source.periodicTimeInterval.increment.nanos > 0 - ) { - "PeriodicTimeInterval increment is unspecified." - } - grpcRequire(source.periodicTimeInterval.intervalCount > 0) { - "PeriodicTimeInterval intervalCount is unspecified." - } - source.periodicTimeInterval.toTimeIntervalsList() - } - Report.TimeCase.TIME_NOT_SET -> - failGrpc(Status.INVALID_ARGUMENT) { "The time in Report is not specified." } - } -} - -/** - * Check if the event groups in the internal [InternalReportingSet] are covered by the event group - * universe. - */ -private fun InternalReportingSet.checkReportingSetEventGroupFilters( - eventGroupFilters: Map -) { - for (eventGroupKey in this.eventGroupKeysList) { - val eventGroupName = - EventGroupKey( - eventGroupKey.measurementConsumerReferenceId, - eventGroupKey.dataProviderReferenceId, - eventGroupKey.eventGroupReferenceId, - ) - .toName() - val internalReportingSetDisplayName = this.displayName - grpcRequire(eventGroupFilters.containsKey(eventGroupName)) { - "The event group [$eventGroupName] in the reporting set " + - "[$internalReportingSetDisplayName] is not included in the event group universe." - } - } -} - -/** Check if the names of the set operations within the same metric type are unique. */ -private fun checkSetOperationNamesUniqueness(metricsList: List) { - val seenNames = mutableMapOf>().withDefault { mutableSetOf() } - - for (metric in metricsList) { - for (setOperation in metric.setOperationsList) { - grpcRequire(!seenNames.getValue(metric.metricTypeCase).contains(setOperation.uniqueName)) { - "The names of the set operations within the same metric type should be unique." - } - seenNames.getOrPut(metric.metricTypeCase, ::mutableSetOf) += setOperation.uniqueName - } - } -} - -/** Builds a list of [InternalMeasurementKey]s from an [InternalMetric]. */ -private fun buildInternalMeasurementKeys( - internalMetric: InternalMetric, - measurementConsumerReferenceId: String, -): List { - return internalMetric.namedSetOperationsList - .flatMap { namedSetOperation -> - namedSetOperation.measurementCalculationsList.flatMap { measurementCalculation -> - measurementCalculation.weightedMeasurementsList.map { it.measurementReferenceId } - } - } - .map { measurementReferenceId -> - InternalCreateReportRequestKt.measurementKey { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - this.measurementReferenceId = measurementReferenceId - } - } -} - -/** Converts an [TimeInterval] to a [Interval] for measurement request. */ -private fun TimeInterval.toMeasurementTimeInterval(): Interval { - val source = this - return interval { - startTime = source.startTime - endTime = source.endTime - } -} - -/** Builds an [InternalMetricDetails] from a [Metric]. */ -private fun buildInternalMetricDetails(metric: Metric): InternalMetricDetails { - return InternalMetricKt.details { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (metric.metricTypeCase) { - MetricTypeCase.REACH -> reach = InternalMetricKt.reachParams {} - MetricTypeCase.FREQUENCY_HISTOGRAM -> - frequencyHistogram = metric.frequencyHistogram.toInternal() - MetricTypeCase.IMPRESSION_COUNT -> impressionCount = metric.impressionCount.toInternal() - MetricTypeCase.WATCH_DURATION -> watchDuration = metric.watchDuration.toInternal() - MetricTypeCase.METRICTYPE_NOT_SET -> - failGrpc(Status.INVALID_ARGUMENT) { "The metric type in Report is not specified." } - } - - cumulative = metric.cumulative - } -} - -/** Builds a unique ID for a [SetOperation]. */ -private fun buildSetOperationId( - reportIdempotencyKey: String, - internalMetricDetails: InternalMetricDetails, - setOperationUniqueName: String, -): String { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - val metricType = - when (internalMetricDetails.metricTypeCase) { - InternalMetricTypeCase.REACH -> "Reach" - InternalMetricTypeCase.FREQUENCY_HISTOGRAM -> "FrequencyHistogram" - InternalMetricTypeCase.IMPRESSION_COUNT -> "ImpressionCount" - InternalMetricTypeCase.WATCH_DURATION -> "WatchDuration" - InternalMetricTypeCase.METRICTYPE_NOT_SET -> - error("Unset metric type should've already raised error.") - } - - return "$reportIdempotencyKey-$metricType-$setOperationUniqueName" -} - -/** Builds a unique reference ID for a [Measurement]. */ -private fun buildMeasurementReferenceId( - setOperationId: String, - timeInterval: TimeInterval, - index: Int, -): String { - val rowHeader = buildRowHeader(timeInterval) - return "$setOperationId-$rowHeader-measurement-$index" -} - -/** Combines two event group filters. */ -private fun combineEventGroupFilters(filter1: String?, filter2: String?): String? { - if (filter1.isNullOrBlank()) return filter2 - - return if (filter2.isNullOrBlank()) filter1 - else { - "($filter1) && ($filter2)" - } -} - -/** Converts a public [SetOperation.Type] to an [InternalSetOperation.Type]. */ -private fun SetOperation.Type.toInternal(): InternalSetOperation.Type { - return when (this) { - SetOperation.Type.UNION -> InternalSetOperation.Type.UNION - SetOperation.Type.INTERSECTION -> InternalSetOperation.Type.INTERSECTION - SetOperation.Type.DIFFERENCE -> InternalSetOperation.Type.DIFFERENCE - SetOperation.Type.TYPE_UNSPECIFIED -> error("Set operator type is not specified.") - SetOperation.Type.UNRECOGNIZED -> error("Unrecognized Set operator type.") - } -} - -/** Converts a [WatchDurationParams] to an [InternalWatchDurationParams]. */ -private fun WatchDurationParams.toInternal(): InternalWatchDurationParams { - val source = this - return InternalMetricKt.watchDurationParams { - maximumWatchDurationPerUserSeconds = source.maximumWatchDurationPerUser - } -} - -/** Converts a [ImpressionCountParams] to an [InternalImpressionCountParams]. */ -private fun ImpressionCountParams.toInternal(): InternalImpressionCountParams { - val source = this - return InternalMetricKt.impressionCountParams { - maximumFrequencyPerUser = source.maximumFrequencyPerUser - } -} - -/** Converts a [FrequencyHistogramParams] to an [InternalFrequencyHistogramParams]. */ -private fun FrequencyHistogramParams.toInternal(): InternalFrequencyHistogramParams { - val source = this - return InternalMetricKt.frequencyHistogramParams { - maximumFrequency = source.maximumFrequencyPerUser - } -} - -/** Converts a public [PeriodicTimeInterval] to an [InternalPeriodicTimeInterval]. */ -private fun PeriodicTimeInterval.toInternal(): InternalPeriodicTimeInterval { - val source = this - return internalPeriodicTimeInterval { - startTime = source.startTime - increment = source.increment - intervalCount = source.intervalCount - } -} - -/** Converts a public [TimeInterval] to an [InternalTimeInterval]. */ -private fun TimeInterval.toInternal(): InternalTimeInterval { - val source = this - return internalTimeInterval { - startTime = source.startTime - endTime = source.endTime - } -} - -/** Converts a public [TimeIntervals] to an [InternalTimeIntervals]. */ -private fun TimeIntervals.toInternal(): InternalTimeIntervals { - val source = this - return internalTimeIntervals { - for (timeInternal in source.timeIntervalsList) { - this.timeIntervals += internalTimeInterval { - startTime = timeInternal.startTime - endTime = timeInternal.endTime - } - } - } -} - -/** Convert an [PeriodicTimeInterval] to a list of [TimeInterval]s. */ -private fun PeriodicTimeInterval.toTimeIntervalsList(): List { - val source = this - var startTime = checkNotNull(source.startTime) - return (0 until source.intervalCount).map { - timeInterval { - this.startTime = startTime - this.endTime = Timestamps.add(startTime, source.increment) - startTime = this.endTime - } - } -} - -/** Builds a row header in String from an [TimeInterval]. */ -private fun buildRowHeader(timeInterval: TimeInterval): String { - val startTimeInstant = - Instant.ofEpochSecond(timeInterval.startTime.seconds, timeInterval.startTime.nanos.toLong()) - val endTimeInstant = - Instant.ofEpochSecond(timeInterval.endTime.seconds, timeInterval.endTime.nanos.toLong()) - return "$startTimeInstant-$endTimeInstant" -} - -private operator fun ProtoDuration.plus(other: ProtoDuration): ProtoDuration { - return Durations.add(this, other) -} - -/** Converts a CMM [Measurement.Failure] to an [InternalMeasurement.Failure]. */ -private fun Measurement.Failure.toInternal(): InternalMeasurement.Failure { - val source = this - - return InternalMeasurementKt.failure { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - reason = - when (source.reason) { - Measurement.Failure.Reason.REASON_UNSPECIFIED -> - InternalMeasurement.Failure.Reason.REASON_UNSPECIFIED - Measurement.Failure.Reason.CERTIFICATE_REVOKED -> - InternalMeasurement.Failure.Reason.CERTIFICATE_REVOKED - Measurement.Failure.Reason.REQUISITION_REFUSED -> - InternalMeasurement.Failure.Reason.REQUISITION_REFUSED - Measurement.Failure.Reason.COMPUTATION_PARTICIPANT_FAILED -> - InternalMeasurement.Failure.Reason.COMPUTATION_PARTICIPANT_FAILED - Measurement.Failure.Reason.UNRECOGNIZED -> InternalMeasurement.Failure.Reason.UNRECOGNIZED - } - message = source.message - } -} - -/** Aggregate a list of [InternalMeasurementResult] to a [InternalMeasurementResult] */ -private fun aggregateResults( - internalResultsList: List -): InternalMeasurementResult { - if (internalResultsList.isEmpty()) { - error("No measurement result.") - } - - var reachValue = 0L - var impressionValue = 0L - val frequencyDistribution = mutableMapOf() - var watchDurationValue = ProtoDuration.getDefaultInstance() - - // Aggregation - for (result in internalResultsList) { - if (result.hasFrequency()) { - if (!result.hasReach()) { - error("Missing reach measurement in the Reach-Frequency measurement.") - } - for ((frequency, percentage) in result.frequency.relativeFrequencyDistributionMap) { - val previousTotalReachCount = - frequencyDistribution.getOrDefault(frequency, 0.0) * reachValue - val currentReachCount = percentage * result.reach.value - frequencyDistribution[frequency] = - (previousTotalReachCount + currentReachCount) / (reachValue + result.reach.value) - } - } - if (result.hasReach()) { - reachValue += result.reach.value - } - if (result.hasImpression()) { - impressionValue += result.impression.value - } - if (result.hasWatchDuration()) { - watchDurationValue += result.watchDuration.value - } - } - - return InternalMeasurementKt.result { - if (internalResultsList.first().hasReach()) { - this.reach = InternalMeasurementKt.ResultKt.reach { value = reachValue } - } - if (internalResultsList.first().hasFrequency()) { - this.frequency = - InternalMeasurementKt.ResultKt.frequency { - relativeFrequencyDistribution.putAll(frequencyDistribution) - } - } - if (internalResultsList.first().hasImpression()) { - this.impression = InternalMeasurementKt.ResultKt.impression { value = impressionValue } - } - if (internalResultsList.first().hasWatchDuration()) { - this.watchDuration = - InternalMeasurementKt.ResultKt.watchDuration { value = watchDurationValue } - } - } -} - -/** Converts a CMM [Measurement.Result] to an [InternalMeasurementResult]. */ -private fun Measurement.Result.toInternal(): InternalMeasurementResult { - val source = this - - return InternalMeasurementKt.result { - if (source.hasReach()) { - this.reach = InternalMeasurementKt.ResultKt.reach { value = source.reach.value } - } - if (source.hasFrequency()) { - this.frequency = - InternalMeasurementKt.ResultKt.frequency { - relativeFrequencyDistribution.putAll(source.frequency.relativeFrequencyDistributionMap) - } - } - if (source.hasImpression()) { - this.impression = - InternalMeasurementKt.ResultKt.impression { value = source.impression.value } - } - if (source.hasWatchDuration()) { - this.watchDuration = - InternalMeasurementKt.ResultKt.watchDuration { value = source.watchDuration.value } - } - } -} - -/** Converts an internal [InternalReport] to a public [Report]. */ -private fun InternalReport.toReport(): Report { - val source = this - val reportResourceName = - ReportKey( - measurementConsumerId = source.measurementConsumerReferenceId, - reportId = externalIdToApiId(source.externalReportId), - ) - .toName() - val measurementConsumerResourceName = - MeasurementConsumerKey(source.measurementConsumerReferenceId).toName() - val eventGroupEntries = - source.details.eventGroupFiltersMap.map { (eventGroupResourceName, filterPredicate) -> - EventGroupUniverseKt.eventGroupEntry { - key = eventGroupResourceName - value = filterPredicate - } - } - - return report { - name = reportResourceName - reportIdempotencyKey = source.reportIdempotencyKey - measurementConsumer = measurementConsumerResourceName - eventGroupUniverse = eventGroupUniverse { this.eventGroupEntries += eventGroupEntries } - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (source.timeCase) { - InternalReport.TimeCase.TIME_INTERVALS -> - this.timeIntervals = source.timeIntervals.toTimeIntervals() - InternalReport.TimeCase.PERIODIC_TIME_INTERVAL -> - this.periodicTimeInterval = source.periodicTimeInterval.toPeriodicTimeInterval() - InternalReport.TimeCase.TIME_NOT_SET -> - error("The time in the internal report should've been set.") - } - - for (metric in source.metricsList) { - this.metrics += metric.toMetric() - } - - this.state = source.state.toState() - if (source.details.hasResult()) { - this.result = source.details.result.toResult() - } - } -} - -/** Converts an [InternalReport.State] to a public [Report.State]. */ -private fun InternalReport.State.toState(): Report.State { - return when (this) { - InternalReport.State.RUNNING -> Report.State.RUNNING - InternalReport.State.SUCCEEDED -> Report.State.SUCCEEDED - InternalReport.State.FAILED -> Report.State.FAILED - InternalReport.State.STATE_UNSPECIFIED -> error("Report state should've been set.") - InternalReport.State.UNRECOGNIZED -> error("Unrecognized report state.") - } -} - -/** Converts an [InternalReport.Details.Result] to a public [Report.Result]. */ -private fun InternalReport.Details.Result.toResult(): Result { - val source = this - return result { - scalarTable = scalarTable { - rowHeaders += source.scalarTable.rowHeadersList - for (sourceColumn in source.scalarTable.columnsList) { - columns += column { - columnHeader = sourceColumn.columnHeader - setOperations += sourceColumn.setOperationsList - } - } - } - for (sourceHistogram in source.histogramTablesList) { - histogramTables += histogramTable { - for (sourceRow in sourceHistogram.rowsList) { - rows += row { - rowHeader = sourceRow.rowHeader - frequency = sourceRow.frequency - } - } - for (sourceColumn in sourceHistogram.columnsList) { - columns += column { - columnHeader = sourceColumn.columnHeader - setOperations += sourceColumn.setOperationsList - } - } - } - } - } -} - -/** Converts an internal [InternalMetric] to a public [Metric]. */ -private fun InternalMetric.toMetric(): Metric { - val source = this - - return metric { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (source.details.metricTypeCase) { - InternalMetricTypeCase.REACH -> reach = reachParams {} - InternalMetricTypeCase.FREQUENCY_HISTOGRAM -> - frequencyHistogram = source.details.frequencyHistogram.toFrequencyHistogram() - InternalMetricTypeCase.IMPRESSION_COUNT -> - impressionCount = source.details.impressionCount.toImpressionCount() - InternalMetricTypeCase.WATCH_DURATION -> - watchDuration = source.details.watchDuration.toWatchDuration() - InternalMetricTypeCase.METRICTYPE_NOT_SET -> - error("The metric type in the internal report should've been set.") - } - - cumulative = source.details.cumulative - - for (internalSetOperation in source.namedSetOperationsList) { - setOperations += internalSetOperation.toNamedSetOperation() - } - } -} - -/** Converts an internal [InternalNamedSetOperation] to a public [NamedSetOperation]. */ -private fun InternalNamedSetOperation.toNamedSetOperation(): NamedSetOperation { - val source = this - - return namedSetOperation { - uniqueName = source.displayName - setOperation = source.setOperation.toSetOperation() - } -} - -/** Converts an internal [InternalSetOperation] to a public [SetOperation]. */ -private fun InternalSetOperation.toSetOperation(): SetOperation { - val source = this - - return setOperation { - this.type = source.type.toType() - this.lhs = source.lhs.toOperand() - this.rhs = source.rhs.toOperand() - } -} - -/** Converts an internal [InternalOperand] to a public [Operand]. */ -private fun InternalOperand.toOperand(): Operand { - val source = this - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - return when (source.operandCase) { - InternalOperand.OperandCase.OPERATION -> - operand { operation = source.operation.toSetOperation() } - InternalOperand.OperandCase.REPORTINGSETID -> - operand { - reportingSet = - ReportingSetKey( - source.reportingSetId.measurementConsumerReferenceId, - externalIdToApiId(source.reportingSetId.externalReportingSetId), - ) - .toName() - } - InternalOperand.OperandCase.OPERAND_NOT_SET -> operand {} - } -} - -/** Converts an internal [InternalSetOperation.Type] to a public [SetOperation.Type]. */ -private fun InternalSetOperation.Type.toType(): SetOperation.Type { - return when (this) { - InternalSetOperation.Type.UNION -> SetOperation.Type.UNION - InternalSetOperation.Type.INTERSECTION -> SetOperation.Type.INTERSECTION - InternalSetOperation.Type.DIFFERENCE -> SetOperation.Type.DIFFERENCE - InternalSetOperation.Type.TYPE_UNSPECIFIED -> error("Set operator type should've been set.") - InternalSetOperation.Type.UNRECOGNIZED -> error("Unrecognized Set operator type.") - } -} - -/** Converts an internal [InternalWatchDurationParams] to a public [WatchDurationParams]. */ -private fun InternalWatchDurationParams.toWatchDuration(): WatchDurationParams { - val source = this - return watchDurationParams { - maximumWatchDurationPerUser = source.maximumWatchDurationPerUserSeconds - } -} - -/** Converts an internal [InternalImpressionCountParams] to a public [ImpressionCountParams]. */ -private fun InternalImpressionCountParams.toImpressionCount(): ImpressionCountParams { - val source = this - return impressionCountParams { maximumFrequencyPerUser = source.maximumFrequencyPerUser } -} - -/** - * Converts an internal [InternalFrequencyHistogramParams] to a public [FrequencyHistogramParams]. - */ -private fun InternalFrequencyHistogramParams.toFrequencyHistogram(): FrequencyHistogramParams { - val source = this - return frequencyHistogramParams { maximumFrequencyPerUser = source.maximumFrequency } -} - -/** Converts an internal [InternalPeriodicTimeInterval] to a public [PeriodicTimeInterval]. */ -private fun InternalPeriodicTimeInterval.toPeriodicTimeInterval(): PeriodicTimeInterval { - val source = this - return periodicTimeInterval { - startTime = source.startTime - increment = source.increment - intervalCount = source.intervalCount - } -} - -/** Converts an internal [InternalTimeIntervals] to a public [TimeIntervals]. */ -private fun InternalTimeIntervals.toTimeIntervals(): TimeIntervals { - val source = this - return timeIntervals { - for (internalTimeInternal in source.timeIntervalsList) { - this.timeIntervals += timeInterval { - startTime = internalTimeInternal.startTime - endTime = internalTimeInternal.endTime - } - } - } -} - -/** Converts an internal [ListReportsPageToken] to an internal [StreamInternalReportsRequest]. */ -private fun ListReportsPageToken.toStreamReportsRequest(): StreamInternalReportsRequest { - val source = this - return streamInternalReportsRequest { - // get 1 more than the actual page size for deciding whether or not to set page token - limit = pageSize + 1 - filter = filter { - measurementConsumerReferenceId = source.measurementConsumerReferenceId - externalReportIdAfter = source.lastReport.externalReportId - } - } -} - -/** Converts a public [ListReportsRequest] to an internal [ListReportsPageToken]. */ -private fun ListReportsRequest.toListReportsPageToken(): ListReportsPageToken { - grpcRequire(pageSize >= 0) { "Page size cannot be less than 0" } - - val source = this - val parentKey: MeasurementConsumerKey = - grpcRequireNotNull(MeasurementConsumerKey.fromName(parent)) { - "Parent is either unspecified or invalid." - } - val measurementConsumerReferenceId = parentKey.measurementConsumerId - - val isValidPageSize = - source.pageSize != 0 && source.pageSize >= MIN_PAGE_SIZE && source.pageSize <= MAX_PAGE_SIZE - - return if (pageToken.isNotBlank()) { - ListReportsPageToken.parseFrom(pageToken.base64UrlDecode()).copy { - grpcRequire(this.measurementConsumerReferenceId == measurementConsumerReferenceId) { - "Arguments must be kept the same when using a page token" - } - - if (isValidPageSize) { - pageSize = source.pageSize - } - } - } else { - listReportsPageToken { - pageSize = - when { - source.pageSize < MIN_PAGE_SIZE -> DEFAULT_PAGE_SIZE - source.pageSize > MAX_PAGE_SIZE -> MAX_PAGE_SIZE - else -> source.pageSize - } - this.measurementConsumerReferenceId = measurementConsumerReferenceId - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/SetOperationCompiler.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/SetOperationCompiler.kt deleted file mode 100644 index b2be4a9f684..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/SetOperationCompiler.kt +++ /dev/null @@ -1,500 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import kotlin.math.pow -import kotlinx.coroutines.coroutineScope -import kotlinx.coroutines.launch -import org.wfanet.measurement.reporting.v1alpha.Metric.SetOperation - -/** - * A primitive region of a Venn diagram is the intersection of a set of reporting sets, and it is - * represented by a bit representation of an integer. Only the reporting sets with IDs equal to the - * bit positions of set bits constitute the primitive region. Ex: Given a Venn Diagram of 3 - * reporting sets (rs0, rs1, rs2), a primitive region with an integer value equal to 3 has the bit - * representation b’011’. This means this primitive region is only covered by the intersection of - * rs0 and rs1 and not covered by rs2 (the order of the bit positions is from right to left). In - * other words, rs0 INTERSECT rs1 INTERSECT COMPLEMENT(rs2). Note that primitive regions are - * disjoint. - */ -private typealias PrimitiveRegion = ULong - -/** - * A union set is the union of a set of reporting sets, and it is represented by a bit - * representation of an integer. Only the reporting sets with IDs equal to the bit positions of set - * bits constitute the union set. Given a Venn Diagram of 3 reporting sets (rs0, rs1, rs2), a - * union-set with an integer value equal to 3 has the bit representation b’011’. This means this - * union-set is only covered by the union of rs0 and rs1 (the order of the bit positions is from - * right to left). - */ -private typealias UnionSet = ULong - -private typealias NumberReportingSets = Int - -/** A mapping from a [UnionSet] to its coefficient in the Venn diagram region decomposition. */ -private typealias UnionSetCoefficientMap = Map - -/** - * A mapping for cardinality computation from a [PrimitiveRegion] to its decomposition in terms of - * union-sets represented by [UnionSetCoefficientMap]. Take a case of 3 reporting sets (rs0, rs1, - * rs2) as an example. A primitive region with its value equal to 3 (b'011') means rs0 INTERSECT rs1 - * INTERSECT COMPLEMENT(rs2). The decomposition of the cardinality of the region = - * PrimitiveRegionToUnionSetCoefficientMap\[region\] = {4: -1, 5: 1, 6: 1, 7: -1}, i.e. |union-set5| - * + |union-set6| - |union-set4| - |union-set7|. - */ -private typealias PrimitiveRegionToUnionSetCoefficientMap = - MutableMap - -/** - * A memory cache that stores the Venn diagram region cardinality decompositions for different - * numbers of reporting sets. - */ -private typealias PrimitiveRegionCache = - MutableMap - -private enum class Operator { - UNION, - INTERSECT, - DIFFERENCE -} - -private interface Operand - -private data class ReportingSet(val id: Int, val resourceName: String) : Operand - -private data class SetOperationExpression( - val setOperator: Operator, - val lhs: Operand, - val rhs: Operand?, -) : Operand - -data class WeightedMeasurement(val reportingSets: List, val coefficient: Int) - -class SetOperationCompiler { - - private var primitiveRegionCache: PrimitiveRegionCache = mutableMapOf() - - // For unit test only. - fun getPrimitiveRegionCache(): - Map> { - return primitiveRegionCache.mapValues { it.value.toMap() }.toMap() - } - - /** - * Compiles a set operation to a list of [WeightedMeasurement]s which will be used for the - * cardinality computation. For example, given a set = primitiveRegion1 UNION primitiveRegion2, - * Count(set) = Count(primitiveRegion1) + Count(primitiveRegion2) = Count(unionSet1) - - * Count(unionSet2) + Count(unionSet3) - Count(unionSet2) = Count(unionSet1) + Count(unionSet3) - - * 2 * Count(unionSet2). - */ - suspend fun compileSetOperation(setOperation: SetOperation): List { - val reportingSetNames = mutableSetOf() - setOperation.storeReportingSetNames(reportingSetNames) - - // Sorts the list in alphabetical order to make sure the IDs are consistent for the same run. - val sortedReportingSetNames = reportingSetNames.sortedBy { it } - val reportingSetsMap = createReportingSetsMap(sortedReportingSetNames) - val numReportingSets = reportingSetsMap.size - - val setOperationExpression = setOperation.toSetOperationExpression(reportingSetsMap) - - // Step 1 - Gets the primitive regions that form the set operation - val primitiveRegions = - setOperationExpressionToPrimitiveRegions(numReportingSets, setOperationExpression) - - // Step 2 - Converts a set of primitive regions to a map of union-set to its coefficients for - // cardinality computation. - val unionSetCoefficientMap = - convertPrimitiveRegionsToUnionSetCoefficientMap(numReportingSets, primitiveRegions) - - return unionSetCoefficientMap.map { (unionSet, coefficient) -> - convertUnionSetToWeightedMeasurements(unionSet, coefficient, sortedReportingSetNames) - } - } - - /** Converts unionSetCoefficientMap to WeightedMeasurements. */ - private fun convertUnionSetToWeightedMeasurements( - unionSet: UnionSet, - coefficient: Int, - sortedReportingSetNames: List, - ): WeightedMeasurement { - // Find the reporting sets in the union-set. - val reportingSetNames = - (sortedReportingSetNames.indices).mapNotNull { bitPosition -> - if (isBitSet(unionSet, bitPosition)) sortedReportingSetNames[bitPosition] else null - } - - return WeightedMeasurement(reportingSetNames, coefficient) - } - - /** - * Converts a set of primitive regions to a map of union-set to its coefficients for cardinality - * computation - */ - private suspend fun convertPrimitiveRegionsToUnionSetCoefficientMap( - numReportingSets: Int, - primitiveRegions: Set, - ): UnionSetCoefficientMap { - - val primitiveRegionsToUnionSetCoefficients: PrimitiveRegionToUnionSetCoefficientMap = - mutableMapOf() - - coroutineScope { - for (region in primitiveRegions) { - // Reuse previous computation if available - if ( - reusePreviousComputation(numReportingSets, region, primitiveRegionsToUnionSetCoefficients) - ) { - continue - } - - launch { - convertSinglePrimitiveRegionToUnionSetCoefficientMap( - numReportingSets, - region, - primitiveRegionsToUnionSetCoefficients, - ) - } - } - } - - // Updates the memory cache with new computation result. - primitiveRegionCache.getOrPut(numReportingSets, ::mutableMapOf) += - primitiveRegionsToUnionSetCoefficients - - return aggregateCoefficientsByUnionSets(primitiveRegionsToUnionSetCoefficients) - } - - /** Aggregates the coefficients by union-sets. */ - private fun aggregateCoefficientsByUnionSets( - primitiveRegionsToUnionSetCoefficients: PrimitiveRegionToUnionSetCoefficientMap - ): UnionSetCoefficientMap { - val aggregatedResult = mutableMapOf() - for ((_, unionSetCoefficients) in primitiveRegionsToUnionSetCoefficients) { - for ((unionSet, coefficient) in unionSetCoefficients) { - aggregatedResult[unionSet] = aggregatedResult.getOrDefault(unionSet, 0) + coefficient - - // Remove the entry if its coefficient is zero. - if (aggregatedResult[unionSet] == 0) { - aggregatedResult.remove(unionSet) - } - } - } - // Sort the aggregatedResult to make sure the result is consistent every time. - return aggregatedResult.toSortedMap().toMap() - } - - /** - * Converts a single primitive region to a map of union-set to its coefficients, where the - * cardinality of the input primitive region is equal to the linear combination of the - * cardinalities of the union-sets with the coefficients. - * - * The algorithm is based on the observation on the linear transformation matrix from primitive - * regions to union-sets. Ex: - * ``` - * b'01' b'10' b'11' - * A 0 -1 1 - * B -1 0 1 - * A U B 1 1 -1 - * ``` - */ - private fun convertSinglePrimitiveRegionToUnionSetCoefficientMap( - numReportingSets: Int, - region: PrimitiveRegion, - primitiveRegionsToUnionSetCoefficients: PrimitiveRegionToUnionSetCoefficientMap, - ) { - // For a given region, we first find which bit positions are set and which are not. - val setBitPositions = mutableListOf() - val unsetBitPositions = mutableListOf() - - for (bitPosition in 0 until numReportingSets) { - if (isBitSet(region, bitPosition)) { - setBitPositions.add(bitPosition) - } else { - unsetBitPositions.add(bitPosition) - } - } - val primitiveRegionWeight = setBitPositions.size - - // Always starts from -1 unless the primitive region = (2^numReportingSets - 1) = b'11...1'. - val baseSign = if (primitiveRegionWeight != numReportingSets) -1 else 1 - var count = 0 - - val unionSetCoefficients = mutableMapOf() - - for (size in 1..numReportingSets) { - // Skips it if the union-only set is too light - if (size + primitiveRegionWeight < numReportingSets) { - continue - } - - // Instead of flipping the sign at the end of the loop, this could avoid the race condition. - count++ - val sign = if (count % 2 == 1) baseSign else -baseSign - - val composingUnionSets = findComposingUnionSets(setBitPositions, unsetBitPositions, size) - unionSetCoefficients += composingUnionSets.associateWith { sign } - } - - if (unionSetCoefficients.isNotEmpty()) { - primitiveRegionsToUnionSetCoefficients[region] = unionSetCoefficients.toMap() - } - } - - /** Reuses previous result in the memory cache if there is any. */ - private fun reusePreviousComputation( - numReportingSets: Int, - region: PrimitiveRegion, - primitiveRegionsToUnionSetCoefficients: PrimitiveRegionToUnionSetCoefficientMap, - ): Boolean { - // If the compiler has already run the case where the number of reporting sets equal to - // `numReportingSets`. - primitiveRegionCache[numReportingSets]?.also { cachedPrimitiveRegionsToUnionSetCoefficients -> - // If the compiler has calculated this region before. - cachedPrimitiveRegionsToUnionSetCoefficients[region]?.also { cachedUnionSetCoefficientMap -> - primitiveRegionsToUnionSetCoefficients[region] = cachedUnionSetCoefficientMap - } - } - return primitiveRegionsToUnionSetCoefficients.containsKey(region) - } - - /** - * Finds the union-sets which will be part of the combination to form the target region. - * Essentially, given the size of a combination, we are finding all the combinations of the bit - * positions where at least unset bit positions are selected. - */ - private fun findComposingUnionSets( - setBitPositions: MutableList, - unsetBitPositions: MutableList, - size: Int, - ): MutableList { - val composingUnionSets = mutableListOf() - - // If the size is not large enough to at least contain all unset bit positions or the size is - // too large to fill, return empty result. - if (unsetBitPositions.size > size || setBitPositions.size + unsetBitPositions.size < size) { - return composingUnionSets - } - - findValidUnionSets(size, 0, setBitPositions, unsetBitPositions, composingUnionSets) - - return composingUnionSets - } - - /** Finds the valid combinations as [UnionSet]s using backtracking. */ - private fun findValidUnionSets( - size: Int, - start: Int, - choices: MutableList, - combination: MutableList, - result: MutableList, - ) { - if (combination.size == size) { - result.add(combination.sumOf { 1.toUnionSet() shl it }) - return - } - - for (i in start until choices.size) { - combination.add(choices[i]) - findValidUnionSets(size, i + 1, choices, combination, result) - combination.removeLast() - } - - return - } - - /** Gets the set of the primitive regions that form the set from the set operation expression. */ - private fun setOperationExpressionToPrimitiveRegions( - numReportingSets: Int, - setOperationExpression: SetOperationExpression, - ): Set { - val allPrimitiveRegionSetsList = buildAllPrimitiveRegions(numReportingSets) - return setOperationExpression.decompose(allPrimitiveRegionSetsList) - } -} - -/** - * Decomposes the set operation expression to a set of primitive regions by calculating the set - * operation between each two operands. - */ -private fun SetOperationExpression.decompose( - allPrimitiveRegionSetsList: List> -): Set { - val source = this - val lhsPrimitiveRegions = source.lhs.decompose(allPrimitiveRegionSetsList) - val rhsPrimitiveRegions = source.rhs.decompose(allPrimitiveRegionSetsList) - return calculateBinarySetOperation(lhsPrimitiveRegions, rhsPrimitiveRegions, source.setOperator) -} - -/** Decomposes the operand to a set of primitive regions. */ -private fun Operand?.decompose( - allPrimitiveRegionSetsList: List> -): Set { - return when (val operand = this) { - is SetOperationExpression -> { - operand.decompose(allPrimitiveRegionSetsList) - } - is ReportingSet -> { - allPrimitiveRegionSetsList[operand.id] - } - else -> setOf() - } -} - -/** Calculates the binary set operation. */ -private fun calculateBinarySetOperation( - lhs: Set, - rhs: Set, - operator: Operator, -): Set { - return when (operator) { - Operator.UNION -> lhs union rhs - Operator.INTERSECT -> lhs intersect rhs - Operator.DIFFERENCE -> lhs subtract rhs - } -} - -/** - * Builds a list of primitive regions where the index represents the reporting set ID and the - * element is the set of primitive regions which forms the corresponding reporting set. For example, - * if reportingSetId = 1, then allPrimitiveRegionSetsList\[reportingSetId\] = setOf(1(=b’001’), - * 3(=b’011’), 5(=b’101’), 7(=b’111’)). - */ -private fun buildAllPrimitiveRegions(numReportingSets: Int): List> { - val numPrimitiveRegions = 2.0.pow(numReportingSets).toPrimitiveRegion() - 1.toPrimitiveRegion() - val allPrimitiveRegionSetsList: List> = - List(numReportingSets) { mutableSetOf() } - - // A region is in the set of reportingSet when its bit at bit position == reportingSetId is set. - for (region in 1.toPrimitiveRegion()..numPrimitiveRegions) { - for (reportingSetId in 0 until numReportingSets) { - if (isBitSet(region, reportingSetId)) { - allPrimitiveRegionSetsList[reportingSetId].add(region) - } - } - } - - return allPrimitiveRegionSetsList.map(MutableSet::toSet) -} - -/** Converts a [Int] to a [PrimitiveRegion] */ -private fun Int.toPrimitiveRegion(): PrimitiveRegion { - return this.toULong() -} - -/** Converts a [Double] to a [PrimitiveRegion] */ -private fun Double.toPrimitiveRegion(): PrimitiveRegion { - return this.toULong() -} - -/** Converts a [Int] to a [UnionSet] */ -private fun Int.toUnionSet(): UnionSet { - return this.toULong() -} - -/** Checks if the bit at `bitPosition` of a number is set or not. */ -fun isBitSet(number: ULong, bitPosition: Int): Boolean { - return (number and (1UL shl bitPosition)) != 0UL -} - -/** Creates a map of resource names of reporting sets to [ReportingSet]s. */ -private fun createReportingSetsMap( - sortedReportingSetNames: List -): Map { - val reportingSetsMap: MutableMap = mutableMapOf() - for ((id, reportingSetName) in sortedReportingSetNames.withIndex()) { - reportingSetsMap[reportingSetName] = ReportingSet(id, reportingSetName) - } - return reportingSetsMap.toMap() -} - -/** Gets all resource names of the reporting sets used in this [SetOperation]. */ -private fun SetOperation.storeReportingSetNames(reportingSetNames: MutableSet) { - val root = this - if (!root.hasLhs()) { - throw IllegalArgumentException("lhs in SetOperation must be set.") - } - if (!root.lhs.hasReportingSet() && !root.lhs.hasOperation()) { - throw IllegalArgumentException("Operand type of lhs in SetOperation must be set.") - } - - root.lhs.storeReportingSetNames(reportingSetNames) - root.rhs.storeReportingSetNames(reportingSetNames) -} - -/** Gets all resource names of the reporting sets used in this [SetOperation.Operand]. */ -private fun SetOperation.Operand.storeReportingSetNames(reportingSetNames: MutableSet) { - val node = this - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (node.operandCase) { - // Leaf node - SetOperation.Operand.OperandCase.REPORTING_SET -> reportingSetNames.add(node.reportingSet) - SetOperation.Operand.OperandCase.OPERATION -> { - node.operation.storeReportingSetNames(reportingSetNames) - } - // Empty node. No further action. - SetOperation.Operand.OperandCase.OPERAND_NOT_SET -> return - } -} - -/** Converts a public [SetOperation] to a [SetOperationExpression]. */ -private fun SetOperation.toSetOperationExpression( - reportingSetsMap: Map -): SetOperationExpression { - val root = this - - if (!root.hasLhs()) { - throw IllegalArgumentException("lhs in SetOperation must be set.") - } - - val lhs = - root.lhs.toOperand(reportingSetsMap) - ?: throw IllegalArgumentException("Operand type of lhs in SetOperation must be set.") - - val rhs = root.rhs.toOperand(reportingSetsMap) - - return SetOperationExpression(root.type.toOperator(), lhs, rhs) -} - -/** Converts a public [SetOperation.Operand] to a nullable [Operand]. */ -private fun SetOperation.Operand.toOperand(reportingSetsMap: Map): Operand? { - val source = this - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - return when (source.operandCase) { - SetOperation.Operand.OperandCase.REPORTING_SET -> { - return reportingSetsMap[source.reportingSet] - } - SetOperation.Operand.OperandCase.OPERATION -> - return source.operation.toSetOperationExpression(reportingSetsMap) - SetOperation.Operand.OperandCase.OPERAND_NOT_SET -> null - } -} - -/** Converts a public [SetOperation.Type] to a [Operator]. */ -private fun SetOperation.Type.toOperator(): Operator { - val source = this - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - return when (source) { - SetOperation.Type.TYPE_UNSPECIFIED -> - throw IllegalArgumentException("Set operator type is not specified.") - SetOperation.Type.UNION -> Operator.UNION - SetOperation.Type.DIFFERENCE -> Operator.DIFFERENCE - SetOperation.Type.INTERSECTION -> Operator.INTERSECT - SetOperation.Type.UNRECOGNIZED -> - throw IllegalArgumentException("Unrecognized Set operator type.") - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/BUILD.bazel deleted file mode 100644 index eba73091447..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/BUILD.bazel +++ /dev/null @@ -1,32 +0,0 @@ -load("@rules_java//java:defs.bzl", "java_binary") -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -package( - default_visibility = ["//src/test/kotlin/org/wfanet/measurement/reporting:__subpackages__"], -) - -kt_jvm_library( - name = "reporting", - srcs = ["Reporting.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:flags", - "//src/main/proto/wfa/measurement/reporting/v1alpha:event_group_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:event_groups_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:report_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reporting_set_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reporting_sets_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reports_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//imports/java/picocli", - "@wfa_common_jvm//imports/kotlin/com/google/protobuf/kotlin", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - ], -) - -java_binary( - name = "Reporting", - main_class = "org.wfanet.measurement.reporting.service.api.v1alpha.tools.ReportingKt", - tags = ["manual"], - runtime_deps = [":reporting"], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/README.md b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/README.md deleted file mode 100644 index 1126c7effd8..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/README.md +++ /dev/null @@ -1,138 +0,0 @@ -# Reporting CLI Tools - -Command-line tools for Reporting API. Use the `help` subcommand for help with -any of the subcommands. - -Note that instead of specifying arguments on the command-line, you can specify -an arguments file using `@` followed by the path. For example, - -```shell -Reporting @/home/foo/args.txt -``` - -## Certificate Host - -In the event that the host you specify to the `--reporting-server-api-target` -option doesn't match what's in the Subject Alternative Name (SAN) extension of -the server's certificate, you'll need to specify a host that does match using -the `--reporting-server-api-cert-host` option. - -## Examples - -### reporting-sets - -#### create - -```shell -Reporting \ - --tls-cert-file=src/main/k8s/testing/secretfiles/mc_tls.pem \ - --tls-key-file=src/main/k8s/testing/secretfiles/mc_tls.key \ - --cert-collection-file src/main/k8s/testing/secretfiles/reporting_root.pem \ - --reporting-server-api-target v1alpha.reporting.dev.halo-cmm.org:8443 \ - reporting-sets create --parent=measurementConsumers/VCTqwV_vFXw \ - --event-group=measurementConsumers/VCTqwV_vFXw/dataProviders/1/eventGroups/1 \ - --event-group=measurementConsumers/VCTqwV_vFXw/dataProviders/1/eventGroups/2 \ - --event-group=measurementConsumers/VCTqwV_vFXw/dataProviders/2/eventGroups/1 \ - --filter='video_ad.age == 1' --display-name='test-reporting-set' -``` - -#### list - -```shell -Reporting \ - --tls-cert-file=src/main/k8s/testing/secretfiles/mc_tls.pem \ - --tls-key-file=src/main/k8s/testing/secretfiles/mc_tls.key \ - --cert-collection-file src/main/k8s/testing/secretfiles/reporting_root.pem \ - --reporting-server-api-target v1alpha.reporting.dev.halo-cmm.org:8443 \ - reporting-sets list --parent=measurementConsumers/VCTqwV_vFXw -``` - -To retrieve the next page of reports, use the `--page-token` option to specify -the token returned from the previous response. - -### reports - -#### create - -```shell -Reporting \ - --tls-cert-file=secretfiles/mc_tls.pem \ - --tls-key-file=secretfiles/mc_tls.key \ - --cert-collection-file=secretfiles/reporting_root.pem \ - --reporting-server-api-target=v1alpha.reporting.dev.halo-cmm.org:8443 \ - reports create \ - --idempotency-key="report001" \ - --parent=measurementConsumers/777 \ - --event-group-key=$EVENT_GROUP_NAME_1 \ - --event-group-value="video_ad.age == 1" \ - --event-group-key=$EVENT_GROUP_NAME_2 \ - --event-group-value="video_ad.age == 12" \ - --interval-start-time=2017-01-15T01:30:15.01Z \ - --interval-end-time=2018-10-27T23:19:12.99Z \ - --interval-start-time=2019-01-19T09:48:35.57Z \ - --interval-end-time=2022-06-13T11:57:54.21Z \ - --metric=' - reach { } - set_operations { - unique_name: "operation1" - set_operation { - type: 1 - lhs { - reporting_set: "measurementConsumers/1/reportingSets/1" - } - rhs { - reporting_set: "measurementConsumers/1/reportingSets/2" - } - } - } - ' -``` - -User specifies the type of time args by using either repeated interval params( -`--interval-start-time`, `--interval-end-time`) or periodic time args( -`--periodic-interval-start-time`, `--periodic-interval-increment` and -`--periodic-interval-count`) - -The `--metric` option expects a -[`Metric`](../../../../../../../../../proto/wfa/measurement/reporting/v1alpha/metric.proto) -protobuf message in text format. See -[`metric1.textproto`](../../../../../../../../../../test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/metric1.textproto) -for a complicated example. You can use shell quoting for a multiline string, or -use command substitution to read the message from a file e.g. `--metric=$(cat -metric1.textproto)`. - -#### list - -```shell -Reporting \ - --tls-cert-file=secretfiles/mc_tls.pem \ - --tls-key-file=secretfiles/mc_tls.key \ - --cert-collection-file=secretfiles/reporting_root.pem \ - --reporting-server-api-target=v1alpha.reporting.dev.halo-cmm.org:8443 \ - reports list --parent=measurementConsumers/777 -``` - -#### get - -```shell -Reporting \ - --tls-cert-file=secretfiles/mc_tls.pem \ - --tls-key-file=secretfiles/mc_tls.key \ - --cert-collection-file=secretfiles/reporting_root.pem \ - --reporting-server-api-target=v1alpha.reporting.dev.halo-cmm.org:8443 \ - reports get measurementConsumers/777/reports/5 -``` - -### event-groups - -#### list - -```shell -Reporting \ - --tls-cert-file=secretfiles/mc_tls.pem \ - --tls-key-file=secretfiles/mc_tls.key \ - --cert-collection-file=secretfiles/reporting_root.pem \ - --reporting-server-api-target=v1alpha.reporting.dev.halo-cmm.org:8443 \ - event-groups list \ - --parent=measurementConsumers/777/dataProviders/1 -``` diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/Reporting.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/Reporting.kt deleted file mode 100644 index c2254b6eb01..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/Reporting.kt +++ /dev/null @@ -1,497 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha.tools - -import io.grpc.ManagedChannel -import java.time.Duration -import java.time.Instant -import kotlin.properties.Delegates -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.runBlocking -import org.wfanet.measurement.common.DurationFormat -import org.wfanet.measurement.common.commandLineMain -import org.wfanet.measurement.common.crypto.SigningCerts -import org.wfanet.measurement.common.grpc.TlsFlags -import org.wfanet.measurement.common.grpc.buildMutualTlsChannel -import org.wfanet.measurement.common.grpc.withShutdownTimeout -import org.wfanet.measurement.common.parseTextProto -import org.wfanet.measurement.common.toProtoDuration -import org.wfanet.measurement.common.toProtoTime -import org.wfanet.measurement.reporting.v1alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub -import org.wfanet.measurement.reporting.v1alpha.Metric -import org.wfanet.measurement.reporting.v1alpha.ReportKt.EventGroupUniverseKt.eventGroupEntry -import org.wfanet.measurement.reporting.v1alpha.ReportKt.eventGroupUniverse -import org.wfanet.measurement.reporting.v1alpha.ReportingSetsGrpcKt.ReportingSetsCoroutineStub -import org.wfanet.measurement.reporting.v1alpha.ReportsGrpcKt.ReportsCoroutineStub -import org.wfanet.measurement.reporting.v1alpha.createReportRequest -import org.wfanet.measurement.reporting.v1alpha.createReportingSetRequest -import org.wfanet.measurement.reporting.v1alpha.getReportRequest -import org.wfanet.measurement.reporting.v1alpha.listEventGroupsRequest -import org.wfanet.measurement.reporting.v1alpha.listReportingSetsRequest -import org.wfanet.measurement.reporting.v1alpha.listReportsRequest -import org.wfanet.measurement.reporting.v1alpha.periodicTimeInterval -import org.wfanet.measurement.reporting.v1alpha.report -import org.wfanet.measurement.reporting.v1alpha.reportingSet -import org.wfanet.measurement.reporting.v1alpha.timeInterval -import org.wfanet.measurement.reporting.v1alpha.timeIntervals -import picocli.CommandLine - -private class ReportingApiFlags { - @CommandLine.Option( - names = ["--reporting-server-api-target"], - description = ["gRPC target (authority) of the reporting server's public API"], - required = true, - ) - lateinit var apiTarget: String - private set - - @CommandLine.Option( - names = ["--reporting-server-api-cert-host"], - description = - [ - "Expected hostname (DNS-ID) in the reporting server's TLS certificate.", - "This overrides derivation of the TLS DNS-ID from --reporting-server-api-target.", - ], - required = false, - ) - var apiCertHost: String? = null - private set -} - -private class PageParams { - @CommandLine.Option( - names = ["--page-size"], - description = ["The maximum number of items to return. The maximum value is 1000"], - required = false, - ) - var pageSize: Int = 1000 - private set - - @CommandLine.Option( - names = ["--page-token"], - description = ["Page token from a previous list call to retrieve the next page"], - defaultValue = "", - required = false, - ) - lateinit var pageToken: String - private set -} - -@CommandLine.Command(name = "create", description = ["Creates a reporting set"]) -class CreateReportingSetCommand : Runnable { - @CommandLine.ParentCommand private lateinit var parent: ReportingSetsCommand - - @CommandLine.Option( - names = ["--parent"], - description = ["API resource name of the Measurement Consumer"], - required = true, - ) - private lateinit var measurementConsumerName: String - - @CommandLine.Option( - names = ["--event-group"], - description = ["List of EventGroup's API resource names"], - required = true, - ) - private lateinit var eventGroups: List - - @CommandLine.Option( - names = ["--filter"], - description = ["CEL filter predicate that applies to all `event_groups`"], - required = false, - defaultValue = "", - ) - private lateinit var filterExpression: String - - @CommandLine.Option( - names = ["--display-name"], - description = ["Human-readable name for display purposes"], - required = false, - defaultValue = "", - ) - private lateinit var displayNameInput: String - - override fun run() { - val request = createReportingSetRequest { - parent = measurementConsumerName - reportingSet = reportingSet { - eventGroups += this@CreateReportingSetCommand.eventGroups - filter = filterExpression - displayName = displayNameInput - } - } - val reportingSet = - runBlocking(Dispatchers.IO) { parent.reportingSetStub.createReportingSet(request) } - println(reportingSet) - } -} - -@CommandLine.Command(name = "list", description = ["List reporting sets"]) -class ListReportingSetsCommand : Runnable { - @CommandLine.ParentCommand private lateinit var parent: ReportingSetsCommand - - @CommandLine.Option( - names = ["--parent"], - description = ["API resource name of the Measurement Consumer"], - required = true, - ) - private lateinit var measurementConsumerName: String - - @CommandLine.Mixin private lateinit var pageParams: PageParams - - override fun run() { - val request = listReportingSetsRequest { - parent = measurementConsumerName - pageSize = pageParams.pageSize - pageToken = pageParams.pageToken - } - - val response = - runBlocking(Dispatchers.IO) { parent.reportingSetStub.listReportingSets(request) } - - println(response) - } -} - -@CommandLine.Command( - name = "reporting-sets", - sortOptions = false, - subcommands = - [ - CommandLine.HelpCommand::class, - CreateReportingSetCommand::class, - ListReportingSetsCommand::class, - ], -) -class ReportingSetsCommand : Runnable { - @CommandLine.ParentCommand lateinit var parent: Reporting - val reportingSetStub: ReportingSetsCoroutineStub by lazy { - ReportingSetsCoroutineStub(parent.channel) - } - - override fun run() {} -} - -@CommandLine.Command(name = "create", description = ["Create a set operation report"]) -class CreateReportCommand : Runnable { - @CommandLine.ParentCommand private lateinit var parent: ReportsCommand - - @CommandLine.Option( - names = ["--parent"], - description = ["API resource name of the Measurement Consumer"], - required = true, - ) - private lateinit var measurementConsumerName: String - - @CommandLine.Option( - names = ["--idempotency-key"], - description = ["Used as the prefix of the idempotency keys of measurements"], - required = true, - ) - private lateinit var idempotencyKey: String - - class EventGroupInput { - @CommandLine.Option( - names = ["--event-group-key"], - description = ["Event Group Entry's key"], - required = true, - ) - lateinit var key: String - private set - - @CommandLine.Option( - names = ["--event-group-value"], - description = ["Event Group Entry's value"], - defaultValue = "", - required = false, - ) - lateinit var value: String - private set - } - - @CommandLine.ArgGroup(exclusive = false, multiplicity = "1..*", heading = "Event Group Entries\n") - private lateinit var eventGroups: List - - class TimeInput { - class TimeIntervalInput { - @CommandLine.Option( - names = ["--interval-start-time"], - description = ["Start of time interval in ISO 8601 format of UTC"], - required = true, - ) - lateinit var intervalStartTime: Instant - private set - - @CommandLine.Option( - names = ["--interval-end-time"], - description = ["End of time interval in ISO 8601 format of UTC"], - required = true, - ) - lateinit var intervalEndTime: Instant - private set - } - - class PeriodicTimeIntervalInput { - @CommandLine.Option( - names = ["--periodic-interval-start-time"], - description = ["Start of the first time interval in ISO 8601 format of UTC"], - required = true, - ) - lateinit var periodicIntervalStartTime: Instant - private set - - @CommandLine.Option( - names = ["--periodic-interval-increment"], - description = ["Increment for each time interval in ISO-8601 format of PnDTnHnMn"], - required = true, - ) - lateinit var periodicIntervalIncrement: Duration - private set - - @set:CommandLine.Option( - names = ["--periodic-interval-count"], - description = ["Number of periodic intervals"], - required = true, - ) - var periodicIntervalCount by Delegates.notNull() - private set - } - - @CommandLine.ArgGroup(exclusive = false, multiplicity = "1..*", heading = "Time intervals\n") - var timeIntervals: List? = null - private set - - @CommandLine.ArgGroup( - exclusive = false, - multiplicity = "1", - heading = "Periodic time interval specification\n", - ) - var periodicTimeIntervalInput: PeriodicTimeIntervalInput? = null - private set - } - - @CommandLine.ArgGroup( - exclusive = true, - multiplicity = "1", - heading = "Time interval or periodic time interval\n", - ) - private lateinit var timeInput: TimeInput - - @CommandLine.Option( - names = ["--metric"], - description = ["Metric protobuf messages in text format"], - required = true, - ) - private lateinit var textFormatMetrics: List - - override fun run() { - val request = createReportRequest { - parent = measurementConsumerName - report = report { - reportIdempotencyKey = idempotencyKey - measurementConsumer = measurementConsumerName - eventGroupUniverse = eventGroupUniverse { - eventGroups.forEach { - eventGroupEntries += eventGroupEntry { - key = it.key - value = it.value - } - } - } - - // Either timeIntervals or periodicTimeIntervalInput are set. - if (timeInput.timeIntervals != null) { - val intervals = checkNotNull(timeInput.timeIntervals) - timeIntervals = timeIntervals { - intervals.forEach { - timeIntervals += timeInterval { - startTime = it.intervalStartTime.toProtoTime() - endTime = it.intervalEndTime.toProtoTime() - } - } - } - } else { - val periodicIntervals = checkNotNull(timeInput.periodicTimeIntervalInput) - periodicTimeInterval = periodicTimeInterval { - startTime = periodicIntervals.periodicIntervalStartTime.toProtoTime() - increment = periodicIntervals.periodicIntervalIncrement.toProtoDuration() - intervalCount = periodicIntervals.periodicIntervalCount - } - } - - for (textFormatMetric in textFormatMetrics) { - metrics += - textFormatMetric.reader().use { parseTextProto(it, Metric.getDefaultInstance()) } - } - } - } - val report = runBlocking(Dispatchers.IO) { parent.reportsStub.createReport(request) } - - println(report) - } -} - -@CommandLine.Command(name = "list", description = ["List set operation reports"]) -class ListReportsCommand : Runnable { - @CommandLine.ParentCommand private lateinit var parent: ReportsCommand - - @CommandLine.Option( - names = ["--parent"], - description = ["API resource name of the Measurement Consumer"], - required = true, - ) - private lateinit var measurementConsumerName: String - - @CommandLine.Mixin private lateinit var pageParams: PageParams - - override fun run() { - val request = listReportsRequest { - parent = measurementConsumerName - pageSize = pageParams.pageSize - pageToken = pageParams.pageToken - } - - val response = runBlocking(Dispatchers.IO) { parent.reportsStub.listReports(request) } - - response.reportsList.forEach { println(it.name + " " + it.state.toString()) } - if (response.nextPageToken.isNotEmpty()) { - println("nextPageToken: ${response.nextPageToken}") - } - } -} - -@CommandLine.Command(name = "get", description = ["Get a set operation report"]) -class GetReportCommand : Runnable { - @CommandLine.ParentCommand private lateinit var parent: ReportsCommand - - @CommandLine.Parameters(description = ["API resource name of the Report"]) - private lateinit var reportName: String - - override fun run() { - val request = getReportRequest { name = reportName } - - val report = runBlocking(Dispatchers.IO) { parent.reportsStub.getReport(request) } - println(report) - } -} - -@CommandLine.Command( - name = "reports", - sortOptions = false, - subcommands = - [ - CommandLine.HelpCommand::class, - CreateReportCommand::class, - ListReportsCommand::class, - GetReportCommand::class, - ], -) -class ReportsCommand : Runnable { - @CommandLine.ParentCommand lateinit var parent: Reporting - - val reportsStub: ReportsCoroutineStub by lazy { ReportsCoroutineStub(parent.channel) } - - override fun run() {} -} - -@CommandLine.Command(name = "list", description = ["List event groups"]) -class ListEventGroups : Runnable { - @CommandLine.ParentCommand private lateinit var parent: EventGroupsCommand - - @CommandLine.Option( - names = ["--parent"], - description = ["API resource name of the Data Provider"], - required = true, - ) - private lateinit var dataProviderName: String - - @CommandLine.Option( - names = ["--filter"], - description = ["Result filter in format of raw CEL expression"], - required = false, - defaultValue = "", - ) - private lateinit var celFilter: String - - @CommandLine.Mixin private lateinit var pageParams: PageParams - - override fun run() { - val request = listEventGroupsRequest { - parent = dataProviderName - pageSize = pageParams.pageSize - pageToken = pageParams.pageToken - filter = celFilter - } - - val response = runBlocking(Dispatchers.IO) { parent.eventGroupStub.listEventGroups(request) } - - println(response) - } -} - -@CommandLine.Command( - name = "event-groups", - sortOptions = false, - subcommands = [CommandLine.HelpCommand::class, ListEventGroups::class], -) -class EventGroupsCommand : Runnable { - @CommandLine.ParentCommand lateinit var parent: Reporting - - val eventGroupStub: EventGroupsCoroutineStub by lazy { EventGroupsCoroutineStub(parent.channel) } - - override fun run() {} -} - -@CommandLine.Command( - name = "reporting", - description = ["Reporting CLI tool"], - sortOptions = false, - subcommands = - [ - CommandLine.HelpCommand::class, - ReportingSetsCommand::class, - ReportsCommand::class, - EventGroupsCommand::class, - ], -) -class Reporting : Runnable { - @CommandLine.Mixin private lateinit var tlsFlags: TlsFlags - @CommandLine.Mixin private lateinit var apiFlags: ReportingApiFlags - - val channel: ManagedChannel by lazy { - val clientCerts = - SigningCerts.fromPemFiles( - certificateFile = tlsFlags.certFile, - privateKeyFile = tlsFlags.privateKeyFile, - trustedCertCollectionFile = tlsFlags.certCollectionFile, - ) - buildMutualTlsChannel(apiFlags.apiTarget, clientCerts, apiFlags.apiCertHost) - .withShutdownTimeout(Duration.ofSeconds(1)) - } - - override fun run() {} - - companion object { - @JvmStatic - fun main(args: Array) = commandLineMain(Reporting(), args, DurationFormat.ISO_8601) - } -} - -/** - * Create, List and Get reporting set or report. - * - * Use the `help` command to see usage details. - */ -fun main(args: Array) = commandLineMain(Reporting(), args) diff --git a/src/main/proto/wfa/measurement/config/reporting/BUILD.bazel b/src/main/proto/wfa/measurement/config/reporting/BUILD.bazel index ff5989eb7bd..ebcb30f5080 100644 --- a/src/main/proto/wfa/measurement/config/reporting/BUILD.bazel +++ b/src/main/proto/wfa/measurement/config/reporting/BUILD.bazel @@ -27,17 +27,6 @@ kt_jvm_proto_library( deps = [":measurement_consumer_config_proto"], ) -proto_library( - name = "measurement_spec_config_proto", - srcs = ["measurement_spec_config.proto"], - strip_import_prefix = IMPORT_PREFIX, -) - -kt_jvm_proto_library( - name = "measurement_spec_config_kt_jvm_proto", - deps = [":measurement_spec_config_proto"], -) - proto_library( name = "metric_spec_config_proto", srcs = ["metric_spec_config.proto"], diff --git a/src/main/proto/wfa/measurement/config/reporting/measurement_spec_config.proto b/src/main/proto/wfa/measurement/config/reporting/measurement_spec_config.proto deleted file mode 100644 index 15cfe90fd12..00000000000 --- a/src/main/proto/wfa/measurement/config/reporting/measurement_spec_config.proto +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright 2023 The Cross-Media Measurement Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package wfa.measurement.config; - -option java_package = "org.wfanet.measurement.config.reporting"; -option java_multiple_files = true; - -// The configuration for MeasurementSpecs used in Measurements created by the -// Reporting server. -message MeasurementSpecConfig { - // Parameters for differential privacy (DP). - // - // For detail, refer to "Dwork, C. and Roth, A., 2014. The algorithmic - // foundations of differential privacy. Foundations and Trends in Theoretical - // Computer Science, 9(3-4), pp.211-407." - message DifferentialPrivacyParams { - double epsilon = 1; - double delta = 2; - } - - // Specifies a range of VIDs to be sampled. - message VidSamplingInterval { - message FixedStart { - // The start of the sampling interval in [0, 1) - float start = 1; - // The width of the sampling interval. - float width = 2; - } - - message RandomStart { - int32 num_vid_buckets = 1; - // The width of the sampling interval in [1, `num_vid_buckets`]. For - // example, if `num_vid_buckets` is 300, then this width can be in the - // range [1, 300]. If 100 is chosen, then the width in - // `vid_sampling_interval` of `MeasurementSpec` is 100/300. - int32 width = 2; - } - - // Defaults to start of 0 width of 1 if not set. - oneof start { - FixedStart fixed_start = 1; - RandomStart random_start = 2; - } - } - - message ReachSingleDataProvider { - // Differential privacy parameters for reach. - DifferentialPrivacyParams privacy_params = 1; - VidSamplingInterval vid_sampling_interval = 2; - } - ReachSingleDataProvider reach_single_data_provider = 1; - - message Reach { - // Differential privacy parameters for reach. - DifferentialPrivacyParams privacy_params = 1; - VidSamplingInterval vid_sampling_interval = 2; - } - Reach reach = 2; - - message ReachAndFrequencySingleDataProvider { - // Differential privacy parameters for reach. - DifferentialPrivacyParams reach_privacy_params = 1; - // Differential privacy parameters for frequency. - DifferentialPrivacyParams frequency_privacy_params = 2; - VidSamplingInterval vid_sampling_interval = 3; - } - ReachAndFrequencySingleDataProvider reach_and_frequency_single_data_provider = - 3; - - message ReachAndFrequency { - // Differential privacy parameters for reach. - DifferentialPrivacyParams reach_privacy_params = 1; - // Differential privacy parameters for frequency. - DifferentialPrivacyParams frequency_privacy_params = 2; - VidSamplingInterval vid_sampling_interval = 3; - } - ReachAndFrequency reach_and_frequency = 4; - - message Impression { - // Differential privacy parameters. - DifferentialPrivacyParams privacy_params = 1; - VidSamplingInterval vid_sampling_interval = 2; - } - Impression impression = 5; - - message Duration { - // Differential privacy parameters. - DifferentialPrivacyParams privacy_params = 1; - VidSamplingInterval vid_sampling_interval = 2; - } - Duration duration = 6; -} diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/BUILD.bazel b/src/main/proto/wfa/measurement/reporting/v1alpha/BUILD.bazel deleted file mode 100644 index 5df47be011f..00000000000 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/BUILD.bazel +++ /dev/null @@ -1,148 +0,0 @@ -load("@rules_proto//proto:defs.bzl", "proto_library") -load( - "@wfa_rules_kotlin_jvm//kotlin:defs.bzl", - "kt_jvm_grpc_proto_library", - "kt_jvm_proto_library", -) - -package(default_visibility = ["//visibility:public"]) - -IMPORT_PREFIX = "/src/main/proto" - -# Resources and shared message types. - -proto_library( - name = "event_group_proto", - srcs = ["event_group.proto"], - strip_import_prefix = IMPORT_PREFIX, - deps = [ - "@com_google_googleapis//google/api:field_behavior_proto", - "@com_google_googleapis//google/api:resource_proto", - "@com_google_protobuf//:any_proto", - ], -) - -kt_jvm_proto_library( - name = "event_group_kt_jvm_proto", - deps = [":event_group_proto"], -) - -proto_library( - name = "metric", - srcs = ["metric.proto"], - strip_import_prefix = IMPORT_PREFIX, - deps = [ - "@com_google_googleapis//google/api:field_behavior_proto", - "@com_google_googleapis//google/api:resource_proto", - ], -) - -proto_library( - name = "report_proto", - srcs = ["report.proto"], - strip_import_prefix = IMPORT_PREFIX, - deps = [ - ":metric", - ":time_interval_proto", - "@com_google_googleapis//google/api:field_behavior_proto", - "@com_google_googleapis//google/api:resource_proto", - ], -) - -kt_jvm_proto_library( - name = "report_kt_jvm_proto", - deps = [":report_proto"], -) - -proto_library( - name = "reporting_set_proto", - srcs = ["reporting_set.proto"], - strip_import_prefix = IMPORT_PREFIX, - deps = [ - "@com_google_googleapis//google/api:field_behavior_proto", - "@com_google_googleapis//google/api:resource_proto", - ], -) - -kt_jvm_proto_library( - name = "reporting_set_kt_jvm_proto", - deps = [":reporting_set_proto"], -) - -proto_library( - name = "time_interval_proto", - srcs = ["time_interval.proto"], - strip_import_prefix = IMPORT_PREFIX, - deps = [ - "@com_google_googleapis//google/api:field_behavior_proto", - "@com_google_protobuf//:duration_proto", - "@com_google_protobuf//:timestamp_proto", - ], -) - -proto_library( - name = "page_token_proto", - srcs = ["page_token.proto"], - deps = [], -) - -kt_jvm_proto_library( - name = "page_token_kt_jvm_proto", - deps = [":page_token_proto"], -) - -# Services. - -proto_library( - name = "event_groups_service_proto", - srcs = ["event_groups_service.proto"], - strip_import_prefix = IMPORT_PREFIX, - deps = [ - ":event_group_proto", - "@com_google_googleapis//google/api:annotations_proto", - "@com_google_googleapis//google/api:client_proto", - "@com_google_googleapis//google/api:field_behavior_proto", - "@com_google_googleapis//google/api:resource_proto", - ], -) - -kt_jvm_grpc_proto_library( - name = "event_groups_service_kt_jvm_grpc_proto", - deps = [":event_groups_service_proto"], -) - -proto_library( - name = "reports_service_proto", - srcs = ["reports_service.proto"], - strip_import_prefix = IMPORT_PREFIX, - deps = [ - ":report_proto", - "@com_google_googleapis//google/api:annotations_proto", - "@com_google_googleapis//google/api:client_proto", - "@com_google_googleapis//google/api:field_behavior_proto", - "@com_google_googleapis//google/api:resource_proto", - ], -) - -kt_jvm_grpc_proto_library( - name = "reports_service_kt_jvm_grpc_proto", - deps = [":reports_service_proto"], -) - -proto_library( - name = "reporting_sets_service_proto", - srcs = ["reporting_sets_service.proto"], - strip_import_prefix = IMPORT_PREFIX, - deps = [ - ":reporting_set_proto", - "@com_google_googleapis//google/api:annotations_proto", - "@com_google_googleapis//google/api:client_proto", - "@com_google_googleapis//google/api:field_behavior_proto", - "@com_google_googleapis//google/api:resource_proto", - ], -) - -kt_jvm_grpc_proto_library( - name = "reporting_sets_service_kt_jvm_grpc_proto", - deps = [":reporting_sets_service_proto"], -) diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/event_group.proto b/src/main/proto/wfa/measurement/reporting/v1alpha/event_group.proto deleted file mode 100644 index a40be039697..00000000000 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/event_group.proto +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.reporting.v1alpha; - -import "google/api/field_behavior.proto"; -import "google/api/resource.proto"; -import "google/protobuf/any.proto"; - -option java_package = "org.wfanet.measurement.reporting.v1alpha"; -option java_multiple_files = true; -option java_outer_classname = "EventGroupProto"; - -// A grouping of events defined by a `DataProvider`. For example, a single -// campaign or creative defined in a publisher's ad system. -message EventGroup { - option (google.api.resource) = { - type: "reporting.halo-cmm.org/EventGroup" - pattern: "measurementConsumers/{measurement_consumer}/dataProviders/{data_provider}/eventGroups/{event_group}" - }; - - // Resource name. - string name = 1; - - // Resource name of the `DataProvider` associated with this `EventGroup`. - string data_provider = 2 [ - (google.api.field_behavior) = IMMUTABLE, - (google.api.resource_reference).type = "halo.wfanet.org/DataProvider" - ]; - - // ID referencing the `EventGroup` in an external system, provided by the - // `DataProvider`. - // - // If set, this value must be unique among `EventGroup`s for the parent - // `DataProvider`. - string event_group_reference_id = 3; - - // The template that events associated with this `EventGroup` conform to. - message EventTemplate { - // The type of the Event Template. A fully-qualified protobuf message type. - string type = 1 [(google.api.field_behavior) = REQUIRED]; - } - // The `EventTemplate`s that events associated with this `EventGroup` conform - // to. - repeated EventTemplate event_templates = 4; - - // Wrapper for per-EDP Event Group metadata. - message Metadata { - // The resource name of the metadata descriptor. - string event_group_metadata_descriptor = 1 - [(google.api.resource_reference).type = - "halo.wfanet.org/EventGroupMetadataDescriptor"]; - - // The serialized value of the metadata message. - google.protobuf.Any metadata = 2; - } - - Metadata metadata = 5; -} diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/event_groups_service.proto b/src/main/proto/wfa/measurement/reporting/v1alpha/event_groups_service.proto deleted file mode 100644 index 79bc69d0a0c..00000000000 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/event_groups_service.proto +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.reporting.v1alpha; - -import "google/api/client.proto"; -import "google/api/field_behavior.proto"; -import "google/api/annotations.proto"; -import "google/api/resource.proto"; -import "wfa/measurement/reporting/v1alpha/event_group.proto"; - -option java_package = "org.wfanet.measurement.reporting.v1alpha"; -option java_multiple_files = true; -option java_outer_classname = "EventGroupsServiceProto"; - -// Service for interacting with `EventGroup` resources. -service EventGroups { - // Lists `EventGroup`s. Results in a `PERMISSION_DENIED` error if attempting - // to list `EventGroup`s that the authenticated user does not have access to. - rpc ListEventGroups(ListEventGroupsRequest) - returns (ListEventGroupsResponse) { - option (google.api.http) = { - get: "/v1alpha/{parent=measurementConsumers/*/dataProviders/*}/eventGroups" - }; - option (google.api.method_signature) = "parent"; - } -} - -// Request message for `ListEventGroups` method. -message ListEventGroupsRequest { - // Resource name of the parent `DataProvider` under a given - // `MeasurementConsumer`, in the form - // `measurementConsumers/{measurement_consumer}/dataProviders/{data_provider}`. - // The wildcard ID (`-`) may be used in place of the `DataProvider` ID to list - // across `DataProvider`s, in which case a filter should be specified. - string parent = 1 [ - (google.api.field_behavior) = REQUIRED, - (google.api.resource_reference) = { - child_type: "reporting.halo-cmm.org/EventGroup" - } - ]; - - // The maximum number of `EventGroup`s to return. The service may return fewer - // than this value. If unspecified, at most 50 `EventGroup`s will be - // returned. The maximum value is 1000; values above 1000 will be coerced to - // 1000. - int32 page_size = 2; - - // A token from a previous call, specified to retrieve the next page. See - // https://aip.dev/158. - string page_token = 3; - - // Result filter. Raw CEL expression that is applied to a message which has a - // field for each event group template. - string filter = 4; -} - -// Response message for `ListEventGroups` method. -message ListEventGroupsResponse { - // The `EventGroup` resources. - repeated EventGroup event_groups = 1; - - // A token that can be specified in a subsequent call to retrieve the next - // page. See https://aip.dev/158. - string next_page_token = 2; -} diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/metric.proto b/src/main/proto/wfa/measurement/reporting/v1alpha/metric.proto deleted file mode 100644 index 2a18e2b4650..00000000000 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/metric.proto +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.reporting.v1alpha; - -import "google/api/field_behavior.proto"; -import "google/api/resource.proto"; - -option java_package = "org.wfanet.measurement.reporting.v1alpha"; -option java_multiple_files = true; -option java_outer_classname = "MetricProto"; - -// Definition of a computed metric for a `Report` in terms of set operations. -message Metric { - // Parameters that are used to generate `Reach` metric. - message ReachParams {} - // Parameters that are used to generate `Frequency Histogram` metric. - message FrequencyHistogramParams { - // Maximum frequency to reveal in the histogram. - int32 maximum_frequency_per_user = 1; - } - // Parameters that are used to generate `Impression Count` metric. - message ImpressionCountParams { - // Setting the maximum frequency for each user is for noising the impression - // estimation with the noise proportional to maximum_frequency_per_user to - // guarantee epsilon-DP, i.e. the higher maximum_frequency_per_user, the - // larger the variance. On the other hand, if maximum_frequency_per_user is - // too small, there's truncation bias. Through optimization, the recommended - // value for maximum_frequency_per_user = 60 for the case with 1M audience - // size. - int32 maximum_frequency_per_user = 1; - } - // Parameters that are used to generate `Watch Duration` metric. - message WatchDurationParams { - // Maximum frequency per user that will be included in this measurement. - // - // Deprecated: Not supported by the CMMS. - int32 maximum_frequency_per_user = 1 [deprecated = true]; - // Maximum watch duration per user that will be included in this - // measurement. Recommended maximum_watch_duration_per_user = cap on the - // total watch duration of all the impressions of a user = 4000 sec for the - // case with 1M audience size. - int32 maximum_watch_duration_per_user = 2; - } - - // Types of metrics that can be selected to be in a `Report`. - // REQUIRED - oneof metric_type { - // The count of unique audiences reached given a set of event groups. - ReachParams reach = 1; - // The reach frequency histogram given a set of event groups. Currently, we - // only support union operations for frequency histograms. Any other - // operations on frequency histograms won't guarantee the result is a - // frequency histogram. - FrequencyHistogramParams frequency_histogram = 2; - // The impression count given a set of event groups. - ImpressionCountParams impression_count = 3; - // The watch duration given a set of event groups. - WatchDurationParams watch_duration = 4; - } - - // Whether the results for a given time interval is cumulative with those of - // previous time intervals. Only supported when using `PeriodicTimeInterval`. - bool cumulative = 6; - - // Represents a binary set operation. - message SetOperation { - // Types of set operators. - enum Type { - // Default value. This value is unused. - TYPE_UNSPECIFIED = 0; - // The set union operation. - UNION = 1; - // The set difference operation. - DIFFERENCE = 2; - // The set intersection operation. - INTERSECTION = 3; - } - // The type of set operator that will be applied on the operands. - Type type = 1 [(google.api.field_behavior) = REQUIRED]; - - // The object of a set operation. - message Operand { - oneof operand { - // Resource name of a `ReportingSet` describing a set operand. Note that - // the reporting set is constrained by the `EventGroupUniverse` defined - // in the `Report`. - string reporting_set = 1 [(google.api.resource_reference).type = - "reporting.halo-cmm.org/ReportingSet"]; - // Nested `SetOperation` to allow for expressions with more terms. - SetOperation operation = 2; - } - } - - // Left-hand side operand of the operation. - Operand lhs = 3 [(google.api.field_behavior) = REQUIRED]; - // Right-hand side operand of the operation. If not specified, implies the - // empty set. - Operand rhs = 4; - } - - // A `SetOperation` associated with a name. - message NamedSetOperation { - // Unique name of the set operation for display purposes and creation of - // measurement reference ID. The name should be unique for the SAME metric - // type among all metrics in a report. - string unique_name = 1 [(google.api.field_behavior) = REQUIRED]; - - // A set operation that specifies the set of event groups. - SetOperation set_operation = 2 [(google.api.field_behavior) = REQUIRED]; - } - - // A list of named `SetOperations` on which the same metric will be applied. - repeated NamedSetOperation set_operations = 7 - [(google.api.field_behavior) = REQUIRED]; -} diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/page_token.proto b/src/main/proto/wfa/measurement/reporting/v1alpha/page_token.proto deleted file mode 100644 index 159cee195ad..00000000000 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/page_token.proto +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2023 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.reporting.v1alpha; - -option java_package = "org.wfanet.measurement.reporting.v1alpha"; -option java_multiple_files = true; - -message ListReportingSetsPageToken { - int32 page_size = 1; - string measurement_consumer_reference_id = 2; - message PreviousPageEnd { - string measurement_consumer_reference_id = 1; - fixed64 external_reporting_set_id = 2; - } - PreviousPageEnd last_reporting_set = 3; -} - -message ListReportsPageToken { - int32 page_size = 1; - string measurement_consumer_reference_id = 2; - message PreviousPageEnd { - string measurement_consumer_reference_id = 1; - fixed64 external_report_id = 2; - } - PreviousPageEnd last_report = 3; -} diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/report.proto b/src/main/proto/wfa/measurement/reporting/v1alpha/report.proto deleted file mode 100644 index d41acb92203..00000000000 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/report.proto +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.reporting.v1alpha; - -import "google/api/field_behavior.proto"; -import "google/api/resource.proto"; -import "wfa/measurement/reporting/v1alpha/metric.proto"; -import "wfa/measurement/reporting/v1alpha/time_interval.proto"; - -option java_package = "org.wfanet.measurement.reporting.v1alpha"; -option java_multiple_files = true; -option java_outer_classname = "ReportProto"; - -// Resource representing a report. -message Report { - option (google.api.resource) = { - type: "reporting.halo-cmm.org/Report" - pattern: "measurementConsumers/{measurement_consumer}/reports/{report}" - }; - - // Resource name. - string name = 1; - - // Used as the prefix of the idempotency keys of internal measurements. This - // value must be unique among `Report`s for the parent `MeasurementConsumer`. - // TODO(@riemanli) Moved the idempotency key to request messages. - string report_idempotency_key = 2 [(google.api.field_behavior) = REQUIRED]; - - // Representation of a Measurement Consumer entity. - string measurement_consumer = 3 [ - (google.api.field_behavior) = REQUIRED, - (google.api.resource_reference) = { - type: "halo.wfanet.org/MeasurementConsumer" - } - ]; - - // Define the universe for all sets of event groups in `metrics`. - message EventGroupUniverse { - // Map entry of `EventGroup` to filter predicate. - message EventGroupEntry { - // Key of the map entry, which is an `EventGroup` resource name. - string key = 1 [ - (google.api.field_behavior) = REQUIRED, - (google.api.resource_reference) = { - type: "reporting.halo-cmm.org/EventGroup" - } - ]; - - // Filter predicate in CEL. If unspecified, evaluates to `true`. - string value = 2; - } - - // A Map with `EventGroup`s as the keys and filter predicates as the values. - repeated EventGroupEntry event_group_entries = 1 - [(google.api.field_behavior) = REQUIRED]; - } - - // The universe that defines all sets of event groups in `metrics`. Events - // from event groups that aren't listed here won't be included. - EventGroupUniverse event_group_universe = 4 - [(google.api.field_behavior) = REQUIRED]; - - // Types of time intervals for metric aggregation. - // REQUIRED - oneof time { - // A list of time intervals with different start times and end times. - TimeIntervals time_intervals = 5; - // A series of time intervals with the same length. - PeriodicTimeInterval periodic_time_interval = 6; - } - - // The metrics that are included in `Report`. - repeated Metric metrics = 7 [(google.api.field_behavior) = REQUIRED]; - - // Possible states of a `Report`. - enum State { - // Default value. This value is unused. - STATE_UNSPECIFIED = 0; - // Computation is running. - RUNNING = 1; - // Completed successfully. Terminal state. - SUCCEEDED = 2; - // Completed with failure. Terminal state. - FAILED = 3; - } - - // Report state. - State state = 8 [(google.api.field_behavior) = OUTPUT_ONLY]; - - // Set operation calculations are done for each set operation for a given - // `Metric`, which is column by column. - message Result { - // A column with a header and values. - message Column { - // The header of the column. - string column_header = 1; - // Must be in the same order as `row_headers`. - repeated double set_operations = 2; - } - // The measurement result of Reach, ImpressionCount, or WatchDuration - message ScalarTable { - // Must be in the same order as column's set_operations. - repeated string row_headers = 1; - // Must be in the same order as `row_headers`. - repeated Column columns = 2; - } - // For Reach, ImpressionCount, or WatchDuration - ScalarTable scalar_table = 1; - - // HistogramTable has an additional field frequency for each row, thus - // different from ScalarTable. - message HistogramTable { - // Combine header and frequency as Row. - message Row { - // Must be in the same order as column's set_operations. - string row_header = 1; - // frequency of the row - int32 frequency = 2; - } - // Must be in the same order as column's set_operations. - repeated Row rows = 1; - // Must be in the same order as row's `row_headers`. - repeated Column columns = 2; - } - // Each HistogramTable contains the result of one frequency histogram - // metric. - repeated HistogramTable histogram_tables = 2; - } - - // Set only when `state` is `SUCCEEDED`. - Result result = 9 [(google.api.field_behavior) = OUTPUT_ONLY]; -} diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/reporting_set.proto b/src/main/proto/wfa/measurement/reporting/v1alpha/reporting_set.proto deleted file mode 100644 index e1c5ce1ce45..00000000000 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/reporting_set.proto +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.reporting.v1alpha; - -import "google/api/field_behavior.proto"; -import "google/api/resource.proto"; - -option java_package = "org.wfanet.measurement.reporting.v1alpha"; -option java_multiple_files = true; -option java_outer_classname = "ReportingSetProto"; - -// Resource that describes a set of events that can be used as an operand for a -// `SetOperation`. -message ReportingSet { - option (google.api.resource) = { - type: "reporting.halo-cmm.org/ReportingSet" - pattern: "measurementConsumers/{measurement_consumer}/reportingSets/{reporting_set}" - }; - - // Resource name. - string name = 1; - - // Set of EventGroup resource names. - repeated string event_groups = 2 [ - (google.api.resource_reference).type = "reporting.halo-cmm.org/EventGroup", - (google.api.field_behavior) = REQUIRED, - (google.api.field_behavior) = IMMUTABLE - ]; - - // CEL filter predicate that applies to all `event_groups`. - // - // This filter and the one in the `event_group_filters` form a conjunction - // that specifies which impressions are included in this set. If - // unspecified, evaluates to `true`. - string filter = 3 [(google.api.field_behavior) = IMMUTABLE]; - - // Human-readable name for display purposes. - string display_name = 4; -} diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/reporting_sets_service.proto b/src/main/proto/wfa/measurement/reporting/v1alpha/reporting_sets_service.proto deleted file mode 100644 index 721df2f02fc..00000000000 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/reporting_sets_service.proto +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.reporting.v1alpha; - -import "google/api/client.proto"; -import "google/api/field_behavior.proto"; -import "google/api/annotations.proto"; -import "google/api/resource.proto"; -import "wfa/measurement/reporting/v1alpha/reporting_set.proto"; - -option java_package = "org.wfanet.measurement.reporting.v1alpha"; -option java_multiple_files = true; -option java_outer_classname = "ReportingSetsServiceProto"; - -// Service for interacting with `ReportingSet` resources. -service ReportingSets { - // Lists `ReportingSet`s. - rpc ListReportingSets(ListReportingSetsRequest) - returns (ListReportingSetsResponse) { - option (google.api.http) = { - get: "/v1alpha/{parent=measurementConsumers/*}/reportingSets" - }; - option (google.api.method_signature) = "parent"; - } - - // Creates a `ReportingSet`. - rpc CreateReportingSet(CreateReportingSetRequest) returns (ReportingSet) { - option (google.api.http) = { - post: "/v1alpha/{parent=measurementConsumers/*}/reportingSets" - body: "reporting_set" - }; - option (google.api.method_signature) = "parent,reporting_set"; - } -} - -// Request message for `ListReportingSet` method. -message ListReportingSetsRequest { - // Format: measurementConsumers/{measurement_consumer} - string parent = 1 [ - (google.api.field_behavior) = REQUIRED, - (google.api.resource_reference) = { - child_type: "reporting.halo-cmm.org/ReportingSet" - } - ]; - - // The maximum number of reportingSets to return. The service may return fewer - // than this value. - // If unspecified, at most 50 reportingSets will be returned. - // The maximum value is 1000; values above 1000 will be coerced to 1000. - int32 page_size = 2; - - // A page token, received from a previous `ListReportingSets` call. - // Provide this to retrieve the subsequent page. - // - // When paginating, all other parameters provided to `ListReportingSets` must - // match the call that provided the page token. - string page_token = 3; -} - -// Response message for `ListReportingSet` method. -message ListReportingSetsResponse { - // The reportingSets from the specified measurement consumer. - repeated ReportingSet reporting_sets = 1; - - // A token, which can be sent as `page_token` to retrieve the next page. - // If this field is omitted, there are no subsequent pages. - string next_page_token = 2; -} - -// Request message for `CreateReportingSet` method. -message CreateReportingSetRequest { - // The parent resource where this reportingSet will be created. - // Format: measurementConsumers/{measurement_consumer} - string parent = 1 [ - (google.api.field_behavior) = REQUIRED, - (google.api.resource_reference) = { - child_type: "reporting.halo-cmm.org/ReportingSet" - } - ]; - - // The ReportingSet to create. - ReportingSet reporting_set = 2 [(google.api.field_behavior) = REQUIRED]; -} diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/reports_service.proto b/src/main/proto/wfa/measurement/reporting/v1alpha/reports_service.proto deleted file mode 100644 index e36ce26c3ca..00000000000 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/reports_service.proto +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.reporting.v1alpha; - -import "google/api/client.proto"; -import "google/api/field_behavior.proto"; -import "google/api/annotations.proto"; -import "google/api/resource.proto"; -import "wfa/measurement/reporting/v1alpha/report.proto"; - -option java_package = "org.wfanet.measurement.reporting.v1alpha"; -option java_multiple_files = true; -option java_outer_classname = "ReportsServiceProto"; - -// Service for interacting with `Report` resources. -service Reports { - // Returns the `Report` with the specified resource key. - rpc GetReport(GetReportRequest) returns (Report) { - option (google.api.http) = { - get: "/v1alpha/{name=measurementConsumers/*/reports/*}" - }; - option (google.api.method_signature) = "name"; - } - - // Lists `Report`s. - rpc ListReports(ListReportsRequest) returns (ListReportsResponse) { - option (google.api.http) = { - get: "/v1alpha/{parent=measurementConsumers/*}/reports" - }; - option (google.api.method_signature) = "parent"; - } - - // Creates a `Report`. - rpc CreateReport(CreateReportRequest) returns (Report) { - option (google.api.http) = { - post: "/v1alpha/{parent=measurementConsumers/*}/reports" - body: "report" - }; - option (google.api.method_signature) = "parent,report"; - } -} - -// Request message for `GetReport` method. -message GetReportRequest { - // The name of the report to retrieve. - // Format: measurementConsumers/{measurement_consumer}/reports/{report} - string name = 1 [ - (google.api.field_behavior) = REQUIRED, - (google.api.resource_reference) = { type: "reporting.halo-cmm.org/Report" } - ]; -} - -// Request message for `ListReports` method. -message ListReportsRequest { - // Format: measurementConsumers/{measurement_consumer} - string parent = 1 [ - (google.api.field_behavior) = REQUIRED, - (google.api.resource_reference) = { - child_type: "reporting.halo-cmm.org/Report" - } - ]; - - // The maximum number of reports to return. The service may return fewer than - // this value. - // If unspecified, at most 50 reports will be returned. - // The maximum value is 1000; values above 1000 will be coerced to 1000. - int32 page_size = 2; - - // A page token, received from a previous `ListReports` call. - // Provide this to retrieve the subsequent page. - // - // When paginating, all other parameters provided to `ListReports` must match - // the call that provided the page token. - string page_token = 3; -} - -// Response message for `ListReports` method. -message ListReportsResponse { - // The reports from the specified measurement consumer. - repeated Report reports = 1; - - // A token, which can be sent as `page_token` to retrieve the next page. - // If this field is omitted, there are no subsequent pages. - string next_page_token = 2; -} - -// Request message for `CreateReport` method. -message CreateReportRequest { - // The parent resource where this report will be created. - // Format: measurementConsumers/{measurement_consumer} - string parent = 1 [ - (google.api.field_behavior) = REQUIRED, - (google.api.resource_reference) = { - child_type: "reporting.halo-cmm.org/Report" - } - ]; - - // The report to create. - Report report = 2 [(google.api.field_behavior) = REQUIRED]; -} diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/time_interval.proto b/src/main/proto/wfa/measurement/reporting/v1alpha/time_interval.proto deleted file mode 100644 index ef7ea15f083..00000000000 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/time_interval.proto +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.reporting.v1alpha; - -import "google/api/field_behavior.proto"; -import "google/protobuf/duration.proto"; -import "google/protobuf/timestamp.proto"; - -option java_package = "org.wfanet.measurement.reporting.v1alpha"; -option java_multiple_files = true; -option java_outer_classname = "TimeIntervalProto"; - -// A time interval. -message TimeInterval { - // Start of the time interval, inclusive. - google.protobuf.Timestamp start_time = 1 - [(google.api.field_behavior) = REQUIRED]; - // End of the time interval, exclusive. This must be later than the start - // time. - google.protobuf.Timestamp end_time = 2 - [(google.api.field_behavior) = REQUIRED]; -} - -// A list of time intervals with different start times and end times. -message TimeIntervals { - // A list of time intervals. - repeated TimeInterval time_intervals = 1 - [(google.api.field_behavior) = REQUIRED]; -} - -// A series of time intervals with the same length. -message PeriodicTimeInterval { - // Start of the first time interval, inclusive. - google.protobuf.Timestamp start_time = 1 - [(google.api.field_behavior) = REQUIRED]; - - // Increment for each time interval. The first interval will be [start_time, - // start_time + increment), the second [start_time + increment, start_time + - // increment * 2) and so forth. - // - // TODO(@SanjayVas): Consider whether we want to use civil time instead. - google.protobuf.Duration increment = 2 - [(google.api.field_behavior) = REQUIRED]; - - // Number of intervals. - int32 interval_count = 3 [(google.api.field_behavior) = REQUIRED]; -} diff --git a/src/main/terraform/gcloud/cmms/reporting.tf b/src/main/terraform/gcloud/cmms/reporting.tf deleted file mode 100644 index d0e416cba56..00000000000 --- a/src/main/terraform/gcloud/cmms/reporting.tf +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2023 The Cross-Media Measurement Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -module "reporting_cluster" { - source = "../modules/cluster" - - name = local.reporting_cluster_name - location = local.cluster_location - release_channel = var.cluster_release_channel - secret_key = module.common.cluster_secret_key -} - -module "reporting_default_node_pool" { - source = "../modules/node-pool" - - name = "default" - cluster = module.reporting_cluster.cluster - service_account = module.common.cluster_service_account - machine_type = "e2-small" - max_node_count = 8 -} - -module "reporting" { - source = "../modules/reporting" - - iam_service_account_name = "reporting-internal" - postgres_instance = google_sql_database_instance.postgres - postgres_database_name = "reporting" -} - -resource "google_compute_address" "reporting_v1alpha" { - name = "reporting-v1alpha" -} diff --git a/src/test/kotlin/org/wfanet/measurement/integration/deploy/common/postgres/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/integration/deploy/common/postgres/BUILD.bazel deleted file mode 100644 index b0ef7121c51..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/integration/deploy/common/postgres/BUILD.bazel +++ /dev/null @@ -1,21 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_test") - -kt_jvm_test( - name = "PostgresInProcessLifeOfAReportIntegrationTest", - srcs = ["PostgresInProcessLifeOfAReportIntegrationTest.kt"], - tags = [ - "cpu:2", - "no-remote-exec", - ], - test_class = "org.wfanet.measurement.integration.deploy.common.postgres.PostgresInProcessLifeOfAReportIntegrationTest", - runtime_deps = [ - "@wfa_common_jvm//imports/java/org/yaml:snakeyaml", - ], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/integration/common/reporting:in_process_life_of_a_report_integration_test", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres:services", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing:database_provider", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", - ], -) diff --git a/src/test/kotlin/org/wfanet/measurement/integration/deploy/common/postgres/PostgresInProcessLifeOfAReportIntegrationTest.kt b/src/test/kotlin/org/wfanet/measurement/integration/deploy/common/postgres/PostgresInProcessLifeOfAReportIntegrationTest.kt deleted file mode 100644 index 35c3dfad5de..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/integration/deploy/common/postgres/PostgresInProcessLifeOfAReportIntegrationTest.kt +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.integration.deploy.common.postgres - -import java.time.Clock -import org.junit.ClassRule -import org.wfanet.measurement.common.db.r2dbc.postgres.testing.PostgresDatabaseProviderRule -import org.wfanet.measurement.common.identity.RandomIdGenerator -import org.wfanet.measurement.integration.common.reporting.InProcessLifeOfAReportIntegrationTest -import org.wfanet.measurement.reporting.deploy.common.server.postgres.PostgresServices -import org.wfanet.measurement.reporting.deploy.postgres.testing.Schemata - -/** Implementation of [InProcessLifeOfAReportIntegrationTest] for Postgres. */ -class PostgresInProcessLifeOfAReportIntegrationTest : InProcessLifeOfAReportIntegrationTest() { - override val reportingServerDataServices by lazy { - PostgresServices.create(RandomIdGenerator(Clock.systemUTC()), databaseProvider.createDatabase()) - } - - companion object { - @get:ClassRule - @JvmStatic - val databaseProvider = PostgresDatabaseProviderRule(Schemata.REPORTING_CHANGELOG_PATH) - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/BUILD.bazel deleted file mode 100644 index 95b258038e9..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/BUILD.bazel +++ /dev/null @@ -1,21 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_test") - -filegroup( - name = "textproto_files", - srcs = glob(["*.textproto"]), -) - -kt_jvm_test( - name = "EncryptionKeyPairMapTest", - srcs = ["EncryptionKeyPairMapTest.kt"], - data = [ - "//src/main/k8s/testing/secretfiles:secret_files", - "//src/test/kotlin/org/wfanet/measurement/reporting/deploy/common:textproto_files", - ], - test_class = "org.wfanet.measurement.reporting.deploy.common.EncryptionKeyPairMapTest", - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common:encryption_key_pair_map", - "@wfa_common_jvm//imports/java/com/google/common/truth", - "@wfa_common_jvm//imports/kotlin/kotlin/test", - ], -) diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/EncryptionKeyPairMapTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/EncryptionKeyPairMapTest.kt deleted file mode 100644 index 3c3c4c0a771..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/EncryptionKeyPairMapTest.kt +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.common - -import com.google.common.truth.Truth.assertThat -import com.google.protobuf.ByteString -import com.google.protobuf.kotlin.toByteStringUtf8 -import java.nio.file.Path -import java.nio.file.Paths -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.wfanet.measurement.common.crypto.PrivateKeyHandle -import org.wfanet.measurement.common.crypto.tink.TinkPublicKeyHandle -import org.wfanet.measurement.common.getRuntimePath -import org.wfanet.measurement.common.readByteString -import picocli.CommandLine -import picocli.CommandLine.Mixin - -private val SECRETS_DIR: Path = - getRuntimePath( - Paths.get("wfa_measurement_system", "src", "main", "k8s", "testing", "secretfiles") - )!! - -private val ENCRYPTION_KEY_PAIR_MAP: Path = - getRuntimePath( - Paths.get( - "wfa_measurement_system", - "src", - "test", - "kotlin", - "org", - "wfanet", - "measurement", - "reporting", - "deploy", - "common", - "key_pair_map.textproto", - ) - )!! - -private val PUBLIC_KEY_FILE_1 = SECRETS_DIR.resolve("mc_enc_public.tink").toFile() -private val PUBLIC_KEY_1 = PUBLIC_KEY_FILE_1.readByteString() -private val PUBLIC_KEY_FILE_2 = SECRETS_DIR.resolve("edp1_enc_public.tink").toFile() -private val PUBLIC_KEY_2 = PUBLIC_KEY_FILE_2.readByteString() -private val PUBLIC_KEY_FILE_3 = SECRETS_DIR.resolve("edp2_enc_public.tink").toFile() -private val PUBLIC_KEY_3 = PUBLIC_KEY_FILE_3.readByteString() -private val NON_EXISTENT_PUBLIC_KEY = "non existent public key".toByteStringUtf8() -private val MEASUREMENT_CONSUMER1 = "measurement_consumer1" -private val MEASUREMENT_CONSUMER2 = "measurement_consumer2" -private val NON_EXISTENT_MEASUREMENT_CONSUMER = "non_existent_measurement_consumer" -private val PLAIN_TEXT = "This is plain text".toByteStringUtf8() - -@RunWith(JUnit4::class) -class EncryptionKeyPairMapTest { - - private fun findKeyPair( - principal: String, - publicKey: ByteString, - keyPairMap: Map>>, - ) = - keyPairMap[principal]?.find { (key, _): Pair -> key == publicKey } - - @Test - fun `keyPairMap returns corresponding private keys`() { - val args = - arrayOf("--key-pair-dir=$SECRETS_DIR", "--key-pair-config-file=$ENCRYPTION_KEY_PAIR_MAP") - - runTest(args) { keyPairMap -> - val findPrivateKey = { principal: String, publicKey: ByteString -> - findKeyPair(principal, publicKey, keyPairMap)?.second - } - verifyKeyPair( - PUBLIC_KEY_1, - requireNotNull(findPrivateKey(MEASUREMENT_CONSUMER1, PUBLIC_KEY_1)), - ) - verifyKeyPair( - PUBLIC_KEY_2, - requireNotNull(findPrivateKey(MEASUREMENT_CONSUMER1, PUBLIC_KEY_2)), - ) - verifyKeyPair( - PUBLIC_KEY_3, - requireNotNull(findPrivateKey(MEASUREMENT_CONSUMER2, PUBLIC_KEY_3)), - ) - } - } - - @Test - fun `keyPairMap returns null when private key is not found`() { - val args = - arrayOf("--key-pair-dir=$SECRETS_DIR", "--key-pair-config-file=$ENCRYPTION_KEY_PAIR_MAP") - - runTest(args) { keyPairMap -> - assertThat(findKeyPair(MEASUREMENT_CONSUMER1, NON_EXISTENT_PUBLIC_KEY, keyPairMap)).isNull() - } - } - - @Test - fun `keyPairMap returns null when principal is not found`() { - val args = - arrayOf("--key-pair-dir=$SECRETS_DIR", "--key-pair-config-file=$ENCRYPTION_KEY_PAIR_MAP") - - runTest(args) { keyPairMap -> - assertThat(findKeyPair(NON_EXISTENT_MEASUREMENT_CONSUMER, PUBLIC_KEY_1, keyPairMap)).isNull() - } - } - - private class KeyPairMapWrapper( - val verifyBlock: (keyPairs: Map>>) -> Unit - ) : Runnable { - @Mixin lateinit var encryptionKeyPairMap: EncryptionKeyPairMap - - override fun run() { - verifyBlock(encryptionKeyPairMap.keyPairs) - } - } - - private fun runTest( - args: Array, - verifyBlock: (Map>>) -> Unit, - ) { - val returnCode = CommandLine(KeyPairMapWrapper(verifyBlock)).execute(*args) - assertThat(returnCode).isEqualTo(0) - } - - private fun verifyKeyPair(publicKeyData: ByteString, privateKeyHandle: PrivateKeyHandle) { - val publicKeyHandle = TinkPublicKeyHandle(publicKeyData) - val encryptedText = publicKeyHandle.hybridEncrypt(PLAIN_TEXT) - val decryptedText = privateKeyHandle.hybridDecrypt(encryptedText) - assertThat(decryptedText).isEqualTo(PLAIN_TEXT) - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/key_pair_map.textproto b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/key_pair_map.textproto deleted file mode 100644 index 3a8e52ec9ce..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/key_pair_map.textproto +++ /dev/null @@ -1,26 +0,0 @@ -# proto-file: -# src/main/proto/wfa/measurement/config/reporting/encryption_key_pair_config.proto -# proto-message: EncryptionKeyPairConfig - -principal_key_pairs { - principal: "measurement_consumer1" - key_pairs { - public_key_file: "mc_enc_public.tink" - private_key_file: "mc_enc_private.tink" - } - key_pairs { - public_key_file: "edp1_enc_public.tink" - private_key_file: "edp1_enc_private.tink" - } -} -principal_key_pairs { - principal: "measurement_consumer2" - key_pairs { - public_key_file: "edp2_enc_public.tink" - private_key_file: "edp2_enc_private.tink" - } - key_pairs { - public_key_file: "edp3_enc_public.tink" - private_key_file: "edp3_enc_private.tink" - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/config/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/config/BUILD.bazel deleted file mode 100644 index 6a543aadc48..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/config/BUILD.bazel +++ /dev/null @@ -1,12 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_test") - -kt_jvm_test( - name = "MeasurementSpecConfigValidatorTest", - srcs = ["MeasurementSpecConfigValidatorTest.kt"], - test_class = "org.wfanet.measurement.reporting.deploy.config.MeasurementSpecConfigValidatorTest", - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/config:measurement_spec_config_validator", - "@wfa_common_jvm//imports/java/com/google/common/truth", - "@wfa_common_jvm//imports/kotlin/kotlin/test", - ], -) diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/config/MeasurementSpecConfigValidatorTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/config/MeasurementSpecConfigValidatorTest.kt deleted file mode 100644 index a13494f6e3a..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/config/MeasurementSpecConfigValidatorTest.kt +++ /dev/null @@ -1,295 +0,0 @@ -/* - * Copyright 2023 The Cross-Media Measurement Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.wfanet.measurement.reporting.deploy.config - -import com.google.common.truth.Truth.assertThat -import kotlin.test.assertFailsWith -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.wfanet.measurement.config.reporting.MeasurementSpecConfigKt -import org.wfanet.measurement.config.reporting.copy -import org.wfanet.measurement.config.reporting.measurementSpecConfig - -@RunWith(JUnit4::class) -class MeasurementSpecConfigValidatorTest { - @Test - fun `validate throws no exception when config is valid`() { - MEASUREMENT_SPEC_CONFIG.validate() - } - - @Test - fun `validate throws ILLEGAL_STATE_EXCEPTION when epsilon is 0`() { - val invalidMeasurementSpecConfig = - MEASUREMENT_SPEC_CONFIG.copy { - reachSingleDataProvider = - MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.copy { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.0 - delta = 0.0 - } - } - } - val exception = - assertFailsWith { invalidMeasurementSpecConfig.validate() } - assertThat(exception.message).contains("privacy_params") - } - - @Test - fun `validate throws ILLEGAL_STATE_EXCEPTION when delta is negaitve`() { - val invalidMeasurementSpecConfig = - MEASUREMENT_SPEC_CONFIG.copy { - reachSingleDataProvider = - MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.copy { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 1.0 - delta = -1.0 - } - } - } - val exception = - assertFailsWith { invalidMeasurementSpecConfig.validate() } - assertThat(exception.message).contains("privacy_params") - } - - @Test - fun `validate throws ILLEGAL_STATE_EXCEPTION when fixed_start width is 0`() { - val invalidMeasurementSpecConfig = - MEASUREMENT_SPEC_CONFIG.copy { - reachSingleDataProvider = - MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.copy { - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { width = 0.0f } - } - } - } - val exception = - assertFailsWith { invalidMeasurementSpecConfig.validate() } - assertThat(exception.message).contains("vid_sampling_interval") - } - - @Test - fun `validate throws ILLEGAL_STATE_EXCEPTION when fixed_start start is negative`() { - val invalidMeasurementSpecConfig = - MEASUREMENT_SPEC_CONFIG.copy { - reachSingleDataProvider = - MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.copy { - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { start = -1.0f } - } - } - } - val exception = - assertFailsWith { invalidMeasurementSpecConfig.validate() } - assertThat(exception.message).contains("vid_sampling_interval") - } - - @Test - fun `validate throws ILLEGAL_STATE_EXCEPTION when fixed_start interval is more than 1`() { - val invalidMeasurementSpecConfig = - MEASUREMENT_SPEC_CONFIG.copy { - reachSingleDataProvider = - MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.copy { - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { width = 1.1f } - } - } - } - val exception = - assertFailsWith { invalidMeasurementSpecConfig.validate() } - assertThat(exception.message).contains("vid_sampling_interval") - } - - @Test - fun `validate throws ILLEGAL_STATE_EXCEPTION when random_start width is negative`() { - val invalidMeasurementSpecConfig = - MEASUREMENT_SPEC_CONFIG.copy { - reachSingleDataProvider = - MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.copy { - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - randomStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.randomStart { - width = -5 - numVidBuckets = 5 - } - } - } - } - val exception = - assertFailsWith { invalidMeasurementSpecConfig.validate() } - assertThat(exception.message).contains("vid_sampling_interval") - } - - @Test - fun `validate throws ILLEGAL_STATE_EXCEPTION when random_start num_vid_buckets is 0`() { - val invalidMeasurementSpecConfig = - MEASUREMENT_SPEC_CONFIG.copy { - reachSingleDataProvider = - MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.copy { - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - randomStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.randomStart { numVidBuckets = 0 } - } - } - } - val exception = - assertFailsWith { invalidMeasurementSpecConfig.validate() } - assertThat(exception.message).contains("vid_sampling_interval") - } - - @Test - fun `validate throws ILLEGAL_STATE_EXCEPTION when random_start width exceeds vid buckets`() { - val invalidMeasurementSpecConfig = - MEASUREMENT_SPEC_CONFIG.copy { - reachSingleDataProvider = - MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.copy { - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - randomStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.randomStart { - width = 6 - numVidBuckets = 5 - } - } - } - } - val exception = - assertFailsWith { invalidMeasurementSpecConfig.validate() } - assertThat(exception.message).contains("vid_sampling_interval") - } - - companion object { - private val MEASUREMENT_SPEC_CONFIG = measurementSpecConfig { - reachSingleDataProvider = - MeasurementSpecConfigKt.reachSingleDataProvider { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.000207 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - reach = - MeasurementSpecConfigKt.reach { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.0007444 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - randomStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.randomStart { - width = 256 - numVidBuckets = 300 - } - } - } - reachAndFrequencySingleDataProvider = - MeasurementSpecConfigKt.reachAndFrequencySingleDataProvider { - reachPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.000207 - delta = 1e-15 - } - frequencyPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.004728 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - reachAndFrequency = - MeasurementSpecConfigKt.reachAndFrequency { - reachPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.0007444 - delta = 1e-15 - } - frequencyPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.014638 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - randomStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.randomStart { - width = 256 - numVidBuckets = 300 - } - } - } - impression = - MeasurementSpecConfigKt.impression { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.003592 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - duration = - MeasurementSpecConfigKt.duration { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.007418 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - } - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/BUILD.bazel deleted file mode 100644 index ef43a61efeb..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/BUILD.bazel +++ /dev/null @@ -1,58 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_test") - -kt_jvm_test( - name = "PostgresMeasurementsServiceTest", - srcs = ["PostgresMeasurementsServiceTest.kt"], - tags = [ - "cpu:2", - "no-remote-exec", - ], - test_class = "org.wfanet.measurement.reporting.deploy.postgres.PostgresMeasurementsServiceTest", - runtime_deps = [ - "@wfa_common_jvm//imports/java/org/yaml:snakeyaml", - ], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres:services", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing:database_provider", - ], -) - -kt_jvm_test( - name = "PostgresReportingSetsServiceTest", - srcs = ["PostgresReportingSetsServiceTest.kt"], - tags = [ - "cpu:2", - "no-remote-exec", - ], - test_class = "org.wfanet.measurement.reporting.deploy.postgres.PostgresReportingSetsServiceTest", - runtime_deps = [ - "@wfa_common_jvm//imports/java/org/yaml:snakeyaml", - ], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres:services", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing:database_provider", - ], -) - -kt_jvm_test( - name = "PostgresReportsServiceTest", - srcs = ["PostgresReportsServiceTest.kt"], - tags = [ - "cpu:2", - "no-remote-exec", - ], - test_class = "org.wfanet.measurement.reporting.deploy.postgres.PostgresReportsServiceTest", - runtime_deps = [ - "@wfa_common_jvm//imports/java/org/yaml:snakeyaml", - ], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres:services", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing:database_provider", - ], -) diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresMeasurementsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresMeasurementsServiceTest.kt deleted file mode 100644 index 37ae12e69cb..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresMeasurementsServiceTest.kt +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres - -import org.junit.ClassRule -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresDatabaseClient -import org.wfanet.measurement.common.db.r2dbc.postgres.testing.PostgresDatabaseProviderRule -import org.wfanet.measurement.common.identity.IdGenerator -import org.wfanet.measurement.reporting.deploy.postgres.testing.Schemata -import org.wfanet.measurement.reporting.service.internal.testing.MeasurementsServiceTest - -@RunWith(JUnit4::class) -class PostgresMeasurementsServiceTest : MeasurementsServiceTest() { - - override fun newServices(idGenerator: IdGenerator): Services { - val client: PostgresDatabaseClient = databaseProvider.createDatabase() - return Services( - PostgresMeasurementsService(idGenerator, client), - PostgresReportsService(idGenerator, client), - ) - } - - companion object { - @get:ClassRule - @JvmStatic - val databaseProvider = PostgresDatabaseProviderRule(Schemata.REPORTING_CHANGELOG_PATH) - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportingSetsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportingSetsServiceTest.kt deleted file mode 100644 index 538b69a136d..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportingSetsServiceTest.kt +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres - -import org.junit.ClassRule -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresDatabaseClient -import org.wfanet.measurement.common.db.r2dbc.postgres.testing.PostgresDatabaseProviderRule -import org.wfanet.measurement.common.identity.IdGenerator -import org.wfanet.measurement.reporting.deploy.postgres.testing.Schemata -import org.wfanet.measurement.reporting.service.internal.testing.ReportingSetsServiceTest - -@RunWith(JUnit4::class) -class PostgresReportingSetsServiceTest : ReportingSetsServiceTest() { - override fun newService(idGenerator: IdGenerator): PostgresReportingSetsService { - val dbClient: PostgresDatabaseClient = databaseProvider.createDatabase() - return PostgresReportingSetsService(idGenerator, dbClient) - } - - companion object { - @get:ClassRule - @JvmStatic - val databaseProvider = PostgresDatabaseProviderRule(Schemata.REPORTING_CHANGELOG_PATH) - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportsServiceTest.kt deleted file mode 100644 index d928a4cfcd9..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportsServiceTest.kt +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres - -import org.junit.ClassRule -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresDatabaseClient -import org.wfanet.measurement.common.db.r2dbc.postgres.testing.PostgresDatabaseProviderRule -import org.wfanet.measurement.common.identity.IdGenerator -import org.wfanet.measurement.reporting.deploy.postgres.testing.Schemata -import org.wfanet.measurement.reporting.service.internal.testing.ReportsServiceTest - -@RunWith(JUnit4::class) -class PostgresReportsServiceTest : ReportsServiceTest() { - override fun newServices(idGenerator: IdGenerator): Services { - val client: PostgresDatabaseClient = databaseProvider.createDatabase() - return Services( - PostgresReportsService(idGenerator, client), - PostgresMeasurementsService(idGenerator, client), - PostgresReportingSetsService(idGenerator, client), - ) - } - - companion object { - @get:ClassRule - @JvmStatic - val databaseProvider = PostgresDatabaseProviderRule(Schemata.REPORTING_CHANGELOG_PATH) - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/BUILD.bazel index e601ec97363..bfad3ea9c46 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/BUILD.bazel @@ -10,7 +10,7 @@ kt_jvm_test( "//src/main/proto/wfa/measurement/api/v2alpha:event_group_metadata_descriptors_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/api/v2alpha/event_group_metadata/testing:test_metadata_message_kt_jvm_proto", "//src/main/proto/wfa/measurement/api/v2alpha/event_group_metadata/testing:test_parent_metadata_message_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:event_group_kt_jvm_proto", + "//src/main/proto/wfa/measurement/reporting/v2alpha:event_group_kt_jvm_proto", "@wfa_common_jvm//imports/java/com/google/common/truth", "@wfa_common_jvm//imports/kotlin/kotlin/test", "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/CelEnvProviderTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/CelEnvProviderTest.kt index 3d76581af90..fa6a2397805 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/CelEnvProviderTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/CelEnvProviderTest.kt @@ -53,9 +53,9 @@ import org.wfanet.measurement.common.ProtoReflection import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule import org.wfanet.measurement.common.grpc.testing.mockService import org.wfanet.measurement.common.testing.verifyProtoArgument -import org.wfanet.measurement.reporting.v1alpha.EventGroup -import org.wfanet.measurement.reporting.v1alpha.EventGroupKt -import org.wfanet.measurement.reporting.v1alpha.eventGroup +import org.wfanet.measurement.reporting.v2alpha.EventGroup +import org.wfanet.measurement.reporting.v2alpha.EventGroupKt +import org.wfanet.measurement.reporting.v2alpha.eventGroup private const val METADATA_FIELD = "metadata.metadata" private const val MAX_PAGE_SIZE = 1000 diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel deleted file mode 100644 index b7e72ce8ea5..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel +++ /dev/null @@ -1,108 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_test") - -kt_jvm_test( - name = "ReportingSetsServiceTest", - srcs = ["ReportingSetsServiceTest.kt"], - associates = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reporting_sets_service", - ], - test_class = "org.wfanet.measurement.reporting.service.api.v1alpha.ReportingSetsServiceTest", - deps = [ - "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:principal_server_interceptor", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:principal_server_interceptor", - "//src/main/proto/wfa/measurement/config/reporting:measurement_consumer_config_kt_jvm_proto", - "@wfa_common_jvm//imports/java/com/google/common/truth", - "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", - "@wfa_common_jvm//imports/java/com/google/protobuf", - "@wfa_common_jvm//imports/kotlin/kotlin/test", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", - "@wfa_common_jvm//imports/kotlin/org/mockito/kotlin", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing", - ], -) - -kt_jvm_test( - name = "SetOperationCompilerTest", - srcs = ["SetOperationCompilerTest.kt"], - associates = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:set_operation_compiler", - ], - test_class = "org.wfanet.measurement.reporting.service.api.v1alpha.SetOperationCompilerTest", - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:resource_key", - "@wfa_common_jvm//imports/java/com/google/common/truth", - "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", - "@wfa_common_jvm//imports/java/com/google/protobuf", - "@wfa_common_jvm//imports/kotlin/kotlin/test", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing", - ], -) - -kt_jvm_test( - name = "EventGroupsServiceTest", - srcs = ["EventGroupsServiceTest.kt"], - data = [ - "//src/main/k8s/testing/secretfiles:all_der_files", - "//src/main/k8s/testing/secretfiles:all_tink_keysets", - ], - test_class = "org.wfanet.measurement.reporting.service.api.v1alpha.EventGroupsServiceTest", - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:cel_env_provider", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:event_groups_service", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:principal_server_interceptor", - "//src/main/proto/wfa/measurement/api/v2alpha:event_group_kt_jvm_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:event_group_metadata_descriptors_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:event_groups_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha/event_group_metadata/testing:test_metadata_message_kt_jvm_proto", - "//src/main/proto/wfa/measurement/api/v2alpha/event_group_metadata/testing:test_parent_metadata_message_kt_jvm_proto", - "//src/main/proto/wfa/measurement/config/reporting:measurement_consumer_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:event_groups_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//imports/kotlin/kotlin/test", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/tink", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing", - "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/dataprovider", - ], -) - -kt_jvm_test( - name = "ReportsServiceTest", - srcs = ["ReportsServiceTest.kt"], - associates = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reports_service", - ], - data = [ - "//src/main/k8s/testing/secretfiles:root_certs", - "//src/main/k8s/testing/secretfiles:secret_files", - ], - test_class = "org.wfanet.measurement.reporting.service.api.v1alpha.ReportsServiceTest", - deps = [ - "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:principal_server_interceptor", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:principal_server_interceptor", - "//src/main/proto/wfa/measurement/api/v2alpha:certificate_kt_jvm_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:certificates_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/config/reporting:measurement_consumer_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/config/reporting:measurement_spec_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reporting_set_kt_jvm_proto", - "@wfa_common_jvm//imports/java/com/google/common/truth", - "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", - "@wfa_common_jvm//imports/java/com/google/protobuf", - "@wfa_common_jvm//imports/kotlin/kotlin/test", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", - "@wfa_common_jvm//imports/kotlin/org/mockito/kotlin", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/tink", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing", - "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/common:key_handles", - "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/dataprovider", - "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/duchy", - ], -) diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsServiceTest.kt deleted file mode 100644 index 17238a9d814..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsServiceTest.kt +++ /dev/null @@ -1,489 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import com.google.common.truth.Truth.assertThat -import com.google.common.truth.extensions.proto.ProtoTruth.assertThat -import com.google.protobuf.ByteString -import io.grpc.Status -import io.grpc.StatusRuntimeException -import java.nio.file.Path -import java.nio.file.Paths -import java.time.Duration -import kotlin.test.assertFailsWith -import kotlinx.coroutines.runBlocking -import org.junit.Before -import org.junit.Rule -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.mockito.kotlin.any -import org.mockito.kotlin.stub -import org.mockito.kotlin.whenever -import org.wfanet.measurement.api.v2alpha.EventGroupKt as CmmsEventGroup -import org.wfanet.measurement.api.v2alpha.EventGroupMetadataDescriptorsGrpcKt.EventGroupMetadataDescriptorsCoroutineImplBase -import org.wfanet.measurement.api.v2alpha.EventGroupMetadataDescriptorsGrpcKt.EventGroupMetadataDescriptorsCoroutineStub -import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineImplBase -import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub -import org.wfanet.measurement.api.v2alpha.ListEventGroupsRequest as CmmsListEventGroupsRequest -import org.wfanet.measurement.api.v2alpha.ListEventGroupsRequestKt -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey -import org.wfanet.measurement.api.v2alpha.copy -import org.wfanet.measurement.api.v2alpha.encryptionPublicKey -import org.wfanet.measurement.api.v2alpha.eventGroup as cmmsEventGroup -import org.wfanet.measurement.api.v2alpha.eventGroupMetadataDescriptor -import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.testMetadataMessage -import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.testParentMetadataMessage -import org.wfanet.measurement.api.v2alpha.listEventGroupMetadataDescriptorsResponse -import org.wfanet.measurement.api.v2alpha.listEventGroupsRequest as cmmsListEventGroupsRequest -import org.wfanet.measurement.api.v2alpha.listEventGroupsResponse as cmmsListEventGroupsResponse -import org.wfanet.measurement.common.ProtoReflection -import org.wfanet.measurement.common.crypto.tink.TinkPrivateKeyHandle -import org.wfanet.measurement.common.crypto.tink.TinkPublicKeyHandle -import org.wfanet.measurement.common.crypto.tink.loadPrivateKey -import org.wfanet.measurement.common.crypto.tink.loadPublicKey -import org.wfanet.measurement.common.getRuntimePath -import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule -import org.wfanet.measurement.common.grpc.testing.mockService -import org.wfanet.measurement.common.pack -import org.wfanet.measurement.common.testing.verifyProtoArgument -import org.wfanet.measurement.config.reporting.measurementConsumerConfig -import org.wfanet.measurement.consent.client.common.toEncryptionPublicKey -import org.wfanet.measurement.consent.client.dataprovider.encryptMetadata -import org.wfanet.measurement.reporting.service.api.CelEnvCacheProvider -import org.wfanet.measurement.reporting.service.api.InMemoryEncryptionKeyPairStore -import org.wfanet.measurement.reporting.v1alpha.EventGroup -import org.wfanet.measurement.reporting.v1alpha.EventGroupKt.metadata -import org.wfanet.measurement.reporting.v1alpha.eventGroup -import org.wfanet.measurement.reporting.v1alpha.listEventGroupsRequest -import org.wfanet.measurement.reporting.v1alpha.listEventGroupsResponse - -private const val DEFAULT_PAGE_SIZE = 50 - -private const val API_AUTHENTICATION_KEY = "nR5QPN7ptx" -private val CONFIG = measurementConsumerConfig { apiKey = API_AUTHENTICATION_KEY } -private val SECRET_FILES_PATH: Path = - checkNotNull( - getRuntimePath( - Paths.get("wfa_measurement_system", "src", "main", "k8s", "testing", "secretfiles") - ) - ) -private val ENCRYPTION_PRIVATE_KEY_HANDLE = loadEncryptionPrivateKey("mc_enc_private.tink") -private val ENCRYPTION_PUBLIC_KEY = - loadEncryptionPublicKey("mc_enc_public.tink").toEncryptionPublicKey() -private const val MEASUREMENT_CONSUMER_REFERENCE_ID = "measurementConsumerRefId" -private val MEASUREMENT_CONSUMER_NAME = - MeasurementConsumerKey(MEASUREMENT_CONSUMER_REFERENCE_ID).toName() -private val ENCRYPTION_KEY_PAIR_STORE = - InMemoryEncryptionKeyPairStore( - mapOf( - MEASUREMENT_CONSUMER_NAME to - listOf(ENCRYPTION_PUBLIC_KEY.data to ENCRYPTION_PRIVATE_KEY_HANDLE) - ) - ) -private val TEST_MESSAGE = testMetadataMessage { publisherId = 15 } -private const val CMMS_EVENT_GROUP_ID = "AAAAAAAAAHs" -private val CMMS_EVENT_GROUP = cmmsEventGroup { - name = "$DATA_PROVIDER_NAME/eventGroups/$CMMS_EVENT_GROUP_ID" - measurementConsumer = MEASUREMENT_CONSUMER_NAME - eventGroupReferenceId = EVENT_GROUP_REFERENCE_ID - measurementConsumerPublicKey = ENCRYPTION_PUBLIC_KEY.pack() - encryptedMetadata = - encryptMetadata( - CmmsEventGroup.metadata { - eventGroupMetadataDescriptor = METADATA_NAME - metadata = TEST_MESSAGE.pack() - }, - ENCRYPTION_PUBLIC_KEY, - ) -} -private val TEST_MESSAGE_2 = testMetadataMessage { publisherId = 5 } -private const val CMMS_EVENT_GROUP_ID_2 = "AAAAAAAAAGs" -private val CMMS_EVENT_GROUP_2 = - CMMS_EVENT_GROUP.copy { - name = "$DATA_PROVIDER_NAME/eventGroups/$CMMS_EVENT_GROUP_ID_2" - eventGroupReferenceId = "id2" - encryptedMetadata = - encryptMetadata( - CmmsEventGroup.metadata { - eventGroupMetadataDescriptor = METADATA_NAME - metadata = TEST_MESSAGE_2.pack() - }, - ENCRYPTION_PUBLIC_KEY, - ) - } -private val EVENT_GROUP = eventGroup { - name = - EventGroupKey( - MEASUREMENT_CONSUMER_REFERENCE_ID, - DATA_PROVIDER_REFERENCE_ID, - CMMS_EVENT_GROUP_ID, - ) - .toName() - dataProvider = DATA_PROVIDER_NAME - eventGroupReferenceId = EVENT_GROUP_REFERENCE_ID - metadata = metadata { - eventGroupMetadataDescriptor = METADATA_NAME - metadata = TEST_MESSAGE.pack() - } -} -private const val PAGE_TOKEN = "base64encodedtoken" -private const val NEXT_PAGE_TOKEN = "base64encodedtoken2" -private const val DATA_PROVIDER_REFERENCE_ID = "123" -private const val DATA_PROVIDER_NAME = "dataProviders/$DATA_PROVIDER_REFERENCE_ID" -private const val EVENT_GROUP_REFERENCE_ID = "edpRefId1" -private const val EVENT_GROUP_PARENT = - "measurementConsumers/$MEASUREMENT_CONSUMER_REFERENCE_ID/dataProviders/$DATA_PROVIDER_REFERENCE_ID" -private const val METADATA_NAME = "$DATA_PROVIDER_NAME/eventGroupMetadataDescriptors/abc" -private val EVENT_GROUP_METADATA_DESCRIPTOR = eventGroupMetadataDescriptor { - name = METADATA_NAME - descriptorSet = ProtoReflection.buildFileDescriptorSet(TEST_MESSAGE.descriptorForType) -} - -@RunWith(JUnit4::class) -class EventGroupsServiceTest { - private val cmmsEventGroupsServiceMock: EventGroupsCoroutineImplBase = mockService { - onBlocking { listEventGroups(any()) } - .thenReturn( - cmmsListEventGroupsResponse { - eventGroups += listOf(CMMS_EVENT_GROUP, CMMS_EVENT_GROUP_2) - nextPageToken = NEXT_PAGE_TOKEN - } - ) - } - private val cmmsEventGroupMetadataDescriptorsServiceMock: - EventGroupMetadataDescriptorsCoroutineImplBase = - mockService { - onBlocking { listEventGroupMetadataDescriptors(any()) } - .thenReturn( - listEventGroupMetadataDescriptorsResponse { - eventGroupMetadataDescriptors += EVENT_GROUP_METADATA_DESCRIPTOR - } - ) - } - - @get:Rule - val grpcTestServerRule = GrpcTestServerRule { - addService(cmmsEventGroupsServiceMock) - addService(cmmsEventGroupMetadataDescriptorsServiceMock) - } - - private lateinit var service: EventGroupsService - - @Before - fun initService() { - val celEnvCacheProvider = - CelEnvCacheProvider( - EventGroupMetadataDescriptorsCoroutineStub(grpcTestServerRule.channel), - EventGroup.getDescriptor(), - Duration.ofSeconds(5), - emptyList(), - ) - - service = - EventGroupsService( - EventGroupsCoroutineStub(grpcTestServerRule.channel), - ENCRYPTION_KEY_PAIR_STORE, - celEnvCacheProvider, - ) - } - - @Test - fun `listEventGroups returns list with no filter`() { - cmmsEventGroupsServiceMock.stub { - onBlocking { listEventGroups(any()) } - .thenReturn( - cmmsListEventGroupsResponse { - eventGroups += - listOf( - CMMS_EVENT_GROUP, - // When there's no filter applied to metadata, it doesn't need to be set on all EGs. - CMMS_EVENT_GROUP_2.copy { clearEncryptedMetadata() }, - ) - nextPageToken = NEXT_PAGE_TOKEN - } - ) - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { - service.listEventGroups( - listEventGroupsRequest { - parent = EVENT_GROUP_PARENT - pageSize = 10 - pageToken = PAGE_TOKEN - } - ) - } - } - - assertThat(result) - .isEqualTo( - listEventGroupsResponse { - eventGroups += - listOf( - EVENT_GROUP, - eventGroup { - name = - EventGroupKey( - MeasurementConsumerKey.fromName(CMMS_EVENT_GROUP_2.measurementConsumer)!! - .measurementConsumerId, - DATA_PROVIDER_REFERENCE_ID, - CMMS_EVENT_GROUP_ID_2, - ) - .toName() - dataProvider = DATA_PROVIDER_NAME - eventGroupReferenceId = "id2" - }, - ) - nextPageToken = NEXT_PAGE_TOKEN - } - ) - - val expectedCmmsEventGroupsRequest = cmmsListEventGroupsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = 10 - pageToken = PAGE_TOKEN - filter = ListEventGroupsRequestKt.filter { dataProviders += DATA_PROVIDER_NAME } - } - - verifyProtoArgument(cmmsEventGroupsServiceMock, EventGroupsCoroutineImplBase::listEventGroups) - .isEqualTo(expectedCmmsEventGroupsRequest) - } - - @Test - fun `listEventGroups returns list with filter`() { - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { - service.listEventGroups( - listEventGroupsRequest { - parent = EVENT_GROUP_PARENT - filter = "metadata.metadata.publisher_id > 10" - pageToken = PAGE_TOKEN - } - ) - } - } - - assertThat(result) - .isEqualTo( - listEventGroupsResponse { - eventGroups += EVENT_GROUP - nextPageToken = NEXT_PAGE_TOKEN - } - ) - - val expectedCmmsEventGroupsRequest = cmmsListEventGroupsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = DEFAULT_PAGE_SIZE - pageToken = PAGE_TOKEN - filter = ListEventGroupsRequestKt.filter { dataProviders += DATA_PROVIDER_NAME } - } - - verifyProtoArgument(cmmsEventGroupsServiceMock, EventGroupsCoroutineImplBase::listEventGroups) - .isEqualTo(expectedCmmsEventGroupsRequest) - } - - @Test - fun `listEventGroups omits DataProvider filter in CMMS request when ID is wildcard`() { - val request = listEventGroupsRequest { - parent = "measurementConsumers/$MEASUREMENT_CONSUMER_REFERENCE_ID/dataProviders/-" - } - - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listEventGroups(request) } - } - - verifyProtoArgument(cmmsEventGroupsServiceMock, EventGroupsCoroutineImplBase::listEventGroups) - .isEqualTo( - cmmsListEventGroupsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = DEFAULT_PAGE_SIZE - filter = CmmsListEventGroupsRequest.Filter.getDefaultInstance() - } - ) - } - - @Test - fun `listEventGroups returns list with filter when event group with metadata and one without`() { - runBlocking { - whenever(cmmsEventGroupsServiceMock.listEventGroups(any())) - .thenReturn( - cmmsListEventGroupsResponse { - eventGroups += - listOf(CMMS_EVENT_GROUP, CMMS_EVENT_GROUP_2.copy { clearEncryptedMetadata() }) - } - ) - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { - service.listEventGroups( - listEventGroupsRequest { - parent = EVENT_GROUP_PARENT - filter = "metadata.metadata.publisher_id > 10" - pageToken = PAGE_TOKEN - } - ) - } - } - - assertThat(result).isEqualTo(listEventGroupsResponse { eventGroups += EVENT_GROUP }) - - val expectedCmmsEventGroupsRequest = cmmsListEventGroupsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = DEFAULT_PAGE_SIZE - pageToken = PAGE_TOKEN - filter = ListEventGroupsRequestKt.filter { dataProviders += DATA_PROVIDER_NAME } - } - - verifyProtoArgument(cmmsEventGroupsServiceMock, EventGroupsCoroutineImplBase::listEventGroups) - .isEqualTo(expectedCmmsEventGroupsRequest) - } - - @Test - fun `listEventGroups throws FAILED_PRECONDITION if message descriptor not found`() { - val eventGroupInvalidMetadata = cmmsEventGroup { - name = "$DATA_PROVIDER_NAME/eventGroups/$CMMS_EVENT_GROUP_ID" - measurementConsumer = MEASUREMENT_CONSUMER_NAME - eventGroupReferenceId = "id1" - measurementConsumerPublicKey = ENCRYPTION_PUBLIC_KEY.pack() - encryptedMetadata = - encryptMetadata( - CmmsEventGroup.metadata { - eventGroupMetadataDescriptor = METADATA_NAME - metadata = testParentMetadataMessage { name = "name" }.pack() - }, - ENCRYPTION_PUBLIC_KEY, - ) - } - cmmsEventGroupsServiceMock.stub { - onBlocking { listEventGroups(any()) } - .thenReturn( - cmmsListEventGroupsResponse { - eventGroups += listOf(CMMS_EVENT_GROUP, CMMS_EVENT_GROUP_2, eventGroupInvalidMetadata) - } - ) - } - - val result = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { - service.listEventGroups( - listEventGroupsRequest { - parent = EVENT_GROUP_PARENT - filter = "metadata.metadata.publisher_id > 10" - } - ) - } - } - } - - assertThat(result.status.code).isEqualTo(Status.Code.FAILED_PRECONDITION) - } - - @Test - fun `listEventGroups throws FAILED_PRECONDITION if private key not found`() { - val eventGroupInvalidPublicKey = cmmsEventGroup { - name = "$DATA_PROVIDER_NAME/eventGroups/$CMMS_EVENT_GROUP_ID" - measurementConsumer = MEASUREMENT_CONSUMER_NAME - eventGroupReferenceId = "id1" - measurementConsumerPublicKey = - encryptionPublicKey { data = ByteString.copyFromUtf8("consumerkey") }.pack() - encryptedMetadata = - encryptMetadata( - CmmsEventGroup.metadata { - eventGroupMetadataDescriptor = METADATA_NAME - metadata = testParentMetadataMessage { name = "name" }.pack() - }, - ENCRYPTION_PUBLIC_KEY, - ) - } - cmmsEventGroupsServiceMock.stub { - onBlocking { listEventGroups(any()) } - .thenReturn( - cmmsListEventGroupsResponse { - eventGroups += listOf(CMMS_EVENT_GROUP, CMMS_EVENT_GROUP_2, eventGroupInvalidPublicKey) - } - ) - } - - val result = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { - service.listEventGroups( - listEventGroupsRequest { - parent = EVENT_GROUP_PARENT - filter = "metadata.metadata.publisher_id > 10" - } - ) - } - } - } - - assertThat(result.status.code).isEqualTo(Status.Code.FAILED_PRECONDITION) - } - - @Test - fun `listEventGroups throws INVALID_ARGUMENT if parent not specified`() { - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { - service.listEventGroups( - listEventGroupsRequest { - filter = "metadata.metadata.publisher_id > 10" - pageToken = PAGE_TOKEN - ENCRYPTION_KEY_PAIR_STORE - } - ) - } - } - } - - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception).hasMessageThat().ignoringCase().contains("parent") - } - - @Test - fun `listEventGroups throws UNAUTHENTICATED if principal not found`() { - val result = - assertFailsWith { - runBlocking { - service.listEventGroups( - listEventGroupsRequest { - parent = EVENT_GROUP_PARENT - filter = "metadata.metadata.publisher_id > 10" - } - ) - } - } - - assertThat(result.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - } -} - -private fun loadEncryptionPrivateKey(fileName: String): TinkPrivateKeyHandle { - return loadPrivateKey(SECRET_FILES_PATH.resolve(fileName).toFile()) -} - -private fun loadEncryptionPublicKey(fileName: String): TinkPublicKeyHandle { - return loadPublicKey(SECRET_FILES_PATH.resolve(fileName).toFile()) -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetsServiceTest.kt deleted file mode 100644 index 20b13d92928..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetsServiceTest.kt +++ /dev/null @@ -1,750 +0,0 @@ -/* - * Copyright 2022 The Cross-Media Measurement Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import com.google.common.truth.Truth.assertThat -import com.google.common.truth.extensions.proto.ProtoTruth.assertThat -import io.grpc.Status -import io.grpc.StatusRuntimeException -import kotlin.test.assertFailsWith -import kotlinx.coroutines.flow.flowOf -import kotlinx.coroutines.runBlocking -import org.junit.Before -import org.junit.Rule -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.mockito.kotlin.any -import org.wfanet.measurement.api.v2alpha.DataProviderKey -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey -import org.wfanet.measurement.api.v2alpha.withDataProviderPrincipal -import org.wfanet.measurement.common.base64UrlEncode -import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule -import org.wfanet.measurement.common.grpc.testing.mockService -import org.wfanet.measurement.common.identity.externalIdToApiId -import org.wfanet.measurement.common.testing.verifyProtoArgument -import org.wfanet.measurement.config.reporting.measurementConsumerConfig -import org.wfanet.measurement.internal.reporting.ReportingSet as InternalReportingSet -import org.wfanet.measurement.internal.reporting.ReportingSetKt.eventGroupKey -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt.ReportingSetsCoroutineImplBase -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt.ReportingSetsCoroutineStub -import org.wfanet.measurement.internal.reporting.StreamReportingSetsRequestKt -import org.wfanet.measurement.internal.reporting.copy -import org.wfanet.measurement.internal.reporting.reportingSet as internalReportingSet -import org.wfanet.measurement.internal.reporting.streamReportingSetsRequest -import org.wfanet.measurement.reporting.v1alpha.ListReportingSetsPageTokenKt.previousPageEnd -import org.wfanet.measurement.reporting.v1alpha.ListReportingSetsRequest -import org.wfanet.measurement.reporting.v1alpha.ReportingSet -import org.wfanet.measurement.reporting.v1alpha.copy -import org.wfanet.measurement.reporting.v1alpha.createReportingSetRequest -import org.wfanet.measurement.reporting.v1alpha.listReportingSetsPageToken -import org.wfanet.measurement.reporting.v1alpha.listReportingSetsRequest -import org.wfanet.measurement.reporting.v1alpha.listReportingSetsResponse -import org.wfanet.measurement.reporting.v1alpha.reportingSet - -private const val DEFAULT_PAGE_SIZE = 50 -private const val MAX_PAGE_SIZE = 1000 -private const val PAGE_SIZE = 2 - -private const val API_AUTHENTICATION_KEY = "nR5QPN7ptx" -private val CONFIG = measurementConsumerConfig { apiKey = API_AUTHENTICATION_KEY } - -// Measurement consumer IDs and names -private const val MEASUREMENT_CONSUMER_EXTERNAL_ID = 111L -private const val MEASUREMENT_CONSUMER_EXTERNAL_ID_2 = 112L -private val MEASUREMENT_CONSUMER_REFERENCE_ID = externalIdToApiId(MEASUREMENT_CONSUMER_EXTERNAL_ID) -private val MEASUREMENT_CONSUMER_REFERENCE_ID_2 = - externalIdToApiId(MEASUREMENT_CONSUMER_EXTERNAL_ID_2) -private val MEASUREMENT_CONSUMER_NAME = - MeasurementConsumerKey(MEASUREMENT_CONSUMER_REFERENCE_ID).toName() -private val MEASUREMENT_CONSUMER_NAME_2 = - MeasurementConsumerKey(MEASUREMENT_CONSUMER_REFERENCE_ID_2).toName() - -// Data provider IDs and names -private const val DATA_PROVIDER_EXTERNAL_ID = 221L -private const val DATA_PROVIDER_EXTERNAL_ID_2 = 222L -private const val DATA_PROVIDER_EXTERNAL_ID_3 = 223L -private val DATA_PROVIDER_REFERENCE_ID = externalIdToApiId(DATA_PROVIDER_EXTERNAL_ID) -private val DATA_PROVIDER_REFERENCE_ID_2 = externalIdToApiId(DATA_PROVIDER_EXTERNAL_ID_2) -private val DATA_PROVIDER_REFERENCE_ID_3 = externalIdToApiId(DATA_PROVIDER_EXTERNAL_ID_3) - -private val DATA_PROVIDER_NAME = DataProviderKey(DATA_PROVIDER_REFERENCE_ID).toName() -private val DATA_PROVIDER_NAME_2 = DataProviderKey(DATA_PROVIDER_REFERENCE_ID_2).toName() -private val DATA_PROVIDER_NAME_3 = DataProviderKey(DATA_PROVIDER_REFERENCE_ID_3).toName() - -// Reporting set IDs and names -private val REPORTING_SET_EXTERNAL_ID = 331L -private val REPORTING_SET_EXTERNAL_ID_2 = 332L -private val REPORTING_SET_EXTERNAL_ID_3 = 333L - -private val REPORTING_SET_NAME = - ReportingSetKey(MEASUREMENT_CONSUMER_REFERENCE_ID, externalIdToApiId(REPORTING_SET_EXTERNAL_ID)) - .toName() -private val REPORTING_SET_NAME_2 = - ReportingSetKey(MEASUREMENT_CONSUMER_REFERENCE_ID, externalIdToApiId(REPORTING_SET_EXTERNAL_ID_2)) - .toName() -private val REPORTING_SET_NAME_3 = - ReportingSetKey(MEASUREMENT_CONSUMER_REFERENCE_ID, externalIdToApiId(REPORTING_SET_EXTERNAL_ID_3)) - .toName() - -// Event group IDs and names -private val EVENT_GROUP_EXTERNAL_ID = 441L -private val EVENT_GROUP_EXTERNAL_ID_2 = 442L -private val EVENT_GROUP_EXTERNAL_ID_3 = 443L -private val EVENT_GROUP_REFERENCE_ID = externalIdToApiId(EVENT_GROUP_EXTERNAL_ID) -private val EVENT_GROUP_REFERENCE_ID_2 = externalIdToApiId(EVENT_GROUP_EXTERNAL_ID_2) -private val EVENT_GROUP_REFERENCE_ID_3 = externalIdToApiId(EVENT_GROUP_EXTERNAL_ID_3) - -private val EVENT_GROUP_NAME = - EventGroupKey( - MEASUREMENT_CONSUMER_REFERENCE_ID, - DATA_PROVIDER_REFERENCE_ID, - EVENT_GROUP_REFERENCE_ID, - ) - .toName() -private val EVENT_GROUP_NAME_2 = - EventGroupKey( - MEASUREMENT_CONSUMER_REFERENCE_ID, - DATA_PROVIDER_REFERENCE_ID_2, - EVENT_GROUP_REFERENCE_ID_2, - ) - .toName() -private val EVENT_GROUP_NAME_3 = - EventGroupKey( - MEASUREMENT_CONSUMER_REFERENCE_ID, - DATA_PROVIDER_REFERENCE_ID_3, - EVENT_GROUP_REFERENCE_ID_3, - ) - .toName() -private val EVENT_GROUP_NAMES = listOf(EVENT_GROUP_NAME, EVENT_GROUP_NAME_2, EVENT_GROUP_NAME_3) - -// Event group keys -private val EVENT_GROUP_KEY = eventGroupKey { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - dataProviderReferenceId = DATA_PROVIDER_REFERENCE_ID - eventGroupReferenceId = EVENT_GROUP_REFERENCE_ID -} -private val EVENT_GROUP_KEY_2 = eventGroupKey { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - dataProviderReferenceId = DATA_PROVIDER_REFERENCE_ID_2 - eventGroupReferenceId = EVENT_GROUP_REFERENCE_ID_2 -} -private val EVENT_GROUP_KEY_3 = eventGroupKey { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - dataProviderReferenceId = DATA_PROVIDER_REFERENCE_ID_3 - eventGroupReferenceId = EVENT_GROUP_REFERENCE_ID_3 -} -private val EVENT_GROUP_KEYS = listOf(EVENT_GROUP_KEY, EVENT_GROUP_KEY_2, EVENT_GROUP_KEY_3) - -// Event filters -private const val FILTER = "AGE>20" - -// Reporting sets -private val DISPLAY_NAME = REPORTING_SET_NAME + FILTER -private val DISPLAY_NAME_2 = REPORTING_SET_NAME_2 + FILTER -private val DISPLAY_NAME_3 = REPORTING_SET_NAME_3 + FILTER - -private val REPORTING_SET: ReportingSet = reportingSet { - name = REPORTING_SET_NAME - eventGroups.addAll(EVENT_GROUP_NAMES) - filter = FILTER - displayName = DISPLAY_NAME -} - -// Internal reporting sets -private val INTERNAL_REPORTING_SET: InternalReportingSet = internalReportingSet { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetId = REPORTING_SET_EXTERNAL_ID - eventGroupKeys.addAll(EVENT_GROUP_KEYS) - filter = FILTER - displayName = DISPLAY_NAME -} - -@RunWith(JUnit4::class) -class ReportingSetsServiceTest { - - private val internalReportingSetsMock: ReportingSetsCoroutineImplBase = - mockService() { - onBlocking { createReportingSet(any()) }.thenReturn(INTERNAL_REPORTING_SET) - onBlocking { streamReportingSets(any()) } - .thenReturn( - flowOf( - INTERNAL_REPORTING_SET, - INTERNAL_REPORTING_SET.copy { - externalReportingSetId = REPORTING_SET_EXTERNAL_ID_2 - displayName = DISPLAY_NAME_2 - }, - INTERNAL_REPORTING_SET.copy { - externalReportingSetId = REPORTING_SET_EXTERNAL_ID_3 - displayName = DISPLAY_NAME_3 - }, - ) - ) - } - - @get:Rule val grpcTestServerRule = GrpcTestServerRule { addService(internalReportingSetsMock) } - - private lateinit var service: ReportingSetsService - - @Before - fun initService() { - service = ReportingSetsService(ReportingSetsCoroutineStub(grpcTestServerRule.channel)) - } - - @Test - fun `createReportingSet returns reporting set`() { - val request = createReportingSetRequest { - parent = MEASUREMENT_CONSUMER_NAME - reportingSet = REPORTING_SET - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.createReportingSet(request) } - } - - val expected = REPORTING_SET - - verifyProtoArgument( - internalReportingSetsMock, - ReportingSetsCoroutineImplBase::createReportingSet, - ) - .isEqualTo(INTERNAL_REPORTING_SET.copy { clearExternalReportingSetId() }) - - assertThat(result).isEqualTo(expected) - } - - @Test - fun `createReportingSet throws UNAUTHENTICATED when no principal is found`() { - val request = createReportingSetRequest { - parent = MEASUREMENT_CONSUMER_NAME - reportingSet = REPORTING_SET - } - - val exception = - assertFailsWith { - runBlocking { service.createReportingSet(request) } - } - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - } - - @Test - fun `createReportingSet throws PERMISSION_DENIED when MC caller doesn't match`() { - val request = createReportingSetRequest { - parent = MEASUREMENT_CONSUMER_NAME - reportingSet = REPORTING_SET - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME_2, CONFIG) { - runBlocking { service.createReportingSet(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) - assertThat(exception.status.description) - .isEqualTo("Cannot create a ReportingSet for another MeasurementConsumer.") - } - - @Test - fun `createReportingSet throws UNAUTHENTICATED when caller is not MeasurementConsumer`() { - val request = createReportingSetRequest { - parent = MEASUREMENT_CONSUMER_NAME - reportingSet = REPORTING_SET - } - - val exception = - assertFailsWith { - withDataProviderPrincipal(DATA_PROVIDER_NAME) { - runBlocking { service.createReportingSet(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - assertThat(exception.status.description).isEqualTo("No ReportingPrincipal found") - } - - @Test - fun `createReportingSet throws INVALID_ARGUMENT when parent is missing`() { - val request = createReportingSetRequest { reportingSet = REPORTING_SET } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.createReportingSet(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReportingSet throws INVALID_ARGUMENT if ReportingSet is not specified`() { - val request = createReportingSetRequest { parent = MEASUREMENT_CONSUMER_NAME } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.createReportingSet(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReportingSet throws INVALID_ARGUMENT when EventGroups in ReportingSet is empty`() { - val request = createReportingSetRequest { - parent = MEASUREMENT_CONSUMER_NAME - reportingSet = REPORTING_SET.copy { eventGroups.clear() } - } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.createReportingSet(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description) - .isEqualTo("EventGroups in ReportingSet cannot be empty.") - } - - @Test - fun `createReportingSet throws INVALID_ARGUMENT when there is any invalid EventGroup`() { - val invalidEventGroupName = "invalid" - val request = createReportingSetRequest { - parent = MEASUREMENT_CONSUMER_NAME - reportingSet = reportingSet { - name = REPORTING_SET_NAME - eventGroups.addAll(listOf(EVENT_GROUP_NAME, invalidEventGroupName)) - filter = FILTER - displayName = DISPLAY_NAME - } - } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.createReportingSet(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description) - .isEqualTo("EventGroup is either unspecified or invalid.") - } - - @Test - fun `listReportingSets returns without a next page token when there is no previous page token`() { - val request = listReportingSetsRequest { parent = MEASUREMENT_CONSUMER_NAME } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listReportingSets(request) } - } - - val expected = listReportingSetsResponse { - reportingSets += REPORTING_SET - reportingSets += - REPORTING_SET.copy { - name = REPORTING_SET_NAME_2 - displayName = DISPLAY_NAME_2 - } - reportingSets += - REPORTING_SET.copy { - name = REPORTING_SET_NAME_3 - displayName = DISPLAY_NAME_3 - } - } - - verifyProtoArgument( - internalReportingSetsMock, - ReportingSetsCoroutineImplBase::streamReportingSets, - ) - .isEqualTo( - streamReportingSetsRequest { - limit = DEFAULT_PAGE_SIZE + 1 - filter = - StreamReportingSetsRequestKt.filter { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReportingSets returns with a next page token when there is no previous page token`() { - val request = listReportingSetsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = PAGE_SIZE - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listReportingSets(request) } - } - - val expected = listReportingSetsResponse { - reportingSets += REPORTING_SET - reportingSets += - REPORTING_SET.copy { - name = REPORTING_SET_NAME_2 - displayName = DISPLAY_NAME_2 - } - nextPageToken = - listReportingSetsPageToken { - pageSize = PAGE_SIZE - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - lastReportingSet = previousPageEnd { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetId = REPORTING_SET_EXTERNAL_ID_2 - } - } - .toByteString() - .base64UrlEncode() - } - - verifyProtoArgument( - internalReportingSetsMock, - ReportingSetsCoroutineImplBase::streamReportingSets, - ) - .isEqualTo( - streamReportingSetsRequest { - limit = PAGE_SIZE + 1 - filter = - StreamReportingSetsRequestKt.filter { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReportingSets returns with a next page token when there is a previous page token`() { - val request = listReportingSetsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = PAGE_SIZE - pageToken = - listReportingSetsPageToken { - pageSize = PAGE_SIZE - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - lastReportingSet = previousPageEnd { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetId = REPORTING_SET_EXTERNAL_ID - } - } - .toByteString() - .base64UrlEncode() - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listReportingSets(request) } - } - - val expected = listReportingSetsResponse { - reportingSets += REPORTING_SET - reportingSets += - REPORTING_SET.copy { - name = REPORTING_SET_NAME_2 - displayName = DISPLAY_NAME_2 - } - nextPageToken = - listReportingSetsPageToken { - pageSize = PAGE_SIZE - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - lastReportingSet = previousPageEnd { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetId = REPORTING_SET_EXTERNAL_ID_2 - } - } - .toByteString() - .base64UrlEncode() - } - - verifyProtoArgument( - internalReportingSetsMock, - ReportingSetsCoroutineImplBase::streamReportingSets, - ) - .isEqualTo( - streamReportingSetsRequest { - limit = PAGE_SIZE + 1 - filter = - StreamReportingSetsRequestKt.filter { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetIdAfter = REPORTING_SET_EXTERNAL_ID - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReportingSets with page size replaced with a valid value and no previous page token`() { - val invalidPageSize = MAX_PAGE_SIZE * 2 - val request = listReportingSetsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = invalidPageSize - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listReportingSets(request) } - } - - val expected = listReportingSetsResponse { - reportingSets += REPORTING_SET - reportingSets += - REPORTING_SET.copy { - name = REPORTING_SET_NAME_2 - displayName = DISPLAY_NAME_2 - } - reportingSets += - REPORTING_SET.copy { - name = REPORTING_SET_NAME_3 - displayName = DISPLAY_NAME_3 - } - } - - verifyProtoArgument( - internalReportingSetsMock, - ReportingSetsCoroutineImplBase::streamReportingSets, - ) - .isEqualTo( - streamReportingSetsRequest { - limit = MAX_PAGE_SIZE + 1 - filter = - StreamReportingSetsRequestKt.filter { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReportingSets with invalid page size replaced with the one in previous page token`() { - val invalidPageSize = MAX_PAGE_SIZE * 2 - val previousPageSize = PAGE_SIZE - val request = listReportingSetsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = invalidPageSize - pageToken = - listReportingSetsPageToken { - pageSize = previousPageSize - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - lastReportingSet = previousPageEnd { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetId = REPORTING_SET_EXTERNAL_ID - } - } - .toByteString() - .base64UrlEncode() - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listReportingSets(request) } - } - - val expected = listReportingSetsResponse { - reportingSets += REPORTING_SET - reportingSets += - REPORTING_SET.copy { - name = REPORTING_SET_NAME_2 - displayName = DISPLAY_NAME_2 - } - nextPageToken = - listReportingSetsPageToken { - pageSize = previousPageSize - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - lastReportingSet = previousPageEnd { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetId = REPORTING_SET_EXTERNAL_ID_2 - } - } - .toByteString() - .base64UrlEncode() - } - - verifyProtoArgument( - internalReportingSetsMock, - ReportingSetsCoroutineImplBase::streamReportingSets, - ) - .isEqualTo( - streamReportingSetsRequest { - limit = previousPageSize + 1 - filter = - StreamReportingSetsRequestKt.filter { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetIdAfter = REPORTING_SET_EXTERNAL_ID - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReportingSets with page size replacing the one in previous page token`() { - val newPageSize = PAGE_SIZE - val previousPageSize = 1 - val request = listReportingSetsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = newPageSize - pageToken = - listReportingSetsPageToken { - pageSize = previousPageSize - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - lastReportingSet = previousPageEnd { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetId = REPORTING_SET_EXTERNAL_ID - } - } - .toByteString() - .base64UrlEncode() - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listReportingSets(request) } - } - - val expected = listReportingSetsResponse { - reportingSets += REPORTING_SET - reportingSets += - REPORTING_SET.copy { - name = REPORTING_SET_NAME_2 - displayName = DISPLAY_NAME_2 - } - nextPageToken = - listReportingSetsPageToken { - pageSize = newPageSize - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - lastReportingSet = previousPageEnd { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetId = REPORTING_SET_EXTERNAL_ID_2 - } - } - .toByteString() - .base64UrlEncode() - } - - verifyProtoArgument( - internalReportingSetsMock, - ReportingSetsCoroutineImplBase::streamReportingSets, - ) - .isEqualTo( - streamReportingSetsRequest { - limit = newPageSize + 1 - filter = - StreamReportingSetsRequestKt.filter { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetIdAfter = REPORTING_SET_EXTERNAL_ID - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReportingSets throws UNAUTHENTICATED when no principal is found`() { - val request = listReportingSetsRequest { parent = MEASUREMENT_CONSUMER_NAME } - val exception = - assertFailsWith { runBlocking { service.listReportingSets(request) } } - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - } - - @Test - fun `listReportingSets throws PERMISSION_DENIED when MeasurementConsumer caller doesn't match`() { - val request = listReportingSetsRequest { parent = MEASUREMENT_CONSUMER_NAME } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME_2, CONFIG) { - runBlocking { service.listReportingSets(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) - assertThat(exception.status.description) - .isEqualTo("Cannot list ReportingSets belonging to other MeasurementConsumers.") - } - - @Test - fun `listReportingSets throws UNAUTHENTICATED when the caller is not MeasurementConsumer`() { - val request = listReportingSetsRequest { parent = MEASUREMENT_CONSUMER_NAME } - val exception = - assertFailsWith { - withDataProviderPrincipal(DATA_PROVIDER_NAME) { - runBlocking { service.listReportingSets(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - assertThat(exception.status.description).isEqualTo("No ReportingPrincipal found") - } - - @Test - fun `listReportingSets throws INVALID_ARGUMENT when page size is less than 0`() { - val request = listReportingSetsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = -1 - } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listReportingSets(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description).isEqualTo("Page size cannot be less than 0") - } - - @Test - fun `listReportingSets throws INVALID_ARGUMENT when parent is unspecified`() { - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listReportingSets(ListReportingSetsRequest.getDefaultInstance()) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `listReportingSets throws INVALID_ARGUMENT when mc id doesn't match one in page token`() { - val request = listReportingSetsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageToken = - listReportingSetsPageToken { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID_2 - lastReportingSet = previousPageEnd { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID_2 - externalReportingSetId = REPORTING_SET_EXTERNAL_ID - } - } - .toByteString() - .base64UrlEncode() - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listReportingSets(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsServiceTest.kt deleted file mode 100644 index 3a31d8ad53a..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsServiceTest.kt +++ /dev/null @@ -1,4656 +0,0 @@ -/* - * Copyright 2022 The Cross-Media Measurement Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import com.google.common.truth.Truth.assertThat -import com.google.common.truth.extensions.proto.ProtoTruth.assertThat -import com.google.protobuf.duration -import com.google.protobuf.kotlin.toByteString -import com.google.protobuf.kotlin.toByteStringUtf8 -import com.google.protobuf.timestamp -import com.google.protobuf.util.Durations -import com.google.protobuf.util.Timestamps -import com.google.type.interval -import io.grpc.Status -import io.grpc.StatusException -import io.grpc.StatusRuntimeException -import java.nio.file.Paths -import java.security.cert.X509Certificate -import java.time.Duration -import java.time.Instant -import kotlin.random.Random -import kotlin.test.assertFails -import kotlin.test.assertFailsWith -import kotlinx.coroutines.flow.flowOf -import kotlinx.coroutines.runBlocking -import org.junit.Before -import org.junit.Rule -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.mockito.kotlin.KArgumentCaptor -import org.mockito.kotlin.any -import org.mockito.kotlin.argumentCaptor -import org.mockito.kotlin.doReturn -import org.mockito.kotlin.eq -import org.mockito.kotlin.mock -import org.mockito.kotlin.stub -import org.mockito.kotlin.times -import org.mockito.kotlin.verify -import org.mockito.kotlin.verifyBlocking -import org.mockito.kotlin.whenever -import org.wfanet.measurement.api.v2alpha.Certificate -import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt.CertificatesCoroutineImplBase -import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt.CertificatesCoroutineStub -import org.wfanet.measurement.api.v2alpha.CreateMeasurementRequest -import org.wfanet.measurement.api.v2alpha.DataProviderCertificateKey -import org.wfanet.measurement.api.v2alpha.DataProviderKey -import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCoroutineImplBase -import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCoroutineStub -import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey -import org.wfanet.measurement.api.v2alpha.EventGroupKey as CmmsEventGroupKey -import org.wfanet.measurement.api.v2alpha.GetDataProviderRequest -import org.wfanet.measurement.api.v2alpha.Measurement -import org.wfanet.measurement.api.v2alpha.Measurement.DataProviderEntry.Value.ENCRYPTED_REQUISITION_SPEC_FIELD_NUMBER -import org.wfanet.measurement.api.v2alpha.MeasurementConsumer -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerCertificateKey -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey -import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineImplBase -import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub -import org.wfanet.measurement.api.v2alpha.MeasurementKey -import org.wfanet.measurement.api.v2alpha.MeasurementKt -import org.wfanet.measurement.api.v2alpha.MeasurementKt.DataProviderEntryKt -import org.wfanet.measurement.api.v2alpha.MeasurementKt.dataProviderEntry -import org.wfanet.measurement.api.v2alpha.MeasurementKt.failure -import org.wfanet.measurement.api.v2alpha.MeasurementKt.resultOutput -import org.wfanet.measurement.api.v2alpha.MeasurementSpec -import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt -import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.vidSamplingInterval -import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineImplBase -import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub -import org.wfanet.measurement.api.v2alpha.RequisitionSpec -import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt -import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt.EventGroupEntryKt -import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt.eventFilter -import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt.eventGroupEntry -import org.wfanet.measurement.api.v2alpha.certificate -import org.wfanet.measurement.api.v2alpha.copy -import org.wfanet.measurement.api.v2alpha.createMeasurementRequest -import org.wfanet.measurement.api.v2alpha.dataProvider -import org.wfanet.measurement.api.v2alpha.differentialPrivacyParams -import org.wfanet.measurement.api.v2alpha.encryptionPublicKey -import org.wfanet.measurement.api.v2alpha.getCertificateRequest -import org.wfanet.measurement.api.v2alpha.getDataProviderRequest -import org.wfanet.measurement.api.v2alpha.getMeasurementConsumerRequest -import org.wfanet.measurement.api.v2alpha.getMeasurementRequest -import org.wfanet.measurement.api.v2alpha.measurement -import org.wfanet.measurement.api.v2alpha.measurementConsumer -import org.wfanet.measurement.api.v2alpha.measurementSpec -import org.wfanet.measurement.api.v2alpha.requisitionSpec -import org.wfanet.measurement.api.v2alpha.unpack -import org.wfanet.measurement.api.v2alpha.withDataProviderPrincipal -import org.wfanet.measurement.common.base64UrlEncode -import org.wfanet.measurement.common.crypto.Hashing -import org.wfanet.measurement.common.crypto.PrivateKeyHandle -import org.wfanet.measurement.common.crypto.SigningKeyHandle -import org.wfanet.measurement.common.crypto.readCertificate -import org.wfanet.measurement.common.crypto.subjectKeyIdentifier -import org.wfanet.measurement.common.crypto.testing.loadSigningKey -import org.wfanet.measurement.common.crypto.tink.loadPrivateKey -import org.wfanet.measurement.common.getRuntimePath -import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule -import org.wfanet.measurement.common.grpc.testing.mockService -import org.wfanet.measurement.common.identity.ExternalId -import org.wfanet.measurement.common.identity.externalIdToApiId -import org.wfanet.measurement.common.pack -import org.wfanet.measurement.common.readByteString -import org.wfanet.measurement.common.testing.captureFirst -import org.wfanet.measurement.common.testing.verifyProtoArgument -import org.wfanet.measurement.common.toProtoDuration -import org.wfanet.measurement.common.toProtoTime -import org.wfanet.measurement.config.reporting.MeasurementSpecConfigKt -import org.wfanet.measurement.config.reporting.measurementConsumerConfig -import org.wfanet.measurement.config.reporting.measurementSpecConfig -import org.wfanet.measurement.consent.client.dataprovider.decryptRequisitionSpec -import org.wfanet.measurement.consent.client.dataprovider.verifyMeasurementSpec -import org.wfanet.measurement.consent.client.dataprovider.verifyRequisitionSpec -import org.wfanet.measurement.consent.client.duchy.encryptResult -import org.wfanet.measurement.consent.client.duchy.signResult -import org.wfanet.measurement.consent.client.measurementconsumer.encryptRequisitionSpec -import org.wfanet.measurement.consent.client.measurementconsumer.signEncryptionPublicKey -import org.wfanet.measurement.consent.client.measurementconsumer.signMeasurementSpec -import org.wfanet.measurement.consent.client.measurementconsumer.signRequisitionSpec -import org.wfanet.measurement.internal.reporting.CreateReportRequestKt as InternalCreateReportRequestKt -import org.wfanet.measurement.internal.reporting.GetReportRequest as GetInternalReportRequest -import org.wfanet.measurement.internal.reporting.Measurement as InternalMeasurement -import org.wfanet.measurement.internal.reporting.MeasurementKt as InternalMeasurementKt -import org.wfanet.measurement.internal.reporting.MeasurementKt.ResultKt as InternalMeasurementResultKt -import org.wfanet.measurement.internal.reporting.MeasurementsGrpcKt.MeasurementsCoroutineImplBase as InternalMeasurementsCoroutineImplBase -import org.wfanet.measurement.internal.reporting.MeasurementsGrpcKt.MeasurementsCoroutineStub as InternalMeasurementsCoroutineStub -import org.wfanet.measurement.internal.reporting.Metric as InternalMetric -import org.wfanet.measurement.internal.reporting.MetricKt as InternalMetricKt -import org.wfanet.measurement.internal.reporting.MetricKt.MeasurementCalculationKt.weightedMeasurement -import org.wfanet.measurement.internal.reporting.MetricKt.SetOperationKt.reportingSetKey -import org.wfanet.measurement.internal.reporting.MetricKt.measurementCalculation -import org.wfanet.measurement.internal.reporting.Report as InternalReport -import org.wfanet.measurement.internal.reporting.ReportKt as InternalReportKt -import org.wfanet.measurement.internal.reporting.ReportKt.DetailsKt as InternalReportDetailsKt -import org.wfanet.measurement.internal.reporting.ReportKt.DetailsKt.ResultKt as InternalReportResultKt -import org.wfanet.measurement.internal.reporting.ReportingSet as InternalReportingSet -import org.wfanet.measurement.internal.reporting.ReportingSetKt as InternalReportingSetKt -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt.ReportingSetsCoroutineImplBase as InternalReportingSetsCoroutineImplBase -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt.ReportingSetsCoroutineStub as InternalReportingSetsCoroutineStub -import org.wfanet.measurement.internal.reporting.ReportsGrpcKt.ReportsCoroutineImplBase -import org.wfanet.measurement.internal.reporting.ReportsGrpcKt.ReportsCoroutineStub as InternalReportsCoroutineStub -import org.wfanet.measurement.internal.reporting.StreamReportsRequestKt.filter -import org.wfanet.measurement.internal.reporting.batchCreateMeasurementsRequest -import org.wfanet.measurement.internal.reporting.batchGetReportingSetRequest -import org.wfanet.measurement.internal.reporting.copy -import org.wfanet.measurement.internal.reporting.createReportRequest as internalCreateReportRequest -import org.wfanet.measurement.internal.reporting.getReportByIdempotencyKeyRequest -import org.wfanet.measurement.internal.reporting.getReportRequest as getInternalReportRequest -import org.wfanet.measurement.internal.reporting.measurement as internalMeasurement -import org.wfanet.measurement.internal.reporting.metric as internalMetric -import org.wfanet.measurement.internal.reporting.periodicTimeInterval as internalPeriodicTimeInterval -import org.wfanet.measurement.internal.reporting.report as internalReport -import org.wfanet.measurement.internal.reporting.reportingSet as internalReportingSet -import org.wfanet.measurement.internal.reporting.setMeasurementFailureRequest -import org.wfanet.measurement.internal.reporting.setMeasurementResultRequest -import org.wfanet.measurement.internal.reporting.streamReportsRequest -import org.wfanet.measurement.internal.reporting.timeInterval as internalTimeInterval -import org.wfanet.measurement.internal.reporting.timeIntervals as internalTimeIntervals -import org.wfanet.measurement.reporting.service.api.InMemoryEncryptionKeyPairStore -import org.wfanet.measurement.reporting.v1alpha.ListReportsPageTokenKt.previousPageEnd -import org.wfanet.measurement.reporting.v1alpha.ListReportsRequest -import org.wfanet.measurement.reporting.v1alpha.Metric.SetOperation -import org.wfanet.measurement.reporting.v1alpha.MetricKt.SetOperationKt -import org.wfanet.measurement.reporting.v1alpha.MetricKt.frequencyHistogramParams -import org.wfanet.measurement.reporting.v1alpha.MetricKt.impressionCountParams -import org.wfanet.measurement.reporting.v1alpha.MetricKt.namedSetOperation -import org.wfanet.measurement.reporting.v1alpha.MetricKt.reachParams -import org.wfanet.measurement.reporting.v1alpha.MetricKt.setOperation -import org.wfanet.measurement.reporting.v1alpha.MetricKt.watchDurationParams -import org.wfanet.measurement.reporting.v1alpha.Report -import org.wfanet.measurement.reporting.v1alpha.ReportKt -import org.wfanet.measurement.reporting.v1alpha.ReportKt.EventGroupUniverseKt -import org.wfanet.measurement.reporting.v1alpha.ReportKt.ResultKt.HistogramTableKt.row -import org.wfanet.measurement.reporting.v1alpha.ReportKt.ResultKt.column -import org.wfanet.measurement.reporting.v1alpha.ReportKt.ResultKt.histogramTable -import org.wfanet.measurement.reporting.v1alpha.ReportKt.ResultKt.scalarTable -import org.wfanet.measurement.reporting.v1alpha.ReportKt.eventGroupUniverse -import org.wfanet.measurement.reporting.v1alpha.copy -import org.wfanet.measurement.reporting.v1alpha.createReportRequest -import org.wfanet.measurement.reporting.v1alpha.getReportRequest -import org.wfanet.measurement.reporting.v1alpha.listReportsPageToken -import org.wfanet.measurement.reporting.v1alpha.listReportsRequest -import org.wfanet.measurement.reporting.v1alpha.listReportsResponse -import org.wfanet.measurement.reporting.v1alpha.metric -import org.wfanet.measurement.reporting.v1alpha.periodicTimeInterval -import org.wfanet.measurement.reporting.v1alpha.report -import org.wfanet.measurement.reporting.v1alpha.timeInterval -import org.wfanet.measurement.reporting.v1alpha.timeIntervals - -private const val DEFAULT_PAGE_SIZE = 50 -private const val MAX_PAGE_SIZE = 1000 -private const val PAGE_SIZE = 3 - -private const val NUMBER_VID_BUCKETS = 300 -private const val WIDTH = 256 -private const val DELTA = 1e-15 - -private const val MAXIMUM_FREQUENCY = 10 - -private val MEASUREMENT_SPEC_CONFIG = measurementSpecConfig { - reachSingleDataProvider = - MeasurementSpecConfigKt.reachSingleDataProvider { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.000207 - delta = DELTA - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - reach = - MeasurementSpecConfigKt.reach { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.0007444 - delta = DELTA - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - randomStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.randomStart { - width = WIDTH - numVidBuckets = NUMBER_VID_BUCKETS - } - } - } - reachAndFrequencySingleDataProvider = - MeasurementSpecConfigKt.reachAndFrequencySingleDataProvider { - reachPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.004728 - delta = DELTA - } - frequencyPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.004728 - delta = DELTA - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - reachAndFrequency = - MeasurementSpecConfigKt.reachAndFrequency { - reachPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.014638 - delta = DELTA - } - frequencyPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.014638 - delta = DELTA - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - randomStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.randomStart { - width = WIDTH - numVidBuckets = NUMBER_VID_BUCKETS - } - } - } - impression = - MeasurementSpecConfigKt.impression { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.003592 - delta = DELTA - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - duration = - MeasurementSpecConfigKt.duration { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.007418 - delta = DELTA - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } -} - -private const val SECURE_RANDOM_OUTPUT_INT = 0 -private const val SECURE_RANDOM_OUTPUT_LONG = 0L - -private val SECRETS_DIR = - getRuntimePath( - Paths.get("wfa_measurement_system", "src", "main", "k8s", "testing", "secretfiles") - )!! - .toFile() - -// Authentication key -private const val API_AUTHENTICATION_KEY = "nR5QPN7ptx" - -// Aggregator certificate - -private val AGGREGATOR_SIGNING_KEY: SigningKeyHandle by lazy { - loadSigningKey( - SECRETS_DIR.resolve("aggregator_cs_cert.der"), - SECRETS_DIR.resolve("aggregator_cs_private.der"), - ) -} -private val AGGREGATOR_CERTIFICATE = certificate { - name = "duchies/aggregator/certificates/abc123" - x509Der = AGGREGATOR_SIGNING_KEY.certificate.encoded.toByteString() -} -private val AGGREGATOR_ROOT_CERTIFICATE: X509Certificate = - readCertificate(SECRETS_DIR.resolve("aggregator_root.pem")) - -private val INVALID_MEASUREMENT_PUBLIC_KEY_DATA = "Invalid public key".toByteStringUtf8() - -// Measurement consumer crypto - -private val TRUSTED_MEASUREMENT_CONSUMER_ISSUER: X509Certificate = - readCertificate(SECRETS_DIR.resolve("mc_root.pem")) -private val MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE = - loadSigningKey(SECRETS_DIR.resolve("mc_cs_cert.der"), SECRETS_DIR.resolve("mc_cs_private.der")) -private val MEASUREMENT_CONSUMER_CERTIFICATE = MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE.certificate -private val MEASUREMENT_CONSUMER_PRIVATE_KEY_HANDLE: PrivateKeyHandle = - loadPrivateKey(SECRETS_DIR.resolve("mc_enc_private.tink")) -private val MEASUREMENT_CONSUMER_PUBLIC_KEY = encryptionPublicKey { - format = EncryptionPublicKey.Format.TINK_KEYSET - data = SECRETS_DIR.resolve("mc_enc_public.tink").readByteString() -} - -private val MEASUREMENT_CONSUMERS: Map = - (1L..2L).associate { - val measurementConsumerKey = MeasurementConsumerKey(ExternalId(it + 110L).apiId.value) - val certificateKey = - MeasurementConsumerCertificateKey( - measurementConsumerKey.measurementConsumerId, - ExternalId(it + 120L).apiId.value, - ) - measurementConsumerKey to - measurementConsumer { - name = measurementConsumerKey.toName() - certificate = certificateKey.toName() - certificateDer = MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE.certificate.encoded.toByteString() - publicKey = - signEncryptionPublicKey( - MEASUREMENT_CONSUMER_PUBLIC_KEY, - MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE, - ) - } - } - -private val CONFIG = measurementConsumerConfig { - apiKey = API_AUTHENTICATION_KEY - signingCertificateName = MEASUREMENT_CONSUMERS.values.first().certificate - signingPrivateKeyPath = "mc_cs_private.der" -} - -// InMemoryEncryptionKeyPairStore -private val ENCRYPTION_KEY_PAIR_STORE = - InMemoryEncryptionKeyPairStore( - MEASUREMENT_CONSUMERS.values.associateBy( - { it.name }, - { - listOf( - it.publicKey.unpack().data to MEASUREMENT_CONSUMER_PRIVATE_KEY_HANDLE - ) - }, - ) - ) - -// Report IDs and names -private val REPORT_EXTERNAL_IDS = listOf(331L, 332L, 333L, 334L) -private val REPORT_NAMES = - REPORT_EXTERNAL_IDS.map { - ReportKey(MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId, externalIdToApiId(it)) - .toName() - } - -// Typo causes invalid name -private const val INVALID_REPORT_NAME = "measurementConsumer/AAAAAAAAAG8/report/AAAAAAAAAU0" - -private val DATA_PROVIDER_PUBLIC_KEY = encryptionPublicKey { - format = EncryptionPublicKey.Format.TINK_KEYSET - data = SECRETS_DIR.resolve("edp1_enc_public.tink").readByteString() -} -private val DATA_PROVIDER_PRIVATE_KEY_HANDLE = - loadPrivateKey(SECRETS_DIR.resolve("edp1_enc_private.tink")) -private val DATA_PROVIDER_SIGNING_KEY = - loadSigningKey( - SECRETS_DIR.resolve("edp1_cs_cert.der"), - SECRETS_DIR.resolve("edp1_cs_private.der"), - ) -private val DATA_PROVIDER_ROOT_CERTIFICATE = readCertificate(SECRETS_DIR.resolve("edp1_root.pem")) - -// Data providers - -private val DATA_PROVIDERS = - (1L..3L).associate { - val dataProviderKey = DataProviderKey(ExternalId(it + 550L).apiId.value) - val certificateKey = - DataProviderCertificateKey(dataProviderKey.dataProviderId, ExternalId(it + 560L).apiId.value) - dataProviderKey to - dataProvider { - name = dataProviderKey.toName() - certificate = certificateKey.toName() - publicKey = signEncryptionPublicKey(DATA_PROVIDER_PUBLIC_KEY, DATA_PROVIDER_SIGNING_KEY) - } - } -private val DATA_PROVIDERS_LIST = DATA_PROVIDERS.values.toList() - -// Event group keys - -private val COVERED_EVENT_GROUP_KEYS = - DATA_PROVIDERS.keys.mapIndexed { index, dataProviderKey -> - val measurementConsumerKey = MEASUREMENT_CONSUMERS.keys.first() - EventGroupKey( - measurementConsumerKey.measurementConsumerId, - dataProviderKey.dataProviderId, - ExternalId(index + 660L).apiId.value, - ) - } -private val UNCOVERED_EVENT_GROUP_KEY = - EventGroupKey( - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId, - DATA_PROVIDERS.keys.last().dataProviderId, - ExternalId(664L).apiId.value, - ) -private val UNCOVERED_EVENT_GROUP_NAME = UNCOVERED_EVENT_GROUP_KEY.toName() -private val UNCOVERED_INTERNAL_EVENT_GROUP_KEY = UNCOVERED_EVENT_GROUP_KEY.toInternal() - -// Reporting sets -private const val REPORTING_SET_FILTER = "AGE>18" - -private val INTERNAL_REPORTING_SETS = - COVERED_EVENT_GROUP_KEYS.mapIndexed { index, eventGroupKey -> - internalReportingSet { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetId = index + 220L - eventGroupKeys += eventGroupKey.toInternal() - filter = REPORTING_SET_FILTER - displayName = "$measurementConsumerReferenceId-$externalReportingSetId-$filter" - } - } -private val UNCOVERED_INTERNAL_REPORTING_SET = internalReportingSet { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetId = INTERNAL_REPORTING_SETS.last().externalReportingSetId + 1 - eventGroupKeys += UNCOVERED_INTERNAL_EVENT_GROUP_KEY - filter = REPORTING_SET_FILTER - displayName = "$measurementConsumerReferenceId-$externalReportingSetId-$filter" -} - -// Reporting set IDs and names - -private const val REPORTING_SET_EXTERNAL_ID_FOR_MC_2 = 241L - -private const val INVALID_REPORTING_SET_NAME = "INVALID_REPORTING_SET_NAME" -private val REPORTING_SET_NAME_FOR_MC_2 = - ReportingSetKey( - MEASUREMENT_CONSUMERS.keys.last().measurementConsumerId, - externalIdToApiId(REPORTING_SET_EXTERNAL_ID_FOR_MC_2), - ) - .toName() - -// Time intervals -private val START_INSTANT = Instant.now() -private val END_INSTANT = START_INSTANT.plus(Duration.ofDays(1)) - -private val START_TIME = START_INSTANT.toProtoTime() -private val TIME_INTERVAL_INCREMENT = Duration.ofDays(1).toProtoDuration() -private const val INTERVAL_COUNT = 1 -private val END_TIME = END_INSTANT.toProtoTime() -private val MEASUREMENT_TIME_INTERVAL = interval { - startTime = START_TIME - endTime = END_TIME -} -private val INTERNAL_TIME_INTERVAL = internalTimeInterval { - startTime = START_TIME - endTime = END_TIME -} -private val INTERNAL_PERIODIC_TIME_INTERVAL = internalPeriodicTimeInterval { - startTime = START_TIME - increment = TIME_INTERVAL_INCREMENT - intervalCount = INTERVAL_COUNT -} -private val PERIODIC_TIME_INTERVAL = periodicTimeInterval { - startTime = START_TIME - increment = TIME_INTERVAL_INCREMENT - intervalCount = INTERVAL_COUNT -} - -// Report idempotency keys -private const val REACH_REPORT_IDEMPOTENCY_KEY = "TEST_REACH_REPORT" -private const val IMPRESSION_REPORT_IDEMPOTENCY_KEY = "TEST_IMPRESSION_REPORT" -private const val WATCH_DURATION_REPORT_IDEMPOTENCY_KEY = "TEST_WATCH_DURATION_REPORT" -private const val FREQUENCY_HISTOGRAM_REPORT_IDEMPOTENCY_KEY = "TEST_FREQUENCY_HISTOGRAM_REPORT" - -// Set operation unique names -private const val REACH_SET_OPERATION_UNIQUE_NAME = "Reach Set Operation" -private const val FREQUENCY_HISTOGRAM_SET_OPERATION_UNIQUE_NAME = - "Frequency Histogram Set Operation" -private const val IMPRESSION_SET_OPERATION_UNIQUE_NAME = "Impression Set Operation" -private const val WATCH_DURATION_SET_OPERATION_UNIQUE_NAME = "Watch Duration Set Operation" - -// Measurement IDs and names -private val REACH_MEASUREMENT_CREATE_REQUEST_ID = - "$REACH_REPORT_IDEMPOTENCY_KEY-Reach-$REACH_SET_OPERATION_UNIQUE_NAME-$START_INSTANT-" + - "$END_INSTANT-measurement-0" - -private val FREQUENCY_HISTOGRAM_MEASUREMENT_CREATE_REQUEST_ID = - "$FREQUENCY_HISTOGRAM_REPORT_IDEMPOTENCY_KEY-FrequencyHistogram-$FREQUENCY_HISTOGRAM_SET_OPERATION_UNIQUE_NAME-$START_INSTANT-" + - "$END_INSTANT-measurement-0" - -private val IMPRESSION_MEASUREMENT_CREATE_REQUEST_ID = - "$IMPRESSION_REPORT_IDEMPOTENCY_KEY-ImpressionCount-$IMPRESSION_SET_OPERATION_UNIQUE_NAME-$START_INSTANT-" + - "$END_INSTANT-measurement-0" - -private val WATCH_DURATION_MEASUREMENT_CREATE_REQUEST_ID = - "$WATCH_DURATION_REPORT_IDEMPOTENCY_KEY-WatchDuration-$WATCH_DURATION_SET_OPERATION_UNIQUE_NAME-$START_INSTANT-" + - "$END_INSTANT-measurement-0" - -private val REACH_MEASUREMENT_KEY = - MeasurementKey( - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId, - ExternalId(111).apiId.value, - ) -private val REACH_MEASUREMENT_KEY_2 = - MeasurementKey(REACH_MEASUREMENT_KEY.measurementConsumerId, ExternalId(222).apiId.value) -private val FREQUENCY_HISTOGRAM_MEASUREMENT_KEY = - MeasurementKey(REACH_MEASUREMENT_KEY.measurementConsumerId, ExternalId(333).apiId.value) -private val IMPRESSION_MEASUREMENT_KEY = - MeasurementKey(REACH_MEASUREMENT_KEY.measurementConsumerId, ExternalId(444).apiId.value) -private val WATCH_DURATION_MEASUREMENT_KEY = - MeasurementKey(REACH_MEASUREMENT_KEY.measurementConsumerId, ExternalId(555).apiId.value) - -// Set operations -private val INTERNAL_SET_OPERATION = - InternalMetricKt.setOperation { - type = InternalMetric.SetOperation.Type.UNION - lhs = - InternalMetricKt.SetOperationKt.operand { - reportingSetId = reportingSetKey { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetId = INTERNAL_REPORTING_SETS[0].externalReportingSetId - } - } - rhs = - InternalMetricKt.SetOperationKt.operand { - reportingSetId = reportingSetKey { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetId = INTERNAL_REPORTING_SETS[1].externalReportingSetId - } - } - } - -private val SET_OPERATION = setOperation { - type = SetOperation.Type.UNION - lhs = SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[0].resourceName } - rhs = SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[1].resourceName } -} -private val DATA_PROVIDER_KEYS_IN_SET_OPERATION = DATA_PROVIDERS.keys.take(2) - -private val SET_OPERATION_WITH_INVALID_REPORTING_SET = setOperation { - type = SetOperation.Type.UNION - lhs = SetOperationKt.operand { reportingSet = INVALID_REPORTING_SET_NAME } - rhs = SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[1].resourceName } -} - -private val SET_OPERATION_WITH_INACCESSIBLE_REPORTING_SET = setOperation { - type = SetOperation.Type.UNION - lhs = SetOperationKt.operand { reportingSet = REPORTING_SET_NAME_FOR_MC_2 } - rhs = SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[1].resourceName } -} - -// Event group filters -private const val EVENT_GROUP_FILTER = "AGE>20" -private val EVENT_GROUP_FILTERS_MAP = - COVERED_EVENT_GROUP_KEYS.associateBy(EventGroupKey::toName) { EVENT_GROUP_FILTER } - -// Event group entries -private val EVENT_GROUP_ENTRIES = - COVERED_EVENT_GROUP_KEYS.groupBy( - { DataProviderKey(it.dataProviderReferenceId) }, - { - eventGroupEntry { - key = CmmsEventGroupKey(it.dataProviderReferenceId, it.eventGroupReferenceId).toName() - value = - EventGroupEntryKt.value { - collectionInterval = MEASUREMENT_TIME_INTERVAL - filter = eventFilter { - expression = "($REPORTING_SET_FILTER) AND ($EVENT_GROUP_FILTER)" - } - } - } - }, - ) - -// Requisition specs -private val REQUISITION_SPECS: Map = - EVENT_GROUP_ENTRIES.mapValues { - requisitionSpec { - events = RequisitionSpecKt.events { eventGroups += it.value } - measurementPublicKey = MEASUREMENT_CONSUMERS.values.first().publicKey.message - nonce = SECURE_RANDOM_OUTPUT_LONG - } - } - -// Data provider entries -private val DATA_PROVIDER_ENTRIES = - REQUISITION_SPECS.mapValues { (dataProviderKey, requisitionSpec) -> - val dataProvider = DATA_PROVIDERS.getValue(dataProviderKey) - dataProviderEntry { - key = dataProvider.name - value = - DataProviderEntryKt.value { - dataProviderCertificate = dataProvider.certificate - dataProviderPublicKey = dataProvider.publicKey.message - encryptedRequisitionSpec = - encryptRequisitionSpec( - signRequisitionSpec(requisitionSpec, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE), - dataProvider.publicKey.unpack(), - ) - nonceHash = Hashing.hashSha256(requisitionSpec.nonce) - } - } - } - -// Measurements -private val BASE_MEASUREMENT = measurement { - measurementConsumerCertificate = MEASUREMENT_CONSUMERS.values.first().certificate -} -private val BASE_MEASUREMENT_SPEC = measurementSpec { - measurementPublicKey = MEASUREMENT_CONSUMER_PUBLIC_KEY.pack() - // TODO(world-federation-of-advertisers/cross-media-measurement#1301): Stop setting this field. - serializedMeasurementPublicKey = measurementPublicKey.value -} - -// Measurement values -private const val REACH_VALUE = 100_000L -private val FREQUENCY_DISTRIBUTION = mapOf(1L to 1.0 / 6, 2L to 2.0 / 6, 3L to 3.0 / 6) -private val IMPRESSION_VALUES = listOf(100L, 150L) -private val TOTAL_IMPRESSION_VALUE = IMPRESSION_VALUES.sum() -private val WATCH_DURATION_SECOND_LIST = listOf(100L, 200L) -private val WATCH_DURATION_LIST = WATCH_DURATION_SECOND_LIST.map { duration { seconds = it } } -private val TOTAL_WATCH_DURATION = duration { seconds = WATCH_DURATION_SECOND_LIST.sum() } - -// Reach measurement -private val BASE_REACH_MEASUREMENT = BASE_MEASUREMENT.copy { name = REACH_MEASUREMENT_KEY.toName() } -private val BASE_REACH_MEASUREMENT_2 = - BASE_MEASUREMENT.copy { name = REACH_MEASUREMENT_KEY_2.toName() } - -private val PENDING_REACH_MEASUREMENT = - BASE_REACH_MEASUREMENT.copy { state = Measurement.State.COMPUTING } - -private val REACH_SINGLE_DATA_PROVIDER_MEASUREMENT_SPEC = - BASE_MEASUREMENT_SPEC.copy { - nonceHashes.addAll(listOf(Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG))) - - reach = - MeasurementSpecKt.reach { - privacyParams = differentialPrivacyParams { - epsilon = MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.privacyParams.epsilon - delta = MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.privacyParams.delta - } - } - vidSamplingInterval = vidSamplingInterval { - start = 0f - width = 1f - } - } - -private val REACH_SINGLE_DATA_PROVIDER_MEASUREMENT_REQUEST = createMeasurementRequest { - parent = MeasurementConsumerKey(REACH_MEASUREMENT_KEY.measurementConsumerId).toName() - measurement = - BASE_MEASUREMENT.copy { - dataProviders += DATA_PROVIDER_ENTRIES.getValue(DataProviderKey(ExternalId(551L).apiId.value)) - measurementSpec = - signMeasurementSpec( - REACH_SINGLE_DATA_PROVIDER_MEASUREMENT_SPEC, - MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE, - ) - } - requestId = REACH_MEASUREMENT_CREATE_REQUEST_ID -} - -private val REACH_MEASUREMENT_SPEC = - BASE_MEASUREMENT_SPEC.copy { - nonceHashes.addAll( - listOf( - Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG), - Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG), - ) - ) - - reach = - MeasurementSpecKt.reach { - privacyParams = differentialPrivacyParams { - epsilon = MEASUREMENT_SPEC_CONFIG.reach.privacyParams.epsilon - delta = MEASUREMENT_SPEC_CONFIG.reach.privacyParams.delta - } - } - vidSamplingInterval = vidSamplingInterval { - start = SECURE_RANDOM_OUTPUT_INT.toFloat() / NUMBER_VID_BUCKETS - width = WIDTH.toFloat() / NUMBER_VID_BUCKETS - } - } - -private val REACH_MEASUREMENT_REQUEST = createMeasurementRequest { - parent = MeasurementConsumerKey(REACH_MEASUREMENT_KEY.measurementConsumerId).toName() - measurement = - BASE_MEASUREMENT.copy { - dataProviders += - DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } - measurementSpec = - signMeasurementSpec(REACH_MEASUREMENT_SPEC, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE) - } - requestId = REACH_MEASUREMENT_CREATE_REQUEST_ID -} - -private val SUCCEEDED_REACH_MEASUREMENT = - BASE_REACH_MEASUREMENT.copy { - dataProviders += DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } - - measurementSpec = - signMeasurementSpec(REACH_MEASUREMENT_SPEC, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE) - - state = Measurement.State.SUCCEEDED - - results += resultOutput { - val result = - MeasurementKt.result { - reach = MeasurementKt.ResultKt.reach { value = REACH_VALUE } - frequency = - MeasurementKt.ResultKt.frequency { - relativeFrequencyDistribution.putAll(FREQUENCY_DISTRIBUTION) - } - } - encryptedResult = - encryptResult(signResult(result, AGGREGATOR_SIGNING_KEY), MEASUREMENT_CONSUMER_PUBLIC_KEY) - certificate = AGGREGATOR_CERTIFICATE.name - } - } - -private val INTERNAL_PENDING_REACH_MEASUREMENT = internalMeasurement { - measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId - state = InternalMeasurement.State.PENDING -} -private val INTERNAL_SUCCEEDED_REACH_MEASUREMENT = - INTERNAL_PENDING_REACH_MEASUREMENT.copy { - state = InternalMeasurement.State.SUCCEEDED - result = - InternalMeasurementKt.result { - reach = InternalMeasurementResultKt.reach { value = REACH_VALUE } - frequency = - InternalMeasurementResultKt.frequency { - relativeFrequencyDistribution.putAll(FREQUENCY_DISTRIBUTION) - } - } - } - -// Frequency histogram measurement -private val BASE_REACH_FREQUENCY_HISTOGRAM_MEASUREMENT = - BASE_MEASUREMENT.copy { name = FREQUENCY_HISTOGRAM_MEASUREMENT_KEY.toName() } - -private val REACH_FREQUENCY_SINGLE_DATA_PROVIDER_MEASUREMENT_SPEC = - BASE_MEASUREMENT_SPEC.copy { - nonceHashes.addAll(listOf(Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG))) - - reachAndFrequency = - MeasurementSpecKt.reachAndFrequency { - reachPrivacyParams = differentialPrivacyParams { - epsilon = - MEASUREMENT_SPEC_CONFIG.reachAndFrequencySingleDataProvider.reachPrivacyParams.epsilon - delta = - MEASUREMENT_SPEC_CONFIG.reachAndFrequencySingleDataProvider.reachPrivacyParams.delta - } - frequencyPrivacyParams = differentialPrivacyParams { - epsilon = - MEASUREMENT_SPEC_CONFIG.reachAndFrequencySingleDataProvider.frequencyPrivacyParams - .epsilon - delta = - MEASUREMENT_SPEC_CONFIG.reachAndFrequencySingleDataProvider.frequencyPrivacyParams.delta - } - maximumFrequency = MAXIMUM_FREQUENCY - } - vidSamplingInterval = vidSamplingInterval { - start = 0f - width = 1f - } - } - -private val REACH_FREQUENCY_SINGLE_DATA_PROVIDER_MEASUREMENT_REQUEST = createMeasurementRequest { - parent = MeasurementConsumerKey(REACH_MEASUREMENT_KEY.measurementConsumerId).toName() - measurement = - BASE_MEASUREMENT.copy { - dataProviders += DATA_PROVIDER_ENTRIES.getValue(DataProviderKey(ExternalId(551L).apiId.value)) - measurementSpec = - signMeasurementSpec( - REACH_FREQUENCY_SINGLE_DATA_PROVIDER_MEASUREMENT_SPEC, - MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE, - ) - } - requestId = FREQUENCY_HISTOGRAM_MEASUREMENT_CREATE_REQUEST_ID -} - -private val REACH_FREQUENCY_MEASUREMENT_SPEC = - BASE_MEASUREMENT_SPEC.copy { - nonceHashes.addAll( - listOf( - Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG), - Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG), - ) - ) - - reachAndFrequency = - MeasurementSpecKt.reachAndFrequency { - reachPrivacyParams = differentialPrivacyParams { - epsilon = MEASUREMENT_SPEC_CONFIG.reachAndFrequency.reachPrivacyParams.epsilon - delta = MEASUREMENT_SPEC_CONFIG.reachAndFrequency.reachPrivacyParams.delta - } - frequencyPrivacyParams = differentialPrivacyParams { - epsilon = MEASUREMENT_SPEC_CONFIG.reachAndFrequency.frequencyPrivacyParams.epsilon - delta = MEASUREMENT_SPEC_CONFIG.reachAndFrequency.frequencyPrivacyParams.delta - } - maximumFrequency = MAXIMUM_FREQUENCY - } - vidSamplingInterval = vidSamplingInterval { - start = SECURE_RANDOM_OUTPUT_INT.toFloat() / NUMBER_VID_BUCKETS - width = WIDTH.toFloat() / NUMBER_VID_BUCKETS - } - } - -private val REACH_FREQUENCY_MEASUREMENT_REQUEST = createMeasurementRequest { - parent = MeasurementConsumerKey(REACH_MEASUREMENT_KEY.measurementConsumerId).toName() - measurement = - BASE_MEASUREMENT.copy { - dataProviders += - DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } - measurementSpec = - signMeasurementSpec( - REACH_FREQUENCY_MEASUREMENT_SPEC, - MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE, - ) - } - requestId = FREQUENCY_HISTOGRAM_MEASUREMENT_CREATE_REQUEST_ID -} - -private val SUCCEEDED_FREQUENCY_HISTOGRAM_MEASUREMENT = - BASE_REACH_FREQUENCY_HISTOGRAM_MEASUREMENT.copy { - dataProviders += DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } - - measurementSpec = - signMeasurementSpec(REACH_FREQUENCY_MEASUREMENT_SPEC, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE) - - state = Measurement.State.SUCCEEDED - results += resultOutput { - val result = - MeasurementKt.result { - reach = MeasurementKt.ResultKt.reach { value = REACH_VALUE } - frequency = - MeasurementKt.ResultKt.frequency { - relativeFrequencyDistribution.putAll(FREQUENCY_DISTRIBUTION) - } - } - encryptedResult = - encryptResult(signResult(result, AGGREGATOR_SIGNING_KEY), MEASUREMENT_CONSUMER_PUBLIC_KEY) - certificate = AGGREGATOR_CERTIFICATE.name - } - } - -private val INTERNAL_PENDING_FREQUENCY_HISTOGRAM_MEASUREMENT = internalMeasurement { - measurementConsumerReferenceId = FREQUENCY_HISTOGRAM_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = FREQUENCY_HISTOGRAM_MEASUREMENT_KEY.measurementId - state = InternalMeasurement.State.PENDING -} - -private val INTERNAL_SUCCEEDED_FREQUENCY_HISTOGRAM_MEASUREMENT = - INTERNAL_PENDING_FREQUENCY_HISTOGRAM_MEASUREMENT.copy { - state = InternalMeasurement.State.SUCCEEDED - result = - InternalMeasurementKt.result { - reach = InternalMeasurementResultKt.reach { value = REACH_VALUE } - frequency = - InternalMeasurementResultKt.frequency { - relativeFrequencyDistribution.putAll(FREQUENCY_DISTRIBUTION) - } - } - } - -// Impression measurement -private val BASE_IMPRESSION_MEASUREMENT = - BASE_MEASUREMENT.copy { name = IMPRESSION_MEASUREMENT_KEY.toName() } - -private val IMPRESSION_MEASUREMENT_SPEC = - BASE_MEASUREMENT_SPEC.copy { - nonceHashes.addAll( - listOf( - Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG), - Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG), - ) - ) - - impression = - MeasurementSpecKt.impression { - privacyParams = differentialPrivacyParams { - epsilon = MEASUREMENT_SPEC_CONFIG.impression.privacyParams.epsilon - delta = MEASUREMENT_SPEC_CONFIG.impression.privacyParams.delta - } - maximumFrequencyPerUser = MAXIMUM_FREQUENCY - } - vidSamplingInterval = vidSamplingInterval { - start = MEASUREMENT_SPEC_CONFIG.impression.vidSamplingInterval.fixedStart.start - width = MEASUREMENT_SPEC_CONFIG.impression.vidSamplingInterval.fixedStart.width - } - } - -private val IMPRESSION_MEASUREMENT_REQUEST = createMeasurementRequest { - parent = MeasurementConsumerKey(REACH_MEASUREMENT_KEY.measurementConsumerId).toName() - measurement = - BASE_MEASUREMENT.copy { - dataProviders += - DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } - measurementSpec = - signMeasurementSpec(IMPRESSION_MEASUREMENT_SPEC, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE) - } - requestId = IMPRESSION_MEASUREMENT_CREATE_REQUEST_ID -} - -private val SUCCEEDED_IMPRESSION_MEASUREMENT = - BASE_IMPRESSION_MEASUREMENT.copy { - dataProviders += DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } - - measurementSpec = - signMeasurementSpec(IMPRESSION_MEASUREMENT_SPEC, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE) - - state = Measurement.State.SUCCEEDED - - results += - DATA_PROVIDER_KEYS_IN_SET_OPERATION.zip(IMPRESSION_VALUES).map { - (dataProviderKey, numImpressions) -> - val dataProvider = DATA_PROVIDERS.getValue(dataProviderKey) - resultOutput { - val result = - MeasurementKt.result { - impression = MeasurementKt.ResultKt.impression { value = numImpressions } - } - encryptedResult = - encryptResult( - signResult(result, DATA_PROVIDER_SIGNING_KEY), - MEASUREMENT_CONSUMER_PUBLIC_KEY, - ) - certificate = dataProvider.certificate - } - } - } - -private val INTERNAL_PENDING_IMPRESSION_MEASUREMENT = internalMeasurement { - measurementConsumerReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementId - state = InternalMeasurement.State.PENDING -} - -private val INTERNAL_SUCCEEDED_IMPRESSION_MEASUREMENT = - INTERNAL_PENDING_IMPRESSION_MEASUREMENT.copy { - state = InternalMeasurement.State.SUCCEEDED - result = - InternalMeasurementKt.result { - impression = InternalMeasurementResultKt.impression { value = TOTAL_IMPRESSION_VALUE } - } - } - -// Watch Duration measurement -private val BASE_WATCH_DURATION_MEASUREMENT = - BASE_MEASUREMENT.copy { name = WATCH_DURATION_MEASUREMENT_KEY.toName() } - -private val PENDING_WATCH_DURATION_MEASUREMENT = - BASE_WATCH_DURATION_MEASUREMENT.copy { state = Measurement.State.COMPUTING } - -private val WATCH_DURATION_MEASUREMENT_SPEC = - BASE_MEASUREMENT_SPEC.copy { - nonceHashes.addAll( - listOf( - Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG), - Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG), - ) - ) - - duration = - MeasurementSpecKt.duration { - privacyParams = differentialPrivacyParams { - epsilon = MEASUREMENT_SPEC_CONFIG.duration.privacyParams.epsilon - delta = MEASUREMENT_SPEC_CONFIG.duration.privacyParams.delta - } - maximumWatchDurationPerUser = - Durations.fromSeconds(MAXIMUM_WATCH_DURATION_PER_USER.toLong()) - } - vidSamplingInterval = vidSamplingInterval { - start = MEASUREMENT_SPEC_CONFIG.duration.vidSamplingInterval.fixedStart.start - width = MEASUREMENT_SPEC_CONFIG.duration.vidSamplingInterval.fixedStart.width - } - } - -private val WATCH_DURATION_MEASUREMENT_REQUEST = createMeasurementRequest { - parent = MeasurementConsumerKey(REACH_MEASUREMENT_KEY.measurementConsumerId).toName() - measurement = - BASE_MEASUREMENT.copy { - dataProviders += - DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } - measurementSpec = - signMeasurementSpec( - WATCH_DURATION_MEASUREMENT_SPEC, - MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE, - ) - } - requestId = WATCH_DURATION_MEASUREMENT_CREATE_REQUEST_ID -} - -private val SUCCEEDED_WATCH_DURATION_MEASUREMENT = - BASE_WATCH_DURATION_MEASUREMENT.copy { - dataProviders += DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } - - measurementSpec = - signMeasurementSpec(WATCH_DURATION_MEASUREMENT_SPEC, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE) - - state = Measurement.State.SUCCEEDED - - results += - DATA_PROVIDER_KEYS_IN_SET_OPERATION.zip(WATCH_DURATION_LIST).map { - (dataProviderKey, watchDuration) -> - val dataProvider = DATA_PROVIDERS.getValue(dataProviderKey) - resultOutput { - val result = - MeasurementKt.result { - this.watchDuration = MeasurementKt.ResultKt.watchDuration { value = watchDuration } - } - encryptedResult = - encryptResult( - signResult(result, DATA_PROVIDER_SIGNING_KEY), - MEASUREMENT_CONSUMER_PUBLIC_KEY, - ) - certificate = dataProvider.certificate - } - } - } - -private val INTERNAL_PENDING_WATCH_DURATION_MEASUREMENT = internalMeasurement { - measurementConsumerReferenceId = WATCH_DURATION_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = WATCH_DURATION_MEASUREMENT_KEY.measurementId - state = InternalMeasurement.State.PENDING -} -private val INTERNAL_SUCCEEDED_WATCH_DURATION_MEASUREMENT = - INTERNAL_PENDING_WATCH_DURATION_MEASUREMENT.copy { - state = InternalMeasurement.State.SUCCEEDED - result = - InternalMeasurementKt.result { - watchDuration = InternalMeasurementResultKt.watchDuration { value = TOTAL_WATCH_DURATION } - } - } - -// Weighted measurements -private val WEIGHTED_REACH_MEASUREMENT = weightedMeasurement { - measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId - coefficient = 1 -} - -private val WEIGHTED_FREQUENCY_HISTOGRAM_MEASUREMENT = weightedMeasurement { - measurementReferenceId = FREQUENCY_HISTOGRAM_MEASUREMENT_KEY.measurementId - coefficient = 1 -} - -private val WEIGHTED_IMPRESSION_MEASUREMENT = weightedMeasurement { - measurementReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementId - coefficient = 1 -} - -private val WEIGHTED_WATCH_DURATION_MEASUREMENT = weightedMeasurement { - measurementReferenceId = WATCH_DURATION_MEASUREMENT_KEY.measurementId - coefficient = 1 -} - -// Measurement Calculations -private val REACH_MEASUREMENT_CALCULATION = measurementCalculation { - timeInterval = INTERNAL_TIME_INTERVAL - weightedMeasurements.add(WEIGHTED_REACH_MEASUREMENT) -} - -private val FREQUENCY_HISTOGRAM_MEASUREMENT_CALCULATION = measurementCalculation { - timeInterval = INTERNAL_TIME_INTERVAL - weightedMeasurements.add(WEIGHTED_FREQUENCY_HISTOGRAM_MEASUREMENT) -} - -private val IMPRESSION_MEASUREMENT_CALCULATION = measurementCalculation { - timeInterval = INTERNAL_TIME_INTERVAL - weightedMeasurements.add(WEIGHTED_IMPRESSION_MEASUREMENT) -} - -private val WATCH_DURATION_MEASUREMENT_CALCULATION = measurementCalculation { - timeInterval = INTERNAL_TIME_INTERVAL - weightedMeasurements.add(WEIGHTED_WATCH_DURATION_MEASUREMENT) -} - -// Named set operations -// Reach set operation -private val INTERNAL_NAMED_REACH_SET_OPERATION = - InternalMetricKt.namedSetOperation { - displayName = REACH_SET_OPERATION_UNIQUE_NAME - setOperation = INTERNAL_SET_OPERATION - measurementCalculations += REACH_MEASUREMENT_CALCULATION - } -private val NAMED_REACH_SET_OPERATION = namedSetOperation { - uniqueName = REACH_SET_OPERATION_UNIQUE_NAME - setOperation = SET_OPERATION -} - -// Frequency histogram set operation -private val INTERNAL_NAMED_FREQUENCY_HISTOGRAM_SET_OPERATION = - InternalMetricKt.namedSetOperation { - displayName = FREQUENCY_HISTOGRAM_SET_OPERATION_UNIQUE_NAME - setOperation = INTERNAL_SET_OPERATION - measurementCalculations += FREQUENCY_HISTOGRAM_MEASUREMENT_CALCULATION - } -private val NAMED_FREQUENCY_HISTOGRAM_SET_OPERATION = namedSetOperation { - uniqueName = FREQUENCY_HISTOGRAM_SET_OPERATION_UNIQUE_NAME - setOperation = SET_OPERATION -} - -// Impression set operation -private val INTERNAL_NAMED_IMPRESSION_SET_OPERATION = - InternalMetricKt.namedSetOperation { - displayName = IMPRESSION_SET_OPERATION_UNIQUE_NAME - setOperation = INTERNAL_SET_OPERATION - measurementCalculations += IMPRESSION_MEASUREMENT_CALCULATION - } -private val NAMED_IMPRESSION_SET_OPERATION = namedSetOperation { - uniqueName = IMPRESSION_SET_OPERATION_UNIQUE_NAME - setOperation = SET_OPERATION -} - -// Watch duration set operation -private val INTERNAL_NAMED_WATCH_DURATION_SET_OPERATION = - InternalMetricKt.namedSetOperation { - displayName = WATCH_DURATION_SET_OPERATION_UNIQUE_NAME - setOperation = INTERNAL_SET_OPERATION - measurementCalculations += WATCH_DURATION_MEASUREMENT_CALCULATION - } -private val NAMED_WATCH_DURATION_SET_OPERATION = namedSetOperation { - uniqueName = WATCH_DURATION_SET_OPERATION_UNIQUE_NAME - setOperation = SET_OPERATION -} - -// Internal metrics -private const val MAXIMUM_FREQUENCY_PER_USER = 10 -private const val MAXIMUM_WATCH_DURATION_PER_USER = 300 - -// Reach metric -private val REACH_METRIC = metric { - reach = reachParams {} - cumulative = false - setOperations.add(NAMED_REACH_SET_OPERATION) -} -private val INTERNAL_REACH_METRIC = internalMetric { - details = - InternalMetricKt.details { - reach = InternalMetricKt.reachParams {} - cumulative = false - } - namedSetOperations.add(INTERNAL_NAMED_REACH_SET_OPERATION) -} - -// Frequency histogram metric -private val FREQUENCY_HISTOGRAM_METRIC = metric { - frequencyHistogram = frequencyHistogramParams { - maximumFrequencyPerUser = MAXIMUM_FREQUENCY_PER_USER - } - cumulative = false - setOperations.add(NAMED_FREQUENCY_HISTOGRAM_SET_OPERATION) -} -private val INTERNAL_FREQUENCY_HISTOGRAM_METRIC = internalMetric { - details = - InternalMetricKt.details { - frequencyHistogram = - InternalMetricKt.frequencyHistogramParams { maximumFrequency = MAXIMUM_FREQUENCY_PER_USER } - cumulative = false - } - namedSetOperations.add(INTERNAL_NAMED_FREQUENCY_HISTOGRAM_SET_OPERATION) -} - -// Impression metric -private val IMPRESSION_METRIC = metric { - impressionCount = impressionCountParams { maximumFrequencyPerUser = MAXIMUM_FREQUENCY_PER_USER } - cumulative = false - setOperations.add(NAMED_IMPRESSION_SET_OPERATION) -} -private val INTERNAL_IMPRESSION_METRIC = internalMetric { - details = - InternalMetricKt.details { - impressionCount = - InternalMetricKt.impressionCountParams { - maximumFrequencyPerUser = MAXIMUM_FREQUENCY_PER_USER - } - cumulative = false - } - namedSetOperations.add(INTERNAL_NAMED_IMPRESSION_SET_OPERATION) -} - -// Watch duration metric -private val WATCH_DURATION_METRIC = metric { - watchDuration = watchDurationParams { - maximumWatchDurationPerUser = MAXIMUM_WATCH_DURATION_PER_USER - } - cumulative = false - setOperations.add(NAMED_WATCH_DURATION_SET_OPERATION) -} -private val INTERNAL_WATCH_DURATION_METRIC = internalMetric { - details = - InternalMetricKt.details { - watchDuration = - InternalMetricKt.watchDurationParams { - maximumWatchDurationPerUserSeconds = MAXIMUM_WATCH_DURATION_PER_USER - } - cumulative = false - } - namedSetOperations.add(INTERNAL_NAMED_WATCH_DURATION_SET_OPERATION) -} - -// Internal reports with running states -// Internal reports of reach -private val INTERNAL_PENDING_REACH_REPORT = internalReport { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[0] - periodicTimeInterval = INTERNAL_PERIODIC_TIME_INTERVAL - metrics.add(INTERNAL_REACH_METRIC) - state = InternalReport.State.RUNNING - measurements.put(REACH_MEASUREMENT_KEY.measurementId, INTERNAL_PENDING_REACH_MEASUREMENT) - details = InternalReportKt.details { eventGroupFilters.putAll(EVENT_GROUP_FILTERS_MAP) } - createTime = timestamp { seconds = 1000 } - reportIdempotencyKey = REACH_REPORT_IDEMPOTENCY_KEY -} -private val INTERNAL_SUCCEEDED_REACH_REPORT = - INTERNAL_PENDING_REACH_REPORT.copy { - state = InternalReport.State.SUCCEEDED - measurements.put(REACH_MEASUREMENT_KEY.measurementId, INTERNAL_SUCCEEDED_REACH_MEASUREMENT) - } - -// Internal reports of impression -private val INTERNAL_PENDING_IMPRESSION_REPORT = internalReport { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[1] - periodicTimeInterval = INTERNAL_PERIODIC_TIME_INTERVAL - metrics.add(INTERNAL_IMPRESSION_METRIC) - state = InternalReport.State.RUNNING - measurements.put( - IMPRESSION_MEASUREMENT_KEY.measurementId, - INTERNAL_PENDING_IMPRESSION_MEASUREMENT, - ) - details = InternalReportKt.details { eventGroupFilters.putAll(EVENT_GROUP_FILTERS_MAP) } - createTime = timestamp { seconds = 2000 } - reportIdempotencyKey = IMPRESSION_REPORT_IDEMPOTENCY_KEY -} -private val INTERNAL_SUCCEEDED_IMPRESSION_REPORT = - INTERNAL_PENDING_IMPRESSION_REPORT.copy { - state = InternalReport.State.SUCCEEDED - measurements.put( - IMPRESSION_MEASUREMENT_KEY.measurementId, - INTERNAL_SUCCEEDED_IMPRESSION_MEASUREMENT, - ) - } - -// Internal reports of watch duration -private val INTERNAL_PENDING_WATCH_DURATION_REPORT = internalReport { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - periodicTimeInterval = INTERNAL_PERIODIC_TIME_INTERVAL - metrics.add(INTERNAL_WATCH_DURATION_METRIC) - state = InternalReport.State.RUNNING - measurements.put( - WATCH_DURATION_MEASUREMENT_KEY.measurementId, - INTERNAL_PENDING_WATCH_DURATION_MEASUREMENT, - ) - details = InternalReportKt.details { eventGroupFilters.putAll(EVENT_GROUP_FILTERS_MAP) } - createTime = timestamp { seconds = 3000 } - reportIdempotencyKey = WATCH_DURATION_REPORT_IDEMPOTENCY_KEY -} -private val INTERNAL_SUCCEEDED_WATCH_DURATION_REPORT = - INTERNAL_PENDING_WATCH_DURATION_REPORT.copy { - state = InternalReport.State.SUCCEEDED - measurements.put( - WATCH_DURATION_MEASUREMENT_KEY.measurementId, - INTERNAL_SUCCEEDED_WATCH_DURATION_MEASUREMENT, - ) - } - -// Internal reports of frequency histogram -private val INTERNAL_PENDING_FREQUENCY_HISTOGRAM_REPORT = internalReport { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[3] - periodicTimeInterval = INTERNAL_PERIODIC_TIME_INTERVAL - metrics.add(INTERNAL_FREQUENCY_HISTOGRAM_METRIC) - state = InternalReport.State.RUNNING - measurements.put( - FREQUENCY_HISTOGRAM_MEASUREMENT_KEY.measurementId, - INTERNAL_PENDING_FREQUENCY_HISTOGRAM_MEASUREMENT, - ) - details = InternalReportKt.details { eventGroupFilters.putAll(EVENT_GROUP_FILTERS_MAP) } - createTime = timestamp { seconds = 4000 } - reportIdempotencyKey = FREQUENCY_HISTOGRAM_REPORT_IDEMPOTENCY_KEY -} -private val INTERNAL_SUCCEEDED_FREQUENCY_HISTOGRAM_REPORT = - INTERNAL_PENDING_FREQUENCY_HISTOGRAM_REPORT.copy { - state = InternalReport.State.SUCCEEDED - measurements.put( - FREQUENCY_HISTOGRAM_MEASUREMENT_KEY.measurementId, - INTERNAL_SUCCEEDED_FREQUENCY_HISTOGRAM_MEASUREMENT, - ) - } - -private val EVENT_GROUP_UNIVERSE = eventGroupUniverse { - eventGroupEntries += - COVERED_EVENT_GROUP_KEYS.map { - EventGroupUniverseKt.eventGroupEntry { - key = it.toName() - value = EVENT_GROUP_FILTER - } - } -} - -private val EVENT_GROUP_UNIVERSE_WITH_ONE_DATA_PROVIDER = eventGroupUniverse { - eventGroupEntries += - EventGroupUniverseKt.eventGroupEntry { - key = COVERED_EVENT_GROUP_KEYS[0].toName() - value = EVENT_GROUP_FILTER - } -} - -// Public reports with running states -// Reports of reach -private val PENDING_REACH_REPORT = report { - name = REPORT_NAMES[0] - reportIdempotencyKey = REACH_REPORT_IDEMPOTENCY_KEY - measurementConsumer = MEASUREMENT_CONSUMERS.values.first().name - eventGroupUniverse = EVENT_GROUP_UNIVERSE - periodicTimeInterval = PERIODIC_TIME_INTERVAL - metrics.add(REACH_METRIC) - state = Report.State.RUNNING -} -private val SUCCEEDED_REACH_REPORT = PENDING_REACH_REPORT.copy { state = Report.State.SUCCEEDED } - -// Reports of impression -private val PENDING_IMPRESSION_REPORT = report { - name = REPORT_NAMES[1] - reportIdempotencyKey = IMPRESSION_REPORT_IDEMPOTENCY_KEY - measurementConsumer = MEASUREMENT_CONSUMERS.values.first().name - eventGroupUniverse = EVENT_GROUP_UNIVERSE - periodicTimeInterval = PERIODIC_TIME_INTERVAL - metrics.add(IMPRESSION_METRIC) - state = Report.State.RUNNING -} -private val SUCCEEDED_IMPRESSION_REPORT = - PENDING_IMPRESSION_REPORT.copy { state = Report.State.SUCCEEDED } - -// Reports of watch duration -private val PENDING_WATCH_DURATION_REPORT = report { - name = REPORT_NAMES[2] - reportIdempotencyKey = WATCH_DURATION_REPORT_IDEMPOTENCY_KEY - measurementConsumer = MEASUREMENT_CONSUMERS.values.first().name - eventGroupUniverse = EVENT_GROUP_UNIVERSE - periodicTimeInterval = PERIODIC_TIME_INTERVAL - metrics.add(WATCH_DURATION_METRIC) - state = Report.State.RUNNING -} -private val SUCCEEDED_WATCH_DURATION_REPORT = - PENDING_WATCH_DURATION_REPORT.copy { state = Report.State.SUCCEEDED } - -// Reports of frequency histogram -private val PENDING_FREQUENCY_HISTOGRAM_REPORT = report { - name = REPORT_NAMES[3] - reportIdempotencyKey = FREQUENCY_HISTOGRAM_REPORT_IDEMPOTENCY_KEY - measurementConsumer = MEASUREMENT_CONSUMERS.values.first().name - eventGroupUniverse = EVENT_GROUP_UNIVERSE - periodicTimeInterval = PERIODIC_TIME_INTERVAL - metrics.add(FREQUENCY_HISTOGRAM_METRIC) - state = Report.State.RUNNING -} -private val SUCCEEDED_FREQUENCY_HISTOGRAM_REPORT = - PENDING_FREQUENCY_HISTOGRAM_REPORT.copy { state = Report.State.SUCCEEDED } - -@RunWith(JUnit4::class) -class ReportsServiceTest { - - private val internalReportsMock: ReportsCoroutineImplBase = mockService { - onBlocking { createReport(any()) } - .thenReturn( - INTERNAL_PENDING_REACH_REPORT, - INTERNAL_PENDING_IMPRESSION_REPORT, - INTERNAL_PENDING_WATCH_DURATION_REPORT, - INTERNAL_PENDING_FREQUENCY_HISTOGRAM_REPORT, - ) - onBlocking { streamReports(any()) } - .thenReturn( - flowOf( - INTERNAL_PENDING_REACH_REPORT, - INTERNAL_PENDING_IMPRESSION_REPORT, - INTERNAL_PENDING_WATCH_DURATION_REPORT, - INTERNAL_PENDING_FREQUENCY_HISTOGRAM_REPORT, - ) - ) - onBlocking { getReport(any()) } - .thenReturn( - INTERNAL_SUCCEEDED_REACH_REPORT, - INTERNAL_SUCCEEDED_IMPRESSION_REPORT, - INTERNAL_SUCCEEDED_WATCH_DURATION_REPORT, - INTERNAL_SUCCEEDED_FREQUENCY_HISTOGRAM_REPORT, - ) - onBlocking { getReportByIdempotencyKey(any()) } - .thenThrow(StatusRuntimeException(Status.NOT_FOUND)) - } - - private val internalReportingSetsMock: InternalReportingSetsCoroutineImplBase = mockService { - onBlocking { batchGetReportingSet(any()) } - .thenReturn( - flowOf( - INTERNAL_REPORTING_SETS[0], - INTERNAL_REPORTING_SETS[1], - INTERNAL_REPORTING_SETS[0], - INTERNAL_REPORTING_SETS[1], - ) - ) - } - - private val internalMeasurementsMock: InternalMeasurementsCoroutineImplBase = mockService { - onBlocking { getMeasurement(any()) }.thenThrow(StatusRuntimeException(Status.NOT_FOUND)) - } - - private val measurementsMock: MeasurementsCoroutineImplBase = mockService { - onBlocking { getMeasurement(any()) } - .thenReturn( - SUCCEEDED_REACH_MEASUREMENT, - SUCCEEDED_IMPRESSION_MEASUREMENT, - SUCCEEDED_WATCH_DURATION_MEASUREMENT, - SUCCEEDED_FREQUENCY_HISTOGRAM_MEASUREMENT, - ) - - onBlocking { createMeasurement(any()) }.thenReturn(BASE_REACH_MEASUREMENT) - } - - private val measurementConsumersMock: MeasurementConsumersCoroutineImplBase = mockService { - onBlocking { getMeasurementConsumer(any()) }.thenReturn(MEASUREMENT_CONSUMERS.values.first()) - } - - private val dataProvidersMock: DataProvidersCoroutineImplBase = mockService { - var stubbing = onBlocking { getDataProvider(any()) } - for (dataProvider in DATA_PROVIDERS.values) { - stubbing = stubbing.thenReturn(dataProvider) - } - } - - private val certificateMock: CertificatesCoroutineImplBase = mockService { - onBlocking { getCertificate(eq(getCertificateRequest { name = AGGREGATOR_CERTIFICATE.name })) } - .thenReturn(AGGREGATOR_CERTIFICATE) - for (dataProvider in DATA_PROVIDERS.values) { - onBlocking { getCertificate(eq(getCertificateRequest { name = dataProvider.certificate })) } - .thenReturn( - certificate { - name = dataProvider.certificate - x509Der = DATA_PROVIDER_SIGNING_KEY.certificate.encoded.toByteString() - } - ) - } - for (measurementConsumer in MEASUREMENT_CONSUMERS.values) { - onBlocking { - getCertificate(eq(getCertificateRequest { name = measurementConsumer.certificate })) - } - .thenReturn( - certificate { - name = measurementConsumer.certificate - x509Der = measurementConsumer.certificateDer - } - ) - } - } - - private val randomMock: Random = mock() - - @get:Rule - val grpcTestServerRule = GrpcTestServerRule { - addService(internalReportsMock) - addService(internalReportingSetsMock) - addService(internalMeasurementsMock) - addService(measurementsMock) - addService(measurementConsumersMock) - addService(dataProvidersMock) - addService(certificateMock) - } - - private lateinit var service: ReportsService - - @Before - fun initService() { - randomMock.stub { - on { nextInt(any()) } doReturn SECURE_RANDOM_OUTPUT_INT - on { nextLong() } doReturn SECURE_RANDOM_OUTPUT_LONG - } - - service = - ReportsService( - InternalReportsCoroutineStub(grpcTestServerRule.channel), - InternalReportingSetsCoroutineStub(grpcTestServerRule.channel), - InternalMeasurementsCoroutineStub(grpcTestServerRule.channel), - DataProvidersCoroutineStub(grpcTestServerRule.channel), - MeasurementConsumersCoroutineStub(grpcTestServerRule.channel), - MeasurementsCoroutineStub(grpcTestServerRule.channel), - CertificatesCoroutineStub(grpcTestServerRule.channel), - ENCRYPTION_KEY_PAIR_STORE, - randomMock, - SECRETS_DIR, - listOf(AGGREGATOR_ROOT_CERTIFICATE, DATA_PROVIDER_ROOT_CERTIFICATE).associateBy { - it.subjectKeyIdentifier!! - }, - MEASUREMENT_SPEC_CONFIG, - ) - } - - @Test - fun `createReport returns a report of reach with RUNNING state`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - val expected = PENDING_REACH_REPORT - - // Verify proto argument of ReportsCoroutineImplBase::getReportByIdempotencyKey - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReportByIdempotencyKey) - .isEqualTo( - getReportByIdempotencyKeyRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - reportIdempotencyKey = REACH_REPORT_IDEMPOTENCY_KEY - } - ) - - // Verify proto argument of InternalReportingSetsCoroutineImplBase::batchGetReportingSet - verifyProtoArgument( - internalReportingSetsMock, - InternalReportingSetsCoroutineImplBase::batchGetReportingSet, - ) - .ignoringRepeatedFieldOrder() - .isEqualTo( - batchGetReportingSetRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_REPORTING_SETS[0].externalReportingSetId - externalReportingSetIds += INTERNAL_REPORTING_SETS[1].externalReportingSetId - } - ) - - // Verify proto argument of MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - verifyProtoArgument( - measurementConsumersMock, - MeasurementConsumersCoroutineImplBase::getMeasurementConsumer, - ) - .isEqualTo(getMeasurementConsumerRequest { name = MEASUREMENT_CONSUMERS.values.first().name }) - - // Verify proto argument of DataProvidersCoroutineImplBase::getDataProvider - val dataProvidersCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(dataProvidersMock, times(2)) { getDataProvider(dataProvidersCaptor.capture()) } - val capturedDataProviderRequests = dataProvidersCaptor.allValues - assertThat(capturedDataProviderRequests) - .containsExactly( - getDataProviderRequest { name = DATA_PROVIDERS_LIST[0].name }, - getDataProviderRequest { name = DATA_PROVIDERS_LIST[1].name }, - ) - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val capturedMeasurementRequest = - captureFirst { - runBlocking { verify(measurementsMock).createMeasurement(capture()) } - } - assertThat(capturedMeasurementRequest) - .ignoringRepeatedFieldOrder() - .ignoringFieldDescriptors( - MEASUREMENT_SPEC_FIELD_DESCRIPTOR, - ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, - ) - .isEqualTo(REACH_MEASUREMENT_REQUEST) - - verifyMeasurementSpec( - capturedMeasurementRequest.measurement.measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - val measurementSpec: MeasurementSpec = - capturedMeasurementRequest.measurement.measurementSpec.unpack() - assertThat(measurementSpec).isEqualTo(REACH_MEASUREMENT_SPEC) - - val dataProvidersList = - capturedMeasurementRequest.measurement.dataProvidersList.sortedBy { it.key } - - dataProvidersList.map { dataProviderEntry -> - val signedRequisitionSpec = - decryptRequisitionSpec( - dataProviderEntry.value.encryptedRequisitionSpec, - DATA_PROVIDER_PRIVATE_KEY_HANDLE, - ) - val requisitionSpec: RequisitionSpec = signedRequisitionSpec.unpack() - verifyRequisitionSpec( - signedRequisitionSpec, - requisitionSpec, - measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - } - - // Verify proto argument of InternalMeasurementsCoroutineImplBase::batchCreateMeasurements - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::batchCreateMeasurements, - ) - .isEqualTo( - batchCreateMeasurementsRequest { measurements += INTERNAL_PENDING_REACH_MEASUREMENT } - ) - - // Verify proto argument of InternalReportsCoroutineImplBase::createReport - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::createReport) - .ignoringRepeatedFieldOrder() - .isEqualTo( - internalCreateReportRequest { - report = - INTERNAL_PENDING_REACH_REPORT.copy { - clearState() - clearExternalReportId() - measurements.clear() - clearCreateTime() - } - measurements += - InternalCreateReportRequestKt.measurementKey { - measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId - } - } - ) - - assertThat(result).isEqualTo(expected) - } - - @Test - fun `createReport creates reach single data provider measurement when report needs reach`() { - runBlocking { - whenever(internalReportingSetsMock.batchGetReportingSet(any())) - .thenReturn(flowOf(INTERNAL_REPORTING_SETS[0])) - } - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - eventGroupUniverse = EVENT_GROUP_UNIVERSE_WITH_ONE_DATA_PROVIDER - metrics.clear() - metrics += metric { - reach = reachParams {} - setOperations += namedSetOperation { - uniqueName = REACH_SET_OPERATION_UNIQUE_NAME - setOperation = setOperation { - type = SetOperation.Type.UNION - lhs = - SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[0].resourceName } - } - } - } - clearState() - } - } - - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val capturedMeasurementRequest = - captureFirst { - runBlocking { verify(measurementsMock).createMeasurement(capture()) } - } - assertThat(capturedMeasurementRequest) - .ignoringRepeatedFieldOrder() - .ignoringFieldDescriptors( - MEASUREMENT_SPEC_FIELD_DESCRIPTOR, - ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, - ) - .isEqualTo(REACH_SINGLE_DATA_PROVIDER_MEASUREMENT_REQUEST) - - verifyMeasurementSpec( - capturedMeasurementRequest.measurement.measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - val measurementSpec: MeasurementSpec = - capturedMeasurementRequest.measurement.measurementSpec.unpack() - assertThat(measurementSpec).isEqualTo(REACH_SINGLE_DATA_PROVIDER_MEASUREMENT_SPEC) - } - - @Test - fun `createReport creates rf single data provider measurement when report needs rf`() { - runBlocking { - whenever(internalReportingSetsMock.batchGetReportingSet(any())) - .thenReturn(flowOf(INTERNAL_REPORTING_SETS[0])) - } - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_FREQUENCY_HISTOGRAM_REPORT.copy { - eventGroupUniverse = EVENT_GROUP_UNIVERSE_WITH_ONE_DATA_PROVIDER - metrics.clear() - metrics += metric { - frequencyHistogram = frequencyHistogramParams { - maximumFrequencyPerUser = MAXIMUM_FREQUENCY - } - setOperations += namedSetOperation { - uniqueName = FREQUENCY_HISTOGRAM_SET_OPERATION_UNIQUE_NAME - setOperation = setOperation { - type = SetOperation.Type.UNION - lhs = - SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[0].resourceName } - } - } - } - clearState() - } - } - - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val capturedMeasurementRequest = - captureFirst { - runBlocking { verify(measurementsMock).createMeasurement(capture()) } - } - assertThat(capturedMeasurementRequest) - .ignoringRepeatedFieldOrder() - .ignoringFieldDescriptors( - MEASUREMENT_SPEC_FIELD_DESCRIPTOR, - ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, - ) - .isEqualTo(REACH_FREQUENCY_SINGLE_DATA_PROVIDER_MEASUREMENT_REQUEST) - - verifyMeasurementSpec( - capturedMeasurementRequest.measurement.measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - val measurementSpec: MeasurementSpec = - capturedMeasurementRequest.measurement.measurementSpec.unpack() - assertThat(measurementSpec).isEqualTo(REACH_FREQUENCY_SINGLE_DATA_PROVIDER_MEASUREMENT_SPEC) - } - - @Test - fun `createReport creates reach and requency measurement when report needs frequency`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_FREQUENCY_HISTOGRAM_REPORT.copy { clearState() } - } - - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val capturedMeasurementRequest = - captureFirst { - runBlocking { verify(measurementsMock).createMeasurement(capture()) } - } - assertThat(capturedMeasurementRequest) - .ignoringRepeatedFieldOrder() - .ignoringFieldDescriptors( - MEASUREMENT_SPEC_FIELD_DESCRIPTOR, - ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, - ) - .isEqualTo(REACH_FREQUENCY_MEASUREMENT_REQUEST) - - verifyMeasurementSpec( - capturedMeasurementRequest.measurement.measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - val measurementSpec: MeasurementSpec = - capturedMeasurementRequest.measurement.measurementSpec.unpack() - assertThat(measurementSpec).isEqualTo(REACH_FREQUENCY_MEASUREMENT_SPEC) - } - - @Test - fun `createReport creates impression measurement when report needs impression`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_IMPRESSION_REPORT.copy { clearState() } - } - - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val capturedMeasurementRequest = - captureFirst { - runBlocking { verify(measurementsMock).createMeasurement(capture()) } - } - assertThat(capturedMeasurementRequest) - .ignoringRepeatedFieldOrder() - .ignoringFieldDescriptors( - MEASUREMENT_SPEC_FIELD_DESCRIPTOR, - ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, - ) - .isEqualTo(IMPRESSION_MEASUREMENT_REQUEST) - - verifyMeasurementSpec( - capturedMeasurementRequest.measurement.measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - val measurementSpec: MeasurementSpec = - capturedMeasurementRequest.measurement.measurementSpec.unpack() - assertThat(measurementSpec).isEqualTo(IMPRESSION_MEASUREMENT_SPEC) - } - - @Test - fun `createReport creates duration measurement when report needs duration`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_WATCH_DURATION_REPORT.copy { clearState() } - } - - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val capturedMeasurementRequest = - captureFirst { - runBlocking { verify(measurementsMock).createMeasurement(capture()) } - } - assertThat(capturedMeasurementRequest) - .ignoringRepeatedFieldOrder() - .ignoringFieldDescriptors( - MEASUREMENT_SPEC_FIELD_DESCRIPTOR, - ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, - ) - .isEqualTo(WATCH_DURATION_MEASUREMENT_REQUEST) - - verifyMeasurementSpec( - capturedMeasurementRequest.measurement.measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - val measurementSpec: MeasurementSpec = - capturedMeasurementRequest.measurement.measurementSpec.unpack() - assertThat(measurementSpec).isEqualTo(WATCH_DURATION_MEASUREMENT_SPEC) - } - - @Test - fun `createReport returns a report of reach when no event filter at all`(): Unit = runBlocking { - val internalReportingSets: List = - INTERNAL_REPORTING_SETS.map { internalReportingSet -> - internalReportingSet.copy { - clearFilter() - displayName = "$measurementConsumerReferenceId-$externalReportingSetId-$filter" - } - } - - whenever( - internalReportingSetsMock.batchGetReportingSet( - eq( - batchGetReportingSetRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += internalReportingSets[0].externalReportingSetId - externalReportingSetIds += internalReportingSets[1].externalReportingSetId - } - ) - ) - ) - .thenReturn( - flowOf( - internalReportingSets[0], - internalReportingSets[1], - internalReportingSets[0], - internalReportingSets[1], - ) - ) - - val requestingReport = - PENDING_REACH_REPORT.copy { - clearState() - eventGroupUniverse = - EVENT_GROUP_UNIVERSE.copy { - eventGroupEntries.clear() - eventGroupEntries += - COVERED_EVENT_GROUP_KEYS.map { - EventGroupUniverseKt.eventGroupEntry { key = it.toName() } - } - } - } - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = requestingReport - } - - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val dataProviderEntries = - REQUISITION_SPECS.mapValues { (dataProviderKey, requisitionSpec) -> - val dataProvider = DATA_PROVIDERS.getValue(dataProviderKey) - dataProviderEntry { - key = dataProvider.name - - val requisitionSpecWithNoFilter = - requisitionSpec.copy { - events = - RequisitionSpecKt.events { - val eventGroupsWithNoFilter = - eventGroups.map { eventGroup -> - eventGroup.copy { - value = - EventGroupEntryKt.value { - collectionInterval = MEASUREMENT_TIME_INTERVAL - filter = eventFilter { expression = "" } - } - } - } - eventGroups.clear() - eventGroups += eventGroupsWithNoFilter - } - } - value = - DataProviderEntryKt.value { - dataProviderCertificate = dataProvider.certificate - dataProviderPublicKey = dataProvider.publicKey.message - encryptedRequisitionSpec = - encryptRequisitionSpec( - signRequisitionSpec( - requisitionSpecWithNoFilter, - MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE, - ), - dataProvider.publicKey.unpack(), - ) - nonceHash = Hashing.hashSha256(requisitionSpecWithNoFilter.nonce) - } - } - } - - val reachMeasurementRequest = - REACH_MEASUREMENT_REQUEST.copy { - measurement = - measurement.copy { - dataProviders.clear() - dataProviders += - DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { dataProviderEntries.getValue(it) } - } - } - - val capturedMeasurementRequest = - captureFirst { - runBlocking { verify(measurementsMock).createMeasurement(capture()) } - } - assertThat(capturedMeasurementRequest) - .ignoringRepeatedFieldOrder() - .ignoringFieldDescriptors( - MEASUREMENT_SPEC_FIELD_DESCRIPTOR, - ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, - ) - .isEqualTo(reachMeasurementRequest) - - verifyMeasurementSpec( - capturedMeasurementRequest.measurement.measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - val measurementSpec: MeasurementSpec = - capturedMeasurementRequest.measurement.measurementSpec.unpack() - assertThat(measurementSpec).isEqualTo(REACH_MEASUREMENT_SPEC) - - val dataProvidersList = - capturedMeasurementRequest.measurement.dataProvidersList.sortedBy { it.key } - - val filters = - dataProvidersList.flatMap { dataProviderEntry -> - val signedRequisitionSpec = - decryptRequisitionSpec( - dataProviderEntry.value.encryptedRequisitionSpec, - DATA_PROVIDER_PRIVATE_KEY_HANDLE, - ) - val requisitionSpec: RequisitionSpec = signedRequisitionSpec.unpack() - - verifyRequisitionSpec( - signedRequisitionSpec, - requisitionSpec, - measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - - requisitionSpec.events.eventGroupsList.map { eventGroupEntry -> - eventGroupEntry.value.filter.expression - } - } - - for (filter in filters) { - assertThat(filter).isEqualTo("") - } - } - - @Test - fun `createReport returns a report of reach with RUNNING state when timeIntervals set`() { - val internalReport = - INTERNAL_PENDING_REACH_REPORT.copy { - clearTime() - timeIntervals = internalTimeIntervals { - timeIntervals += internalTimeInterval { - startTime = START_TIME - endTime = Timestamps.add(START_TIME, TIME_INTERVAL_INCREMENT) - } - } - } - runBlocking { whenever(internalReportsMock.createReport(any())).thenReturn(internalReport) } - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - timeIntervals = timeIntervals { - timeIntervals += timeInterval { - startTime = START_TIME - endTime = Timestamps.add(START_TIME, TIME_INTERVAL_INCREMENT) - } - } - } - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - val expected = - PENDING_REACH_REPORT.copy { - clearTime() - timeIntervals = timeIntervals { - timeIntervals += timeInterval { - startTime = START_TIME - endTime = Timestamps.add(START_TIME, TIME_INTERVAL_INCREMENT) - } - } - } - - // Verify proto argument of ReportsCoroutineImplBase::getReportByIdempotencyKey - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReportByIdempotencyKey) - .isEqualTo( - getReportByIdempotencyKeyRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - reportIdempotencyKey = REACH_REPORT_IDEMPOTENCY_KEY - } - ) - - // Verify proto argument of InternalReportingSetsCoroutineImplBase::batchGetReportingSet - verifyProtoArgument( - internalReportingSetsMock, - InternalReportingSetsCoroutineImplBase::batchGetReportingSet, - ) - .ignoringRepeatedFieldOrder() - .isEqualTo( - batchGetReportingSetRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_REPORTING_SETS[0].externalReportingSetId - externalReportingSetIds += INTERNAL_REPORTING_SETS[1].externalReportingSetId - } - ) - - // Verify proto argument of MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - verifyProtoArgument( - measurementConsumersMock, - MeasurementConsumersCoroutineImplBase::getMeasurementConsumer, - ) - .isEqualTo(getMeasurementConsumerRequest { name = MEASUREMENT_CONSUMERS.values.first().name }) - - // Verify proto argument of DataProvidersCoroutineImplBase::getDataProvider - val dataProvidersCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(dataProvidersMock, times(2)) { getDataProvider(dataProvidersCaptor.capture()) } - val capturedDataProviderRequests = dataProvidersCaptor.allValues - assertThat(capturedDataProviderRequests) - .containsExactly( - getDataProviderRequest { name = DATA_PROVIDERS_LIST[0].name }, - getDataProviderRequest { name = DATA_PROVIDERS_LIST[1].name }, - ) - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val capturedMeasurementRequest = - captureFirst { - runBlocking { verify(measurementsMock).createMeasurement(capture()) } - } - - assertThat(capturedMeasurementRequest) - .ignoringRepeatedFieldOrder() - .ignoringFieldDescriptors( - MEASUREMENT_SPEC_FIELD_DESCRIPTOR, - ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, - ) - .isEqualTo(REACH_MEASUREMENT_REQUEST) - - verifyMeasurementSpec( - capturedMeasurementRequest.measurement.measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - val measurementSpec: MeasurementSpec = - capturedMeasurementRequest.measurement.measurementSpec.unpack() - assertThat(measurementSpec).isEqualTo(REACH_MEASUREMENT_SPEC) - - val dataProvidersList = - capturedMeasurementRequest.measurement.dataProvidersList.sortedBy { it.key } - - dataProvidersList.map { dataProviderEntry -> - val signedRequisitionSpec = - decryptRequisitionSpec( - dataProviderEntry.value.encryptedRequisitionSpec, - DATA_PROVIDER_PRIVATE_KEY_HANDLE, - ) - val requisitionSpec: RequisitionSpec = signedRequisitionSpec.unpack() - verifyRequisitionSpec( - signedRequisitionSpec, - requisitionSpec, - measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - } - - // Verify proto argument of InternalMeasurementsCoroutineImplBase::batchCreateMeasurements - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::batchCreateMeasurements, - ) - .isEqualTo( - batchCreateMeasurementsRequest { measurements += INTERNAL_PENDING_REACH_MEASUREMENT } - ) - - // Verify proto argument of InternalReportsCoroutineImplBase::createReport - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::createReport) - .ignoringRepeatedFieldOrder() - .isEqualTo( - internalCreateReportRequest { - report = - internalReport.copy { - clearState() - clearExternalReportId() - measurements.clear() - clearCreateTime() - } - measurements += - InternalCreateReportRequestKt.measurementKey { - measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId - } - } - ) - - assertThat(result).isEqualTo(expected) - } - - @Test - fun `createReport returns a report with a cumulative metric`() { - val internalCumulativeReport = - INTERNAL_PENDING_REACH_REPORT.copy { - metrics.clear() - metrics += - INTERNAL_PENDING_REACH_REPORT.metricsList[0].copy { - details = - INTERNAL_PENDING_REACH_REPORT.metricsList[0].details.copy { cumulative = true } - } - } - runBlocking { - whenever(internalReportsMock.createReport(any())).thenReturn(internalCumulativeReport) - } - - val cumulativeReport = - PENDING_REACH_REPORT.copy { - metrics.clear() - metrics += PENDING_REACH_REPORT.metricsList[0].copy { cumulative = true } - } - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = cumulativeReport.copy { clearState() } - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - // Verify proto argument of ReportsCoroutineImplBase::getReportByIdempotencyKey - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReportByIdempotencyKey) - .isEqualTo( - getReportByIdempotencyKeyRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - reportIdempotencyKey = REACH_REPORT_IDEMPOTENCY_KEY - } - ) - - // Verify proto argument of InternalReportingSetsCoroutineImplBase::batchGetReportingSet - verifyProtoArgument( - internalReportingSetsMock, - InternalReportingSetsCoroutineImplBase::batchGetReportingSet, - ) - .ignoringRepeatedFieldOrder() - .isEqualTo( - batchGetReportingSetRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_REPORTING_SETS[0].externalReportingSetId - externalReportingSetIds += INTERNAL_REPORTING_SETS[1].externalReportingSetId - } - ) - - // Verify proto argument of MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - verifyProtoArgument( - measurementConsumersMock, - MeasurementConsumersCoroutineImplBase::getMeasurementConsumer, - ) - .isEqualTo(getMeasurementConsumerRequest { name = MEASUREMENT_CONSUMERS.values.first().name }) - - // Verify proto argument of DataProvidersCoroutineImplBase::getDataProvider - val dataProvidersCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(dataProvidersMock, times(2)) { getDataProvider(dataProvidersCaptor.capture()) } - val capturedDataProviderRequests = dataProvidersCaptor.allValues - assertThat(capturedDataProviderRequests) - .containsExactly( - getDataProviderRequest { name = DATA_PROVIDERS_LIST[0].name }, - getDataProviderRequest { name = DATA_PROVIDERS_LIST[1].name }, - ) - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val capturedMeasurementRequest = - captureFirst { - runBlocking { verify(measurementsMock).createMeasurement(capture()) } - } - assertThat(capturedMeasurementRequest) - .ignoringRepeatedFieldOrder() - .ignoringFieldDescriptors( - MEASUREMENT_SPEC_FIELD_DESCRIPTOR, - ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, - ) - .isEqualTo(REACH_MEASUREMENT_REQUEST) - - verifyMeasurementSpec( - capturedMeasurementRequest.measurement.measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - val measurementSpec: MeasurementSpec = - capturedMeasurementRequest.measurement.measurementSpec.unpack() - assertThat(measurementSpec).isEqualTo(REACH_MEASUREMENT_SPEC) - - val dataProvidersList = - capturedMeasurementRequest.measurement.dataProvidersList.sortedBy { it.key } - - dataProvidersList.map { dataProviderEntry -> - val signedRequisitionSpec = - decryptRequisitionSpec( - dataProviderEntry.value.encryptedRequisitionSpec, - DATA_PROVIDER_PRIVATE_KEY_HANDLE, - ) - val requisitionSpec: RequisitionSpec = signedRequisitionSpec.unpack() - verifyRequisitionSpec( - signedRequisitionSpec, - requisitionSpec, - measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - } - - // Verify proto argument of InternalMeasurementsCoroutineImplBase::batchCreateMeasurements - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::batchCreateMeasurements, - ) - .isEqualTo( - batchCreateMeasurementsRequest { measurements += INTERNAL_PENDING_REACH_MEASUREMENT } - ) - - // Verify proto argument of InternalReportsCoroutineImplBase::createReport - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::createReport) - .ignoringRepeatedFieldOrder() - .isEqualTo( - internalCreateReportRequest { - report = - internalCumulativeReport.copy { - clearState() - clearExternalReportId() - measurements.clear() - clearCreateTime() - } - measurements += - InternalCreateReportRequestKt.measurementKey { - measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId - } - } - ) - - assertThat(result).isEqualTo(cumulativeReport) - } - - @Test - fun `createReport returns a report with set operation type DIFFERENCE`() { - val internalPendingReachReportWithSetDifference = - INTERNAL_PENDING_REACH_REPORT.copy { - val source = this - measurements.clear() - clearCreateTime() - val metric = internalMetric { - details = InternalMetricKt.details { reach = InternalMetricKt.reachParams {} } - namedSetOperations += - source.metrics[0].namedSetOperationsList[0].copy { - setOperation = - setOperation.copy { type = InternalMetric.SetOperation.Type.DIFFERENCE } - measurementCalculations.clear() - measurementCalculations += - source.metrics[0].namedSetOperationsList[0].measurementCalculationsList[0].copy { - weightedMeasurements.clear() - weightedMeasurements += weightedMeasurement { - measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId - coefficient = -1 - } - weightedMeasurements += weightedMeasurement { - measurementReferenceId = REACH_MEASUREMENT_KEY_2.measurementId - coefficient = 1 - } - } - } - } - metrics.clear() - metrics += metric - } - - runBlocking { - whenever(internalReportsMock.createReport(any())) - .thenReturn(internalPendingReachReportWithSetDifference) - whenever(measurementsMock.createMeasurement(any())) - .thenReturn(BASE_REACH_MEASUREMENT, BASE_REACH_MEASUREMENT_2) - } - - val pendingReachReportWithSetDifference = - PENDING_REACH_REPORT.copy { - metrics.clear() - metrics += metric { - reach = reachParams {} - cumulative = false - setOperations += namedSetOperation { - uniqueName = REACH_SET_OPERATION_UNIQUE_NAME - setOperation = setOperation { - type = SetOperation.Type.DIFFERENCE - lhs = - SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[0].resourceName } - rhs = - SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[1].resourceName } - } - } - } - } - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = pendingReachReportWithSetDifference - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - // Verify proto argument of ReportsCoroutineImplBase::getReportByIdempotencyKey - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReportByIdempotencyKey) - .isEqualTo( - getReportByIdempotencyKeyRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - reportIdempotencyKey = REACH_REPORT_IDEMPOTENCY_KEY - } - ) - - // Verify proto argument of InternalReportingSetsCoroutineImplBase::batchGetReportingSet - verifyProtoArgument( - internalReportingSetsMock, - InternalReportingSetsCoroutineImplBase::batchGetReportingSet, - ) - .ignoringRepeatedFieldOrder() - .isEqualTo( - batchGetReportingSetRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_REPORTING_SETS[0].externalReportingSetId - externalReportingSetIds += INTERNAL_REPORTING_SETS[1].externalReportingSetId - } - ) - - // Verify proto argument of MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - verifyProtoArgument( - measurementConsumersMock, - MeasurementConsumersCoroutineImplBase::getMeasurementConsumer, - ) - .isEqualTo(getMeasurementConsumerRequest { name = MEASUREMENT_CONSUMERS.values.first().name }) - - // Verify proto argument of DataProvidersCoroutineImplBase::getDataProvider - val dataProvidersCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(dataProvidersMock, times(2)) { getDataProvider(dataProvidersCaptor.capture()) } - val capturedDataProviderRequests = dataProvidersCaptor.allValues - assertThat(capturedDataProviderRequests) - .containsExactly( - getDataProviderRequest { name = DATA_PROVIDERS_LIST[0].name }, - getDataProviderRequest { name = DATA_PROVIDERS_LIST[1].name }, - ) - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val measurementCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(measurementsMock, times(2)) { createMeasurement(measurementCaptor.capture()) } - assertThat(measurementCaptor.allValues.map { it.measurement }).containsNoDuplicates() - - // Verify proto argument of InternalMeasurementsCoroutineImplBase::batchCreateMeasurements - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::batchCreateMeasurements, - ) - .ignoringRepeatedFieldOrder() - .isEqualTo( - batchCreateMeasurementsRequest { - measurements += INTERNAL_PENDING_REACH_MEASUREMENT - measurements += - INTERNAL_PENDING_REACH_MEASUREMENT.copy { - measurementReferenceId = REACH_MEASUREMENT_KEY_2.measurementId - } - } - ) - - // Verify proto argument of InternalReportsCoroutineImplBase::createReport - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::createReport) - .ignoringRepeatedFieldOrder() - .isEqualTo( - internalCreateReportRequest { - report = - internalPendingReachReportWithSetDifference.copy { - clearState() - clearExternalReportId() - } - measurements += - InternalCreateReportRequestKt.measurementKey { - measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId - } - measurements += - InternalCreateReportRequestKt.measurementKey { - measurementConsumerReferenceId = REACH_MEASUREMENT_KEY_2.measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_KEY_2.measurementId - } - } - ) - - assertThat(result).isEqualTo(pendingReachReportWithSetDifference) - } - - @Test - fun `createReport succeeds when the internal createMeasurement throws ALREADY_EXISTS`() = - runBlocking { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val report = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - assertThat(report.state).isEqualTo(Report.State.RUNNING) - } - - @Test - fun `createReport throws UNAUTHENTICATED when no principal is found`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - val exception = - assertFailsWith { runBlocking { service.createReport(request) } } - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - } - - @Test - fun `createReport throws PERMISSION_DENIED when MeasurementConsumer caller doesn't match`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.last().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) - assertThat(exception.status.description) - .isEqualTo("Cannot create a Report for another MeasurementConsumer.") - } - - @Test - fun `createReport throws PERMISSION_DENIED when report doesn't belong to caller`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.last().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) - assertThat(exception.status.description) - .isEqualTo("Cannot create a Report for another MeasurementConsumer.") - } - - @Test - fun `createReport throws UNAUTHENTICATED when the caller is not MeasurementConsumer`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - val exception = - assertFailsWith { - withDataProviderPrincipal(DATA_PROVIDERS_LIST[0].name) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - assertThat(exception.status.description).isEqualTo("No ReportingPrincipal found") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when parent is unspecified`() { - val request = createReportRequest { report = PENDING_REACH_REPORT.copy { clearState() } } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description).isEqualTo("Parent is either unspecified or invalid.") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when report is unspecified`() { - val request = createReportRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description).isEqualTo("Report is not specified.") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when reportIdempotencyKey is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearReportIdempotencyKey() - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description).isEqualTo("ReportIdempotencyKey is not specified.") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when eventGroupUniverse in Report is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearEventGroupUniverse() - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description).isEqualTo("EventGroupUniverse is not specified.") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when eventGroupUniverse entries list is empty`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - eventGroupUniverse = eventGroupUniverse { eventGroupEntries.clear() } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when eventGroupUniverse entry is missing key`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - eventGroupUniverse = eventGroupUniverse { - eventGroupEntries += EventGroupUniverseKt.eventGroupEntry {} - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when setOperationName duplicate for same metricType`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - metrics.add(REACH_METRIC) - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description) - .isEqualTo("The names of the set operations within the same metric type should be unique.") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when time in Report is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description).isEqualTo("The time in Report is not specified.") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when TimeIntervals is set and cumulative is true`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - timeIntervals = timeIntervals { - timeIntervals += timeInterval { - startTime = timestamp { seconds = 1 } - endTime = timestamp { seconds = 5 } - } - } - metrics.clear() - metrics += PENDING_REACH_REPORT.metricsList[0].copy { cumulative = true } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when TimeIntervals timeIntervalsList is empty`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - timeIntervals = timeIntervals {} - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when TimeInterval startTime is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - timeIntervals = timeIntervals { - timeIntervals += timeInterval { endTime = timestamp { seconds = 5 } } - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when TimeInterval endTime is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - timeIntervals = timeIntervals { - timeIntervals += timeInterval { startTime = timestamp { seconds = 5 } } - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when TimeInterval endTime is before startTime`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - timeIntervals = timeIntervals { - timeIntervals += timeInterval { - startTime = timestamp { - seconds = 5 - nanos = 5 - } - endTime = timestamp { - seconds = 5 - nanos = 1 - } - } - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when PeriodicTimeInterval startTime is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - periodicTimeInterval = periodicTimeInterval { - increment = duration { seconds = 5 } - intervalCount = 3 - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when PeriodicTimeInterval increment is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - periodicTimeInterval = periodicTimeInterval { - startTime = timestamp { - seconds = 5 - nanos = 5 - } - intervalCount = 3 - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when PeriodicTimeInterval intervalCount is 0`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - periodicTimeInterval = periodicTimeInterval { - startTime = timestamp { - seconds = 5 - nanos = 5 - } - increment = duration { seconds = 5 } - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when any metric type in Report is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - metrics.clear() - metrics.add(REACH_METRIC.copy { clearReach() }) - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description) - .isEqualTo("The metric type in Report is not specified.") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when metrics list is empty`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - metrics.clear() - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when namedSetOperation uniqueName is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - metrics.clear() - metrics += metric { - reach = reachParams {} - setOperations += namedSetOperation { - setOperation = setOperation { - type = SetOperation.Type.UNION - lhs = - SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[0].resourceName } - } - } - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when setOperation type is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - metrics.clear() - metrics += metric { - reach = reachParams {} - setOperations += namedSetOperation { - uniqueName = "name" - setOperation = setOperation { - lhs = - SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[0].resourceName } - } - } - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when setOperation lhs is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - metrics.clear() - metrics += metric { - reach = reachParams {} - setOperations += namedSetOperation { - uniqueName = "name" - setOperation = setOperation { type = SetOperation.Type.UNION } - } - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when provided reporting set name is invalid`() { - val invalidMetric = metric { - reach = reachParams {} - cumulative = false - setOperations.add( - NAMED_REACH_SET_OPERATION.copy { setOperation = SET_OPERATION_WITH_INVALID_REPORTING_SET } - ) - } - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - metrics.clear() - metrics.add(invalidMetric) - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description) - .isEqualTo("Invalid reporting set name $INVALID_REPORTING_SET_NAME.") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when any reporting set is not accessible to caller`() { - val invalidMetric = metric { - reach = reachParams {} - cumulative = false - setOperations.add( - NAMED_REACH_SET_OPERATION.copy { - setOperation = SET_OPERATION_WITH_INACCESSIBLE_REPORTING_SET - } - ) - } - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - metrics.clear() - metrics.add(invalidMetric) - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description) - .isEqualTo("No access to the reporting set [$REPORTING_SET_NAME_FOR_MC_2].") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when eventGroup isn't covered by eventGroupUniverse`() = - runBlocking { - whenever(internalReportingSetsMock.batchGetReportingSet(any())) - .thenReturn(flowOf(INTERNAL_REPORTING_SETS[0], UNCOVERED_INTERNAL_REPORTING_SET)) - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - val expectedExceptionDescription = - "The event group [$UNCOVERED_EVENT_GROUP_NAME] in the reporting set" + - " [${UNCOVERED_INTERNAL_REPORTING_SET.displayName}] is not included in the event group " + - "universe." - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description).isEqualTo(expectedExceptionDescription) - } - - @Test - fun `createReport throws NOT_FOUND when reporting set is not found`() = runBlocking { - whenever(internalReportingSetsMock.batchGetReportingSet(any())) - .thenReturn(flowOf(INTERNAL_REPORTING_SETS[0])) - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.NOT_FOUND) - } - - @Test - fun `createReport throws FAILED_PRECONDITION when EDP cert is revoked`() = runBlocking { - val dataProvider = DATA_PROVIDERS.values.first() - whenever( - certificateMock.getCertificate( - eq(getCertificateRequest { name = dataProvider.certificate }) - ) - ) - .thenReturn( - certificate { - name = dataProvider.certificate - x509Der = DATA_PROVIDER_SIGNING_KEY.certificate.encoded.toByteString() - revocationState = Certificate.RevocationState.REVOKED - } - ) - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - - assertThat(exception).hasMessageThat().ignoringCase().contains("revoked") - } - - @Test - fun `createReport throws FAILED_PRECONDITION when EDP public key signature is invalid`() = - runBlocking { - val dataProvider = DATA_PROVIDERS.values.first() - whenever( - dataProvidersMock.getDataProvider(eq(getDataProviderRequest { name = dataProvider.name })) - ) - .thenReturn( - dataProvider.copy { - publicKey = publicKey.copy { signature = "invalid sig".toByteStringUtf8() } - } - ) - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - - assertThat(exception).hasMessageThat().ignoringCase().contains("signature") - } - - @Test - fun `createReport throws exception from getReportByIdempotencyKey when status isn't NOT_FOUND`() = - runBlocking { - whenever(internalReportsMock.getReportByIdempotencyKey(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - val expectedExceptionDescription = - "Unable to retrieve a report from the reporting database using the provided " + - "reportIdempotencyKey [${PENDING_REACH_REPORT.reportIdempotencyKey}]." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) - } - - @Test - fun `createReport throws exception when internal createReport throws exception`() = runBlocking { - val status = Status.INVALID_ARGUMENT.withDescription("Bad CreateReport request") - whenever(internalReportsMock.createReport(any())).thenThrow(StatusRuntimeException(status)) - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.cause).isInstanceOf(StatusException::class.java) - val actualStatus = (exception.cause as StatusException).status - assertThat(actualStatus.code).isEqualTo(status.code) - assertThat(actualStatus.description).isEqualTo(status.description) - } - - @Test - fun `createReport throws exception when the CMM createMeasurement throws exception`() = - runBlocking { - whenever(measurementsMock.createMeasurement(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.message).contains(REACH_MEASUREMENT_CREATE_REQUEST_ID) - } - - @Test - fun `createReport throws exception when the internal createMeasurement throws exception`() = - runBlocking { - whenever(internalMeasurementsMock.batchCreateMeasurements(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.UNKNOWN) - assertThat(exception.message).contains("Unable to create measurement") - } - - @Test - fun `createReport throws exception when getMeasurementConsumer throws exception`() = runBlocking { - whenever(measurementConsumersMock.getMeasurementConsumer(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - val expectedExceptionDescription = - "Unable to retrieve the measurement consumer [${MEASUREMENT_CONSUMERS.values.first().name}]." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) - } - - @Test - fun `createReport throws exception when the internal batchGetReportingSet throws exception`(): - Unit = runBlocking { - whenever(internalReportingSetsMock.batchGetReportingSet(any())) - .thenThrow(StatusRuntimeException(Status.UNKNOWN)) - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - assertFails { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - } - - @Test - fun `createReport throws exception when getDataProvider throws exception`() = runBlocking { - whenever(dataProvidersMock.getDataProvider(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception).hasMessageThat().contains("dataProviders/") - } - - @Test - fun `listReports returns without a next page token when there is no previous page token`() { - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { - reports.add(SUCCEEDED_REACH_REPORT) - reports.add(SUCCEEDED_IMPRESSION_REPORT) - reports.add(SUCCEEDED_WATCH_DURATION_REPORT) - reports.add(SUCCEEDED_FREQUENCY_HISTOGRAM_REPORT) - } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = DEFAULT_PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports returns with a next page token when there is no previous page token`() { - val request = listReportsRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - pageSize = PAGE_SIZE - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { - reports.add(SUCCEEDED_REACH_REPORT) - reports.add(SUCCEEDED_IMPRESSION_REPORT) - reports.add(SUCCEEDED_WATCH_DURATION_REPORT) - - nextPageToken = - listReportsPageToken { - pageSize = PAGE_SIZE - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - lastReport = previousPageEnd { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - } - } - .toByteString() - .base64UrlEncode() - } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports returns with a next page token when there is a previous page token`() { - val request = listReportsRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - pageSize = PAGE_SIZE - pageToken = - listReportsPageToken { - pageSize = PAGE_SIZE - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - lastReport = previousPageEnd { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[0] - } - } - .toByteString() - .base64UrlEncode() - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { - reports.add(SUCCEEDED_REACH_REPORT) - reports.add(SUCCEEDED_IMPRESSION_REPORT) - reports.add(SUCCEEDED_WATCH_DURATION_REPORT) - - nextPageToken = - listReportsPageToken { - pageSize = PAGE_SIZE - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - lastReport = previousPageEnd { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - } - } - .toByteString() - .base64UrlEncode() - } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportIdAfter = REPORT_EXTERNAL_IDS[0] - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports with page size replaced with a valid value and no previous page token`() { - val invalidPageSize = MAX_PAGE_SIZE * 2 - val request = listReportsRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - pageSize = invalidPageSize - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { - reports.add(SUCCEEDED_REACH_REPORT) - reports.add(SUCCEEDED_IMPRESSION_REPORT) - reports.add(SUCCEEDED_WATCH_DURATION_REPORT) - reports.add(SUCCEEDED_FREQUENCY_HISTOGRAM_REPORT) - } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = MAX_PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports with invalid page size replaced with the one in previous page token`() { - val invalidPageSize = MAX_PAGE_SIZE * 2 - val previousPageSize = PAGE_SIZE - val request = listReportsRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - pageSize = invalidPageSize - pageToken = - listReportsPageToken { - pageSize = previousPageSize - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - lastReport = previousPageEnd { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[0] - } - } - .toByteString() - .base64UrlEncode() - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { - reports.add(SUCCEEDED_REACH_REPORT) - reports.add(SUCCEEDED_IMPRESSION_REPORT) - reports.add(SUCCEEDED_WATCH_DURATION_REPORT) - - nextPageToken = - listReportsPageToken { - pageSize = previousPageSize - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - lastReport = previousPageEnd { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - } - } - .toByteString() - .base64UrlEncode() - } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = previousPageSize + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportIdAfter = REPORT_EXTERNAL_IDS[0] - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports with page size replacing the one in previous page token`() { - val newPageSize = PAGE_SIZE - val previousPageSize = 1 - val request = listReportsRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - pageSize = newPageSize - pageToken = - listReportsPageToken { - pageSize = previousPageSize - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - lastReport = previousPageEnd { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[0] - } - } - .toByteString() - .base64UrlEncode() - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { - reports.add(SUCCEEDED_REACH_REPORT) - reports.add(SUCCEEDED_IMPRESSION_REPORT) - reports.add(SUCCEEDED_WATCH_DURATION_REPORT) - - nextPageToken = - listReportsPageToken { - pageSize = newPageSize - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - lastReport = previousPageEnd { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - } - } - .toByteString() - .base64UrlEncode() - } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = newPageSize + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportIdAfter = REPORT_EXTERNAL_IDS[0] - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports throws UNAUTHENTICATED when no principal is found`() { - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - val exception = - assertFailsWith { runBlocking { service.listReports(request) } } - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - } - - @Test - fun `listReports throws PERMISSION_DENIED when MeasurementConsumer caller doesn't match`() { - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.last().name, CONFIG) { - runBlocking { service.listReports(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) - assertThat(exception.status.description) - .isEqualTo("Cannot list Reports belonging to other MeasurementConsumers.") - } - - @Test - fun `listReports throws UNAUTHENTICATED when the caller is not MeasurementConsumer`() { - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - val exception = - assertFailsWith { - withDataProviderPrincipal(DATA_PROVIDERS.values.first().name) { - runBlocking { service.listReports(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - assertThat(exception.status.description).isEqualTo("No ReportingPrincipal found") - } - - @Test - fun `listReports throws INVALID_ARGUMENT when page size is less than 0`() { - val request = listReportsRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - pageSize = -1 - } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description).isEqualTo("Page size cannot be less than 0") - } - - @Test - fun `listReports throws INVALID_ARGUMENT when parent is unspecified`() { - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(ListReportsRequest.getDefaultInstance()) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `listReports throws INVALID_ARGUMENT when mc id doesn't match one in page token`() { - val measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.last().measurementConsumerId - val request = listReportsRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - pageToken = - listReportsPageToken { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - lastReport = previousPageEnd { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - externalReportId = REPORT_EXTERNAL_IDS[0] - } - } - .toByteString() - .base64UrlEncode() - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `listReports throws Exception when the internal streamReports throws Exception`() = - runBlocking { - whenever(internalReportsMock.streamReports(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - } - val expectedExceptionDescription = "Unable to list reports from the reporting database." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) - } - - @Test - fun `listReports throws Exception when the internal getReport throws Exception`() = runBlocking { - whenever(internalReportsMock.getReport(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - } - val expectedExceptionDescription = - "Unable to get the report [${REPORT_NAMES[0]}] from the reporting database." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) - } - - @Test - fun `listReports throws Exception when the CMM getMeasurement throws Exception`() = runBlocking { - whenever(measurementsMock.getMeasurement(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - } - val expectedExceptionDescription = - "Unable to retrieve the measurement [${REACH_MEASUREMENT_KEY.toName()}]." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) - } - - @Test - fun `listReports throws Exception when the internal setMeasurementResult throws Exception`() = - runBlocking { - whenever(internalMeasurementsMock.setMeasurementResult(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - } - val expectedExceptionDescription = - "Unable to update the measurement [${REACH_MEASUREMENT_KEY.toName()}] in the reporting database." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) - } - - @Test - fun `listReports throws Exception when the internal setMeasurementFailure throws Exception`() = - runBlocking { - whenever(internalMeasurementsMock.setMeasurementFailure(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - whenever(internalReportsMock.streamReports(any())) - .thenReturn(flowOf(INTERNAL_PENDING_REACH_REPORT)) - whenever(measurementsMock.getMeasurement(any())) - .thenReturn( - PENDING_REACH_MEASUREMENT.copy { - state = Measurement.State.FAILED - failure = failure { - reason = Measurement.Failure.Reason.REQUISITION_REFUSED - message = "Privacy budget exceeded." - } - } - ) - whenever(internalReportsMock.getReport(any())) - .thenReturn(INTERNAL_PENDING_REACH_REPORT.copy { state = InternalReport.State.FAILED }) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - } - val expectedExceptionDescription = - "Unable to update the measurement [${REACH_MEASUREMENT_KEY.toName()}] in the reporting database." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) - } - - @Test - fun `listReports throws Exception when the getCertificate throws Exception`() = runBlocking { - whenever(certificateMock.getCertificate(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - } - - assertThat(exception).hasMessageThat().contains(AGGREGATOR_CERTIFICATE.name) - } - - @Test - fun `listReports returns reports with SUCCEEDED states when reports are already succeeded`() { - whenever(internalReportsMock.streamReports(any())) - .thenReturn( - flowOf( - INTERNAL_SUCCEEDED_REACH_REPORT, - INTERNAL_SUCCEEDED_IMPRESSION_REPORT, - INTERNAL_SUCCEEDED_WATCH_DURATION_REPORT, - INTERNAL_SUCCEEDED_FREQUENCY_HISTOGRAM_REPORT, - ) - ) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { - reports.add(SUCCEEDED_REACH_REPORT) - reports.add(SUCCEEDED_IMPRESSION_REPORT) - reports.add(SUCCEEDED_WATCH_DURATION_REPORT) - reports.add(SUCCEEDED_FREQUENCY_HISTOGRAM_REPORT) - } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = DEFAULT_PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports returns reports with FAILED states when reports are already failed`() { - whenever(internalReportsMock.streamReports(any())) - .thenReturn( - flowOf( - INTERNAL_PENDING_REACH_REPORT.copy { state = InternalReport.State.FAILED }, - INTERNAL_PENDING_IMPRESSION_REPORT.copy { state = InternalReport.State.FAILED }, - INTERNAL_PENDING_WATCH_DURATION_REPORT.copy { state = InternalReport.State.FAILED }, - INTERNAL_PENDING_FREQUENCY_HISTOGRAM_REPORT.copy { state = InternalReport.State.FAILED }, - ) - ) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { - reports.add(PENDING_REACH_REPORT.copy { state = Report.State.FAILED }) - reports.add(PENDING_IMPRESSION_REPORT.copy { state = Report.State.FAILED }) - reports.add(PENDING_WATCH_DURATION_REPORT.copy { state = Report.State.FAILED }) - reports.add(PENDING_FREQUENCY_HISTOGRAM_REPORT.copy { state = Report.State.FAILED }) - } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = DEFAULT_PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports returns reports with RUNNING states when measurements are PENDING`() = - runBlocking { - whenever(internalReportsMock.streamReports(any())) - .thenReturn(flowOf(INTERNAL_PENDING_REACH_REPORT)) - whenever(measurementsMock.getMeasurement(any())) - .thenReturn( - PENDING_REACH_MEASUREMENT.copy { - state = Measurement.State.COMPUTING - results.clear() - } - ) - whenever(internalReportsMock.getReport(any())).thenReturn(INTERNAL_PENDING_REACH_REPORT) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { reports.add(PENDING_REACH_REPORT) } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = DEFAULT_PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = REACH_MEASUREMENT_KEY.toName() }) - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReport) - .isEqualTo( - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[0] - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports returns reports with FAILED states when measurements are FAILED`() = - runBlocking { - whenever(internalReportsMock.streamReports(any())) - .thenReturn(flowOf(INTERNAL_PENDING_REACH_REPORT)) - whenever(measurementsMock.getMeasurement(any())) - .thenReturn( - PENDING_REACH_MEASUREMENT.copy { - state = Measurement.State.FAILED - failure = failure { - reason = Measurement.Failure.Reason.REQUISITION_REFUSED - message = "Privacy budget exceeded." - } - } - ) - whenever(internalReportsMock.getReport(any())) - .thenReturn(INTERNAL_PENDING_REACH_REPORT.copy { state = InternalReport.State.FAILED }) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { - reports.add(PENDING_REACH_REPORT.copy { state = Report.State.FAILED }) - } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = DEFAULT_PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = REACH_MEASUREMENT_KEY.toName() }) - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::setMeasurementFailure, - ) - .isEqualTo( - setMeasurementFailureRequest { - measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId - failure = - InternalMeasurementKt.failure { - reason = InternalMeasurement.Failure.Reason.REQUISITION_REFUSED - message = "Privacy budget exceeded." - } - } - ) - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReport) - .isEqualTo( - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[0] - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports returns reports with SUCCEEDED states when measurements are SUCCEEDED`() = - runBlocking { - whenever(internalReportsMock.streamReports(any())) - .thenReturn(flowOf(INTERNAL_PENDING_REACH_REPORT)) - whenever(measurementsMock.getMeasurement(any())).thenReturn(SUCCEEDED_REACH_MEASUREMENT) - whenever(internalReportsMock.getReport(any())).thenReturn(INTERNAL_SUCCEEDED_REACH_REPORT) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { reports.add(SUCCEEDED_REACH_REPORT) } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = DEFAULT_PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = REACH_MEASUREMENT_KEY.toName() }) - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::setMeasurementResult, - ) - .usingDoubleTolerance(1e-12) - .isEqualTo( - setMeasurementResultRequest { - measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId - this.result = - InternalMeasurementKt.result { - reach = InternalMeasurementResultKt.reach { value = REACH_VALUE } - frequency = - InternalMeasurementResultKt.frequency { - relativeFrequencyDistribution.putAll(FREQUENCY_DISTRIBUTION) - } - } - } - ) - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReport) - .isEqualTo( - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[0] - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports returns an impression report with aggregated results`() = runBlocking { - whenever(internalReportsMock.streamReports(any())) - .thenReturn(flowOf(INTERNAL_PENDING_IMPRESSION_REPORT)) - whenever(measurementsMock.getMeasurement(any())).thenReturn(SUCCEEDED_IMPRESSION_MEASUREMENT) - whenever(internalReportsMock.getReport(any())).thenReturn(INTERNAL_SUCCEEDED_IMPRESSION_REPORT) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { reports.add(SUCCEEDED_IMPRESSION_REPORT) } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = DEFAULT_PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = IMPRESSION_MEASUREMENT_KEY.toName() }) - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::setMeasurementResult, - ) - .isEqualTo( - setMeasurementResultRequest { - measurementConsumerReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementId - this.result = - InternalMeasurementKt.result { - impression = InternalMeasurementResultKt.impression { value = TOTAL_IMPRESSION_VALUE } - } - } - ) - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReport) - .isEqualTo( - getInternalReportRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[1] - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports returns a watch duration report with aggregated results`() = runBlocking { - whenever(internalReportsMock.streamReports(any())) - .thenReturn(flowOf(INTERNAL_PENDING_WATCH_DURATION_REPORT)) - whenever(measurementsMock.getMeasurement(any())) - .thenReturn(SUCCEEDED_WATCH_DURATION_MEASUREMENT) - whenever(internalReportsMock.getReport(any())) - .thenReturn(INTERNAL_SUCCEEDED_WATCH_DURATION_REPORT) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { reports.add(SUCCEEDED_WATCH_DURATION_REPORT) } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = DEFAULT_PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = WATCH_DURATION_MEASUREMENT_KEY.toName() }) - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::setMeasurementResult, - ) - .isEqualTo( - setMeasurementResultRequest { - measurementConsumerReferenceId = WATCH_DURATION_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = WATCH_DURATION_MEASUREMENT_KEY.measurementId - this.result = - InternalMeasurementKt.result { - watchDuration = - InternalMeasurementResultKt.watchDuration { value = TOTAL_WATCH_DURATION } - } - } - ) - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReport) - .isEqualTo( - getInternalReportRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `getReport returns the report with SUCCEEDED when the report is already succeeded`() = - runBlocking { - whenever(internalReportsMock.getReport(any())) - .thenReturn(INTERNAL_SUCCEEDED_WATCH_DURATION_REPORT) - - val request = getReportRequest { name = REPORT_NAMES[2] } - - val report = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.getReport(request) } - } - - assertThat(report).isEqualTo(SUCCEEDED_WATCH_DURATION_REPORT) - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReport) - .isEqualTo( - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - } - ) - } - - @Test - fun `getReport returns the report with FAILED when the report is already failed`() = runBlocking { - whenever(internalReportsMock.getReport(any())) - .thenReturn( - INTERNAL_PENDING_WATCH_DURATION_REPORT.copy { state = InternalReport.State.FAILED } - ) - - val request = getReportRequest { name = REPORT_NAMES[2] } - - val report = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.getReport(request) } - } - - assertThat(report).isEqualTo(PENDING_WATCH_DURATION_REPORT.copy { state = Report.State.FAILED }) - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReport) - .isEqualTo( - getInternalReportRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - } - ) - } - - @Test - fun `getReport returns the report with RUNNING when measurements are pending`(): Unit = - runBlocking { - whenever(internalReportsMock.getReport(any())) - .thenReturn(INTERNAL_PENDING_WATCH_DURATION_REPORT) - whenever(measurementsMock.getMeasurement(any())) - .thenReturn(PENDING_WATCH_DURATION_MEASUREMENT) - - val request = getReportRequest { name = REPORT_NAMES[2] } - - val report = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.getReport(request) } - } - - assertThat(report).isEqualTo(PENDING_WATCH_DURATION_REPORT) - - verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .comparingExpectedFieldsOnly() - .isEqualTo(getMeasurementRequest { name = WATCH_DURATION_MEASUREMENT_KEY.toName() }) - - val internalReportCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(internalReportsMock, times(2)) { getReport(internalReportCaptor.capture()) } - assertThat(internalReportCaptor.allValues) - .containsExactly( - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - }, - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - }, - ) - } - - @Test - fun `getReport syncs and returns an SUCCEEDED report with aggregated results`(): Unit = - runBlocking { - whenever(measurementsMock.getMeasurement(any())).thenReturn(SUCCEEDED_IMPRESSION_MEASUREMENT) - whenever(internalReportsMock.getReport(any())) - .thenReturn(INTERNAL_PENDING_IMPRESSION_REPORT, INTERNAL_SUCCEEDED_IMPRESSION_REPORT) - - val request = getReportRequest { name = REPORT_NAMES[1] } - - val report = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.getReport(request) } - } - - assertThat(report).isEqualTo(SUCCEEDED_IMPRESSION_REPORT) - - verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = IMPRESSION_MEASUREMENT_KEY.toName() }) - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::setMeasurementResult, - ) - .isEqualTo( - setMeasurementResultRequest { - measurementConsumerReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementId - this.result = - InternalMeasurementKt.result { - impression = - InternalMeasurementResultKt.impression { value = TOTAL_IMPRESSION_VALUE } - } - } - ) - - val internalReportCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(internalReportsMock, times(2)) { getReport(internalReportCaptor.capture()) } - assertThat(internalReportCaptor.allValues) - .containsExactly( - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[1] - }, - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[1] - }, - ) - } - - @Test - fun `getReport syncs and returns an FAILED report when measurements failed`(): Unit = - runBlocking { - whenever(measurementsMock.getMeasurement(any())) - .thenReturn( - BASE_IMPRESSION_MEASUREMENT.copy { - state = Measurement.State.FAILED - failure = failure { - reason = Measurement.Failure.Reason.REQUISITION_REFUSED - message = "Privacy budget exceeded." - } - } - ) - whenever(internalReportsMock.getReport(any())) - .thenReturn( - INTERNAL_PENDING_IMPRESSION_REPORT, - INTERNAL_PENDING_IMPRESSION_REPORT.copy { state = InternalReport.State.FAILED }, - ) - - val request = getReportRequest { name = REPORT_NAMES[1] } - - val report = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.getReport(request) } - } - - assertThat(report).isEqualTo(PENDING_IMPRESSION_REPORT.copy { state = Report.State.FAILED }) - - verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = IMPRESSION_MEASUREMENT_KEY.toName() }) - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::setMeasurementFailure, - ) - .isEqualTo( - setMeasurementFailureRequest { - measurementConsumerReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementId - failure = - InternalMeasurementKt.failure { - reason = InternalMeasurement.Failure.Reason.REQUISITION_REFUSED - message = "Privacy budget exceeded." - } - } - ) - - val internalReportCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(internalReportsMock, times(2)) { getReport(internalReportCaptor.capture()) } - assertThat(internalReportCaptor.allValues) - .containsExactly( - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[1] - }, - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[1] - }, - ) - } - - @Test - fun `getReport throws INVALID_ARGUMENT when Report name is invalid`() { - val request = getReportRequest { name = INVALID_REPORT_NAME } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.getReport(request) } - } - } - - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `getReport throws PERMISSION_DENIED when MeasurementConsumer's identity does not match`() { - val request = getReportRequest { name = REPORT_NAMES[0] } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.last().name, CONFIG) { - runBlocking { service.getReport(request) } - } - } - - assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) - } - - @Test - fun `getReport throws UNAUTHENTICATED when the caller is not a MeasurementConsumer`() { - val request = getReportRequest { name = REPORT_NAMES[0] } - - val exception = - assertFailsWith { - withDataProviderPrincipal(DATA_PROVIDERS.values.first().name) { - runBlocking { service.getReport(request) } - } - } - - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - } - - @Test - fun `getReport throws PERMISSION_DENIED when encryption private key not found`() = runBlocking { - whenever(internalReportsMock.getReport(any())) - .thenReturn(INTERNAL_PENDING_WATCH_DURATION_REPORT) - - whenever(measurementsMock.getMeasurement(any())) - .thenReturn( - SUCCEEDED_WATCH_DURATION_MEASUREMENT.copy { - val measurementSpec = measurementSpec { - measurementPublicKey = - MEASUREMENT_CONSUMER_PUBLIC_KEY.copy { data = INVALID_MEASUREMENT_PUBLIC_KEY_DATA } - .pack() - } - this.measurementSpec = - signMeasurementSpec(measurementSpec, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE) - } - ) - - val request = getReportRequest { name = REPORT_NAMES[2] } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.getReport(request) } - } - } - - assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) - assertThat(exception.status.description).contains("private key") - } - - @Test - fun `getReport throws Exception when the internal GetReport throws Exception`() = runBlocking { - whenever(internalReportsMock.getReport(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = getReportRequest { name = REPORT_NAMES[2] } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.getReport(request) } - } - } - val expectedExceptionDescription = "Unable to get the report from the reporting database." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) - } - - @Test - fun `toResult converts internal result to external result with the same content`() = runBlocking { - val internalResult = - InternalReportDetailsKt.result { - scalarTable = - InternalReportResultKt.scalarTable { - rowHeaders += listOf("row1", "row2", "row3") - columns += - InternalReportResultKt.column { - columnHeader = "column1" - setOperations += listOf(1.0, 2.0, 3.0) - } - } - histogramTables += - InternalReportResultKt.histogramTable { - rows += - InternalReportResultKt.HistogramTableKt.row { - rowHeader = "row4" - frequency = 100 - } - rows += - InternalReportResultKt.HistogramTableKt.row { - rowHeader = "row5" - frequency = 101 - } - columns += - InternalReportResultKt.column { - columnHeader = "column1" - setOperations += listOf(10.0, 11.0, 12.0) - } - columns += - InternalReportResultKt.column { - columnHeader = "column2" - setOperations += listOf(20.0, 21.0, 22.0) - } - } - } - - whenever(internalReportsMock.getReport(any())) - .thenReturn( - INTERNAL_SUCCEEDED_WATCH_DURATION_REPORT.copy { - details = InternalReportKt.details { result = internalResult } - } - ) - - val request = getReportRequest { name = REPORT_NAMES[2] } - - val report = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.getReport(request) } - } - - assertThat(report.result) - .isEqualTo( - ReportKt.result { - scalarTable = scalarTable { - rowHeaders += listOf("row1", "row2", "row3") - columns += column { - columnHeader = "column1" - setOperations += listOf(1.0, 2.0, 3.0) - } - } - histogramTables += histogramTable { - rows += row { - rowHeader = "row4" - frequency = 100 - } - rows += row { - rowHeader = "row5" - frequency = 101 - } - columns += column { - columnHeader = "column1" - setOperations += listOf(10.0, 11.0, 12.0) - } - columns += column { - columnHeader = "column2" - setOperations += listOf(20.0, 21.0, 22.0) - } - } - } - ) - } - - companion object { - private val MEASUREMENT_SPEC_FIELD_DESCRIPTOR = - Measurement.getDescriptor().findFieldByNumber(Measurement.MEASUREMENT_SPEC_FIELD_NUMBER) - private val ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR = - Measurement.DataProviderEntry.Value.getDescriptor() - .findFieldByNumber(ENCRYPTED_REQUISITION_SPEC_FIELD_NUMBER) - } -} - -private fun EventGroupKey.toInternal(): InternalReportingSet.EventGroupKey { - val source = this - return InternalReportingSetKt.eventGroupKey { - measurementConsumerReferenceId = source.measurementConsumerReferenceId - dataProviderReferenceId = source.dataProviderReferenceId - eventGroupReferenceId = source.eventGroupReferenceId - } -} - -private val InternalReportingSet.resourceKey: ReportingSetKey - get() = - ReportingSetKey(measurementConsumerReferenceId, ExternalId(externalReportingSetId).apiId.value) -private val InternalReportingSet.resourceName: String - get() = resourceKey.toName() diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/SetOperationCompilerTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/SetOperationCompilerTest.kt deleted file mode 100644 index 0fe757be12d..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/SetOperationCompilerTest.kt +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import com.google.common.truth.Truth.assertThat -import kotlinx.coroutines.runBlocking -import org.junit.Assert.assertThrows -import org.junit.Before -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.wfanet.measurement.common.identity.externalIdToApiId -import org.wfanet.measurement.reporting.v1alpha.Metric -import org.wfanet.measurement.reporting.v1alpha.MetricKt.SetOperationKt.operand -import org.wfanet.measurement.reporting.v1alpha.MetricKt.namedSetOperation -import org.wfanet.measurement.reporting.v1alpha.MetricKt.setOperation -import org.wfanet.measurement.reporting.v1alpha.copy - -// Measurement consumer IDs and names -private const val MEASUREMENT_CONSUMER_EXTERNAL_ID = 111L -private val MEASUREMENT_CONSUMER_REFERENCE_ID = externalIdToApiId(MEASUREMENT_CONSUMER_EXTERNAL_ID) - -// Reporting set IDs and names -private const val REPORTING_SET_EXTERNAL_ID = 331L -private const val REPORTING_SET_EXTERNAL_ID_2 = 332L -private const val REPORTING_SET_EXTERNAL_ID_3 = 333L - -private val REPORTING_SET_NAME = - ReportingSetKey(MEASUREMENT_CONSUMER_REFERENCE_ID, externalIdToApiId(REPORTING_SET_EXTERNAL_ID)) - .toName() -private val REPORTING_SET_NAME_2 = - ReportingSetKey(MEASUREMENT_CONSUMER_REFERENCE_ID, externalIdToApiId(REPORTING_SET_EXTERNAL_ID_2)) - .toName() -private val REPORTING_SET_NAME_3 = - ReportingSetKey(MEASUREMENT_CONSUMER_REFERENCE_ID, externalIdToApiId(REPORTING_SET_EXTERNAL_ID_3)) - .toName() - -private val EXPECTED_REPORTING_SET_NAMES_LIST_ALL_UNION = - listOf(REPORTING_SET_NAME, REPORTING_SET_NAME_2, REPORTING_SET_NAME_3).sorted() - -private val SET_OPERATION = setOperation { - type = Metric.SetOperation.Type.UNION - lhs = operand { reportingSet = REPORTING_SET_NAME } - rhs = operand { reportingSet = REPORTING_SET_NAME_2 } -} - -private val SET_OPERATION_ALL_UNION = setOperation { - type = Metric.SetOperation.Type.UNION - lhs = operand { operation = SET_OPERATION } - rhs = operand { reportingSet = REPORTING_SET_NAME_3 } -} - -// SetOperation = A + B + C - B -private val SET_OPERATION_ALL_UNION_BUT_ONE = setOperation { - type = Metric.SetOperation.Type.DIFFERENCE - lhs = operand { operation = SET_OPERATION_ALL_UNION } - rhs = operand { reportingSet = EXPECTED_REPORTING_SET_NAMES_LIST_ALL_UNION[1] } -} - -private const val SET_OPERATION_ALL_UNION_DISPLAY_NAME = "SET_OPERATION_ALL_UNION" -private const val SET_OPERATION_ALL_UNION_BUT_ONE_DISPLAY_NAME = "SET_OPERATION_ALL_UNION_BUT_ONE" - -private val NAMED_SET_OPERATION_ALL_UNION = namedSetOperation { - uniqueName = SET_OPERATION_ALL_UNION_DISPLAY_NAME - setOperation = SET_OPERATION_ALL_UNION -} - -private val NAMED_SET_OPERATION_ALL_UNION_BUT_ONE = namedSetOperation { - uniqueName = SET_OPERATION_ALL_UNION_BUT_ONE_DISPLAY_NAME - setOperation = SET_OPERATION_ALL_UNION_BUT_ONE -} - -private val EXPECTED_RESULT_FOR_ALL_UNION_SET_OPERATION = - listOf(WeightedMeasurement(EXPECTED_REPORTING_SET_NAMES_LIST_ALL_UNION, coefficient = 1)) - -private val EXPECTED_RESULT_FOR_ALL_UNION_BUT_ONE_SET_OPERATION = - listOf( - WeightedMeasurement(EXPECTED_REPORTING_SET_NAMES_LIST_ALL_UNION, coefficient = 1), - WeightedMeasurement(listOf(EXPECTED_REPORTING_SET_NAMES_LIST_ALL_UNION[1]), coefficient = -1), - ) - -private val EXPECTED_CACHE_FOR_ALL_UNION_SET_OPERATION = - mapOf( - EXPECTED_REPORTING_SET_NAMES_LIST_ALL_UNION.size to - mapOf( - 1UL to mapOf(6UL to -1, 7UL to 1), - 2UL to mapOf(5UL to -1, 7UL to 1), - 3UL to mapOf(4UL to -1, 5UL to 1, 6UL to 1, 7UL to -1), - 4UL to mapOf(3UL to -1, 7UL to 1), - 5UL to mapOf(2UL to -1, 3UL to 1, 6UL to 1, 7UL to -1), - 6UL to mapOf(1UL to -1, 3UL to 1, 5UL to 1, 7UL to -1), - 7UL to mapOf(1UL to 1, 2UL to 1, 3UL to -1, 4UL to 1, 5UL to -1, 6UL to -1, 7UL to 1), - ) - ) - -// {4: {3: -1, 7: 1}, 1: {6: -1, 7: 1}, 5: {2: -1, 3: 1, 6: 1, 7: -1}} -private val EXPECTED_CACHE_FOR_ALL_UNION_BUT_ONE_SET_OPERATION = - mapOf( - EXPECTED_REPORTING_SET_NAMES_LIST_ALL_UNION.size to - mapOf( - 1UL to mapOf(6UL to -1, 7UL to 1), - 4UL to mapOf(3UL to -1, 7UL to 1), - 5UL to mapOf(2UL to -1, 3UL to 1, 6UL to 1, 7UL to -1), - ) - ) - -@RunWith(JUnit4::class) -class SetOperationCompilerTest { - private lateinit var reportResultCompiler: SetOperationCompiler - - @Before - fun initService() { - reportResultCompiler = SetOperationCompiler() - } - - @Test - fun `compileSetOperation returns a list of weightedMeasurements and store it in the cache`() { - val resultAllUnionButOne = runBlocking { - reportResultCompiler.compileSetOperation(SET_OPERATION_ALL_UNION_BUT_ONE) - } - val primitiveRegionCacheAllUnionButOne = reportResultCompiler.getPrimitiveRegionCache() - - assertThat(resultAllUnionButOne) - .containsExactlyElementsIn(EXPECTED_RESULT_FOR_ALL_UNION_BUT_ONE_SET_OPERATION) - assertThat(primitiveRegionCacheAllUnionButOne) - .isEqualTo(EXPECTED_CACHE_FOR_ALL_UNION_BUT_ONE_SET_OPERATION) - - val resultAllUnion = runBlocking { - reportResultCompiler.compileSetOperation(SET_OPERATION_ALL_UNION) - } - val primitiveRegionCacheAllUnion = reportResultCompiler.getPrimitiveRegionCache() - - assertThat(resultAllUnion) - .containsExactlyElementsIn(EXPECTED_RESULT_FOR_ALL_UNION_SET_OPERATION) - assertThat(primitiveRegionCacheAllUnion).isEqualTo(EXPECTED_CACHE_FOR_ALL_UNION_SET_OPERATION) - } - - @Test - fun `compileSetOperation reuses the computation in the cache when there exists one`() { - runBlocking { reportResultCompiler.compileSetOperation(SET_OPERATION_ALL_UNION) } - val firstRoundPrimitiveRegionCache = reportResultCompiler.getPrimitiveRegionCache() - - runBlocking { reportResultCompiler.compileSetOperation(SET_OPERATION_ALL_UNION) } - val secondRoundPrimitiveRegionCache = reportResultCompiler.getPrimitiveRegionCache() - - assertThat(firstRoundPrimitiveRegionCache).isEqualTo(secondRoundPrimitiveRegionCache) - } - - @Test - fun `compileSetOperation throws IllegalArgumentException when lhs in SetOperation is not set`() { - val setOperationWithLhsNotSet = SET_OPERATION_ALL_UNION.copy { clearLhs() } - - val exception = - assertThrows(IllegalArgumentException::class.java) { - runBlocking { reportResultCompiler.compileSetOperation(setOperationWithLhsNotSet) } - } - assertThat(exception.message).isEqualTo("lhs in SetOperation must be set.") - } - - @Test - fun `compileSetOperation throws IllegalArgumentException when lhs operand type is not set`() { - val setOperationWithLhsOperandTypeNotSet = SET_OPERATION_ALL_UNION.copy { lhs = operand {} } - - val exception = - assertThrows(IllegalArgumentException::class.java) { - runBlocking { - reportResultCompiler.compileSetOperation(setOperationWithLhsOperandTypeNotSet) - } - } - assertThat(exception.message).isEqualTo("Operand type of lhs in SetOperation must be set.") - } - - @Test - fun `compileSetOperation throws IllegalArgumentException when a set operator type is not set`() { - val setOperationWithSetOperatorTypeNotSet = SET_OPERATION_ALL_UNION.copy { clearType() } - - val exception = - assertThrows(IllegalArgumentException::class.java) { - runBlocking { - reportResultCompiler.compileSetOperation(setOperationWithSetOperatorTypeNotSet) - } - } - assertThat(exception.message).isEqualTo("Set operator type is not specified.") - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/BUILD.bazel deleted file mode 100644 index 28ecd80eb1d..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/BUILD.bazel +++ /dev/null @@ -1,29 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_test") - -filegroup( - name = "textproto_files", - srcs = glob(["*.textproto"]), -) - -kt_jvm_test( - name = "ReportingTest", - srcs = ["ReportingTest.kt"], - data = [ - "textproto_files", - "//src/main/k8s/testing/secretfiles:root_certs", - "//src/main/k8s/testing/secretfiles:secret_files", - ], - jvm_flags = ["-Dcom.google.testing.junit.runner.shouldInstallTestSecurityManager=false"], - test_class = "org.wfanet.measurement.reporting.service.api.v1alpha.tools.ReportingTest", - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools:reporting", - "@wfa_common_jvm//imports/java/com/google/common/truth", - "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", - "@wfa_common_jvm//imports/java/io/grpc/netty", - "@wfa_common_jvm//imports/java/org/junit", - "@wfa_common_jvm//imports/kotlin/kotlin/test", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing", - ], -) diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/ReportingTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/ReportingTest.kt deleted file mode 100644 index 90060571776..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/ReportingTest.kt +++ /dev/null @@ -1,561 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha.tools - -import com.google.common.truth.Truth.assertThat -import io.grpc.Server -import io.grpc.ServerServiceDefinition -import io.grpc.netty.NettyServerBuilder -import java.nio.file.Path -import java.nio.file.Paths -import java.time.Duration -import java.time.Instant -import java.time.LocalDate -import java.time.ZoneOffset -import java.util.concurrent.TimeUnit.SECONDS -import org.junit.After -import org.junit.Before -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.mockito.kotlin.any -import org.wfanet.measurement.common.crypto.SigningCerts -import org.wfanet.measurement.common.getRuntimePath -import org.wfanet.measurement.common.grpc.testing.mockService -import org.wfanet.measurement.common.grpc.toServerTlsContext -import org.wfanet.measurement.common.parseTextProto -import org.wfanet.measurement.common.testing.CommandLineTesting -import org.wfanet.measurement.common.testing.CommandLineTesting.assertThat -import org.wfanet.measurement.common.testing.ExitInterceptingSecurityManager -import org.wfanet.measurement.common.testing.verifyProtoArgument -import org.wfanet.measurement.common.toProtoDuration -import org.wfanet.measurement.common.toProtoTime -import org.wfanet.measurement.reporting.v1alpha.EventGroupsGrpcKt.EventGroupsCoroutineImplBase -import org.wfanet.measurement.reporting.v1alpha.ListEventGroupsResponse -import org.wfanet.measurement.reporting.v1alpha.ListReportingSetsResponse -import org.wfanet.measurement.reporting.v1alpha.Metric -import org.wfanet.measurement.reporting.v1alpha.MetricKt.SetOperationKt.operand -import org.wfanet.measurement.reporting.v1alpha.MetricKt.namedSetOperation -import org.wfanet.measurement.reporting.v1alpha.MetricKt.reachParams -import org.wfanet.measurement.reporting.v1alpha.MetricKt.setOperation -import org.wfanet.measurement.reporting.v1alpha.Report -import org.wfanet.measurement.reporting.v1alpha.ReportKt.EventGroupUniverseKt.eventGroupEntry -import org.wfanet.measurement.reporting.v1alpha.ReportKt.eventGroupUniverse -import org.wfanet.measurement.reporting.v1alpha.ReportingSet -import org.wfanet.measurement.reporting.v1alpha.ReportingSetsGrpcKt.ReportingSetsCoroutineImplBase -import org.wfanet.measurement.reporting.v1alpha.ReportsGrpcKt.ReportsCoroutineImplBase -import org.wfanet.measurement.reporting.v1alpha.createReportRequest -import org.wfanet.measurement.reporting.v1alpha.createReportingSetRequest -import org.wfanet.measurement.reporting.v1alpha.eventGroup -import org.wfanet.measurement.reporting.v1alpha.getReportRequest -import org.wfanet.measurement.reporting.v1alpha.listEventGroupsRequest -import org.wfanet.measurement.reporting.v1alpha.listEventGroupsResponse -import org.wfanet.measurement.reporting.v1alpha.listReportingSetsRequest -import org.wfanet.measurement.reporting.v1alpha.listReportingSetsResponse -import org.wfanet.measurement.reporting.v1alpha.listReportsRequest -import org.wfanet.measurement.reporting.v1alpha.listReportsResponse -import org.wfanet.measurement.reporting.v1alpha.metric -import org.wfanet.measurement.reporting.v1alpha.periodicTimeInterval -import org.wfanet.measurement.reporting.v1alpha.report -import org.wfanet.measurement.reporting.v1alpha.reportingSet -import org.wfanet.measurement.reporting.v1alpha.timeInterval -import org.wfanet.measurement.reporting.v1alpha.timeIntervals - -private const val HOST = "localhost" -private val SECRETS_DIR: Path = - getRuntimePath( - Paths.get("wfa_measurement_system", "src", "main", "k8s", "testing", "secretfiles") - )!! - -private val TEXTPROTO_DIR: Path = - getRuntimePath( - Paths.get( - "wfa_measurement_system", - "src", - "test", - "kotlin", - "org", - "wfanet", - "measurement", - "reporting", - "service", - "api", - "v1alpha", - "tools", - ) - )!! - -private const val REPORT_IDEMPOTENCY_KEY = "report001" -private const val MEASUREMENT_CONSUMER_NAME = "measurementConsumers/1" -private const val EVENT_GROUP_NAME_1 = "$MEASUREMENT_CONSUMER_NAME/dataProviders/1/eventGroups/1" -private const val EVENT_GROUP_NAME_2 = "$MEASUREMENT_CONSUMER_NAME/dataProviders/1/eventGroups/2" -private const val EVENT_GROUP_NAME_3 = "$MEASUREMENT_CONSUMER_NAME/dataProviders/2/eventGroups/1" - -private val REPORTING_SET = reportingSet { name = "$MEASUREMENT_CONSUMER_NAME/reportingSets/1" } -private val LIST_REPORTING_SETS_RESPONSE = listReportingSetsResponse { - reportingSets += reportingSet { - name = "$MEASUREMENT_CONSUMER_NAME/reportingSets/1" - eventGroups += listOf(EVENT_GROUP_NAME_1, EVENT_GROUP_NAME_2, EVENT_GROUP_NAME_3) - filter = "some.filter1" - displayName = "test-reporting-set1" - } - reportingSets += reportingSet { - name = "$MEASUREMENT_CONSUMER_NAME/reportingSets/2" - eventGroups += listOf(EVENT_GROUP_NAME_1) - filter = "some.filter2" - displayName = "test-reporting-set2" - } - nextPageToken = "TokenToGetTheNextPage" -} - -private const val REPORT_NAME = "$MEASUREMENT_CONSUMER_NAME/reports/1" -private val REPORT = report { - name = REPORT_NAME - measurementConsumer = MEASUREMENT_CONSUMER_NAME - eventGroupUniverse = eventGroupUniverse { - eventGroupEntries += eventGroupEntry { - key = "measurementConsumers/1/dataProviders/1/eventGroups/1" - value = "" - } - eventGroupEntries += eventGroupEntry { - key = "measurementConsumers/1/dataProviders/2/eventGroups/3" - value = "partner=abc" - } - } - timeIntervals = timeIntervals { - timeIntervals += timeInterval { - startTime = - LocalDate.now().minusDays(1).atStartOfDay().toInstant(ZoneOffset.UTC).toProtoTime() - endTime = LocalDate.now().atStartOfDay().toInstant(ZoneOffset.UTC).toProtoTime() - } - } -} - -private val LIST_REPORTS_RESPONSE = listReportsResponse { - reports += - listOf( - report { - name = "$MEASUREMENT_CONSUMER_NAME/reports/1" - measurementConsumer = MEASUREMENT_CONSUMER_NAME - state = Report.State.RUNNING - }, - report { - name = "$MEASUREMENT_CONSUMER_NAME/reports/2" - measurementConsumer = MEASUREMENT_CONSUMER_NAME - state = Report.State.SUCCEEDED - }, - report { - name = "$MEASUREMENT_CONSUMER_NAME/reports/3" - measurementConsumer = MEASUREMENT_CONSUMER_NAME - state = Report.State.FAILED - }, - ) -} - -private const val DATA_PROVIDER_NAME = "dataProviders/1" - -private val LIST_EVENT_GROUPS_RESPONSE = listEventGroupsResponse { eventGroups += eventGroup {} } - -@RunWith(JUnit4::class) -class ReportingTest { - private val reportingSetsServiceMock: ReportingSetsCoroutineImplBase = - mockService() { - onBlocking { createReportingSet(any()) }.thenReturn(REPORTING_SET) - onBlocking { listReportingSets(any()) }.thenReturn(LIST_REPORTING_SETS_RESPONSE) - } - private val reportsServiceMock: ReportsCoroutineImplBase = - mockService() { - onBlocking { createReport(any()) }.thenReturn(REPORT) - onBlocking { listReports(any()) }.thenReturn(LIST_REPORTS_RESPONSE) - onBlocking { getReport(any()) }.thenReturn(REPORT) - } - private val eventGroupsServiceMock: EventGroupsCoroutineImplBase = - mockService() { onBlocking { listEventGroups(any()) }.thenReturn(LIST_EVENT_GROUPS_RESPONSE) } - - private val serverCerts = - SigningCerts.fromPemFiles( - certificateFile = SECRETS_DIR.resolve("reporting_tls.pem").toFile(), - privateKeyFile = SECRETS_DIR.resolve("reporting_tls.key").toFile(), - trustedCertCollectionFile = SECRETS_DIR.resolve("reporting_root.pem").toFile(), - ) - - private val services: List = - listOf( - reportingSetsServiceMock.bindService(), - reportsServiceMock.bindService(), - eventGroupsServiceMock.bindService(), - ) - - private val server: Server = - NettyServerBuilder.forPort(0) - .sslContext(serverCerts.toServerTlsContext()) - .addServices(services) - .build() - - @Before - fun initServer() { - server.start() - } - - @After - fun shutdownServer() { - server.shutdown() - server.awaitTermination(1, SECONDS) - } - - private fun callCli(args: Array): String { - val capturedOutput = CommandLineTesting.capturingOutput(args, Reporting::main) - assertThat(capturedOutput).status().isEqualTo(0) - return capturedOutput.out - } - - @Test - fun `reporting_sets create calls api with valid request`() { - val args = - arrayOf( - "--tls-cert-file=$SECRETS_DIR/mc_tls.pem", - "--tls-key-file=$SECRETS_DIR/mc_tls.key", - "--cert-collection-file=$SECRETS_DIR/reporting_root.pem", - "--reporting-server-api-target=$HOST:${server.port}", - "reporting-sets", - "create", - "--parent=$MEASUREMENT_CONSUMER_NAME", - "--event-group=$EVENT_GROUP_NAME_1", - "--event-group=$EVENT_GROUP_NAME_2", - "--event-group=$EVENT_GROUP_NAME_3", - "--filter=some.filter", - "--display-name=test-reporting-set", - ) - - val output = callCli(args) - - verifyProtoArgument( - reportingSetsServiceMock, - ReportingSetsCoroutineImplBase::createReportingSet, - ) - .isEqualTo( - createReportingSetRequest { - parent = MEASUREMENT_CONSUMER_NAME - reportingSet = reportingSet { - eventGroups += listOf(EVENT_GROUP_NAME_1, EVENT_GROUP_NAME_2, EVENT_GROUP_NAME_3) - filter = "some.filter" - displayName = "test-reporting-set" - } - } - ) - assertThat(parseTextProto(output.reader(), ReportingSet.getDefaultInstance())) - .isEqualTo(REPORTING_SET) - } - - @Test - fun `reporting_sets list calls api with valid request`() { - val args = - arrayOf( - "--tls-cert-file=$SECRETS_DIR/mc_tls.pem", - "--tls-key-file=$SECRETS_DIR/mc_tls.key", - "--cert-collection-file=$SECRETS_DIR/reporting_root.pem", - "--reporting-server-api-target=$HOST:${server.port}", - "reporting-sets", - "list", - "--parent=$MEASUREMENT_CONSUMER_NAME", - "--page-size=50", - ) - - val output = callCli(args) - - verifyProtoArgument(reportingSetsServiceMock, ReportingSetsCoroutineImplBase::listReportingSets) - .isEqualTo( - listReportingSetsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = 50 - } - ) - assertThat(parseTextProto(output.reader(), ListReportingSetsResponse.getDefaultInstance())) - .isEqualTo(LIST_REPORTING_SETS_RESPONSE) - } - - @Test - fun `Reports create calls api with valid request`() { - val metric = - """ - reach { } - set_operations { - unique_name: "operation1" - set_operation { - type: 1 - lhs { - reporting_set: "measurementConsumers/1/reportingSets/1" - } - rhs { - reporting_set: "measurementConsumers/1/reportingSets/2" - } - } - } - """ - .trimIndent() - val args = - arrayOf( - "--tls-cert-file=$SECRETS_DIR/mc_tls.pem", - "--tls-key-file=$SECRETS_DIR/mc_tls.key", - "--cert-collection-file=$SECRETS_DIR/reporting_root.pem", - "--reporting-server-api-target=$HOST:${server.port}", - "reports", - "create", - "--idempotency-key=$REPORT_IDEMPOTENCY_KEY", - "--parent=$MEASUREMENT_CONSUMER_NAME", - "--event-group-key=$EVENT_GROUP_NAME_1", - "--event-group-value=", - "--event-group-key=$EVENT_GROUP_NAME_2", - "--event-group-value=partner=abc", - "--periodic-interval-start-time=2017-01-15T01:30:15.01Z", - "--periodic-interval-increment=P1DT3H5M12.99S", - "--periodic-interval-count=3", - "--metric=$metric", - ) - - val output = callCli(args) - - verifyProtoArgument(reportsServiceMock, ReportsCoroutineImplBase::createReport) - .isEqualTo( - createReportRequest { - parent = MEASUREMENT_CONSUMER_NAME - report = report { - reportIdempotencyKey = REPORT_IDEMPOTENCY_KEY - measurementConsumer = MEASUREMENT_CONSUMER_NAME - eventGroupUniverse = eventGroupUniverse { - eventGroupEntries += eventGroupEntry { key = EVENT_GROUP_NAME_1 } - eventGroupEntries += eventGroupEntry { - key = EVENT_GROUP_NAME_2 - value = "partner=abc" - } - } - periodicTimeInterval = periodicTimeInterval { - startTime = Instant.parse("2017-01-15T01:30:15.01Z").toProtoTime() - increment = Duration.parse("P1DT3H5M12.99S").toProtoDuration() - intervalCount = 3 - } - metrics += metric { - reach = reachParams {} - setOperations += namedSetOperation { - uniqueName = "operation1" - setOperation = setOperation { - type = Metric.SetOperation.Type.UNION - lhs = operand { reportingSet = "measurementConsumers/1/reportingSets/1" } - rhs = operand { reportingSet = "measurementConsumers/1/reportingSets/2" } - } - } - } - } - } - ) - assertThat(parseTextProto(output.reader(), Report.getDefaultInstance())).isEqualTo(REPORT) - } - - @Test - fun `Reports create calls api with correct time intervals params`() { - val textFormatMetric = - """ - reach { } - set_operations { - unique_name: "operation1" - set_operation { - type: 1 - lhs { - reporting_set: "measurementConsumers/1/reportingSets/1" - } - rhs { - reporting_set: "measurementConsumers/1/reportingSets/2" - } - } - } - """ - .trimIndent() - - val args = - arrayOf( - "--tls-cert-file=$SECRETS_DIR/mc_tls.pem", - "--tls-key-file=$SECRETS_DIR/mc_tls.key", - "--cert-collection-file=$SECRETS_DIR/reporting_root.pem", - "--reporting-server-api-target=$HOST:${server.port}", - "reports", - "create", - "--idempotency-key=$REPORT_IDEMPOTENCY_KEY", - "--parent=$MEASUREMENT_CONSUMER_NAME", - "--event-group-key=$EVENT_GROUP_NAME_1", - "--event-group-value=", - "--interval-start-time=2017-01-15T01:30:15.01Z", - "--interval-end-time=2018-10-27T23:19:12.99Z", - "--interval-start-time=2019-01-19T09:48:35.57Z", - "--interval-end-time=2022-06-13T11:57:54.21Z", - "--metric=$textFormatMetric", - ) - val output = callCli(args) - - verifyProtoArgument(reportsServiceMock, ReportsCoroutineImplBase::createReport) - .comparingExpectedFieldsOnly() - .isEqualTo( - createReportRequest { - report = report { - timeIntervals = timeIntervals { - timeIntervals += timeInterval { - startTime = Instant.parse("2017-01-15T01:30:15.01Z").toProtoTime() - endTime = Instant.parse("2018-10-27T23:19:12.99Z").toProtoTime() - } - timeIntervals += timeInterval { - startTime = Instant.parse("2019-01-19T09:48:35.57Z").toProtoTime() - endTime = Instant.parse("2022-06-13T11:57:54.21Z").toProtoTime() - } - } - } - } - ) - assertThat(parseTextProto(output.reader(), Report.getDefaultInstance())).isEqualTo(REPORT) - } - - @Test - fun `Reports create calls api with complex metric`() { - val textFormatMetric = TEXTPROTO_DIR.resolve("metric2.textproto").toFile().readText() - - val args = - arrayOf( - "--tls-cert-file=$SECRETS_DIR/mc_tls.pem", - "--tls-key-file=$SECRETS_DIR/mc_tls.key", - "--cert-collection-file=$SECRETS_DIR/reporting_root.pem", - "--reporting-server-api-target=$HOST:${server.port}", - "reports", - "create", - "--idempotency-key=$REPORT_IDEMPOTENCY_KEY", - "--parent=$MEASUREMENT_CONSUMER_NAME", - "--event-group-key=$EVENT_GROUP_NAME_1", - "--event-group-value=", - "--periodic-interval-start-time=2017-01-15T01:30:15.01Z", - "--periodic-interval-increment=P1DT3H5M12.99S", - "--periodic-interval-count=3", - "--metric=$textFormatMetric", - ) - val output = callCli(args) - - verifyProtoArgument(reportsServiceMock, ReportsCoroutineImplBase::createReport) - .comparingExpectedFieldsOnly() - .isEqualTo( - createReportRequest { - report = report { - metrics += metric { - reach = reachParams {} - cumulative = true - setOperations += namedSetOperation { - uniqueName = "operation1" - setOperation = setOperation { - type = Metric.SetOperation.Type.UNION - lhs = operand { reportingSet = "measurementConsumers/1/reportingSets/1" } - rhs = operand { reportingSet = "measurementConsumers/1/reportingSets/2" } - } - } - setOperations += namedSetOperation { - uniqueName = "operation2" - setOperation = setOperation { - type = Metric.SetOperation.Type.DIFFERENCE - lhs = operand { reportingSet = "measurementConsumers/1/reportingSets/3" } - rhs = operand { - operation = setOperation { - type = Metric.SetOperation.Type.INTERSECTION - lhs = operand { reportingSet = "measurementConsumers/1/reportingSets/4" } - rhs = operand { reportingSet = "measurementConsumers/1/reportingSets/5" } - } - } - } - } - } - } - } - ) - assertThat(parseTextProto(output.reader(), Report.getDefaultInstance())).isEqualTo(REPORT) - } - - @Test - fun `Reports list calls api with valid request`() { - val args = - arrayOf( - "--tls-cert-file=$SECRETS_DIR/mc_tls.pem", - "--tls-key-file=$SECRETS_DIR/mc_tls.key", - "--cert-collection-file=$SECRETS_DIR/reporting_root.pem", - "--reporting-server-api-target=$HOST:${server.port}", - "reports", - "list", - "--parent=$MEASUREMENT_CONSUMER_NAME", - ) - callCli(args) - - verifyProtoArgument(reportsServiceMock, ReportsCoroutineImplBase::listReports) - .isEqualTo( - listReportsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = 1000 - } - ) - } - - @Test - fun `Reports get calls api with valid request`() { - val args = - arrayOf( - "--tls-cert-file=$SECRETS_DIR/mc_tls.pem", - "--tls-key-file=$SECRETS_DIR/mc_tls.key", - "--cert-collection-file=$SECRETS_DIR/reporting_root.pem", - "--reporting-server-api-target=$HOST:${server.port}", - "reports", - "get", - REPORT_NAME, - ) - val output = callCli(args) - - verifyProtoArgument(reportsServiceMock, ReportsCoroutineImplBase::getReport) - .isEqualTo(getReportRequest { name = REPORT_NAME }) - assertThat(parseTextProto(output.reader(), Report.getDefaultInstance())).isEqualTo(REPORT) - } - - @Test - fun `EventGroups list callls api with valid request`() { - val args = - arrayOf( - "--tls-cert-file=$SECRETS_DIR/mc_tls.pem", - "--tls-key-file=$SECRETS_DIR/mc_tls.key", - "--cert-collection-file=$SECRETS_DIR/reporting_root.pem", - "--reporting-server-api-target=$HOST:${server.port}", - "event-groups", - "list", - "--parent=$DATA_PROVIDER_NAME", - "--filter=abcd", - ) - val output = callCli(args) - - verifyProtoArgument(eventGroupsServiceMock, EventGroupsCoroutineImplBase::listEventGroups) - .isEqualTo( - listEventGroupsRequest { - parent = DATA_PROVIDER_NAME - filter = "abcd" - pageSize = 1000 - } - ) - assertThat(parseTextProto(output.reader(), ListEventGroupsResponse.getDefaultInstance())) - .isEqualTo(LIST_EVENT_GROUPS_RESPONSE) - } - - companion object { - init { - System.setSecurityManager(ExitInterceptingSecurityManager) - } - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/metric1.textproto b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/metric1.textproto deleted file mode 100644 index c5bf5885454..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/metric1.textproto +++ /dev/null @@ -1,15 +0,0 @@ -# proto-file: wfa/measurement/reporting/v1alpha/metric.proto -# proto-message: Metric -reach { } -set_operations { - unique_name: "operation1" - set_operation { - type: 1 - lhs { - reporting_set: "measurementConsumers/1/reportingSets/1" - } - rhs { - reporting_set: "measurementConsumers/1/reportingSets/2" - } - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/metric2.textproto b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/metric2.textproto deleted file mode 100644 index 48a44bd3e55..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/metric2.textproto +++ /dev/null @@ -1,36 +0,0 @@ -# proto-file: wfa/measurement/reporting/v1alpha/metric.proto -# proto-message: Metric -reach { } -cumulative: true -set_operations { - unique_name: "operation1" - set_operation { - type: 1 - lhs { - reporting_set: "measurementConsumers/1/reportingSets/1" - } - rhs { - reporting_set: "measurementConsumers/1/reportingSets/2" - } - } -} -set_operations { - unique_name: "operation2" - set_operation { - type: 2 - lhs { - reporting_set: "measurementConsumers/1/reportingSets/3" - } - rhs { - operation { - type: 3 - lhs { - reporting_set: "measurementConsumers/1/reportingSets/4" - } - rhs { - reporting_set: "measurementConsumers/1/reportingSets/5" - } - } - } - } -} \ No newline at end of file