diff --git a/docs/source/java/flight_sql_jdbc_driver.rst b/docs/source/java/flight_sql_jdbc_driver.rst index 34ccfea47f9e3..0ace2185983a9 100644 --- a/docs/source/java/flight_sql_jdbc_driver.rst +++ b/docs/source/java/flight_sql_jdbc_driver.rst @@ -141,6 +141,17 @@ case-sensitive. The supported parameters are: - true - When TLS is enabled, whether to use the system certificate store + * - retainCookies + - true + - Whether to use cookies from the initial connection in subsequent + internal connections when retrieving streams from separate endpoints. + + * - retainAuth + - true + - Whether to use bearer tokens obtained from the initial connection + in subsequent internal connections used for retrieving streams + from separate endpoints. + Note that URI values must be URI-encoded if they contain characters such as !, @, $, etc. diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientIncomingAuthHeaderMiddleware.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientIncomingAuthHeaderMiddleware.java index be5f3f54d326c..7bb55d145d104 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientIncomingAuthHeaderMiddleware.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientIncomingAuthHeaderMiddleware.java @@ -34,7 +34,7 @@ public class ClientIncomingAuthHeaderMiddleware implements FlightClientMiddlewar */ public static class Factory implements FlightClientMiddleware.Factory { private final ClientHeaderHandler headerHandler; - private CredentialCallOption credentialCallOption; + private CredentialCallOption credentialCallOption = null; /** * Construct a factory with the given header handler. diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java index fdbb9381c0a55..ad19c616ff29a 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java @@ -109,6 +109,8 @@ private static ArrowFlightSqlClientHandler createNewClientHandler( .withDisableCertificateVerification(config.getDisableCertificateVerification()) .withToken(config.getToken()) .withCallOptions(config.toCallOption()) + .withRetainCookies(config.retainCookies()) + .withRetainAuth(config.retainAuth()) .build(); } catch (final SQLException e) { try { diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java index 75e80d45dc669..d6bcfce4e9ddb 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java @@ -49,6 +49,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.util.Preconditions; +import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.calcite.avatica.Meta.StatementType; @@ -425,21 +426,59 @@ public static final class Builder { private final Set options = new HashSet<>(); private String host; private int port; - private String username; - private String password; - private String trustStorePath; - private String trustStorePassword; - private String token; - private boolean useEncryption; - private boolean disableCertificateVerification; - private boolean useSystemTrustStore; - private String tlsRootCertificatesPath; - private String clientCertificatePath; - private String clientKeyPath; + + @VisibleForTesting + String username; + + @VisibleForTesting + String password; + + @VisibleForTesting + String trustStorePath; + + @VisibleForTesting + String trustStorePassword; + + @VisibleForTesting + String token; + + @VisibleForTesting + boolean useEncryption = true; + + @VisibleForTesting + boolean disableCertificateVerification; + + @VisibleForTesting + boolean useSystemTrustStore = true; + + @VisibleForTesting + String tlsRootCertificatesPath; + + @VisibleForTesting + String clientCertificatePath; + + @VisibleForTesting + String clientKeyPath; + + @VisibleForTesting private BufferAllocator allocator; - public Builder() { + @VisibleForTesting + boolean retainCookies = true; + + @VisibleForTesting + boolean retainAuth = true; + + // These two middlewares are for internal use within build() and should not be exposed by builder APIs. + // Note that these middlewares may not necessarily be registered. + @VisibleForTesting + ClientIncomingAuthHeaderMiddleware.Factory authFactory + = new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler()); + @VisibleForTesting + ClientCookieMiddleware.Factory cookieFactory = new ClientCookieMiddleware.Factory(); + + public Builder() { } /** @@ -447,7 +486,8 @@ public Builder() { * * @param original The builder to base this copy off of. */ - private Builder(Builder original) { + @VisibleForTesting + Builder(Builder original) { this.middlewareFactories.addAll(original.middlewareFactories); this.options.addAll(original.options); this.host = original.host; @@ -464,6 +504,14 @@ private Builder(Builder original) { this.clientCertificatePath = original.clientCertificatePath; this.clientKeyPath = original.clientKeyPath; this.allocator = original.allocator; + + if (original.retainCookies) { + this.cookieFactory = original.cookieFactory; + } + + if (original.retainAuth) { + this.authFactory = original.authFactory; + } } /** @@ -622,6 +670,28 @@ public Builder withBufferAllocator(final BufferAllocator allocator) { return this; } + /** + * Indicates if cookies should be re-used by connections spawned for getStreams() calls. + * @param retainCookies The flag indicating if cookies should be re-used. + * @return this builder instance. + */ + public Builder withRetainCookies(boolean retainCookies) { + this.retainCookies = retainCookies; + return this; + } + + /** + * Indicates if bearer tokens negotiated should be re-used by connections + * spawned for getStreams() calls. + * + * @param retainAuth The flag indicating if auth tokens should be re-used. + * @return this builder instance. + */ + public Builder withRetainAuth(boolean retainAuth) { + this.retainAuth = retainAuth; + return this; + } + /** * Adds the provided {@code factories} to the list of {@link #middlewareFactories} of this handler. * @@ -675,13 +745,11 @@ public ArrowFlightSqlClientHandler build() throws SQLException { // Copy middlewares so that the build method doesn't change the state of the builder fields itself. Set buildTimeMiddlewareFactories = new HashSet<>(this.middlewareFactories); FlightClient client = null; + boolean isUsingUserPasswordAuth = username != null && token == null; try { - ClientIncomingAuthHeaderMiddleware.Factory authFactory = null; // Token should take priority since some apps pass in a username/password even when a token is provided - if (username != null && token == null) { - authFactory = - new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler()); + if (isUsingUserPasswordAuth) { buildTimeMiddlewareFactories.add(authFactory); } final FlightClient.Builder clientBuilder = FlightClient.builder().allocator(allocator); @@ -722,10 +790,17 @@ public ArrowFlightSqlClientHandler build() throws SQLException { client = clientBuilder.build(); final ArrayList credentialOptions = new ArrayList<>(); - if (authFactory != null) { - credentialOptions.add( - ClientAuthenticationUtils.getAuthenticate( - client, username, password, authFactory, options.toArray(new CallOption[0]))); + if (username != null && token == null) { + // If the authFactory has already been used for a handshake, use the existing token. + // This can occur if the authFactory is being re-used for a new connection spawned for getStream(). + if (authFactory.getCredentialCallOption() != null) { + credentialOptions.add(authFactory.getCredentialCallOption()); + } else { + // Otherwise do the handshake and get the token if possible. + credentialOptions.add( + ClientAuthenticationUtils.getAuthenticate( + client, username, password, authFactory, options.toArray(new CallOption[0]))); + } } else if (token != null) { credentialOptions.add( ClientAuthenticationUtils.getAuthenticate( diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java index 59118e1d6f788..6237a8b58d68a 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java @@ -143,6 +143,22 @@ public int threadPoolSize() { return ArrowFlightConnectionProperty.THREAD_POOL_SIZE.getInteger(properties); } + /** + * Indicates if sub-connections created for stream retrieval + * should reuse cookies from the main connection. + */ + public boolean retainCookies() { + return ArrowFlightConnectionProperty.RETAIN_COOKIES.getBoolean(properties); + } + + /** + * Indicates if sub-connections created for stream retrieval + * should reuse bearer tokens created from the main connection. + */ + public boolean retainAuth() { + return ArrowFlightConnectionProperty.RETAIN_AUTH.getBoolean(properties); + } + /** * Gets the {@link CallOption}s from this {@link ConnectionConfig}. * @@ -191,7 +207,9 @@ public enum ArrowFlightConnectionProperty implements ConnectionProperty { CLIENT_CERTIFICATE("clientCertificate", null, Type.STRING, false), CLIENT_KEY("clientKey", null, Type.STRING, false), THREAD_POOL_SIZE("threadPoolSize", 1, Type.NUMBER, false), - TOKEN("token", null, Type.STRING, false); + TOKEN("token", null, Type.STRING, false), + RETAIN_COOKIES("retainCookies", true, Type.BOOLEAN, false), + RETAIN_AUTH("retainAuth", true, Type.BOOLEAN, false); private final String camelName; private final Object defaultValue; diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTest.java index bec0ff1e59752..ff0b528b3bf9d 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTest.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTest.java @@ -161,6 +161,7 @@ public void testGetBasicClientAuthenticatedShouldOpenConnection() new ArrowFlightSqlClientHandler.Builder() .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) .withPort(FLIGHT_SERVER_TEST_RULE.getPort()) + .withEncryption(false) .withUsername(userTest) .withPassword(passTest) .withBufferAllocator(allocator) diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTlsTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTlsTest.java index 95d591766a836..b8b9a3809b35d 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTlsTest.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTlsTest.java @@ -127,6 +127,7 @@ public void testGetEncryptedClientAuthenticated() throws Exception { new ArrowFlightSqlClientHandler.Builder() .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) .withPort(FLIGHT_SERVER_TEST_RULE.getPort()) + .withSystemTrustStore(false) .withUsername(credentials.getUserName()) .withPassword(credentials.getPassword()) .withTrustStorePath(trustStorePath) @@ -153,6 +154,7 @@ public void testGetEncryptedClientWithNoCertificateOnKeyStore() throws Exception .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) .withTrustStorePath(noCertificateKeyStorePath) .withTrustStorePassword(noCertificateKeyStorePassword) + .withSystemTrustStore(false) .withBufferAllocator(allocator) .withEncryption(true) .build()) { @@ -170,6 +172,7 @@ public void testGetNonAuthenticatedEncryptedClientNoAuth() throws Exception { try (ArrowFlightSqlClientHandler client = new ArrowFlightSqlClientHandler.Builder() .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) + .withSystemTrustStore(false) .withTrustStorePath(trustStorePath) .withTrustStorePassword(trustStorePass) .withBufferAllocator(allocator) @@ -192,6 +195,7 @@ public void testGetEncryptedClientWithKeyStoreBadPasswordAndNoAuth() throws Exce try (ArrowFlightSqlClientHandler ignored = new ArrowFlightSqlClientHandler.Builder() .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) + .withSystemTrustStore(false) .withTrustStorePath(trustStorePath) .withTrustStorePassword(keyStoreBadPassword) .withBufferAllocator(allocator) diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/FlightServerTestRule.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/FlightServerTestRule.java index df7cbea56ee2f..39eb0a29866f1 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/FlightServerTestRule.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/FlightServerTestRule.java @@ -55,6 +55,9 @@ * and interact with it. */ public class FlightServerTestRule implements TestRule, AutoCloseable { + public static final String DEFAULT_USER = "flight-test-user"; + public static final String DEFAULT_PASSWORD = "flight-test-password"; + private static final Logger LOGGER = LoggerFactory.getLogger(FlightServerTestRule.class); private final Properties properties; @@ -92,7 +95,7 @@ private FlightServerTestRule(final Properties properties, public static FlightServerTestRule createStandardTestRule(final FlightSqlProducer producer) { UserPasswordAuthentication authentication = new UserPasswordAuthentication.Builder() - .user("flight-test-user", "flight-test-password") + .user(DEFAULT_USER, DEFAULT_PASSWORD) .build(); return new Builder() diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandlerBuilderTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandlerBuilderTest.java new file mode 100644 index 0000000000000..6565a85ddf99f --- /dev/null +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandlerBuilderTest.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.client; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +import org.apache.arrow.driver.jdbc.FlightServerTestRule; +import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; + +/** + * Test the behavior of ArrowFlightSqlClientHandler.Builder + */ +public class ArrowFlightSqlClientHandlerBuilderTest { + @ClassRule + public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE = FlightServerTestRule + .createStandardTestRule(CoreMockedSqlProducers.getLegacyProducer()); + + private static BufferAllocator allocator; + + @BeforeClass + public static void setup() { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @AfterClass + public static void tearDown() { + allocator.close(); + } + + @Test + public void testRetainCookiesOnAuthOff() throws Exception { + // Arrange + final ArrowFlightSqlClientHandler.Builder rootBuilder = new ArrowFlightSqlClientHandler.Builder() + .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) + .withPort(FLIGHT_SERVER_TEST_RULE.getPort()) + .withBufferAllocator(allocator) + .withUsername(FlightServerTestRule.DEFAULT_USER) + .withPassword(FlightServerTestRule.DEFAULT_PASSWORD) + .withEncryption(false) + .withRetainCookies(true) + .withRetainAuth(false); + + try (ArrowFlightSqlClientHandler rootHandler = rootBuilder.build()) { + // Act + final ArrowFlightSqlClientHandler.Builder testBuilder = new ArrowFlightSqlClientHandler.Builder(rootBuilder); + + // Assert + assertSame(rootBuilder.cookieFactory, testBuilder.cookieFactory); + assertNotSame(rootBuilder.authFactory, testBuilder.authFactory); + } + } + + @Test + public void testRetainCookiesOffAuthOff() throws Exception { + // Arrange + final ArrowFlightSqlClientHandler.Builder rootBuilder = new ArrowFlightSqlClientHandler.Builder() + .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) + .withPort(FLIGHT_SERVER_TEST_RULE.getPort()) + .withBufferAllocator(allocator) + .withUsername(FlightServerTestRule.DEFAULT_USER) + .withPassword(FlightServerTestRule.DEFAULT_PASSWORD) + .withEncryption(false) + .withRetainCookies(false) + .withRetainAuth(false); + + try (ArrowFlightSqlClientHandler rootHandler = rootBuilder.build()) { + // Act + final ArrowFlightSqlClientHandler.Builder testBuilder = new ArrowFlightSqlClientHandler.Builder(rootBuilder); + + // Assert + assertNotSame(rootBuilder.cookieFactory, testBuilder.cookieFactory); + assertNotSame(rootBuilder.authFactory, testBuilder.authFactory); + } + } + + @Test + public void testRetainCookiesOnAuthOn() throws Exception { + // Arrange + final ArrowFlightSqlClientHandler.Builder rootBuilder = new ArrowFlightSqlClientHandler.Builder() + .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) + .withPort(FLIGHT_SERVER_TEST_RULE.getPort()) + .withBufferAllocator(allocator) + .withUsername(FlightServerTestRule.DEFAULT_USER) + .withPassword(FlightServerTestRule.DEFAULT_PASSWORD) + .withEncryption(false) + .withRetainCookies(true) + .withRetainAuth(true); + + try (ArrowFlightSqlClientHandler rootHandler = rootBuilder.build()) { + // Act + final ArrowFlightSqlClientHandler.Builder testBuilder = new ArrowFlightSqlClientHandler.Builder(rootBuilder); + + // Assert + assertSame(rootBuilder.cookieFactory, testBuilder.cookieFactory); + assertSame(rootBuilder.authFactory, testBuilder.authFactory); + } + } + + @Test + public void testDefaults() { + final ArrowFlightSqlClientHandler.Builder builder = new ArrowFlightSqlClientHandler.Builder(); + + // Validate all non-mandatory fields against defaults in ArrowFlightConnectionProperty. + assertNull(builder.username); + assertNull(builder.password); + assertTrue(builder.useEncryption); + assertFalse(builder.disableCertificateVerification); + assertNull(builder.trustStorePath); + assertNull(builder.trustStorePassword); + assertTrue(builder.useSystemTrustStore); + assertNull(builder.token); + assertTrue(builder.retainAuth); + assertTrue(builder.retainCookies); + assertNull(builder.tlsRootCertificatesPath); + assertNull(builder.clientCertificatePath); + assertNull(builder.clientKeyPath); + } +}