diff --git a/core/src/main/scala/io/qbeast/core/transform/CDFNumericQuantilesTransformation.scala b/core/src/main/scala/io/qbeast/core/transform/CDFNumericQuantilesTransformation.scala index 0b17403fc..089c23d0c 100644 --- a/core/src/main/scala/io/qbeast/core/transform/CDFNumericQuantilesTransformation.scala +++ b/core/src/main/scala/io/qbeast/core/transform/CDFNumericQuantilesTransformation.scala @@ -31,6 +31,8 @@ import com.fasterxml.jackson.databind.SerializerProvider import io.qbeast.core.model.OrderedDataType import org.apache.spark.annotation.Experimental +import scala.collection.Searching._ + /** * CDF Quantiles Transformation for Numeric types * @param quantiles @@ -47,10 +49,10 @@ case class CDFNumericQuantilesTransformation( extends CDFQuantilesTransformation { require(quantiles.size > 1, "Quantiles size should be greater than 1") - override def ordering: Ordering[Any] = + override implicit def ordering: Ordering[Any] = Ordering[Double].asInstanceOf[Ordering[Any]] - override def mapValue(value: Any): Any = { + override def mapValue(value: Any): Double = { value match { case v: Double => v case v: Long => v.toDouble @@ -63,6 +65,39 @@ case class CDFNumericQuantilesTransformation( } } + override def transform(value: Any): Double = { + // If the value is null, we return 0 + if (value == null) return 0d + + val currentValue = mapValue(value) + + // Otherwise, we search for the value in the quantiles + quantiles.search(currentValue) match { + // If the exact index is found, normalize it to a range [0, 1] + case Found(foundIndex) => foundIndex.toDouble / (quantiles.length - 1) + + // If not found, calculate the interpolated relative position + case InsertionPoint(insertionPoint) => + if (insertionPoint == 0) 0d // Value is below the first quantile + else if (insertionPoint == quantiles.length) 1d // Value is above the last quantile + else { + // InsertionPoint gives the index of the first element in quantiles greater than x. + // Thus, the lowerIndex can safely be derived as insertionPoint - 1. + val lowerIndex = insertionPoint - 1 + val upperIndex = insertionPoint + val lowerValue = quantiles(lowerIndex) + val upperValue = quantiles(upperIndex) + + // Linear interpolation within the bin + // 1. Calculate the linear value between the two quantiles + val fraction = (currentValue - lowerValue) / (upperValue - lowerValue) + // 2. Normalize the value to the range [0, 1] + val result = (lowerIndex.toDouble + fraction) / (quantiles.length - 1) + result + } + } + } + } class CDFNumericQuantilesTransformationSerializer diff --git a/core/src/main/scala/io/qbeast/core/transform/CDFStringQuantilesTransformation.scala b/core/src/main/scala/io/qbeast/core/transform/CDFStringQuantilesTransformation.scala index e0f4dd16c..a7163ad14 100644 --- a/core/src/main/scala/io/qbeast/core/transform/CDFStringQuantilesTransformation.scala +++ b/core/src/main/scala/io/qbeast/core/transform/CDFStringQuantilesTransformation.scala @@ -45,7 +45,7 @@ case class CDFStringQuantilesTransformation(quantiles: IndexedSeq[String]) override val dataType: QDataType = StringDataType - override def ordering: Ordering[Any] = + override implicit def ordering: Ordering[Any] = Ordering[String].asInstanceOf[Ordering[Any]] override def mapValue(value: Any): Any = { diff --git a/core/src/test/scala/io/qbeast/core/transform/CDFNumericQuantilesTransformationTest.scala b/core/src/test/scala/io/qbeast/core/transform/CDFNumericQuantilesTransformationTest.scala index f520fa0d3..88a7792f2 100644 --- a/core/src/test/scala/io/qbeast/core/transform/CDFNumericQuantilesTransformationTest.scala +++ b/core/src/test/scala/io/qbeast/core/transform/CDFNumericQuantilesTransformationTest.scala @@ -1,6 +1,5 @@ package io.qbeast.core.transform -import io.qbeast.core.model.DoubleDataType import io.qbeast.core.model.IntegerDataType import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -24,9 +23,31 @@ class CDFNumericQuantilesTransformationTest extends AnyFlatSpec with Matchers { qt.transform(4) should be(1.0) } - it should "return correct transformation for insertion point in middle" in { - val qt = CDFNumericQuantilesTransformation(IndexedSeq(1.0, 2.0, 3.0), DoubleDataType) - qt.transform(1.5) should be(0.0) + it should "return correct transformation for insertion point in the bin" in { + val qt = CDFNumericQuantilesTransformation(IndexedSeq(1, 3, 5), orderedDataTypeTest) + // 2 is between 1 and 3, so it should be 0.25 (fraction = 2-1 / 3-1 = 0.5. -> 0 + fraction(0.5) / 2 = 0.25) + qt.transform(2) should be(0.25) + } + + it should "return correct transformation for insertion point in the bin with repeated values" in { + val qt = CDFNumericQuantilesTransformation(IndexedSeq(1, 1, 3, 5), orderedDataTypeTest) + // 2 is between 1 and 3 so it should be 0.5 + qt.transform(2) should be(0.5) + } + + it should "return correct transformation point for all values inside the bin" in { + val quantiles = IndexedSeq(1, 1, 1, 1, 1, 1, 100, 100, 100).map(_.toDouble) + val qt = + CDFNumericQuantilesTransformation(quantiles, orderedDataTypeTest) + val valuesToTest = 2.to(99) + val maxIndexQuantiles = quantiles.size - 1 + val results = valuesToTest.map { value => + val transformation = qt.transform(value) + transformation shouldBe >=(5 / maxIndexQuantiles.toDouble) + transformation shouldBe <=(6 / maxIndexQuantiles.toDouble) + transformation + } + results.sorted should be(results) } it should "return true when quantiles are different and neither is default" in {