diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/planner/GraphIOProcessor.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/planner/GraphIOProcessor.java index 6d2ff73cfe5a..be525e0e8211 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/planner/GraphIOProcessor.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/planner/GraphIOProcessor.java @@ -35,6 +35,7 @@ import com.alibaba.graphscope.common.ir.rex.RexGraphVariable; import com.alibaba.graphscope.common.ir.tools.AliasInference; import com.alibaba.graphscope.common.ir.tools.GraphBuilder; +import com.alibaba.graphscope.common.ir.tools.GraphStdOperatorTable; import com.alibaba.graphscope.common.ir.tools.Utils; import com.alibaba.graphscope.common.ir.tools.config.*; import com.alibaba.graphscope.common.ir.type.GraphLabelType; @@ -59,7 +60,6 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexUtil; import org.apache.calcite.rex.RexVariable; -import org.apache.calcite.sql.fun.SqlLibraryOperators; import org.apache.calcite.util.ImmutableBitSet; import org.apache.commons.lang3.ObjectUtils; import org.checkerframework.checker.nullness.qual.Nullable; @@ -574,10 +574,12 @@ public RelNode visit(GraphJoinDecomposition decomposition) { decomposition.getBuildPattern(), buildOrderMap, new ParentPattern(decomposition.getParentPatten(), 1)); + RelNode joinLeft = decomposition.getLeft(); + RelNode joinRight = decomposition.getRight(); this.details = probeDetails; - RelNode newLeft = visitChild(decomposition, 0, decomposition.getLeft()).getInput(0); + RelNode newLeft = visitChild(decomposition, 0, joinLeft).getInput(0); this.details = buildDetails; - RelNode newRight = visitChild(decomposition, 1, decomposition.getRight()).getInput(1); + RelNode newRight = visitChild(decomposition, 1, joinRight).getInput(1); RexNode joinCondition = createJoinFilter( jointVertices, @@ -651,9 +653,19 @@ public RelNode visit(GraphJoinDecomposition decomposition) { String buildAlias = buildValue.getAlias(); concatExprs.add( builder.call( - SqlLibraryOperators.ARRAY_CONCAT, + GraphStdOperatorTable.PATH_CONCAT, builder.variable(probeAlias), - builder.variable(buildAlias))); + builder.getRexBuilder() + .makeFlag( + getConcatDirection( + probeJointVertex, + joinLeft)), + builder.variable(buildAlias), + builder.getRexBuilder() + .makeFlag( + getConcatDirection( + buildJointVertex, + joinRight)))); concatAliases.add(probeValue.getParentAlias()); } } @@ -667,6 +679,12 @@ public RelNode visit(GraphJoinDecomposition decomposition) { return builder.build(); } + private GraphOpt.GetV getConcatDirection(PatternVertex concatVertex, RelNode splitPattern) { + ConcatDirectionVisitor visitor = new ConcatDirectionVisitor(concatVertex); + visitor.go(splitPattern); + return visitor.direction; + } + private RexNode createJoinFilter( List jointVertices, Map vertexDetails, @@ -1163,5 +1181,52 @@ private EdgeDataKey createEdgeKey(ExtendEdge edge, GlogueExtendIntersectEdge glo } return new EdgeDataKey(srcOrderId, targetOrderId, edge.getDirection()); } + + // given a concat vertex, help to determine its direction in the split path expand + private class ConcatDirectionVisitor extends RelVisitor { + private GraphOpt.GetV direction; + private final PatternVertex concatVertex; + + public ConcatDirectionVisitor(PatternVertex concatVertex) { + this.concatVertex = concatVertex; + this.direction = null; + } + + @Override + public void visit(RelNode node, int ordinal, @Nullable RelNode parent) { + if (direction != null) return; + if (node instanceof GraphExtendIntersect) { + GlogueExtendIntersectEdge intersect = + ((GraphExtendIntersect) node).getGlogueEdge(); + ExtendStep extendStep = intersect.getExtendStep(); + PatternVertex dstVertex = + intersect + .getDstPattern() + .getVertexByOrder(extendStep.getTargetVertexOrder()); + if (dstVertex.equals(concatVertex)) { + direction = GraphOpt.GetV.END; + return; + } + for (ExtendEdge edge : extendStep.getExtendEdges()) { + PatternVertex srcVertex = + intersect + .getSrcPattern() + .getVertexByOrder(edge.getSrcVertexOrder()); + if (srcVertex.equals(concatVertex)) { + direction = GraphOpt.GetV.START; + return; + } + } + } else if (node instanceof GraphPattern) { + PatternVertex vertex = + ((GraphPattern) node).getPattern().getVertexSet().iterator().next(); + if (vertex.equals(concatVertex)) { + direction = GraphOpt.GetV.START; + return; + } + } + node.childrenAccept(this); + } + } } } diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/RexToProtoConverter.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/RexToProtoConverter.java index 63e6d5696f19..bddfc23f9ca4 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/RexToProtoConverter.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/RexToProtoConverter.java @@ -19,6 +19,7 @@ import com.alibaba.graphscope.common.ir.rex.RexGraphVariable; import com.alibaba.graphscope.common.ir.tools.AliasInference; import com.alibaba.graphscope.common.ir.tools.GraphStdOperatorTable; +import com.alibaba.graphscope.common.ir.tools.config.GraphOpt; import com.alibaba.graphscope.gaia.proto.Common; import com.alibaba.graphscope.gaia.proto.DataType; import com.alibaba.graphscope.gaia.proto.OuterExpression; @@ -58,8 +59,9 @@ public OuterExpression.Expression visitCall(RexCall call) { return visitArrayValueConstructor(call); } else if (operator.getKind() == SqlKind.MAP_VALUE_CONSTRUCTOR) { return visitMapValueConstructor(call); - } else if (operator.getKind() == SqlKind.ARRAY_CONCAT) { - return visitArrayConcat(call); + } else if (operator.getKind() == SqlKind.OTHER + && operator.getName().equals("PATH_CONCAT")) { + return visitPathConcat(call); } else if (operator.getKind() == SqlKind.EXTRACT) { return visitExtract(call); } else if (operator.getKind() == SqlKind.OTHER @@ -72,6 +74,48 @@ public OuterExpression.Expression visitCall(RexCall call) { } } + private OuterExpression.Expression visitPathConcat(RexCall call) { + List operands = call.getOperands(); + return OuterExpression.Expression.newBuilder() + .addOperators( + OuterExpression.ExprOpr.newBuilder() + .setPathConcat( + OuterExpression.PathConcat.newBuilder() + .setLeft(convertPathInfo(operands)) + .setRight( + convertPathInfo( + operands.subList( + 2, operands.size()))))) + .build(); + } + + private OuterExpression.PathConcat.ConcatPathInfo convertPathInfo(List operands) { + Preconditions.checkArgument( + operands.size() >= 2 + && operands.get(0) instanceof RexGraphVariable + && operands.get(1) instanceof RexLiteral); + OuterExpression.Variable variable = operands.get(0).accept(this).getOperators(0).getVar(); + GraphOpt.GetV direction = ((RexLiteral) operands.get(1)).getValueAs(GraphOpt.GetV.class); + return OuterExpression.PathConcat.ConcatPathInfo.newBuilder() + .setPathTag(variable) + .setEndpoint(convertPathDirection(direction)) + .build(); + } + + private OuterExpression.PathConcat.Endpoint convertPathDirection(GraphOpt.GetV direction) { + switch (direction) { + case START: + return OuterExpression.PathConcat.Endpoint.START; + case END: + return OuterExpression.PathConcat.Endpoint.END; + default: + throw new IllegalArgumentException( + "invalid path concat direction [" + + direction.name() + + "], cannot convert to any physical expression"); + } + } + private OuterExpression.Expression visitDateMinus(RexCall call) { OuterExpression.Expression.Builder exprBuilder = OuterExpression.Expression.newBuilder(); RexLiteral interval = (RexLiteral) call.getOperands().get(2); @@ -124,25 +168,6 @@ private OuterExpression.Expression visitCase(RexCall call) { .build(); } - private OuterExpression.Expression visitArrayConcat(RexCall call) { - OuterExpression.Concat.Builder concatBuilder = OuterExpression.Concat.newBuilder(); - call.getOperands() - .forEach( - operand -> { - Preconditions.checkArgument( - operand instanceof RexGraphVariable, - "parameters of 'CONCAT' should be" - + " 'variable' in ir core structure"); - concatBuilder.addVars(operand.accept(this).getOperators(0).getVar()); - }); - return OuterExpression.Expression.newBuilder() - .addOperators( - OuterExpression.ExprOpr.newBuilder() - .setConcat(concatBuilder) - .setNodeType(Utils.protoIrDataType(call.getType(), isColumnId))) - .build(); - } - private OuterExpression.Expression visitArrayValueConstructor(RexCall call) { OuterExpression.VariableKeys.Builder varsBuilder = OuterExpression.VariableKeys.newBuilder(); diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphBuilder.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphBuilder.java index 579881448618..2f16b86c4264 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphBuilder.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphBuilder.java @@ -547,6 +547,45 @@ public RexGraphVariable variable(@Nullable String alias, String property) { columnField.left, varName, getTypeFactory().createSqlType(SqlTypeName.ANY)); + } else if (property.equals(GraphProperty.START_V_KEY)) { + if (!(aliasField.getType() instanceof GraphPathType)) { + throw new ClassCastException( + "cannot get property='start_v' from type class [" + + aliasField.getType().getClass() + + "], should be [" + + GraphPathType.class + + "]"); + } else { + Preconditions.checkArgument(size() > 0, "frame stack is empty"); + RelNode peek = peek(); + Preconditions.checkArgument( + peek != null && !peek.getInputs().isEmpty(), + "path expand should have start vertex"); + RelNode input = peek.getInput(0); + return RexGraphVariable.of( + aliasField.getIndex(), + new GraphProperty(GraphProperty.Opt.START_V), + columnField.left, + varName, + input.getRowType().getFieldList().get(0).getType()); + } + } else if (property.equals(GraphProperty.END_V_KEY)) { + if (!(aliasField.getType() instanceof GraphPathType)) { + throw new ClassCastException( + "cannot get property='end_v' from type class [" + + aliasField.getType().getClass() + + "], should be [" + + GraphPathType.class + + "]"); + } else { + GraphPathType pathType = (GraphPathType) aliasField.getType(); + return RexGraphVariable.of( + aliasField.getIndex(), + new GraphProperty(GraphProperty.Opt.END_V), + columnField.left, + varName, + pathType.getComponentType().getGetVType()); + } } GraphSchemaType graphType = (GraphSchemaType) aliasField.getType(); List properties = new ArrayList<>(); @@ -815,7 +854,8 @@ private boolean isCurrentSupported(SqlOperator operator) { || sqlKind == SqlKind.BIT_XOR || (sqlKind == SqlKind.OTHER && (operator.getName().equals("IN") - || operator.getName().equals("DATETIME_MINUS"))) + || operator.getName().equals("DATETIME_MINUS") + || operator.getName().equals("PATH_CONCAT"))) || sqlKind == SqlKind.ARRAY_CONCAT; } diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphStdOperatorTable.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphStdOperatorTable.java index 01e33d048a2e..1d8d64ba79e7 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphStdOperatorTable.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphStdOperatorTable.java @@ -20,6 +20,8 @@ import com.alibaba.graphscope.common.ir.rex.operator.CaseOperator; import com.alibaba.graphscope.common.ir.rex.operator.SqlArrayValueConstructor; import com.alibaba.graphscope.common.ir.rex.operator.SqlMapValueConstructor; +import com.alibaba.graphscope.common.ir.type.GraphTypeFamily; +import com.google.common.collect.ImmutableList; import org.apache.calcite.sql.*; import org.apache.calcite.sql.fun.ExtSqlPosixRegexOperator; @@ -296,4 +298,27 @@ public static final SqlFunction USER_DEFINED_PROCEDURE(StoredProcedureMeta meta) ReturnTypes.BOOLEAN_NULLABLE, GraphInferTypes.IN_OPERANDS_TYPE, OperandTypes.ANY); + + public static final SqlOperator PATH_CONCAT = + new SqlFunction( + "PATH_CONCAT", + SqlKind.OTHER, + ReturnTypes.ARG0, + null, + GraphOperandTypes.operandMetadata( + ImmutableList.of( + GraphTypeFamily.PATH, + SqlTypeFamily.IGNORE, + GraphTypeFamily.PATH, + SqlTypeFamily.IGNORE), + typeFactory -> ImmutableList.of(), + i -> + ImmutableList.of( + "LeftPath", + "LeftDirection", + "RightPath", + "RightDirection") + .get(i), + i -> false), + SqlFunctionCategory.SYSTEM); } diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/config/GraphOpt.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/config/GraphOpt.java index 0b2e5aefc4f6..36c8e5926097 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/config/GraphOpt.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/config/GraphOpt.java @@ -16,10 +16,8 @@ package com.alibaba.graphscope.common.ir.tools.config; -import org.apache.calcite.rel.type.RelDataTypeFamily; - public abstract class GraphOpt { - public enum Source implements RelDataTypeFamily { + public enum Source { VERTEX, EDGE } diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphPathType.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphPathType.java index e06a577efdf5..f69106b83760 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphPathType.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphPathType.java @@ -17,6 +17,7 @@ package com.alibaba.graphscope.common.ir.type; import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFamily; import org.apache.calcite.sql.type.AbstractSqlType; import org.apache.calcite.sql.type.ArraySqlType; import org.apache.calcite.sql.type.SqlTypeName; @@ -78,4 +79,9 @@ public String getFullTypeString() { return sb.toString(); } } + + @Override + public RelDataTypeFamily getFamily() { + return GraphTypeFamily.PATH; + } } diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphProperty.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphProperty.java index 70e67addf2db..6faab3cc004a 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphProperty.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphProperty.java @@ -25,6 +25,8 @@ public class GraphProperty { public static final String ALL_KEY = "~all"; public static final String ID_KEY = "~id"; public static final String LABEL_KEY = "~label"; + public static final String START_V_KEY = "~start_v"; + public static final String END_V_KEY = "~end_v"; private final Opt opt; private final @Nullable GraphNameOrId key; @@ -65,6 +67,8 @@ public enum Opt { LABEL, LEN, ALL, - KEY + KEY, + START_V, + END_V } } diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphSchemaType.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphSchemaType.java index 579370b1b0f6..d730365777ae 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphSchemaType.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphSchemaType.java @@ -191,7 +191,7 @@ public boolean isStruct() { @Override public RelDataTypeFamily getFamily() { - return scanOpt; + return scanOpt == GraphOpt.Source.VERTEX ? GraphTypeFamily.VERTEX : GraphTypeFamily.EDGE; } public List getSchemaTypeAsList() { diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphTypeFamily.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphTypeFamily.java new file mode 100644 index 000000000000..7a8396339d99 --- /dev/null +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphTypeFamily.java @@ -0,0 +1,27 @@ +/* + * + * * Copyright 2020 Alibaba Group Holding Limited. + * * + * * Licensed under the Apache License, Version 2.0 (the "License"); + * * you may not use this file except in compliance with the License. + * * You may obtain a copy of the License at + * * + * * http://www.apache.org/licenses/LICENSE-2.0 + * * + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, + * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * * See the License for the specific language governing permissions and + * * limitations under the License. + * + */ + +package com.alibaba.graphscope.common.ir.type; + +import org.apache.calcite.rel.type.RelDataTypeFamily; + +public enum GraphTypeFamily implements RelDataTypeFamily { + PATH, + VERTEX, + EDGE, +} diff --git a/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphFamilyOperandTypeChecker.java b/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphFamilyOperandTypeChecker.java index 182c1389c596..1eb5e5f54fb4 100644 --- a/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphFamilyOperandTypeChecker.java +++ b/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphFamilyOperandTypeChecker.java @@ -21,10 +21,14 @@ import org.apache.calcite.linq4j.Ord; import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeFamily; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlCallBinding; import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlUtil; import org.apache.calcite.sql.validate.implicit.TypeCoercion; import org.apache.calcite.util.Litmus; @@ -38,16 +42,20 @@ * we have to make the subclass under the same package {@code org.apache.calcite.sql.type} */ public class GraphFamilyOperandTypeChecker extends FamilyOperandTypeChecker { + protected final List expectedFamilies; + protected GraphFamilyOperandTypeChecker( - List families, Predicate optional) { - super(families, optional); + List typeFamilies, Predicate optional) { + super(ImmutableList.of(), optional); + this.expectedFamilies = typeFamilies; } - @Override public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { - if (families.size() != callBinding.getOperandCount()) { + if (expectedFamilies.size() != callBinding.getOperandCount()) { Litmus.THROW.fail( - "wrong operand count {} for {}", callBinding.getOperandCount(), families); + "wrong operand count {} for {}", + callBinding.getOperandCount(), + expectedFamilies); } if (!(callBinding instanceof RexCallBinding)) { throw new IllegalArgumentException( @@ -83,10 +91,9 @@ public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFail return true; } - @Override public boolean checkOperandTypesWithoutTypeCoercion( SqlCallBinding callBinding, boolean throwOnFailure) { - if (families.size() != callBinding.getOperandCount()) { + if (expectedFamilies.size() != callBinding.getOperandCount()) { // assume this is an inapplicable sub-rule of a composite rule; // don't throw exception. return false; @@ -124,54 +131,74 @@ public boolean checkSingleOperandType( + " type"); } - private boolean checkSingleOperandType( + @Override + public String getAllowedSignatures(SqlOperator op, String opName) { + return SqlUtil.getAliasedSignature(op, opName, this.expectedFamilies); + } + + protected boolean checkSingleOperandType( SqlCallBinding callBinding, RexNode node, int iFormalOperand, boolean throwOnFailure) { - final SqlTypeFamily family = families.get(iFormalOperand); - switch (family) { - case ANY: - SqlTypeName typeName = node.getType().getSqlTypeName(); - if (typeName == SqlTypeName.CURSOR) { - // We do not allow CURSOR operands, even for ANY - if (throwOnFailure) { - throw callBinding.newValidationSignatureError(); + RelDataTypeFamily expectedFamily = expectedFamilies.get(iFormalOperand); + RelDataType type = node.getType(); + if (expectedFamily instanceof SqlTypeFamily) { + SqlTypeFamily sqlTypeFamily = (SqlTypeFamily) expectedFamily; + switch (sqlTypeFamily) { + case ANY: + SqlTypeName typeName = node.getType().getSqlTypeName(); + if (typeName == SqlTypeName.CURSOR) { + // We do not allow CURSOR operands, even for ANY + if (throwOnFailure) { + throw callBinding.newValidationSignatureError(); + } + return false; } + // fall through + case IGNORE: + // no need to check + return true; + default: + break; + } + if (isNullLiteral(node)) { + if (callBinding.isTypeCoercionEnabled()) { + return true; + } else if (throwOnFailure) { + throw new IllegalArgumentException( + "node " + node + " should not be of null value"); + } else { return false; } - // fall through - case IGNORE: - // no need to check - return true; - default: - break; - } - if (isNullLiteral(node)) { - if (callBinding.isTypeCoercionEnabled()) { - return true; - } else if (throwOnFailure) { - throw new IllegalArgumentException("node " + node + " should not be of null value"); - } else { - return false; } - } - RelDataType type = node.getType(); - SqlTypeName typeName = type.getSqlTypeName(); - // Pass type checking for operators if it's of type 'ANY'. - if (typeName.getFamily() == SqlTypeFamily.ANY) { - return true; - } + SqlTypeName typeName = type.getSqlTypeName(); - if (!getAllowedTypeNames(family, iFormalOperand).contains(typeName)) { - if (throwOnFailure) { - throw callBinding.newValidationSignatureError(); + // Pass type checking for operators if it's of type 'ANY'. + if (typeName.getFamily() == SqlTypeFamily.ANY) { + return true; + } + + if (!getAllowedTypeNames(callBinding.getTypeFactory(), sqlTypeFamily, iFormalOperand) + .contains(typeName)) { + if (throwOnFailure) { + throw callBinding.newValidationSignatureError(); + } + return false; + } + } else { + if (type.getFamily() == SqlTypeFamily.ANY || expectedFamily == SqlTypeFamily.ANY) + return true; + if (type.getFamily() != expectedFamily) { + if (throwOnFailure) { + throw callBinding.newValidationSignatureError(); + } + return false; } - return false; } return true; } protected Collection getAllowedTypeNames( - SqlTypeFamily family, int iFormalOperand) { + RelDataTypeFactory typeFactory, SqlTypeFamily family, int iFormalOperand) { return family.getTypeNames(); } diff --git a/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphOperandMetaDataImpl.java b/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphOperandMetaDataImpl.java index d162c50a7f48..76dfa03a2b87 100644 --- a/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphOperandMetaDataImpl.java +++ b/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphOperandMetaDataImpl.java @@ -22,8 +22,10 @@ import org.apache.calcite.linq4j.function.Functions; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeFamily; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.SqlUtil; +import org.apache.commons.lang3.ObjectUtils; import org.checkerframework.checker.nullness.qual.Nullable; import java.util.Collection; @@ -40,19 +42,19 @@ public class GraphOperandMetaDataImpl extends GraphFamilyOperandTypeChecker private final IntFunction paramNameFn; GraphOperandMetaDataImpl( - List families, + List expectedexpectedFamilies, Function<@Nullable RelDataTypeFactory, List> paramTypesFactory, IntFunction paramNameFn, Predicate optional) { - super(families, optional); + super(expectedexpectedFamilies, optional); this.paramTypesFactory = Objects.requireNonNull(paramTypesFactory, "paramTypesFactory"); this.paramNameFn = paramNameFn; } @Override protected Collection getAllowedTypeNames( - SqlTypeFamily family, int iFormalOperand) { - List paramsAllowedTypes = paramTypes(null); + RelDataTypeFactory typeFactory, SqlTypeFamily family, int iFormalOperand) { + List paramsAllowedTypes = paramTypes(typeFactory); Preconditions.checkArgument( paramsAllowedTypes.size() > iFormalOperand, "cannot find allowed type for type index=" @@ -74,16 +76,18 @@ public List paramTypes(@Nullable RelDataTypeFactory typeFactory) { @Override public List paramNames() { - return Functions.generate(this.families.size(), this.paramNameFn); + return Functions.generate(this.expectedFamilies.size(), this.paramNameFn); } @Override public String getAllowedSignatures(SqlOperator op, String opName) { - return SqlUtil.getAliasedSignature( - op, - opName, - paramTypes(null).stream() - .map(k -> k.getSqlTypeName()) - .collect(Collectors.toList())); + List paramTypes = paramTypes(null); + List signatureTypes = + ObjectUtils.isEmpty(paramTypes) + ? this.expectedFamilies + : paramTypes.stream() + .map(k -> k.getSqlTypeName()) + .collect(Collectors.toList()); + return SqlUtil.getAliasedSignature(op, opName, signatureTypes); } } diff --git a/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphOperandTypes.java b/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphOperandTypes.java index 87cefaf32b97..8d06e8b41902 100644 --- a/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphOperandTypes.java +++ b/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphOperandTypes.java @@ -20,6 +20,7 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeFamily; import java.util.List; import java.util.function.Function; @@ -96,7 +97,7 @@ public static FamilyOperandTypeChecker family(SqlTypeFamily... families) { OperandTypes.or(NUMERIC_NUMERIC, INTERVAL_NUMERIC); public static SqlOperandMetadata operandMetadata( - List families, + List families, Function> typesFactory, IntFunction operandName, Predicate optional) { diff --git a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/common/ir/planner/cbo/STPathTest.java b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/common/ir/planner/cbo/STPathTest.java index d31d763a125e..91c5eebf481d 100644 --- a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/common/ir/planner/cbo/STPathTest.java +++ b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/common/ir/planner/cbo/STPathTest.java @@ -63,7 +63,7 @@ public static void beforeClass() { } @Test - public void st_path_test() { + public void st_path_person_id_knows_person_id() { GraphBuilder builder = Utils.mockGraphBuilder(optimizer, irMeta); RelNode before = com.alibaba.graphscope.cypher.antlr4.Utils.eval( diff --git a/interactive_engine/compiler/src/test/resources/proto/st_path_test.json b/interactive_engine/compiler/src/test/resources/proto/st_path_test.json index f35084ebb859..9a78f1a00e11 100644 --- a/interactive_engine/compiler/src/test/resources/proto/st_path_test.json +++ b/interactive_engine/compiler/src/test/resources/proto/st_path_test.json @@ -585,22 +585,27 @@ "mappings": [{ "expr": { "operators": [{ - "nodeType": { - }, - "concat": { - "vars": [{ - "tag": { - "id": 3 + "pathConcat": { + "left": { + "pathTag": { + "tag": { + "id": 3 + }, + "nodeType": { + } }, - "nodeType": { - } - }, { - "tag": { - "id": 5 + "endpoint": "END" + }, + "right": { + "pathTag": { + "tag": { + "id": 5 + }, + "nodeType": { + } }, - "nodeType": { - } - }] + "endpoint": "END" + } } }] }, @@ -650,4 +655,4 @@ } } }] -} \ No newline at end of file +} diff --git a/interactive_engine/executor/ir/graph_proxy/src/apis/graph/element/path.rs b/interactive_engine/executor/ir/graph_proxy/src/apis/graph/element/path.rs index fa5e2b1961b0..1832e08dfb27 100644 --- a/interactive_engine/executor/ir/graph_proxy/src/apis/graph/element/path.rs +++ b/interactive_engine/executor/ir/graph_proxy/src/apis/graph/element/path.rs @@ -152,6 +152,13 @@ impl GraphPath { } } + pub fn get_path_start(&self) -> Option<&VertexOrEdge> { + match self { + GraphPath::AllPath(ref p) | GraphPath::SimpleAllPath(ref p) => p.first(), + GraphPath::EndV(_) | GraphPath::SimpleEndV(_) => None, + } + } + pub fn get_path_end(&self) -> &VertexOrEdge { match self { GraphPath::AllPath(ref p) | GraphPath::SimpleAllPath(ref p) => p.last().unwrap(), @@ -172,6 +179,41 @@ impl GraphPath { GraphPath::EndV(_) | GraphPath::SimpleEndV(_) => None, } } + + // append another path to the current path, and return the flag of whether the path has been appended or not. + // notice that, if the path is a simple path, we simply concatenate the two paths, without checking the duplication (this may not be as expected) + // e.g., [1,2,3] + [4,5] = [1,2,3,4,5] + pub fn append_path(&mut self, other: GraphPath) -> bool { + match self { + GraphPath::AllPath(ref mut p) | GraphPath::SimpleAllPath(ref mut p) => { + if let Some(other_path) = other.take_path() { + p.extend(other_path); + true + } else { + false + } + } + GraphPath::EndV(_) | GraphPath::SimpleEndV(_) => false, + } + } + + // pop the last element from the path, and return the element. + pub fn pop(&mut self) -> Option { + match self { + GraphPath::AllPath(ref mut p) | GraphPath::SimpleAllPath(ref mut p) => p.pop(), + GraphPath::EndV(_) | GraphPath::SimpleEndV(_) => None, + } + } + + // reverse the path. + pub fn reverse(&mut self) { + match self { + GraphPath::AllPath(ref mut p) | GraphPath::SimpleAllPath(ref mut p) => { + p.reverse(); + } + GraphPath::EndV(_) | GraphPath::SimpleEndV(_) => {} + } + } } impl Element for VertexOrEdge { diff --git a/interactive_engine/executor/ir/proto/expr.proto b/interactive_engine/executor/ir/proto/expr.proto index c4a5979a5b65..0d93d7914d88 100644 --- a/interactive_engine/executor/ir/proto/expr.proto +++ b/interactive_engine/executor/ir/proto/expr.proto @@ -186,10 +186,20 @@ message DateTimeMinus { Extract.Interval interval = 1; } -// e.g., supposing p1 refers to [v1->v2->v3], p2 refers to [v3->v4->v5], -// then CONCAT(p1, p2) outputs [v1->v2->v3->v4->v5] -message Concat { - repeated Variable vars = 1; +// e.g., supposing p1 refers to path[v1->v2->v3], p2 refers to path[v5->v4->v3], +// then PathCONCAT((p1, END), (p2, END))) outputs [v1->v2->v3->v4->v5] +message PathConcat { + // Enum that defines the endpoint of a path where the concatenation will occur. + enum Endpoint { + START = 0; + END = 1; + } + message ConcatPathInfo { + Variable path_tag = 1; + Endpoint endpoint = 2; + } + ConcatPathInfo left = 1; + ConcatPathInfo right = 2; } // An operator of expression is one of Logical, Arithmetic, Const and Variable. @@ -214,7 +224,7 @@ message ExprOpr { VariableKeyValues map = 13; TimeInterval time_interval = 14; DateTimeMinus date_time_minus = 15; - Concat concat = 16; + PathConcat path_concat = 16; } // The data of type of ExprOpr common.IrDataType node_type = 12; diff --git a/interactive_engine/executor/ir/runtime/src/process/operator/map/project.rs b/interactive_engine/executor/ir/runtime/src/process/operator/map/project.rs index 7eace8d6acf3..1ae5bf3c519b 100644 --- a/interactive_engine/executor/ir/runtime/src/process/operator/map/project.rs +++ b/interactive_engine/executor/ir/runtime/src/process/operator/map/project.rs @@ -15,6 +15,7 @@ use std::convert::TryFrom; +use common_pb::path_concat::Endpoint; use dyn_type::Object; use graph_proxy::utils::expr::eval::{Evaluate, Evaluator}; use ir_common::error::ParsePbError; @@ -83,8 +84,8 @@ enum Projector { /// MapProjector will output a collection entry, which is a collection of key-value pairs. The key is a Object (preserve the user-given key), and the value is a projected graph element (computed via TagKey). /// Besides, MapProjector supports nested map. MapProjector(VariableKeyValues), - /// A simple concatenation of multiple entries. - ConcatProjector(Vec), + /// A simple concatenation of two paths. + PathConcatProjector((TagKey, Endpoint), (TagKey, Endpoint)), } // TODO: @@ -111,34 +112,76 @@ fn exec_projector(input: &Record, projector: &Projector) -> FnExecResult map.exec_projector(input)?, - Projector::ConcatProjector(concat_vars) => { - if concat_vars.len() != 2 { - Err(FnExecError::unsupported_error("Only support concatenated 2 entries now"))? - } else { - let left_path = concat_vars[0] - .get_arc_entry(input)? - .as_graph_path() - .cloned(); - let right_path = concat_vars[1] - .get_arc_entry(input)? - .as_graph_path() - .cloned(); - if left_path.is_none() || right_path.is_none() { - Err(FnExecError::unsupported_error("Concatenated entries are not Path"))? - } else { - let mut left_path = left_path.unwrap(); - let right_path = right_path.unwrap().take_path(); - if let Some(mut right_path) = right_path { - // specifically, we pop the last entry of right_path, which is already in left_path - right_path.pop(); + Projector::PathConcatProjector((left, left_endpoint), (right, right_endpoint)) => { + let mut left_path = left + .get_arc_entry(input)? + .as_graph_path() + .ok_or_else(|| FnExecError::unsupported_error("Left entry is not Path in PathConcat"))? + .clone(); + let mut right_path = right + .get_arc_entry(input)? + .as_graph_path() + .ok_or_else(|| FnExecError::unsupported_error("Right entry is not Path in PathConcat"))? + .clone(); + + let mut invalid = false; + let mut concat_success = false; + match (left_endpoint, right_endpoint) { + // e.g., concat [3,2,1], [3,4,5] => [1,2,3,4,5] + (Endpoint::Start, Endpoint::Start) => { + if left_path.get_path_start() != right_path.get_path_start() { + invalid = true; + } else { + left_path.reverse(); + left_path.pop(); + concat_success = left_path.append_path(right_path); + } + } + (Endpoint::Start, Endpoint::End) => { + // e.g., concat [3,2,1], [5,4,3] => [1,2,3,4,5] + if left_path.get_path_start().is_none() + || (left_path.get_path_start().unwrap() != right_path.get_path_end()) + { + invalid = true; + } else { + left_path.reverse(); + left_path.pop(); + right_path.reverse(); + concat_success = left_path.append_path(right_path); + } + } + (Endpoint::End, Endpoint::Start) => { + // e.g., concat [1,2,3], [3,4,5] => [1,2,3,4,5] + if right_path.get_path_start().is_none() + || (right_path.get_path_start().unwrap() != left_path.get_path_end()) + { + invalid = true; + } else { + left_path.pop(); + concat_success = left_path.append_path(right_path); + } + } + (Endpoint::End, Endpoint::End) => { + // e.g., concat [1,2,3], [5,4,3] => [1,2,3,4,5] + if left_path.get_path_end() != right_path.get_path_end() { + invalid = true; + } else { + left_path.pop(); right_path.reverse(); - for entry in right_path { - left_path.append(entry); - } + concat_success = left_path.append_path(right_path); } - DynEntry::new(left_path) } } + + if invalid { + Err(FnExecError::unexpected_data_error(&format!( + "Concat vertices are not the same in PathConcat" + )))? + } else if !concat_success { + Err(FnExecError::unexpected_data_error(&format!("Failed to concat paths in PathConcat")))? + } else { + DynEntry::new(left_path) + } } }; Ok(entry) @@ -221,15 +264,42 @@ impl FilterMapFuncGen for pb::Project { Projector::MapProjector(variable_key_values) } common_pb::ExprOpr { - item: Some(common_pb::expr_opr::Item::Concat(concat_vars)), + item: Some(common_pb::expr_opr::Item::PathConcat(concat_vars)), .. } => { - let tag_keys = concat_vars - .vars - .iter() - .map(|var| TagKey::try_from(var.clone())) - .collect::, _>>()?; - Projector::ConcatProjector(tag_keys) + let left = concat_vars.left.as_ref().ok_or_else(|| { + ParsePbError::EmptyFieldError(format!( + "left in PathConcat Expr {:?}", + concat_vars + )) + })?; + let left_path_tag = left.path_tag.clone().ok_or_else(|| { + ParsePbError::EmptyFieldError(format!( + "path_tag in PathConcat Expr {:?}", + concat_vars + )) + })?; + let left_endpoint: common_pb::path_concat::Endpoint = + unsafe { std::mem::transmute(left.endpoint) }; + let right = concat_vars.right.as_ref().ok_or_else(|| { + ParsePbError::EmptyFieldError(format!( + "right in PathConcat Expr {:?}", + concat_vars + )) + })?; + + let right_path_tag = right.path_tag.clone().ok_or_else(|| { + ParsePbError::EmptyFieldError(format!( + "path_tag in PathConcat Expr {:?}", + concat_vars + )) + })?; + let right_endpoint: common_pb::path_concat::Endpoint = + unsafe { std::mem::transmute(right.endpoint) }; + Projector::PathConcatProjector( + (TagKey::try_from(left_path_tag)?, left_endpoint), + (TagKey::try_from(right_path_tag)?, right_endpoint), + ) } _ => { let evaluator = Evaluator::try_from(expr)?; @@ -283,6 +353,8 @@ impl TryFrom for VariableKeyValues { #[cfg(test)] mod tests { + use std::vec; + use ahash::HashMap; use dyn_type::Object; use graph_proxy::apis::{DynDetails, Edge, GraphElement, GraphPath, Vertex}; @@ -1297,54 +1369,52 @@ mod tests { assert_eq!(results, expected_results); } - #[test] - fn project_concat_allv_path_test() { + fn build_path(vids: Vec) -> GraphPath { let details = DynDetails::default(); - // sub_path1: [1,2] - let mut sub_path1 = GraphPath::new( - Vertex::new(1, None, details.clone()), + let mut path = GraphPath::new( + Vertex::new(vids[0], None, details.clone()), pb::path_expand::PathOpt::Arbitrary, pb::path_expand::ResultOpt::AllV, ); - sub_path1.append(Vertex::new(2, None, details.clone())); - // sub_path2: [3,2] - let mut sub_path2 = GraphPath::new( - Vertex::new(3, None, details.clone()), - pb::path_expand::PathOpt::Arbitrary, - pb::path_expand::ResultOpt::AllV, - ); - sub_path2.append(Vertex::new(2, None, details.clone())); - // concat path: [1,2,3] - let mut concat_path = GraphPath::new( - Vertex::new(1, None, details.clone()), - pb::path_expand::PathOpt::Arbitrary, - pb::path_expand::ResultOpt::AllV, - ); - concat_path.append(Vertex::new(2, None, details.clone())); - concat_path.append(Vertex::new(3, None, details.clone())); - - let mut r1 = Record::new(sub_path1, Some(TAG_A.into())); - r1.append(sub_path2, Some(TAG_B.into())); + for i in 1..vids.len() { + path.append(Vertex::new(vids[i], None, details.clone())); + } + path + } - let source = vec![r1]; - let project_opr_pb = pb::Project { + fn build_project_path_concat( + left_endpoint: common_pb::path_concat::Endpoint, right_endpoint: common_pb::path_concat::Endpoint, + ) -> pb::Project { + let path_concat = common_pb::PathConcat { + left: Some(common_pb::path_concat::ConcatPathInfo { + path_tag: Some(to_var_pb(Some(TAG_A.into()), None)), + endpoint: left_endpoint as i32, + }), + right: Some(common_pb::path_concat::ConcatPathInfo { + path_tag: Some(to_var_pb(Some(TAG_B.into()), None)), + endpoint: right_endpoint as i32, + }), + }; + pb::Project { mappings: vec![pb::project::ExprAlias { expr: Some(common_pb::Expression { operators: vec![common_pb::ExprOpr { - item: Some(common_pb::expr_opr::Item::Concat(common_pb::Concat { - vars: vec![ - to_var_pb(Some(TAG_A.into()), None), - to_var_pb(Some(TAG_B.into()), None), - ], - })), + item: Some(common_pb::expr_opr::Item::PathConcat(path_concat)), node_type: None, }], }), alias: Some(TAG_C.into()), }], is_append: false, - }; - let mut result = project_test(source, project_opr_pb); + } + } + + fn project_concat_allv_path_test( + left_path: GraphPath, right_path: GraphPath, project_opr_pb: pb::Project, concat_path: GraphPath, + ) { + let mut r1 = Record::new(left_path, Some(TAG_A.into())); + r1.append(right_path, Some(TAG_B.into())); + let mut result = project_test(vec![r1], project_opr_pb); let mut results = vec![]; while let Some(Ok(res)) = result.next() { let path = res @@ -1358,6 +1428,70 @@ mod tests { assert_eq!(results, vec![concat_path]); } + #[test] + fn project_concat_allv_path_test_01() { + // sub_path1: [1,2] + let sub_path1 = build_path(vec![1, 2]); + // sub_path2: [3,2] + let sub_path2 = build_path(vec![3, 2]); + // concat project + let project_opr_pb = build_project_path_concat( + common_pb::path_concat::Endpoint::End, + common_pb::path_concat::Endpoint::End, + ); + // concat path: [1,2,3] + let concat_path = build_path(vec![1, 2, 3]); + project_concat_allv_path_test(sub_path1, sub_path2, project_opr_pb, concat_path); + } + + #[test] + fn project_concat_allv_path_test_02() { + // sub_path1: [1,2] + let sub_path1 = build_path(vec![1, 2]); + // sub_path2: [2,3] + let sub_path2 = build_path(vec![2, 3]); + // concat project + let project_opr_pb = build_project_path_concat( + common_pb::path_concat::Endpoint::End, + common_pb::path_concat::Endpoint::Start, + ); + // concat path: [1,2,3] + let concat_path = build_path(vec![1, 2, 3]); + project_concat_allv_path_test(sub_path1, sub_path2, project_opr_pb, concat_path); + } + + #[test] + fn project_concat_allv_path_test_03() { + // sub_path1: [2,1] + let sub_path1 = build_path(vec![2, 1]); + // sub_path2: [3,2] + let sub_path2 = build_path(vec![3, 2]); + // concat project + let project_opr_pb = build_project_path_concat( + common_pb::path_concat::Endpoint::Start, + common_pb::path_concat::Endpoint::End, + ); + // concat path: [1,2,3] + let concat_path = build_path(vec![1, 2, 3]); + project_concat_allv_path_test(sub_path1, sub_path2, project_opr_pb, concat_path); + } + + #[test] + fn project_concat_allv_path_test_04() { + // sub_path1: [2,1] + let sub_path1 = build_path(vec![2, 1]); + // sub_path2: [2,3] + let sub_path2 = build_path(vec![2, 3]); + // concat project + let project_opr_pb = build_project_path_concat( + common_pb::path_concat::Endpoint::Start, + common_pb::path_concat::Endpoint::Start, + ); + // concat path: [1,2,3] + let concat_path = build_path(vec![1, 2, 3]); + project_concat_allv_path_test(sub_path1, sub_path2, project_opr_pb, concat_path); + } + #[test] fn project_concat_allve_path_test() { let details = DynDetails::default(); @@ -1392,23 +1526,10 @@ mod tests { r1.append(sub_path2, Some(TAG_B.into())); let source = vec![r1]; - let project_opr_pb = pb::Project { - mappings: vec![pb::project::ExprAlias { - expr: Some(common_pb::Expression { - operators: vec![common_pb::ExprOpr { - item: Some(common_pb::expr_opr::Item::Concat(common_pb::Concat { - vars: vec![ - to_var_pb(Some(TAG_A.into()), None), - to_var_pb(Some(TAG_B.into()), None), - ], - })), - node_type: None, - }], - }), - alias: Some(TAG_C.into()), - }], - is_append: false, - }; + let project_opr_pb = build_project_path_concat( + common_pb::path_concat::Endpoint::End, + common_pb::path_concat::Endpoint::End, + ); let mut result = project_test(source, project_opr_pb); let mut results = vec![]; while let Some(Ok(res)) = result.next() { @@ -1422,4 +1543,25 @@ mod tests { } assert_eq!(results, vec![concat_path]); } + + // a fail test case + #[test] + fn project_concat_allv_path_error_test() { + // sub_path1: [1,2] + let sub_path1 = build_path(vec![1, 2]); + // sub_path2: [2,3] + let sub_path2 = build_path(vec![2, 3]); + // concat project, if concat sub_path1.start and sub_path2.start, it will fail + let project_opr_pb = build_project_path_concat( + common_pb::path_concat::Endpoint::Start, + common_pb::path_concat::Endpoint::Start, + ); + + let mut r1 = Record::new(sub_path1, Some(TAG_A.into())); + r1.append(sub_path2, Some(TAG_B.into())); + let mut result = project_test(vec![r1], project_opr_pb); + if let Some(res) = result.next() { + assert!(res.is_err()); + } + } }