Skip to content

Commit

Permalink
fix(functions): improve array_prepend and array_append (#15437)
Browse files Browse the repository at this point in the history
* fix(functions): improve array_prepend and array_append

* fix

* fix

* fix

* fix
  • Loading branch information
andylokandy authored May 9, 2024
1 parent b232ff1 commit 43863f4
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 206 deletions.
4 changes: 2 additions & 2 deletions src/common/tracing/src/structlog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ fn build_trees(spans: &[&SpanRecord]) -> Vec<TreeNode> {

let roots = raw.keys().filter(|id| !span_ids.contains(id)).cloned();
roots
.flat_map(|root| build_sub_tree(root, &raw).pop())
.collect_vec()
.filter_map(|root| build_sub_tree(root, &raw).pop())
.collect()
}

fn build_sub_tree(parent_id: SpanId, raw: &HashMap<SpanId, Vec<&SpanRecord>>) -> Vec<TreeNode> {
Expand Down
239 changes: 56 additions & 183 deletions src/query/functions/src/scalars/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ use std::hash::Hash;
use std::ops::Range;
use std::sync::Arc;

use databend_common_expression::type_check::common_super_type;
use databend_common_expression::types::array::ArrayColumn;
use databend_common_expression::types::array::ArrayColumnBuilder;
use databend_common_expression::types::boolean::BooleanDomain;
use databend_common_expression::types::nullable::NullableDomain;
Expand Down Expand Up @@ -72,7 +70,6 @@ use siphasher::sip128::SipHasher24;

use crate::aggregates::eval_aggr;
use crate::AggregateFunctionFactory;
use crate::BUILTIN_FUNCTIONS;

const ARRAY_AGGREGATE_FUNCTIONS: &[(&str, &str); 14] = &[
("array_avg", "avg"),
Expand Down Expand Up @@ -243,10 +240,10 @@ pub fn register(registry: &mut FunctionRegistry) {
),
);

registry.register_2_arg_core::<NullableType<EmptyArrayType>, NullableType<EmptyArrayType>, EmptyArrayType, _, _>(
registry.register_2_arg::<EmptyArrayType, EmptyArrayType, EmptyArrayType, _, _>(
"array_concat",
|_, _, _| FunctionDomain::Full,
|_, _, _| Value::Scalar(()),
|_, _, _| (),
);

registry.register_passthrough_nullable_2_arg::<ArrayType<GenericType<0>>, ArrayType<GenericType<0>>, ArrayType<GenericType<0>>, _, _>(
Expand Down Expand Up @@ -431,183 +428,57 @@ pub fn register(registry: &mut FunctionRegistry) {
),
);

registry.register_function_factory("array_prepend", |_, args_type| {
if args_type.len() != 2 {
return None;
}
let (common_type, return_type) = match args_type[1].remove_nullable() {
DataType::EmptyArray => (
args_type[0].clone(),
DataType::Array(Box::new(args_type[0].clone())),
),
DataType::Array(box inner_type) => {
let common_type = common_super_type(
inner_type.clone(),
args_type[0].clone(),
&BUILTIN_FUNCTIONS.default_cast_rules,
)?;
(common_type.clone(), DataType::Array(Box::new(common_type)))
}
_ => {
return None;
}
};
let args_type = vec![
common_type,
if args_type[1].is_nullable() {
return_type.wrap_nullable()
} else {
return_type.clone()
},
];
Some(Arc::new(Function {
signature: FunctionSignature {
name: "array_prepend".to_string(),
args_type,
return_type: return_type.clone(),
},
eval: FunctionEval::Scalar {
calc_domain: Box::new(|_, args_domain| {
let array_domain = match &args_domain[1] {
Domain::Nullable(nullable_domain) => nullable_domain.value.clone(),
other => Some(Box::new(other.clone())),
};
let inner_domain = match array_domain {
Some(box Domain::Array(Some(box inner_domain))) => {
inner_domain.merge(&args_domain[0])
}
_ => args_domain[0].clone(),
};
FunctionDomain::Domain(Domain::Array(Some(Box::new(inner_domain))))
}),
eval: Box::new(move |args, _| {
let len = args.iter().find_map(|arg| match arg {
ValueRef::Column(col) => Some(col.len()),
_ => None,
});

let mut offsets = Vec::with_capacity(len.unwrap_or(1) + 1);
offsets.push(0);
let inner_type = return_type.as_array().unwrap();
let mut builder = ColumnBuilder::with_capacity(inner_type, len.unwrap_or(1));

for idx in 0..(len.unwrap_or(1)) {
let val = match &args[0] {
ValueRef::Scalar(scalar) => scalar.clone(),
ValueRef::Column(col) => unsafe { col.index_unchecked(idx) },
};
builder.push(val.clone());
let array_col = match &args[1] {
ValueRef::Scalar(scalar) => scalar.clone(),
ValueRef::Column(col) => unsafe { col.index_unchecked(idx).clone() },
};
if let ScalarRef::Array(col) = array_col {
for val in col.iter() {
builder.push(val.clone());
}
}
offsets.push(builder.len() as u64);
}
match len {
Some(_) => Value::Column(Column::Array(Box::new(ArrayColumn {
values: builder.build(),
offsets: offsets.into(),
}))),
None => Value::Scalar(Scalar::Array(builder.build())),
registry.register_2_arg_core::<GenericType<0>, NullableType<ArrayType<GenericType<0>>>, ArrayType<GenericType<0>>, _, _>(
"array_prepend",
|_, item_domain, array_domain| {
let domain = array_domain
.value
.as_ref()
.map(|box inner_domain| {
inner_domain
.as_ref()
.map(|inner_domain| inner_domain.merge(item_domain))
.unwrap_or(item_domain.clone())
});
FunctionDomain::Domain(domain)
},
vectorize_with_builder_2_arg::<GenericType<0>, NullableType<ArrayType<GenericType<0>>>, ArrayType<GenericType<0>>>(
|val, arr, output, _| {
output.put_item(val);
if let Some(arr) = arr {
for item in arr.iter() {
output.put_item(item);
}
}),
},
}))
});

registry.register_function_factory("array_append", |_, args_type| {
if args_type.len() != 2 {
return None;
}
let (common_type, return_type) = match args_type[0].remove_nullable() {
DataType::EmptyArray => (
args_type[1].clone(),
DataType::Array(Box::new(args_type[1].clone())),
),
DataType::Array(box inner_type) => {
let common_type = common_super_type(
inner_type.clone(),
args_type[1].clone(),
&BUILTIN_FUNCTIONS.default_cast_rules,
)?;
(common_type.clone(), DataType::Array(Box::new(common_type)))
}
_ => {
return None;
}
};
let args_type = vec![
if args_type[0].is_nullable() {
return_type.wrap_nullable()
} else {
return_type.clone()
},
common_type,
];
Some(Arc::new(Function {
signature: FunctionSignature {
name: "array_append".to_string(),
args_type,
return_type: return_type.clone(),
},
eval: FunctionEval::Scalar {
calc_domain: Box::new(|_, args_domain| {
let array_domain = match &args_domain[0] {
Domain::Nullable(nullable_domain) => nullable_domain.value.clone(),
other => Some(Box::new(other.clone())),
};
let inner_domain = match array_domain {
Some(box Domain::Array(Some(box inner_domain))) => {
inner_domain.merge(&args_domain[1])
}
_ => args_domain[1].clone(),
};
FunctionDomain::Domain(Domain::Array(Some(Box::new(inner_domain))))
}),
eval: Box::new(move |args, _| {
let len = args.iter().find_map(|arg| match arg {
ValueRef::Column(col) => Some(col.len()),
_ => None,
});

let mut offsets = Vec::with_capacity(len.unwrap_or(1) + 1);
offsets.push(0);
let inner_type = return_type.as_array().unwrap();
let mut builder = ColumnBuilder::with_capacity(inner_type, len.unwrap_or(1));
}
output.commit_row()
})
);

for idx in 0..(len.unwrap_or(1)) {
let array_col = match &args[0] {
ValueRef::Scalar(scalar) => scalar.clone(),
ValueRef::Column(col) => unsafe { col.index_unchecked(idx).clone() },
};
if let ScalarRef::Array(col) = array_col {
for val in col.iter() {
builder.push(val.clone());
}
}
let val = match &args[1] {
ValueRef::Scalar(scalar) => scalar.clone(),
ValueRef::Column(col) => unsafe { col.index_unchecked(idx) },
};
builder.push(val.clone());
offsets.push(builder.len() as u64);
}
match len {
Some(_) => Value::Column(Column::Array(Box::new(ArrayColumn {
values: builder.build(),
offsets: offsets.into(),
}))),
None => Value::Scalar(Scalar::Array(builder.build())),
registry.register_2_arg_core::<NullableType<ArrayType<GenericType<0>>>, GenericType<0>, ArrayType<GenericType<0>>, _, _>(
"array_append",
|_, array_domain, item_domain| {
let domain = array_domain
.value
.as_ref()
.map(|box inner_domain| {
inner_domain
.as_ref()
.map(|inner_domain| inner_domain.merge(item_domain))
.unwrap_or(item_domain.clone())
});
FunctionDomain::Domain(domain)
},
vectorize_with_builder_2_arg::<NullableType<ArrayType<GenericType<0>>>, GenericType<0>, ArrayType<GenericType<0>>>(
|arr, val, output, _| {
if let Some(arr) = arr {
for item in arr.iter() {
output.put_item(item);
}
}),
},
}))
});
}
output.put_item(val);
output.commit_row()
})
);

fn eval_contains<T: ArgType>(
lhs: ValueRef<ArrayType<T>>,
Expand Down Expand Up @@ -791,12 +662,14 @@ pub fn register(registry: &mut FunctionRegistry) {
}
);

registry.register_2_arg_core::<ArrayType<GenericType<0>>, GenericType<0>, BooleanType, _, _>(
registry.register_2_arg_core::<NullableType<ArrayType<GenericType<0>>>, GenericType<0>, BooleanType, _, _>(
"contains",
|_, _, _| FunctionDomain::Full,
vectorize_2_arg::<ArrayType<GenericType<0>>, GenericType<0>, BooleanType>(|lhs, rhs, _| {
lhs.iter().contains(&rhs)
}),
vectorize_2_arg::<NullableType<ArrayType<GenericType<0>>>, GenericType<0>, BooleanType>(
|lhs, rhs, _| {
lhs.map(|col| col.iter().contains(&rhs)).unwrap_or(false)
}
)
);

registry.register_passthrough_nullable_1_arg::<EmptyArrayType, UInt64Type, _, _>(
Expand Down
20 changes: 20 additions & 0 deletions src/query/functions/tests/it/scalars/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,16 @@ fn test_array_prepend(file: &mut impl Write) {
run_ast(file, "array_prepend(1, [])", &[]);
run_ast(file, "array_prepend(1, [2, 3, NULL, 4])", &[]);
run_ast(file, "array_prepend('a', ['b', NULL, NULL, 'c', 'd'])", &[]);
run_ast(
file,
"array_prepend(NULL, CAST([2, 3] AS Array(INT8 NULL) NULL))",
&[],
);
run_ast(
file,
"array_prepend(1, CAST([2, 3] AS Array(INT8 NULL) NULL))",
&[],
);
run_ast(file, "array_prepend(a, [b, c])", &[
("a", Int16Type::from_data(vec![0i16, 1, 2])),
("b", Int16Type::from_data(vec![3i16, 4, 5])),
Expand All @@ -237,6 +247,16 @@ fn test_array_append(file: &mut impl Write) {
run_ast(file, "array_append([], 1)", &[]);
run_ast(file, "array_append([2, 3, NULL, 4], 5)", &[]);
run_ast(file, "array_append(['b', NULL, NULL, 'c', 'd'], 'e')", &[]);
run_ast(
file,
"array_append(CAST([1, 2] AS Array(INT8 NULL) NULL), NULL)",
&[],
);
run_ast(
file,
"array_append(CAST([1, 2] AS Array(INT8 NULL) NULL), 3)",
&[],
);
run_ast(file, "array_append([b, c], a)", &[
("a", Int16Type::from_data(vec![0i16, 1, 2])),
("b", Int16Type::from_data(vec![3i16, 4, 5])),
Expand Down
Loading

0 comments on commit 43863f4

Please sign in to comment.