Skip to content

Commit

Permalink
fix bugs of aggregate column order mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
shirly121 committed Dec 13, 2024
1 parent c02b1d7 commit 18df95b
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 53 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
*
* * Copyright 2020 Alibaba Group Holding Limited.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/

package com.alibaba.graphscope.cypher.antlr4.visitor;

import com.alibaba.graphscope.common.ir.tools.GraphBuilder;
import com.google.common.base.Objects;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexNode;
import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
* ColumnOrder keeps fields as the same order with RETURN clause
*/
public class ColumnOrder {
public static class Field {
private final RexNode expr;
private final String alias;

public Field(RexNode expr, String alias) {
this.expr = expr;
this.alias = alias;
}

public RexNode getExpr() {
return expr;
}

public String getAlias() {
return alias;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Field field = (Field) o;
return Objects.equal(expr, field.expr) && Objects.equal(alias, field.alias);
}

@Override
public int hashCode() {
return Objects.hashCode(expr, alias);
}
}

public interface FieldSupplier {
Field get(RelDataType inputType);

class Default implements FieldSupplier {
private final GraphBuilder builder;
private final Supplier<Integer> ordinalSupplier;

public Default(GraphBuilder builder, Supplier<Integer> ordinalSupplier) {
this.builder = builder;
this.ordinalSupplier = ordinalSupplier;
}

@Override
public Field get(RelDataType inputType) {
String aliasName = inputType.getFieldList().get(ordinalSupplier.get()).getName();
return new Field(this.builder.variable(aliasName), aliasName);
}
}
}

private final List<FieldSupplier> fieldSuppliers;

public ColumnOrder(List<FieldSupplier> fieldSuppliers) {
this.fieldSuppliers = fieldSuppliers;
}

public @Nullable List<Field> getFields(RelDataType inputType) {
return this.fieldSuppliers.stream().map(k -> k.get(inputType)).collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@

import com.alibaba.graphscope.common.antlr4.ExprUniqueAliasInfer;
import com.alibaba.graphscope.common.antlr4.ExprVisitorResult;
import com.alibaba.graphscope.common.ir.rel.GraphLogicalAggregate;
import com.alibaba.graphscope.common.ir.rel.GraphProcedureCall;
import com.alibaba.graphscope.common.ir.rel.graph.GraphLogicalGetV;
import com.alibaba.graphscope.common.ir.rel.graph.GraphLogicalPathExpand;
import com.alibaba.graphscope.common.ir.rel.type.group.GraphAggCall;
import com.alibaba.graphscope.common.ir.rex.RexTmpVariableConverter;
import com.alibaba.graphscope.common.ir.rex.RexVariableAliasCollector;
import com.alibaba.graphscope.common.ir.tools.GraphBuilder;
import com.alibaba.graphscope.common.ir.tools.config.GraphOpt;
import com.alibaba.graphscope.grammar.CypherGSBaseVisitor;
Expand All @@ -39,6 +37,7 @@
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexSubQuery;
Expand All @@ -49,6 +48,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

public class GraphBuilderVisitor extends CypherGSBaseVisitor<GraphBuilder> {
Expand Down Expand Up @@ -271,49 +271,34 @@ public GraphBuilder visitOC_ProjectionBody(CypherGSParser.OC_ProjectionBodyConte
List<RexNode> keyExprs = new ArrayList<>();
List<String> keyAliases = new ArrayList<>();
List<RelBuilder.AggCall> aggCalls = new ArrayList<>();
List<RexNode> extraExprs = new ArrayList<>();
List<String> extraAliases = new ArrayList<>();
if (isGroupPattern(ctx, keyExprs, keyAliases, aggCalls, extraExprs, extraAliases)) {
AtomicReference<ColumnOrder> columnManagerRef = new AtomicReference<>();
if (isGroupPattern(ctx, keyExprs, keyAliases, aggCalls, columnManagerRef)) {
RelBuilder.GroupKey groupKey;
if (keyExprs.isEmpty()) {
groupKey = builder.groupKey();
} else {
groupKey = builder.groupKey(keyExprs, keyAliases);
}
builder.aggregate(groupKey, aggCalls);
if (!extraExprs.isEmpty()) {
RelDataType inputType = builder.peek().getRowType();
List<ColumnOrder.Field> originalFields =
inputType.getFieldList().stream()
.map(
k ->
new ColumnOrder.Field(
builder.variable(k.getName()), k.getName()))
.collect(Collectors.toList());
List<ColumnOrder.Field> newFields = columnManagerRef.get().getFields(inputType);
if (!originalFields.equals(newFields)) {
List<RexNode> extraExprs = new ArrayList<>();
List<@Nullable String> extraAliases = new ArrayList<>();
RexTmpVariableConverter converter = new RexTmpVariableConverter(true, builder);
extraExprs =
extraExprs.stream()
.map(k -> k.accept(converter))
.collect(Collectors.toList());
List<RexNode> projectExprs = Lists.newArrayList();
List<String> projectAliases = Lists.newArrayList();
List<String> extraVarNames = Lists.newArrayList();
RexVariableAliasCollector<String> varNameCollector =
new RexVariableAliasCollector<>(
true,
v -> {
String[] splits = v.getName().split("\\.");
return splits[0];
});
extraExprs.forEach(k -> extraVarNames.addAll(k.accept(varNameCollector)));
GraphLogicalAggregate aggregate = (GraphLogicalAggregate) builder.peek();
aggregate
.getRowType()
.getFieldList()
.forEach(
field -> {
if (!extraVarNames.contains(field.getName())) {
projectExprs.add(builder.variable(field.getName()));
projectAliases.add(field.getName());
}
});
for (int i = 0; i < extraExprs.size(); ++i) {
projectExprs.add(extraExprs.get(i));
projectAliases.add(extraAliases.get(i));
}
builder.project(projectExprs, projectAliases, false);
newFields.forEach(
k -> {
extraExprs.add(k.getExpr().accept(converter));
extraAliases.add(k.getAlias());
});
builder.project(extraExprs, extraAliases, false);
}
} else if (isDistinct) {
builder.aggregate(builder.groupKey(keyExprs, keyAliases));
Expand All @@ -334,21 +319,27 @@ private boolean isGroupPattern(
List<RexNode> keyExprs,
List<String> keyAliases,
List<RelBuilder.AggCall> aggCalls,
List<RexNode> extraExprs,
List<String> extraAliases) {
AtomicReference<ColumnOrder> columnManagerRef) {
List<ColumnOrder.FieldSupplier> fieldSuppliers = Lists.newArrayList();
for (CypherGSParser.OC_ProjectionItemContext itemCtx :
ctx.oC_ProjectionItems().oC_ProjectionItem()) {
ExprVisitorResult item = expressionVisitor.visitOC_Expression(itemCtx.oC_Expression());
String alias = (itemCtx.AS() == null) ? null : itemCtx.oC_Variable().getText();
if (item.getAggCalls().isEmpty()) {
int ordinal = keyExprs.size();
fieldSuppliers.add(new ColumnOrder.FieldSupplier.Default(builder, () -> ordinal));
keyExprs.add(item.getExpr());
keyAliases.add(alias);
} else {
if (item.getExpr() instanceof RexCall) {
extraExprs.add(item.getExpr());
extraAliases.add(alias);
fieldSuppliers.add(
(RelDataType type) -> new ColumnOrder.Field(item.getExpr(), alias));
aggCalls.addAll(item.getAggCalls());
} else if (item.getAggCalls().size() == 1) { // count(a.name)
int ordinal = aggCalls.size();
fieldSuppliers.add(
new ColumnOrder.FieldSupplier.Default(
builder, () -> keyExprs.size() + ordinal));
GraphAggCall original = (GraphAggCall) item.getAggCalls().get(0);
aggCalls.add(
new GraphAggCall(
Expand All @@ -362,6 +353,7 @@ private boolean isGroupPattern(
}
}
}
columnManagerRef.set(new ColumnOrder(fieldSuppliers));
return !aggCalls.isEmpty();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ public void bi1_test() {
+ " totalMessageCount)], isAppend=[false])\n"
+ " GraphLogicalProject(totalMessageCount=[totalMessageCount], year=[year],"
+ " isComment=[isComment], lengthCategory=[lengthCategory],"
+ " messageCount=[messageCount], sumMessageLength=[sumMessageLength],"
+ " averageMessageLength=[/(EXPR$2, EXPR$3)], isAppend=[false])\n"
+ " messageCount=[messageCount], averageMessageLength=[/(EXPR$2, EXPR$3)],"
+ " sumMessageLength=[sumMessageLength], isAppend=[false])\n"
+ " GraphLogicalAggregate(keys=[{variables=[totalMessageCount, year, $f0,"
+ " $f1], aliases=[totalMessageCount, year, isComment, lengthCategory]}],"
+ " values=[[{operands=[message], aggFunction=COUNT, alias='messageCount',"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -511,30 +511,32 @@ public void ldbc7_test() {
+ " messageContent=[message.content], messageImageFile=[message.imageFile],"
+ " minutesLatency=[/(/(-(likeTime, message.creationDate), 1000), 60)],"
+ " isNew=[isNew], isAppend=[false])\n"
+ " GraphLogicalAggregate(keys=[{variables=[liker, person, isNew],"
+ " GraphLogicalProject(liker=[liker], person=[person], message=[message],"
+ " likeTime=[likeTime], isNew=[isNew], isAppend=[false])\n"
+ " GraphLogicalAggregate(keys=[{variables=[liker, person, isNew],"
+ " aliases=[liker, person, isNew]}], values=[[{operands=[message],"
+ " aggFunction=FIRST_VALUE, alias='message', distinct=false},"
+ " {operands=[likeTime], aggFunction=FIRST_VALUE, alias='likeTime',"
+ " distinct=false}]])\n"
+ " GraphLogicalSort(sort0=[likeTime], sort1=[message.id], dir0=[DESC],"
+ " GraphLogicalSort(sort0=[likeTime], sort1=[message.id], dir0=[DESC],"
+ " dir1=[ASC])\n"
+ " GraphLogicalProject(liker=[liker], message=[message],"
+ " GraphLogicalProject(liker=[liker], message=[message],"
+ " likeTime=[like.creationDate], person=[person], isNew=[IS NULL(k)],"
+ " isAppend=[false])\n"
+ " MultiJoin(joinFilter=[=(liker, liker)], isFullOuterJoin=[false],"
+ " MultiJoin(joinFilter=[=(liker, liker)], isFullOuterJoin=[false],"
+ " joinTypes=[[INNER, INNER]], outerJoinConditions=[[NULL, NULL]],"
+ " projFields=[[ALL, ALL]])\n"
+ " GraphLogicalGetV(tableConfig=[{isAll=false, tables=[PERSON]}],"
+ " GraphLogicalGetV(tableConfig=[{isAll=false, tables=[PERSON]}],"
+ " alias=[liker], opt=[START])\n"
+ " GraphLogicalExpand(tableConfig=[{isAll=false,"
+ " GraphLogicalExpand(tableConfig=[{isAll=false,"
+ " tables=[LIKES]}], alias=[like], startAlias=[message], opt=[IN])\n"
+ " CommonTableScan(table=[[common#378747223]])\n"
+ " GraphLogicalGetV(tableConfig=[{isAll=false, tables=[PERSON]}],"
+ " CommonTableScan(table=[[common#378747223]])\n"
+ " GraphLogicalGetV(tableConfig=[{isAll=false, tables=[PERSON]}],"
+ " alias=[liker], opt=[OTHER])\n"
+ " GraphLogicalExpand(tableConfig=[{isAll=false,"
+ " GraphLogicalExpand(tableConfig=[{isAll=false,"
+ " tables=[KNOWS]}], alias=[k], startAlias=[person], opt=[BOTH],"
+ " optional=[true])\n"
+ " CommonTableScan(table=[[common#378747223]])\n"
+ " CommonTableScan(table=[[common#378747223]])\n"
+ "common#378747223:\n"
+ "GraphPhysicalExpand(tableConfig=[{isAll=false, tables=[HASCREATOR]}],"
+ " alias=[message], startAlias=[person], opt=[IN], physicalOpt=[VERTEX])\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -713,4 +713,22 @@ public void optional_shortest_path_test() {
+ " alias=[p1], opt=[VERTEX], uniqueKeyFilters=[=(_.id, ?0)])",
after.explain().trim());
}

// the return column order should align with the query given
@Test
public void aggregate_column_order_test() {
GraphBuilder builder =
com.alibaba.graphscope.common.ir.Utils.mockGraphBuilder(optimizer, irMeta);
RelNode node =
Utils.eval("Match (n:person) Return count(n), n, sum(n.age)", builder).build();
RelNode after = optimizer.optimize(node, new GraphIOProcessor(builder, irMeta));
Assert.assertEquals(
"GraphLogicalProject($f1=[$f1], n=[n], $f2=[$f2], isAppend=[false])\n"
+ " GraphLogicalAggregate(keys=[{variables=[n], aliases=[n]}],"
+ " values=[[{operands=[n], aggFunction=COUNT, alias='$f1', distinct=false},"
+ " {operands=[n.age], aggFunction=SUM, alias='$f2', distinct=false}]])\n"
+ " GraphLogicalSource(tableConfig=[{isAll=false, tables=[person]}],"
+ " alias=[n], opt=[VERTEX])",
after.explain().trim());
}
}

0 comments on commit 18df95b

Please sign in to comment.