Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #515: Add ColumnStats Schema for JSON parsing #522

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions core/src/main/scala/io/qbeast/core/model/QbeastColumnStats.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package io.qbeast.core.model

import io.qbeast.core.transform.CDFNumericQuantilesTransformer
import io.qbeast.core.transform.CDFStringQuantilesTransformer
import io.qbeast.core.transform.LinearTransformer
import io.qbeast.core.transform.StringHistogramTransformer
import io.qbeast.core.transform.Transformer
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.AnalysisExceptionFactory
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession

/**
* Container for Qbeast Column Stats
*
* @param columnStatsSchema
* the column stats schema
* @param columnStatsRow
* the column stats row
*/
case class QbeastColumnStats(columnStatsSchema: StructType, columnStatsRow: Row)

object QbeastColumnStatsBuilder {

/**
* Builds the column stats schema
*
* For each column transformer, it creates a StructField with it's stats names
* @param dataSchema
* the data schema
* @param columnTransformers
* the column transformers
* @return
*/
def buildColumnStatsSchema(
dataSchema: StructType,
columnTransformers: Seq[Transformer]): StructType = {
val columnStatsSchema = StructType(columnTransformers.flatMap { transformer =>
val transformerStatsNames = transformer.stats.statsNames
val transformerColumnName = transformer.columnName
val sparkDataType = dataSchema.find(_.name == transformerColumnName) match {
case Some(field) => field.dataType
case None =>
throw AnalysisExceptionFactory.create(
s"Column $transformerColumnName not found in the data schema")
}

transformer match {
case LinearTransformer(_, _) =>
transformerStatsNames.map(statName =>
StructField(statName, sparkDataType, nullable = true))
case CDFNumericQuantilesTransformer(_, _) =>
transformerStatsNames.map(statName =>
StructField(statName, ArrayType(DoubleType), nullable = true))
case CDFStringQuantilesTransformer(_) =>
transformerStatsNames.map(statName =>
StructField(statName, ArrayType(StringType), nullable = true))
case StringHistogramTransformer(_, _) =>
println("string hist")
transformerStatsNames.map(statName =>
StructField(statName, ArrayType(StringType), nullable = true))
case _ => // TODO: Add support for other transformers
Seq.empty
}
})
columnStatsSchema
}

/**
* Builds the column stats row
*
* @param stats
* the stats in a JSON string
* @param columnStatsSchema
* the column stats schema
* @return
*/

def buildColumnStatsRow(stats: String, columnStatsSchema: StructType): Row = {
// If the stats are empty, return an empty row
if (stats.isEmpty) return Row.empty
// Otherwise, parse the stats
val spark = SparkSession.active
import spark.implicits._
val columnStatsJSON = Seq(stats).toDS()
val row = spark.read
.option("inferTimestamp", "true")
.option("timestampFormat", "yyyy-MM-dd HH:mm:ss.SSSSSS'Z'")
.schema(columnStatsSchema)
.json(columnStatsJSON)
.first()
// If the stats are non-empty, and the row values are null,
// we assume that the stats are not in the correct format
val areAllStatsNull = row.toSeq.forall(f => f == null)
if (areAllStatsNull) {
throw AnalysisExceptionFactory.create(
s"The columnStats provided is not a valid JSON: $stats")
}
// return row
row
}

/**
* Builds the QbeastColumnStats
*
* @param statsString
* the stats in a JSON string
* @param columnTransformers
* the set of columnTransformers to build the Stats from
* @param dataSchema
* the data schema to build the Stats from
* @return
*/
def build(
statsString: String,
columnTransformers: Seq[Transformer],
dataSchema: StructType): QbeastColumnStats = {
val columnStatsSchema = buildColumnStatsSchema(dataSchema, columnTransformers)
val columnStatsRow = buildColumnStatsRow(statsString, columnStatsSchema)
QbeastColumnStats(columnStatsSchema, columnStatsRow)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package io.qbeast.spark.index

import io.qbeast.core.model.ColumnToIndex
import io.qbeast.core.model.QTableID
import io.qbeast.core.model.QbeastColumnStatsBuilder
import io.qbeast.core.model.QbeastOptions
import io.qbeast.core.model.Revision
import io.qbeast.core.model.RevisionChange
Expand All @@ -33,7 +34,6 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.AnalysisExceptionFactory
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession

/**
* Spark implementation of RevisionBuilder
Expand Down Expand Up @@ -86,19 +86,31 @@ trait SparkRevisionChangesUtils extends StagingUtils with Logging {
options: QbeastOptions,
data: DataFrame): (Option[RevisionChange], Long) = {
checkColumnChanges(revision, options)
// 1. Compute transformer changes
val transformerChanges =
computeTransformerChanges(revision.columnTransformers, options, data.schema)
// 2. Update transformers if necessary
val updatedTransformers =
computeUpdatedTransformers(revision.columnTransformers, transformerChanges)
val statsRow = getDataFrameStats(data, updatedTransformers)
val numElements = statsRow.getAs[Long]("count")
// 3. Get the stats from the DataFrame
val dataFrameStats = getDataFrameStats(data, updatedTransformers)
val numElements = dataFrameStats.getAs[Long]("count")
// 4. Compute the cube size changes
val cubeSizeChanges = computeCubeSizeChanges(revision, options)
// 5. Compute the Transformation changes given the input data and the user input
val transformationChanges =
computeTransformationChanges(
updatedTransformers,
revision.transformations,
options,
statsRow)
dataFrameStats,
data.schema)
// 6. Return RevisionChanges.
//
// Revision should change if:
// - Cube Size has changed
// - Transformer types had changed
// - Transformations have changed
val hasRevisionChanges =
cubeSizeChanges.isDefined ||
transformerChanges.flatten.nonEmpty ||
Expand Down Expand Up @@ -237,12 +249,17 @@ trait SparkRevisionChangesUtils extends StagingUtils with Logging {
transformers: IISeq[Transformer],
transformations: IISeq[Transformation],
options: QbeastOptions,
row: Row): IISeq[Option[Transformation]] = {
// Compute transformations from columnStats and DataFrame stats, and merge them
row: Row,
dataSchema: StructType): IISeq[Option[Transformation]] = {
// Compute transformations from dataFrameStats
val transformationsFromDataFrameStats =
computeTransformationsFromDataFrameStats(transformers, row)

// Compute transformations from columnStats
val transformationsFromColumnsStats =
computeTransformationsFromColumnStats(transformers, options)
computeTransformationsFromColumnStats(transformers, options, dataSchema)

// Merge transformations from DataFrame and columnStats
val newTransformations = transformationsFromDataFrameStats
.zip(transformationsFromColumnsStats)
.map {
Expand Down Expand Up @@ -283,15 +300,26 @@ trait SparkRevisionChangesUtils extends StagingUtils with Logging {
*/
private[index] def computeTransformationsFromColumnStats(
transformers: IISeq[Transformer],
options: QbeastOptions): IISeq[Option[Transformation]] = {
val (columnStats, availableColumnStats) = parseColumnStats(options)
options: QbeastOptions,
dataSchema: StructType): IISeq[Option[Transformation]] = {
// 1. Get the columnStats from the options
val columnStatsString = options.columnStats.getOrElse("")
// 2. Build the QbeastColumnStats
val qbeastColumnStats =
QbeastColumnStatsBuilder.build(columnStatsString, transformers, dataSchema)
// 3. Compute transformations from the columnStats
val columnStatsRow = qbeastColumnStats.columnStatsRow
transformers.map { t =>
if (t.stats.statsNames.forall(availableColumnStats.contains)) {
try {
// Create transformation with columnStats
Some(t.makeTransformation(columnStats.getAs[Object]))
} else {
// Ignore the transformation if the stats are not available
None
Some(t.makeTransformation(columnStatsRow.getAs[Object]))
} catch {
case e: Throwable =>
logWarning(
s"Error creating transformation for column ${t.columnName} with columnStats: $columnStatsString",
e)
// Ignore the transformation if the stats are not available
None
}
}
}
Expand All @@ -310,22 +338,4 @@ trait SparkRevisionChangesUtils extends StagingUtils with Logging {
transformers.map(_.makeTransformation(row.getAs[Object]))
}

private[index] def parseColumnStats(options: QbeastOptions): (Row, Set[String]) = {
val (row, statsNames) = if (options.columnStats.isDefined) {
val spark = SparkSession.active
import spark.implicits._
val stats = spark.read
.option("inferTimestamp", "true")
.option("timestampFormat", "yyyy-MM-dd HH:mm:ss.SSSSSS'Z'")
.json(Seq(options.columnStats.get).toDS())
.first()
(stats, stats.schema.fieldNames.toSet)
} else (Row.empty, Set.empty[String])
if (statsNames.contains("_corrupt_record")) {
throw AnalysisExceptionFactory.create(
"The columnStats provided is not a valid JSON: " + row.getAs[String]("_corrupt_record"))
}
(row, statsNames)
}

}
8 changes: 8 additions & 0 deletions src/test/scala/io/qbeast/TestClasses.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ object TestClasses {
case class TestBigDecimal(a: BigDecimal, b: BigDecimal, c: BigDecimal)
case class TestInt(a: Int, b: Int, c: Int)
case class TestLong(a: Long, b: Long, c: Long)

case class TestAll(
string_value: String,
double_value: Double,
float_value: Float,
int_value: Int,
long_value: Long)

case class TestNull(a: Option[String], b: Option[Double], c: Option[Long])

case class IndexData(id: Long, cube: Array[Byte], weight: Double)
Expand Down
Loading
Loading