diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs index 798fc9c470..36d62f89f2 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs @@ -50,6 +50,8 @@ public class SparkConnection : HiveServer2Connection const string InfoDriverVersion = "1.0.0"; const string InfoVendorName = "Spark"; const string InfoDriverArrowVersion = "1.0.0"; + const int DecimalPrecisionDefault = 10; + const int DecimalScaleDefault = 0; internal static TSparkGetDirectResults sparkGetDirectResults = new TSparkGetDirectResults(1000); @@ -287,6 +289,9 @@ public override Schema GetTableSchema(string? catalog, string? dbSchema, string? string columnName = columns[3].StringVal.Values.GetString(i); int? columnType = columns[4].I32Val.Values.GetValue(i); string typeName = columns[5].StringVal.Values.GetString(i); + // Note: the following two columns do not seem to be set correctly for DECIMAL types. + //int? columnSize = columns[6].I32Val.Values.GetValue(i); + //int? decimalDigits = columns[8].I32Val.Values.GetValue(i); bool nullable = columns[10].I32Val.Values.GetValue(i) == 1; IArrowType dataType = SparkConnection.GetArrowType((ColumnTypeId)columnType!.Value, typeName); fields[i] = new Field(columnName, dataType, nullable); @@ -481,8 +486,9 @@ private static IArrowType GetArrowType(ColumnTypeId columnTypeId, string typeNam case ColumnTypeId.CHAR_TYPE: return StringType.Default; case ColumnTypeId.DECIMAL_TYPE: - // TODO: Parse typeName for precision and scale, because not available in other metadata. - return new Decimal128Type(38, 38); + // Note: parsing the type name for SQL DECIMAL types as the precision and scale values + // are not returned in the Thrift call to GetColumns + return SqlDecimalTypeParser.ParseOrDefault(typeName, new Decimal128Type(DecimalPrecisionDefault, DecimalScaleDefault)); case ColumnTypeId.ARRAY_TYPE: case ColumnTypeId.MAP_TYPE: case ColumnTypeId.STRUCT_TYPE: @@ -521,7 +527,6 @@ private StructArray GetDbSchemas( } - IReadOnlyList schema = StandardSchemas.DbSchemaSchema; IReadOnlyList dataArrays = schema.Validate( new List @@ -688,9 +693,65 @@ private string PatternToRegEx(string? pattern) return builder.ToString(); } + + /// + /// Provides a parser for SQL DECIMAL type definitions. + /// + private static class SqlDecimalTypeParser + { + // Pattern is based on this definition + // https://docs.databricks.com/en/sql/language-manual/data-types/decimal-type.html#syntax + // { DECIMAL | DEC | NUMERIC } [ ( p [ , s ] ) ] + // p: Optional maximum precision (total number of digits) of the number between 1 and 38. The default is 10. + // s: Optional scale of the number between 0 and p. The number of digits to the right of the decimal point. The default is 0. + private static readonly Regex s_expression = new( + @"^\s*(?((DECIMAL)|(DEC)|(NUMERIC)))(\s*\(\s*((?\d{1,2})(\s*\,\s*(?\d{1,2}))?)\s*\))?\s*$", + RegexOptions.IgnoreCase | RegexOptions.Compiled | RegexOptions.CultureInvariant); + + /// + /// Parses the input string for a valid SQL DECIMAL type definition and returns a new or returns the defaultValue, if invalid. + /// + /// The SQL type defintion string to parse. + /// If input string is an invalid SQL DECIMAL type definition, this value is returned instead. + /// If input string is a valid SQL DECIMAL type definition, it returns a new ; otherwise defaultValue. + public static Decimal128Type ParseOrDefault(string input, Decimal128Type defaultValue) + { + return TryParse(input, out Decimal128Type? candidate) ? candidate! : defaultValue; + } + + /// + /// Tries to parse the input string for a valid SQL DECIMAL type definition. + /// + /// The SQL type defintion string to parse. + /// If successful, an new with the precision and scale set; otherwise null. + /// True if it can successfully parse the type definition input string; otherwise false. + private static bool TryParse(string input, out Decimal128Type? value) + { + // Ensure defaults are set, in case not provided in precision/scale clause. + int precision = DecimalPrecisionDefault; + int scale = DecimalScaleDefault; + + Match match = s_expression.Match(input); + if (!match.Success) + { + value = null; + return false; + } + + GroupCollection groups = match.Groups; + Group precisionGroup = groups["precision"]; + Group scaleGroup = groups["scale"]; + + precision = precisionGroup.Success && int.TryParse(precisionGroup.Value, out int candidatePrecision) ? candidatePrecision : precision; + scale = scaleGroup.Success && int.TryParse(scaleGroup.Value, out int candidateScale) ? candidateScale : scale; + + value = new Decimal128Type(precision, scale); + return true; + } + } } - public struct TableInfoPair + internal struct TableInfoPair { public string Type { get; set; }