Skip to content

Commit

Permalink
chore(query): desugar round/truncate
Browse files Browse the repository at this point in the history
  • Loading branch information
sundy-li committed Dec 20, 2023
1 parent 4705cae commit 757db35
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 79 deletions.
27 changes: 0 additions & 27 deletions src/query/expression/src/type_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ use crate::function::FunctionSignature;
use crate::types::decimal::DecimalSize;
use crate::types::decimal::MAX_DECIMAL128_PRECISION;
use crate::types::decimal::MAX_DECIMAL256_PRECISION;
use crate::types::ArgType;
use crate::types::DataType;
use crate::types::DecimalDataType;
use crate::types::Int64Type;
use crate::types::Number;
use crate::types::NumberScalar;
use crate::AutoCastRules;
Expand Down Expand Up @@ -127,31 +125,6 @@ pub fn check<Index: ColumnIndex>(
}
}

// inject the params
if ["round", "truncate"].contains(&name.as_str()) && params.is_empty() {
let mut scale = 0;
let mut new_args = args_expr.clone();

if args_expr.len() == 2 {
let scalar_expr = &args_expr[1];
scale = check_number::<_, i64>(
scalar_expr.span(),
&FunctionContext::default(),
scalar_expr,
fn_registry,
)?;
} else {
new_args.push(Expr::Constant {
span: None,
scalar: Scalar::Number(scale.into()),
data_type: Int64Type::data_type(),
})
}
scale = scale.clamp(-76, 76);
let params = vec![Scalar::Number(scale.into())];
return check_function(*span, name, &params, &args_expr, fn_registry);
}

check_function(*span, name, params, &args_expr, fn_registry)
}
RawExpr::LambdaFunctionCall {
Expand Down
1 change: 1 addition & 0 deletions src/query/functions/src/scalars/decimal/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub fn register_decimal_math(registry: &mut FunctionRegistry) {
let from_decimal_type = from_type.as_decimal().unwrap();

let scale = if params.is_empty() {
debug_assert!(matches!(round_mode, RoundMode::Ceil | RoundMode::Floor));
0
} else {
params[0].get_i64()?
Expand Down
28 changes: 17 additions & 11 deletions src/query/functions/tests/it/scalars/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,17 @@ fn test_exp(file: &mut impl Write) {
}

fn test_round(file: &mut impl Write) {
run_ast(file, "round(-1.23)", &[]);
run_ast(file, "round(1.298, 1)", &[]);
run_ast(file, "round(1.298, 0)", &[]);
run_ast(file, "round(23.298, -1)", &[]);
run_ast(file, "round(0.12345678901234567890123456789012345, 35)", &[
]);
run_ast(file, "round(0)(-1.23)", &[]);
run_ast(file, "round(1)(1.298, 1)", &[]);
run_ast(file, "round(0)(1.298, 0)", &[]);

// Currently we don't support negative param
// run_ast(file, "round(-1)(23.298, -1)", &[]);
run_ast(
file,
"round(35)(0.12345678901234567890123456789012345, 35)",
&[],
);
run_ast(file, "round(a)", &[(
"a",
Float64Type::from_data(vec![22.22f64, -22.23, 10.0]),
Expand Down Expand Up @@ -131,11 +136,12 @@ fn test_factorial(file: &mut impl Write) {
}

fn test_truncate(file: &mut impl Write) {
run_ast(file, "truncate(1.223, 1)", &[]);
run_ast(file, "truncate(1.999)", &[]);
run_ast(file, "truncate(1.999, 1)", &[]);
run_ast(file, "truncate(122, -2)", &[]);
run_ast(file, "truncate(10.28*100, 0)", &[]);
run_ast(file, "truncate(1)(1.223, 1)", &[]);
run_ast(file, "truncate(0)(1.999)", &[]);
run_ast(file, "truncate(1)(1.999, 1)", &[]);
// todo(negative param)
// run_ast(file, "truncate(-2)(122, -2)", &[]);
run_ast(file, "truncate(0)(10.28*100, 0)", &[]);
run_ast(file, "truncate(a, 1)", &[(
"a",
Float64Type::from_data(vec![22.22f64, -22.23, 10.0]),
Expand Down
76 changes: 36 additions & 40 deletions src/query/functions/tests/it/scalars/testdata/math.txt
Original file line number Diff line number Diff line change
Expand Up @@ -279,44 +279,35 @@ evaluation (internal):
+--------+----------------------------------------------------------+


ast : round(-1.23)
raw expr : round(minus(1.23))
ast : round(0)(-1.23)
raw expr : round(0)(minus(1.23))
checked expr : round<Decimal(3, 2)>(0)(minus<Decimal(3, 2)>(1.23_d128(3,2)))
optimized expr : -1_d128(3,0)
output type : Decimal(3, 0)
output domain : {-1..=-1}
output : -1


ast : round(1.298, 1)
raw expr : round(1.298, 1)
ast : round(1)(1.298, 1)
raw expr : round(1)(1.298, 1)
checked expr : round<Decimal(4, 3), UInt8>(1)(1.298_d128(4,3), 1_u8)
optimized expr : 1.3_d128(4,1)
output type : Decimal(4, 1)
output domain : {1.3..=1.3}
output : 1.3


ast : round(1.298, 0)
raw expr : round(1.298, 0)
ast : round(0)(1.298, 0)
raw expr : round(0)(1.298, 0)
checked expr : round<Decimal(4, 3), UInt8>(0)(1.298_d128(4,3), 0_u8)
optimized expr : 1_d128(4,0)
output type : Decimal(4, 0)
output domain : {1..=1}
output : 1


ast : round(23.298, -1)
raw expr : round(23.298, minus(1))
checked expr : round<Decimal(5, 3), Int16>(-1)(23.298_d128(5,3), minus<UInt8>(1_u8))
optimized expr : 20_d128(5,0)
output type : Decimal(5, 0)
output domain : {20..=20}
output : 20


ast : round(0.12345678901234567890123456789012345, 35)
raw expr : round(0.12345678901234567890123456789012345, 35)
ast : round(35)(0.12345678901234567890123456789012345, 35)
raw expr : round(35)(0.12345678901234567890123456789012345, 35)
checked expr : round<Decimal(35, 35), UInt8>(35)(0.12345678901234567890123456789012345_d128(35,35), 35_u8)
optimized expr : 0.12345678901234567890123456789012345_d128(35,35)
output type : Decimal(35, 35)
Expand Down Expand Up @@ -408,44 +399,35 @@ evaluation (internal):
+--------+----------------------------------------------------+


ast : truncate(1.223, 1)
raw expr : truncate(1.223, 1)
ast : truncate(1)(1.223, 1)
raw expr : truncate(1)(1.223, 1)
checked expr : truncate<Decimal(4, 3), UInt8>(1)(1.223_d128(4,3), 1_u8)
optimized expr : 1.2_d128(4,1)
output type : Decimal(4, 1)
output domain : {1.2..=1.2}
output : 1.2


ast : truncate(1.999)
raw expr : truncate(1.999)
ast : truncate(0)(1.999)
raw expr : truncate(0)(1.999)
checked expr : truncate<Decimal(4, 3)>(0)(1.999_d128(4,3))
optimized expr : 1_d128(4,0)
output type : Decimal(4, 0)
output domain : {1..=1}
output : 1


ast : truncate(1.999, 1)
raw expr : truncate(1.999, 1)
ast : truncate(1)(1.999, 1)
raw expr : truncate(1)(1.999, 1)
checked expr : truncate<Decimal(4, 3), UInt8>(1)(1.999_d128(4,3), 1_u8)
optimized expr : 1.9_d128(4,1)
output type : Decimal(4, 1)
output domain : {1.9..=1.9}
output : 1.9


ast : truncate(122, -2)
raw expr : truncate(122, minus(2))
checked expr : truncate<UInt8, Int64>(122_u8, to_int64<Int16>(minus<UInt8>(2_u8)))
optimized expr : 100_f64
output type : Float64
output domain : {100..=100}
output : 100


ast : truncate(10.28*100, 0)
raw expr : truncate(multiply(10.28, 100), 0)
ast : truncate(0)(10.28*100, 0)
raw expr : truncate(0)(multiply(10.28, 100), 0)
checked expr : truncate<Decimal(7, 2), UInt8>(0)(multiply<Decimal(7, 2), Decimal(7, 0)>(to_decimal<Decimal(4, 2)>(7, 2)(10.28_d128(4,2)), to_decimal<UInt8>(7, 0)(100_u8)), 0_u8)
optimized expr : 1028_d128(7,0)
output type : Decimal(7, 0)
Expand Down Expand Up @@ -521,12 +503,26 @@ output domain : {0.6931471805..=0.6931471805}
output : 0.6931471805


error:
--> SQL:1:10
|
1 | round(2, a)
| ^ Need constant number.

ast : round(2, a)
raw expr : round(2, a::Int64)
checked expr : round<UInt8, Int64>(2_u8, a)
evaluation:
+--------+--------------+--------------+
| | a | Output |
+--------+--------------+--------------+
| Type | Int64 | Float64 |
| Domain | {10..=65536} | {-inf..=NaN} |
| Row 0 | 22 | 2 |
| Row 1 | 65536 | 2 |
| Row 2 | 10 | 2 |
+--------+--------------+--------------+
evaluation (internal):
+--------+------------------------+
| Column | Data |
+--------+------------------------+
| a | Int64([22, 65536, 10]) |
| Output | Float64([2, 2, 2]) |
+--------+------------------------+


ast : factorial(5)
Expand Down
26 changes: 25 additions & 1 deletion src/query/sql/src/planner/semantic/type_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1885,11 +1885,35 @@ impl<'a> TypeChecker<'a> {
&self,
span: Span,
func_name: &str,
params: Vec<i64>,
mut params: Vec<i64>,
args: Vec<ScalarExpr>,
) -> Result<Box<(ScalarExpr, DataType)>> {
// Type check
let arguments = args.iter().map(|v| v.as_raw_expr()).collect::<Vec<_>>();

// inject the params
if ["round", "truncate"].contains(&func_name)
&& !args.is_empty()
&& params.is_empty()
&& args[0].data_type()?.remove_nullable().is_decimal()
{
let scale = if args.len() == 2 {
let scalar_expr = &arguments[1];
let expr = type_check::check(scalar_expr, &BUILTIN_FUNCTIONS)?;

let scale = check_number::<_, i64>(
expr.span(),
&FunctionContext::default(),
&expr,
&BUILTIN_FUNCTIONS,
)?;
scale.clamp(-76, 76)
} else {
0
};
params.push(scale);
}

let raw_expr = RawExpr::FunctionCall {
span,
name: func_name.to_string(),
Expand Down

0 comments on commit 757db35

Please sign in to comment.