Skip to content

Commit

Permalink
Issue #506: Add relative position calculation in CDFNumericQuantilesT…
Browse files Browse the repository at this point in the history
…ransformation (#508)
  • Loading branch information
osopardo1 authored Dec 16, 2024
1 parent adf8b39 commit 66b8d79
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import com.fasterxml.jackson.databind.SerializerProvider
import io.qbeast.core.model.OrderedDataType
import org.apache.spark.annotation.Experimental

import scala.collection.Searching._

/**
* CDF Quantiles Transformation for Numeric types
* @param quantiles
Expand All @@ -47,10 +49,10 @@ case class CDFNumericQuantilesTransformation(
extends CDFQuantilesTransformation {
require(quantiles.size > 1, "Quantiles size should be greater than 1")

override def ordering: Ordering[Any] =
override implicit def ordering: Ordering[Any] =
Ordering[Double].asInstanceOf[Ordering[Any]]

override def mapValue(value: Any): Any = {
override def mapValue(value: Any): Double = {
value match {
case v: Double => v
case v: Long => v.toDouble
Expand All @@ -63,6 +65,39 @@ case class CDFNumericQuantilesTransformation(
}
}

override def transform(value: Any): Double = {
// If the value is null, we return 0
if (value == null) return 0d

val currentValue = mapValue(value)

// Otherwise, we search for the value in the quantiles
quantiles.search(currentValue) match {
// If the exact index is found, normalize it to a range [0, 1]
case Found(foundIndex) => foundIndex.toDouble / (quantiles.length - 1)

// If not found, calculate the interpolated relative position
case InsertionPoint(insertionPoint) =>
if (insertionPoint == 0) 0d // Value is below the first quantile
else if (insertionPoint == quantiles.length) 1d // Value is above the last quantile
else {
// InsertionPoint gives the index of the first element in quantiles greater than x.
// Thus, the lowerIndex can safely be derived as insertionPoint - 1.
val lowerIndex = insertionPoint - 1
val upperIndex = insertionPoint
val lowerValue = quantiles(lowerIndex)
val upperValue = quantiles(upperIndex)

// Linear interpolation within the bin
// 1. Calculate the linear value between the two quantiles
val fraction = (currentValue - lowerValue) / (upperValue - lowerValue)
// 2. Normalize the value to the range [0, 1]
val result = (lowerIndex.toDouble + fraction) / (quantiles.length - 1)
result
}
}
}

}

class CDFNumericQuantilesTransformationSerializer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ case class CDFStringQuantilesTransformation(quantiles: IndexedSeq[String])

override val dataType: QDataType = StringDataType

override def ordering: Ordering[Any] =
override implicit def ordering: Ordering[Any] =
Ordering[String].asInstanceOf[Ordering[Any]]

override def mapValue(value: Any): Any = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package io.qbeast.core.transform

import io.qbeast.core.model.DoubleDataType
import io.qbeast.core.model.IntegerDataType
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
Expand All @@ -24,9 +23,31 @@ class CDFNumericQuantilesTransformationTest extends AnyFlatSpec with Matchers {
qt.transform(4) should be(1.0)
}

it should "return correct transformation for insertion point in middle" in {
val qt = CDFNumericQuantilesTransformation(IndexedSeq(1.0, 2.0, 3.0), DoubleDataType)
qt.transform(1.5) should be(0.0)
it should "return correct transformation for insertion point in the bin" in {
val qt = CDFNumericQuantilesTransformation(IndexedSeq(1, 3, 5), orderedDataTypeTest)
// 2 is between 1 and 3, so it should be 0.25 (fraction = 2-1 / 3-1 = 0.5. -> 0 + fraction(0.5) / 2 = 0.25)
qt.transform(2) should be(0.25)
}

it should "return correct transformation for insertion point in the bin with repeated values" in {
val qt = CDFNumericQuantilesTransformation(IndexedSeq(1, 1, 3, 5), orderedDataTypeTest)
// 2 is between 1 and 3 so it should be 0.5
qt.transform(2) should be(0.5)
}

it should "return correct transformation point for all values inside the bin" in {
val quantiles = IndexedSeq(1, 1, 1, 1, 1, 1, 100, 100, 100).map(_.toDouble)
val qt =
CDFNumericQuantilesTransformation(quantiles, orderedDataTypeTest)
val valuesToTest = 2.to(99)
val maxIndexQuantiles = quantiles.size - 1
val results = valuesToTest.map { value =>
val transformation = qt.transform(value)
transformation shouldBe >=(5 / maxIndexQuantiles.toDouble)
transformation shouldBe <=(6 / maxIndexQuantiles.toDouble)
transformation
}
results.sorted should be(results)
}

it should "return true when quantiles are different and neither is default" in {
Expand Down

0 comments on commit 66b8d79

Please sign in to comment.