From 6c0c5b07cc31377da784220da0a3034fa76dfc8f Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 27 Jun 2016 13:40:38 +0300 Subject: [PATCH] Mirrored 'StridedVector' API for matrices --- CHANGES | 4 +- .../org/jetbrains/bio/viktor/StridedMatrix.kt | 74 +++++++- .../jetbrains/bio/viktor/StridedMatrix2.kt | 43 +++-- .../jetbrains/bio/viktor/StridedMatrix3.kt | 32 ++-- .../org/jetbrains/bio/viktor/StridedVector.kt | 9 +- ...rg_jetbrains_bio_viktor_NativeSpeedups.cpp | 6 +- .../bio/viktor/StridedMatrix2Tests.kt | 146 +++++++++++++++ .../bio/viktor/StridedMatrix3Tests.kt | 137 ++++++++++++++ .../bio/viktor/StridedMatrixTests.kt | 167 ------------------ .../bio/viktor/StridedVectorTests.kt | 24 --- 10 files changed, 407 insertions(+), 235 deletions(-) create mode 100644 src/test/kotlin/org/jetbrains/bio/viktor/StridedMatrix2Tests.kt create mode 100644 src/test/kotlin/org/jetbrains/bio/viktor/StridedMatrix3Tests.kt delete mode 100644 src/test/kotlin/org/jetbrains/bio/viktor/StridedMatrixTests.kt diff --git a/CHANGES b/CHANGES index e54fe2f..5a986a1 100644 --- a/CHANGES +++ b/CHANGES @@ -13,8 +13,10 @@ Version 0.3.0 - Added 'StridedVector.sd' for computing unbiased standard deviation. - Added SIMD speedups for / and /= operations. - Added 'StridedVector.log1p' and 'expm1'. -- Removed 'StridedMatrix2.plus' and 'logAddExp'. - Fixed a bug in scalar division in expressions of the form '1.0 / v'. +- Mirrored 'StridedVector' operations in 'StridedMatrix2' and + 'StridedMatrix3'. Scalar-matrix, e.g. '1.0 / m' operations are to be + added in future releases. Version 0.2.3 ------------- diff --git a/src/main/kotlin/org/jetbrains/bio/viktor/StridedMatrix.kt b/src/main/kotlin/org/jetbrains/bio/viktor/StridedMatrix.kt index b1450b9..96f9a0d 100644 --- a/src/main/kotlin/org/jetbrains/bio/viktor/StridedMatrix.kt +++ b/src/main/kotlin/org/jetbrains/bio/viktor/StridedMatrix.kt @@ -63,7 +63,7 @@ object StridedMatrix { } /** A common interface for whole-matrix operations. */ -internal interface FlatMatrixOps { +internal interface FlatMatrixOps> { /** * Returns a flat view of this matrix. * @@ -74,6 +74,9 @@ internal interface FlatMatrixOps { /** Returns the copy of this matrix. */ fun copy(): T + /** Ensures a given matrix has the same dimensions as this matrix. */ + fun checkDimensions(other: T) + fun fill(init: Double) = flatten().fill(init) fun mean() = flatten().mean() @@ -107,4 +110,71 @@ internal interface FlatMatrixOps { fun log1pInPlace() = flatten().logInPlace() fun log1p() = copy().apply { log1pInPlace() } -} + + infix fun logAddExp(other: T): T = copy().apply { logAddExp(other, this) } + + fun logAddExp(other: T, dst: T) { + checkDimensions(other) + checkDimensions(dst) + flatten().logAddExp(other.flatten(), dst.flatten()) + } + + operator fun unaryPlus() = this + + operator fun unaryMinus() = copy().apply { + val v = flatten() + NativeSpeedups.unsafeNegate(v.data, v.offset, v.data, v.offset, v.size) + } + + operator fun plus(other: T) = copy().apply { this += other } + + operator open fun plusAssign(other: T) { + checkDimensions(other) + flatten() += other.flatten() + } + + operator fun plus(update: Double) = copy().apply { this += update } + + operator open fun plusAssign(update: Double) { + flatten() += update + } + + operator fun minus(other: T) = copy().apply { this -= other } + + operator open fun minusAssign(other: T) { + checkDimensions(other) + flatten() -= other.flatten() + } + + operator fun minus(update: Double) = copy().apply { this -= update } + + operator open fun minusAssign(update: Double) { + flatten() -= update + } + + operator fun times(other: T) = copy().apply { this *= other } + + operator open fun timesAssign(other: T) { + checkDimensions(other) + flatten() *= other.flatten() + } + + operator fun times(update: Double) = copy().apply { this *= update } + + operator open fun timesAssign(update: Double) { + flatten() *= update + } + + operator fun div(other: T) = copy().apply { this /= other } + + operator open fun divAssign(other: T) { + checkDimensions(other) + flatten() /= other.flatten() + } + + operator fun div(update: Double) = copy().apply { this /= update } + + operator open fun divAssign(update: Double) { + flatten() /= update + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/jetbrains/bio/viktor/StridedMatrix2.kt b/src/main/kotlin/org/jetbrains/bio/viktor/StridedMatrix2.kt index 9fa7811..8eff2dd 100644 --- a/src/main/kotlin/org/jetbrains/bio/viktor/StridedMatrix2.kt +++ b/src/main/kotlin/org/jetbrains/bio/viktor/StridedMatrix2.kt @@ -58,7 +58,10 @@ class StridedMatrix2 internal constructor( /** Returns a view of the [r]-th row of this matrix. */ fun rowView(r: Int): StridedVector { - require(r >= 0 && r < rowsNumber) { "r must be in [0, $rowsNumber)" } + if (r < 0 || r >= rowsNumber) { + throw IndexOutOfBoundsException("r must be in [0, $rowsNumber)") + } + return StridedVector.create(data, offset + rowStride * r, columnsNumber, columnStride) } @@ -66,7 +69,10 @@ class StridedMatrix2 internal constructor( * Returns a view of the [c]-th column of this matrix. */ fun columnView(c: Int): StridedVector { - require(c >= 0 && c < columnsNumber) { "c must be in [0, $columnsNumber)" } + if (c < 0 || c >= columnsNumber) { + throw IndexOutOfBoundsException("c must be in [0, $columnsNumber)") + } + return StridedVector.create(data, offset + columnStride * c, rowsNumber, rowStride) } @@ -93,8 +99,6 @@ class StridedMatrix2 internal constructor( operator fun set(any: _I, c: Int, init: Double) = columnView(c).fill(init) - operator fun set(row: Int, any: _I, init: Double) = columnView(row).fill(init) - /** An alias for [transpose]. */ val T: StridedMatrix2 get() = transpose() @@ -102,17 +106,6 @@ class StridedMatrix2 internal constructor( fun transpose() = StridedMatrix2(columnsNumber, rowsNumber, data, offset, columnStride, rowStride) - /** - * Flattens the matrix into a vector in O(1) time. - * - * No data copying is performed, thus the operation is only applicable - * to dense matrices. - */ - override fun flatten(): StridedVector { - check(isDense) { "matrix is not dense" } - return data.asStrided(offset, rowsNumber * columnsNumber) - } - /** Returns a copy of the elements in this matrix. */ override fun copy(): StridedMatrix2 { val copy = StridedMatrix2(rowsNumber, columnsNumber) @@ -133,6 +126,17 @@ class StridedMatrix2 internal constructor( } } + /** + * Flattens the matrix into a vector in O(1) time. + * + * No data copying is performed, thus the operation is only applicable + * to dense matrices. + */ + override fun flatten(): StridedVector { + check(isDense) { "matrix is not dense" } + return data.asStrided(offset, rowsNumber * columnsNumber) + } + /** * Returns a stream of row or column views of the matrix. * @@ -198,9 +202,16 @@ class StridedMatrix2 internal constructor( return acc } - private fun checkDimensions(other: StridedMatrix2) { + override fun checkDimensions(other: StridedMatrix2) { check(this === other || (rowsNumber == other.rowsNumber && columnsNumber == other.columnsNumber)) { "non-conformable matrices" } } } + +/** Reshapes this vector into a matrix in row-major order. */ +fun StridedVector.reshape(numRows: Int, numColumns: Int): StridedMatrix2 { + require(numRows * numColumns == size) + return StridedMatrix2(numRows, numColumns, data, offset, + numColumns * stride, stride) +} diff --git a/src/main/kotlin/org/jetbrains/bio/viktor/StridedMatrix3.kt b/src/main/kotlin/org/jetbrains/bio/viktor/StridedMatrix3.kt index 94a811a..3700924 100644 --- a/src/main/kotlin/org/jetbrains/bio/viktor/StridedMatrix3.kt +++ b/src/main/kotlin/org/jetbrains/bio/viktor/StridedMatrix3.kt @@ -10,7 +10,7 @@ import java.util.* */ class StridedMatrix3 internal constructor( val depth: Int, val rowsNumber: Int, val columnsNumber: Int, - val data: DoubleArray, + val data: DoubleArray, val offset: Int, val depthStride: Int, val rowStride: Int, val columnStride: Int) : FlatMatrixOps { @@ -18,7 +18,7 @@ class StridedMatrix3 internal constructor( constructor(depth: Int, numRows: Int, numColumns: Int) : this(depth, numRows, numColumns, DoubleArray(depth * numRows * numColumns), - numRows * numColumns, numColumns, 1) { + 0, numRows * numColumns, numColumns, 1) { } /** @@ -58,13 +58,13 @@ class StridedMatrix3 internal constructor( @Suppress("nothing_to_inline") private inline fun unsafeIndex(d: Int, r: Int, c: Int): Int { - return d * depthStride + r * rowStride + c * columnStride + return offset + d * depthStride + r * rowStride + c * columnStride } override fun copy(): StridedMatrix3 { - val copy = StridedMatrix(depth, rowsNumber, columnsNumber) - copyTo(copy) - return copy + val m = StridedMatrix(depth, rowsNumber, columnsNumber) + copyTo(m) + return m } fun copyTo(other: StridedMatrix3) { @@ -89,7 +89,10 @@ class StridedMatrix3 internal constructor( operator fun set(d: Int, other: Double) = view(d).fill(other) fun view(d: Int): StridedMatrix2 { - require(d >= 0 && d < depth) { "d must be in [0, $depth)" } + if (d < 0 || d >= depth) { + throw IndexOutOfBoundsException("d must be in [0, $depth)") + } + return StridedMatrix2(rowsNumber, columnsNumber, data, d * depthStride, rowStride, columnStride) } @@ -101,12 +104,6 @@ class StridedMatrix3 internal constructor( operator fun set(d: Int, r: Int, other: StridedVector) = view(d).set(r, other) - fun logAddExp(other: StridedMatrix3, dst: StridedMatrix3) { - checkDimensions(other) - checkDimensions(dst) - flatten().logAddExp(other.flatten(), dst.flatten()) - } - fun toArray() = Array(depth) { view(it).toArray() } override fun toString() = Arrays.deepToString(toArray()) @@ -141,9 +138,16 @@ class StridedMatrix3 internal constructor( return acc } - private fun checkDimensions(other: StridedMatrix3) { + override fun checkDimensions(other: StridedMatrix3) { check(this === other || (depth == other.depth && rowsNumber == other.rowsNumber && columnsNumber == other.columnsNumber)) { "non-conformable matrices" } } } + +/** Reshapes this vector into a matrix in row-major order. */ +fun StridedVector.reshape(depth: Int, numRows: Int, numColumns: Int): StridedMatrix3 { + require(depth * numRows * numColumns == size) + return StridedMatrix3(depth, numRows, numColumns, data, offset, + numRows * numColumns * stride, numColumns * stride, stride) +} diff --git a/src/main/kotlin/org/jetbrains/bio/viktor/StridedVector.kt b/src/main/kotlin/org/jetbrains/bio/viktor/StridedVector.kt index 223b7e5..6af62c3 100644 --- a/src/main/kotlin/org/jetbrains/bio/viktor/StridedVector.kt +++ b/src/main/kotlin/org/jetbrains/bio/viktor/StridedVector.kt @@ -154,13 +154,6 @@ open class StridedVector internal constructor( } } - /** Reshapes this vector into a matrix in row-major order. */ - fun reshape(numRows: Int, numColumns: Int): StridedMatrix2 { - require(numRows * numColumns == size) - return StridedMatrix2(numRows, numColumns, data, offset, - numColumns * stride, stride) - } - /** * Computes a dot product of this vector with an array. */ @@ -585,4 +578,4 @@ open class StridedVector internal constructor( } } } -} +} \ No newline at end of file diff --git a/src/simd/cpp/org_jetbrains_bio_viktor_NativeSpeedups.cpp b/src/simd/cpp/org_jetbrains_bio_viktor_NativeSpeedups.cpp index 839aa8f..9d76081 100644 --- a/src/simd/cpp/org_jetbrains_bio_viktor_NativeSpeedups.cpp +++ b/src/simd/cpp/org_jetbrains_bio_viktor_NativeSpeedups.cpp @@ -492,9 +492,9 @@ JNI_METHOD(jdouble, sd)(JNIEnv *env, jobject, } JNI_METHOD(void, cumSum)(JNIEnv *env, jobject, - jdoubleArray jsrc, jint src_offset, - jdoubleArray jdst, jint dst_offset, - jint length) + jdoubleArray jsrc, jint src_offset, + jdoubleArray jdst, jint dst_offset, + jint length) { jdouble *src = (jdouble *) env->GetPrimitiveArrayCritical(jsrc, NULL); jdouble *dst = (jdouble *) env->GetPrimitiveArrayCritical(jdst, NULL); diff --git a/src/test/kotlin/org/jetbrains/bio/viktor/StridedMatrix2Tests.kt b/src/test/kotlin/org/jetbrains/bio/viktor/StridedMatrix2Tests.kt new file mode 100644 index 0000000..6ea8e14 --- /dev/null +++ b/src/test/kotlin/org/jetbrains/bio/viktor/StridedMatrix2Tests.kt @@ -0,0 +1,146 @@ +package org.jetbrains.bio.viktor + +import org.apache.commons.math3.util.Precision +import org.junit.Assert.assertArrayEquals +import org.junit.Assert.assertEquals +import org.junit.Test +import kotlin.test.assertNotEquals + +class StridedMatrix2Slicing { + private val m = StridedVector.of(0.0, 1.0, + 2.0, 3.0, + 4.0, 5.0).reshape(3, 2) + + @Test fun transposeUnit() { + val m = StridedMatrix(1, 1) + assertEquals(m, m.T) + } + + @Test fun transpose() { + assertEquals(StridedVector.of(0.0, 2.0, 4.0, + 1.0, 3.0, 5.0).reshape(2, 3), + m.T) + } + + @Test fun rowView() { + assertEquals(StridedVector.of(0.0, 1.0), m.rowView(0)) + assertEquals(StridedVector.of(2.0, 3.0), m.rowView(1)) + assertEquals(StridedVector.of(4.0, 5.0), m.rowView(2)) + } + + @Test(expected = IndexOutOfBoundsException::class) fun rowViewOutOfBounds() { + m.rowView(42) + } + + @Test fun columnView() { + assertEquals(StridedVector.of(0.0, 2.0, 4.0), m.columnView(0)) + assertEquals(StridedVector.of(1.0, 3.0, 5.0), m.columnView(1)) + } + + @Test(expected = IndexOutOfBoundsException::class) fun columnViewOutOfBounds() { + m.columnView(42) + } + + @Test fun reshape() { + val v = StridedVector.of(0.0, 1.0, 2.0, 3.0, 4.0, 5.0) + assertArrayEquals(arrayOf(doubleArrayOf(0.0, 1.0, 2.0), + doubleArrayOf(3.0, 4.0, 5.0)), + v.reshape(2, 3).toArray()) + assertArrayEquals(arrayOf(doubleArrayOf(0.0, 1.0), + doubleArrayOf(2.0, 3.0), + doubleArrayOf(4.0, 5.0)), + v.reshape(3, 2).toArray()) + } + + @Test fun reshapeWithStride() { + val v = StridedVector.create(doubleArrayOf(0.0, 1.0, 2.0, 3.0, + 4.0, 5.0, 6.0, 7.0), + 0, 4, stride = 2) + assertArrayEquals(arrayOf(doubleArrayOf(0.0, 2.0), + doubleArrayOf(4.0, 6.0)), + v.reshape(2, 2).toArray()) + } +} + +class StridedMatrix2GetSet { + private val m = StridedVector.of(0.0, 1.0, + 2.0, 3.0, + 4.0, 5.0).reshape(3, 2) + + @Test fun get() { + assertEquals(0.0, m[0, 0], Precision.EPSILON) + assertEquals(1.0, m[0, 1], Precision.EPSILON) + assertEquals(2.0, m[1, 0], Precision.EPSILON) + assertEquals(3.0, m[1, 1], Precision.EPSILON) + assertEquals(4.0, m[2, 0], Precision.EPSILON) + assertEquals(5.0, m[2, 1], Precision.EPSILON) + } + + @Test(expected = IndexOutOfBoundsException::class) fun getOutOfBounds() { + m[42, 42] + } + + @Test fun set() { + val copy = m.copy() + copy[0, 1] = 42.0 + assertEquals(42.0, copy[0, 1], Precision.EPSILON) + } + + @Test(expected = IndexOutOfBoundsException::class) fun setOutOfBounds() { + m[42, 42] = 100500.0 + } + + @Test fun setMagicRowScalar() { + val copy = m.copy() + copy[0] = 42.0 + assertEquals(StridedVector.full(copy.columnsNumber, 42.0), copy[0]) + } + + @Test fun setMagicRowVector() { + val copy = m.copy() + val v = StridedVector.full(copy.columnsNumber, 42.0) + copy[0] = v + assertEquals(v, copy[0]) + + for (r in 1..copy.rowsNumber - 1) { + assertNotEquals(v, copy[r]) + assertEquals(m[r], copy[r]) + } + } + + @Test fun setMagicColumnScalar() { + val copy = m.copy() + copy[_I, 0] = 42.0 + assertEquals(StridedVector.full(copy.rowsNumber, 42.0), copy[_I, 0]) + } + + @Test fun setMagicColumnVector() { + val copy = m.copy() + val v = StridedVector.full(copy.rowsNumber, 42.0) + copy[_I, 0] = v + assertEquals(v, copy[_I, 0]) + + for (c in 1..copy.columnsNumber - 1) { + assertNotEquals(v, copy[_I, c]) + assertEquals(m[_I, c], copy[_I, c]) + } + } +} + +class StridedMatrix2OpsTest { + @Test fun equals() { + val m = StridedVector.of(0.0, 1.0, + 2.0, 3.0, + 4.0, 5.0).reshape(3, 2) + + assertEquals(m, m) + assertEquals(m, m.copy()) + assertNotEquals(m, m.T) + } + + @Test fun _toString() { + assertEquals("[]", StridedMatrix(0, 0).toString()) + assertEquals("[[]]", StridedMatrix(1, 0).toString()) + assertEquals("[[0.0], [0.0]]", StridedMatrix(2, 1).toString()) + } +} \ No newline at end of file diff --git a/src/test/kotlin/org/jetbrains/bio/viktor/StridedMatrix3Tests.kt b/src/test/kotlin/org/jetbrains/bio/viktor/StridedMatrix3Tests.kt new file mode 100644 index 0000000..1e3501e --- /dev/null +++ b/src/test/kotlin/org/jetbrains/bio/viktor/StridedMatrix3Tests.kt @@ -0,0 +1,137 @@ +package org.jetbrains.bio.viktor + +import org.apache.commons.math3.util.Precision +import org.junit.Assert.assertArrayEquals +import org.junit.Assert.assertEquals +import org.junit.Test +import kotlin.test.assertNotEquals + +class StridedMatrix3Slicing { + private val m = StridedVector.of(0.0, 1.0, + 2.0, 3.0, + 4.0, 5.0).reshape(3, 1, 2) + + @Test fun view() { + assertEquals(StridedVector.of(0.0, 1.0).reshape(1, 2), m.view(0)) + assertEquals(StridedVector.of(2.0, 3.0).reshape(1, 2), m.view(1)) + assertEquals(StridedVector.of(4.0, 5.0).reshape(1, 2), m.view(2)) + } + + @Test(expected = IndexOutOfBoundsException::class) fun viewOutOfBounds() { + m.view(42) + } + + @Test fun reshape() { + val v = StridedVector.of(0.0, 1.0, + 2.0, 3.0, + 4.0, 5.0) + assertArrayEquals(arrayOf(arrayOf(doubleArrayOf(0.0, 1.0)), + arrayOf(doubleArrayOf(2.0, 3.0)), + arrayOf(doubleArrayOf(4.0, 5.0))), + v.reshape(3, 1, 2).toArray()) + assertArrayEquals(arrayOf(arrayOf(doubleArrayOf(0.0), + doubleArrayOf(1.0)), + arrayOf(doubleArrayOf(2.0), + doubleArrayOf(3.0)), + arrayOf(doubleArrayOf(4.0), + doubleArrayOf(5.0))), + v.reshape(3, 2, 1).toArray()) + } + + @Test fun reshapeWithStride() { + val v = StridedVector.create(doubleArrayOf(0.0, 1.0, 2.0, 3.0, + 4.0, 5.0, 6.0, 7.0), + 0, 4, stride = 2) + assertArrayEquals(arrayOf(arrayOf(doubleArrayOf(0.0, 2.0)), + arrayOf(doubleArrayOf(4.0, 6.0))), + v.reshape(2, 1, 2).toArray()) + assertArrayEquals(arrayOf(arrayOf(doubleArrayOf(0.0), + doubleArrayOf(2.0)), + arrayOf(doubleArrayOf(4.0), + doubleArrayOf(6.0))), + v.reshape(2, 2, 1).toArray()) + } +} + +class StridedMatrix3GetSet { + private val m = StridedVector.of(0.0, 1.0, + 2.0, 3.0, + 4.0, 5.0).reshape(3, 1, 2) + + @Test fun get() { + assertEquals(0.0, m[0, 0, 0], Precision.EPSILON) + assertEquals(1.0, m[0, 0, 1], Precision.EPSILON) + assertEquals(2.0, m[1, 0, 0], Precision.EPSILON) + assertEquals(3.0, m[1, 0, 1], Precision.EPSILON) + assertEquals(4.0, m[2, 0, 0], Precision.EPSILON) + assertEquals(5.0, m[2, 0, 1], Precision.EPSILON) + } + + @Test(expected = IndexOutOfBoundsException::class) fun getOutOfBounds() { + m[42, 42, 42] + } + + @Test fun set() { + val copy = m.copy() + copy[1, 0, 1] = 42.0 + assertEquals(42.0, copy[1, 0, 1], Precision.EPSILON) + } + + @Test(expected = IndexOutOfBoundsException::class) fun setOutOfBounds() { + m[42, 42, 42] = 100500.0 + } + + @Test fun setMagicMatrix() { + val copy = m.copy() + val replacement = StridedMatrix.full(m.rowsNumber, m.columnsNumber, 42.0) + copy[0] = replacement + assertEquals(replacement, copy[0]) + + for (d in 1..m.depth - 1) { + assertNotEquals(replacement, copy[d]) + assertEquals(m[d], copy[d]) + } + } + + @Test fun setMagicMatrixViaScalar() { + val copy1 = m.copy() + copy1[0] = 42.0 + val copy2 = m.copy() + copy2[0] = StridedMatrix.full(m.rowsNumber, m.columnsNumber, 42.0) + assertEquals(copy1, copy2) + } + + @Test fun setMagicVector() { + val copy = m.copy() + val replacement = StridedVector.full(m.columnsNumber, 42.0) + copy[0, 0] = replacement + assertEquals(replacement, copy[0, 0]) + + for (d in 1..m.depth - 1) { + for (r in 1..m.rowsNumber - 1) { + assertNotEquals(replacement, copy[d, r]) + assertEquals(m[d, r], copy[d, r]) + } + } + } + + @Test fun setMagicVectorViaScalar() { + val copy1 = m.copy() + copy1[1, 0] = 42.0 + val copy2 = m.copy() + copy2[1, 0] = StridedVector.full(m.columnsNumber, 42.0) + assertEquals(copy1, copy2) + } + + @Test fun setMagicScalar() { + val copy = m.copy() + val replacement = StridedMatrix.full(m.rowsNumber, m.columnsNumber, 42.0) + copy[0] = 42.0 + assertEquals(replacement, copy[0]) + + for (d in 1..m.depth - 1) { + assertNotEquals(replacement, copy[d]) + assertEquals(m[d], copy[d]) + } + } +} \ No newline at end of file diff --git a/src/test/kotlin/org/jetbrains/bio/viktor/StridedMatrixTests.kt b/src/test/kotlin/org/jetbrains/bio/viktor/StridedMatrixTests.kt deleted file mode 100644 index 5886f7d..0000000 --- a/src/test/kotlin/org/jetbrains/bio/viktor/StridedMatrixTests.kt +++ /dev/null @@ -1,167 +0,0 @@ -package org.jetbrains.bio.viktor - -import org.junit.Test -import kotlin.test.assertEquals -import kotlin.test.assertNotEquals -import kotlin.test.assertTrue - -class StridedMatrix2Test { - @Test fun testIndex() { - val m = StridedMatrix.full(NUM_ROWS, NUM_COLUMNS, -1.0) - for (r in 0..NUM_ROWS - 1) { - for (c in 0..NUM_COLUMNS - 1) { - m[r, c] = 1.0 - } - } - - assertEquals(NUM_ROWS * NUM_COLUMNS, m.sum().toInt()) - } - - @Test fun testGetSet() { - val m = getMatrix(NUM_ROWS, NUM_COLUMNS) - for (r in 0..NUM_ROWS - 1) { - for (c in 0..NUM_COLUMNS - 1) { - assertEquals(r.toDouble() * NUM_COLUMNS + c, m[r, c]) - } - } - } - - @Test fun testRowView() { - val m = getMatrix(NUM_ROWS, NUM_COLUMNS) - for (r in 0..NUM_ROWS - 1) { - val view = m.rowView(r) - for (c in 0..NUM_COLUMNS - 1) { - assertEquals(view[c], m[r, c]) - } - } - } - - @Test fun testColumnView() { - val m = getMatrix(NUM_ROWS, NUM_COLUMNS) - for (c in 0..NUM_COLUMNS - 1) { - val view = m.columnView(c) - for (r in 0..NUM_ROWS - 1) { - assertEquals(view[r], m[r, c]) - } - } - } - - @Test fun testTranspose() { - val m = getMatrix(NUM_ROWS, NUM_COLUMNS) - val mt = m.transpose() - for (r in 0..NUM_ROWS - 1) { - for (c in 0..NUM_COLUMNS - 1) { - assertEquals(m[r, c], mt[c, r]) - } - } - } - - @Test fun testCopyToFast() { - val src = getMatrix(NUM_ROWS, NUM_COLUMNS) - val dst = StridedMatrix(NUM_ROWS, NUM_COLUMNS) - src.copyTo(dst) - assertEquals(src, dst) - } - - @Test fun testCopyToSlow() { - val src = getMatrix(NUM_ROWS, NUM_COLUMNS) - val dst = StridedMatrix(NUM_COLUMNS, NUM_ROWS).transpose() - src.copyTo(dst) - assertEquals(src, dst) - } - - @Test fun testAlong() { - val m = StridedMatrix(NUM_ROWS, NUM_COLUMNS) - assertTrue(m.along(0).allMatch { it.size == NUM_ROWS }) - assertTrue(m.along(1).allMatch { it.size == NUM_COLUMNS }) - } - - @Test fun testEquals() { - val m = getMatrix(3, 4) - assertEquals(m, m) - assertNotEquals(m, m.exp()) - } - - private fun getMatrix(numRows: Int, numColumns: Int): StridedMatrix2 { - return StridedMatrix(numRows, numColumns) { r, c -> - (r * numColumns + c).toDouble() - } - } - - companion object { - val NUM_ROWS = 3 - val NUM_COLUMNS = 5 - } -} - -class StridedMatrix3Test { - @Test fun testIndex() { - val m = getMatrix() - m.fill(-1.0) - - for (d in 0..m.depth - 1) { - for (r in 0..m.rowsNumber - 1) { - for (c in 0..m.columnsNumber - 1) { - m[d, r, c] = 1.0 - } - } - } - - assertEquals(m.depth * m.rowsNumber * m.columnsNumber, - m.sum().toInt()) - } - - @Test fun testGetSet() { - val m = getMatrix() - var i = 0 - for (d in 0..m.depth - 1) { - for (r in 0..m.rowsNumber - 1) { - for (c in 0..m.columnsNumber - 1) { - assertEquals(i++, m[d, r, c].toInt()) - } - } - } - } - - @Test fun testView() { - val m = getMatrix() - for (d in 0..m.depth - 1) { - val view = m.view(d) - for (r in 0..m.rowsNumber - 1) { - for (c in 0..m.columnsNumber - 1) { - assertEquals(m[d, r, c], view[r, c]) - } - } - } - } - - @Test fun testViewAssignment() { - val magic = 100500.0 - val m = StridedMatrix(3, 4, 5) - for (d in 0..m.depth - 1) { - val copy = m.copy() - copy[d] = magic - assertEquals(StridedMatrix.full(m.rowsNumber, m.columnsNumber, magic), - copy[d]) - - for (other in 0..m.depth - 1) { - if (other != d) { - for (value in copy[other].flatten().toArray()) { - assertNotEquals(magic, value) - } - } - } - } - } - - @Test fun testEquals() { - val m = getMatrix() - assertEquals(m, m) - assertNotEquals(m, m.exp()) - } - - private fun getMatrix(): StridedMatrix3 { - var i = 0 - return StridedMatrix(7, 4, 3) { d, r, c -> i++.toDouble() } - } -} \ No newline at end of file diff --git a/src/test/kotlin/org/jetbrains/bio/viktor/StridedVectorTests.kt b/src/test/kotlin/org/jetbrains/bio/viktor/StridedVectorTests.kt index 2a726b0..a069837 100644 --- a/src/test/kotlin/org/jetbrains/bio/viktor/StridedVectorTests.kt +++ b/src/test/kotlin/org/jetbrains/bio/viktor/StridedVectorTests.kt @@ -93,10 +93,6 @@ class StridedVectorSlicing { StridedVector.of(1.0, 2.0).T.columnView(0)) assertEquals(StridedVector.of(1.0, 2.0, 3.0), StridedVector.of(1.0, 2.0, 3.0).T.columnView(0)) - - val m = StridedMatrix(2, 3) { i, j -> i + j * 42.0 } - assertEquals(StridedVector.of(42.0, 43.0), - m.columnView(1).T.columnView(0)) } @Test fun slice() { @@ -228,26 +224,6 @@ class StridedVectorOpsTest(private val v: StridedVector) { assertEquals(StridedVector.full(copy.size, 42.0), copy) } - @Test fun reshape() { - val v = (0..5).toStrided() - assertArrayEquals(arrayOf(doubleArrayOf(0.0, 1.0, 2.0), - doubleArrayOf(3.0, 4.0, 5.0)), - v.reshape(2, 3).toArray()) - assertArrayEquals(arrayOf(doubleArrayOf(0.0, 1.0), - doubleArrayOf(2.0, 3.0), - doubleArrayOf(4.0, 5.0)), - v.reshape(3, 2).toArray()) - } - - @Test fun reshapeWithStride() { - val v = StridedVector.create(doubleArrayOf(0.0, 1.0, 2.0, 3.0, - 4.0, 5.0, 6.0, 7.0), - 0, 4, stride = 2) - assertArrayEquals(arrayOf(doubleArrayOf(0.0, 2.0), - doubleArrayOf(4.0, 6.0)), - v.reshape(2, 2).toArray()) - } - @Test fun reverse() { val copy = v.copy() copy.reverse()