Skip to content

Commit

Permalink
Improve schema type constraint and constraint management (#581)
Browse files Browse the repository at this point in the history
  • Loading branch information
conker84 authored Feb 9, 2024
1 parent a159a31 commit e09f94b
Show file tree
Hide file tree
Showing 20 changed files with 1,133 additions and 356 deletions.
182 changes: 182 additions & 0 deletions common/src/main/scala/org/neo4j/spark/converter/DataConverter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package org.neo4j.spark.converter

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema, UnsafeRow}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.neo4j.driver.internal._
import org.neo4j.driver.types.{IsoDuration, Node, Relationship}
import org.neo4j.driver.{Value, Values}
import org.neo4j.spark.service.SchemaService
import org.neo4j.spark.util.Neo4jUtil

import java.time._
import java.time.format.DateTimeFormatter
import scala.annotation.tailrec
import scala.collection.JavaConverters._

trait DataConverter[T] {
def convert(value: Any, dataType: DataType = null): T

@tailrec
private[converter] final def extractStructType(dataType: DataType): StructType = dataType match {
case structType: StructType => structType
case mapType: MapType => extractStructType(mapType.valueType)
case arrayType: ArrayType => extractStructType(arrayType.elementType)
case _ => throw new UnsupportedOperationException(s"$dataType not supported")
}
}

object SparkToNeo4jDataConverter {
def apply(): SparkToNeo4jDataConverter = new SparkToNeo4jDataConverter()
}

class SparkToNeo4jDataConverter extends DataConverter[Value] {
override def convert(value: Any, dataType: DataType): Value = {
value match {
case date: java.sql.Date => convert(date.toLocalDate, dataType)
case timestamp: java.sql.Timestamp => convert(timestamp.toLocalDateTime, dataType)
case intValue: Int if dataType == DataTypes.DateType => convert(DateTimeUtils
.toJavaDate(intValue), dataType)
case longValue: Long if dataType == DataTypes.TimestampType => convert(DateTimeUtils
.toJavaTimestamp(longValue), dataType)
case unsafeRow: UnsafeRow => {
val structType = extractStructType(dataType)
val row = new GenericRowWithSchema(unsafeRow.toSeq(structType).toArray, structType)
convert(row)
}
case struct: GenericRow => {
def toMap(struct: GenericRow): Value = {
Values.value(
struct.schema.fields.map(
f => f.name -> convert(struct.getAs(f.name), f.dataType)
).toMap.asJava)
}

try {
struct.getAs[UTF8String]("type").toString match {
case SchemaService.POINT_TYPE_2D => Values.point(struct.getAs[Number]("srid").intValue(),
struct.getAs[Number]("x").doubleValue(),
struct.getAs[Number]("y").doubleValue())
case SchemaService.POINT_TYPE_3D => Values.point(struct.getAs[Number]("srid").intValue(),
struct.getAs[Number]("x").doubleValue(),
struct.getAs[Number]("y").doubleValue(),
struct.getAs[Number]("z").doubleValue())
case SchemaService.DURATION_TYPE => Values.isoDuration(struct.getAs[Number]("months").longValue(),
struct.getAs[Number]("days").longValue(),
struct.getAs[Number]("seconds").longValue(),
struct.getAs[Number]("nanoseconds").intValue())
case SchemaService.TIME_TYPE_OFFSET => Values.value(OffsetTime.parse(struct.getAs[UTF8String]("value").toString))
case SchemaService.TIME_TYPE_LOCAL => Values.value(LocalTime.parse(struct.getAs[UTF8String]("value").toString))
case _ => toMap(struct)
}
} catch {
case _: Throwable => toMap(struct)
}
}
case unsafeArray: ArrayData => {
val sparkType = dataType match {
case arrayType: ArrayType => arrayType.elementType
case _ => dataType
}
val javaList = unsafeArray.toSeq[AnyRef](sparkType)
.map(elem => convert(elem, sparkType))
.asJava
Values.value(javaList)
}
case unsafeMapData: MapData => { // Neo4j only supports Map[String, AnyRef]
val mapType = dataType.asInstanceOf[MapType]
val map: Map[String, AnyRef] = (0 until unsafeMapData.numElements())
.map(i => (unsafeMapData.keyArray().getUTF8String(i).toString, unsafeMapData.valueArray().get(i, mapType.valueType)))
.toMap[String, AnyRef]
.mapValues(innerValue => convert(innerValue, mapType.valueType))
.toMap[String, AnyRef]
Values.value(map.asJava)
}
case string: UTF8String => convert(string.toString)
case _ => Values.value(value)
}
}
}

