Skip to content

Commit

Permalink
Support passing credentials from pipeline options into SpannerIO.read…
Browse files Browse the repository at this point in the history
…ChangeStream (#30361)

* Support credentials in SpannerConfig

* Support passing credentials from pipeline options in SpannerIO
  • Loading branch information
dengwe1 authored Feb 29, 2024
1 parent 53cae78 commit c5e1d10
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 26 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@
## Bugfixes

* Fixed X (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Fixed SpannerIO.readChangeStream to support propagating credentials from pipeline options
to the getDialect calls for authenticating with Spanner (Java) ([#30361](https://github.com/apache/beam/pull/30361)).

## Security Fixes
* Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.google.api.gax.rpc.ServerStreamingCallSettings;
import com.google.api.gax.rpc.StatusCode.Code;
import com.google.api.gax.rpc.UnaryCallSettings;
import com.google.auth.Credentials;
import com.google.cloud.NoCredentials;
import com.google.cloud.ServiceFactory;
import com.google.cloud.spanner.BatchClient;
Expand All @@ -41,6 +42,7 @@
import java.util.concurrent.ConcurrentHashMap;
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.util.ReleaseInfo;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.joda.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -100,7 +102,8 @@ public static SpannerAccessor getOrCreate(SpannerConfig spannerConfig) {
}
}

private static SpannerAccessor createAndConnect(SpannerConfig spannerConfig) {
@VisibleForTesting
static SpannerOptions buildSpannerOptions(SpannerConfig spannerConfig) {
SpannerOptions.Builder builder = SpannerOptions.newBuilder();

Set<Code> retryableCodes = new HashSet<>();
Expand Down Expand Up @@ -222,8 +225,16 @@ private static SpannerAccessor createAndConnect(SpannerConfig spannerConfig) {
if (databaseRole != null && databaseRole.get() != null && !databaseRole.get().isEmpty()) {
builder.setDatabaseRole(databaseRole.get());
}
SpannerOptions options = builder.build();
ValueProvider<Credentials> credentials = spannerConfig.getCredentials();
if (credentials != null && credentials.get() != null) {
builder.setCredentials(credentials.get());
}

return builder.build();
}

private static SpannerAccessor createAndConnect(SpannerConfig spannerConfig) {
SpannerOptions options = buildSpannerOptions(spannerConfig);
Spanner spanner = options.getService();
String instanceId = spannerConfig.getInstanceId().get();
String databaseId = spannerConfig.getDatabaseId().get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.google.api.gax.retrying.RetrySettings;
import com.google.api.gax.rpc.StatusCode.Code;
import com.google.auth.Credentials;
import com.google.auto.value.AutoValue;
import com.google.cloud.ServiceFactory;
import com.google.cloud.spanner.Options.RpcPriority;
Expand Down Expand Up @@ -84,6 +85,8 @@ public abstract class SpannerConfig implements Serializable {

public abstract @Nullable ValueProvider<Boolean> getDataBoostEnabled();

public abstract @Nullable ValueProvider<Credentials> getCredentials();

abstract Builder toBuilder();

public static SpannerConfig create() {
Expand Down Expand Up @@ -161,6 +164,8 @@ abstract Builder setExecuteStreamingSqlRetrySettings(

abstract Builder setPartitionReadTimeout(ValueProvider<Duration> partitionReadTimeout);

abstract Builder setCredentials(ValueProvider<Credentials> credentials);

public abstract SpannerConfig build();
}

Expand Down Expand Up @@ -302,4 +307,14 @@ public SpannerConfig withPartitionReadTimeout(Duration partitionReadTimeout) {
public SpannerConfig withPartitionReadTimeout(ValueProvider<Duration> partitionReadTimeout) {
return toBuilder().setPartitionReadTimeout(partitionReadTimeout).build();
}

/** Specifies the credentials. */
public SpannerConfig withCredentials(Credentials credentials) {
return withCredentials(ValueProvider.StaticValueProvider.of(credentials));
}

/** Specifies the credentials. */
public SpannerConfig withCredentials(ValueProvider<Credentials> credentials) {
return toBuilder().setCredentials(credentials).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import com.google.api.gax.retrying.RetrySettings;
import com.google.api.gax.rpc.StatusCode.Code;
import com.google.auth.Credentials;
import com.google.auto.value.AutoValue;
import com.google.cloud.ServiceFactory;
import com.google.cloud.Timestamp;
Expand Down Expand Up @@ -68,6 +69,7 @@
import org.apache.beam.runners.core.metrics.ServiceCallMetric;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
import org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamMetrics;
import org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamsConstants;
import org.apache.beam.sdk.io.gcp.spanner.changestreams.MetadataSpannerConfigFactory;
Expand All @@ -86,6 +88,7 @@
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Distribution;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.StreamingOptions;
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.schemas.Schema;
Expand Down Expand Up @@ -1667,31 +1670,15 @@ && getInclusiveStartAt().toSqlTimestamp().after(getInclusiveEndAt().toSqlTimesta
getSpannerConfig().getProjectId().get(),
partitionMetadataInstanceId,
partitionMetadataDatabaseId);
SpannerConfig changeStreamSpannerConfig = getSpannerConfig();
// Set default retryable errors for ReadChangeStream
if (changeStreamSpannerConfig.getRetryableCodes() == null) {
ImmutableSet<Code> defaultRetryableCodes = ImmutableSet.of(Code.UNAVAILABLE, Code.ABORTED);
changeStreamSpannerConfig =
changeStreamSpannerConfig.toBuilder().setRetryableCodes(defaultRetryableCodes).build();
}
// Set default retry timeouts for ReadChangeStream
if (changeStreamSpannerConfig.getExecuteStreamingSqlRetrySettings() == null) {
changeStreamSpannerConfig =
changeStreamSpannerConfig
.toBuilder()
.setExecuteStreamingSqlRetrySettings(
RetrySettings.newBuilder()
.setTotalTimeout(org.threeten.bp.Duration.ofMinutes(5))
.setInitialRpcTimeout(org.threeten.bp.Duration.ofMinutes(1))
.setMaxRpcTimeout(org.threeten.bp.Duration.ofMinutes(1))
.build())
.build();
}

final SpannerConfig changeStreamSpannerConfig = buildChangeStreamSpannerConfig();
final SpannerConfig partitionMetadataSpannerConfig =
MetadataSpannerConfigFactory.create(
changeStreamSpannerConfig, partitionMetadataInstanceId, partitionMetadataDatabaseId);
Dialect changeStreamDatabaseDialect = getDialect(changeStreamSpannerConfig);
Dialect metadataDatabaseDialect = getDialect(partitionMetadataSpannerConfig);
final Dialect changeStreamDatabaseDialect =
getDialect(changeStreamSpannerConfig, input.getPipeline().getOptions());
final Dialect metadataDatabaseDialect =
getDialect(partitionMetadataSpannerConfig, input.getPipeline().getOptions());
LOG.info(
"The Spanner database "
+ changeStreamDatabaseId
Expand Down Expand Up @@ -1773,10 +1760,52 @@ && getInclusiveStartAt().toSqlTimestamp().after(getInclusiveEndAt().toSqlTimesta
.apply(ParDo.of(new CleanUpReadChangeStreamDoFn(daoFactory)));
return dataChangeRecordsOut;
}

@VisibleForTesting
SpannerConfig buildChangeStreamSpannerConfig() {
SpannerConfig changeStreamSpannerConfig = getSpannerConfig();
// Set default retryable errors for ReadChangeStream
if (changeStreamSpannerConfig.getRetryableCodes() == null) {
ImmutableSet<Code> defaultRetryableCodes = ImmutableSet.of(Code.UNAVAILABLE, Code.ABORTED);
changeStreamSpannerConfig =
changeStreamSpannerConfig.toBuilder().setRetryableCodes(defaultRetryableCodes).build();
}
// Set default retry timeouts for ReadChangeStream
if (changeStreamSpannerConfig.getExecuteStreamingSqlRetrySettings() == null) {
changeStreamSpannerConfig =
changeStreamSpannerConfig
.toBuilder()
.setExecuteStreamingSqlRetrySettings(
RetrySettings.newBuilder()
.setTotalTimeout(org.threeten.bp.Duration.ofMinutes(5))
.setInitialRpcTimeout(org.threeten.bp.Duration.ofMinutes(1))
.setMaxRpcTimeout(org.threeten.bp.Duration.ofMinutes(1))
.build())
.build();
}
return changeStreamSpannerConfig;
}
}

/** If credentials are not set in spannerConfig, uses the credentials from pipeline options. */
@VisibleForTesting
static SpannerConfig buildSpannerConfigWithCredential(
SpannerConfig spannerConfig, PipelineOptions pipelineOptions) {
if (spannerConfig.getCredentials() == null && pipelineOptions != null) {
final Credentials credentials = pipelineOptions.as(GcpOptions.class).getGcpCredential();
if (credentials != null) {
spannerConfig = spannerConfig.withCredentials(credentials);
}
}
return spannerConfig;
}

private static Dialect getDialect(SpannerConfig spannerConfig) {
DatabaseClient databaseClient = SpannerAccessor.getOrCreate(spannerConfig).getDatabaseClient();
private static Dialect getDialect(SpannerConfig spannerConfig, PipelineOptions pipelineOptions) {
// Allow passing the credential from pipeline options to the getDialect() call.
SpannerConfig spannerConfigWithCredential =
buildSpannerConfigWithCredential(spannerConfig, pipelineOptions);
DatabaseClient databaseClient =
SpannerAccessor.getOrCreate(spannerConfigWithCredential).getDatabaseClient();
return databaseClient.getDialect();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.google.api.gax.retrying.RetrySettings;
import com.google.api.gax.rpc.StatusCode.Code;
import com.google.auth.Credentials;
import com.google.cloud.spanner.Options.RpcPriority;
import org.apache.beam.sdk.io.gcp.spanner.SpannerConfig;
import org.apache.beam.sdk.options.ValueProvider;
Expand Down Expand Up @@ -113,6 +114,11 @@ public static SpannerConfig create(
config = config.withRpcPriority(StaticValueProvider.of(rpcPriority.get()));
}

ValueProvider<Credentials> credentials = primaryConfig.getCredentials();
if (credentials != null) {
config = config.withCredentials(credentials);
}

return config;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
*/
package org.apache.beam.sdk.io.gcp.spanner;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import com.google.cloud.spanner.DatabaseId;
import com.google.cloud.spanner.SpannerOptions;
import org.apache.beam.sdk.extensions.gcp.auth.TestCredential;
import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -141,4 +144,24 @@ public void testCreateWithEmptyDatabaseRole() {
.getDatabaseClient(DatabaseId.of("project", "test1", "test1"));
verify(serviceFactory.mockSpanner(), times(1)).close();
}

@Test
public void testBuildSpannerOptionsWithCredential() {
TestCredential testCredential = new TestCredential();
SpannerConfig config1 =
SpannerConfig.create()
.toBuilder()
.setServiceFactory(serviceFactory)
.setProjectId(StaticValueProvider.of("project"))
.setInstanceId(StaticValueProvider.of("test-instance"))
.setDatabaseId(StaticValueProvider.of("test-db"))
.setDatabaseRole(StaticValueProvider.of("test-role"))
.setCredentials(StaticValueProvider.of(testCredential))
.build();

SpannerOptions options = SpannerAccessor.buildSpannerOptions(config1);
assertEquals("project", options.getProjectId());
assertEquals("test-role", options.getDatabaseRole());
assertEquals(testCredential, options.getCredentials());
}
}
Loading

0 comments on commit c5e1d10

Please sign in to comment.