Skip to content

Commit

Permalink
fix(go/adbc/driver/snowflake): Made GetObjects case insensitive (apac…
Browse files Browse the repository at this point in the history
…he#1328)

### Description:
`GetObjects` API was inconsistent case sensitivity for patterns.
`getObjectsDbSchemas` driver implementation used `LIKE` whereas
`getObjectsTables` used `ILIKE`

### Solution:
Based on discussion here:
apache#1314 changed `GetObjects`
API to use `ILIKE` throughout instead of `LIKE`

### Testing:
Added tests for lowercase, uppercase and `_` wildcard

Fixes apache#1314.
  • Loading branch information
ryan-syed authored Dec 4, 2023
1 parent 163bc03 commit 7a06864
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 140 deletions.
196 changes: 73 additions & 123 deletions csharp/test/Drivers/Snowflake/DriverTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,45 @@ public class DriverTests : IDisposable
readonly AdbcDriver _snowflakeDriver;
readonly AdbcDatabase _database;
readonly AdbcConnection _connection;
readonly List<string> _tableTypes;

public static IEnumerable<object[]> GetPatterns(string name)
{
if (string.IsNullOrEmpty(name)) yield break;

yield return new object[] { name };
yield return new object[] { $"{DriverTests.GetPartialNameForPatternMatch(name)}%" };
yield return new object[] { $"{DriverTests.GetPartialNameForPatternMatch(name).ToLower()}%" };
yield return new object[] { $"{DriverTests.GetPartialNameForPatternMatch(name).ToUpper()}%" };
yield return new object[] { $"_{DriverTests.GetNameWithoutFirstChatacter(name)}" };
yield return new object[] { $"_{DriverTests.GetNameWithoutFirstChatacter(name).ToLower()}" };
yield return new object[] { $"_{DriverTests.GetNameWithoutFirstChatacter(name).ToUpper()}" };
}

public static IEnumerable<object[]> CatalogNamePatternData()
{
string databaseName = SnowflakeTestingUtils.TestConfiguration?.Metadata.Catalog;
return GetPatterns(databaseName);
}

public static IEnumerable<object[]> DbSchemasNamePatternData()
{
string dbSchemaName = SnowflakeTestingUtils.TestConfiguration?.Metadata.Schema;
return GetPatterns(dbSchemaName);
}

public static IEnumerable<object[]> TableNamePatternData()
{
string tableName = SnowflakeTestingUtils.TestConfiguration?.Metadata.Table;
return GetPatterns(tableName);
}

public DriverTests()
{
Skip.IfNot(Utils.CanExecuteTestConfig(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE));
_testConfiguration = Utils.LoadTestConfiguration<SnowflakeTestConfiguration>(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE);
_testConfiguration = SnowflakeTestingUtils.TestConfiguration;

_tableTypes = new List<string> { "BASE TABLE", "VIEW" };
Dictionary<string, string> parameters = new Dictionary<string, string>();
Dictionary<string, string> options = new Dictionary<string, string>();
_snowflakeDriver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(_testConfiguration, out parameters);
Expand Down Expand Up @@ -110,191 +143,98 @@ public void CanGetInfo()
}
}

/// <summary>
/// Validates if the driver can call GetObjects with GetObjectsDepth as Catalogs.
/// </summary>
[SkippableFact, Order(3)]
public void CanGetObjectsCatalogs()
{
string databaseName = _testConfiguration.Metadata.Catalog;

using IArrowArrayStream stream = _connection.GetObjects(
depth: AdbcConnection.GetObjectsDepth.Catalogs,
catalogPattern: databaseName,
dbSchemaPattern: null,
tableNamePattern: null,
tableTypes: new List<string> { "BASE TABLE", "VIEW" },
columnNamePattern: null);

using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result;

List<AdbcCatalog> catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, null);

AdbcCatalog catalog = catalogs.FirstOrDefault();

Assert.True(catalog != null, "catalog should not be null");
Assert.Equal(databaseName, catalog.Name);
}

/// <summary>
/// Validates if the driver can call GetObjects with GetObjectsDepth as Catalogs and CatalogName passed as a pattern.
/// </summary>
[SkippableFact, Order(3)]
public void CanGetObjectsCatalogsWithPattern()
[SkippableTheory, Order(3)]
[MemberData(nameof(CatalogNamePatternData))]
public void CanGetObjectsCatalogs(string catalogPattern)
{
string databaseName = _testConfiguration.Metadata.Catalog;
string schemaName = _testConfiguration.Metadata.Schema;
string partialDatabaseName = GetPartialNameForPatternMatch(databaseName);


using IArrowArrayStream stream = _connection.GetObjects(
depth: AdbcConnection.GetObjectsDepth.Catalogs,
catalogPattern: $"{partialDatabaseName}%",
catalogPattern: catalogPattern,
dbSchemaPattern: null,
tableNamePattern: null,
tableTypes: new List<string> { "BASE TABLE", "VIEW" },
tableTypes: _tableTypes,
columnNamePattern: null);

using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result;

List<AdbcCatalog> catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, null);

