From a0bff721566d32ecc50ba4ff5840bdac12989b4b Mon Sep 17 00:00:00 2001 From: Unknown Date: Tue, 23 Jan 2018 07:05:57 +0000 Subject: [PATCH] Update (very) stale rng code to use platform support that was introduced later. Closes #39. --- .../koma.matrix.cblas/CBlasMatrixFactory.kt | 35 ------------------- .../src/koma.matrix.ejml/EJMLMatrixFactory.kt | 26 -------------- .../src/koma.matrix.ejml/backend/ejml.kt | 24 ------------- .../src/koma.matrix.ejml/backend/internal.kt | 3 -- .../koma.matrix.jblas/JBlasMatrixFactory.kt | 20 ----------- .../src/koma.matrix.mtj/MTJMatrixFactory.kt | 10 ------ .../src/koma.matrix.mtj/backend/internal.kt | 2 -- .../src/koma.matrix.mtj/backend/mtj.kt | 23 ------------ core/src/koma/matrix/MatrixFactory.kt | 14 -------- .../koma/matrix/common/DoubleFactoryBase.kt | 27 ++++++++++++++ core/src/koma/misc.kt | 3 ++ core/templates/DefaultXMatrixFactory.kt | 12 ++----- tests/test/koma/CreatorsTests.kt | 28 +++++++++++++++ 13 files changed, 61 insertions(+), 166 deletions(-) diff --git a/backend-matrix-cblas/src/koma.matrix.cblas/CBlasMatrixFactory.kt b/backend-matrix-cblas/src/koma.matrix.cblas/CBlasMatrixFactory.kt index 4e55b935..2d8e6e30 100644 --- a/backend-matrix-cblas/src/koma.matrix.cblas/CBlasMatrixFactory.kt +++ b/backend-matrix-cblas/src/koma.matrix.cblas/CBlasMatrixFactory.kt @@ -63,39 +63,4 @@ class CBlasMatrixFactory: DoubleFactoryBase() { it[row, col] = 0 } } - - @Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("rand(size, size)")) - override fun rand(size: Int): CBlasMatrix { - return rand(size, size) - } - - override fun rand(rows: Int, cols: Int): CBlasMatrix - = zeros(rows, cols).also { - it.forEachIndexed { row, col, _ -> - it[row, col] = rng.nextDouble() - } - } - - override fun rand(rows: Int, cols: Int, seed: Long): CBlasMatrix { - rng.setSeed(seed.toInt()) - return rand(rows, cols) - } - - @Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("randn(size, size)")) - override fun randn(size: Int): CBlasMatrix { - return randn(size, size) - } - - override fun randn(rows: Int, cols: Int): CBlasMatrix - = zeros(rows, cols).also { - it.forEachIndexed { row, col, _ -> - it[row, col] = rng.nextGaussian() - } - } - - override fun randn(rows: Int, cols: Int, seed: Long): CBlasMatrix { - rng.setSeed(seed.toInt()) - return randn(rows, cols) - } - } diff --git a/backend-matrix-ejml/src/koma.matrix.ejml/EJMLMatrixFactory.kt b/backend-matrix-ejml/src/koma.matrix.ejml/EJMLMatrixFactory.kt index 7476af94..09c702b2 100644 --- a/backend-matrix-ejml/src/koma.matrix.ejml/EJMLMatrixFactory.kt +++ b/backend-matrix-ejml/src/koma.matrix.ejml/EJMLMatrixFactory.kt @@ -45,31 +45,5 @@ class EJMLMatrixFactory : DoubleFactoryBase() { return EJMLMatrix(out) } - @Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("rand(size, size)")) - override fun rand(size: Int): EJMLMatrix { - return rand(size, size) - } - - override fun rand(rows: Int, cols: Int): EJMLMatrix { - return EJMLMatrix(koma.matrix.ejml.backend.rand(rows, cols)) - } - - override fun rand(rows: Int, cols: Int, seed: Long): EJMLMatrix { - return EJMLMatrix(koma.matrix.ejml.backend.rand(rows, cols, seed)) - } - - @Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("randn(size, size)")) - override fun randn(size: Int): EJMLMatrix { - return EJMLMatrix(koma.matrix.ejml.backend.randn(size)) - } - - override fun randn(rows: Int, cols: Int): EJMLMatrix { - return EJMLMatrix(koma.matrix.ejml.backend.randn(rows, cols)) - } - - override fun randn(rows: Int, cols: Int, seed: Long): EJMLMatrix { - return EJMLMatrix(koma.matrix.ejml.backend.randn(rows, cols, seed)) - } - } diff --git a/backend-matrix-ejml/src/koma.matrix.ejml/backend/ejml.kt b/backend-matrix-ejml/src/koma.matrix.ejml/backend/ejml.kt index ca63a19f..19cacd6f 100644 --- a/backend-matrix-ejml/src/koma.matrix.ejml/backend/ejml.kt +++ b/backend-matrix-ejml/src/koma.matrix.ejml/backend/ejml.kt @@ -76,30 +76,6 @@ fun ones(rows: Int, cols: Int): SimpleMatrix { return out } -fun rand(rows: Int, cols: Int, seed: Long): SimpleMatrix { - if (seed != curSeed) { - random.setSeed(seed) - curSeed = seed - } - return SimpleMatrix.random(rows, cols, 0.0, 1.0, random) -} -fun rand(rows: Int, cols: Int) = rand(rows, cols, curSeed) -fun rand(len: Int) = rand(1, len, curSeed) - -fun randn(rows: Int, cols: Int, seed: Long): SimpleMatrix { - if (seed != curSeed) { - random.setSeed(seed) - curSeed = seed - } - val out = SimpleMatrix(rows, cols) - for (i in 0..rows - 1) - for (j in 0..cols - 1) - out[i, j] = random.nextGaussian() - return out -} -fun randn(len: Int) = randn(len, len, curSeed) -fun randn(rows: Int, cols: Int) = randn(rows, cols, curSeed) - fun SimpleMatrix.map(f: (Double) -> Double): SimpleMatrix { val out = SimpleMatrix(this.numRows(), this.numCols()) for (row in 0..this.numRows() - 1) diff --git a/backend-matrix-ejml/src/koma.matrix.ejml/backend/internal.kt b/backend-matrix-ejml/src/koma.matrix.ejml/backend/internal.kt index d95d2402..104c2980 100644 --- a/backend-matrix-ejml/src/koma.matrix.ejml/backend/internal.kt +++ b/backend-matrix-ejml/src/koma.matrix.ejml/backend/internal.kt @@ -5,6 +5,3 @@ package koma.matrix.ejml.backend import koma.matrix.ejml.* internal var factoryInstance: EJMLMatrixFactory = EJMLMatrixFactory() -internal var curSeed = System.currentTimeMillis() -internal var random = java.util.Random(curSeed) - diff --git a/backend-matrix-jblas/src/koma.matrix.jblas/JBlasMatrixFactory.kt b/backend-matrix-jblas/src/koma.matrix.jblas/JBlasMatrixFactory.kt index 653dc48f..2b77a98a 100644 --- a/backend-matrix-jblas/src/koma.matrix.jblas/JBlasMatrixFactory.kt +++ b/backend-matrix-jblas/src/koma.matrix.jblas/JBlasMatrixFactory.kt @@ -22,24 +22,4 @@ class JBlasMatrixFactory : DoubleFactoryBase() { out[i, i] = 1.0 return JBlasMatrix(out) } - - @Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("rand(size, size)")) - override fun rand(size: Int) = JBlasMatrix(koma.matrix.jblas.backend.rand(size)) - override fun rand(rows: Int, cols: Int) = JBlasMatrix(koma.matrix.jblas.backend.rand(rows, cols)) - - override fun rand(rows: Int, cols: Int, seed: Long): JBlasMatrix { - println("Warning: JBlas RNG doesnt support seeds") - return JBlasMatrix(koma.matrix.jblas.backend.rand(rows, cols)) - } - - @Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("randn(size, size)")) - override fun randn(size: Int) = JBlasMatrix(koma.matrix.jblas.backend.randn(size)) - - override fun randn(rows: Int, cols: Int) = JBlasMatrix(koma.matrix.jblas.backend.randn(rows, cols)) - - override fun randn(rows: Int, cols: Int, seed: Long): JBlasMatrix { - println("Warning: JBlas RNG doesnt support seeds") - return JBlasMatrix(koma.matrix.jblas.backend.randn(rows, cols)) - } - } diff --git a/backend-matrix-mtj/src/koma.matrix.mtj/MTJMatrixFactory.kt b/backend-matrix-mtj/src/koma.matrix.mtj/MTJMatrixFactory.kt index ad45be82..4e5ab209 100644 --- a/backend-matrix-mtj/src/koma.matrix.mtj/MTJMatrixFactory.kt +++ b/backend-matrix-mtj/src/koma.matrix.mtj/MTJMatrixFactory.kt @@ -44,14 +44,4 @@ class MTJMatrixFactory : DoubleFactoryBase() { return MTJMatrix(out) } - @Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("rand(size, size)")) - override fun rand(size: Int) = rand(size, size) - override fun rand(rows: Int, cols: Int) = MTJMatrix(koma.matrix.mtj.backend.rand(rows, cols)) - override fun rand(rows: Int, cols: Int, seed: Long) = MTJMatrix(koma.matrix.mtj.backend.rand(rows, cols, seed)) - - @Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("randn(size, size)")) - override fun randn(size: Int) = randn(size, size) - override fun randn(rows: Int, cols: Int) = MTJMatrix(koma.matrix.mtj.backend.randn(rows, cols)) - override fun randn(rows: Int, cols: Int, seed: Long) = MTJMatrix(koma.matrix.mtj.backend.randn(rows, cols, seed)) - } \ No newline at end of file diff --git a/backend-matrix-mtj/src/koma.matrix.mtj/backend/internal.kt b/backend-matrix-mtj/src/koma.matrix.mtj/backend/internal.kt index ad36b62d..ece6ff9d 100644 --- a/backend-matrix-mtj/src/koma.matrix.mtj/backend/internal.kt +++ b/backend-matrix-mtj/src/koma.matrix.mtj/backend/internal.kt @@ -6,5 +6,3 @@ import koma.matrix.mtj.* internal var factoryInstance: MTJMatrixFactory = MTJMatrixFactory() -internal var curSeed = System.currentTimeMillis() -internal var random = java.util.Random(curSeed) diff --git a/backend-matrix-mtj/src/koma.matrix.mtj/backend/mtj.kt b/backend-matrix-mtj/src/koma.matrix.mtj/backend/mtj.kt index 516b3c1d..dea98824 100644 --- a/backend-matrix-mtj/src/koma.matrix.mtj/backend/mtj.kt +++ b/backend-matrix-mtj/src/koma.matrix.mtj/backend/mtj.kt @@ -137,29 +137,6 @@ fun DenseMatrix.det(): Double { } - -fun rand(rows: Int, cols: Int, seed: Long): DenseMatrix { - if (seed != curSeed) { - random.setSeed(seed) - curSeed = seed - } - return DenseMatrix(rows, cols).mapMat { random.nextDouble() } -} - -fun rand(rows: Int, cols: Int) = rand(rows, cols, curSeed) -fun rand(len: Int, seed: Long) = rand(1, len, seed) -fun rand(len: Int) = rand(len, curSeed) - -fun randn(len: Int) = randn(len, len) -fun randn(rows: Int, cols: Int) = randn(rows, cols, curSeed) -fun randn(rows: Int, cols: Int, seed: Long): DenseMatrix { - if (seed != curSeed) { - random.setSeed(seed) - curSeed = seed - } - return DenseMatrix(rows, cols).mapMat { random.nextGaussian() } -} - object mat { operator fun get(vararg ts: Any): DenseMatrix { // Todo: check for malformed inputs to avoid ambiguous out of bounds exceptions diff --git a/core/src/koma/matrix/MatrixFactory.kt b/core/src/koma/matrix/MatrixFactory.kt index 17030616..c2f3c246 100644 --- a/core/src/koma/matrix/MatrixFactory.kt +++ b/core/src/koma/matrix/MatrixFactory.kt @@ -70,13 +70,6 @@ interface MatrixFactory { @JsName("eye") fun eye(rows: Int, cols: Int): T - /** - * Creates a vector of [size] many uniform 0-1 random samples - */ - @JsName("rand__") - @Deprecated(DEPRECATE_IMPLICIT_2D) - fun rand(size: Int): T - /** * Creates a matrix of uniform 0-1 random samples */ @@ -90,13 +83,6 @@ interface MatrixFactory { @JsName("randSeed") fun rand(rows: Int, cols: Int, seed: Long): T - /** - * Creates a vector of [size] many unit-normal random samples - */ - @JsName("randn__Deprecated") - @Deprecated(DEPRECATE_IMPLICIT_2D) - fun randn(size: Int): T - /** * Creates a matrix of unit-normal random samples */ diff --git a/core/src/koma/matrix/common/DoubleFactoryBase.kt b/core/src/koma/matrix/common/DoubleFactoryBase.kt index 3badb701..232739ca 100644 --- a/core/src/koma/matrix/common/DoubleFactoryBase.kt +++ b/core/src/koma/matrix/common/DoubleFactoryBase.kt @@ -1,5 +1,6 @@ package koma.matrix.common +import koma.DEPRECATE_IMPLICIT_2D import koma.extensions.* import koma.matrix.Matrix import koma.matrix.MatrixFactory @@ -32,4 +33,30 @@ abstract class DoubleFactoryBase> : MatrixFactory { val inc = 1.0 * signum(stop.toDouble() - start.toDouble()) return arange(start.toDouble(), stop.toDouble(), inc) } + + override fun rand(rows: Int, cols: Int) = zeros(rows, cols).also { + it.fill { _, _ -> + koma.platformsupport.rng.nextDouble() + } + } + + override fun randn(rows: Int, cols: Int) = zeros(rows, cols).also { + it.fill { _, _ -> + koma.platformsupport.rng.nextGaussian() + } + } + + + @Deprecated("Call koma.setSeed and rand(row,col) separately") + override fun rand(rows: Int, cols: Int, seed: Long): T { + koma.platformsupport.seed = seed + return rand(rows, cols) + } + + + @Deprecated("Call koma.setSeed and randn(row,col) separately") + override fun randn(rows: Int, cols: Int, seed: Long): T { + koma.platformsupport.seed = seed + return randn(rows, cols) + } } diff --git a/core/src/koma/misc.kt b/core/src/koma/misc.kt index fc1805c4..200a5a57 100644 --- a/core/src/koma/misc.kt +++ b/core/src/koma/misc.kt @@ -17,6 +17,9 @@ val SCIENTIFIC_VERY_LONG_NUMBER = "SciNotVLong" val end = -1 val all = 0..end +fun setSeed(seed: Long) { + koma.platformsupport.seed = seed +} /** * Sets the format for Koma to display numbers in. For example, calling * diff --git a/core/templates/DefaultXMatrixFactory.kt b/core/templates/DefaultXMatrixFactory.kt index 2bbfe285..de0f689e 100644 --- a/core/templates/DefaultXMatrixFactory.kt +++ b/core/templates/DefaultXMatrixFactory.kt @@ -8,7 +8,7 @@ class Default${dtype}MatrixFactory: MatrixFactory> { override fun zeros(rows: Int, cols: Int) = Default${dtype}Matrix(rows, cols) @Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("zeros(size, size)")) - override fun zeros(size: Int): Matrix<${dtype}> + override fun zeros(size: Int): Matrix<${dtype}> = zeros(size, size) override fun create(data: IntRange): Matrix<${dtype}> { @@ -45,10 +45,7 @@ class Default${dtype}MatrixFactory: MatrixFactory> { = zeros(rows, cols) .fill {row,col->if (row==col) 1.to${dtype}() else 0.to${dtype}() } - @Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("rand(size, size)")) - override fun rand(size: Int): Matrix<${dtype}> - = rand(size, size) - + override fun rand(rows: Int, cols: Int): Matrix<${dtype}> = zeros(rows, cols) .fill { _, _ -> koma.platformsupport.rng.nextDouble().to${dtype}()} @@ -60,10 +57,7 @@ class Default${dtype}MatrixFactory: MatrixFactory> { return rand(rows, cols) } - @Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("randn(size, size)")) - override fun randn(size: Int): Matrix<${dtype}> - = randn(size, size) - + override fun randn(rows: Int, cols: Int): Matrix<${dtype}> = zeros(rows, cols) .fill { _, _ -> koma.platformsupport.rng.nextGaussian().to${dtype}()} diff --git a/tests/test/koma/CreatorsTests.kt b/tests/test/koma/CreatorsTests.kt index 11682a11..b124aa96 100644 --- a/tests/test/koma/CreatorsTests.kt +++ b/tests/test/koma/CreatorsTests.kt @@ -137,4 +137,32 @@ class CreatorsTests { assertFalse { (b-c).any { it == 0.0 } } } } + + @Test + fun testSeed() { + allBackends { + setSeed(4) + val a = randn(30,30) + val b = randn(30,30) + setSeed(4) + val c = randn(30,30) + val d = randn(30,30) + setSeed(5) + val e = randn(30,30) + val f = randn(30,30) + + assertMatrixEquals(a, c) + assertMatrixEquals(b, d) + + assertFalse { allclose(a,b) } + assertFalse { allclose(c,d) } + assertFalse { allclose(e,f) } + + assertFalse { allclose(a,d) } + assertFalse { allclose(b,c) } + + assertFalse { allclose(a,e) } + assertFalse { allclose(b,f) } + } + } }