-
Notifications
You must be signed in to change notification settings - Fork 112
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve schema type constraint and constraint management (#581)
- Loading branch information
Showing
20 changed files
with
1,133 additions
and
356 deletions.
There are no files selected for viewing
182 changes: 182 additions & 0 deletions
182
common/src/main/scala/org/neo4j/spark/converter/DataConverter.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
package org.neo4j.spark.converter | ||
|
||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema, UnsafeRow} | ||
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, MapData} | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.unsafe.types.UTF8String | ||
import org.neo4j.driver.internal._ | ||
import org.neo4j.driver.types.{IsoDuration, Node, Relationship} | ||
import org.neo4j.driver.{Value, Values} | ||
import org.neo4j.spark.service.SchemaService | ||
import org.neo4j.spark.util.Neo4jUtil | ||
|
||
import java.time._ | ||
import java.time.format.DateTimeFormatter | ||
import scala.annotation.tailrec | ||
import scala.collection.JavaConverters._ | ||
|
||
trait DataConverter[T] { | ||
def convert(value: Any, dataType: DataType = null): T | ||
|
||
@tailrec | ||
private[converter] final def extractStructType(dataType: DataType): StructType = dataType match { | ||
case structType: StructType => structType | ||
case mapType: MapType => extractStructType(mapType.valueType) | ||
case arrayType: ArrayType => extractStructType(arrayType.elementType) | ||
case _ => throw new UnsupportedOperationException(s"$dataType not supported") | ||
} | ||
} | ||
|
||
object SparkToNeo4jDataConverter { | ||
def apply(): SparkToNeo4jDataConverter = new SparkToNeo4jDataConverter() | ||
} | ||
|
||
class SparkToNeo4jDataConverter extends DataConverter[Value] { | ||
override def convert(value: Any, dataType: DataType): Value = { | ||
value match { | ||
case date: java.sql.Date => convert(date.toLocalDate, dataType) | ||
case timestamp: java.sql.Timestamp => convert(timestamp.toLocalDateTime, dataType) | ||
case intValue: Int if dataType == DataTypes.DateType => convert(DateTimeUtils | ||
.toJavaDate(intValue), dataType) | ||
case longValue: Long if dataType == DataTypes.TimestampType => convert(DateTimeUtils | ||
.toJavaTimestamp(longValue), dataType) | ||
case unsafeRow: UnsafeRow => { | ||
val structType = extractStructType(dataType) | ||
val row = new GenericRowWithSchema(unsafeRow.toSeq(structType).toArray, structType) | ||
convert(row) | ||
} | ||
case struct: GenericRow => { | ||
def toMap(struct: GenericRow): Value = { | ||
Values.value( | ||
struct.schema.fields.map( | ||
f => f.name -> convert(struct.getAs(f.name), f.dataType) | ||
).toMap.asJava) | ||
} | ||
|
||
try { | ||
struct.getAs[UTF8String]("type").toString match { | ||
case SchemaService.POINT_TYPE_2D => Values.point(struct.getAs[Number]("srid").intValue(), | ||
struct.getAs[Number]("x").doubleValue(), | ||
struct.getAs[Number]("y").doubleValue()) | ||
case SchemaService.POINT_TYPE_3D => Values.point(struct.getAs[Number]("srid").intValue(), | ||
struct.getAs[Number]("x").doubleValue(), | ||
struct.getAs[Number]("y").doubleValue(), | ||
struct.getAs[Number]("z").doubleValue()) | ||
case SchemaService.DURATION_TYPE => Values.isoDuration(struct.getAs[Number]("months").longValue(), | ||
struct.getAs[Number]("days").longValue(), | ||
struct.getAs[Number]("seconds").longValue(), | ||
struct.getAs[Number]("nanoseconds").intValue()) | ||
case SchemaService.TIME_TYPE_OFFSET => Values.value(OffsetTime.parse(struct.getAs[UTF8String]("value").toString)) | ||
case SchemaService.TIME_TYPE_LOCAL => Values.value(LocalTime.parse(struct.getAs[UTF8String]("value").toString)) | ||
case _ => toMap(struct) | ||
} | ||
} catch { | ||
case _: Throwable => toMap(struct) | ||
} | ||
} | ||
case unsafeArray: ArrayData => { | ||
val sparkType = dataType match { | ||
case arrayType: ArrayType => arrayType.elementType | ||
case _ => dataType | ||
} | ||
val javaList = unsafeArray.toSeq[AnyRef](sparkType) | ||
.map(elem => convert(elem, sparkType)) | ||
.asJava | ||
Values.value(javaList) | ||
} | ||
case unsafeMapData: MapData => { // Neo4j only supports Map[String, AnyRef] | ||
val mapType = dataType.asInstanceOf[MapType] | ||
val map: Map[String, AnyRef] = (0 until unsafeMapData.numElements()) | ||
.map(i => (unsafeMapData.keyArray().getUTF8String(i).toString, unsafeMapData.valueArray().get(i, mapType.valueType))) | ||
.toMap[String, AnyRef] | ||
.mapValues(innerValue => convert(innerValue, mapType.valueType)) | ||
.toMap[String, AnyRef] | ||
Values.value(map.asJava) | ||
} | ||
case string: UTF8String => convert(string.toString) | ||
case _ => Values.value(value) | ||
} | ||
} | ||
} | ||
|
||
object Neo4jToSparkDataConverter { | ||
def apply(): Neo4jToSparkDataConverter = new Neo4jToSparkDataConverter() | ||
} | ||
|
||
class Neo4jToSparkDataConverter extends DataConverter[Any] { | ||
override def convert(value: Any, dataType: DataType): Any = { | ||
if (dataType != null && dataType == DataTypes.StringType && value != null && !value.isInstanceOf[String]) { | ||
convert(Neo4jUtil.mapper.writeValueAsString(value), dataType) | ||
} else { | ||
value match { | ||
case node: Node => { | ||
val map = node.asMap() | ||
val structType = extractStructType(dataType) | ||
val fields = structType | ||
.filter(field => field.name != Neo4jUtil.INTERNAL_ID_FIELD && field.name != Neo4jUtil.INTERNAL_LABELS_FIELD) | ||
.map(field => convert(map.get(field.name), field.dataType)) | ||
InternalRow.fromSeq(Seq(convert(node.id()), convert(node.labels())) ++ fields) | ||
} | ||
case rel: Relationship => { | ||
val map = rel.asMap() | ||
val structType = extractStructType(dataType) | ||
val fields = structType | ||
.filter(field => field.name != Neo4jUtil.INTERNAL_REL_ID_FIELD | ||
&& field.name != Neo4jUtil.INTERNAL_REL_TYPE_FIELD | ||
&& field.name != Neo4jUtil.INTERNAL_REL_SOURCE_ID_FIELD | ||
&& field.name != Neo4jUtil.INTERNAL_REL_TARGET_ID_FIELD) | ||
.map(field => convert(map.get(field.name), field.dataType)) | ||
InternalRow.fromSeq(Seq(convert(rel.id()), | ||
convert(rel.`type`()), | ||
convert(rel.startNodeId()), | ||
convert(rel.endNodeId())) ++ fields) | ||
} | ||
case d: IsoDuration => { | ||
val months = d.months() | ||
val days = d.days() | ||
val nanoseconds: Integer = d.nanoseconds() | ||
val seconds = d.seconds() | ||
InternalRow.fromSeq(Seq(UTF8String.fromString(SchemaService.DURATION_TYPE), months, days, seconds, nanoseconds, UTF8String.fromString(d.toString))) | ||
} | ||
case zt: ZonedDateTime => DateTimeUtils.instantToMicros(zt.toInstant) | ||
case dt: LocalDateTime => DateTimeUtils.instantToMicros(dt.toInstant(ZoneOffset.UTC)) | ||
case d: LocalDate => d.toEpochDay.toInt | ||
case lt: LocalTime => { | ||
InternalRow.fromSeq(Seq( | ||
UTF8String.fromString(SchemaService.TIME_TYPE_LOCAL), | ||
UTF8String.fromString(lt.format(DateTimeFormatter.ISO_TIME)) | ||
)) | ||
} | ||
case t: OffsetTime => { | ||
InternalRow.fromSeq(Seq( | ||
UTF8String.fromString(SchemaService.TIME_TYPE_OFFSET), | ||
UTF8String.fromString(t.format(DateTimeFormatter.ISO_TIME)) | ||
)) | ||
} | ||
case p: InternalPoint2D => { | ||
val srid: Integer = p.srid() | ||
InternalRow.fromSeq(Seq(UTF8String.fromString(SchemaService.POINT_TYPE_2D), srid, p.x(), p.y(), null)) | ||
} | ||
case p: InternalPoint3D => { | ||
val srid: Integer = p.srid() | ||
InternalRow.fromSeq(Seq(UTF8String.fromString(SchemaService.POINT_TYPE_3D), srid, p.x(), p.y(), p.z())) | ||
} | ||
case l: java.util.List[_] => { | ||
val elementType = if (dataType != null) dataType.asInstanceOf[ArrayType].elementType else null | ||
ArrayData.toArrayData(l.asScala.map(e => convert(e, elementType)).toArray) | ||
} | ||
case map: java.util.Map[_, _] => { | ||
if (dataType != null) { | ||
val mapType = dataType.asInstanceOf[MapType] | ||
ArrayBasedMapData(map.asScala.map(t => (convert(t._1, mapType.keyType), convert(t._2, mapType.valueType)))) | ||
} else { | ||
ArrayBasedMapData(map.asScala.map(t => (convert(t._1), convert(t._2)))) | ||
} | ||
} | ||
case s: String => UTF8String.fromString(s) | ||
case _ => value | ||
} | ||
} | ||
} | ||
} |
129 changes: 129 additions & 0 deletions
129
common/src/main/scala/org/neo4j/spark/converter/TypeConverter.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
package org.neo4j.spark.converter | ||
|
||
import org.apache.spark.sql.types.{DataType, DataTypes} | ||
import org.neo4j.driver.types.Entity | ||
import org.neo4j.spark.converter.CypherToSparkTypeConverter.{cleanTerms, durationType, pointType, timeType} | ||
import org.neo4j.spark.converter.SparkToCypherTypeConverter.mapping | ||
import org.neo4j.spark.service.SchemaService.normalizedClassName | ||
import org.neo4j.spark.util.Neo4jImplicits.EntityImplicits | ||
|
||
import scala.collection.JavaConverters._ | ||
|
||
trait TypeConverter[SOURCE_TYPE, DESTINATION_TYPE] { | ||
|
||
def convert(sourceType: SOURCE_TYPE, value: Any = null): DESTINATION_TYPE | ||
|
||
} | ||
|
||
object CypherToSparkTypeConverter { | ||
def apply(): CypherToSparkTypeConverter = new CypherToSparkTypeConverter() | ||
|
||
private val cleanTerms: String = "Unmodifiable|Internal|Iso|2D|3D|Offset|Local|Zoned" | ||
|
||
val durationType: DataType = DataTypes.createStructType(Array( | ||
DataTypes.createStructField("type", DataTypes.StringType, false), | ||
DataTypes.createStructField("months", DataTypes.LongType, false), | ||
DataTypes.createStructField("days", DataTypes.LongType, false), | ||
DataTypes.createStructField("seconds", DataTypes.LongType, false), | ||
DataTypes.createStructField("nanoseconds", DataTypes.IntegerType, false), | ||
DataTypes.createStructField("value", DataTypes.StringType, false) | ||
)) | ||
|
||
val pointType: DataType = DataTypes.createStructType(Array( | ||
DataTypes.createStructField("type", DataTypes.StringType, false), | ||
DataTypes.createStructField("srid", DataTypes.IntegerType, false), | ||
DataTypes.createStructField("x", DataTypes.DoubleType, false), | ||
DataTypes.createStructField("y", DataTypes.DoubleType, false), | ||
DataTypes.createStructField("z", DataTypes.DoubleType, true) | ||
)) | ||
|
||
val timeType: DataType = DataTypes.createStructType(Array( | ||
DataTypes.createStructField("type", DataTypes.StringType, false), | ||
DataTypes.createStructField("value", DataTypes.StringType, false) | ||
)) | ||
} | ||
|
||
class CypherToSparkTypeConverter extends TypeConverter[String, DataType] { | ||
override def convert(sourceType: String, value: Any = null): DataType = sourceType | ||
.replaceAll(cleanTerms, "") match { | ||
case "Node" | "Relationship" => if (value != null) value.asInstanceOf[Entity].toStruct else DataTypes.NullType | ||
case "NodeArray" | "RelationshipArray" => if (value != null) DataTypes.createArrayType(value.asInstanceOf[Entity].toStruct) else DataTypes.NullType | ||
case "Boolean" => DataTypes.BooleanType | ||
case "Long" => DataTypes.LongType | ||
case "Double" => DataTypes.DoubleType | ||
case "Point" => pointType | ||
case "DateTime" | "ZonedDateTime" | "LocalDateTime" => DataTypes.TimestampType | ||
case "Time" => timeType | ||
case "Date" => DataTypes.DateType | ||
case "Duration" => durationType | ||
case "Map" => { | ||
val valueType = if (value == null) { | ||
DataTypes.NullType | ||
} else { | ||
val map = value.asInstanceOf[java.util.Map[String, AnyRef]].asScala | ||
val types = map.values | ||
.map(normalizedClassName) | ||
.toSet | ||
if (types.size == 1) convert(types.head, map.values.head) else DataTypes.StringType | ||
} | ||
DataTypes.createMapType(DataTypes.StringType, valueType) | ||
} | ||
case "Array" => { | ||
val valueType = if (value == null) { | ||
DataTypes.NullType | ||
} else { | ||
val list = value.asInstanceOf[java.util.List[AnyRef]].asScala | ||
val types = list | ||
.map(normalizedClassName) | ||
.toSet | ||
if (types.size == 1) convert(types.head, list.head) else DataTypes.StringType | ||
} | ||
DataTypes.createArrayType(valueType) | ||
} | ||
// These are from APOC | ||
case "StringArray" => DataTypes.createArrayType(DataTypes.StringType) | ||
case "LongArray" => DataTypes.createArrayType(DataTypes.LongType) | ||
case "DoubleArray" => DataTypes.createArrayType(DataTypes.DoubleType) | ||
case "BooleanArray" => DataTypes.createArrayType(DataTypes.BooleanType) | ||
case "PointArray" => DataTypes.createArrayType(pointType) | ||
case "DateTimeArray" => DataTypes.createArrayType(DataTypes.TimestampType) | ||
case "TimeArray" => DataTypes.createArrayType(timeType) | ||
case "DateArray" => DataTypes.createArrayType(DataTypes.DateType) | ||
case "DurationArray" => DataTypes.createArrayType(durationType) | ||
// Default is String | ||
case _ => DataTypes.StringType | ||
} | ||
} | ||
|
||
object SparkToCypherTypeConverter { | ||
def apply(): SparkToCypherTypeConverter = new SparkToCypherTypeConverter() | ||
|
||
private val mapping: Map[DataType, String] = Map( | ||
DataTypes.BooleanType -> "BOOLEAN", | ||
DataTypes.StringType -> "STRING", | ||
DataTypes.IntegerType -> "INTEGER", | ||
DataTypes.LongType -> "INTEGER", | ||
DataTypes.FloatType -> "FLOAT", | ||
DataTypes.DoubleType -> "FLOAT", | ||
DataTypes.DateType -> "DATE", | ||
DataTypes.TimestampType -> "LOCAL DATETIME", | ||
durationType -> "DURATION", | ||
pointType -> "POINT", | ||
// Cypher graph entities do not allow null values in arrays | ||
DataTypes.createArrayType(DataTypes.BooleanType, false) -> "LIST<BOOLEAN NOT NULL>", | ||
DataTypes.createArrayType(DataTypes.StringType, false) -> "LIST<STRING NOT NULL>", | ||
DataTypes.createArrayType(DataTypes.IntegerType, false) -> "LIST<INTEGER NOT NULL>", | ||
DataTypes.createArrayType(DataTypes.LongType, false) -> "LIST<INTEGER NOT NULL>", | ||
DataTypes.createArrayType(DataTypes.FloatType, false) -> "LIST<FLOAT NOT NULL>", | ||
DataTypes.createArrayType(DataTypes.DoubleType, false) -> "LIST<FLOAT NOT NULL>", | ||
DataTypes.createArrayType(DataTypes.DateType, false) -> "LIST<DATE NOT NULL>", | ||
DataTypes.createArrayType(DataTypes.TimestampType, false) -> "LIST<LOCAL DATETIME NOT NULL>", | ||
DataTypes.createArrayType(DataTypes.TimestampType, true) -> "LIST<LOCAL DATETIME NOT NULL>", | ||
DataTypes.createArrayType(durationType, false) -> "LIST<DURATION NOT NULL>", | ||
DataTypes.createArrayType(pointType, false) -> "LIST<POINT NOT NULL>" | ||
) | ||
} | ||
|
||
class SparkToCypherTypeConverter extends TypeConverter[DataType, String] { | ||
override def convert(sourceType: DataType, value: Any): String = mapping(sourceType) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.