Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Add test proxy implementation for ExecuteQuery api #2360

Merged
merged 3 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
import com.google.api.gax.retrying.RetrySettings;
import com.google.api.gax.rpc.ApiException;
import com.google.api.gax.rpc.ServerStream;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.OAuth2Credentials;
import com.google.auto.value.AutoValue;
import com.google.bigtable.v2.Column;
import com.google.bigtable.v2.Family;
Expand All @@ -41,6 +42,7 @@
import com.google.cloud.bigtable.data.v2.models.ReadModifyWriteRow;
import com.google.cloud.bigtable.data.v2.models.RowCell;
import com.google.cloud.bigtable.data.v2.models.RowMutation;
import com.google.cloud.bigtable.data.v2.models.sql.ResultSet;
import com.google.cloud.bigtable.data.v2.stub.EnhancedBigtableStubSettings;
import com.google.cloud.bigtable.testproxy.CloudBigtableV2TestProxyGrpc.CloudBigtableV2TestProxyImplBase;
import com.google.common.base.Preconditions;
Expand All @@ -50,26 +52,24 @@
import io.grpc.ManagedChannelBuilder;
import io.grpc.Status;
import io.grpc.StatusException;
import io.grpc.StatusRuntimeException;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext;
import io.grpc.stub.StreamObserver;
import java.io.ByteArrayInputStream;
import java.io.Closeable;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.logging.Logger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.threeten.bp.Duration;