object Neo4jToSparkDataConverter {
def apply(): Neo4jToSparkDataConverter = new Neo4jToSparkDataConverter()
}

class Neo4jToSparkDataConverter extends DataConverter[Any] {
override def convert(value: Any, dataType: DataType): Any = {
if (dataType != null && dataType == DataTypes.StringType && value != null && !value.isInstanceOf[String]) {
convert(Neo4jUtil.mapper.writeValueAsString(value), dataType)
} else {
value match {
case node: Node => {
val map = node.asMap()
val structType = extractStructType(dataType)
val fields = structType
.filter(field => field.name != Neo4jUtil.INTERNAL_ID_FIELD && field.name != Neo4jUtil.INTERNAL_LABELS_FIELD)
.map(field => convert(map.get(field.name), field.dataType))
InternalRow.fromSeq(Seq(convert(node.id()), convert(node.labels())) ++ fields)
}
case rel: Relationship => {
val map = rel.asMap()
val structType = extractStructType(dataType)
val fields = structType
.filter(field => field.name != Neo4jUtil.INTERNAL_REL_ID_FIELD
&& field.name != Neo4jUtil.INTERNAL_REL_TYPE_FIELD
&& field.name != Neo4jUtil.INTERNAL_REL_SOURCE_ID_FIELD
&& field.name != Neo4jUtil.INTERNAL_REL_TARGET_ID_FIELD)
.map(field => convert(map.get(field.name), field.dataType))
InternalRow.fromSeq(Seq(convert(rel.id()),
convert(rel.`type`()),
convert(rel.startNodeId()),
convert(rel.endNodeId())) ++ fields)
}
case d: IsoDuration => {
val months = d.months()
val days = d.days()
val nanoseconds: Integer = d.nanoseconds()
val seconds = d.seconds()
InternalRow.fromSeq(Seq(UTF8String.fromString(SchemaService.DURATION_TYPE), months, days, seconds, nanoseconds, UTF8String.fromString(d.toString)))
}
case zt: ZonedDateTime => DateTimeUtils.instantToMicros(zt.toInstant)
case dt: LocalDateTime => DateTimeUtils.instantToMicros(dt.toInstant(ZoneOffset.UTC))
case d: LocalDate => d.toEpochDay.toInt
case lt: LocalTime => {
InternalRow.fromSeq(Seq(
UTF8String.fromString(SchemaService.TIME_TYPE_LOCAL),
UTF8String.fromString(lt.format(DateTimeFormatter.ISO_TIME))
))
}
case t: OffsetTime => {
InternalRow.fromSeq(Seq(
UTF8String.fromString(SchemaService.TIME_TYPE_OFFSET),
UTF8String.fromString(t.format(DateTimeFormatter.ISO_TIME))
))
}
case p: InternalPoint2D => {
val srid: Integer = p.srid()
InternalRow.fromSeq(Seq(UTF8String.fromString(SchemaService.POINT_TYPE_2D), srid, p.x(), p.y(), null))
}
case p: InternalPoint3D => {
val srid: Integer = p.srid()
InternalRow.fromSeq(Seq(UTF8String.fromString(SchemaService.POINT_TYPE_3D), srid, p.x(), p.y(), p.z()))
}
case l: java.util.List[_] => {
val elementType = if (dataType != null) dataType.asInstanceOf[ArrayType].elementType else null
ArrayData.toArrayData(l.asScala.map(e => convert(e, elementType)).toArray)
}
case map: java.util.Map[_, _] => {
if (dataType != null) {
val mapType = dataType.asInstanceOf[MapType]
ArrayBasedMapData(map.asScala.map(t => (convert(t._1, mapType.keyType), convert(t._2, mapType.valueType))))
} else {
ArrayBasedMapData(map.asScala.map(t => (convert(t._1), convert(t._2))))
}
}
case s: String => UTF8String.fromString(s)
case _ => value
}
}
}
}
129 changes: 129 additions & 0 deletions common/src/main/scala/org/neo4j/spark/converter/TypeConverter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package org.neo4j.spark.converter

