Skip to content

Commit

Permalink
feat(csharp/src/Drivers/Apache): add connect and query timeout options (
Browse files Browse the repository at this point in the history
apache#2312)

Adds options for command and query timeout

| Property               | Description | Default |
| :---                   | :---        | :---    |
| `adbc.spark.connect_timeout_ms` | Sets the timeout (in milliseconds)
to open a new session. Values can be 0 (infinite) or greater than zero.
| `30000` |
| `adbc.apache.statement.query_timeout_s` | Sets the maximum time (in
seconds) for a query to complete. Values can be 0 (infinite) or greater
than zero. | `60` |

---------

Co-authored-by: Aman Goyal <[email protected]>
Co-authored-by: David Coe <[email protected]>
  • Loading branch information
3 people authored Dec 4, 2024
1 parent 08ac053 commit 845f9c2
Show file tree
Hide file tree
Showing 23 changed files with 1,098 additions and 353 deletions.
28 changes: 26 additions & 2 deletions csharp/src/Client/AdbcCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
using System.Data;
using System.Data.Common;
using System.Data.SqlTypes;
using System.Globalization;
using System.Linq;
using System.Threading.Tasks;
using Apache.Arrow.Types;
Expand All @@ -32,10 +33,11 @@ namespace Apache.Arrow.Adbc.Client
/// </summary>
public sealed class AdbcCommand : DbCommand
{
private AdbcStatement _adbcStatement;
private readonly AdbcStatement _adbcStatement;
private AdbcParameterCollection? _dbParameterCollection;
private int _timeout = 30;
private bool _disposed;
private string? _commandTimeoutProperty;

/// <summary>
/// Overloaded. Initializes <see cref="AdbcCommand"/>.
Expand Down Expand Up @@ -117,10 +119,32 @@ public override CommandType CommandType
}
}


/// <summary>
/// Gets or sets the name of the command timeout property for the underlying ADBC driver.
/// </summary>
public string AdbcCommandTimeoutProperty
{
get
{
if (string.IsNullOrEmpty(_commandTimeoutProperty))
throw new InvalidOperationException("CommandTimeoutProperty is not set.");

return _commandTimeoutProperty!;
}
set => _commandTimeoutProperty = value;
}

public override int CommandTimeout
{
get => _timeout;
set => _timeout = value;
set
{
// ensures the property exists before setting the CommandTimeout value
string property = AdbcCommandTimeoutProperty;
_adbcStatement.SetOption(property, value.ToString(CultureInfo.InvariantCulture));
_timeout = value;
}
}

protected override DbParameterCollection DbParameterCollection
Expand Down
29 changes: 29 additions & 0 deletions csharp/src/Drivers/Apache/ApacheParameters.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

namespace Apache.Arrow.Adbc.Drivers.Apache
{
/// <summary>
/// Options common to all Apache drivers.
/// </summary>
public class ApacheParameters
{
public const string PollTimeMilliseconds = "adbc.apache.statement.polltime_ms";
public const string BatchSize = "adbc.apache.statement.batch_size";
public const string QueryTimeoutSeconds = "adbc.apache.statement.query_timeout_s";
}
}
141 changes: 141 additions & 0 deletions csharp/src/Drivers/Apache/ApacheUtility.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System;
using System.Threading;

