Skip to content

Commit

Permalink
Minor: Support LargeList List Range indexing and fix large list handl…
Browse files Browse the repository at this point in the history
…ing in ConstEvaluator (apache#9393)

* fix largelist

Signed-off-by: jayzhan211 <[email protected]>

* support large list for list range

Signed-off-by: jayzhan211 <[email protected]>

---------

Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 authored Mar 1, 2024
1 parent 9e39afd commit 2a490e4
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 8 deletions.
5 changes: 3 additions & 2 deletions datafusion/expr/src/field_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,11 @@ impl GetFieldAccessSchema {
Self::ListRange { start_dt, stop_dt, stride_dt } => {
match (data_type, start_dt, stop_dt, stride_dt) {
(DataType::List(_), DataType::Int64, DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)),
(DataType::List(_), _, _, _) => plan_err!(
(DataType::LargeList(_), DataType::Int64, DataType::Int64, DataType::Int64) => Ok(Field::new("large_list", data_type.clone(), true)),
(DataType::List(_), _, _, _) | (DataType::LargeList(_), _, _, _)=> plan_err!(
"Only ints are valid as an indexed field in a list"
),
(other, _, _, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
(other, _, _, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `LargeList` or `Struct` types, got {other}"),
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,13 @@ impl<'a> ConstEvaluator<'a> {
DataFusionError::Execution(format!("Could not evaluate the expression, found a result of length {}", a.len())),
expr,
)
} else if as_list_array(&a).is_ok() || as_large_list_array(&a).is_ok() {
} else if as_list_array(&a).is_ok() {
ConstSimplifyResult::Simplified(ScalarValue::List(
a.as_list().to_owned().into(),
a.as_list::<i32>().to_owned().into(),
))
} else if as_large_list_array(&a).is_ok() {
ConstSimplifyResult::Simplified(ScalarValue::LargeList(
a.as_list::<i64>().to_owned().into(),
))
} else {
// Non-ListArray
Expand Down
10 changes: 6 additions & 4 deletions datafusion/physical-expr/src/expressions/get_indexed_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,16 +268,18 @@ impl PhysicalExpr for GetIndexedFieldExpr {
let stop = stop.evaluate(batch)?.into_array(batch.num_rows())?;
let stride = stride.evaluate(batch)?.into_array(batch.num_rows())?;
match (array.data_type(), start.data_type(), stop.data_type(), stride.data_type()) {
(DataType::List(_), DataType::Int64, DataType::Int64, DataType::Int64) => {
(DataType::List(_), DataType::Int64, DataType::Int64, DataType::Int64) |
(DataType::LargeList(_), DataType::Int64, DataType::Int64, DataType::Int64)=> {
Ok(ColumnarValue::Array((array_slice(&[
array, start, stop, stride
]))?))
},
(DataType::List(_), start, stop, stride) => exec_err!(
"get indexed field is only possible on lists with int64 indexes. \
(DataType::List(_), start, stop, stride) |
(DataType::LargeList(_), start, stop, stride)=> exec_err!(
"get indexed field is only possible on List/LargeList with int64 indexes. \
Tried with {start:?}, {stop:?} and {stride:?} indices"),
(dt, start, stop, stride) => exec_err!(
"get indexed field is only possible on lists with int64 indexes or struct \
"get indexed field is only possible on List/LargeList with int64 indexes or struct \
with utf8 indexes. Tried {dt:?} with {start:?}, {stop:?} and {stride:?} indices"),
}
}
Expand Down
16 changes: 16 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -861,12 +861,28 @@ select make_array(1, 2, 3)[1:2], make_array(1.0, 2.0, 3.0)[2:3], make_array('h',
----
[1, 2] [2.0, 3.0] [e, l, l]

query ???
select arrow_cast([1, 2, 3], 'LargeList(Int64)')[1:2],
arrow_cast([1.0, 2.0, 3.0], 'LargeList(Int64)')[2:3],
arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)')[2:4]
;
----
[1, 2] [2, 3] [e, l, l]

# multiple index with columns #2 (zero index)
query ???
select make_array(1, 2, 3)[0:0], make_array(1.0, 2.0, 3.0)[0:2], make_array('h', 'e', 'l', 'l', 'o')[0:6];
----
[] [1.0, 2.0] [h, e, l, l, o]

query ???
select arrow_cast([1, 2, 3], 'LargeList(Int64)')[0:0],
arrow_cast([1.0, 2.0, 3.0], 'LargeList(Int64)')[0:2],
arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)')[0:6]
;
----
[] [1, 2] [h, e, l, l, o]

# TODO: support multiple negative index
# multiple index with columns #3 (negative index)
# query II
Expand Down

0 comments on commit 2a490e4

Please sign in to comment.