import org.apache.spark.sql.types.{DataType, DataTypes}
import org.neo4j.driver.types.Entity
import org.neo4j.spark.converter.CypherToSparkTypeConverter.{cleanTerms, durationType, pointType, timeType}
import org.neo4j.spark.converter.SparkToCypherTypeConverter.mapping
import org.neo4j.spark.service.SchemaService.normalizedClassName
import org.neo4j.spark.util.Neo4jImplicits.EntityImplicits

import scala.collection.JavaConverters._

trait TypeConverter[SOURCE_TYPE, DESTINATION_TYPE] {

def convert(sourceType: SOURCE_TYPE, value: Any = null): DESTINATION_TYPE

}

object CypherToSparkTypeConverter {
def apply(): CypherToSparkTypeConverter = new CypherToSparkTypeConverter()

private val cleanTerms: String = "Unmodifiable|Internal|Iso|2D|3D|Offset|Local|Zoned"

val durationType: DataType = DataTypes.createStructType(Array(
DataTypes.createStructField("type", DataTypes.StringType, false),
DataTypes.createStructField("months", DataTypes.LongType, false),
DataTypes.createStructField("days", DataTypes.LongType, false),
DataTypes.createStructField("seconds", DataTypes.LongType, false),
DataTypes.createStructField("nanoseconds", DataTypes.IntegerType, false),
DataTypes.createStructField("value", DataTypes.StringType, false)
))

val pointType: DataType = DataTypes.createStructType(Array(
DataTypes.createStructField("type", DataTypes.StringType, false),
DataTypes.createStructField("srid", DataTypes.IntegerType, false),
DataTypes.createStructField("x", DataTypes.DoubleType, false),
DataTypes.createStructField("y", DataTypes.DoubleType, false),
DataTypes.createStructField("z", DataTypes.DoubleType, true)
))

val timeType: DataType = DataTypes.createStructType(Array(
DataTypes.createStructField("type", DataTypes.StringType, false),
DataTypes.createStructField("value", DataTypes.StringType, false)
))
}

class CypherToSparkTypeConverter extends TypeConverter[String, DataType] {
override def convert(sourceType: String, value: Any = null): DataType = sourceType
.replaceAll(cleanTerms, "") match {
case "Node" | "Relationship" => if (value != null) value.asInstanceOf[Entity].toStruct else DataTypes.NullType
case "NodeArray" | "RelationshipArray" => if (value != null) DataTypes.createArrayType(value.asInstanceOf[Entity].toStruct) else DataTypes.NullType
case "Boolean" => DataTypes.BooleanType
case "Long" => DataTypes.LongType
case "Double" => DataTypes.DoubleType
case "Point" => pointType
case "DateTime" | "ZonedDateTime" | "LocalDateTime" => DataTypes.TimestampType
case "Time" => timeType
case "Date" => DataTypes.DateType
case "Duration" => durationType
case "Map" => {
val valueType = if (value == null) {
DataTypes.NullType
} else {
val map = value.asInstanceOf[java.util.Map[String, AnyRef]].asScala
val types = map.values
.map(normalizedClassName)
.toSet
if (types.size == 1) convert(types.head, map.values.head) else DataTypes.StringType
}
DataTypes.createMapType(DataTypes.StringType, valueType)
}
case "Array" => {
val valueType = if (value == null) {
DataTypes.NullType
} else {
val list = value.asInstanceOf[java.util.List[AnyRef]].asScala
val types = list
.map(normalizedClassName)
.toSet
if (types.size == 1) convert(types.head, list.head) else DataTypes.StringType
}
DataTypes.createArrayType(valueType)
}
// These are from APOC
case "StringArray" => DataTypes.createArrayType(DataTypes.StringType)
case "LongArray" => DataTypes.createArrayType(DataTypes.LongType)
case "DoubleArray" => DataTypes.createArrayType(DataTypes.DoubleType)
case "BooleanArray" => DataTypes.createArrayType(DataTypes.BooleanType)
case "PointArray" => DataTypes.createArrayType(pointType)
case "DateTimeArray" => DataTypes.createArrayType(DataTypes.TimestampType)
case "TimeArray" => DataTypes.createArrayType(timeType)
case "DateArray" => DataTypes.createArrayType(DataTypes.DateType)
case "DurationArray" => DataTypes.createArrayType(durationType)
// Default is String
case _ => DataTypes.StringType
}
}