namespace Apache.Arrow.Adbc.Drivers.Apache
{
internal class ApacheUtility
{
internal const int QueryTimeoutSecondsDefault = 60;

public enum TimeUnit
{
Seconds,
Milliseconds
}

public static CancellationToken GetCancellationToken(int timeout, TimeUnit timeUnit)
{
TimeSpan span;

if (timeout == 0 || timeout == int.MaxValue)
{
// the max TimeSpan for CancellationTokenSource is int.MaxValue in milliseconds (not TimeSpan.MaxValue)
// no matter what the unit is
span = TimeSpan.FromMilliseconds(int.MaxValue);
}
else
{
if (timeUnit == TimeUnit.Seconds)
{
span = TimeSpan.FromSeconds(timeout);
}
else
{
span = TimeSpan.FromMilliseconds(timeout);
}
}

return GetCancellationToken(span);
}

private static CancellationToken GetCancellationToken(TimeSpan timeSpan)
{
var cts = new CancellationTokenSource(timeSpan);
return cts.Token;
}

public static bool QueryTimeoutIsValid(string key, string value, out int queryTimeoutSeconds)
{
if (!string.IsNullOrEmpty(value) && int.TryParse(value, out int queryTimeout) && (queryTimeout >= 0))
{
queryTimeoutSeconds = queryTimeout;
return true;
}
else
{
throw new ArgumentOutOfRangeException(key, value, $"The value '{value}' for option '{key}' is invalid. Must be a numeric value of 0 (infinite) or greater.");
}
}

public static bool ContainsException<T>(Exception exception, out T? containedException) where T : Exception
{
if (exception is AggregateException aggregateException)
{
foreach (Exception? ex in aggregateException.InnerExceptions)
{
if (ex is T ce)
{
containedException = ce;
return true;
}
}
}

Exception? e = exception;
while (e != null)
{
if (e is T ce)
{
containedException = ce;
return true;
}
e = e.InnerException;
}

containedException = null;
return false;
}

public static bool ContainsException(Exception exception, Type? exceptionType, out Exception? containedException)
{
if (exception == null || exceptionType == null)
{
containedException = null;
return false;
}

if (exception is AggregateException aggregateException)
{
foreach (Exception? ex in aggregateException.InnerExceptions)
{
if (exceptionType.IsInstanceOfType(ex))
{
containedException = ex;
return true;
}
}
}

Exception? e = exception;
while (e != null)
{
if (exceptionType.IsInstanceOfType(e))
{
containedException = e;
return true;
}
e = e.InnerException;
}

containedException = null;
return false;
}
}
}
106 changes: 73 additions & 33 deletions csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ internal abstract class HiveServer2Connection : AdbcConnection
{
internal const long BatchSizeDefault = 50000;
internal const int PollTimeMillisecondsDefault = 500;

private const int ConnectTimeoutMillisecondsDefault = 30000;
private TTransport? _transport;
private TCLIService.Client? _client;
private readonly Lazy<string> _vendorVersion;
Expand All @@ -45,6 +45,14 @@ internal HiveServer2Connection(IReadOnlyDictionary<string, string> properties)
// https://learn.microsoft.com/en-us/dotnet/framework/performance/lazy-initialization#exceptions-in-lazy-objects
_vendorVersion = new Lazy<string>(() => GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_VER), LazyThreadSafetyMode.PublicationOnly);
_vendorName = new Lazy<string>(() => GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_NAME), LazyThreadSafetyMode.PublicationOnly);

if (properties.TryGetValue(ApacheParameters.QueryTimeoutSeconds, out string? queryTimeoutSecondsSettingValue))
{
if (ApacheUtility.QueryTimeoutIsValid(ApacheParameters.QueryTimeoutSeconds, queryTimeoutSecondsSettingValue, out int queryTimeoutSeconds))
{
QueryTimeoutSeconds = queryTimeoutSeconds;
}
}
}

internal TCLIService.Client Client
Expand All @@ -56,30 +64,48 @@ internal TCLIService.Client Client

internal string VendorName => _vendorName.Value;

protected internal int QueryTimeoutSeconds { get; set; } = ApacheUtility.QueryTimeoutSecondsDefault;

internal IReadOnlyDictionary<string, string> Properties { get; }

internal async Task OpenAsync()
{
TTransport transport = await CreateTransportAsync();
TProtocol protocol = await CreateProtocolAsync(transport);
_transport = protocol.Transport;
_client = new TCLIService.Client(protocol);
TOpenSessionReq request = CreateSessionRequest();
TOpenSessionResp? session = await Client.OpenSession(request);

// Some responses don't raise an exception. Explicitly check the status.
if (session == null)
CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(ConnectTimeoutMilliseconds, ApacheUtility.TimeUnit.Milliseconds);
try
{
throw new HiveServer2Exception("unable to open session. unknown error.");
TTransport transport = CreateTransport();
TProtocol protocol = await CreateProtocolAsync(transport, cancellationToken);
_transport = protocol.Transport;
_client = new TCLIService.Client(protocol);
TOpenSessionReq request = CreateSessionRequest();

TOpenSessionResp? session = await Client.OpenSession(request, cancellationToken);

// Explicitly check the session status
if (session == null)
{
throw new HiveServer2Exception("Unable to open session. Unknown error.");
}
else if (session.Status.StatusCode != TStatusCode.SUCCESS_STATUS)
{
throw new HiveServer2Exception(session.Status.ErrorMessage)
.SetNativeError(session.Status.ErrorCode)
.SetSqlState(session.Status.SqlState);
}

SessionHandle = session.SessionHandle;
}
else if (session.Status.StatusCode != TStatusCode.SUCCESS_STATUS)
catch (Exception ex)
when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) ||
(ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested))
{
throw new HiveServer2Exception(session.Status.ErrorMessage)
.SetNativeError(session.Status.ErrorCode)
.SetSqlState(session.Status.SqlState);
throw new TimeoutException("The operation timed out while attempting to open a session. Please try increasing connect timeout.", ex);
}
catch (Exception ex) when (ex is not HiveServer2Exception)
{
// Handle other exceptions if necessary
throw new HiveServer2Exception($"An unexpected error occurred while opening the session. '{ex.Message}'", ex);
}

