Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support passing credentials from pipeline options into SpannerIO.readChangeStream #30361

Merged
merged 13 commits into from
Feb 29, 2024
Merged
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@
## Bugfixes

* Fixed X (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Fixed SpannerIO.readChangeStream to support propagating credentials from pipeline options
to the getDialect calls for authenticating with Spanner ([#30361](https://github.com/apache/beam/pull/30361)).
dengwe1 marked this conversation as resolved.
Show resolved Hide resolved

## Security Fixes
* Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.google.api.gax.rpc.ServerStreamingCallSettings;
import com.google.api.gax.rpc.StatusCode.Code;
import com.google.api.gax.rpc.UnaryCallSettings;
import com.google.auth.Credentials;
import com.google.cloud.NoCredentials;
import com.google.cloud.ServiceFactory;
import com.google.cloud.spanner.BatchClient;
Expand All @@ -41,6 +42,7 @@
import java.util.concurrent.ConcurrentHashMap;
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.util.ReleaseInfo;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.joda.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -100,7 +102,8 @@ public static SpannerAccessor getOrCreate(SpannerConfig spannerConfig) {
}
}

private static SpannerAccessor createAndConnect(SpannerConfig spannerConfig) {
@VisibleForTesting
static SpannerOptions buildSpannerOptions(SpannerConfig spannerConfig) {
SpannerOptions.Builder builder = SpannerOptions.newBuilder();

Set<Code> retryableCodes = new HashSet<>();
Expand Down Expand Up @@ -222,8 +225,16 @@ private static SpannerAccessor createAndConnect(SpannerConfig spannerConfig) {
if (databaseRole != null && databaseRole.get() != null && !databaseRole.get().isEmpty()) {
builder.setDatabaseRole(databaseRole.get());
}
SpannerOptions options = builder.build();
ValueProvider<Credentials> credentials = spannerConfig.getCredentials();
if (credentials != null && credentials.get() != null) {
builder.setCredentials(credentials.get());
}

return builder.build();
}

private static SpannerAccessor createAndConnect(SpannerConfig spannerConfig) {
SpannerOptions options = buildSpannerOptions(spannerConfig);
Spanner spanner = options.getService();
String instanceId = spannerConfig.getInstanceId().get();
String databaseId = spannerConfig.getDatabaseId().get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.google.api.gax.retrying.RetrySettings;
import com.google.api.gax.rpc.StatusCode.Code;
import com.google.auth.Credentials;
import com.google.auto.value.AutoValue;
import com.google.cloud.ServiceFactory;
import com.google.cloud.spanner.Options.RpcPriority;
Expand Down Expand Up @@ -84,6 +85,8 @@ public abstract class SpannerConfig implements Serializable {

public abstract @Nullable ValueProvider<Boolean> getDataBoostEnabled();

public abstract @Nullable ValueProvider<Credentials> getCredentials();

abstract Builder toBuilder();

public static SpannerConfig create() {
Expand Down Expand Up @@ -161,6 +164,8 @@ abstract Builder setExecuteStreamingSqlRetrySettings(

abstract Builder setPartitionReadTimeout(ValueProvider<Duration> partitionReadTimeout);

abstract Builder setCredentials(ValueProvider<Credentials> credentials);

public abstract SpannerConfig build();
}

Expand Down Expand Up @@ -302,4 +307,14 @@ public SpannerConfig withPartitionReadTimeout(Duration partitionReadTimeout) {
public SpannerConfig withPartitionReadTimeout(ValueProvider<Duration> partitionReadTimeout) {
return toBuilder().setPartitionReadTimeout(partitionReadTimeout).build();
}

/** Specifies the credentials. */
public SpannerConfig withCredentials(Credentials credentials) {
return withCredentials(ValueProvider.StaticValueProvider.of(credentials));
}

/** Specifies the credentials. */
public SpannerConfig withCredentials(ValueProvider<Credentials> credentials) {
return toBuilder().setCredentials(credentials).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import com.google.api.gax.retrying.RetrySettings;
import com.google.api.gax.rpc.StatusCode.Code;
import com.google.auth.Credentials;
import com.google.auto.value.AutoValue;
import com.google.cloud.ServiceFactory;
import com.google.cloud.Timestamp;
Expand Down Expand Up @@ -68,6 +69,7 @@
import org.apache.beam.runners.core.metrics.ServiceCallMetric;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
import org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamMetrics;
import org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamsConstants;
import org.apache.beam.sdk.io.gcp.spanner.changestreams.MetadataSpannerConfigFactory;
Expand All @@ -86,6 +88,7 @@
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Distribution;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.StreamingOptions;
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.schemas.Schema;
Expand Down Expand Up @@ -1667,31 +1670,15 @@ && getInclusiveStartAt().toSqlTimestamp().after(getInclusiveEndAt().toSqlTimesta
getSpannerConfig().getProjectId().get(),
partitionMetadataInstanceId,
partitionMetadataDatabaseId);
SpannerConfig changeStreamSpannerConfig = getSpannerConfig();
// Set default retryable errors for ReadChangeStream
if (changeStreamSpannerConfig.getRetryableCodes() == null) {
ImmutableSet<Code> defaultRetryableCodes = ImmutableSet.of(Code.UNAVAILABLE, Code.ABORTED);
changeStreamSpannerConfig =
changeStreamSpannerConfig.toBuilder().setRetryableCodes(defaultRetryableCodes).build();
}
// Set default retry timeouts for ReadChangeStream
if (changeStreamSpannerConfig.getExecuteStreamingSqlRetrySettings() == null) {
changeStreamSpannerConfig =
changeStreamSpannerConfig
.toBuilder()
.setExecuteStreamingSqlRetrySettings(
RetrySettings.newBuilder()
.setTotalTimeout(org.threeten.bp.Duration.ofMinutes(5))
.setInitialRpcTimeout(org.threeten.bp.Duration.ofMinutes(1))
.setMaxRpcTimeout(org.threeten.bp.Duration.ofMinutes(1))
.build())
.build();
}

final SpannerConfig changeStreamSpannerConfig = buildChangeStreamSpannerConfig();
final SpannerConfig partitionMetadataSpannerConfig =
MetadataSpannerConfigFactory.create(
changeStreamSpannerConfig, partitionMetadataInstanceId, partitionMetadataDatabaseId);
Dialect changeStreamDatabaseDialect = getDialect(changeStreamSpannerConfig);
Dialect metadataDatabaseDialect = getDialect(partitionMetadataSpannerConfig);
final Dialect changeStreamDatabaseDialect =
getDialect(changeStreamSpannerConfig, input.getPipeline().getOptions());
final Dialect metadataDatabaseDialect =
getDialect(partitionMetadataSpannerConfig, input.getPipeline().getOptions());
LOG.info(
"The Spanner database "
+ changeStreamDatabaseId
Expand Down Expand Up @@ -1773,10 +1760,52 @@ && getInclusiveStartAt().toSqlTimestamp().after(getInclusiveEndAt().toSqlTimesta
.apply(ParDo.of(new CleanUpReadChangeStreamDoFn(daoFactory)));
return dataChangeRecordsOut;
}

@VisibleForTesting
SpannerConfig buildChangeStreamSpannerConfig() {
SpannerConfig changeStreamSpannerConfig = getSpannerConfig();
// Set default retryable errors for ReadChangeStream
if (changeStreamSpannerConfig.getRetryableCodes() == null) {
ImmutableSet<Code> defaultRetryableCodes = ImmutableSet.of(Code.UNAVAILABLE, Code.ABORTED);
changeStreamSpannerConfig =
changeStreamSpannerConfig.toBuilder().setRetryableCodes(defaultRetryableCodes).build();
}
// Set default retry timeouts for ReadChangeStream
if (changeStreamSpannerConfig.getExecuteStreamingSqlRetrySettings() == null) {
changeStreamSpannerConfig =
changeStreamSpannerConfig
.toBuilder()
.setExecuteStreamingSqlRetrySettings(
RetrySettings.newBuilder()
.setTotalTimeout(org.threeten.bp.Duration.ofMinutes(5))
.setInitialRpcTimeout(org.threeten.bp.Duration.ofMinutes(1))
.setMaxRpcTimeout(org.threeten.bp.Duration.ofMinutes(1))
.build())
.build();
}
return changeStreamSpannerConfig;
}
}

/** If credentials are not set in spannerConfig, uses the credentials from pipeline options. */
@VisibleForTesting
static SpannerConfig buildSpannerConfigWithCredential(
SpannerConfig spannerConfig, PipelineOptions pipelineOptions) {
if (spannerConfig.getCredentials() == null && pipelineOptions != null) {
final Credentials credentials = pipelineOptions.as(GcpOptions.class).getGcpCredential();
if (credentials != null) {
spannerConfig = spannerConfig.withCredentials(credentials);
}
}
return spannerConfig;
}

private static Dialect getDialect(SpannerConfig spannerConfig) {
DatabaseClient databaseClient = SpannerAccessor.getOrCreate(spannerConfig).getDatabaseClient();
private static Dialect getDialect(SpannerConfig spannerConfig, PipelineOptions pipelineOptions) {
// Allow passing the credential from pipeline options to the getDialect() call.
SpannerConfig spannerConfigWithCredential =
buildSpannerConfigWithCredential(spannerConfig, pipelineOptions);
DatabaseClient databaseClient =
SpannerAccessor.getOrCreate(spannerConfigWithCredential).getDatabaseClient();
return databaseClient.getDialect();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.google.api.gax.retrying.RetrySettings;
import com.google.api.gax.rpc.StatusCode.Code;
import com.google.auth.Credentials;
import com.google.cloud.spanner.Options.RpcPriority;
import org.apache.beam.sdk.io.gcp.spanner.SpannerConfig;
import org.apache.beam.sdk.options.ValueProvider;
Expand Down Expand Up @@ -113,6 +114,11 @@ public static SpannerConfig create(
config = config.withRpcPriority(StaticValueProvider.of(rpcPriority.get()));
}

ValueProvider<Credentials> credentials = primaryConfig.getCredentials();
if (credentials != null) {
config = config.withCredentials(credentials);
}

return config;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
*/
package org.apache.beam.sdk.io.gcp.spanner;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import com.google.cloud.spanner.DatabaseId;
import com.google.cloud.spanner.SpannerOptions;
import org.apache.beam.sdk.extensions.gcp.auth.TestCredential;
import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -141,4 +144,24 @@ public void testCreateWithEmptyDatabaseRole() {
.getDatabaseClient(DatabaseId.of("project", "test1", "test1"));
verify(serviceFactory.mockSpanner(), times(1)).close();
}

@Test
public void testBuildSpannerOptionsWithCredential() {
TestCredential testCredential = new TestCredential();
SpannerConfig config1 =
SpannerConfig.create()
.toBuilder()
.setServiceFactory(serviceFactory)
.setProjectId(StaticValueProvider.of("project"))
.setInstanceId(StaticValueProvider.of("test-instance"))
.setDatabaseId(StaticValueProvider.of("test-db"))
.setDatabaseRole(StaticValueProvider.of("test-role"))
.setCredentials(StaticValueProvider.of(testCredential))
.build();

SpannerOptions options = SpannerAccessor.buildSpannerOptions(config1);
assertEquals("project", options.getProjectId());
assertEquals("test-role", options.getDatabaseRole());
assertEquals(testCredential, options.getCredentials());
}
}
Loading
Loading