Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into spark-4.0
Browse files Browse the repository at this point in the history
  • Loading branch information
kazuyukitanimura committed May 21, 2024
2 parents 5463bd9 + 7b0a7e0 commit 0b81d7f
Show file tree
Hide file tree
Showing 15 changed files with 799 additions and 524 deletions.
14 changes: 11 additions & 3 deletions common/src/main/java/org/apache/comet/parquet/ColumnReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.parquet.column.page.DataPageV2;
import org.apache.parquet.column.page.DictionaryPage;
import org.apache.parquet.column.page.PageReader;
import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.spark.sql.types.DataType;

import org.apache.comet.CometConf;
Expand Down Expand Up @@ -199,14 +200,19 @@ public CometDecodedVector loadVector() {
currentVector.close();
}

LogicalTypeAnnotation logicalTypeAnnotation =
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
boolean isUuid =
logicalTypeAnnotation instanceof LogicalTypeAnnotation.UUIDLogicalTypeAnnotation;

long[] addresses = Native.currentBatch(nativeHandle);

try (ArrowArray array = ArrowArray.wrap(addresses[0]);
ArrowSchema schema = ArrowSchema.wrap(addresses[1])) {
FieldVector vector = Data.importVector(ALLOCATOR, array, schema, dictionaryProvider);
DictionaryEncoding dictionaryEncoding = vector.getField().getDictionary();

CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128);
CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128, isUuid);

// Update whether the current vector contains any null values. This is used in the following
// batch(s) to determine whether we can skip loading the native vector.
Expand All @@ -229,12 +235,14 @@ public CometDecodedVector loadVector() {
// initialized yet.
Dictionary arrowDictionary = dictionaryProvider.lookup(dictionaryEncoding.getId());
CometPlainVector dictionaryVector =
new CometPlainVector(arrowDictionary.getVector(), useDecimal128);
new CometPlainVector(arrowDictionary.getVector(), useDecimal128, isUuid);
dictionary = new CometDictionary(dictionaryVector);
}

currentVector =
new CometDictionaryVector(cometVector, dictionary, dictionaryProvider, useDecimal128);
new CometDictionaryVector(
cometVector, dictionary, dictionaryProvider, useDecimal128, false, isUuid);

