Skip to content

Commit

Permalink
feat: Port Datafusion Covariance to Comet
Browse files Browse the repository at this point in the history
  • Loading branch information
Huaxin Gao committed Apr 10, 2024
1 parent 06bbb36 commit e6c4d6f
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 58 deletions.
58 changes: 27 additions & 31 deletions core/src/execution/datafusion/expressions/covariance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,22 @@

//! Defines physical expressions that can evaluated at runtime during query execution
use std::any::Any;
use std::sync::Arc;
use std::{any::Any, sync::Arc};

use arrow::array::Float64Array;
use arrow::{
array::{ArrayRef, Int64Array},
array::{ArrayRef, Float64Array},
compute::cast,
datatypes::DataType,
datatypes::Field,
datatypes::{DataType, Field},
};
use datafusion::logical_expr::Accumulator;
use datafusion_common::{downcast_value, unwrap_or_internal_err, ScalarValue};
use datafusion_common::{DataFusionError, Result};
use datafusion_physical_expr::{aggregate::utils::down_cast_any_ref, expressions::format_state_name, AggregateExpr, PhysicalExpr};
use datafusion_physical_expr::expressions::StatsType;
use datafusion_common::{
downcast_value, unwrap_or_internal_err, DataFusionError, Result, ScalarValue,
};
use datafusion_physical_expr::{
aggregate::utils::down_cast_any_ref,
expressions::{format_state_name, StatsType},
AggregateExpr, PhysicalExpr,
};

/// COVAR and COVAR_SAMP aggregate expression
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -85,7 +86,7 @@ impl AggregateExpr for Covariance {
Ok(vec![
Field::new(
format_state_name(&self.name, "count"),
DataType::Int64,
DataType::Float64,
true,
),
Field::new(
Expand Down Expand Up @@ -119,9 +120,7 @@ impl PartialEq<dyn Any> for Covariance {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name && self.expr1.eq(&x.expr1) && self.expr2.eq(&x.expr2)
})
.map(|x| self.name == x.name && self.expr1.eq(&x.expr1) && self.expr2.eq(&x.expr2))
.unwrap_or(false)
}
}
Expand Down Expand Up @@ -164,7 +163,7 @@ impl AggregateExpr for CovariancePop {
Ok(vec![
Field::new(
format_state_name(&self.name, "count"),
DataType::Int64,
DataType::Float64,
true,
),
Field::new(
Expand Down Expand Up @@ -198,9 +197,7 @@ impl PartialEq<dyn Any> for CovariancePop {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name && self.expr1.eq(&x.expr1) && self.expr2.eq(&x.expr2)
})
.map(|x| self.name == x.name && self.expr1.eq(&x.expr1) && self.expr2.eq(&x.expr2))
.unwrap_or(false)
}
}
Expand All @@ -211,7 +208,7 @@ pub struct CovarianceAccumulator {
algo_const: f64,
mean1: f64,
mean2: f64,
count: i64,
count: f64,
stats_type: StatsType,
}

Expand All @@ -222,12 +219,12 @@ impl CovarianceAccumulator {
algo_const: 0_f64,
mean1: 0_f64,
mean2: 0_f64,
count: 0_i64,
count: 0_f64,
stats_type: s_type,
})
}

