Skip to content

Commit

Permalink
Update (very) stale rng code to use platform support that was introdu…
Browse files Browse the repository at this point in the history
…ced later. Closes #39.
  • Loading branch information
kyonifer committed Jan 23, 2018
1 parent a077fa0 commit 7bd793d
Show file tree
Hide file tree
Showing 13 changed files with 61 additions and 166 deletions.
35 changes: 0 additions & 35 deletions backend-matrix-cblas/src/koma.matrix.cblas/CBlasMatrixFactory.kt
Original file line number Diff line number Diff line change
Expand Up @@ -63,39 +63,4 @@ class CBlasMatrixFactory: DoubleFactoryBase<CBlasMatrix>() {
it[row, col] = 0
}
}

@Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("rand(size, size)"))
override fun rand(size: Int): CBlasMatrix {
return rand(size, size)
}

override fun rand(rows: Int, cols: Int): CBlasMatrix
= zeros(rows, cols).also {
it.forEachIndexed { row, col, _ ->
it[row, col] = rng.nextDouble()
}
}

override fun rand(rows: Int, cols: Int, seed: Long): CBlasMatrix {
rng.setSeed(seed.toInt())
return rand(rows, cols)
}

@Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("randn(size, size)"))
override fun randn(size: Int): CBlasMatrix {
return randn(size, size)
}

override fun randn(rows: Int, cols: Int): CBlasMatrix
= zeros(rows, cols).also {
it.forEachIndexed { row, col, _ ->
it[row, col] = rng.nextGaussian()
}
}

override fun randn(rows: Int, cols: Int, seed: Long): CBlasMatrix {
rng.setSeed(seed.toInt())
return randn(rows, cols)
}

}
26 changes: 0 additions & 26 deletions backend-matrix-ejml/src/koma.matrix.ejml/EJMLMatrixFactory.kt
Original file line number Diff line number Diff line change
Expand Up @@ -45,31 +45,5 @@ class EJMLMatrixFactory : DoubleFactoryBase<EJMLMatrix>() {
return EJMLMatrix(out)
}

@Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("rand(size, size)"))
override fun rand(size: Int): EJMLMatrix {
return rand(size, size)
}

override fun rand(rows: Int, cols: Int): EJMLMatrix {
return EJMLMatrix(koma.matrix.ejml.backend.rand(rows, cols))
}

override fun rand(rows: Int, cols: Int, seed: Long): EJMLMatrix {
return EJMLMatrix(koma.matrix.ejml.backend.rand(rows, cols, seed))
}

@Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("randn(size, size)"))
override fun randn(size: Int): EJMLMatrix {
return EJMLMatrix(koma.matrix.ejml.backend.randn(size))
}

override fun randn(rows: Int, cols: Int): EJMLMatrix {
return EJMLMatrix(koma.matrix.ejml.backend.randn(rows, cols))
}

override fun randn(rows: Int, cols: Int, seed: Long): EJMLMatrix {
return EJMLMatrix(koma.matrix.ejml.backend.randn(rows, cols, seed))
}

}

24 changes: 0 additions & 24 deletions backend-matrix-ejml/src/koma.matrix.ejml/backend/ejml.kt
Original file line number Diff line number Diff line change
Expand Up @@ -76,30 +76,6 @@ fun ones(rows: Int, cols: Int): SimpleMatrix {
return out
}

fun rand(rows: Int, cols: Int, seed: Long): SimpleMatrix {
if (seed != curSeed) {
random.setSeed(seed)
curSeed = seed
}
return SimpleMatrix.random(rows, cols, 0.0, 1.0, random)
}
fun rand(rows: Int, cols: Int) = rand(rows, cols, curSeed)
fun rand(len: Int) = rand(1, len, curSeed)

fun randn(rows: Int, cols: Int, seed: Long): SimpleMatrix {
if (seed != curSeed) {
random.setSeed(seed)
curSeed = seed
}
val out = SimpleMatrix(rows, cols)
for (i in 0..rows - 1)
for (j in 0..cols - 1)
out[i, j] = random.nextGaussian()
return out
}
fun randn(len: Int) = randn(len, len, curSeed)
fun randn(rows: Int, cols: Int) = randn(rows, cols, curSeed)