object SparkToCypherTypeConverter {
def apply(): SparkToCypherTypeConverter = new SparkToCypherTypeConverter()

private val mapping: Map[DataType, String] = Map(
DataTypes.BooleanType -> "BOOLEAN",
DataTypes.StringType -> "STRING",
DataTypes.IntegerType -> "INTEGER",
DataTypes.LongType -> "INTEGER",
DataTypes.FloatType -> "FLOAT",
DataTypes.DoubleType -> "FLOAT",
DataTypes.DateType -> "DATE",
DataTypes.TimestampType -> "LOCAL DATETIME",
durationType -> "DURATION",
pointType -> "POINT",
// Cypher graph entities do not allow null values in arrays
DataTypes.createArrayType(DataTypes.BooleanType, false) -> "LIST<BOOLEAN NOT NULL>",
DataTypes.createArrayType(DataTypes.StringType, false) -> "LIST<STRING NOT NULL>",
DataTypes.createArrayType(DataTypes.IntegerType, false) -> "LIST<INTEGER NOT NULL>",
DataTypes.createArrayType(DataTypes.LongType, false) -> "LIST<INTEGER NOT NULL>",
DataTypes.createArrayType(DataTypes.FloatType, false) -> "LIST<FLOAT NOT NULL>",
DataTypes.createArrayType(DataTypes.DoubleType, false) -> "LIST<FLOAT NOT NULL>",
DataTypes.createArrayType(DataTypes.DateType, false) -> "LIST<DATE NOT NULL>",
DataTypes.createArrayType(DataTypes.TimestampType, false) -> "LIST<LOCAL DATETIME NOT NULL>",
DataTypes.createArrayType(DataTypes.TimestampType, true) -> "LIST<LOCAL DATETIME NOT NULL>",
DataTypes.createArrayType(durationType, false) -> "LIST<DURATION NOT NULL>",
DataTypes.createArrayType(pointType, false) -> "LIST<POINT NOT NULL>"
)
}

class SparkToCypherTypeConverter extends TypeConverter[DataType, String] {
override def convert(sourceType: DataType, value: Any): String = mapping(sourceType)
}
20 changes: 14 additions & 6 deletions common/src/main/scala/org/neo4j/spark/service/MappingService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import org.apache.spark.sql.types.StructType
import org.neo4j.driver.internal.value.MapValue
import org.neo4j.driver.types.Node
import org.neo4j.driver.{Record, Value, Values}
import org.neo4j.spark.converter.{Neo4jToSparkDataConverter, SparkToNeo4jDataConverter}
import org.neo4j.spark.service.Neo4jWriteMappingStrategy.{KEYS, PROPERTIES}
import org.neo4j.spark.util.{Neo4jNodeMetadata, Neo4jOptions, Neo4jUtil, QueryType, RelationshipSaveStrategy, ValidateSchemaOptions, Validations}

