From 6fe13a21952cde22030c0640133b38c87259d2b2 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 27 Jun 2016 14:51:08 +0300 Subject: [PATCH] Implemented and tested scalar-matrix operations --- CHANGES | 7 ++- README.md | 3 +- .../jetbrains/bio/viktor/DoubleExtensions.kt | 57 ++++++++++++++----- .../org/jetbrains/bio/viktor/StridedMatrix.kt | 4 +- .../bio/viktor/DoubleExtensionsTest.kt | 32 +++++++++-- 5 files changed, 79 insertions(+), 24 deletions(-) diff --git a/CHANGES b/CHANGES index 5a986a1..c0becc5 100644 --- a/CHANGES +++ b/CHANGES @@ -15,8 +15,9 @@ Version 0.3.0 - Added 'StridedVector.log1p' and 'expm1'. - Fixed a bug in scalar division in expressions of the form '1.0 / v'. - Mirrored 'StridedVector' operations in 'StridedMatrix2' and - 'StridedMatrix3'. Scalar-matrix, e.g. '1.0 / m' operations are to be - added in future releases. + 'StridedMatrix3'. +- Extended operator overloads for 'Double' to 'StridedMatrix2' and + 'StidedMatrix3'. Version 0.2.3 ------------- @@ -39,7 +40,7 @@ Released on April 29th 2016 - Added unary operator overloads for 'StridedVector'. - Implemented * and / operations for 'StridedVector'. -- Added extra operator overloads to 'Double', so it is now possible to +- Added extra operator overloads for 'Double', so it is now possible to write (1.0 + v / 2.0). - Fixed 'StridedVector.toString' in case of 'NaN' and infinities. - Changed 'StridedVector.toString' to be more like NumPy for larger vectors. diff --git a/README.md b/README.md index 8a4d67d..e6ec4f6 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,8 @@ Kotlin. Here're some of the highlights: m[0] = StridedVector.full(3, 42.0) // row-view. m[_I, 0] // column-view. m[0] = 42.0 // broadcasting. - m[0] + m[0] // arithmetic operations. + m + 0.5 * m // arithmetic operations. + m[0].exp() + 1.0 // math functions. ``` [ndarray]: http://docs.scipy.org/doc/numpy/reference/arrays.ndarray.html diff --git a/src/main/kotlin/org/jetbrains/bio/viktor/DoubleExtensions.kt b/src/main/kotlin/org/jetbrains/bio/viktor/DoubleExtensions.kt index 14d2768..1a068cb 100644 --- a/src/main/kotlin/org/jetbrains/bio/viktor/DoubleExtensions.kt +++ b/src/main/kotlin/org/jetbrains/bio/viktor/DoubleExtensions.kt @@ -1,3 +1,5 @@ +@file:Suppress("nothing_to_inline") + package org.jetbrains.bio.viktor import org.jetbrains.bio.viktor.NativeSpeedups.unsafeNegate @@ -10,33 +12,58 @@ import org.jetbrains.bio.viktor.NativeSpeedups.unsafeScalarDiv * @since 0.2.2 */ -operator fun Double.minus(other: StridedVector): StridedVector { - val v = other.copy() - if (v is LargeDenseVector) { - unsafeNegate(v.data, v.offset, v.data, v.offset, v.size) - unsafePlusScalar(v.data, v.offset, this, v.data, v.offset, v.size) +private inline fun Double.minusInPlace(other: StridedVector) { + if (other is LargeDenseVector) { + unsafeNegate(other.data, other.offset, + other.data, other.offset, other.size) + unsafePlusScalar(other.data, other.offset, this, + other.data, other.offset, other.size) } else { - for (pos in 0..v.size - 1) { - v.unsafeSet(pos, this - v.unsafeGet(pos)) + for (pos in 0..other.size - 1) { + other.unsafeSet(pos, this - other.unsafeGet(pos)) } } +} +operator fun Double.minus(other: StridedVector): StridedVector { + val v = other.copy() + minusInPlace(v) return v } -operator fun Double.plus(other: StridedVector) = other + this +operator fun > Double.minus(other: T): T { + val m = other.copy() + minusInPlace(m.flatten()) + return m +} -operator fun Double.times(other: StridedVector) = other * this +inline operator fun Double.plus(other: StridedVector) = other + this -operator fun Double.div(other: StridedVector): StridedVector { - val v = other.copy() - if (v is LargeDenseVector) { - unsafeScalarDiv(this, v.data, v.offset, v.data, v.offset, v.size) +inline operator fun > Double.plus(other: T) = other + this + +inline operator fun Double.times(other: StridedVector) = other * this + +inline operator fun > Double.times(other: T) = other * this + +private inline fun Double.divInPlace(other: StridedVector) { + if (other is LargeDenseVector) { + unsafeScalarDiv(this, other.data, other.offset, + other.data, other.offset, other.size) } else { - for (pos in 0..v.size - 1) { - v.unsafeSet(pos, this / v.unsafeGet(pos)) + for (pos in 0..other.size - 1) { + other.unsafeSet(pos, this / other.unsafeGet(pos)) } } +} +operator fun Double.div(other: StridedVector): StridedVector { + val v = other.copy() + divInPlace(v) return v } + +operator fun > Double.div(other: T): T { + val m = other.copy() + divInPlace(m.flatten()) + return m +} \ No newline at end of file diff --git a/src/main/kotlin/org/jetbrains/bio/viktor/StridedMatrix.kt b/src/main/kotlin/org/jetbrains/bio/viktor/StridedMatrix.kt index 96f9a0d..fadeffb 100644 --- a/src/main/kotlin/org/jetbrains/bio/viktor/StridedMatrix.kt +++ b/src/main/kotlin/org/jetbrains/bio/viktor/StridedMatrix.kt @@ -63,7 +63,7 @@ object StridedMatrix { } /** A common interface for whole-matrix operations. */ -internal interface FlatMatrixOps> { +interface FlatMatrixOps> { /** * Returns a flat view of this matrix. * @@ -123,6 +123,8 @@ internal interface FlatMatrixOps> { operator fun unaryMinus() = copy().apply { val v = flatten() + + // XXX this might be slower for small matrices. NativeSpeedups.unsafeNegate(v.data, v.offset, v.data, v.offset, v.size) } diff --git a/src/test/kotlin/org/jetbrains/bio/viktor/DoubleExtensionsTest.kt b/src/test/kotlin/org/jetbrains/bio/viktor/DoubleExtensionsTest.kt index 376cc96..8b4937f 100644 --- a/src/test/kotlin/org/jetbrains/bio/viktor/DoubleExtensionsTest.kt +++ b/src/test/kotlin/org/jetbrains/bio/viktor/DoubleExtensionsTest.kt @@ -4,27 +4,51 @@ import org.junit.Test import kotlin.test.assertEquals class DoubleExtensionsTest { - @Test fun plus() { + @Test fun plusVector() { val v = StridedVector(10) { it.toDouble() } val incremented = StridedVector(10) { it + 1.0 } assertEquals(incremented, 1.0 + v) } - @Test fun minus() { + @Test fun plusMatrix() { + val m = StridedMatrix(10, 2) { i, j -> i + 2.0 * j } + val incremented = StridedMatrix(10, 2) { i, j -> m[i, j] + 1.0 } + assertEquals(incremented, 1.0 + m) + } + + @Test fun minusVector() { val v = StridedVector(10) { it.toDouble() } val reversed = v.copy().apply { reverse() } assertEquals(reversed, 9.0 - v) } - @Test fun times() { + @Test fun minusMatrix() { + val m = StridedMatrix(10, 2) { i, j -> i + 2.0 * j } + val decremented = StridedMatrix(10, 2) { i, j -> 42.0 - m[i, j] } + assertEquals(decremented, 42.0 - m) + } + + @Test fun timesVector() { val v = StridedVector(10) { it.toDouble() } val scaled = StridedVector(10) { it * 42.0 } assertEquals(scaled, 42.0 * v) } - @Test fun div() { + @Test fun timesMatrix() { + val m = StridedMatrix(10, 2) { i, j -> i + 2.0 * j } + val decremented = StridedMatrix(10, 2) { i, j -> 42.0 * m[i, j] } + assertEquals(decremented, 42.0 * m) + } + + @Test fun divVector() { val v = StridedVector(10) { it.toDouble() } val scaled = StridedVector(10) { 1.0 / it } assertEquals(scaled, 1.0 / v) } + + @Test fun divMatrix() { + val m = StridedMatrix(10, 2) { i, j -> i + 2.0 * j } + val decremented = StridedMatrix(10, 2) { i, j -> 42.0 / m[i, j] } + assertEquals(decremented, 42.0 / m) + } }