fun SimpleMatrix.map(f: (Double) -> Double): SimpleMatrix {
val out = SimpleMatrix(this.numRows(), this.numCols())
for (row in 0..this.numRows() - 1)
Expand Down
3 changes: 0 additions & 3 deletions backend-matrix-ejml/src/koma.matrix.ejml/backend/internal.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,3 @@ package koma.matrix.ejml.backend
import koma.matrix.ejml.*

internal var factoryInstance: EJMLMatrixFactory = EJMLMatrixFactory()
internal var curSeed = System.currentTimeMillis()
internal var random = java.util.Random(curSeed)

20 changes: 0 additions & 20 deletions backend-matrix-jblas/src/koma.matrix.jblas/JBlasMatrixFactory.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,4 @@ class JBlasMatrixFactory : DoubleFactoryBase<JBlasMatrix>() {
out[i, i] = 1.0
return JBlasMatrix(out)
}

@Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("rand(size, size)"))
override fun rand(size: Int) = JBlasMatrix(koma.matrix.jblas.backend.rand(size))
override fun rand(rows: Int, cols: Int) = JBlasMatrix(koma.matrix.jblas.backend.rand(rows, cols))

override fun rand(rows: Int, cols: Int, seed: Long): JBlasMatrix {
println("Warning: JBlas RNG doesnt support seeds")
return JBlasMatrix(koma.matrix.jblas.backend.rand(rows, cols))
}

@Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("randn(size, size)"))
override fun randn(size: Int) = JBlasMatrix(koma.matrix.jblas.backend.randn(size))

override fun randn(rows: Int, cols: Int) = JBlasMatrix(koma.matrix.jblas.backend.randn(rows, cols))

override fun randn(rows: Int, cols: Int, seed: Long): JBlasMatrix {
println("Warning: JBlas RNG doesnt support seeds")
return JBlasMatrix(koma.matrix.jblas.backend.randn(rows, cols))
}

}
10 changes: 0 additions & 10 deletions backend-matrix-mtj/src/koma.matrix.mtj/MTJMatrixFactory.kt
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,4 @@ class MTJMatrixFactory : DoubleFactoryBase<MTJMatrix>() {
return MTJMatrix(out)
}

@Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("rand(size, size)"))
override fun rand(size: Int) = rand(size, size)
override fun rand(rows: Int, cols: Int) = MTJMatrix(koma.matrix.mtj.backend.rand(rows, cols))
override fun rand(rows: Int, cols: Int, seed: Long) = MTJMatrix(koma.matrix.mtj.backend.rand(rows, cols, seed))

@Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("randn(size, size)"))
override fun randn(size: Int) = randn(size, size)
override fun randn(rows: Int, cols: Int) = MTJMatrix(koma.matrix.mtj.backend.randn(rows, cols))
override fun randn(rows: Int, cols: Int, seed: Long) = MTJMatrix(koma.matrix.mtj.backend.randn(rows, cols, seed))

}
2 changes: 0 additions & 2 deletions backend-matrix-mtj/src/koma.matrix.mtj/backend/internal.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,3 @@ import koma.matrix.mtj.*


internal var factoryInstance: MTJMatrixFactory = MTJMatrixFactory()
internal var curSeed = System.currentTimeMillis()
internal var random = java.util.Random(curSeed)
23 changes: 0 additions & 23 deletions backend-matrix-mtj/src/koma.matrix.mtj/backend/mtj.kt
Original file line number Diff line number Diff line change
Expand Up @@ -137,29 +137,6 @@ fun DenseMatrix.det(): Double {

}


fun rand(rows: Int, cols: Int, seed: Long): DenseMatrix {
if (seed != curSeed) {
random.setSeed(seed)
curSeed = seed
}
return DenseMatrix(rows, cols).mapMat { random.nextDouble() }
}

fun rand(rows: Int, cols: Int) = rand(rows, cols, curSeed)
fun rand(len: Int, seed: Long) = rand(1, len, seed)
fun rand(len: Int) = rand(len, curSeed)

fun randn(len: Int) = randn(len, len)
fun randn(rows: Int, cols: Int) = randn(rows, cols, curSeed)
fun randn(rows: Int, cols: Int, seed: Long): DenseMatrix {
if (seed != curSeed) {
random.setSeed(seed)
curSeed = seed
}
return DenseMatrix(rows, cols).mapMat { random.nextGaussian() }
}