Expand All @@ -21,6 +22,8 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions)
extends Neo4jMappingStrategy[InternalRow, java.util.Map[String, AnyRef]]
with Logging {

private val dataConverter = SparkToNeo4jDataConverter()

override def node(row: InternalRow, schema: StructType): java.util.Map[String, AnyRef] = {
Validations.validate(ValidateSchemaOptions(options, schema))

Expand All @@ -35,7 +38,7 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions)
override def accept(key: String, value: AnyRef): Unit = if (options.nodeMetadata.nodeKeys.contains(key)) {
keys.put(options.nodeMetadata.nodeKeys.getOrElse(key, key), value)
} else {
properties.put(options.nodeMetadata.nodeProps.getOrElse(key, key), value)
properties.put(options.nodeMetadata.properties.getOrElse(key, key), value)
}
})

Expand Down Expand Up @@ -70,8 +73,8 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions)
if (nodeMetadata.nodeKeys.contains(key)) {
nodeMap.get(KEYS).put(nodeMetadata.nodeKeys.getOrElse(key, key), value)
}
if (nodeMetadata.nodeProps.contains(key)) {
nodeMap.get(PROPERTIES).put(nodeMetadata.nodeProps.getOrElse(key, key), value)
if (nodeMetadata.properties.contains(key)) {
nodeMap.get(PROPERTIES).put(nodeMetadata.properties.getOrElse(key, key), value)
}
}

Expand All @@ -83,7 +86,9 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions)
addToNodeMap(sourceNodeMap, source, key, value)
addToNodeMap(targetNodeMap, target, key, value)

if (options.relationshipMetadata.properties.contains(key)) {
if (options.relationshipMetadata.relationshipKeys.contains(key)) {
relMap.get(KEYS).put(options.relationshipMetadata.relationshipKeys.getOrElse(key, key), value)
} else {
relMap.get(PROPERTIES).put(options.relationshipMetadata.properties.getOrElse(key, key), value)
}
}
Expand Down Expand Up @@ -123,7 +128,7 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions)
schema.indices
.flatMap(i => {
val field = schema(i)
val neo4jValue = Neo4jUtil.convertFromSpark(seq(i), field.dataType)
val neo4jValue = dataConverter.convert(seq(i), field.dataType)
neo4jValue match {
case map: MapValue =>
map.asMap().asScala.toMap
Expand All @@ -140,6 +145,8 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions)

class Neo4jReadMappingStrategy(private val options: Neo4jOptions, requiredColumns: StructType) extends Neo4jMappingStrategy[Record, InternalRow] {

private val dataConverter = Neo4jToSparkDataConverter()

override def node(record: Record, schema: StructType): InternalRow = {
if (requiredColumns.nonEmpty) {
query(record, schema)
Expand All @@ -158,7 +165,7 @@ class Neo4jReadMappingStrategy(private val options: Neo4jOptions, requiredColumn
schema: StructType) = InternalRow
.fromSeq(
schema.map(
field => Neo4jUtil.convertFromNeo4j(map.get(field.name), field.dataType)
field => dataConverter.convert(map.get(field.name), field.dataType)
)
)

Expand Down Expand Up @@ -254,6 +261,7 @@ private abstract class MappingBiConsumer extends BiConsumer[String, AnyRef] {
val sourceNodeMap = new util.HashMap[String, util.Map[String, AnyRef]]()
val targetNodeMap = new util.HashMap[String, util.Map[String, AnyRef]]()

relMap.put(KEYS, new util.HashMap[String, AnyRef]())
relMap.put(PROPERTIES, new util.HashMap[String, AnyRef]())
sourceNodeMap.put(PROPERTIES, new util.HashMap[String, AnyRef]())
sourceNodeMap.put(KEYS, new util.HashMap[String, AnyRef]())
Expand Down
Loading

0 comments on commit e09f94b

Please sign in to comment.