From 6bd6a5f746e9058a0ae37b1e05140648db91d0ef Mon Sep 17 00:00:00 2001 From: sundyli <543950155@qq.com> Date: Fri, 12 Jan 2024 07:17:18 -0800 Subject: [PATCH] feat(query): add float64 version of distance function overload (#14311) * Update 02_0063_function_vector.test * feat(query): add float64 version of distance function overload * feat(query): add float64 version of distance function overload * fix the verctor.rs use style --------- Co-authored-by: Bohu --- src/common/vector/src/distance.rs | 34 ++++++++++ src/common/vector/src/lib.rs | 2 + src/query/functions/src/scalars/vector.rs | 62 +++++++++++++++++-- .../it/scalars/testdata/function_list.txt | 4 ++ .../02_function/02_0063_function_vector.test | 4 +- 5 files changed, 98 insertions(+), 8 deletions(-) diff --git a/src/common/vector/src/distance.rs b/src/common/vector/src/distance.rs index 56f605dfb81f..97953374d239 100644 --- a/src/common/vector/src/distance.rs +++ b/src/common/vector/src/distance.rs @@ -49,3 +49,37 @@ pub fn l2_distance(from: &[f32], to: &[f32]) -> Result { .sum::() .sqrt()) } + +pub fn cosine_distance_64(from: &[f64], to: &[f64]) -> Result { + if from.len() != to.len() { + return Err(ErrorCode::InvalidArgument(format!( + "Vector length not equal: {:} != {:}", + from.len(), + to.len(), + ))); + } + + let a = ArrayView::from(from); + let b = ArrayView::from(to); + let aa_sum = (&a * &a).sum(); + let bb_sum = (&b * &b).sum(); + + Ok(1.0 - (&a * &b).sum() / ((aa_sum).sqrt() * (bb_sum).sqrt())) +} + +pub fn l2_distance_64(from: &[f64], to: &[f64]) -> Result { + if from.len() != to.len() { + return Err(ErrorCode::InvalidArgument(format!( + "Vector length not equal: {:} != {:}", + from.len(), + to.len(), + ))); + } + + Ok(from + .iter() + .zip(to.iter()) + .map(|(a, b)| (a - b).powi(2)) + .sum::() + .sqrt()) +} diff --git a/src/common/vector/src/lib.rs b/src/common/vector/src/lib.rs index 8ee8bc83e7cd..2988f0db1173 100644 --- a/src/common/vector/src/lib.rs +++ b/src/common/vector/src/lib.rs @@ -15,4 +15,6 @@ mod distance; pub use distance::cosine_distance; +pub use distance::cosine_distance_64; pub use distance::l2_distance; +pub use distance::l2_distance_64; diff --git a/src/query/functions/src/scalars/vector.rs b/src/query/functions/src/scalars/vector.rs index a7ea8a1ec4e3..aa3d65a19fed 100644 --- a/src/query/functions/src/scalars/vector.rs +++ b/src/query/functions/src/scalars/vector.rs @@ -15,15 +15,19 @@ use databend_common_arrow::arrow::buffer::Buffer; use databend_common_expression::types::ArrayType; use databend_common_expression::types::Float32Type; +use databend_common_expression::types::Float64Type; use databend_common_expression::types::StringType; use databend_common_expression::types::F32; +use databend_common_expression::types::F64; use databend_common_expression::vectorize_with_builder_1_arg; use databend_common_expression::vectorize_with_builder_2_arg; use databend_common_expression::FunctionDomain; use databend_common_expression::FunctionRegistry; use databend_common_openai::OpenAI; use databend_common_vector::cosine_distance; +use databend_common_vector::cosine_distance_64; use databend_common_vector::l2_distance; +use databend_common_vector::l2_distance_64; pub fn register(registry: &mut FunctionRegistry) { // cosine_distance @@ -33,12 +37,12 @@ pub fn register(registry: &mut FunctionRegistry) { |_, _, _| FunctionDomain::MayThrow, vectorize_with_builder_2_arg::, ArrayType, Float32Type>( |lhs, rhs, output, ctx| { - let l_f32= + let l= unsafe { std::mem::transmute::, Buffer>(lhs) }; - let r_f32= + let r = unsafe { std::mem::transmute::, Buffer>(rhs) }; - match cosine_distance(l_f32.as_slice(), r_f32.as_slice()) { + match cosine_distance(l.as_slice(), r .as_slice()) { Ok(dist) => { output.push(F32::from(dist)); } @@ -59,12 +63,12 @@ pub fn register(registry: &mut FunctionRegistry) { |_, _, _| FunctionDomain::MayThrow, vectorize_with_builder_2_arg::, ArrayType, Float32Type>( |lhs, rhs, output, ctx| { - let l_f32= + let l= unsafe { std::mem::transmute::, Buffer>(lhs) }; - let r_f32= + let r = unsafe { std::mem::transmute::, Buffer>(rhs) }; - match l2_distance(l_f32.as_slice(), r_f32.as_slice()) { + match l2_distance(l.as_slice(), r .as_slice()) { Ok(dist) => { output.push(F32::from(dist)); } @@ -77,6 +81,52 @@ pub fn register(registry: &mut FunctionRegistry) { ), ); + registry.register_passthrough_nullable_2_arg::, ArrayType, Float64Type, _, _>( + "cosine_distance", + |_, _, _| FunctionDomain::MayThrow, + vectorize_with_builder_2_arg::, ArrayType, Float64Type>( + |lhs, rhs, output, ctx| { + let l = + unsafe { std::mem::transmute::, Buffer>(lhs) }; + let r = + unsafe { std::mem::transmute::, Buffer>(rhs) }; + + match cosine_distance_64(l.as_slice(), r .as_slice()) { + Ok(dist) => { + output.push(F64::from(dist)); + } + Err(err) => { + ctx.set_error(output.len(), err.to_string()); + output.push(F64::from(0.0)); + } + } + } + ), + ); + + registry.register_passthrough_nullable_2_arg::, ArrayType, Float64Type, _, _>( + "l2_distance", + |_, _, _| FunctionDomain::MayThrow, + vectorize_with_builder_2_arg::, ArrayType, Float64Type>( + |lhs, rhs, output, ctx| { + let l= + unsafe { std::mem::transmute::, Buffer>(lhs) }; + let r = + unsafe { std::mem::transmute::, Buffer>(rhs) }; + + match l2_distance_64(l.as_slice(), r .as_slice()) { + Ok(dist) => { + output.push(F64::from(dist)); + } + Err(err) => { + ctx.set_error(output.len(), err.to_string()); + output.push(F64::from(0.0)); + } + } + } + ), + ); + // embedding_vector // This function takes two strings as input, sends an API request to OpenAI, and returns the Float32 array of embeddings. // The OpenAI API key is pre-configured during the binder phase, so we rewrite this function and set the API key. diff --git a/src/query/functions/tests/it/scalars/testdata/function_list.txt b/src/query/functions/tests/it/scalars/testdata/function_list.txt index 6fd08ce249cb..7ae134ba9597 100644 --- a/src/query/functions/tests/it/scalars/testdata/function_list.txt +++ b/src/query/functions/tests/it/scalars/testdata/function_list.txt @@ -1169,6 +1169,8 @@ Functions overloads: 1 cos(Float64 NULL) :: Float64 NULL 0 cosine_distance(Array(Float32), Array(Float32)) :: Float32 1 cosine_distance(Array(Float32) NULL, Array(Float32) NULL) :: Float32 NULL +2 cosine_distance(Array(Float64), Array(Float64)) :: Float64 +3 cosine_distance(Array(Float64) NULL, Array(Float64) NULL) :: Float64 NULL 0 cot(Float64) :: Float64 1 cot(Float64 NULL) :: Float64 NULL 0 crc32(String) :: UInt32 @@ -1892,6 +1894,8 @@ Functions overloads: 1 json_typeof(Variant NULL) :: String NULL 0 l2_distance(Array(Float32), Array(Float32)) :: Float32 1 l2_distance(Array(Float32) NULL, Array(Float32) NULL) :: Float32 NULL +2 l2_distance(Array(Float64), Array(Float64)) :: Float64 +3 l2_distance(Array(Float64) NULL, Array(Float64) NULL) :: Float64 NULL 0 left(String, UInt64) :: String 1 left(String NULL, UInt64 NULL) :: String NULL 0 length(Variant NULL) :: UInt32 NULL diff --git a/tests/sqllogictests/suites/query/02_function/02_0063_function_vector.test b/tests/sqllogictests/suites/query/02_function/02_0063_function_vector.test index 5901a7c1fc43..a5e02dd86500 100644 --- a/tests/sqllogictests/suites/query/02_function/02_0063_function_vector.test +++ b/tests/sqllogictests/suites/query/02_function/02_0063_function_vector.test @@ -1,8 +1,8 @@ # From sklearn.metrics.pairwise import cosine_similarity query F -select cosine_distance([3.0::Float32, 45.0, 7.0, 2.0, 5.0, 20.0, 13.0, 12.0], [2.0::Float32, 54.0, 13.0, 15.0, 22.0, 34.0, 50.0, 1.0]) as sim +select cosine_distance([3.0, 45.0, 7.0, 2.0, 5.0, 20.0, 13.0, 12.0], [2.0, 54.0, 13.0, 15.0, 22.0, 34.0, 50.0, 1.0]) as sim ---- -0.1264193 +0.12641934893868967 query F select [1, 2] <-> [2, 3] as sim