object mat {
operator fun get(vararg ts: Any): DenseMatrix {
// Todo: check for malformed inputs to avoid ambiguous out of bounds exceptions
Expand Down
14 changes: 0 additions & 14 deletions core/src/koma/matrix/MatrixFactory.kt
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,6 @@ interface MatrixFactory<out T> {
@JsName("eye")
fun eye(rows: Int, cols: Int): T

/**
* Creates a vector of [size] many uniform 0-1 random samples
*/
@JsName("rand__")
@Deprecated(DEPRECATE_IMPLICIT_2D)
fun rand(size: Int): T

/**
* Creates a matrix of uniform 0-1 random samples
*/
Expand All @@ -90,13 +83,6 @@ interface MatrixFactory<out T> {
@JsName("randSeed")
fun rand(rows: Int, cols: Int, seed: Long): T

/**
* Creates a vector of [size] many unit-normal random samples
*/
@JsName("randn__Deprecated")
@Deprecated(DEPRECATE_IMPLICIT_2D)
fun randn(size: Int): T

/**
* Creates a matrix of unit-normal random samples
*/
Expand Down
27 changes: 27 additions & 0 deletions core/src/koma/matrix/common/DoubleFactoryBase.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package koma.matrix.common

import koma.DEPRECATE_IMPLICIT_2D
import koma.extensions.*
import koma.matrix.Matrix
import koma.matrix.MatrixFactory
Expand Down Expand Up @@ -32,4 +33,30 @@ abstract class DoubleFactoryBase<T: Matrix<Double>> : MatrixFactory<T> {
val inc = 1.0 * signum(stop.toDouble() - start.toDouble())
return arange(start.toDouble(), stop.toDouble(), inc)
}

override fun rand(rows: Int, cols: Int) = zeros(rows, cols).also {
it.fill { _, _ ->
koma.platformsupport.rng.nextDouble()
}
}

override fun randn(rows: Int, cols: Int) = zeros(rows, cols).also {
it.fill { _, _ ->
koma.platformsupport.rng.nextGaussian()
}
}


@Deprecated("Call koma.setSeed and rand(row,col) separately")
override fun rand(rows: Int, cols: Int, seed: Long): T {
koma.platformsupport.seed = seed
return rand(rows, cols)
}


@Deprecated("Call koma.setSeed and randn(row,col) separately")
override fun randn(rows: Int, cols: Int, seed: Long): T {
koma.platformsupport.seed = seed
return randn(rows, cols)
}
}
3 changes: 3 additions & 0 deletions core/src/koma/misc.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ val SCIENTIFIC_VERY_LONG_NUMBER = "SciNotVLong"
val end = -1
val all = 0..end

fun setSeed(seed: Long) {
koma.platformsupport.seed = seed
}
/**
* Sets the format for Koma to display numbers in. For example, calling
*
Expand Down
12 changes: 3 additions & 9 deletions core/templates/DefaultXMatrixFactory.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class Default${dtype}MatrixFactory: MatrixFactory<Matrix<${dtype}>> {
override fun zeros(rows: Int, cols: Int)
= Default${dtype}Matrix(rows, cols)
@Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("zeros(size, size)"))
override fun zeros(size: Int): Matrix<${dtype}>
override fun zeros(size: Int): Matrix<${dtype}>
= zeros(size, size)

override fun create(data: IntRange): Matrix<${dtype}> {
Expand Down Expand Up @@ -45,10 +45,7 @@ class Default${dtype}MatrixFactory: MatrixFactory<Matrix<${dtype}>> {
= zeros(rows, cols)
.fill {row,col->if (row==col) 1.to${dtype}() else 0.to${dtype}() }

@Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("rand(size, size)"))
override fun rand(size: Int): Matrix<${dtype}>
= rand(size, size)


override fun rand(rows: Int, cols: Int): Matrix<${dtype}>
= zeros(rows, cols)
.fill { _, _ -> koma.platformsupport.rng.nextDouble().to${dtype}()}
Expand All @@ -60,10 +57,7 @@ class Default${dtype}MatrixFactory: MatrixFactory<Matrix<${dtype}>> {
return rand(rows, cols)
}

@Deprecated(DEPRECATE_IMPLICIT_2D, ReplaceWith("randn(size, size)"))
override fun randn(size: Int): Matrix<${dtype}>
= randn(size, size)


override fun randn(rows: Int, cols: Int): Matrix<${dtype}>
= zeros(rows, cols)
.fill { _, _ -> koma.platformsupport.rng.nextGaussian().to${dtype}()}
Expand Down
28 changes: 28 additions & 0 deletions tests/test/koma/CreatorsTests.kt
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,32 @@ class CreatorsTests {
assertFalse { (b-c).any { it == 0.0 } }
}
}

@Test
fun testSeed() {
allBackends {
setSeed(4)
val a = randn(30,30)
val b = randn(30,30)
setSeed(4)
val c = randn(30,30)
val d = randn(30,30)
setSeed(5)
val e = randn(30,30)
val f = randn(30,30)

assertMatrixEquals(a, c)
assertMatrixEquals(b, d)

assertFalse { allclose(a,b) }
assertFalse { allclose(c,d) }
assertFalse { allclose(e,f) }

assertFalse { allclose(a,d) }
assertFalse { allclose(b,c) }

assertFalse { allclose(a,e) }
assertFalse { allclose(b,f) }
}
}
}

0 comments on commit 7bd793d

Please sign in to comment.