Skip to content

Commit

Permalink
Implement Top N pushdown (#526)
Browse files Browse the repository at this point in the history
Fixes #519

Signed-off-by: Florent Biville <[email protected]>
Co-authored-by: Andrea Santurbano <[email protected]>
  • Loading branch information
fbiville and conker84 authored Sep 29, 2023
1 parent d3154ae commit e971d09
Show file tree
Hide file tree
Showing 19 changed files with 370 additions and 201 deletions.
5 changes: 5 additions & 0 deletions common/src/main/scala/org/neo4j/spark/config/TopN.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package org.neo4j.spark.config

import org.apache.spark.sql.connector.expressions.SortOrder

case class TopN(limit: Int, orders: Array[SortOrder] = Array.empty)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.neo4j.driver.{Record, Session, Transaction, Values}
import org.neo4j.spark.service.{MappingService, Neo4jQueryReadStrategy, Neo4jQueryService, Neo4jQueryStrategy, Neo4jReadMappingStrategy, PartitionSkipLimit}
import org.neo4j.spark.service.{MappingService, Neo4jQueryReadStrategy, Neo4jQueryService, Neo4jQueryStrategy, Neo4jReadMappingStrategy, PartitionPagination}
import org.neo4j.spark.util.{DriverCache, Neo4jOptions, Neo4jUtil, QueryType}

import java.io.IOException
Expand All @@ -17,7 +17,7 @@ abstract class BasePartitionReader(private val options: Neo4jOptions,
private val filters: Array[Filter],
private val schema: StructType,
private val jobId: String,
private val partitionSkipLimit: PartitionSkipLimit,
private val partitionSkipLimit: PartitionPagination,
private val scriptResult: java.util.List[java.util.Map[String, AnyRef]],
private val requiredColumns: StructType,
private val aggregateColumns: Array[AggregateFunc]) extends Logging {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package org.neo4j.spark.service

import org.apache.commons.lang3.StringUtils
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.connector.expressions.{SortDirection, SortOrder}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.sources.{And, Filter, Or}
import org.neo4j.cypherdsl.core._
Expand Down Expand Up @@ -102,21 +104,28 @@ class Neo4jQueryWriteStrategy(private val saveMode: SaveMode) extends Neo4jQuery
}

class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
partitionSkipLimit: PartitionSkipLimit = PartitionSkipLimit.EMPTY,
partitionPagination: PartitionPagination = PartitionPagination.EMPTY,
requiredColumns: Seq[String] = Seq.empty,
aggregateColumns: Array[AggregateFunc] = Array.empty,
jobId: String = "") extends Neo4jQueryStrategy {
jobId: String = "") extends Neo4jQueryStrategy with Logging {
private val renderer: Renderer = Renderer.getDefaultRenderer

private val hasSkipLimit: Boolean = partitionSkipLimit.skip != -1 && partitionSkipLimit.limit != -1
private val hasSkipLimit: Boolean = partitionPagination.skip != -1 && partitionPagination.topN.limit != -1

override def createStatementForQuery(options: Neo4jOptions): String = {
if (partitionPagination.topN.orders.nonEmpty) {
logWarning(
s"""Top N push-down optimizations with aggregations are not supported for custom queries.
|\tThese aggregations are going to be ignored.
|\tPlease specify the aggregations in the custom query directly""".stripMargin)
}
val limitedQuery = if (hasSkipLimit) {
s"""${options.query.value}
|SKIP ${partitionSkipLimit.skip} LIMIT ${partitionSkipLimit.limit}
|SKIP ${partitionPagination.skip} LIMIT ${partitionPagination.topN.limit}
|""".stripMargin
} else {
options.query.value
s"""${options.query.value}
|""".stripMargin
}
s"""WITH ${"$"}scriptResult AS ${Neo4jQueryStrategy.VARIABLE_SCRIPT_RESULT}
|$limitedQuery""".stripMargin
Expand All @@ -130,16 +139,39 @@ class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
.named(Neo4jUtil.RELATIONSHIP_ALIAS)

val matchQuery: StatementBuilder.OngoingReadingWithoutWhere = filterRelationship(sourceNode, targetNode, relationship)

val returnExpressions: Seq[Expression] = buildReturnExpression(sourceNode, targetNode, relationship)
val stmt = if (aggregateColumns.isEmpty) {
buildStatement(options, matchQuery.returning(returnExpressions : _*), relationship)
val query = matchQuery.returning(returnExpressions: _*)
buildStatement(options, query, relationship)
} else {
buildStatementAggregation(options, matchQuery, relationship, returnExpressions)
}
renderer.render(stmt)
}

private def convertSort(entity: PropertyContainer, order: SortOrder): SortItem = {
val sortExpression = order.expression().describe()

val container: Option[PropertyContainer] = entity match {
case relationship: Relationship =>
if (sortExpression.contains(s"${Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS}.")) {
Some(relationship.getLeft)
} else if (sortExpression.contains(s"${Neo4jUtil.RELATIONSHIP_TARGET_ALIAS}.")) {
Some(relationship.getRight)
} else if (sortExpression.contains(s"${Neo4jUtil.RELATIONSHIP_ALIAS}.")) {
Some(relationship)
} else {
None
}
case _ => Some(entity)
}
val direction = if (order.direction() == SortDirection.ASCENDING) SortItem.Direction.ASC else SortItem.Direction.DESC

Cypher.sort(container
.map(_.property(sortExpression.removeAlias()))
.getOrElse(Cypher.name(sortExpression.unquote())), direction)
}

private def buildReturnExpression(sourceNode: Node, targetNode: Node, relationship: Relationship): Seq[Expression] = {
if (requiredColumns.isEmpty) {
Seq(relationship.getRequiredSymbolicName, sourceNode.as(Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS), targetNode.as(Neo4jUtil.RELATIONSHIP_TARGET_ALIAS))
Expand Down Expand Up @@ -186,13 +218,13 @@ class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
}
query
.`with`(entity)
// Spark does not push down limits when aggregation is involved
// Spark does not push down limits/top N when aggregation is involved
.orderBy(id)
.skip(partitionSkipLimit.skip)
.limit(partitionSkipLimit.limit)
.skip(partitionPagination.skip)
.limit(partitionPagination.topN.limit)
.returning(fields: _*)
} else {
val orderByProp = options.orderBy
val orderByProp = options.streamingOrderBy
if (StringUtils.isBlank(orderByProp)) {
query.returning(fields: _*)
} else {
Expand All @@ -207,37 +239,40 @@ class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
}

private def buildStatement(options: Neo4jOptions,
returning: StatementBuilder.OngoingReadingAndReturn,
returning: StatementBuilder.TerminalExposesSkip
with StatementBuilder.TerminalExposesLimit
with StatementBuilder.TerminalExposesOrderBy
with StatementBuilder.BuildableStatement[_],
entity: PropertyContainer = null): Statement = {

def addSkipLimit(ret: StatementBuilder.TerminalExposesSkip
with StatementBuilder.TerminalExposesLimit
with StatementBuilder.BuildableStatement[_]) = {
with StatementBuilder.TerminalExposesLimit
with StatementBuilder.BuildableStatement[_]) = {

if (partitionSkipLimit.skip == 0) {
ret.limit(partitionSkipLimit.limit)
if (partitionPagination.skip == 0) {
ret.limit(partitionPagination.topN.limit)
}
else {
ret.skip(partitionSkipLimit.skip).asInstanceOf[StatementBuilder.TerminalExposesLimit]
.limit(partitionSkipLimit.limit)
ret.skip(partitionPagination.skip)
.limit(partitionPagination.topN.limit)
}
}

val ret = if (entity == null) {
if (hasSkipLimit) addSkipLimit(returning) else returning
} else {
if (hasSkipLimit) {
val id = entity match {
case node: Node => Functions.id(node)
case rel: Relationship => Functions.id(rel)
}
if (options.partitions == 1) {
addSkipLimit(returning)
if (options.partitions == 1 || partitionPagination.topN.orders.nonEmpty) {
addSkipLimit(returning.orderBy(partitionPagination.topN.orders.map(order => convertSort(entity, order)): _*))
} else {
val id = entity match {
case node: Node => Functions.id(node)
case rel: Relationship => Functions.id(rel)
}
addSkipLimit(returning.orderBy(id))
}
} else {
val orderByProp = options.orderBy
val orderByProp = options.streamingOrderBy
if (StringUtils.isBlank(orderByProp)) returning else returning.orderBy(entity.property(orderByProp))
}
}
Expand Down Expand Up @@ -282,6 +317,7 @@ class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
def propertyOrSymbolicName(col: String) = {
if (entity != null) entity.property(col) else Cypher.name(col)
}

column match {
case Neo4jUtil.INTERNAL_ID_FIELD => Functions.id(entity.asInstanceOf[Node]).as(Neo4jUtil.INTERNAL_ID_FIELD)
case Neo4jUtil.INTERNAL_REL_ID_FIELD => Functions.id(entity.asInstanceOf[Relationship]).as(Neo4jUtil.INTERNAL_REL_ID_FIELD)
Expand Down Expand Up @@ -340,7 +376,7 @@ class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
val ret = if (requiredColumns.isEmpty) {
matchQuery.returning(node)
} else {
matchQuery.returning(expressions : _*)
matchQuery.returning(expressions: _*)
}
buildStatement(options, ret, node)
}
Expand Down Expand Up @@ -416,9 +452,9 @@ class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
.map(_._1)
.map(Cypher.parameter)
val statement = Cypher.call(options.query.value)
.withArgs(cypherParams : _*)
.`yield`(yieldFields : _*)
.returning(retCols : _*)
.withArgs(cypherParams: _*)
.`yield`(yieldFields: _*)
.returning(retCols: _*)
.build()
renderer.render(statement)
}
Expand Down Expand Up @@ -450,7 +486,8 @@ class Neo4jQueryService(private val options: Neo4jOptions,
case QueryType.RELATIONSHIP => strategy.createStatementForRelationships(options)
case QueryType.QUERY => strategy.createStatementForQuery(options)
case QueryType.GDS => strategy.createStatementForGDS(options)
case _ => throw new UnsupportedOperationException(s"""Query Type not supported.
case _ => throw new UnsupportedOperationException(
s"""Query Type not supported.
|You provided ${options.query.queryType},
|supported types: ${QueryType.values.mkString(",")}""".stripMargin)
}
Expand Down
52 changes: 27 additions & 25 deletions common/src/main/scala/org/neo4j/spark/service/SchemaService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}
import org.neo4j.driver.exceptions.ClientException
import org.neo4j.driver.types.Entity
import org.neo4j.driver.{Record, Session, Transaction, TransactionWork, Value, Values, summary}
import org.neo4j.spark.config.TopN
import org.neo4j.spark.service.SchemaService.{cypherToSparkType, normalizedClassName, normalizedClassNameFromGraphEntity}
import org.neo4j.spark.util.Neo4jImplicits.{CypherImplicits, EntityImplicits}
import org.neo4j.spark.util._
Expand All @@ -16,12 +17,12 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

object PartitionSkipLimit {
val EMPTY = PartitionSkipLimit(0, -1, -1)
val EMPTY_FOR_QUERY = PartitionSkipLimit(0, 0, 0)
object PartitionPagination {
val EMPTY = PartitionPagination(0, -1, TopN(-1))
val EMPTY_FOR_QUERY = PartitionPagination(0, 0, TopN(0))
}

case class PartitionSkipLimit(partitionNumber: Int, skip: Long, limit: Long)
case class PartitionPagination(partitionNumber: Int, skip: Long, topN: TopN)

case class Neo4jVersion(name: String, versions: Seq[String], edition: String)

Expand Down Expand Up @@ -301,7 +302,7 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:
StructType(fields)
}

def inputForGDSProc(procName: String): Seq[(String, Boolean)] = {
def inputForGDSProc(procName: String): Seq[(String, Boolean)] = {
val query =
"""
|WITH $procName AS procName
Expand Down Expand Up @@ -472,27 +473,28 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:
case QueryType.QUERY => countForQuery()
}

def skipLimitFromPartition(limit: Option[Int]): Seq[PartitionSkipLimit] = if (options.partitions == 1) {
val skipLimit = limit.map(l => PartitionSkipLimit(0, 0, l)).getOrElse(PartitionSkipLimit.EMPTY)
Seq(skipLimit)
} else {
val count: Long = this.count()
if (count <= 0) {
Seq(PartitionSkipLimit.EMPTY)
def skipLimitFromPartition(topN: Option[TopN]): Seq[PartitionPagination] =
if (options.partitions == 1) {
val skipLimit = topN.map(top => PartitionPagination(0, 0, top)).getOrElse(PartitionPagination.EMPTY)
Seq(skipLimit)
} else {
val partitionSize = Math.ceil(count.toDouble / options.partitions).toLong
val partitions = options.query.queryType match {
case QueryType.QUERY => if (options.queryMetadata.queryCount.nonEmpty) {
options.partitions // for custom query count we overfetch
} else {
options.partitions - 1
val count: Long = this.count()
if (count <= 0) {
Seq(PartitionPagination.EMPTY)
} else {
val partitionSize = Math.ceil(count.toDouble / options.partitions).toInt
val partitions = options.query.queryType match {
case QueryType.QUERY => if (options.queryMetadata.queryCount.nonEmpty) {
options.partitions // for custom query count we overfetch
} else {
options.partitions - 1
}
case _ => options.partitions - 1
}
case _ => options.partitions - 1
(0 to partitions)
.map(index => PartitionPagination(index, index * partitionSize, TopN(partitionSize, Array.empty)))
}
(0 to partitions)
.map(index => PartitionSkipLimit(index, index * partitionSize, partitionSize))
}
}

def isGdsProcedure(procName: String): Boolean = {
val params: util.Map[String, AnyRef] = Map[String, AnyRef]("procName" -> procName).asJava
Expand Down Expand Up @@ -614,7 +616,7 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:
|RETURN count(*) > 0 AS isPresent""".stripMargin
val params: util.Map[String, AnyRef] = Map("labels" -> Seq(label).asJava,
"properties" -> props.asJava).asJava.asInstanceOf[util.Map[String, AnyRef]]
session.run(queryCheck, params)
session.run(queryCheck, params)
.single()
.get("isPresent")
.asBoolean()
Expand Down Expand Up @@ -699,7 +701,7 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:
val label = options.nodeMetadata.labels.head
session.run(
s"""MATCH (n:$label)
|RETURN max(n.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin)
|RETURN max(n.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin)
.single()
.get(options.streamingOptions.propertyName)
.asLong(-1)
Expand Down Expand Up @@ -731,7 +733,7 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:

private def logResolutionChange(message: String, e: ClientException): Unit = {
log.warn(message)
if(!e.code().equals("Neo.ClientError.Procedure.ProcedureNotFound")) {
if (!e.code().equals("Neo.ClientError.Procedure.ProcedureNotFound")) {
log.warn(s"For the following exception", e)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.{DataTypes, StructType}
import org.neo4j.spark.reader.BasePartitionReader
import org.neo4j.spark.service.{Neo4jQueryStrategy, PartitionSkipLimit}
import org.neo4j.spark.service.{Neo4jQueryStrategy, PartitionPagination}
import org.neo4j.spark.util.Neo4jImplicits._
import org.neo4j.spark.util.{Neo4jOptions, Neo4jUtil, StreamingFrom}

Expand All @@ -16,7 +16,7 @@ class BaseStreamingPartitionReader(private val options: Neo4jOptions,
private val filters: Array[Filter],
private val schema: StructType,
private val jobId: String,
private val partitionSkipLimit: PartitionSkipLimit,
private val partitionSkipLimit: PartitionPagination,
private val scriptResult: java.util.List[java.util.Map[String, AnyRef]],
private val offsetAccumulator: OffsetStorage[java.lang.Long, java.lang.Long],
private val requiredColumns: StructType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ object Neo4jImplicits {
def isQuoted(): Boolean = str.startsWith("`");

def removeAlias(): String = {
val splatString = str.split('.')
val splatString = str.unquote().split('.')

if (splatString.size > 1) {
splatString.tail.mkString(".")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class Neo4jOptions(private val options: java.util.Map[String, String]) extends S
val pushdownColumnsEnabled: Boolean = getParameter(PUSHDOWN_COLUMNS_ENABLED, DEFAULT_PUSHDOWN_COLUMNS_ENABLED.toString).toBoolean
val pushdownAggregateEnabled: Boolean = getParameter(PUSHDOWN_AGGREGATE_ENABLED, DEFAULT_PUSHDOWN_AGGREGATE_ENABLED.toString).toBoolean
val pushdownLimitEnabled: Boolean = getParameter(PUSHDOWN_LIMIT_ENABLED, DEFAULT_PUSHDOWN_LIMIT_ENABLED.toString).toBoolean
val pushdownTopNEnabled: Boolean = getParameter(PUSHDOWN_TOPN_ENABLED, DEFAULT_PUSHDOWN_TOPN_ENABLED.toString).toBoolean

val schemaMetadata: Neo4jSchemaMetadata = Neo4jSchemaMetadata(getParameter(SCHEMA_FLATTEN_LIMIT, DEFAULT_SCHEMA_FLATTEN_LIMIT.toString).toInt,
SchemaStrategy.withCaseInsensitiveName(getParameter(SCHEMA_STRATEGY, DEFAULT_SCHEMA_STRATEGY.toString).toUpperCase),
Expand Down Expand Up @@ -202,7 +203,7 @@ class Neo4jOptions(private val options: java.util.Map[String, String]) extends S

val partitions: Int = getParameter(PARTITIONS, DEFAULT_PARTITIONS.toString).toInt

val orderBy: String = getParameter(ORDER_BY, getParameter(STREAMING_PROPERTY_NAME))
val streamingOrderBy: String = getParameter(ORDER_BY, getParameter(STREAMING_PROPERTY_NAME))

val apocConfig: Neo4jApocConfig = Neo4jApocConfig(parameters.asScala
.filterKeys(_.startsWith("apoc."))
Expand Down Expand Up @@ -391,6 +392,7 @@ object Neo4jOptions {
val PUSHDOWN_COLUMNS_ENABLED = "pushdown.columns.enabled"
val PUSHDOWN_AGGREGATE_ENABLED = "pushdown.aggregate.enabled"
val PUSHDOWN_LIMIT_ENABLED = "pushdown.limit.enabled"
val PUSHDOWN_TOPN_ENABLED = "pushdown.topN.enabled"

// schema options
val SCHEMA_STRATEGY = "schema.strategy"
Expand Down Expand Up @@ -462,6 +464,7 @@ object Neo4jOptions {
val DEFAULT_PUSHDOWN_COLUMNS_ENABLED = true
val DEFAULT_PUSHDOWN_AGGREGATE_ENABLED = true
val DEFAULT_PUSHDOWN_LIMIT_ENABLED = true
val DEFAULT_PUSHDOWN_TOPN_ENABLED = true
val DEFAULT_PARTITIONS = 1
val DEFAULT_OPTIMIZATION_TYPE = OptimizationType.NONE
val DEFAULT_SAVE_MODE = SaveMode.Overwrite
Expand Down
Loading

0 comments on commit e971d09

Please sign in to comment.