From 845f9c2d98416ae6871680c4afd98d7ff2e51333 Mon Sep 17 00:00:00 2001 From: Bruce Irschick Date: Wed, 4 Dec 2024 10:27:43 -0800 Subject: [PATCH] feat(csharp/src/Drivers/Apache): add connect and query timeout options (#2312) Adds options for command and query timeout | Property | Description | Default | | :--- | :--- | :--- | | `adbc.spark.connect_timeout_ms` | Sets the timeout (in milliseconds) to open a new session. Values can be 0 (infinite) or greater than zero. | `30000` | | `adbc.apache.statement.query_timeout_s` | Sets the maximum time (in seconds) for a query to complete. Values can be 0 (infinite) or greater than zero. | `60` | --------- Co-authored-by: Aman Goyal Co-authored-by: David Coe --- csharp/src/Client/AdbcCommand.cs | 28 +- csharp/src/Drivers/Apache/ApacheParameters.cs | 29 ++ csharp/src/Drivers/Apache/ApacheUtility.cs | 141 ++++++ .../Apache/Hive2/HiveServer2Connection.cs | 106 +++-- .../Apache/Hive2/HiveServer2Parameters.cs | 2 - .../Drivers/Apache/Hive2/HiveServer2Reader.cs | 34 +- .../Apache/Hive2/HiveServer2Statement.cs | 134 ++++-- .../Drivers/Apache/Impala/ImpalaConnection.cs | 6 +- .../Drivers/Apache/Impala/ImpalaStatement.cs | 2 +- csharp/src/Drivers/Apache/Spark/README.md | 15 +- .../Drivers/Apache/Spark/SparkConnection.cs | 438 ++++++++++-------- .../Apache/Spark/SparkDatabricksConnection.cs | 19 +- .../Apache/Spark/SparkDatabricksReader.cs | 1 - .../Apache/Spark/SparkHttpConnection.cs | 69 ++- .../Drivers/Apache/Spark/SparkParameters.cs | 4 +- .../Apache/Spark/SparkStandardConnection.cs | 10 +- .../Drivers/Apache/Spark/SparkStatement.cs | 7 +- .../Drivers/Apache/ApacheTestConfiguration.cs | 9 +- .../test/Drivers/Apache/Common/ClientTests.cs | 22 + .../Drivers/Apache/Common/StatementTests.cs | 124 ++++- .../Apache/Spark/SparkConnectionTest.cs | 236 +++++++++- .../Apache/Spark/SparkTestEnvironment.cs | 13 +- .../Drivers/Apache/Spark/StatementTests.cs | 2 + 23 files changed, 1098 insertions(+), 353 deletions(-) create mode 100644 csharp/src/Drivers/Apache/ApacheParameters.cs create mode 100644 csharp/src/Drivers/Apache/ApacheUtility.cs diff --git a/csharp/src/Client/AdbcCommand.cs b/csharp/src/Client/AdbcCommand.cs index 8b85be2062..c3695feaf4 100644 --- a/csharp/src/Client/AdbcCommand.cs +++ b/csharp/src/Client/AdbcCommand.cs @@ -21,6 +21,7 @@ using System.Data; using System.Data.Common; using System.Data.SqlTypes; +using System.Globalization; using System.Linq; using System.Threading.Tasks; using Apache.Arrow.Types; @@ -32,10 +33,11 @@ namespace Apache.Arrow.Adbc.Client /// public sealed class AdbcCommand : DbCommand { - private AdbcStatement _adbcStatement; + private readonly AdbcStatement _adbcStatement; private AdbcParameterCollection? _dbParameterCollection; private int _timeout = 30; private bool _disposed; + private string? _commandTimeoutProperty; /// /// Overloaded. Initializes . @@ -117,10 +119,32 @@ public override CommandType CommandType } } + + /// + /// Gets or sets the name of the command timeout property for the underlying ADBC driver. + /// + public string AdbcCommandTimeoutProperty + { + get + { + if (string.IsNullOrEmpty(_commandTimeoutProperty)) + throw new InvalidOperationException("CommandTimeoutProperty is not set."); + + return _commandTimeoutProperty!; + } + set => _commandTimeoutProperty = value; + } + public override int CommandTimeout { get => _timeout; - set => _timeout = value; + set + { + // ensures the property exists before setting the CommandTimeout value + string property = AdbcCommandTimeoutProperty; + _adbcStatement.SetOption(property, value.ToString(CultureInfo.InvariantCulture)); + _timeout = value; + } } protected override DbParameterCollection DbParameterCollection diff --git a/csharp/src/Drivers/Apache/ApacheParameters.cs b/csharp/src/Drivers/Apache/ApacheParameters.cs new file mode 100644 index 0000000000..17c94be32a --- /dev/null +++ b/csharp/src/Drivers/Apache/ApacheParameters.cs @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace Apache.Arrow.Adbc.Drivers.Apache +{ + /// + /// Options common to all Apache drivers. + /// + public class ApacheParameters + { + public const string PollTimeMilliseconds = "adbc.apache.statement.polltime_ms"; + public const string BatchSize = "adbc.apache.statement.batch_size"; + public const string QueryTimeoutSeconds = "adbc.apache.statement.query_timeout_s"; + } +} diff --git a/csharp/src/Drivers/Apache/ApacheUtility.cs b/csharp/src/Drivers/Apache/ApacheUtility.cs new file mode 100644 index 0000000000..f1cb07e07f --- /dev/null +++ b/csharp/src/Drivers/Apache/ApacheUtility.cs @@ -0,0 +1,141 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Threading; + +namespace Apache.Arrow.Adbc.Drivers.Apache +{ + internal class ApacheUtility + { + internal const int QueryTimeoutSecondsDefault = 60; + + public enum TimeUnit + { + Seconds, + Milliseconds + } + + public static CancellationToken GetCancellationToken(int timeout, TimeUnit timeUnit) + { + TimeSpan span; + + if (timeout == 0 || timeout == int.MaxValue) + { + // the max TimeSpan for CancellationTokenSource is int.MaxValue in milliseconds (not TimeSpan.MaxValue) + // no matter what the unit is + span = TimeSpan.FromMilliseconds(int.MaxValue); + } + else + { + if (timeUnit == TimeUnit.Seconds) + { + span = TimeSpan.FromSeconds(timeout); + } + else + { + span = TimeSpan.FromMilliseconds(timeout); + } + } + + return GetCancellationToken(span); + } + + private static CancellationToken GetCancellationToken(TimeSpan timeSpan) + { + var cts = new CancellationTokenSource(timeSpan); + return cts.Token; + } + + public static bool QueryTimeoutIsValid(string key, string value, out int queryTimeoutSeconds) + { + if (!string.IsNullOrEmpty(value) && int.TryParse(value, out int queryTimeout) && (queryTimeout >= 0)) + { + queryTimeoutSeconds = queryTimeout; + return true; + } + else + { + throw new ArgumentOutOfRangeException(key, value, $"The value '{value}' for option '{key}' is invalid. Must be a numeric value of 0 (infinite) or greater."); + } + } + + public static bool ContainsException(Exception exception, out T? containedException) where T : Exception + { + if (exception is AggregateException aggregateException) + { + foreach (Exception? ex in aggregateException.InnerExceptions) + { + if (ex is T ce) + { + containedException = ce; + return true; + } + } + } + + Exception? e = exception; + while (e != null) + { + if (e is T ce) + { + containedException = ce; + return true; + } + e = e.InnerException; + } + + containedException = null; + return false; + } + + public static bool ContainsException(Exception exception, Type? exceptionType, out Exception? containedException) + { + if (exception == null || exceptionType == null) + { + containedException = null; + return false; + } + + if (exception is AggregateException aggregateException) + { + foreach (Exception? ex in aggregateException.InnerExceptions) + { + if (exceptionType.IsInstanceOfType(ex)) + { + containedException = ex; + return true; + } + } + } + + Exception? e = exception; + while (e != null) + { + if (exceptionType.IsInstanceOfType(e)) + { + containedException = e; + return true; + } + e = e.InnerException; + } + + containedException = null; + return false; + } + } +} diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs index c839bbaa70..d420edb2bf 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs @@ -30,7 +30,7 @@ internal abstract class HiveServer2Connection : AdbcConnection { internal const long BatchSizeDefault = 50000; internal const int PollTimeMillisecondsDefault = 500; - + private const int ConnectTimeoutMillisecondsDefault = 30000; private TTransport? _transport; private TCLIService.Client? _client; private readonly Lazy _vendorVersion; @@ -45,6 +45,14 @@ internal HiveServer2Connection(IReadOnlyDictionary properties) // https://learn.microsoft.com/en-us/dotnet/framework/performance/lazy-initialization#exceptions-in-lazy-objects _vendorVersion = new Lazy(() => GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_VER), LazyThreadSafetyMode.PublicationOnly); _vendorName = new Lazy(() => GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_NAME), LazyThreadSafetyMode.PublicationOnly); + + if (properties.TryGetValue(ApacheParameters.QueryTimeoutSeconds, out string? queryTimeoutSecondsSettingValue)) + { + if (ApacheUtility.QueryTimeoutIsValid(ApacheParameters.QueryTimeoutSeconds, queryTimeoutSecondsSettingValue, out int queryTimeoutSeconds)) + { + QueryTimeoutSeconds = queryTimeoutSeconds; + } + } } internal TCLIService.Client Client @@ -56,30 +64,48 @@ internal TCLIService.Client Client internal string VendorName => _vendorName.Value; + protected internal int QueryTimeoutSeconds { get; set; } = ApacheUtility.QueryTimeoutSecondsDefault; + internal IReadOnlyDictionary Properties { get; } internal async Task OpenAsync() { - TTransport transport = await CreateTransportAsync(); - TProtocol protocol = await CreateProtocolAsync(transport); - _transport = protocol.Transport; - _client = new TCLIService.Client(protocol); - TOpenSessionReq request = CreateSessionRequest(); - TOpenSessionResp? session = await Client.OpenSession(request); - - // Some responses don't raise an exception. Explicitly check the status. - if (session == null) + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(ConnectTimeoutMilliseconds, ApacheUtility.TimeUnit.Milliseconds); + try { - throw new HiveServer2Exception("unable to open session. unknown error."); + TTransport transport = CreateTransport(); + TProtocol protocol = await CreateProtocolAsync(transport, cancellationToken); + _transport = protocol.Transport; + _client = new TCLIService.Client(protocol); + TOpenSessionReq request = CreateSessionRequest(); + + TOpenSessionResp? session = await Client.OpenSession(request, cancellationToken); + + // Explicitly check the session status + if (session == null) + { + throw new HiveServer2Exception("Unable to open session. Unknown error."); + } + else if (session.Status.StatusCode != TStatusCode.SUCCESS_STATUS) + { + throw new HiveServer2Exception(session.Status.ErrorMessage) + .SetNativeError(session.Status.ErrorCode) + .SetSqlState(session.Status.SqlState); + } + + SessionHandle = session.SessionHandle; } - else if (session.Status.StatusCode != TStatusCode.SUCCESS_STATUS) + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) { - throw new HiveServer2Exception(session.Status.ErrorMessage) - .SetNativeError(session.Status.ErrorCode) - .SetSqlState(session.Status.SqlState); + throw new TimeoutException("The operation timed out while attempting to open a session. Please try increasing connect timeout.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + // Handle other exceptions if necessary + throw new HiveServer2Exception($"An unexpected error occurred while opening the session. '{ex.Message}'", ex); } - - SessionHandle = session.SessionHandle; } internal TSessionHandle? SessionHandle { get; private set; } @@ -88,11 +114,11 @@ internal async Task OpenAsync() protected internal HiveServer2TlsOption TlsOptions { get; set; } = HiveServer2TlsOption.Empty; - protected internal int HttpRequestTimeout { get; set; } = 30000; + protected internal int ConnectTimeoutMilliseconds { get; set; } = ConnectTimeoutMillisecondsDefault; - protected abstract Task CreateTransportAsync(); + protected abstract TTransport CreateTransport(); - protected abstract Task CreateProtocolAsync(TTransport transport); + protected abstract Task CreateProtocolAsync(TTransport transport, CancellationToken cancellationToken = default); protected abstract TOpenSessionReq CreateSessionRequest(); @@ -110,14 +136,14 @@ public override IArrowArrayStream GetTableTypes() throw new NotImplementedException(); } - internal static async Task PollForResponseAsync(TOperationHandle operationHandle, TCLIService.IAsync client, int pollTimeMilliseconds) + internal static async Task PollForResponseAsync(TOperationHandle operationHandle, TCLIService.IAsync client, int pollTimeMilliseconds, CancellationToken cancellationToken = default) { TGetOperationStatusResp? statusResponse = null; do { - if (statusResponse != null) { await Task.Delay(pollTimeMilliseconds); } + if (statusResponse != null) { await Task.Delay(pollTimeMilliseconds, cancellationToken); } TGetOperationStatusReq request = new(operationHandle); - statusResponse = await client.GetOperationStatus(request); + statusResponse = await client.GetOperationStatus(request, cancellationToken); } while (statusResponse.OperationState == TOperationState.PENDING_STATE || statusResponse.OperationState == TOperationState.RUNNING_STATE); } @@ -129,24 +155,38 @@ private string GetInfoTypeStringValue(TGetInfoType infoType) InfoType = infoType, }; - TGetInfoResp getInfoResp = Client.GetInfo(req).Result; - if (getInfoResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try { - throw new HiveServer2Exception(getInfoResp.Status.ErrorMessage) - .SetNativeError(getInfoResp.Status.ErrorCode) - .SetSqlState(getInfoResp.Status.SqlState); + TGetInfoResp getInfoResp = Client.GetInfo(req, cancellationToken).Result; + if (getInfoResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new HiveServer2Exception(getInfoResp.Status.ErrorMessage) + .SetNativeError(getInfoResp.Status.ErrorCode) + .SetSqlState(getInfoResp.Status.SqlState); + } + + return getInfoResp.InfoValue.StringValue; + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The metadata query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while running metadata query. '{ex.Message}'", ex); } - - return getInfoResp.InfoValue.StringValue; } public override void Dispose() { if (_client != null) { - TCloseSessionReq r6 = new TCloseSessionReq(SessionHandle); - _client.CloseSession(r6).Wait(); - + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + TCloseSessionReq r6 = new(SessionHandle); + _client.CloseSession(r6, cancellationToken).Wait(); _transport?.Close(); _client.Dispose(); _transport = null; diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs index 2170cd17b4..4f2bc62d21 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs @@ -15,8 +15,6 @@ * limitations under the License. */ -using System; - namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 { public static class DataTypeConversionOptions diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs index 08b0675d09..34dbf10f2c 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs @@ -25,6 +25,7 @@ using Apache.Arrow.Ipc; using Apache.Arrow.Types; using Apache.Hive.Service.Rpc.Thrift; +using Thrift.Transport; namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 { @@ -89,19 +90,32 @@ public HiveServer2Reader( return null; } - // Await the fetch response - TFetchResultsResp response = await FetchNext(_statement, cancellationToken); + try + { + // Await the fetch response + TFetchResultsResp response = await FetchNext(_statement, cancellationToken); + + int columnCount = GetColumnCount(response); + int rowCount = GetRowCount(response, columnCount); + if ((_statement.BatchSize > 0 && rowCount < _statement.BatchSize) || rowCount == 0) + { + // This is the last batch + _statement = null; + } - int columnCount = GetColumnCount(response); - int rowCount = GetRowCount(response, columnCount); - if ((_statement.BatchSize > 0 && rowCount < _statement.BatchSize) || rowCount == 0) + // Build the current batch, if any data exists + return rowCount > 0 ? CreateBatch(response, columnCount, rowCount) : null; + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) { - // This is the last batch - _statement = null; + throw new TimeoutException("The query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while fetching results. '{ex.Message}'", ex); } - - // Build the current batch, if any data exists - return rowCount > 0 ? CreateBatch(response, columnCount, rowCount) : null; } private RecordBatch CreateBatch(TFetchResultsResp response, int columnCount, int rowCount) diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs index 824feceb9e..06723e324d 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs @@ -20,6 +20,7 @@ using System.Threading.Tasks; using Apache.Arrow.Ipc; using Apache.Hive.Service.Rpc.Thrift; +using Thrift.Transport; namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 { @@ -32,33 +33,89 @@ protected HiveServer2Statement(HiveServer2Connection connection) protected virtual void SetStatementProperties(TExecuteStatementReq statement) { + statement.QueryTimeout = QueryTimeoutSeconds; } - public override QueryResult ExecuteQuery() => ExecuteQueryAsync().AsTask().Result; + public override QueryResult ExecuteQuery() + { + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try + { + return ExecuteQueryAsyncInternal(cancellationToken).Result; + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while fetching results. '{ex.Message}'", ex); + } + } - public override UpdateResult ExecuteUpdate() => ExecuteUpdateAsync().Result; + public override UpdateResult ExecuteUpdate() + { + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try + { + return ExecuteUpdateAsyncInternal(cancellationToken).Result; + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while fetching results. '{ex.Message}'", ex); + } + } - public override async ValueTask ExecuteQueryAsync() + private async Task ExecuteQueryAsyncInternal(CancellationToken cancellationToken = default) { - await ExecuteStatementAsync(); - await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds); - Schema schema = await GetResultSetSchemaAsync(OperationHandle!, Connection.Client); + // this could either: + // take QueryTimeoutSeconds * 3 + // OR + // take QueryTimeoutSeconds (but this could be restricting) + await ExecuteStatementAsync(cancellationToken); // --> get QueryTimeout + + await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken); // + poll, up to QueryTimeout + Schema schema = await GetResultSetSchemaAsync(OperationHandle!, Connection.Client, cancellationToken); // + get the result, up to QueryTimeout - // TODO: Ensure this is set dynamically based on server capabilities return new QueryResult(-1, Connection.NewReader(this, schema)); } + public override async ValueTask ExecuteQueryAsync() + { + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try + { + return await ExecuteQueryAsyncInternal(cancellationToken); + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while fetching results. '{ex.Message}'", ex); + } + } + private async Task GetResultSetSchemaAsync(TOperationHandle operationHandle, TCLIService.IAsync client, CancellationToken cancellationToken = default) { TGetResultSetMetadataResp response = await HiveServer2Connection.GetResultSetMetadataAsync(operationHandle, client, cancellationToken); return Connection.SchemaParser.GetArrowSchema(response.Schema, Connection.DataTypeConversion); } - public override async Task ExecuteUpdateAsync() + public async Task ExecuteUpdateAsyncInternal(CancellationToken cancellationToken = default) { const string NumberOfAffectedRowsColumnName = "num_affected_rows"; - - QueryResult queryResult = await ExecuteQueryAsync(); + QueryResult queryResult = await ExecuteQueryAsyncInternal(cancellationToken); if (queryResult.Stream == null) { throw new AdbcException("no data found"); @@ -79,7 +136,7 @@ public override async Task ExecuteUpdateAsync() long? affectedRows = null; while (true) { - using RecordBatch nextBatch = await stream.ReadNextRecordBatchAsync(); + using RecordBatch nextBatch = await stream.ReadNextRecordBatchAsync(cancellationToken); if (nextBatch == null) { break; } Int64Array numOfModifiedArray = (Int64Array)nextBatch.Column(NumberOfAffectedRowsColumnName); // Note: should only have one item, but iterate for completeness @@ -94,26 +151,51 @@ public override async Task ExecuteUpdateAsync() return new UpdateResult(affectedRows ?? -1); } + public override async Task ExecuteUpdateAsync() + { + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try + { + return await ExecuteUpdateAsyncInternal(cancellationToken); + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while fetching results. '{ex.Message}'", ex); + } + } + public override void SetOption(string key, string value) { switch (key) { - case Options.PollTimeMilliseconds: + case ApacheParameters.PollTimeMilliseconds: UpdatePollTimeIfValid(key, value); break; - case Options.BatchSize: + case ApacheParameters.BatchSize: UpdateBatchSizeIfValid(key, value); break; + case ApacheParameters.QueryTimeoutSeconds: + if (ApacheUtility.QueryTimeoutIsValid(key, value, out int queryTimeoutSeconds)) + { + QueryTimeoutSeconds = queryTimeoutSeconds; + } + break; default: throw AdbcException.NotImplemented($"Option '{key}' is not implemented."); } } - protected async Task ExecuteStatementAsync() + protected async Task ExecuteStatementAsync(CancellationToken cancellationToken = default) { TExecuteStatementReq executeRequest = new TExecuteStatementReq(Connection.SessionHandle, SqlQuery); SetStatementProperties(executeRequest); - TExecuteStatementResp executeResponse = await Connection.Client.ExecuteStatement(executeRequest); + TExecuteStatementResp executeResponse = await Connection.Client.ExecuteStatement(executeRequest, cancellationToken); if (executeResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) { throw new HiveServer2Exception(executeResponse.Status.ErrorMessage) @@ -127,23 +209,20 @@ protected async Task ExecuteStatementAsync() protected internal long BatchSize { get; private set; } = HiveServer2Connection.BatchSizeDefault; + protected internal int QueryTimeoutSeconds + { + // Coordinate updates with the connection + get => Connection.QueryTimeoutSeconds; + set => Connection.QueryTimeoutSeconds = value; + } + public HiveServer2Connection Connection { get; private set; } public TOperationHandle? OperationHandle { get; private set; } - /// - /// Provides the constant string key values to the method. - /// - public class Options - { - // Options common to all HiveServer2Statement-derived drivers go here - public const string PollTimeMilliseconds = "adbc.statement.polltime_milliseconds"; - public const string BatchSize = "adbc.statement.batch_size"; - } - private void UpdatePollTimeIfValid(string key, string value) => PollTimeMilliseconds = !string.IsNullOrEmpty(key) && int.TryParse(value, result: out int pollTimeMilliseconds) && pollTimeMilliseconds >= 0 ? pollTimeMilliseconds - : throw new ArgumentOutOfRangeException(key, value, $"The value '{value}' for option '{key}' is invalid. Must be a numeric value greater than or equal to -1."); + : throw new ArgumentOutOfRangeException(key, value, $"The value '{value}' for option '{key}' is invalid. Must be a numeric value greater than or equal to 0."); private void UpdateBatchSizeIfValid(string key, string value) => BatchSize = !string.IsNullOrEmpty(value) && long.TryParse(value, out long batchSize) && batchSize > 0 ? batchSize @@ -153,8 +232,9 @@ public override void Dispose() { if (OperationHandle != null) { + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); TCloseOperationReq request = new TCloseOperationReq(OperationHandle); - Connection.Client.CloseOperation(request).Wait(); + Connection.Client.CloseOperation(request, cancellationToken).Wait(); OperationHandle = null; } diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs index c6c6cc7969..0e673c7c4a 100644 --- a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs +++ b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs @@ -40,7 +40,7 @@ internal ImpalaConnection(IReadOnlyDictionary properties) { } - protected override Task CreateTransportAsync() + protected override TTransport CreateTransport() { string hostName = Properties["HostName"]; string? tmp; @@ -52,10 +52,10 @@ protected override Task CreateTransportAsync() TConfiguration config = new TConfiguration(); TTransport transport = new ThriftSocketTransport(hostName, port, config); - return Task.FromResult(transport); + return transport; } - protected override Task CreateProtocolAsync(TTransport transport) + protected override Task CreateProtocolAsync(TTransport transport, CancellationToken cancellationToken = default) { return Task.FromResult(new TBinaryProtocol(transport)); } diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs b/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs index 0bd620ee9b..f94ac3970e 100644 --- a/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs +++ b/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs @@ -30,7 +30,7 @@ internal ImpalaStatement(ImpalaConnection connection) /// /// Provides the constant string key values to the method. /// - public new sealed class Options : HiveServer2Statement.Options + public sealed class Options : ApacheParameters { // options specific to Impala go here } diff --git a/csharp/src/Drivers/Apache/Spark/README.md b/csharp/src/Drivers/Apache/Spark/README.md index 7d1f8b5607..3b5a0e79ed 100644 --- a/csharp/src/Drivers/Apache/Spark/README.md +++ b/csharp/src/Drivers/Apache/Spark/README.md @@ -37,9 +37,18 @@ but can also be passed in the call to `AdbcDatabase.Connect`. | `password` | The password for the user name used for basic authentication. | | | `adbc.spark.data_type_conv` | Comma-separated list of data conversion options. Each option indicates the type of conversion to perform on data returned from the Spark server.

Allowed values: `none`, `scalar`.

Option `none` indicates there is no conversion from Spark type to native type (i.e., no conversion from String to Timestamp for Apache Spark over HTTP). Example `adbc.spark.conv_data_type=none`.

Option `scalar` will perform conversion (if necessary) from the Spark data type to corresponding Arrow data types for types `DATE/Date32/DateTime`, `DECIMAL/Decimal128/SqlDecimal`, and `TIMESTAMP/Timestamp/DateTimeOffset`. Example `adbc.spark.conv_data_type=scalar` | `scalar` | | `adbc.spark.tls_options` | Comma-separated list of TLS/SSL options. Each option indicates the TLS/SSL option when connecting to a Spark server.

Allowed values: `allow_self_signed`, `allow_hostname_mismatch`.

Option `allow_self_signed` allows certificate errors due to an unknown certificate authority, typically when using a self-signed certificate. Option `allow_hostname_mismatch` allow certificate errors due to a mismatch of the hostname. (e.g., when connecting through an SSH tunnel). Example `adbc.spark.tls_options=allow_self_signed` | | -| `adbc.spark.http_request_timeout_ms` | Sets the timeout (in milliseconds) when making requests to the Spark server (type: `http`). Set the value higher than the default if you notice errors due to network timeouts. | `30000` | -| `adbc.statement.batch_size` | Sets the maximum number of rows to retrieve in a single batch request. | `50000` | -| `adbc.statement.polltime_milliseconds` | If polling is necessary to get a result, this option sets the length of time (in milliseconds) to wait between polls. | `500` | +| `adbc.spark.connect_timeout_ms` | Sets the timeout (in milliseconds) to open a new session. Values can be 0 (infinite) or greater than zero. | `30000` | +| `adbc.apache.statement.batch_size` | Sets the maximum number of rows to retrieve in a single batch request. | `50000` | +| `adbc.apache.statement.polltime_ms` | If polling is necessary to get a result, this option sets the length of time (in milliseconds) to wait between polls. | `500` | +| `adbc.apache.statement.query_timeout_s` | Sets the maximum time (in seconds) for a query to complete. Values can be 0 (infinite) or greater than zero. | `60` | + +## Timeout Configuration + +Timeouts have a hierarchy to their behavior. As specified above, the `adbc.spark.connect_timeout_ms` is analogous to a ConnectTimeout and used to initially establish a new session with the server. + +The `adbc.apache.statement.query_timeout_s` is analogous to a CommandTimeout for any subsequent calls to the server for requests, including metadata calls and executing queries. + +The `adbc.apache.statement.polltime_ms` specifies the time between polls to the service, up to the limit specifed by `adbc.apache.statement.query_timeout_s`. ## Spark Types diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs index f532369e62..b3c0c56ba1 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs @@ -19,8 +19,6 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; -using System.Net; -using System.Net.Http; using System.Reflection; using System.Text; using System.Text.RegularExpressions; @@ -32,6 +30,7 @@ using Apache.Arrow.Ipc; using Apache.Arrow.Types; using Apache.Hive.Service.Rpc.Thrift; +using Thrift.Transport; namespace Apache.Arrow.Adbc.Drivers.Apache.Spark { @@ -420,26 +419,42 @@ public override IArrowArrayStream GetTableTypes() SessionHandle = SessionHandle ?? throw new InvalidOperationException("session not created"), GetDirectResults = sparkGetDirectResults }; - TGetTableTypesResp resp = Client.GetTableTypes(req).Result; - if (resp.Status.StatusCode == TStatusCode.ERROR_STATUS) + + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try { - throw new HiveServer2Exception(resp.Status.ErrorMessage) - .SetNativeError(resp.Status.ErrorCode) - .SetSqlState(resp.Status.SqlState); - } + TGetTableTypesResp resp = Client.GetTableTypes(req, cancellationToken).Result; - TRowSet rowSet = GetRowSetAsync(resp).Result; - StringArray tableTypes = rowSet.Columns[0].StringVal.Values; + if (resp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new HiveServer2Exception(resp.Status.ErrorMessage) + .SetNativeError(resp.Status.ErrorCode) + .SetSqlState(resp.Status.SqlState); + } - StringArray.Builder tableTypesBuilder = new StringArray.Builder(); - tableTypesBuilder.AppendRange(tableTypes); + TRowSet rowSet = GetRowSetAsync(resp, cancellationToken).Result; + StringArray tableTypes = rowSet.Columns[0].StringVal.Values; - IArrowArray[] dataArrays = new IArrowArray[] - { + StringArray.Builder tableTypesBuilder = new StringArray.Builder(); + tableTypesBuilder.AppendRange(tableTypes); + + IArrowArray[] dataArrays = new IArrowArray[] + { tableTypesBuilder.Build() - }; + }; - return new SparkInfoArrowStream(StandardSchemas.TableTypesSchema, dataArrays); + return new SparkInfoArrowStream(StandardSchemas.TableTypesSchema, dataArrays); + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The metadata query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while running metadata query. '{ex.Message}'", ex); + } } public override Schema GetTableSchema(string? catalog, string? dbSchema, string? tableName) @@ -450,221 +465,248 @@ public override Schema GetTableSchema(string? catalog, string? dbSchema, string? getColumnsReq.TableName = tableName; getColumnsReq.GetDirectResults = sparkGetDirectResults; - var columnsResponse = Client.GetColumns(getColumnsReq).Result; - if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try { - throw new Exception(columnsResponse.Status.ErrorMessage); - } + var columnsResponse = Client.GetColumns(getColumnsReq, cancellationToken).Result; + if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(columnsResponse.Status.ErrorMessage); + } - TRowSet rowSet = GetRowSetAsync(columnsResponse).Result; - List columns = rowSet.Columns; - int rowCount = rowSet.Columns[3].StringVal.Values.Length; + TRowSet rowSet = GetRowSetAsync(columnsResponse, cancellationToken).Result; + List columns = rowSet.Columns; + int rowCount = rowSet.Columns[3].StringVal.Values.Length; - Field[] fields = new Field[rowCount]; - for (int i = 0; i < rowCount; i++) + Field[] fields = new Field[rowCount]; + for (int i = 0; i < rowCount; i++) + { + string columnName = columns[3].StringVal.Values.GetString(i); + int? columnType = columns[4].I32Val.Values.GetValue(i); + string typeName = columns[5].StringVal.Values.GetString(i); + // Note: the following two columns do not seem to be set correctly for DECIMAL types. + //int? columnSize = columns[6].I32Val.Values.GetValue(i); + //int? decimalDigits = columns[8].I32Val.Values.GetValue(i); + bool nullable = columns[10].I32Val.Values.GetValue(i) == 1; + IArrowType dataType = SparkConnection.GetArrowType(columnType!.Value, typeName); + fields[i] = new Field(columnName, dataType, nullable); + } + return new Schema(fields, null); + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) { - string columnName = columns[3].StringVal.Values.GetString(i); - int? columnType = columns[4].I32Val.Values.GetValue(i); - string typeName = columns[5].StringVal.Values.GetString(i); - // Note: the following two columns do not seem to be set correctly for DECIMAL types. - //int? columnSize = columns[6].I32Val.Values.GetValue(i); - //int? decimalDigits = columns[8].I32Val.Values.GetValue(i); - bool nullable = columns[10].I32Val.Values.GetValue(i) == 1; - IArrowType dataType = SparkConnection.GetArrowType(columnType!.Value, typeName); - fields[i] = new Field(columnName, dataType, nullable); + throw new TimeoutException("The metadata query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while running metadata query. '{ex.Message}'", ex); } - return new Schema(fields, null); } public override IArrowArrayStream GetObjects(GetObjectsDepth depth, string? catalogPattern, string? dbSchemaPattern, string? tableNamePattern, IReadOnlyList? tableTypes, string? columnNamePattern) { - Trace.TraceError($"getting objects with depth={depth.ToString()}, catalog = {catalogPattern}, dbschema = {dbSchemaPattern}, tablename = {tableNamePattern}"); - Dictionary>> catalogMap = new Dictionary>>(); - if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Catalogs) + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try { - TGetCatalogsReq getCatalogsReq = new TGetCatalogsReq(SessionHandle); - getCatalogsReq.GetDirectResults = sparkGetDirectResults; - - TGetCatalogsResp getCatalogsResp = Client.GetCatalogs(getCatalogsReq).Result; - if (getCatalogsResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Catalogs) { - throw new Exception(getCatalogsResp.Status.ErrorMessage); - } - var catalogsMetadata = GetResultSetMetadataAsync(getCatalogsResp).Result; - IReadOnlyDictionary columnMap = GetColumnIndexMap(catalogsMetadata.Schema.Columns); + TGetCatalogsReq getCatalogsReq = new TGetCatalogsReq(SessionHandle); + getCatalogsReq.GetDirectResults = sparkGetDirectResults; - string catalogRegexp = PatternToRegEx(catalogPattern); - TRowSet rowSet = GetRowSetAsync(getCatalogsResp).Result; - IReadOnlyList list = rowSet.Columns[columnMap[TableCat]].StringVal.Values; - for (int i = 0; i < list.Count; i++) - { - string col = list[i]; - string catalog = col; + TGetCatalogsResp getCatalogsResp = Client.GetCatalogs(getCatalogsReq, cancellationToken).Result; - if (Regex.IsMatch(catalog, catalogRegexp, RegexOptions.IgnoreCase)) + if (getCatalogsResp.Status.StatusCode == TStatusCode.ERROR_STATUS) { - catalogMap.Add(catalog, new Dictionary>()); + throw new Exception(getCatalogsResp.Status.ErrorMessage); } - } - // Handle the case where server does not support 'catalog' in the namespace. - if (list.Count == 0 && string.IsNullOrEmpty(catalogPattern)) - { - catalogMap.Add(string.Empty, []); - } - } + var catalogsMetadata = GetResultSetMetadataAsync(getCatalogsResp, cancellationToken).Result; + IReadOnlyDictionary columnMap = GetColumnIndexMap(catalogsMetadata.Schema.Columns); - if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.DbSchemas) - { - TGetSchemasReq getSchemasReq = new TGetSchemasReq(SessionHandle); - getSchemasReq.CatalogName = catalogPattern; - getSchemasReq.SchemaName = dbSchemaPattern; - getSchemasReq.GetDirectResults = sparkGetDirectResults; + string catalogRegexp = PatternToRegEx(catalogPattern); + TRowSet rowSet = GetRowSetAsync(getCatalogsResp, cancellationToken).Result; + IReadOnlyList list = rowSet.Columns[columnMap[TableCat]].StringVal.Values; + for (int i = 0; i < list.Count; i++) + { + string col = list[i]; + string catalog = col; - TGetSchemasResp getSchemasResp = Client.GetSchemas(getSchemasReq).Result; - if (getSchemasResp.Status.StatusCode == TStatusCode.ERROR_STATUS) - { - throw new Exception(getSchemasResp.Status.ErrorMessage); + if (Regex.IsMatch(catalog, catalogRegexp, RegexOptions.IgnoreCase)) + { + catalogMap.Add(catalog, new Dictionary>()); + } + } + // Handle the case where server does not support 'catalog' in the namespace. + if (list.Count == 0 && string.IsNullOrEmpty(catalogPattern)) + { + catalogMap.Add(string.Empty, []); + } } - TGetResultSetMetadataResp schemaMetadata = GetResultSetMetadataAsync(getSchemasResp).Result; - IReadOnlyDictionary columnMap = GetColumnIndexMap(schemaMetadata.Schema.Columns); - TRowSet rowSet = GetRowSetAsync(getSchemasResp).Result; - - IReadOnlyList catalogList = rowSet.Columns[columnMap[TableCatalog]].StringVal.Values; - IReadOnlyList schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; - - for (int i = 0; i < catalogList.Count; i++) + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.DbSchemas) { - string catalog = catalogList[i]; - string schemaDb = schemaList[i]; - // It seems Spark sometimes returns empty string for catalog on some schema (temporary tables). - catalogMap.GetValueOrDefault(catalog)?.Add(schemaDb, new Dictionary()); - } - } + TGetSchemasReq getSchemasReq = new TGetSchemasReq(SessionHandle); + getSchemasReq.CatalogName = catalogPattern; + getSchemasReq.SchemaName = dbSchemaPattern; + getSchemasReq.GetDirectResults = sparkGetDirectResults; - if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Tables) - { - TGetTablesReq getTablesReq = new TGetTablesReq(SessionHandle); - getTablesReq.CatalogName = catalogPattern; - getTablesReq.SchemaName = dbSchemaPattern; - getTablesReq.TableName = tableNamePattern; - getTablesReq.GetDirectResults = sparkGetDirectResults; - - TGetTablesResp getTablesResp = Client.GetTables(getTablesReq).Result; - if (getTablesResp.Status.StatusCode == TStatusCode.ERROR_STATUS) - { - throw new Exception(getTablesResp.Status.ErrorMessage); - } + TGetSchemasResp getSchemasResp = Client.GetSchemas(getSchemasReq, cancellationToken).Result; + if (getSchemasResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(getSchemasResp.Status.ErrorMessage); + } - TGetResultSetMetadataResp tableMetadata = GetResultSetMetadataAsync(getTablesResp).Result; - IReadOnlyDictionary columnMap = GetColumnIndexMap(tableMetadata.Schema.Columns); - TRowSet rowSet = GetRowSetAsync(getTablesResp).Result; + TGetResultSetMetadataResp schemaMetadata = GetResultSetMetadataAsync(getSchemasResp, cancellationToken).Result; + IReadOnlyDictionary columnMap = GetColumnIndexMap(schemaMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(getSchemasResp, cancellationToken).Result; - IReadOnlyList catalogList = rowSet.Columns[columnMap[TableCat]].StringVal.Values; - IReadOnlyList schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; - IReadOnlyList tableList = rowSet.Columns[columnMap[TableName]].StringVal.Values; - IReadOnlyList tableTypeList = rowSet.Columns[columnMap[TableType]].StringVal.Values; + IReadOnlyList catalogList = rowSet.Columns[columnMap[TableCatalog]].StringVal.Values; + IReadOnlyList schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; - for (int i = 0; i < catalogList.Count; i++) - { - string catalog = catalogList[i]; - string schemaDb = schemaList[i]; - string tableName = tableList[i]; - string tableType = tableTypeList[i]; - TableInfo tableInfo = new(tableType); - catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.Add(tableName, tableInfo); + for (int i = 0; i < catalogList.Count; i++) + { + string catalog = catalogList[i]; + string schemaDb = schemaList[i]; + // It seems Spark sometimes returns empty string for catalog on some schema (temporary tables). + catalogMap.GetValueOrDefault(catalog)?.Add(schemaDb, new Dictionary()); + } } - } - if (depth == GetObjectsDepth.All) - { - TGetColumnsReq columnsReq = new TGetColumnsReq(SessionHandle); - columnsReq.CatalogName = catalogPattern; - columnsReq.SchemaName = dbSchemaPattern; - columnsReq.TableName = tableNamePattern; - columnsReq.GetDirectResults = sparkGetDirectResults; + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Tables) + { + TGetTablesReq getTablesReq = new TGetTablesReq(SessionHandle); + getTablesReq.CatalogName = catalogPattern; + getTablesReq.SchemaName = dbSchemaPattern; + getTablesReq.TableName = tableNamePattern; + getTablesReq.GetDirectResults = sparkGetDirectResults; + + TGetTablesResp getTablesResp = Client.GetTables(getTablesReq, cancellationToken).Result; + if (getTablesResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(getTablesResp.Status.ErrorMessage); + } - if (!string.IsNullOrEmpty(columnNamePattern)) - columnsReq.ColumnName = columnNamePattern; + TGetResultSetMetadataResp tableMetadata = GetResultSetMetadataAsync(getTablesResp, cancellationToken).Result; + IReadOnlyDictionary columnMap = GetColumnIndexMap(tableMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(getTablesResp, cancellationToken).Result; - var columnsResponse = Client.GetColumns(columnsReq).Result; - if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) - { - throw new Exception(columnsResponse.Status.ErrorMessage); + IReadOnlyList catalogList = rowSet.Columns[columnMap[TableCat]].StringVal.Values; + IReadOnlyList schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; + IReadOnlyList tableList = rowSet.Columns[columnMap[TableName]].StringVal.Values; + IReadOnlyList tableTypeList = rowSet.Columns[columnMap[TableType]].StringVal.Values; + + for (int i = 0; i < catalogList.Count; i++) + { + string catalog = catalogList[i]; + string schemaDb = schemaList[i]; + string tableName = tableList[i]; + string tableType = tableTypeList[i]; + TableInfo tableInfo = new(tableType); + catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.Add(tableName, tableInfo); + } } - TGetResultSetMetadataResp columnsMetadata = GetResultSetMetadataAsync(columnsResponse).Result; - IReadOnlyDictionary columnMap = GetColumnIndexMap(columnsMetadata.Schema.Columns); - TRowSet rowSet = GetRowSetAsync(columnsResponse).Result; - - IReadOnlyList catalogList = rowSet.Columns[columnMap[TableCat]].StringVal.Values; - IReadOnlyList schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; - IReadOnlyList tableList = rowSet.Columns[columnMap[TableName]].StringVal.Values; - IReadOnlyList columnNameList = rowSet.Columns[columnMap[ColumnName]].StringVal.Values; - ReadOnlySpan columnTypeList = rowSet.Columns[columnMap[DataType]].I32Val.Values.Values; - IReadOnlyList typeNameList = rowSet.Columns[columnMap[TypeName]].StringVal.Values; - ReadOnlySpan nullableList = rowSet.Columns[columnMap[Nullable]].I32Val.Values.Values; - IReadOnlyList columnDefaultList = rowSet.Columns[columnMap[ColumnDef]].StringVal.Values; - ReadOnlySpan ordinalPosList = rowSet.Columns[columnMap[OrdinalPosition]].I32Val.Values.Values; - IReadOnlyList isNullableList = rowSet.Columns[columnMap[IsNullable]].StringVal.Values; - IReadOnlyList isAutoIncrementList = rowSet.Columns[columnMap[IsAutoIncrement]].StringVal.Values; - - for (int i = 0; i < catalogList.Count; i++) + if (depth == GetObjectsDepth.All) { - // For systems that don't support 'catalog' in the namespace - string catalog = catalogList[i] ?? string.Empty; - string schemaDb = schemaList[i]; - string tableName = tableList[i]; - string columnName = columnNameList[i]; - short colType = (short)columnTypeList[i]; - string typeName = typeNameList[i]; - short nullable = (short)nullableList[i]; - string? isAutoIncrementString = isAutoIncrementList[i]; - bool isAutoIncrement = (!string.IsNullOrEmpty(isAutoIncrementString) && (isAutoIncrementString.Equals("YES", StringComparison.InvariantCultureIgnoreCase) || isAutoIncrementString.Equals("TRUE", StringComparison.InvariantCultureIgnoreCase))); - string isNullable = isNullableList[i] ?? "YES"; - string columnDefault = columnDefaultList[i] ?? ""; - // Spark/Databricks reports ordinal index zero-indexed, instead of one-indexed - int ordinalPos = ordinalPosList[i] + 1; - TableInfo? tableInfo = catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.GetValueOrDefault(tableName); - tableInfo?.ColumnName.Add(columnName); - tableInfo?.ColType.Add(colType); - tableInfo?.Nullable.Add(nullable); - tableInfo?.IsAutoIncrement.Add(isAutoIncrement); - tableInfo?.IsNullable.Add(isNullable); - tableInfo?.ColumnDefault.Add(columnDefault); - tableInfo?.OrdinalPosition.Add(ordinalPos); - SetPrecisionScaleAndTypeName(colType, typeName, tableInfo); - } - } + TGetColumnsReq columnsReq = new TGetColumnsReq(SessionHandle); + columnsReq.CatalogName = catalogPattern; + columnsReq.SchemaName = dbSchemaPattern; + columnsReq.TableName = tableNamePattern; + columnsReq.GetDirectResults = sparkGetDirectResults; - StringArray.Builder catalogNameBuilder = new StringArray.Builder(); - List catalogDbSchemasValues = new List(); + if (!string.IsNullOrEmpty(columnNamePattern)) + columnsReq.ColumnName = columnNamePattern; - foreach (KeyValuePair>> catalogEntry in catalogMap) - { - catalogNameBuilder.Append(catalogEntry.Key); + var columnsResponse = Client.GetColumns(columnsReq, cancellationToken).Result; + if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(columnsResponse.Status.ErrorMessage); + } - if (depth == GetObjectsDepth.Catalogs) - { - catalogDbSchemasValues.Add(null); + TGetResultSetMetadataResp columnsMetadata = GetResultSetMetadataAsync(columnsResponse, cancellationToken).Result; + IReadOnlyDictionary columnMap = GetColumnIndexMap(columnsMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(columnsResponse, cancellationToken).Result; + + IReadOnlyList catalogList = rowSet.Columns[columnMap[TableCat]].StringVal.Values; + IReadOnlyList schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; + IReadOnlyList tableList = rowSet.Columns[columnMap[TableName]].StringVal.Values; + IReadOnlyList columnNameList = rowSet.Columns[columnMap[ColumnName]].StringVal.Values; + ReadOnlySpan columnTypeList = rowSet.Columns[columnMap[DataType]].I32Val.Values.Values; + IReadOnlyList typeNameList = rowSet.Columns[columnMap[TypeName]].StringVal.Values; + ReadOnlySpan nullableList = rowSet.Columns[columnMap[Nullable]].I32Val.Values.Values; + IReadOnlyList columnDefaultList = rowSet.Columns[columnMap[ColumnDef]].StringVal.Values; + ReadOnlySpan ordinalPosList = rowSet.Columns[columnMap[OrdinalPosition]].I32Val.Values.Values; + IReadOnlyList isNullableList = rowSet.Columns[columnMap[IsNullable]].StringVal.Values; + IReadOnlyList isAutoIncrementList = rowSet.Columns[columnMap[IsAutoIncrement]].StringVal.Values; + + for (int i = 0; i < catalogList.Count; i++) + { + // For systems that don't support 'catalog' in the namespace + string catalog = catalogList[i] ?? string.Empty; + string schemaDb = schemaList[i]; + string tableName = tableList[i]; + string columnName = columnNameList[i]; + short colType = (short)columnTypeList[i]; + string typeName = typeNameList[i]; + short nullable = (short)nullableList[i]; + string? isAutoIncrementString = isAutoIncrementList[i]; + bool isAutoIncrement = (!string.IsNullOrEmpty(isAutoIncrementString) && (isAutoIncrementString.Equals("YES", StringComparison.InvariantCultureIgnoreCase) || isAutoIncrementString.Equals("TRUE", StringComparison.InvariantCultureIgnoreCase))); + string isNullable = isNullableList[i] ?? "YES"; + string columnDefault = columnDefaultList[i] ?? ""; + // Spark/Databricks reports ordinal index zero-indexed, instead of one-indexed + int ordinalPos = ordinalPosList[i] + 1; + TableInfo? tableInfo = catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.GetValueOrDefault(tableName); + tableInfo?.ColumnName.Add(columnName); + tableInfo?.ColType.Add(colType); + tableInfo?.Nullable.Add(nullable); + tableInfo?.IsAutoIncrement.Add(isAutoIncrement); + tableInfo?.IsNullable.Add(isNullable); + tableInfo?.ColumnDefault.Add(columnDefault); + tableInfo?.OrdinalPosition.Add(ordinalPos); + SetPrecisionScaleAndTypeName(colType, typeName, tableInfo); + } } - else + + StringArray.Builder catalogNameBuilder = new StringArray.Builder(); + List catalogDbSchemasValues = new List(); + + foreach (KeyValuePair>> catalogEntry in catalogMap) { - catalogDbSchemasValues.Add(GetDbSchemas( - depth, catalogEntry.Value)); + catalogNameBuilder.Append(catalogEntry.Key); + + if (depth == GetObjectsDepth.Catalogs) + { + catalogDbSchemasValues.Add(null); + } + else + { + catalogDbSchemasValues.Add(GetDbSchemas( + depth, catalogEntry.Value)); + } } - } - Schema schema = StandardSchemas.GetObjectsSchema; - IReadOnlyList dataArrays = schema.Validate( - new List - { + Schema schema = StandardSchemas.GetObjectsSchema; + IReadOnlyList dataArrays = schema.Validate( + new List + { catalogNameBuilder.Build(), catalogDbSchemasValues.BuildListArrayForType(new StructType(StandardSchemas.DbSchemaSchema)), - }); + }); - return new SparkInfoArrowStream(schema, dataArrays); + return new SparkInfoArrowStream(schema, dataArrays); + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The metadata query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while running metadata query. '{ex.Message}'", ex); + } } private static IReadOnlyDictionary GetColumnIndexMap(List columns) => columns @@ -998,15 +1040,15 @@ protected static Uri GetBaseAddress(string? uri, string? hostName, string? path, protected abstract void ValidateAuthentication(); protected abstract void ValidateOptions(); - protected abstract Task GetRowSetAsync(TGetTableTypesResp response); - protected abstract Task GetRowSetAsync(TGetColumnsResp response); - protected abstract Task GetRowSetAsync(TGetTablesResp response); - protected abstract Task GetRowSetAsync(TGetCatalogsResp getCatalogsResp); - protected abstract Task GetRowSetAsync(TGetSchemasResp getSchemasResp); - protected abstract Task GetResultSetMetadataAsync(TGetSchemasResp response); - protected abstract Task GetResultSetMetadataAsync(TGetCatalogsResp response); - protected abstract Task GetResultSetMetadataAsync(TGetColumnsResp response); - protected abstract Task GetResultSetMetadataAsync(TGetTablesResp response); + protected abstract Task GetRowSetAsync(TGetTableTypesResp response, CancellationToken cancellationToken = default); + protected abstract Task GetRowSetAsync(TGetColumnsResp response, CancellationToken cancellationToken = default); + protected abstract Task GetRowSetAsync(TGetTablesResp response, CancellationToken cancellationToken = default); + protected abstract Task GetRowSetAsync(TGetCatalogsResp getCatalogsResp, CancellationToken cancellationToken = default); + protected abstract Task GetRowSetAsync(TGetSchemasResp getSchemasResp, CancellationToken cancellationToken = default); + protected abstract Task GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken cancellationToken = default); + protected abstract Task GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default); + protected abstract Task GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken cancellationToken = default); + protected abstract Task GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken cancellationToken = default); internal abstract SparkServerType ServerType { get; } diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs index 7d187fc71b..d51ef42b9b 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs @@ -16,6 +16,7 @@ */ using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Ipc; using Apache.Hive.Service.Rpc.Thrift; @@ -43,24 +44,24 @@ protected override TOpenSessionReq CreateSessionRequest() return req; } - protected override Task GetResultSetMetadataAsync(TGetSchemasResp response) => + protected override Task GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSetMetadata); - protected override Task GetResultSetMetadataAsync(TGetCatalogsResp response) => + protected override Task GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSetMetadata); - protected override Task GetResultSetMetadataAsync(TGetColumnsResp response) => + protected override Task GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSetMetadata); - protected override Task GetResultSetMetadataAsync(TGetTablesResp response) => + protected override Task GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSetMetadata); - protected override Task GetRowSetAsync(TGetTableTypesResp response) => + protected override Task GetRowSetAsync(TGetTableTypesResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSet.Results); - protected override Task GetRowSetAsync(TGetColumnsResp response) => + protected override Task GetRowSetAsync(TGetColumnsResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSet.Results); - protected override Task GetRowSetAsync(TGetTablesResp response) => + protected override Task GetRowSetAsync(TGetTablesResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSet.Results); - protected override Task GetRowSetAsync(TGetCatalogsResp response) => + protected override Task GetRowSetAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSet.Results); - protected override Task GetRowSetAsync(TGetSchemasResp response) => + protected override Task GetRowSetAsync(TGetSchemasResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSet.Results); } } diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs b/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs index 77ecdb6a20..059ab1690b 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs @@ -15,7 +15,6 @@ * limitations under the License. */ -using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; diff --git a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs index 9d34ac75c7..4c068aaa57 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs @@ -120,24 +120,19 @@ protected override void ValidateOptions() DataTypeConversion = DataTypeConversionParser.Parse(dataTypeConv); Properties.TryGetValue(SparkParameters.TLSOptions, out string? tlsOptions); TlsOptions = TlsOptionsParser.Parse(tlsOptions); - Properties.TryGetValue(SparkParameters.HttpRequestTimeoutMilliseconds, out string? requestTimeoutMs); - if (requestTimeoutMs != null) + Properties.TryGetValue(SparkParameters.ConnectTimeoutMilliseconds, out string? connectTimeoutMs); + if (connectTimeoutMs != null) { - HttpRequestTimeout = int.TryParse(requestTimeoutMs, NumberStyles.Integer, CultureInfo.InvariantCulture, out int requestTimeoutMsValue) && requestTimeoutMsValue > 0 - ? requestTimeoutMsValue - : throw new ArgumentOutOfRangeException(SparkParameters.HttpRequestTimeoutMilliseconds, requestTimeoutMs, $"must be a value between 1 .. {int.MaxValue}. default is 30000 milliseconds."); + ConnectTimeoutMilliseconds = int.TryParse(connectTimeoutMs, NumberStyles.Integer, CultureInfo.InvariantCulture, out int connectTimeoutMsValue) && (connectTimeoutMsValue >= 0) + ? connectTimeoutMsValue + : throw new ArgumentOutOfRangeException(SparkParameters.ConnectTimeoutMilliseconds, connectTimeoutMs, $"must be a value of 0 (infinite) or between 1 .. {int.MaxValue}. default is 30000 milliseconds."); } } internal override IArrowArrayStream NewReader(T statement, Schema schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: statement.Connection.DataTypeConversion); - protected override Task CreateTransportAsync() + protected override TTransport CreateTransport() { - foreach (var property in Properties.Keys) - { - Trace.TraceError($"key = {property} value = {Properties[property]}"); - } - // Assumption: parameters have already been validated. Properties.TryGetValue(SparkParameters.HostName, out string? hostName); Properties.TryGetValue(SparkParameters.Path, out string? path); @@ -164,9 +159,12 @@ protected override Task CreateTransportAsync() TConfiguration config = new(); ThriftHttpTransport transport = new(httpClient, config) { - ConnectTimeout = HttpRequestTimeout, + // This value can only be set before the first call/request. So if a new value for query timeout + // is set, we won't be able to update the value. Setting to ~infinite and relying on cancellation token + // to ensure cancelled correctly. + ConnectTimeout = int.MaxValue, }; - return Task.FromResult(transport); + return transport; } private HttpClientHandler NewHttpClientHandler() @@ -211,11 +209,9 @@ private HttpClientHandler NewHttpClientHandler() } } - protected override async Task CreateProtocolAsync(TTransport transport) + protected override async Task CreateProtocolAsync(TTransport transport, CancellationToken cancellationToken = default) { - Trace.TraceError($"create protocol with {Properties.Count} properties."); - - if (!transport.IsOpen) await transport.OpenAsync(CancellationToken.None); + if (!transport.IsOpen) await transport.OpenAsync(cancellationToken); return new TBinaryProtocol(transport); } @@ -228,28 +224,29 @@ protected override TOpenSessionReq CreateSessionRequest() return req; } - protected override Task GetResultSetMetadataAsync(TGetSchemasResp response) => - GetResultSetMetadataAsync(response.OperationHandle, Client); - protected override Task GetResultSetMetadataAsync(TGetCatalogsResp response) => - GetResultSetMetadataAsync(response.OperationHandle, Client); - protected override Task GetResultSetMetadataAsync(TGetColumnsResp response) => - GetResultSetMetadataAsync(response.OperationHandle, Client); - protected override Task GetResultSetMetadataAsync(TGetTablesResp response) => - GetResultSetMetadataAsync(response.OperationHandle, Client); - protected override Task GetRowSetAsync(TGetTableTypesResp response) => - FetchResultsAsync(response.OperationHandle); - protected override Task GetRowSetAsync(TGetColumnsResp response) => - FetchResultsAsync(response.OperationHandle); - protected override Task GetRowSetAsync(TGetTablesResp response) => - FetchResultsAsync(response.OperationHandle); - protected override Task GetRowSetAsync(TGetCatalogsResp response) => - FetchResultsAsync(response.OperationHandle); - protected override Task GetRowSetAsync(TGetSchemasResp response) => - FetchResultsAsync(response.OperationHandle); + protected override Task GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken cancellationToken = default) => + GetResultSetMetadataAsync(response.OperationHandle, Client, cancellationToken); + protected override Task GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default) => + GetResultSetMetadataAsync(response.OperationHandle, Client, cancellationToken); + protected override Task GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken cancellationToken = default) => + GetResultSetMetadataAsync(response.OperationHandle, Client, cancellationToken); + protected override Task GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken cancellationToken = default) => + GetResultSetMetadataAsync(response.OperationHandle, Client, cancellationToken); + protected override Task GetRowSetAsync(TGetTableTypesResp response, CancellationToken cancellationToken = default) => + FetchResultsAsync(response.OperationHandle, cancellationToken: cancellationToken); + protected override Task GetRowSetAsync(TGetColumnsResp response, CancellationToken cancellationToken = default) => + FetchResultsAsync(response.OperationHandle, cancellationToken: cancellationToken); + protected override Task GetRowSetAsync(TGetTablesResp response, CancellationToken cancellationToken = default) => + FetchResultsAsync(response.OperationHandle, cancellationToken: cancellationToken); + protected override Task GetRowSetAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default) => + FetchResultsAsync(response.OperationHandle, cancellationToken: cancellationToken); + protected override Task GetRowSetAsync(TGetSchemasResp response, CancellationToken cancellationToken = default) => + FetchResultsAsync(response.OperationHandle, cancellationToken: cancellationToken); private async Task FetchResultsAsync(TOperationHandle operationHandle, long batchSize = BatchSizeDefault, CancellationToken cancellationToken = default) { - await PollForResponseAsync(operationHandle, Client, PollTimeMillisecondsDefault); + await PollForResponseAsync(operationHandle, Client, PollTimeMillisecondsDefault, cancellationToken); + TFetchResultsResp fetchResp = await FetchNextAsync(operationHandle, Client, batchSize, cancellationToken); if (fetchResp.Status.StatusCode == TStatusCode.ERROR_STATUS) { diff --git a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs index 4722efce54..6cb96dd5f1 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs @@ -15,8 +15,6 @@ * limitations under the License. */ -using static System.Net.WebRequestMethods; - namespace Apache.Arrow.Adbc.Drivers.Apache.Spark { /// @@ -32,7 +30,7 @@ public static class SparkParameters public const string Type = "adbc.spark.type"; public const string DataTypeConv = "adbc.spark.data_type_conv"; public const string TLSOptions = "adbc.spark.tls_options"; - public const string HttpRequestTimeoutMilliseconds = "adbc.spark.http_request_timeout_ms"; + public const string ConnectTimeoutMilliseconds = "adbc.spark.connect_timeout_ms"; } public static class SparkAuthTypeConstants diff --git a/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs index 51813ed6c4..c8ab5772c9 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs @@ -18,6 +18,7 @@ using System; using System.Collections.Generic; using System.Net; +using System.Threading; using System.Threading.Tasks; using Apache.Hive.Service.Rpc.Thrift; using Thrift.Protocol; @@ -85,7 +86,7 @@ protected override void ValidateConnection() } - protected override Task CreateTransportAsync() + protected override TTransport CreateTransport() { // Assumption: hostName and port have already been validated. Properties.TryGetValue(SparkParameters.HostName, out string? hostName); @@ -94,14 +95,13 @@ protected override Task CreateTransportAsync() // Delay the open connection until later. bool connectClient = false; ThriftSocketTransport transport = new(hostName!, int.Parse(port!), connectClient, config: new()); - return Task.FromResult(transport); + return transport; } - protected override async Task CreateProtocolAsync(TTransport transport) + protected override async Task CreateProtocolAsync(TTransport transport, CancellationToken cancellationToken = default) { - return await base.CreateProtocolAsync(transport); + return await base.CreateProtocolAsync(transport, cancellationToken); - //Trace.TraceError($"create protocol with {Properties.Count} properties."); //if (!transport.IsOpen) await transport.OpenAsync(CancellationToken.None); //return new TBinaryProtocol(transport); } diff --git a/csharp/src/Drivers/Apache/Spark/SparkStatement.cs b/csharp/src/Drivers/Apache/Spark/SparkStatement.cs index e4bc3f6cd3..25888b1a3b 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkStatement.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkStatement.cs @@ -32,6 +32,7 @@ internal SparkStatement(SparkConnection connection) { case Options.BatchSize: case Options.PollTimeMilliseconds: + case Options.QueryTimeoutSeconds: { SetOption(kvp.Key, kvp.Value); break; @@ -45,7 +46,9 @@ protected override void SetStatementProperties(TExecuteStatementReq statement) // TODO: Ensure this is set dynamically depending on server capabilities. statement.EnforceResultPersistenceMode = false; statement.ResultPersistenceMode = 2; - + // This seems like a good idea to have the server timeout so it doesn't keep processing unnecessarily. + // Set in combination with a CancellationToken. + statement.QueryTimeout = QueryTimeoutSeconds; statement.CanReadArrowResult = true; statement.CanDownloadResult = true; statement.ConfOverlay = SparkConnection.timestampConfig; @@ -65,7 +68,7 @@ protected override void SetStatementProperties(TExecuteStatementReq statement) /// /// Provides the constant string key values to the method. /// - public new sealed class Options : HiveServer2Statement.Options + public sealed class Options : ApacheParameters { // options specific to Spark go here } diff --git a/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs b/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs index fb62ccd9a7..ea3d7d16ec 100644 --- a/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs +++ b/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs @@ -45,11 +45,14 @@ public class ApacheTestConfiguration : TestConfiguration [JsonPropertyName("batch_size"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public string BatchSize { get; set; } = string.Empty; - [JsonPropertyName("polltime_milliseconds"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + [JsonPropertyName("polltime_ms"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public string PollTimeMilliseconds { get; set; } = string.Empty; - [JsonPropertyName("http_request_timeout_ms"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] - public string HttpRequestTimeoutMilliseconds { get; set; } = string.Empty; + [JsonPropertyName("connect_timeout_ms"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string ConnectTimeoutMilliseconds { get; set; } = string.Empty; + + [JsonPropertyName("query_timeout_s"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string QueryTimeoutSeconds { get; set; } = string.Empty; [JsonPropertyName("type"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public string Type { get; set; } = string.Empty; diff --git a/csharp/test/Drivers/Apache/Common/ClientTests.cs b/csharp/test/Drivers/Apache/Common/ClientTests.cs index e3b0309d08..9148d72811 100644 --- a/csharp/test/Drivers/Apache/Common/ClientTests.cs +++ b/csharp/test/Drivers/Apache/Common/ClientTests.cs @@ -17,6 +17,7 @@ using System; using System.Collections.Generic; +using Apache.Arrow.Adbc.Client; using Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2; using Apache.Arrow.Adbc.Tests.Xunit; using Xunit; @@ -203,6 +204,27 @@ public void VerifySchemaTables() } } + [SkippableFact] + public void VerifyTimeoutsSet() + { + using (Adbc.Client.AdbcConnection adbcConnection = GetAdbcConnection()) + { + int timeout = 99; + using AdbcCommand cmd = adbcConnection.CreateCommand(); + + // setting the timout before the property value + Assert.Throws(() => + { + cmd.CommandTimeout = 1; + }); + + cmd.AdbcCommandTimeoutProperty = "adbc.apache.statement.query_timeout_s"; + cmd.CommandTimeout = timeout; + + Assert.True(cmd.CommandTimeout == timeout, $"ConnectionTimeout is not set to {timeout}"); + } + } + private Adbc.Client.AdbcConnection GetAdbcConnection(bool includeTableConstraints = true) { return new Adbc.Client.AdbcConnection( diff --git a/csharp/test/Drivers/Apache/Common/StatementTests.cs b/csharp/test/Drivers/Apache/Common/StatementTests.cs index 69eec0dd26..b793b7686c 100644 --- a/csharp/test/Drivers/Apache/Common/StatementTests.cs +++ b/csharp/test/Drivers/Apache/Common/StatementTests.cs @@ -18,6 +18,7 @@ using System; using System.Collections.Generic; using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache; using Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2; using Apache.Arrow.Adbc.Tests.Xunit; using Xunit; @@ -68,11 +69,11 @@ public void CanSetOptionPollTime(string value, bool throws = false) AdbcStatement statement = NewConnection().CreateStatement(); if (throws) { - Assert.Throws(() => statement.SetOption(Adbc.Drivers.Apache.Hive2.HiveServer2Statement.Options.PollTimeMilliseconds, value)); + Assert.Throws(() => statement.SetOption(ApacheParameters.PollTimeMilliseconds, value)); } else { - statement.SetOption(Adbc.Drivers.Apache.Hive2.HiveServer2Statement.Options.PollTimeMilliseconds, value); + statement.SetOption(ApacheParameters.PollTimeMilliseconds, value); } } @@ -101,11 +102,74 @@ public void CanSetOptionBatchSize(string value, bool throws = false) AdbcStatement statement = NewConnection().CreateStatement(); if (throws) { - Assert.Throws(() => statement!.SetOption(Adbc.Drivers.Apache.Hive2.HiveServer2Statement.Options.BatchSize, value)); + Assert.Throws(() => statement!.SetOption(ApacheParameters.BatchSize, value)); } else { - statement.SetOption(Adbc.Drivers.Apache.Hive2.HiveServer2Statement.Options.BatchSize, value); + statement.SetOption(ApacheParameters.BatchSize, value); + } + } + + /// + /// Validates if the SetOption handle valid/invalid data correctly for the QueryTimeout option. + /// + [SkippableTheory] + [InlineData("zero", true)] + [InlineData("-2147483648", true)] + [InlineData("2147483648", true)] + [InlineData("0", false)] + [InlineData("-1", true)] + [InlineData("1")] + [InlineData("2147483647")] + public void CanSetOptionQueryTimeout(string value, bool throws = false) + { + var testConfiguration = TestConfiguration.Clone() as TConfig; + testConfiguration!.QueryTimeoutSeconds = value; + if (throws) + { + Assert.Throws(() => NewConnection(testConfiguration).CreateStatement()); + } + + AdbcStatement statement = NewConnection().CreateStatement(); + if (throws) + { + Assert.Throws(() => statement.SetOption(ApacheParameters.QueryTimeoutSeconds, value)); + } + else + { + statement.SetOption(ApacheParameters.QueryTimeoutSeconds, value); + } + } + + /// + /// Queries the backend with various timeouts. + /// + /// + [SkippableTheory] + [ClassData(typeof(StatementTimeoutTestData))] + internal void StatementTimeoutTest(StatementWithExceptions statementWithExceptions) + { + TConfig testConfiguration = (TConfig)TestConfiguration.Clone(); + + if (statementWithExceptions.QueryTimeoutSeconds.HasValue) + testConfiguration.QueryTimeoutSeconds = statementWithExceptions.QueryTimeoutSeconds.Value.ToString(); + + if (!string.IsNullOrEmpty(statementWithExceptions.Query)) + testConfiguration.Query = statementWithExceptions.Query!; + + OutputHelper?.WriteLine($"QueryTimeoutSeconds: {testConfiguration.QueryTimeoutSeconds}. ShouldSucceed: {statementWithExceptions.ExceptionType == null}. Query: [{testConfiguration.Query}]"); + + try + { + AdbcStatement st = NewConnection(testConfiguration).CreateStatement(); + st.SqlQuery = testConfiguration.Query; + QueryResult qr = st.ExecuteQuery(); + + OutputHelper?.WriteLine($"QueryResultRowCount: {qr.RowCount}"); + } + catch (Exception ex) when (ApacheUtility.ContainsException(ex, statementWithExceptions.ExceptionType, out Exception? containedException)) + { + Assert.IsType(statementWithExceptions.ExceptionType!, containedException!); } } @@ -116,10 +180,58 @@ public void CanSetOptionBatchSize(string value, bool throws = false) public async Task CanInteractUsingSetOptions() { const string columnName = "INDEX"; - Statement.SetOption(Adbc.Drivers.Apache.Hive2.HiveServer2Statement.Options.PollTimeMilliseconds, "100"); - Statement.SetOption(Adbc.Drivers.Apache.Hive2.HiveServer2Statement.Options.BatchSize, "10"); + Statement.SetOption(ApacheParameters.PollTimeMilliseconds, "100"); + Statement.SetOption(ApacheParameters.BatchSize, "10"); using TemporaryTable temporaryTable = await NewTemporaryTableAsync(Statement, $"{columnName} INT"); await ValidateInsertSelectDeleteSingleValueAsync(temporaryTable.TableName, columnName, 1); } } + + /// + /// Data type used for metadata timeout tests. + /// + internal class StatementWithExceptions + { + public StatementWithExceptions(int? queryTimeoutSeconds, string? query, Type? exceptionType) + { + QueryTimeoutSeconds = queryTimeoutSeconds; + Query = query; + ExceptionType = exceptionType; + } + + /// + /// If null, uses the default timeout. + /// + public int? QueryTimeoutSeconds { get; } + + /// + /// If null, expected to succeed. + /// + public Type? ExceptionType { get; } + + /// + /// If null, uses the default TestConfiguration. + /// + public string? Query { get; } + } + + /// + /// Collection of for testing statement timeouts."/> + /// + internal class StatementTimeoutTestData : TheoryData + { + public StatementTimeoutTestData() + { + string longRunningQuery = "SELECT COUNT(*) AS total_count\nFROM (\n SELECT t1.id AS id1, t2.id AS id2\n FROM RANGE(1000000) t1\n CROSS JOIN RANGE(10000) t2\n) subquery\nWHERE MOD(id1 + id2, 2) = 0"; + + Add(new(0, null, null)); + Add(new(null, null, null)); + Add(new(1, null, typeof(TimeoutException))); + Add(new(5, null, null)); + Add(new(30, null, null)); + Add(new(5, longRunningQuery, typeof(TimeoutException))); + Add(new(null, longRunningQuery, typeof(TimeoutException))); + Add(new(0, longRunningQuery, null)); + } + } } diff --git a/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs b/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs index c2faa9d12b..34e971bd8f 100644 --- a/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs +++ b/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs @@ -19,7 +19,10 @@ using System.Collections.Generic; using System.Globalization; using System.Net; +using Apache.Arrow.Adbc.Drivers.Apache; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; using Apache.Arrow.Adbc.Drivers.Apache.Spark; +using Thrift.Transport; using Xunit; using Xunit.Abstractions; @@ -48,6 +51,231 @@ internal void CanDetectConnectionParameterErrors(ParametersWithExceptions test) OutputHelper?.WriteLine(exeption.Message); } + /// + /// Tests connection timeout to establish a session with the backend. + /// + /// The timeout (in ms) + /// The exception type to expect (if any) + /// An alternate exception that may occur (if any) + [SkippableTheory] + [InlineData(0, null, null)] + [InlineData(1, typeof(TimeoutException), typeof(TTransportException))] + [InlineData(10, typeof(TimeoutException), typeof(TTransportException))] + [InlineData(30000, null, null)] + [InlineData(null, null, null)] + public void ConnectionTimeoutTest(int? connectTimeoutMilliseconds, Type? exceptionType, Type? alternateExceptionType) + { + SparkTestConfiguration testConfiguration = (SparkTestConfiguration)TestConfiguration.Clone(); + + if (connectTimeoutMilliseconds.HasValue) + testConfiguration.ConnectTimeoutMilliseconds = connectTimeoutMilliseconds.Value.ToString(); + + OutputHelper?.WriteLine($"ConnectTimeoutMilliseconds: {testConfiguration.ConnectTimeoutMilliseconds}. ShouldSucceed: {exceptionType == null}"); + + try + { + NewConnection(testConfiguration); + } + catch(AggregateException aex) + { + if (exceptionType != null) + { + if (alternateExceptionType != null && aex.InnerException?.GetType() != exceptionType) + { + if (aex.InnerException?.GetType() == typeof(HiveServer2Exception)) + { + // a TTransportException is inside a HiveServer2Exception + Assert.IsType(alternateExceptionType, aex.InnerException!.InnerException); + } + else + { + throw; + } + } + else + { + Assert.IsType(exceptionType, aex.InnerException); + } + } + else + { + throw; + } + } + } + + /// + /// Tests the various metadata calls on a SparkConnection + /// + /// + [SkippableTheory] + [ClassData(typeof(MetadataTimeoutTestData))] + internal void MetadataTimeoutTest(MetadataWithExceptions metadataWithException) + { + SparkTestConfiguration testConfiguration = (SparkTestConfiguration)TestConfiguration.Clone(); + + if (metadataWithException.QueryTimeoutSeconds.HasValue) + testConfiguration.QueryTimeoutSeconds = metadataWithException.QueryTimeoutSeconds.Value.ToString(); + + OutputHelper?.WriteLine($"Action: {metadataWithException.ActionName}. QueryTimeoutSeconds: {testConfiguration.QueryTimeoutSeconds}. ShouldSucceed: {metadataWithException.ExceptionType == null}"); + + try + { + metadataWithException.MetadataAction(testConfiguration); + } + catch (Exception ex) when (ApacheUtility.ContainsException(ex, metadataWithException.ExceptionType, out Exception? containedException)) + { + Assert.IsType(metadataWithException.ExceptionType!, containedException); + } + catch (Exception ex) when (ApacheUtility.ContainsException(ex, metadataWithException.AlternateExceptionType, out Exception? containedException)) + { + Assert.IsType(metadataWithException.AlternateExceptionType!, containedException); + } + } + + /// + /// Data type used for metadata timeout tests. + /// + internal class MetadataWithExceptions + { + public MetadataWithExceptions(int? queryTimeoutSeconds, string actionName, Action action, Type? exceptionType, Type? alternateExceptionType) + { + QueryTimeoutSeconds = queryTimeoutSeconds; + ActionName = actionName; + MetadataAction = action; + ExceptionType = exceptionType; + AlternateExceptionType = alternateExceptionType; + } + + /// + /// If null, uses the default timeout. + /// + public int? QueryTimeoutSeconds { get; } + + public string ActionName { get; } + + /// + /// If null, expected to succeed. + /// + public Type? ExceptionType { get; } + + /// + /// Sometimes you can expect one but may get another. + /// For example, on GetObjectsAll, sometimes a TTransportException is expected but a TaskCanceledException is received during the test. + /// + public Type? AlternateExceptionType { get; } + + /// + /// The metadata action to perform. + /// + public Action MetadataAction { get; } + } + + /// + /// Used for testing timeouts on metadata calls. + /// + internal class MetadataTimeoutTestData : TheoryData + { + public MetadataTimeoutTestData() + { + SparkConnectionTest sparkConnectionTest = new SparkConnectionTest(null); + + Action getObjectsAll = (testConfiguration) => + { + AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); + cn.GetObjects(AdbcConnection.GetObjectsDepth.All, testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Table, null, null); + }; + + Action getObjectsCatalogs = (testConfiguration) => + { + AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); + cn.GetObjects(AdbcConnection.GetObjectsDepth.Catalogs, testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Schema, null, null); + }; + + Action getObjectsDbSchemas = (testConfiguration) => + { + AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); + cn.GetObjects(AdbcConnection.GetObjectsDepth.DbSchemas, testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Schema, null, null); + }; + + Action getObjectsTables = (testConfiguration) => + { + AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); + cn.GetObjects(AdbcConnection.GetObjectsDepth.Tables, testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Schema, null, null); + }; + + AddAction("getObjectsAll", getObjectsAll, new List() { null, typeof(TimeoutException), null, null, null } ); + AddAction("getObjectsCatalogs", getObjectsCatalogs); + AddAction("getObjectsDbSchemas", getObjectsDbSchemas); + AddAction("getObjectsTables", getObjectsTables); + + Action getTableTypes = (testConfiguration) => + { + AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); + cn.GetTableTypes(); + }; + + AddAction("getTableTypes", getTableTypes); + + Action getTableSchema = (testConfiguration) => + { + AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); + cn.GetTableSchema(testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Table); + }; + + AddAction("getTableSchema", getTableSchema); + } + + /// + /// Adds the action with the default timeouts. + /// + /// The friendly name of the action. + /// The action to perform. + /// Optional list of alternate exceptions that are possible. Must have 5 items if present. + private void AddAction(string name, Action action, List? alternateExceptions = null) + { + List expectedExceptions = new List() + { + null, // QueryTimeout = 0 + typeof(TTransportException), // QueryTimeout = 1 + typeof(TimeoutException), // QueryTimeout = 10 + null, // QueryTimeout = default + null // QueryTimeout = 300 + }; + + AddAction(name, action, expectedExceptions, alternateExceptions); + } + + /// + /// Adds the action with the default timeouts. + /// + /// The action to perform. + /// The expected exceptions. + /// + /// For List the position is based on the behavior when: + /// [0] QueryTimeout = 0 + /// [1] QueryTimeout = 1 + /// [2] QueryTimeout = 10 + /// [3] QueryTimeout = default + /// [4] QueryTimeout = 300 + /// + private void AddAction(string name, Action action, List expectedExceptions, List? alternateExceptions) + { + Assert.True(expectedExceptions.Count == 5); + + if (alternateExceptions != null) + { + Assert.True(alternateExceptions.Count == 5); + } + + Add(new(0, name, action, expectedExceptions[0], alternateExceptions?[0])); + Add(new(1, name, action, expectedExceptions[1], alternateExceptions?[1])); + Add(new(10, name, action, expectedExceptions[2], alternateExceptions?[2])); + Add(new(null, name, action, expectedExceptions[3], alternateExceptions?[3])); + Add(new(300, name, action, expectedExceptions[4], alternateExceptions?[4])); + } + } + internal class ParametersWithExceptions { public ParametersWithExceptions(Dictionary parameters, Type exceptionType) @@ -85,11 +313,9 @@ public InvalidConnectionParametersTestData() Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] = "httpxxz://hostname.com" }, typeof(ArgumentOutOfRangeException))); Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] = "http-//hostname.com" }, typeof(UriFormatException))); Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] = "httpxxz://hostname.com:1234567890" }, typeof(UriFormatException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword" , [SparkParameters.HttpRequestTimeoutMilliseconds] = "0" }, typeof(ArgumentOutOfRangeException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [SparkParameters.HttpRequestTimeoutMilliseconds] = "-1" }, typeof(ArgumentOutOfRangeException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [SparkParameters.HttpRequestTimeoutMilliseconds] = ((long)int.MaxValue + 1).ToString() }, typeof(ArgumentOutOfRangeException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [SparkParameters.HttpRequestTimeoutMilliseconds] = "non-numeric" }, typeof(ArgumentOutOfRangeException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [SparkParameters.HttpRequestTimeoutMilliseconds] = "" }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [SparkParameters.ConnectTimeoutMilliseconds] = ((long)int.MaxValue + 1).ToString() }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [SparkParameters.ConnectTimeoutMilliseconds] = "non-numeric" }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [SparkParameters.ConnectTimeoutMilliseconds] = "" }, typeof(ArgumentOutOfRangeException))); } } } diff --git a/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs index 54c5368536..16a5501118 100644 --- a/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs +++ b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.Data.SqlTypes; using System.Text; +using Apache.Arrow.Adbc.Drivers.Apache; using Apache.Arrow.Adbc.Drivers.Apache.Hive2; using Apache.Arrow.Adbc.Drivers.Apache.Spark; using Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2; @@ -102,15 +103,19 @@ public override Dictionary GetDriverParameters(SparkTestConfigur } if (!string.IsNullOrEmpty(testConfiguration.BatchSize)) { - parameters.Add(HiveServer2Statement.Options.BatchSize, testConfiguration.BatchSize!); + parameters.Add(ApacheParameters.BatchSize, testConfiguration.BatchSize!); } if (!string.IsNullOrEmpty(testConfiguration.PollTimeMilliseconds)) { - parameters.Add(HiveServer2Statement.Options.PollTimeMilliseconds, testConfiguration.PollTimeMilliseconds!); + parameters.Add(ApacheParameters.PollTimeMilliseconds, testConfiguration.PollTimeMilliseconds!); } - if (!string.IsNullOrEmpty(testConfiguration.HttpRequestTimeoutMilliseconds)) + if (!string.IsNullOrEmpty(testConfiguration.ConnectTimeoutMilliseconds)) { - parameters.Add(SparkParameters.HttpRequestTimeoutMilliseconds, testConfiguration.HttpRequestTimeoutMilliseconds!); + parameters.Add(SparkParameters.ConnectTimeoutMilliseconds, testConfiguration.ConnectTimeoutMilliseconds!); + } + if (!string.IsNullOrEmpty(testConfiguration.QueryTimeoutSeconds)) + { + parameters.Add(ApacheParameters.QueryTimeoutSeconds, testConfiguration.QueryTimeoutSeconds!); } return parameters; diff --git a/csharp/test/Drivers/Apache/Spark/StatementTests.cs b/csharp/test/Drivers/Apache/Spark/StatementTests.cs index 25d27179ac..aaafc31ba7 100644 --- a/csharp/test/Drivers/Apache/Spark/StatementTests.cs +++ b/csharp/test/Drivers/Apache/Spark/StatementTests.cs @@ -15,6 +15,8 @@ * limitations under the License. */ +using System; +using Xunit; using Xunit.Abstractions; namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark