Skip to content

Commit

Permalink
fix(query): support subquery in pivot (#16631)
Browse files Browse the repository at this point in the history
* fix(query): support subquery in pivot

* add pivot and unpivot sqllogictests, fix unit-test

* code format
  • Loading branch information
Dragonliu2018 authored Oct 20, 2024
1 parent 6a707e1 commit 4557131
Show file tree
Hide file tree
Showing 28 changed files with 1,389 additions and 187 deletions.
12 changes: 12 additions & 0 deletions src/query/ast/src/ast/format/syntax/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ pub(crate) fn pretty_table(table: TableReference) -> RcDoc<'static> {
lateral,
subquery,
alias,
pivot,
unpivot,
} => (if lateral {
RcDoc::text("LATERAL")
} else {
Expand All @@ -379,6 +381,16 @@ pub(crate) fn pretty_table(table: TableReference) -> RcDoc<'static> {
RcDoc::text(format!(" AS {alias}"))
} else {
RcDoc::nil()
})
.append(if let Some(pivot) = pivot {
RcDoc::text(format!(" {pivot}"))
} else {
RcDoc::nil()
})
.append(if let Some(unpivot) = unpivot {
RcDoc::text(format!(" {unpivot}"))
} else {
RcDoc::nil()
}),
TableReference::TableFunction {
span: _,
Expand Down
31 changes: 29 additions & 2 deletions src/query/ast/src/ast/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,17 +533,30 @@ impl Display for TimeTravelPoint {
}
}

#[derive(Debug, Clone, PartialEq, Drive, DriveMut)]
pub enum PivotValues {
ColumnValues(Vec<Expr>),
Subquery(Box<Query>),
}

#[derive(Debug, Clone, PartialEq, Drive, DriveMut)]
pub struct Pivot {
pub aggregate: Expr,
pub value_column: Identifier,
pub values: Vec<Expr>,
pub values: PivotValues,
}

impl Display for Pivot {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "PIVOT({} FOR {} IN (", self.aggregate, self.value_column)?;
write_comma_separated_list(f, &self.values)?;
match &self.values {
PivotValues::ColumnValues(column_values) => {
write_comma_separated_list(f, column_values)?;
}
PivotValues::Subquery(subquery) => {
write!(f, "{}", subquery)?;
}
}
write!(f, "))")?;
Ok(())
}
Expand Down Expand Up @@ -740,6 +753,8 @@ pub enum TableReference {
lateral: bool,
subquery: Box<Query>,
alias: Option<TableAlias>,
pivot: Option<Box<Pivot>>,
unpivot: Option<Box<Unpivot>>,
},
Join {
span: Span,
Expand All @@ -757,13 +772,15 @@ impl TableReference {
pub fn pivot(&self) -> Option<&Pivot> {
match self {
TableReference::Table { pivot, .. } => pivot.as_ref().map(|b| b.as_ref()),
TableReference::Subquery { pivot, .. } => pivot.as_ref().map(|b| b.as_ref()),
_ => None,
}
}

pub fn unpivot(&self) -> Option<&Unpivot> {
match self {
TableReference::Table { unpivot, .. } => unpivot.as_ref().map(|b| b.as_ref()),
TableReference::Subquery { unpivot, .. } => unpivot.as_ref().map(|b| b.as_ref()),
_ => None,
}
}
Expand Down Expand Up @@ -862,6 +879,8 @@ impl Display for TableReference {
lateral,
subquery,
alias,
pivot,
unpivot,
} => {
if *lateral {
write!(f, "LATERAL ")?;
Expand All @@ -870,6 +889,14 @@ impl Display for TableReference {
if let Some(alias) = alias {
write!(f, " AS {alias}")?;
}

if let Some(pivot) = pivot {
write!(f, " {pivot}")?;
}

if let Some(unpivot) = unpivot {
write!(f, " {unpivot}")?;
}
}
TableReference::Join { span: _, join } => {
write!(f, "{}", join.left)?;
Expand Down
2 changes: 2 additions & 0 deletions src/query/ast/src/ast/statements/merge_into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ impl MergeSource {
lateral: false,
subquery: query.clone(),
alias: Some(source_alias.clone()),
pivot: None,
unpivot: None,
},
Self::Table {
catalog,
Expand Down
69 changes: 45 additions & 24 deletions src/query/ast/src/parser/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,8 @@ pub enum TableReferenceElement {
lateral: bool,
subquery: Box<Query>,
alias: Option<TableAlias>,
pivot: Option<Box<Pivot>>,
unpivot: Option<Box<Unpivot>>,
},
// [NATURAL] [INNER|OUTER|CROSS|...] JOIN
Join {
Expand All @@ -736,28 +738,6 @@ pub enum TableReferenceElement {
}

pub fn table_reference_element(i: Input) -> IResult<WithSpan<TableReferenceElement>> {
// PIVOT(expr FOR col IN (ident, ...))
let pivot = map(
rule! {
PIVOT ~ "(" ~ #expr ~ FOR ~ #ident ~ IN ~ "(" ~ #comma_separated_list1(expr) ~ ")" ~ ")"
},
|(_pivot, _, aggregate, _for, value_column, _in, _, values, _, _)| Pivot {
aggregate,
value_column,
values,
},
);
// UNPIVOT(ident for ident IN (ident, ...))
let unpivot = map(
rule! {
UNPIVOT ~ "(" ~ #ident ~ FOR ~ #ident ~ IN ~ "(" ~ #comma_separated_list1(ident) ~ ")" ~ ")"
},
|(_unpivot, _, value_column, _for, column_name, _in, _, names, _, _)| Unpivot {
value_column,
column_name,
names,
},
);
let aliased_table = map(
rule! {
#dot_separated_idents_1_to_3 ~ #temporal_clause? ~ #with_options? ~ #table_alias? ~ #pivot? ~ #unpivot? ~ SAMPLE? ~ (BLOCK ~ "(" ~ #expr ~ ")")? ~ (ROW ~ "(" ~ #expr ~ ROWS? ~ ")")?
Expand Down Expand Up @@ -825,12 +805,14 @@ pub fn table_reference_element(i: Input) -> IResult<WithSpan<TableReferenceEleme
);
let subquery = map(
rule! {
LATERAL? ~ "(" ~ #query ~ ")" ~ #table_alias?
LATERAL? ~ "(" ~ #query ~ ")" ~ #table_alias? ~ #pivot? ~ #unpivot?
},
|(lateral, _, subquery, _, alias)| TableReferenceElement::Subquery {
|(lateral, _, subquery, _, alias, pivot, unpivot)| TableReferenceElement::Subquery {
lateral: lateral.is_some(),
subquery: Box::new(subquery),
alias,
pivot: pivot.map(Box::new),
unpivot: unpivot.map(Box::new),
},
);

Expand Down Expand Up @@ -869,6 +851,41 @@ pub fn table_reference_element(i: Input) -> IResult<WithSpan<TableReferenceEleme
Ok((rest, WithSpan { span, elem }))
}

// PIVOT(expr FOR col IN (ident, ... | subquery))
fn pivot(i: Input) -> IResult<Pivot> {
map(
rule! {
PIVOT ~ "(" ~ #expr ~ FOR ~ #ident ~ IN ~ "(" ~ #pivot_values ~ ")" ~ ")"
},
|(_pivot, _, aggregate, _for, value_column, _in, _, values, _, _)| Pivot {
aggregate,
value_column,
values,
},
)(i)
}

// UNPIVOT(ident for ident IN (ident, ...))
fn unpivot(i: Input) -> IResult<Unpivot> {
map(
rule! {
UNPIVOT ~ "(" ~ #ident ~ FOR ~ #ident ~ IN ~ "(" ~ #comma_separated_list1(ident) ~ ")" ~ ")"
},
|(_unpivot, _, value_column, _for, column_name, _in, _, names, _, _)| Unpivot {
value_column,
column_name,
names,
},
)(i)
}

fn pivot_values(i: Input) -> IResult<PivotValues> {
alt((
map(comma_separated_list1(expr), PivotValues::ColumnValues),
map(query, |q| PivotValues::Subquery(Box::new(q))),
))(i)
}

fn get_table_sample(
sample: Option<&Token>,
block_level_sample: Option<(&Token, &Token, Expr, &Token)>,
Expand Down Expand Up @@ -966,11 +983,15 @@ impl<'a, I: Iterator<Item = WithSpan<'a, TableReferenceElement>>> PrattParser<I>
lateral,
subquery,
alias,
pivot,
unpivot,
} => TableReference::Subquery {
span: transform_span(input.span.tokens),
lateral,
subquery,
alias,
pivot,
unpivot,
},
TableReferenceElement::Stage {
location,
Expand Down
4 changes: 4 additions & 0 deletions src/query/ast/tests/it/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,11 @@ fn test_query() {
r#"SELECT * FROM (((SELECT *) EXCEPT (SELECT *))) foo"#,
r#"SELECT * FROM (SELECT * FROM xyu ORDER BY x, y) AS xyu"#,
r#"select * from monthly_sales pivot(sum(amount) for month in ('JAN', 'FEB', 'MAR', 'APR')) order by empid"#,
r#"select * from (select * from monthly_sales) pivot(sum(amount) for month in ('JAN', 'FEB', 'MAR', 'APR')) order by empid"#,
r#"select * from monthly_sales pivot(sum(amount) for month in (select distinct month from monthly_sales)) order by empid"#,
r#"select * from (select * from monthly_sales) pivot(sum(amount) for month in ((select distinct month from monthly_sales))) order by empid"#,
r#"select * from monthly_sales_1 unpivot(sales for month in (jan, feb, mar, april)) order by empid"#,
r#"select * from (select * from monthly_sales_1) unpivot(sales for month in (jan, feb, mar, april)) order by empid"#,
r#"select * from range(1, 2)"#,
r#"select sum(a) over w from customer window w as (partition by a order by b)"#,
r#"select a, sum(a) over w, sum(a) over w1, sum(a) over w2 from t1 window w as (partition by a), w2 as (w1 rows current row), w1 as (w order by a) order by a"#,
Expand Down
Loading

0 comments on commit 4557131

Please sign in to comment.