From 0f517a91b0d8702bb5a8e20fb42597cce48f4d38 Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Tue, 19 Nov 2024 14:32:05 -0800 Subject: [PATCH] plumb mtls endpoint to grpc channel provider. --- .../InstantiatingGrpcChannelProvider.java | 38 ++++++++++++++++++- .../api/gax/grpc/GrpcLongRunningTest.java | 2 + .../grpc/testing/LocalChannelProvider.java | 10 +++++ .../InstantiatingHttpJsonChannelProvider.java | 10 +++++ .../com/google/api/gax/rpc/ClientContext.java | 4 ++ .../google/api/gax/rpc/EndpointContext.java | 4 -- .../rpc/FixedTransportChannelProvider.java | 11 ++++++ .../api/gax/rpc/TransportChannelProvider.java | 10 +++++ .../google/api/gax/rpc/ClientContextTest.java | 16 ++++++++ 9 files changed, 99 insertions(+), 6 deletions(-) diff --git a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java index 8cad9f0383..eeff4f34e4 100644 --- a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java +++ b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java @@ -123,6 +123,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP private final HeaderProvider headerProvider; private final boolean useS2A; private final String endpoint; + private final String mtlsEndpoint; // TODO: remove. envProvider currently provides DirectPath environment variable, and is only used // during initial rollout for DirectPath. This provider will be removed once the DirectPath // environment is not used. @@ -152,6 +153,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) { this.executor = builder.executor; this.headerProvider = builder.headerProvider; this.endpoint = builder.endpoint; + this.mtlsEndpoint = builder.mtlsEndpoint; this.useS2A = builder.useS2A; this.mtlsProvider = builder.mtlsProvider; this.s2aConfigProvider = builder.s2aConfigProvider; @@ -229,6 +231,11 @@ public boolean needsEndpoint() { return endpoint == null; } + @Override + public boolean needsMtlsEndpoint() { + return mtlsEndpoint == null; + } + /** * Specify the endpoint the channel should connect to. * @@ -243,6 +250,21 @@ public TransportChannelProvider withEndpoint(String endpoint) { return toBuilder().setEndpoint(endpoint).build(); } + /** + * Specify the mtlsEndpoint the channel should connect to. + * + *