pub fn get_count(&self) -> i64 {
pub fn get_count(&self) -> f64 {
self.count
}

Expand Down Expand Up @@ -279,14 +276,14 @@ impl Accumulator for CovarianceAccumulator {

let value1 = unwrap_or_internal_err!(value1);
let value2 = unwrap_or_internal_err!(value2);
let new_count = self.count + 1;
let new_count = self.count + 1.0;
let delta1 = value1 - self.mean1;
let new_mean1 = delta1 / new_count as f64 + self.mean1;
let delta2 = value2 - self.mean2;
let new_mean2 = delta2 / new_count as f64 + self.mean2;
let new_c = delta1 * (value2 - new_mean2) + self.algo_const;

self.count += 1;
self.count += 1.0;
self.mean1 = new_mean1;
self.mean2 = new_mean2;
self.algo_const = new_c;
Expand Down Expand Up @@ -320,14 +317,14 @@ impl Accumulator for CovarianceAccumulator {
let value1 = unwrap_or_internal_err!(value1);
let value2 = unwrap_or_internal_err!(value2);

let new_count = self.count - 1;
let new_count = self.count - 1.0;
let delta1 = self.mean1 - value1;
let new_mean1 = delta1 / new_count as f64 + self.mean1;
let delta2 = self.mean2 - value2;
let new_mean2 = delta2 / new_count as f64 + self.mean2;
let new_c = self.algo_const - delta1 * (new_mean2 - value2);

self.count -= 1;
self.count -= 1.0;
self.mean1 = new_mean1;
self.mean2 = new_mean2;
self.algo_const = new_c;
Expand All @@ -337,14 +334,14 @@ impl Accumulator for CovarianceAccumulator {
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let counts = downcast_value!(states[0], Int64Array);
let counts = downcast_value!(states[0], Float64Array);
let means1 = downcast_value!(states[1], Float64Array);
let means2 = downcast_value!(states[2], Float64Array);
let cs = downcast_value!(states[3], Float64Array);

for i in 0..counts.len() {
let c = counts.value(i);
if c == 0 {
if c == 0.0 {
continue;
}
let new_count = self.count + c;
Expand All @@ -369,19 +366,18 @@ impl Accumulator for CovarianceAccumulator {
}

fn evaluate(&mut self) -> Result<ScalarValue> {
println!("evaluate evaluate evaluate");
let count = match self.stats_type {
datafusion_physical_expr::expressions::StatsType::Population => self.count,
StatsType::Sample => {
if self.count > 0 {
self.count - 1
if self.count > 0.0 {
self.count - 1.0
} else {
self.count
}
}
};

if count == 0 {
if count == 0.0 {
Ok(ScalarValue::Float64(None))
} else {
Ok(ScalarValue::Float64(Some(self.algo_const / count as f64)))
Expand Down
2 changes: 1 addition & 1 deletion core/src/execution/datafusion/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ pub use normalize_nan::NormalizeNaNAndZero;
pub mod avg;
pub mod avg_decimal;
pub mod bloom_filter_might_contain;
pub mod covariance;
pub mod strings;
pub mod subquery;
pub mod sum_decimal;
pub mod temporal;
mod utils;
pub mod covariance;
3 changes: 1 addition & 2 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,11 @@ use crate::{
expressions::{
avg::Avg,
avg_decimal::AvgDecimal,
covariance::Covariance,
covariance::CovariancePop,
bitwise_not::BitwiseNotExpr,
bloom_filter_might_contain::BloomFilterMightContain,
cast::Cast,
checkoverflow::CheckOverflow,
covariance::{Covariance, CovariancePop},
if_expr::IfExpr,
scalar_funcs::create_comet_physical_fun,
strings::{Contains, EndsWith, Like, StartsWith, StringSpaceExec, SubstringExec},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -998,30 +998,38 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

test("covar_pop and covar_samp") {
withSQLConf(
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") {
Seq(false).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
val table = "test"
withTable(table) {
sql(
s"create table $table(col1 int, col2 int, col3 int, col4 float, col5 double, col6 double, col7 int) using parquet")
sql(
s"insert into $table values(1, 4, null, 1.1, 2.2, null, 1), (2, 5, 6, 3.4, 5.6, null, 1), (3, 6, null, 7.9, 2.4, null, 2)")
val expectedNumOfCometAggregates = 2
checkSparkAnswerAndNumOfAggregates(
"SELECT covar_samp(col1, col2), covar_samp(col1, col3), covar_samp(col4, col5), covar_samp(col4, col6) FROM test",
expectedNumOfCometAggregates)
checkSparkAnswerAndNumOfAggregates(
"SELECT covar_pop(col1, col2), covar_pop(col1, col3), covar_pop(col4, col5), covar_pop(col4, col6) FROM test",
expectedNumOfCometAggregates)
checkSparkAnswerAndNumOfAggregates(
"SELECT covar_samp(col1, col2), covar_samp(col1, col3), covar_samp(col4, col5), covar_samp(col4, col6) FROM test GROUP BY col7",
expectedNumOfCometAggregates)
checkSparkAnswerAndNumOfAggregates(
"SELECT covar_pop(col1, col2), covar_pop(col1, col3), covar_pop(col4, col5), covar_pop(col4, col6) FROM test GROUP BY col7",
expectedNumOfCometAggregates)
withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
Seq(true, false).foreach { cometColumnShuffleEnabled =>
withSQLConf(
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> cometColumnShuffleEnabled.toString) {
Seq(true, false).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
val table = "test"
withTable(table) {
sql(
s"create table $table(col1 int, col2 int, col3 int, col4 float, col5 double," +
s" col6 double, col7 int) using parquet")
sql(s"insert into $table values(1, 4, null, 1.1, 2.2, null, 1)," +
s" (2, 5, 6, 3.4, 5.6, null, 1), (3, 6, null, 7.9, 2.4, null, 2)")
val expectedNumOfCometAggregates = 2
checkSparkAnswerAndNumOfAggregates(
"SELECT covar_samp(col1, col2), covar_samp(col1, col3), covar_samp(col4, col5)," +
" covar_samp(col4, col6) FROM test",
expectedNumOfCometAggregates)
checkSparkAnswerAndNumOfAggregates(
"SELECT covar_pop(col1, col2), covar_pop(col1, col3), covar_pop(col4, col5)," +
" covar_pop(col4, col6) FROM test",
expectedNumOfCometAggregates)
checkSparkAnswerAndNumOfAggregates(
"SELECT covar_samp(col1, col2), covar_samp(col1, col3), covar_samp(col4, col5)," +
" covar_samp(col4, col6) FROM test GROUP BY col7",
expectedNumOfCometAggregates)
checkSparkAnswerAndNumOfAggregates(
"SELECT covar_pop(col1, col2), covar_pop(col1, col3), covar_pop(col4, col5)," +
" covar_pop(col4, col6) FROM test GROUP BY col7",
expectedNumOfCometAggregates)
}
}
}
}
}
Expand Down

0 comments on commit e6c4d6f

Please sign in to comment.