Skip to content

Commit

Permalink
Implemented and tested scalar-matrix operations
Browse files Browse the repository at this point in the history
  • Loading branch information
superbobry committed Jun 27, 2016
1 parent 6c0c5b0 commit 6fe13a2
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 24 deletions.
7 changes: 4 additions & 3 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ Version 0.3.0
- Added 'StridedVector.log1p' and 'expm1'.
- Fixed a bug in scalar division in expressions of the form '1.0 / v'.
- Mirrored 'StridedVector' operations in 'StridedMatrix2' and
'StridedMatrix3'. Scalar-matrix, e.g. '1.0 / m' operations are to be
added in future releases.
'StridedMatrix3'.
- Extended operator overloads for 'Double' to 'StridedMatrix2' and
'StidedMatrix3'.

Version 0.2.3
-------------
Expand All @@ -39,7 +40,7 @@ Released on April 29th 2016

- Added unary operator overloads for 'StridedVector'.
- Implemented * and / operations for 'StridedVector'.
- Added extra operator overloads to 'Double', so it is now possible to
- Added extra operator overloads for 'Double', so it is now possible to
write (1.0 + v / 2.0).
- Fixed 'StridedVector.toString' in case of 'NaN' and infinities.
- Changed 'StridedVector.toString' to be more like NumPy for larger vectors.
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ Kotlin. Here're some of the highlights:
m[0] = StridedVector.full(3, 42.0) // row-view.
m[_I, 0] // column-view.
m[0] = 42.0 // broadcasting.
m[0] + m[0] // arithmetic operations.
m + 0.5 * m // arithmetic operations.
m[0].exp() + 1.0 // math functions.
```

[ndarray]: http://docs.scipy.org/doc/numpy/reference/arrays.ndarray.html
Expand Down
57 changes: 42 additions & 15 deletions src/main/kotlin/org/jetbrains/bio/viktor/DoubleExtensions.kt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
@file:Suppress("nothing_to_inline")

package org.jetbrains.bio.viktor

import org.jetbrains.bio.viktor.NativeSpeedups.unsafeNegate
Expand All @@ -10,33 +12,58 @@ import org.jetbrains.bio.viktor.NativeSpeedups.unsafeScalarDiv
* @since 0.2.2
*/

operator fun Double.minus(other: StridedVector): StridedVector {
val v = other.copy()
if (v is LargeDenseVector) {
unsafeNegate(v.data, v.offset, v.data, v.offset, v.size)
unsafePlusScalar(v.data, v.offset, this, v.data, v.offset, v.size)
private inline fun Double.minusInPlace(other: StridedVector) {
if (other is LargeDenseVector) {
unsafeNegate(other.data, other.offset,
other.data, other.offset, other.size)
unsafePlusScalar(other.data, other.offset, this,
other.data, other.offset, other.size)
} else {
for (pos in 0..v.size - 1) {
v.unsafeSet(pos, this - v.unsafeGet(pos))
for (pos in 0..other.size - 1) {
other.unsafeSet(pos, this - other.unsafeGet(pos))
}
}
}

operator fun Double.minus(other: StridedVector): StridedVector {
val v = other.copy()
minusInPlace(v)
return v
}

operator fun Double.plus(other: StridedVector) = other + this
operator fun <T : FlatMatrixOps<T>> Double.minus(other: T): T {
val m = other.copy()
minusInPlace(m.flatten())
return m
}

operator fun Double.times(other: StridedVector) = other * this
inline operator fun Double.plus(other: StridedVector) = other + this

operator fun Double.div(other: StridedVector): StridedVector {
val v = other.copy()
if (v is LargeDenseVector) {
unsafeScalarDiv(this, v.data, v.offset, v.data, v.offset, v.size)
inline operator fun <T : FlatMatrixOps<T>> Double.plus(other: T) = other + this

inline operator fun Double.times(other: StridedVector) = other * this

inline operator fun <T : FlatMatrixOps<T>> Double.times(other: T) = other * this

private inline fun Double.divInPlace(other: StridedVector) {
if (other is LargeDenseVector) {
unsafeScalarDiv(this, other.data, other.offset,
other.data, other.offset, other.size)
} else {
for (pos in 0..v.size - 1) {
v.unsafeSet(pos, this / v.unsafeGet(pos))
for (pos in 0..other.size - 1) {
other.unsafeSet(pos, this / other.unsafeGet(pos))
}
}
}

operator fun Double.div(other: StridedVector): StridedVector {
val v = other.copy()
divInPlace(v)
return v
}

operator fun <T : FlatMatrixOps<T>> Double.div(other: T): T {
val m = other.copy()
divInPlace(m.flatten())
return m
}
4 changes: 3 additions & 1 deletion src/main/kotlin/org/jetbrains/bio/viktor/StridedMatrix.kt
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ object StridedMatrix {
}

/** A common interface for whole-matrix operations. */
internal interface FlatMatrixOps<T : FlatMatrixOps<T>> {
interface FlatMatrixOps<T : FlatMatrixOps<T>> {
/**
* Returns a flat view of this matrix.
*
Expand Down Expand Up @@ -123,6 +123,8 @@ internal interface FlatMatrixOps<T : FlatMatrixOps<T>> {

operator fun unaryMinus() = copy().apply {
val v = flatten()

// XXX this might be slower for small matrices.
NativeSpeedups.unsafeNegate(v.data, v.offset, v.data, v.offset, v.size)
}

Expand Down
32 changes: 28 additions & 4 deletions src/test/kotlin/org/jetbrains/bio/viktor/DoubleExtensionsTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,51 @@ import org.junit.Test
import kotlin.test.assertEquals

class DoubleExtensionsTest {
@Test fun plus() {
@Test fun plusVector() {
val v = StridedVector(10) { it.toDouble() }
val incremented = StridedVector(10) { it + 1.0 }
assertEquals(incremented, 1.0 + v)
}

@Test fun minus() {
@Test fun plusMatrix() {
val m = StridedMatrix(10, 2) { i, j -> i + 2.0 * j }
val incremented = StridedMatrix(10, 2) { i, j -> m[i, j] + 1.0 }
assertEquals(incremented, 1.0 + m)
}

@Test fun minusVector() {
val v = StridedVector(10) { it.toDouble() }
val reversed = v.copy().apply { reverse() }
assertEquals(reversed, 9.0 - v)
}

@Test fun times() {
@Test fun minusMatrix() {
val m = StridedMatrix(10, 2) { i, j -> i + 2.0 * j }
val decremented = StridedMatrix(10, 2) { i, j -> 42.0 - m[i, j] }
assertEquals(decremented, 42.0 - m)
}

@Test fun timesVector() {
val v = StridedVector(10) { it.toDouble() }
val scaled = StridedVector(10) { it * 42.0 }
assertEquals(scaled, 42.0 * v)
}

@Test fun div() {
@Test fun timesMatrix() {
val m = StridedMatrix(10, 2) { i, j -> i + 2.0 * j }
val decremented = StridedMatrix(10, 2) { i, j -> 42.0 * m[i, j] }
assertEquals(decremented, 42.0 * m)
}

@Test fun divVector() {
val v = StridedVector(10) { it.toDouble() }
val scaled = StridedVector(10) { 1.0 / it }
assertEquals(scaled, 1.0 / v)
}

@Test fun divMatrix() {
val m = StridedMatrix(10, 2) { i, j -> i + 2.0 * j }
val decremented = StridedMatrix(10, 2) { i, j -> 42.0 / m[i, j] }
assertEquals(decremented, 42.0 / m)
}
}

0 comments on commit 6fe13a2

Please sign in to comment.