Skip to content

Commit

Permalink
fix(csharp/src/Drivers/Apache): set the precision and scale correctly…
Browse files Browse the repository at this point in the history
… on Decimal128Type (apache#1858)

Addresses a TODO item to accurately retrieve the precision and scale for
`DECIMAL` (`Decimal128Type`) when retrieving table schema
(`GetTableSchema`).

* Parses the type name to find the precision and scale, using the
default of 10 and 0, respectively.
  • Loading branch information
birschick-bq authored May 14, 2024
1 parent 4d203b8 commit 57770e3
Showing 1 changed file with 65 additions and 4 deletions.
69 changes: 65 additions & 4 deletions csharp/src/Drivers/Apache/Spark/SparkConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ public class SparkConnection : HiveServer2Connection
const string InfoDriverVersion = "1.0.0";
const string InfoVendorName = "Spark";
const string InfoDriverArrowVersion = "1.0.0";
const int DecimalPrecisionDefault = 10;
const int DecimalScaleDefault = 0;

internal static TSparkGetDirectResults sparkGetDirectResults = new TSparkGetDirectResults(1000);

Expand Down Expand Up @@ -287,6 +289,9 @@ public override Schema GetTableSchema(string? catalog, string? dbSchema, string?
string columnName = columns[3].StringVal.Values.GetString(i);
int? columnType = columns[4].I32Val.Values.GetValue(i);
string typeName = columns[5].StringVal.Values.GetString(i);
// Note: the following two columns do not seem to be set correctly for DECIMAL types.
//int? columnSize = columns[6].I32Val.Values.GetValue(i);
//int? decimalDigits = columns[8].I32Val.Values.GetValue(i);
bool nullable = columns[10].I32Val.Values.GetValue(i) == 1;
IArrowType dataType = SparkConnection.GetArrowType((ColumnTypeId)columnType!.Value, typeName);
fields[i] = new Field(columnName, dataType, nullable);
Expand Down Expand Up @@ -481,8 +486,9 @@ private static IArrowType GetArrowType(ColumnTypeId columnTypeId, string typeNam
case ColumnTypeId.CHAR_TYPE:
return StringType.Default;
case ColumnTypeId.DECIMAL_TYPE:
// TODO: Parse typeName for precision and scale, because not available in other metadata.
return new Decimal128Type(38, 38);
// Note: parsing the type name for SQL DECIMAL types as the precision and scale values
// are not returned in the Thrift call to GetColumns
return SqlDecimalTypeParser.ParseOrDefault(typeName, new Decimal128Type(DecimalPrecisionDefault, DecimalScaleDefault));
case ColumnTypeId.ARRAY_TYPE:
case ColumnTypeId.MAP_TYPE:
case ColumnTypeId.STRUCT_TYPE:
Expand Down Expand Up @@ -521,7 +527,6 @@ private StructArray GetDbSchemas(

}


IReadOnlyList<Field> schema = StandardSchemas.DbSchemaSchema;
IReadOnlyList<IArrowArray> dataArrays = schema.Validate(
new List<IArrowArray>
Expand Down Expand Up @@ -688,9 +693,65 @@ private string PatternToRegEx(string? pattern)

return builder.ToString();
}

/// <summary>
/// Provides a parser for SQL DECIMAL type definitions.
/// </summary>
private static class SqlDecimalTypeParser
{
// Pattern is based on this definition
// https://docs.databricks.com/en/sql/language-manual/data-types/decimal-type.html#syntax
// { DECIMAL | DEC | NUMERIC } [ ( p [ , s ] ) ]
// p: Optional maximum precision (total number of digits) of the number between 1 and 38. The default is 10.
// s: Optional scale of the number between 0 and p. The number of digits to the right of the decimal point. The default is 0.
private static readonly Regex s_expression = new(
@"^\s*(?<typeName>((DECIMAL)|(DEC)|(NUMERIC)))(\s*\(\s*((?<precision>\d{1,2})(\s*\,\s*(?<scale>\d{1,2}))?)\s*\))?\s*$",
RegexOptions.IgnoreCase | RegexOptions.Compiled | RegexOptions.CultureInvariant);

/// <summary>
/// Parses the input string for a valid SQL DECIMAL type definition and returns a new <see cref="Decimal128Type"/> or returns the <c>defaultValue</c>, if invalid.
/// </summary>
/// <param name="input">The SQL type defintion string to parse.</param>
/// <param name="defaultValue">If input string is an invalid SQL DECIMAL type definition, this value is returned instead.</param>
/// <returns>If input string is a valid SQL DECIMAL type definition, it returns a new <see cref="Decimal128Type"/>; otherwise <c>defaultValue</c>.</returns>
public static Decimal128Type ParseOrDefault(string input, Decimal128Type defaultValue)
{
return TryParse(input, out Decimal128Type? candidate) ? candidate! : defaultValue;
}

/// <summary>
/// Tries to parse the input string for a valid SQL DECIMAL type definition.
/// </summary>
/// <param name="input">The SQL type defintion string to parse.</param>
/// <param name="value">If successful, an new <see cref="Decimal128Type"/> with the precision and scale set; otherwise <c>null</c>.</param>
/// <returns>True if it can successfully parse the type definition input string; otherwise false.</returns>
private static bool TryParse(string input, out Decimal128Type? value)
{
// Ensure defaults are set, in case not provided in precision/scale clause.
int precision = DecimalPrecisionDefault;
int scale = DecimalScaleDefault;

Match match = s_expression.Match(input);
if (!match.Success)
{
value = null;
return false;
}

GroupCollection groups = match.Groups;
Group precisionGroup = groups["precision"];
Group scaleGroup = groups["scale"];

precision = precisionGroup.Success && int.TryParse(precisionGroup.Value, out int candidatePrecision) ? candidatePrecision : precision;
scale = scaleGroup.Success && int.TryParse(scaleGroup.Value, out int candidateScale) ? candidateScale : scale;

value = new Decimal128Type(precision, scale);
return true;
}
}
}

public struct TableInfoPair
internal struct TableInfoPair
{
public string Type { get; set; }

Expand Down

0 comments on commit 57770e3

Please sign in to comment.