Skip to content

Commit

Permalink
feat(query): add float64 version of distance function overload
Browse files Browse the repository at this point in the history
  • Loading branch information
sundy-li committed Jan 12, 2024
1 parent 8791af1 commit ccb65b0
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 7 deletions.
34 changes: 34 additions & 0 deletions src/common/vector/src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,37 @@ pub fn l2_distance(from: &[f32], to: &[f32]) -> Result<f32> {
.sum::<f32>()
.sqrt())
}

pub fn cosine_distance_64(from: &[f64], to: &[f64]) -> Result<f64> {
if from.len() != to.len() {
return Err(ErrorCode::InvalidArgument(format!(
"Vector length not equal: {:} != {:}",
from.len(),
to.len(),
)));
}

let a = ArrayView::from(from);
let b = ArrayView::from(to);
let aa_sum = (&a * &a).sum();
let bb_sum = (&b * &b).sum();

Ok(1.0 - (&a * &b).sum() / ((aa_sum).sqrt() * (bb_sum).sqrt()))
}

pub fn l2_distance_64(from: &[f64], to: &[f64]) -> Result<f64> {
if from.len() != to.len() {
return Err(ErrorCode::InvalidArgument(format!(
"Vector length not equal: {:} != {:}",
from.len(),
to.len(),
)));
}

Ok(from
.iter()
.zip(to.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt())
}
2 changes: 2 additions & 0 deletions src/common/vector/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@
mod distance;

pub use distance::cosine_distance;
pub use distance::cosine_distance_64;
pub use distance::l2_distance;
pub use distance::l2_distance_64;
64 changes: 58 additions & 6 deletions src/query/functions/src/scalars/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,22 @@

use databend_common_arrow::arrow::buffer::Buffer;
use databend_common_expression::types::ArrayType;
use databend_common_expression::types::DataType;
use databend_common_expression::types::Float32Type;
use databend_common_expression::types::Float64Type;
use databend_common_expression::types::NumberDataType;
use databend_common_expression::types::StringType;
use databend_common_expression::types::F32;
use databend_common_expression::types::F64;
use databend_common_expression::vectorize_with_builder_1_arg;
use databend_common_expression::vectorize_with_builder_2_arg;
use databend_common_expression::FunctionDomain;
use databend_common_expression::FunctionRegistry;
use databend_common_openai::OpenAI;
use databend_common_vector::cosine_distance;
use databend_common_vector::cosine_distance_64;
use databend_common_vector::l2_distance;
use databend_common_vector::l2_distance_64;

pub fn register(registry: &mut FunctionRegistry) {
// cosine_distance
Expand All @@ -33,12 +39,12 @@ pub fn register(registry: &mut FunctionRegistry) {
|_, _, _| FunctionDomain::MayThrow,
vectorize_with_builder_2_arg::<ArrayType<Float32Type>, ArrayType<Float32Type>, Float32Type>(
|lhs, rhs, output, ctx| {
let l_f32=
let l=
unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(lhs) };
let r_f32=
let r =
unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(rhs) };

match cosine_distance(l_f32.as_slice(), r_f32.as_slice()) {
match cosine_distance(l.as_slice(), r .as_slice()) {
Ok(dist) => {
output.push(F32::from(dist));
}
Expand All @@ -59,12 +65,12 @@ pub fn register(registry: &mut FunctionRegistry) {
|_, _, _| FunctionDomain::MayThrow,
vectorize_with_builder_2_arg::<ArrayType<Float32Type>, ArrayType<Float32Type>, Float32Type>(
|lhs, rhs, output, ctx| {
let l_f32=
let l=
unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(lhs) };
let r_f32=
let r =
unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(rhs) };

match l2_distance(l_f32.as_slice(), r_f32.as_slice()) {
match l2_distance(l.as_slice(), r .as_slice()) {
Ok(dist) => {
output.push(F32::from(dist));
}
Expand All @@ -77,6 +83,52 @@ pub fn register(registry: &mut FunctionRegistry) {
),
);

registry.register_passthrough_nullable_2_arg::<ArrayType<Float64Type>, ArrayType<Float64Type>, Float64Type, _, _>(
"cosine_distance",
|_, _, _| FunctionDomain::MayThrow,
vectorize_with_builder_2_arg::<ArrayType<Float64Type>, ArrayType<Float64Type>, Float64Type>(
|lhs, rhs, output, ctx| {
let l =
unsafe { std::mem::transmute::<Buffer<F64>, Buffer<f64>>(lhs) };
let r =
unsafe { std::mem::transmute::<Buffer<F64>, Buffer<f64>>(rhs) };

match cosine_distance_64(l.as_slice(), r .as_slice()) {
Ok(dist) => {
output.push(F64::from(dist));
}
Err(err) => {
ctx.set_error(output.len(), err.to_string());
output.push(F64::from(0.0));
}
}
}
),
);

registry.register_passthrough_nullable_2_arg::<ArrayType<Float64Type>, ArrayType<Float64Type>, Float64Type, _, _>(
"l2_distance",
|_, _, _| FunctionDomain::MayThrow,
vectorize_with_builder_2_arg::<ArrayType<Float64Type>, ArrayType<Float64Type>, Float64Type>(
|lhs, rhs, output, ctx| {
let l=
unsafe { std::mem::transmute::<Buffer<F64>, Buffer<f64>>(lhs) };
let r =
unsafe { std::mem::transmute::<Buffer<F64>, Buffer<f64>>(rhs) };

match l2_distance_64(l.as_slice(), r .as_slice()) {
Ok(dist) => {
output.push(F64::from(dist));
}
Err(err) => {
ctx.set_error(output.len(), err.to_string());
output.push(F64::from(0.0));
}
}
}
),
);

// embedding_vector
// This function takes two strings as input, sends an API request to OpenAI, and returns the Float32 array of embeddings.
// The OpenAI API key is pre-configured during the binder phase, so we rewrite this function and set the API key.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,8 @@ Functions overloads:
1 cos(Float64 NULL) :: Float64 NULL
0 cosine_distance(Array(Float32), Array(Float32)) :: Float32
1 cosine_distance(Array(Float32) NULL, Array(Float32) NULL) :: Float32 NULL
2 cosine_distance(Array(Float64), Array(Float64)) :: Float64
3 cosine_distance(Array(Float64) NULL, Array(Float64) NULL) :: Float64 NULL
0 cot(Float64) :: Float64
1 cot(Float64 NULL) :: Float64 NULL
0 crc32(String) :: UInt32
Expand Down Expand Up @@ -1892,6 +1894,8 @@ Functions overloads:
1 json_typeof(Variant NULL) :: String NULL
0 l2_distance(Array(Float32), Array(Float32)) :: Float32
1 l2_distance(Array(Float32) NULL, Array(Float32) NULL) :: Float32 NULL
2 l2_distance(Array(Float64), Array(Float64)) :: Float64
3 l2_distance(Array(Float64) NULL, Array(Float64) NULL) :: Float64 NULL
0 left(String, UInt64) :: String
1 left(String NULL, UInt64 NULL) :: String NULL
0 length(Variant NULL) :: UInt32 NULL
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# From sklearn.metrics.pairwise import cosine_similarity
query F
select cosine_distance([3.0::Float32, 45.0, 7.0, 2.0, 5.0, 20.0, 13.0, 12.0], [2.0::Float32, 54.0, 13.0, 15.0, 22.0, 34.0, 50.0, 1.0]) as sim
select cosine_distance([3.0, 45.0, 7.0, 2.0, 5.0, 20.0, 13.0, 12.0], [2.0, 54.0, 13.0, 15.0, 22.0, 34.0, 50.0, 1.0]) as sim
----
0.1264193

Expand Down

0 comments on commit ccb65b0

Please sign in to comment.