AdbcCatalog catalog = catalogs.FirstOrDefault();
AdbcCatalog catalog = catalogs.Where((catalog) => string.Equals(catalog.Name, databaseName)).FirstOrDefault();

Assert.True(catalog != null, "catalog should not be null");
Assert.StartsWith(databaseName, catalog.Name);
}

/// <summary>
/// Validates if the driver can call GetObjects with GetObjectsDepth as DbSchemas.
/// </summary>
[SkippableFact, Order(3)]
public void CanGetObjectsDbSchemas()
{
// need to add the database
string databaseName = _testConfiguration.Metadata.Catalog;
string schemaName = _testConfiguration.Metadata.Schema;

using IArrowArrayStream stream = _connection.GetObjects(
depth: AdbcConnection.GetObjectsDepth.DbSchemas,
catalogPattern: databaseName,
dbSchemaPattern: schemaName,
tableNamePattern: null,
tableTypes: new List<string> { "BASE TABLE", "VIEW" },
columnNamePattern: null);

using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result;

List<AdbcCatalog> catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName);

List<AdbcDbSchema> dbSchemas = catalogs
.Select(s => s.DbSchemas)
.FirstOrDefault();
AdbcDbSchema dbSchema = dbSchemas.FirstOrDefault();

Assert.True(dbSchema != null, "dbSchema should not be null");
Assert.Equal(schemaName, dbSchema.Name);
}

/// <summary>
/// Validates if the driver can call GetObjects with GetObjectsDepth as DbSchemas with DbSchemaName as a pattern.
/// </summary>
[SkippableFact, Order(3)]
public void CanGetObjectsDbSchemasWithPattern()
[SkippableTheory, Order(3)]
[MemberData(nameof(DbSchemasNamePatternData))]
public void CanGetObjectsDbSchemas(string dbSchemaPattern)
{
// need to add the database
string databaseName = _testConfiguration.Metadata.Catalog;
string schemaName = _testConfiguration.Metadata.Schema;
string partialSchemaName = GetPartialNameForPatternMatch(schemaName);

using IArrowArrayStream stream = _connection.GetObjects(
depth: AdbcConnection.GetObjectsDepth.DbSchemas,
catalogPattern: databaseName,
dbSchemaPattern: $"{partialSchemaName}%",
dbSchemaPattern: dbSchemaPattern,
tableNamePattern: null,
tableTypes: new List<string> { "BASE TABLE", "VIEW" },
tableTypes: _tableTypes,
columnNamePattern: null);

using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result;

List<AdbcCatalog> catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName);

List<AdbcDbSchema> dbSchemas = catalogs
.Select(s => s.DbSchemas)
.Where(c => string.Equals(c.Name, databaseName))
.Select(c => c.DbSchemas)
.FirstOrDefault();
AdbcDbSchema dbSchema = dbSchemas.FirstOrDefault();
AdbcDbSchema dbSchema = dbSchemas.Where((dbSchema) => string.Equals(dbSchema.Name, schemaName)).FirstOrDefault();

Assert.True(dbSchema != null, "dbSchema should not be null");
Assert.StartsWith(partialSchemaName, dbSchema.Name);
}

/// <summary>
/// Validates if the driver can call GetObjects with GetObjectsDepth as Tables.
/// </summary>
[SkippableFact, Order(3)]
public void CanGetObjectsTables()
{
// need to add the database
string databaseName = _testConfiguration.Metadata.Catalog;
string schemaName = _testConfiguration.Metadata.Schema;
string tableName = _testConfiguration.Metadata.Table;

using IArrowArrayStream stream = _connection.GetObjects(
depth: AdbcConnection.GetObjectsDepth.All,
catalogPattern: databaseName,
dbSchemaPattern: schemaName,
tableNamePattern: tableName,
tableTypes: new List<string> { "BASE TABLE", "VIEW" },
columnNamePattern: null);

using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result;

List<AdbcCatalog> catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName);

List<AdbcTable> tables = catalogs
.Select(s => s.DbSchemas)
.FirstOrDefault()
.Select(t => t.Tables)
.FirstOrDefault();
AdbcTable table = tables.FirstOrDefault();

Assert.True(table != null, "table should not be null");
Assert.Equal(tableName, table.Name);
}

/// <summary>
/// Validates if the driver can call GetObjects with GetObjectsDepth as Tables with TableName as a pattern.
/// </summary>
[SkippableFact, Order(3)]
public void CanGetObjectsTablesWithPattern()
[SkippableTheory, Order(3)]
[MemberData(nameof(TableNamePatternData))]
public void CanGetObjectsTables(string tableNamePattern)
{
// need to add the database
string databaseName = _testConfiguration.Metadata.Catalog;
string schemaName = _testConfiguration.Metadata.Schema;
string tableName = _testConfiguration.Metadata.Table;
string partialTableName = GetPartialNameForPatternMatch(tableName);

using IArrowArrayStream stream = _connection.GetObjects(
depth: AdbcConnection.GetObjectsDepth.All,
catalogPattern: databaseName,
dbSchemaPattern: schemaName,
tableNamePattern: $"{partialTableName}%",
tableTypes: new List<string> { "BASE TABLE", "VIEW" },
tableNamePattern: tableNamePattern,
tableTypes: _tableTypes,
columnNamePattern: null);

using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result;

List<AdbcCatalog> catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName);

