Skip to content

Commit

Permalink
Fixed a silly bug in scalar division
Browse files Browse the repository at this point in the history
  1.0 / v /= v / 1.0
  • Loading branch information
superbobry committed Jun 27, 2016
1 parent c2fd5b3 commit 4d3afdb
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Version 0.3.0
- Added SIMD speedups for / and /= operations.
- Added 'StridedVector.log1p' and 'expm1'.
- Removed 'StridedMatrix2.plus' and 'logAddExp'.
- Fixed a bug in scalar division in expressions of the form '1.0 / v'.

Version 0.2.3
-------------
Expand Down
22 changes: 18 additions & 4 deletions src/main/kotlin/org/jetbrains/bio/viktor/DoubleExtensions.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package org.jetbrains.bio.viktor

import org.jetbrains.bio.viktor.NativeSpeedups.unsafeNegate
import org.jetbrains.bio.viktor.NativeSpeedups.unsafePlusScalar
import org.jetbrains.bio.viktor.NativeSpeedups.unsafeScalarDiv

/**
* Operator overloads for [Double] and [StridedVector].
*
Expand All @@ -9,9 +13,8 @@ package org.jetbrains.bio.viktor
operator fun Double.minus(other: StridedVector): StridedVector {
val v = other.copy()
if (v is LargeDenseVector) {
NativeSpeedups.unsafeNegate(v.data, v.offset, v.data, v.offset, v.size)
NativeSpeedups.unsafePlusScalar(
v.data, v.offset, this, v.data, v.offset, v.size)
unsafeNegate(v.data, v.offset, v.data, v.offset, v.size)
unsafePlusScalar(v.data, v.offset, this, v.data, v.offset, v.size)
} else {
for (pos in 0..v.size - 1) {
v.unsafeSet(pos, this - v.unsafeGet(pos))
Expand All @@ -25,4 +28,15 @@ operator fun Double.plus(other: StridedVector) = other + this

operator fun Double.times(other: StridedVector) = other * this

operator fun Double.div(other: StridedVector) = other / this
operator fun Double.div(other: StridedVector): StridedVector {
val v = other.copy()
if (v is LargeDenseVector) {
unsafeScalarDiv(this, v.data, v.offset, v.data, v.offset, v.size)
} else {
for (pos in 0..v.size - 1) {
v.unsafeSet(pos, this / v.unsafeGet(pos))
}
}

return v
}
3 changes: 3 additions & 0 deletions src/main/kotlin/org/jetbrains/bio/viktor/NativeSpeedups.kt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ object NativeSpeedups {
external fun unsafeDivScalar(src1: DoubleArray, srcOffset1: Int, update: Double,
dst: DoubleArray, dstOffset: Int, length: Int)

external fun unsafeScalarDiv(update: Double, src1: DoubleArray, srcOffset1: Int,
dst: DoubleArray, dstOffset: Int, length: Int)

external fun unsafeMin(values: DoubleArray, offset: Int, length: Int): Double

external fun unsafeMax(values: DoubleArray, offset: Int, length: Int): Double
Expand Down
28 changes: 28 additions & 0 deletions src/simd/cpp/org_jetbrains_bio_viktor_NativeSpeedups.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,16 @@ struct div_scalar {
double update_;
};

struct scalar_div {
scalar_div(double const update) : update_(update) {};

template <typename T>
BOOST_FORCEINLINE T operator()(T const &x) const { return update_ / x; }

private:
double update_;
};

}

JNI_METHOD(void, unsafePlusScalar)(JNIEnv *env, jobject,
Expand Down Expand Up @@ -221,6 +231,24 @@ JNI_METHOD(void, unsafeDivScalar)(JNIEnv *env, jobject,
env->ReleasePrimitiveArrayCritical(jdst, dst, JNI_ABORT);
}

JNI_METHOD(void, unsafeScalarDiv)(JNIEnv *env, jobject,
jdouble update,
jdoubleArray jsrc, jint src_offset,
jdoubleArray jdst, jint dst_offset,
jint length)
{
jdouble *src = reinterpret_cast<jdouble *>(
env->GetPrimitiveArrayCritical(jsrc, NULL));
jdouble *dst = reinterpret_cast<jdouble *>(
env->GetPrimitiveArrayCritical(jdst, NULL));
boost::simd::transform(src + src_offset,
src + src_offset + length,
dst + dst_offset,
scalar_div(update));
env->ReleasePrimitiveArrayCritical(jsrc, src, JNI_ABORT);
env->ReleasePrimitiveArrayCritical(jdst, dst, JNI_ABORT);
}

JNI_METHOD(void, unsafeNegate)(JNIEnv *env, jobject,
jdoubleArray jsrc, jint src_offset,
jdoubleArray jdst, jint dst_offset,
Expand Down
20 changes: 19 additions & 1 deletion src/test/kotlin/org/jetbrains/bio/viktor/DoubleExtensionsTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,27 @@ import org.junit.Test
import kotlin.test.assertEquals

class DoubleExtensionsTest {
@Test fun plus() {
val v = StridedVector(10) { it.toDouble() }
val incremented = StridedVector(10) { it + 1.0 }
assertEquals(incremented, 1.0 + v)
}

@Test fun minus() {
val v = StridedVector(10) { it.toDouble() }
val reversed = v.copy().apply { reverse() }
assertEquals(9.0 - v, reversed)
assertEquals(reversed, 9.0 - v)
}

@Test fun times() {
val v = StridedVector(10) { it.toDouble() }
val scaled = StridedVector(10) { it * 42.0 }
assertEquals(scaled, 42.0 * v)
}

@Test fun div() {
val v = StridedVector(10) { it.toDouble() }
val scaled = StridedVector(10) { 1.0 / it }
assertEquals(scaled, 1.0 / v)
}
}

0 comments on commit 4d3afdb

Please sign in to comment.