/** Java implementation of the CBT test proxy. Used to test the Java CBT client. */
Expand All @@ -92,50 +92,13 @@ static CbtClient create(BigtableDataSettings settings, BigtableDataClient dataCl

private static final Logger logger = Logger.getLogger(CbtTestProxy.class.getName());

private CbtTestProxy(
boolean encrypted,
@Nullable String rootCerts,
@Nullable String sslTarget,
@Nullable String credential) {
this.encrypted = encrypted;
this.rootCerts = rootCerts;
this.sslTarget = sslTarget;
this.credential = credential;
private CbtTestProxy() {
this.idClientMap = new ConcurrentHashMap<>();
}

/**
* Factory method to return a proxy instance that interacts with server unencrypted and
* unauthenticated.
*/
public static CbtTestProxy createUnencrypted() {
return new CbtTestProxy(false, null, null, null);
}

/**
* Factory method to return a proxy instance that interacts with server encrypted. Default
* authority and public certificates are used if null values are passed in.
*
* @param rootCertsPemPath The path to a root certificate PEM file
* @param sslTarget The override of SSL target name
* @param credentialJsonPath The path to a credential JSON file
*/
public static CbtTestProxy createEncrypted(
@Nullable String rootCertsPemPath,
@Nullable String sslTarget,
@Nullable String credentialJsonPath)
throws IOException {
String tmpRootCerts = null, tmpCredential = null;
if (rootCertsPemPath != null) {
Path file = Paths.get(rootCertsPemPath);
tmpRootCerts = new String(Files.readAllBytes(file), UTF_8);
}
if (credentialJsonPath != null) {
Path file = Paths.get(credentialJsonPath);
tmpCredential = new String(Files.readAllBytes(file), UTF_8);
}

return new CbtTestProxy(true, tmpRootCerts, sslTarget, tmpCredential);
/** Factory method to return a proxy instance. */
public static CbtTestProxy create() {
return new CbtTestProxy();
}

/**
Expand All @@ -159,6 +122,8 @@ private static BigtableDataSettings.Builder overrideTimeoutSetting(
settingsBuilder.stubSettings().readModifyWriteRowSettings().retrySettings(), newTimeout);
updateTimeout(
settingsBuilder.stubSettings().sampleRowKeysSettings().retrySettings(), newTimeout);
updateTimeout(
settingsBuilder.stubSettings().executeQuerySettings().retrySettings(), newTimeout);

return settingsBuilder;
}
Expand Down Expand Up @@ -187,19 +152,26 @@ private CbtClient getClient(String id) throws StatusException {
@Override
public synchronized void createClient(
CreateClientRequest request, StreamObserver<CreateClientResponse> responseObserver) {

Preconditions.checkArgument(!request.getClientId().isEmpty(), "client id must be provided");
Preconditions.checkArgument(!request.getProjectId().isEmpty(), "project id must be provided");
Preconditions.checkArgument(!request.getInstanceId().isEmpty(), "instance id must be provided");
Preconditions.checkArgument(!request.getDataTarget().isEmpty(), "data target must be provided");
Preconditions.checkArgument(
!request.getSecurityOptions().getUseSsl()
|| !request.getSecurityOptions().getSslRootCertsPemBytes().isEmpty(),
"security_options.ssl_root_certs_pem must be provided if security_options.use_ssl is true");

if (idClientMap.contains(request.getClientId())) {
if (idClientMap.containsKey(request.getClientId())) {
responseObserver.onError(
Status.ALREADY_EXISTS
.withDescription("Client " + request.getClientId() + " already exists.")
.asException());
return;
}

// setRefreshingChannel is needed for now.
@SuppressWarnings("deprecation")
BigtableDataSettings.Builder settingsBuilder =
BigtableDataSettings.newBuilder()
// Disable channel refreshing when not using the real server
Expand All @@ -208,9 +180,6 @@ public synchronized void createClient(
.setInstanceId(request.getInstanceId())
.setAppProfileId(request.getAppProfileId());

settingsBuilder.stubSettings().setEnableRoutingCookie(false);
settingsBuilder.stubSettings().setEnableRetryInfo(false);

if (request.hasPerOperationTimeout()) {
Duration newTimeout = Duration.ofMillis(Durations.toMillis(request.getPerOperationTimeout()));
settingsBuilder = overrideTimeoutSetting(newTimeout, settingsBuilder);
Expand Down Expand Up @@ -244,8 +213,13 @@ public synchronized void createClient(
settingsBuilder
.stubSettings()
.setEndpoint(request.getDataTarget())
.setTransportChannelProvider(getTransportChannel())
.setCredentialsProvider(getCredentialsProvider());
.setTransportChannelProvider(
getTransportChannel(
request.getSecurityOptions().getUseSsl(),
request.getSecurityOptions().getSslRootCertsPem(),
request.getSecurityOptions().getSslEndpointOverride()))
.setCredentialsProvider(
getCredentialsProvider(request.getSecurityOptions().getAccessToken()));
}
BigtableDataSettings settings = settingsBuilder.build();
BigtableDataClient client = BigtableDataClient.create(settings);
Expand Down Expand Up @@ -698,6 +672,64 @@ public void readModifyWriteRow(
responseObserver.onCompleted();
}

@Override
public void executeQuery(
ExecuteQueryRequest request, StreamObserver<ExecuteQueryResult> responseObserver) {
CbtClient client;
try {
client = getClient(request.getClientId());
} catch (StatusException e) {
responseObserver.onError(e);
return;
}
try (ResultSet resultSet =
client.dataClient().executeQuery(StatementDeserializer.toStatement(request))) {
responseObserver.onNext(ResultSetSerializer.toExecuteQueryResult(resultSet));
} catch (InterruptedException e) {
responseObserver.onError(e);
return;
} catch (ExecutionException e) {
responseObserver.onError(e);
return;
} catch (ApiException e) {
responseObserver.onNext(
ExecuteQueryResult.newBuilder()
.setStatus(
com.google.rpc.Status.newBuilder()
.setCode(e.getStatusCode().getCode().ordinal())
.setMessage(e.getMessage())
.build())
.build());
responseObserver.onCompleted();
return;
} catch (StatusRuntimeException e) {
responseObserver.onNext(
ExecuteQueryResult.newBuilder()
.setStatus(
com.google.rpc.Status.newBuilder()
.setCode(e.getStatus().getCode().value())
.setMessage(e.getStatus().getDescription())
.build())
.build());
responseObserver.onCompleted();
return;
} catch (RuntimeException e) {
// If client encounters problem, don't return any results.
responseObserver.onNext(
ExecuteQueryResult.newBuilder()
.setStatus(
com.google.rpc.Status.newBuilder()
.setCode(Code.INTERNAL.getNumber())
.setMessage(e.getMessage())
.build())
.build());
responseObserver.onCompleted();
return;
}
responseObserver.onCompleted();
return;
}

@Override
public synchronized void close() {
Iterator<Map.Entry<String, CbtClient>> it = idClientMap.entrySet().iterator();
Expand All @@ -717,52 +749,60 @@ private static String extractTableIdFromTableName(String fullTableName)
return matcher.group(3);
}

private InstantiatingGrpcChannelProvider getTransportChannel() throws IOException {
@SuppressWarnings("rawtypes")
private InstantiatingGrpcChannelProvider getTransportChannel(
boolean encrypted, String rootCertsPem, String sslTarget) {
if (!encrypted) {
return EnhancedBigtableStubSettings.defaultGrpcTransportProviderBuilder()
.setChannelConfigurator(ManagedChannelBuilder::usePlaintext)
.build();
}

if (rootCerts == null) {
return EnhancedBigtableStubSettings.defaultGrpcTransportProviderBuilder().build();
final SslContext sslContext;
if (rootCertsPem.isEmpty()) {
sslContext = null;
} else {
try {
sslContext =
GrpcSslContexts.forClient()
.trustManager(new ByteArrayInputStream(rootCertsPem.getBytes(UTF_8)))
.build();
} catch (IOException e) {
throw new IllegalArgumentException(e);
}
}

final SslContext secureContext =
GrpcSslContexts.forClient()
.trustManager(new ByteArrayInputStream(rootCerts.getBytes(UTF_8)))
.build();
return EnhancedBigtableStubSettings.defaultGrpcTransportProviderBuilder()
.setChannelConfigurator(
new ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder>() {
@Override
public ManagedChannelBuilder apply(ManagedChannelBuilder input) {
NettyChannelBuilder channelBuilder = (NettyChannelBuilder) input;
channelBuilder.sslContext(secureContext).overrideAuthority(sslTarget);

if (sslContext != null) {
channelBuilder.sslContext(sslContext);
}

if (!sslTarget.isEmpty()) {
channelBuilder.overrideAuthority(sslTarget);
}

return channelBuilder;
}
})
.build();
}

private CredentialsProvider getCredentialsProvider() throws IOException {
if (credential == null) {
private CredentialsProvider getCredentialsProvider(String accessToken) {
if (accessToken.isEmpty()) {
return NoCredentialsProvider.create();
}

final GoogleCredentials creds =
GoogleCredentials.fromStream(new ByteArrayInputStream(credential.getBytes(UTF_8)));

return FixedCredentialsProvider.create(creds);
return FixedCredentialsProvider.create(
OAuth2Credentials.create(new AccessToken(accessToken, null)));
}

private final ConcurrentHashMap<String, CbtClient> idClientMap;
private final boolean encrypted;

// Parameters that may be needed when "encrypted" is true.
private final String rootCerts;
private final String sslTarget;
private final String credential;

private static final Pattern tablePattern =
Pattern.compile("projects/([^/]+)/instances/([^/]+)/tables/([^/]+)");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,7 @@ public static void main(String[] args) throws InterruptedException, IOException
throw new IllegalArgumentException(String.format("Port %d is not > 0.", port));
}

CbtTestProxy cbtTestProxy;

// If encryption is specified
boolean encrypted = Boolean.getBoolean("encrypted");
if (encrypted) {
String rootCertsPemPath = System.getProperty("root.certs.pem.path");
String sslTarget = System.getProperty("ssl.target");
String credentialJsonPath = System.getProperty("credential.json.path");
cbtTestProxy = CbtTestProxy.createEncrypted(rootCertsPemPath, sslTarget, credentialJsonPath);
} else {
cbtTestProxy = CbtTestProxy.createUnencrypted();
}

CbtTestProxy cbtTestProxy = CbtTestProxy.create();
logger.info(String.format("Test proxy starting on %d", port));
ServerBuilder.forPort(port).addService(cbtTestProxy).build().start().awaitTermination();
}
Expand Down
Loading
Loading