SessionHandle = session.SessionHandle;
}

internal TSessionHandle? SessionHandle { get; private set; }
Expand All @@ -88,11 +114,11 @@ internal async Task OpenAsync()

protected internal HiveServer2TlsOption TlsOptions { get; set; } = HiveServer2TlsOption.Empty;

protected internal int HttpRequestTimeout { get; set; } = 30000;
protected internal int ConnectTimeoutMilliseconds { get; set; } = ConnectTimeoutMillisecondsDefault;

protected abstract Task<TTransport> CreateTransportAsync();
protected abstract TTransport CreateTransport();

protected abstract Task<TProtocol> CreateProtocolAsync(TTransport transport);
protected abstract Task<TProtocol> CreateProtocolAsync(TTransport transport, CancellationToken cancellationToken = default);

protected abstract TOpenSessionReq CreateSessionRequest();

Expand All @@ -110,14 +136,14 @@ public override IArrowArrayStream GetTableTypes()
throw new NotImplementedException();
}

internal static async Task PollForResponseAsync(TOperationHandle operationHandle, TCLIService.IAsync client, int pollTimeMilliseconds)
internal static async Task PollForResponseAsync(TOperationHandle operationHandle, TCLIService.IAsync client, int pollTimeMilliseconds, CancellationToken cancellationToken = default)
{
TGetOperationStatusResp? statusResponse = null;
do
{
if (statusResponse != null) { await Task.Delay(pollTimeMilliseconds); }
if (statusResponse != null) { await Task.Delay(pollTimeMilliseconds, cancellationToken); }
TGetOperationStatusReq request = new(operationHandle);
statusResponse = await client.GetOperationStatus(request);
statusResponse = await client.GetOperationStatus(request, cancellationToken);
} while (statusResponse.OperationState == TOperationState.PENDING_STATE || statusResponse.OperationState == TOperationState.RUNNING_STATE);
}

Expand All @@ -129,24 +155,38 @@ private string GetInfoTypeStringValue(TGetInfoType infoType)
InfoType = infoType,
};

TGetInfoResp getInfoResp = Client.GetInfo(req).Result;
if (getInfoResp.Status.StatusCode == TStatusCode.ERROR_STATUS)
CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds);
try
{
throw new HiveServer2Exception(getInfoResp.Status.ErrorMessage)
.SetNativeError(getInfoResp.Status.ErrorCode)
.SetSqlState(getInfoResp.Status.SqlState);
TGetInfoResp getInfoResp = Client.GetInfo(req, cancellationToken).Result;
if (getInfoResp.Status.StatusCode == TStatusCode.ERROR_STATUS)
{
throw new HiveServer2Exception(getInfoResp.Status.ErrorMessage)
.SetNativeError(getInfoResp.Status.ErrorCode)
.SetSqlState(getInfoResp.Status.SqlState);
}

return getInfoResp.InfoValue.StringValue;
}
catch (Exception ex)
when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) ||
(ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested))
{
throw new TimeoutException("The metadata query execution timed out. Consider increasing the query timeout value.", ex);
}
catch (Exception ex) when (ex is not HiveServer2Exception)
{
throw new HiveServer2Exception($"An unexpected error occurred while running metadata query. '{ex.Message}'", ex);
}

return getInfoResp.InfoValue.StringValue;
}

public override void Dispose()
{
if (_client != null)
{
TCloseSessionReq r6 = new TCloseSessionReq(SessionHandle);
_client.CloseSession(r6).Wait();

CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds);
TCloseSessionReq r6 = new(SessionHandle);
_client.CloseSession(r6, cancellationToken).Wait();
_transport?.Close();
_client.Dispose();
_transport = null;
Expand Down
Loading

0 comments on commit 845f9c2

Please sign in to comment.