The value of {@code mtlsEndpoint} must be of the form {@code host:port}. + * + * @param mtlsEndpoint The mtlsEndpoint to connect to + * @return A new {@link InstantiatingGrpcChannelProvider} with the specified mtlsEndpoint + * configured + */ + @Override + public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) { + validateEndpoint(mtlsEndpoint); + return toBuilder().setMtlsEndpoint(mtlsEndpoint).build(); + } + /** * Specify whether or not to use S2A. * @@ -590,8 +612,7 @@ private ManagedChannel createSingleChannel() throws IOException { } if (channelCredentials != null) { // Create the channel using S2A-secured channel credentials. - // {@code endpoint} is set to mtlsEndpoint in {@link EndpointContext} when useS2A is true. - builder = Grpc.newChannelBuilder(endpoint, channelCredentials); + builder = Grpc.newChannelBuilder(mtlsEndpoint, channelCredentials); } else { // Use default if we cannot initialize channel credentials via DCA or S2A. builder = ManagedChannelBuilder.forAddress(serviceAddress, port); @@ -743,6 +764,7 @@ public static final class Builder { private Executor executor; private HeaderProvider headerProvider; private String endpoint; + private String mtlsEndpoint; private boolean useS2A; private EnvironmentProvider envProvider; private SecureSessionAgent s2aConfigProvider = SecureSessionAgent.create(); @@ -773,6 +795,7 @@ private Builder(InstantiatingGrpcChannelProvider provider) { this.executor = provider.executor; this.headerProvider = provider.headerProvider; this.endpoint = provider.endpoint; + this.mtlsEndpoint = provider.mtlsEndpoint; this.useS2A = provider.useS2A; this.envProvider = provider.envProvider; this.interceptorProvider = provider.interceptorProvider; @@ -843,6 +866,13 @@ public Builder setEndpoint(String endpoint) { return this; } + /** Sets the mtlsEndpoint used to reach the service, eg "localhost:8080". */ + public Builder setMtlsEndpoint(String mtlsEndpoint) { + validateEndpoint(mtlsEndpoint); + this.mtlsEndpoint = mtlsEndpoint; + return this; + } + Builder setUseS2A(boolean useS2A) { this.useS2A = useS2A; return this; @@ -876,6 +906,10 @@ public String getEndpoint() { return endpoint; } + public String getMtlsEndpoint() { + return mtlsEndpoint; + } + /** The maximum message size allowed to be received on the channel. */ public Builder setMaxInboundMessageSize(Integer max) { this.maxInboundMessageSize = max; diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcLongRunningTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcLongRunningTest.java index ac88e4acec..f0fc4278c3 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcLongRunningTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcLongRunningTest.java @@ -103,6 +103,8 @@ void setUp() throws IOException { when(operationsChannelProvider.getTransportChannel()).thenReturn(transportChannel); when(operationsChannelProvider.withUseS2A(Mockito.any(boolean.class))) .thenReturn(operationsChannelProvider); + when(operationsChannelProvider.withMtlsEndpoint(Mockito.any(String.class))) + .thenReturn(operationsChannelProvider); clock = new FakeApiClock(0L); executor = RecordingScheduler.create(clock); diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/testing/LocalChannelProvider.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/testing/LocalChannelProvider.java index 856a2850bb..61c74c3608 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/testing/LocalChannelProvider.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/testing/LocalChannelProvider.java @@ -101,11 +101,21 @@ public boolean needsEndpoint() { return false; } + @Override + public boolean needsMtlsEndpoint() { + return false; + } + @Override public TransportChannelProvider withEndpoint(String endpoint) { throw new UnsupportedOperationException("LocalChannelProvider doesn't need an endpoint"); } + @Override + public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) { + throw new UnsupportedOperationException("LocalChannelProvider doesn't need an mtlsEndpoint"); + } + @Override public TransportChannelProvider withUseS2A(boolean useS2A) { // Overriden for technical reasons. This method is a no-op for LocalChannelProvider. diff --git a/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java b/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java index 170b955c2a..7dd5dac613 100644 --- a/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java +++ b/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java @@ -119,11 +119,21 @@ public boolean needsEndpoint() { return endpoint == null; } + @Override + public boolean needsMtlsEndpoint() { + return false; + } + @Override public TransportChannelProvider withEndpoint(String endpoint) { return toBuilder().setEndpoint(endpoint).build(); } + @Override + public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) { + return this; + } + @Override public TransportChannelProvider withUseS2A(boolean useS2A) { return this; diff --git a/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientContext.java b/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientContext.java index 8e7c9a3090..2b10c0ae45 100644 --- a/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientContext.java +++ b/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientContext.java @@ -222,6 +222,10 @@ public static ClientContext create(StubSettings settings) throws IOException { if (transportChannelProvider.needsEndpoint()) { transportChannelProvider = transportChannelProvider.withEndpoint(endpoint); } + if (transportChannelProvider.needsMtlsEndpoint()) { + transportChannelProvider = + transportChannelProvider.withMtlsEndpoint(endpointContext.mtlsEndpoint()); + } transportChannelProvider = transportChannelProvider.withUseS2A(endpointContext.useS2A()); TransportChannel transportChannel = transportChannelProvider.getTransportChannel(); diff --git a/gax-java/gax/src/main/java/com/google/api/gax/rpc/EndpointContext.java b/gax-java/gax/src/main/java/com/google/api/gax/rpc/EndpointContext.java index 0148c07a01..9b89082df7 100644 --- a/gax-java/gax/src/main/java/com/google/api/gax/rpc/EndpointContext.java +++ b/gax-java/gax/src/main/java/com/google/api/gax/rpc/EndpointContext.java @@ -272,10 +272,6 @@ private String determineUniverseDomain() { /** Determines the fully resolved endpoint and universe domain values */ private String determineEndpoint() throws IOException { - if (shouldUseS2A()) { - return mtlsEndpoint(); - } - MtlsProvider mtlsProvider = mtlsProvider() == null ? new MtlsProvider() : mtlsProvider(); // TransportChannelProvider's endpoint will override the ClientSettings' endpoint String customEndpoint = diff --git a/gax-java/gax/src/main/java/com/google/api/gax/rpc/FixedTransportChannelProvider.java b/gax-java/gax/src/main/java/com/google/api/gax/rpc/FixedTransportChannelProvider.java index 2f70c06b5f..1d0b54f75d 100644 --- a/gax-java/gax/src/main/java/com/google/api/gax/rpc/FixedTransportChannelProvider.java +++ b/gax-java/gax/src/main/java/com/google/api/gax/rpc/FixedTransportChannelProvider.java @@ -83,12 +83,23 @@ public boolean needsEndpoint() { return false; } + @Override + public boolean needsMtlsEndpoint() { + return false; + } + @Override public TransportChannelProvider withEndpoint(String endpoint) { throw new UnsupportedOperationException( "FixedTransportChannelProvider doesn't need an endpoint"); } + @Override + public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) { + throw new UnsupportedOperationException( + "FixedTransportChannelProvider doesn't need an mtlsEndpoint"); + } + @Override public TransportChannelProvider withUseS2A(boolean useS2A) throws UnsupportedOperationException { // Overriden for technical reasons. This method is a no-op for FixedTransportChannelProvider. diff --git a/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannelProvider.java b/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannelProvider.java index f58acffc54..d59cbfa0e6 100644 --- a/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannelProvider.java +++ b/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannelProvider.java @@ -90,6 +90,9 @@ public interface TransportChannelProvider { /** True if the TransportProvider has no endpoint set. */ boolean needsEndpoint(); + /** True if the TransportProvider has no mtlsEndpoint set. */ + boolean needsMtlsEndpoint(); + /** * Sets the endpoint to use when constructing a new {@link TransportChannel}. * @@ -97,6 +100,13 @@ public interface TransportChannelProvider { */ TransportChannelProvider withEndpoint(String endpoint); + /** + * Sets the mtlsEndpoint to use when constructing a new {@link TransportChannel}. + * + *

This method should only be called if {@link #needsMtlsEndpoint()} returns true. + */ + TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint); + /** Sets whether to use S2A when constructing a new {@link TransportChannel}. */ default TransportChannelProvider withUseS2A(boolean useS2A) { throw new UnsupportedOperationException("S2A is not supported"); diff --git a/gax-java/gax/src/test/java/com/google/api/gax/rpc/ClientContextTest.java b/gax-java/gax/src/test/java/com/google/api/gax/rpc/ClientContextTest.java index facc93ed86..46ba62079e 100644 --- a/gax-java/gax/src/test/java/com/google/api/gax/rpc/ClientContextTest.java +++ b/gax-java/gax/src/test/java/com/google/api/gax/rpc/ClientContextTest.java @@ -179,6 +179,11 @@ public boolean needsEndpoint() { return true; } + @Override + public boolean needsMtlsEndpoint() { + return false; + } + @Override public String getEndpoint() { return endpoint; @@ -195,6 +200,17 @@ public TransportChannelProvider withEndpoint(String endpoint) { endpoint); } + @Override + public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) { + return new FakeTransportProvider( + this.transport, + this.executor, + this.shouldAutoClose, + this.headers, + this.credentials, + this.endpoint); + } + @Override public TransportChannelProvider withUseS2A(boolean useS2A) { return new FakeTransportProvider(