List<AdbcTable> tables = catalogs
.Select(s => s.DbSchemas)
.Where(c => string.Equals(c.Name, databaseName))
.Select(c => c.DbSchemas)
.FirstOrDefault()
.Select(t => t.Tables)
.Where(s => string.Equals(s.Name, schemaName))
.Select(s => s.Tables)
.FirstOrDefault();
AdbcTable table = tables.FirstOrDefault();

AdbcTable table = tables.Where((table) => string.Equals(table.Name, tableName)).FirstOrDefault();
Assert.True(table != null, "table should not be null");
Assert.StartsWith(partialTableName, table.Name);
}

/// <summary>
Expand All @@ -314,19 +254,22 @@ public void CanGetObjectsAll()
catalogPattern: databaseName,
dbSchemaPattern: schemaName,
tableNamePattern: tableName,
tableTypes: new List<string> { "BASE TABLE", "VIEW" },
tableTypes: _tableTypes,
columnNamePattern: columnName);

using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result;

List<AdbcCatalog> catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName);

List<AdbcColumn> columns = catalogs
.Select(s => s.DbSchemas)
.Where(c => string.Equals(c.Name, databaseName))
.Select(c => c.DbSchemas)
.FirstOrDefault()
.Select(t => t.Tables)
.Where(s => string.Equals(s.Name, schemaName))
.Select(s => s.Tables)
.FirstOrDefault()
.Select(c => c.Columns)
.Where(t => string.Equals(t.Name, tableName))
.Select(t => t.Columns)
.FirstOrDefault();

Assert.True(columns != null, "Columns cannot be null");
Expand Down Expand Up @@ -411,13 +354,20 @@ public void CanExecuteQuery()
Tests.DriverTests.CanExecuteQuery(queryResult, _testConfiguration.ExpectedResultsCount);
}

private string GetPartialNameForPatternMatch(string name)
private static string GetPartialNameForPatternMatch(string name)
{
if (string.IsNullOrEmpty(name) || name.Length == 1) return name;

return name.Substring(0, name.Length / 2);
}

private static string GetNameWithoutFirstChatacter(string name)
{
if (string.IsNullOrEmpty(name)) return name;

return name.Substring(1);
}

public void Dispose()
{
_connection.Dispose();
Expand Down
24 changes: 20 additions & 4 deletions csharp/test/Drivers/Snowflake/SnowflakeTestingUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
* limitations under the License.
*/

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Apache.Arrow.Adbc.C;
using Xunit;

namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
{
Expand All @@ -40,8 +42,22 @@ internal class SnowflakeParameters

internal class SnowflakeTestingUtils
{
internal static readonly SnowflakeTestConfiguration TestConfiguration;

internal const string SNOWFLAKE_TEST_CONFIG_VARIABLE = "SNOWFLAKE_TEST_CONFIG_FILE";

static SnowflakeTestingUtils()
{
try
{
TestConfiguration = Utils.LoadTestConfiguration<SnowflakeTestConfiguration>(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE);
}
catch (InvalidOperationException ex)
{
Console.WriteLine(ex.Message);
}
}

/// <summary>
/// Gets a the Snowflake ADBC driver with settings from the
/// <see cref="SnowflakeTestConfiguration"/>.
Expand All @@ -66,12 +82,12 @@ out Dictionary<string, string> parameters
{ SnowflakeParameters.USE_HIGH_PRECISION, testConfiguration.UseHighPrecision.ToString().ToLowerInvariant() }
};

if(!string.IsNullOrWhiteSpace(testConfiguration.Host))
if (!string.IsNullOrWhiteSpace(testConfiguration.Host))
{
parameters[SnowflakeParameters.HOST] = testConfiguration.Host;
}

if(!string.IsNullOrWhiteSpace(testConfiguration.Database))
if (!string.IsNullOrWhiteSpace(testConfiguration.Database))
{
parameters[SnowflakeParameters.DATABASE] = testConfiguration.Database;
}
Expand Down Expand Up @@ -120,9 +136,9 @@ internal static string[] GetQueries(SnowflakeTestConfiguration testConfiguration
{
string modifiedLine = line;

foreach(string key in placeholderValues.Keys)
foreach (string key in placeholderValues.Keys)
{
if(modifiedLine.Contains(key))
if (modifiedLine.Contains(key))
modifiedLine = modifiedLine.Replace(key, placeholderValues[key]);
}

Expand Down
Loading

0 comments on commit 7a06864

Please sign in to comment.