From f79db552b9a47e9d05620952f52b886d68731f7f Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 15 Oct 2015 14:22:36 +0300 Subject: [PATCH] Added 'MoreMath' tests --- .../org/jetbrains/bio/viktor/MoreMath.kt | 48 +++++-------------- .../org/jetbrains/bio/viktor/StridedVector.kt | 8 ++-- .../org/jetbrains/bio/viktor/MoreMathTests.kt | 40 ++++++++++++++++ 3 files changed, 55 insertions(+), 41 deletions(-) create mode 100644 src/test/kotlin/org/jetbrains/bio/viktor/MoreMathTests.kt diff --git a/src/main/kotlin/org/jetbrains/bio/viktor/MoreMath.kt b/src/main/kotlin/org/jetbrains/bio/viktor/MoreMath.kt index 5803f1c..6324b58 100644 --- a/src/main/kotlin/org/jetbrains/bio/viktor/MoreMath.kt +++ b/src/main/kotlin/org/jetbrains/bio/viktor/MoreMath.kt @@ -3,30 +3,18 @@ package org.jetbrains.bio.viktor import org.apache.commons.math3.util.FastMath /** - * Useful mathematical routines absent in [java.util.Math] - * and [org.apache.commons.math3.util.FastMath]. + * Evaluates log(exp(a) + exp(b)) using the following trick * - * When adding new functionality please consider reading - * http://blog.juma.me.uk/2011/02/23/performance-of-fastmath-from-commons-math. + * log(exp(a) + log(exp(b)) = a + log(1 + exp(b - a)) * - * @author Alexey Dievsky - * @author Sergei Lebedev - * @since 0.1.0 + * assuming a >= b. */ -object MoreMath { - /** - * Evaluates log(exp(a) + exp(b)) using the following trick - * - * log(exp(a) + log(exp(b)) = a + log(1 + exp(b - a)) - * - * assuming a >= b. - */ - @JvmStatic fun logAddExp(a: Double, b: Double): Double { - return when { - a.isInfinite() && a < 0 -> b - b.isInfinite() && b < 0 -> a - else -> Math.max(a, b) + StrictMath.log1p(FastMath.exp(-Math.abs(a - b))) - } +fun Double.logAddExp(b: Double): Double { + val a = this + return when { + a.isInfinite() && a < 0 -> b + b.isInfinite() && b < 0 -> a + else -> Math.max(a, b) + StrictMath.log1p(FastMath.exp(-Math.abs(a - b))) } } @@ -42,7 +30,7 @@ object MoreMath { * @author Alexey Dievsky * @since 0.1.0 */ -class KahanSum private constructor(private var accumulator: Double) { +class KahanSum @JvmOverloads constructor(private var accumulator: Double = 0.0) { private var compensator = 0.0 /** @@ -68,19 +56,5 @@ class KahanSum private constructor(private var accumulator: Double) { /** * Returns the sum accumulated so far. */ - fun result(): Double = accumulator + compensator - - companion object { - /** - * Creates and returns a zero-initiated accumulator which can be - * fed doubles and polled for the accumulated sum. - */ - @JvmStatic fun create(): KahanSum = create(0.0) - - /** - * Creates and returns an accumulator which can be fed - * doubles and polled for the accumulated sum. - */ - @JvmStatic fun create(initial: Double): KahanSum = KahanSum(initial) - } + fun result() = accumulator + compensator } diff --git a/src/main/kotlin/org/jetbrains/bio/viktor/StridedVector.kt b/src/main/kotlin/org/jetbrains/bio/viktor/StridedVector.kt index 9a77eaa..3c1b765 100644 --- a/src/main/kotlin/org/jetbrains/bio/viktor/StridedVector.kt +++ b/src/main/kotlin/org/jetbrains/bio/viktor/StridedVector.kt @@ -219,7 +219,7 @@ open class StridedVector(protected val data: DoubleArray, } open fun sum(): Double { - val acc = KahanSum.create() + val acc = KahanSum() for (pos in 0..size - 1) { acc += unsafeGet(pos) } @@ -228,7 +228,7 @@ open class StridedVector(protected val data: DoubleArray, } open fun cumSum() { - val acc = KahanSum.create() + val acc = KahanSum() for (pos in 0..size - 1) { acc += unsafeGet(pos) unsafeSet(pos, acc.result()) @@ -307,7 +307,7 @@ open class StridedVector(protected val data: DoubleArray, open fun logSumExp(): Double { val offset = max() - val sum = KahanSum.create() + val sum = KahanSum() for (pos in 0..size - 1) { sum += FastMath.exp(unsafeGet(pos) - offset) } @@ -325,7 +325,7 @@ open class StridedVector(protected val data: DoubleArray, checkSize(other) checkSize(dst) for (pos in 0..size - 1) { - dst.unsafeSet(pos, MoreMath.logAddExp(unsafeGet(pos), other.unsafeGet(pos))) + dst.unsafeSet(pos, unsafeGet(pos) logAddExp other.unsafeGet(pos)) } } diff --git a/src/test/kotlin/org/jetbrains/bio/viktor/MoreMathTests.kt b/src/test/kotlin/org/jetbrains/bio/viktor/MoreMathTests.kt new file mode 100644 index 0000000..d928c45 --- /dev/null +++ b/src/test/kotlin/org/jetbrains/bio/viktor/MoreMathTests.kt @@ -0,0 +1,40 @@ +package org.jetbrains.bio.viktor + +import org.junit.Test +import java.util.* +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class MoreMathTest { + @Test fun testLogAddExpEdgeCases() { + val r = Random() + val logx = -Math.abs(r.nextDouble()) + assertEquals(logx, Double.NEGATIVE_INFINITY logAddExp logx) + assertEquals(logx, logx logAddExp Double.NEGATIVE_INFINITY) + assertEquals(Double.NEGATIVE_INFINITY, + Double.NEGATIVE_INFINITY logAddExp Double.NEGATIVE_INFINITY) + } +} + +class KahanSumTest { + @Test fun testPrecision() { + val bigNumber = 10000000 + for (d in 9..15) { + // note that in each case 1/d is not precisely representable as a double, + // which is bound to lead to accumulating rounding errors. + val oneDth = 1.0 / d + val preciseSum = KahanSum() + var impreciseSum = 0.0 + for (i in 0..bigNumber * d - 1) { + preciseSum += oneDth + impreciseSum += oneDth + } + + val imprecision = Math.abs(impreciseSum - bigNumber) + val precision = Math.abs(preciseSum.result() - bigNumber) + assertTrue(imprecision >= precision, + "Kahan's algorithm yielded worse precision than ordinary summation: " + + "$precision is greater than $imprecision") + } + } +} \ No newline at end of file