return currentVector;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,20 @@ public abstract class CometDecodedVector extends CometVector {
private int numValues;
private int validityByteCacheIndex = -1;
private byte validityByteCache;
protected boolean isUuid;

protected CometDecodedVector(ValueVector vector, Field valueField, boolean useDecimal128) {
this(vector, valueField, useDecimal128, false);
}

protected CometDecodedVector(
ValueVector vector, Field valueField, boolean useDecimal128, boolean isUuid) {
super(Utils.fromArrowField(valueField), useDecimal128);
this.valueVector = vector;
this.numNulls = valueVector.getNullCount();
this.numValues = valueVector.getValueCount();
this.hasNull = numNulls != 0;
this.isUuid = isUuid;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,17 @@ public CometDictionaryVector(
CometDictionary values,
DictionaryProvider provider,
boolean useDecimal128) {
this(indices, values, provider, useDecimal128, false);
this(indices, values, provider, useDecimal128, false, false);
}

public CometDictionaryVector(
CometPlainVector indices,
CometDictionary values,
DictionaryProvider provider,
boolean useDecimal128,
boolean isAlias) {
super(indices.valueVector, values.getValueVector().getField(), useDecimal128);
boolean isAlias,
boolean isUuid) {
super(indices.valueVector, values.getValueVector().getField(), useDecimal128, isUuid);
Preconditions.checkArgument(
indices.valueVector instanceof IntVector, "'indices' should be a IntVector");
this.values = values;
Expand Down Expand Up @@ -130,6 +131,6 @@ public CometVector slice(int offset, int length) {
// Set the alias flag to true so that the sliced vector will not close the dictionary vector.
// Otherwise, if the dictionary is closed, the sliced vector will not be able to access the
// dictionary.
return new CometDictionaryVector(sliced, values, provider, useDecimal128, true);
return new CometDictionaryVector(sliced, values, provider, useDecimal128, true, isUuid);
}
}
13 changes: 11 additions & 2 deletions common/src/main/java/org/apache/comet/vector/CometPlainVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ public class CometPlainVector extends CometDecodedVector {
private int booleanByteCacheIndex = -1;

public CometPlainVector(ValueVector vector, boolean useDecimal128) {
super(vector, vector.getField(), useDecimal128);
this(vector, useDecimal128, false);
}

public CometPlainVector(ValueVector vector, boolean useDecimal128, boolean isUuid) {
super(vector, vector.getField(), useDecimal128, isUuid);
// NullType doesn't have data buffer.
if (vector instanceof NullVector) {
this.valueBufferAddress = -1;
Expand Down Expand Up @@ -111,7 +115,12 @@ public UTF8String getUTF8String(int rowId) {
byte[] result = new byte[length];
Platform.copyMemory(
null, valueBufferAddress + offset, result, Platform.BYTE_ARRAY_OFFSET, length);
return UTF8String.fromBytes(result);

if (!isUuid) {
return UTF8String.fromBytes(result);
} else {
return UTF8String.fromString(convertToUuid(result).toString());
}
}
}

Expand Down
57 changes: 11 additions & 46 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ macro_rules! cast_utf8_to_int {
for i in 0..len {
if $array.is_null(i) {
cast_array.append_null()
} else if let Some(cast_value) = $cast_method($array.value(i).trim(), $eval_mode)? {
} else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? {
cast_array.append_value(cast_value);
} else {
cast_array.append_null()
Expand Down Expand Up @@ -1010,14 +1010,6 @@ fn cast_string_to_int_with_range_check(
}
}

#[derive(PartialEq)]
enum State {
SkipLeadingWhiteSpace,
SkipTrailingWhiteSpace,
ParseSignAndDigits,
ParseFractionalDigits,
}

/// Equivalent to
/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal)
/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal)
Expand All @@ -1029,34 +1021,22 @@ fn do_cast_string_to_int<
type_name: &str,
min_value: T,
) -> CometResult<Option<T>> {
let len = str.len();
if str.is_empty() {
let trimmed_str = str.trim();
if trimmed_str.is_empty() {
return none_or_err(eval_mode, type_name, str);
}

let len = trimmed_str.len();
let mut result: T = T::zero();
let mut negative = false;
let radix = T::from(10);
let stop_value = min_value / radix;
let mut state = State::SkipLeadingWhiteSpace;
let mut parsed_sign = false;

for (i, ch) in str.char_indices() {
// skip leading whitespace
if state == State::SkipLeadingWhiteSpace {
if ch.is_whitespace() {
// consume this char
continue;
}
// change state and fall through to next section
state = State::ParseSignAndDigits;
}
let mut parse_sign_and_digits = true;

if state == State::ParseSignAndDigits {
if !parsed_sign {
for (i, ch) in trimmed_str.char_indices() {
if parse_sign_and_digits {
if i == 0 {
negative = ch == '-';
let positive = ch == '+';
parsed_sign = true;
if negative || positive {
if i + 1 == len {
// input string is just "+" or "-"
Expand All @@ -1070,7 +1050,7 @@ fn do_cast_string_to_int<
if ch == '.' {
if eval_mode == EvalMode::Legacy {
// truncate decimal in legacy mode
state = State::ParseFractionalDigits;
parse_sign_and_digits = false;
continue;
} else {
return none_or_err(eval_mode, type_name, str);
Expand Down Expand Up @@ -1102,27 +1082,12 @@ fn do_cast_string_to_int<
return none_or_err(eval_mode, type_name, str);
}
}
}

if state == State::ParseFractionalDigits {
// This is the case when we've encountered a decimal separator. The fractional
// part will not change the number, but we will verify that the fractional part
// is well-formed.
if ch.is_whitespace() {
// finished parsing fractional digits, now need to skip trailing whitespace
state = State::SkipTrailingWhiteSpace;
// consume this char
continue;
}
} else {
// make sure fractional digits are valid digits but ignore them
if !ch.is_ascii_digit() {
return none_or_err(eval_mode, type_name, str);
}
}

// skip trailing whitespace
if state == State::SkipTrailingWhiteSpace && !ch.is_whitespace() {
return none_or_err(eval_mode, type_name, str);
}
}

if !negative {
Expand Down
Loading

0 comments on commit 0b81d7f

Please sign in to comment.