Skip to content

Commit

Permalink
Use idiomatic Kotlin for greater good
Browse files Browse the repository at this point in the history
  • Loading branch information
superbobry committed Jun 27, 2016
1 parent 94b7709 commit c2fd5b3
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 82 deletions.
2 changes: 0 additions & 2 deletions src/main/kotlin/org/jetbrains/bio/viktor/DoubleExtensions.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package org.jetbrains.bio.viktor

import org.jetbrains.bio.viktor.NativeSpeedups

/**
* Operator overloads for [Double] and [StridedVector].
*
Expand Down
104 changes: 24 additions & 80 deletions src/main/kotlin/org/jetbrains/bio/viktor/StridedVector.kt
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,7 @@ open class StridedVector internal constructor(
*/
open fun max() = unsafeGet(argMax())

fun exp(): StridedVector {
val copy = copy()
copy.expInPlace()
return copy
}
fun exp() = copy().apply { expInPlace() }

/**
* Computes the exponent of each element of this vector.
Expand All @@ -254,11 +250,7 @@ open class StridedVector internal constructor(
}
}

fun expm1(): StridedVector {
val copy = copy()
copy.expm1InPlace()
return copy
}
fun expm1() = copy().apply { expm1InPlace() }

/**
* Computes exp(x) - 1 for each element of this vector.
Expand All @@ -273,11 +265,7 @@ open class StridedVector internal constructor(
}
}

fun log(): StridedVector {
val copy = copy()
copy.logInPlace()
return copy
}
fun log() = copy().apply { logInPlace() }

/**
* Computes the natural log of each element of this vector.
Expand All @@ -290,11 +278,7 @@ open class StridedVector internal constructor(
}
}

fun log1p(): StridedVector {
val copy = copy()
copy.log1pInPlace()
return copy
}
fun log1p() = copy().apply { log1pInPlace() }

/**
* Computes log(1 + x) for each element of this vector.
Expand All @@ -315,12 +299,11 @@ open class StridedVector internal constructor(
* The operation is done **in place**.
*/
fun rescale() {
val total = sum() + Precision.EPSILON * size.toDouble()
this /= total
this /= sum() + Precision.EPSILON * size.toDouble()
}

/**
* Rescales the element so that the exponent of the sum is 1.
* Rescales the element so that the exponent of the sum is 1.0.
*
* Optimized for dense vectors.
*
Expand Down Expand Up @@ -374,11 +357,7 @@ open class StridedVector internal constructor(
return v
}

operator fun plus(other: StridedVector): StridedVector {
val v = copy()
v += other
return v
}
operator fun plus(other: StridedVector) = copy().apply { this += other }

operator open fun plusAssign(other: StridedVector) {
checkSize(other)
Expand All @@ -387,23 +366,15 @@ open class StridedVector internal constructor(
}
}

operator fun plus(update: Double): StridedVector {
val v = copy()
v += update
return v
}
operator fun plus(update: Double) = copy().apply { this += update }

operator open fun plusAssign(update: Double) {
for (pos in 0..size - 1) {
unsafeSet(pos, unsafeGet(pos) + update)
}
}

operator fun minus(other: StridedVector): StridedVector {
val v = copy()
v -= other
return v
}
operator fun minus(other: StridedVector) = copy().apply { this -= other }

operator open fun minusAssign(other: StridedVector) {
checkSize(other)
Expand All @@ -412,35 +383,15 @@ open class StridedVector internal constructor(
}
}

operator fun minus(update: Double): StridedVector {
val v = copy()
v -= update
return v
}
operator fun minus(update: Double) = copy().apply { this -= update }

operator open fun minusAssign(update: Double) {
for (pos in 0..size - 1) {
unsafeSet(pos, unsafeGet(pos) - update)
}
}

operator fun times(value: Double): StridedVector {
val v = copy()
v *= value
return v
}

operator open fun timesAssign(update: Double) {
for (pos in 0..size - 1) {
unsafeSet(pos, unsafeGet(pos) * update)
}
}

operator fun times(other: StridedVector): StridedVector {
val v = copy()
v *= other
return v
}
operator fun times(other: StridedVector) = copy().apply { this *= other }

operator open fun timesAssign(other: StridedVector) {
checkSize(other)
Expand All @@ -449,23 +400,15 @@ open class StridedVector internal constructor(
}
}

operator fun div(value: Double): StridedVector {
val v = copy()
v /= value
return v
}
operator fun times(update: Double) = copy().apply { this *= update }

operator open fun divAssign(update: Double) {
operator open fun timesAssign(update: Double) {
for (pos in 0..size - 1) {
unsafeSet(pos, unsafeGet(pos) / update)
unsafeSet(pos, unsafeGet(pos) * update)
}
}

operator fun div(other: StridedVector): StridedVector {
val v = copy()
v /= other
return v
}
operator fun div(other: StridedVector) = copy().apply { this /= other }

operator open fun divAssign(other: StridedVector) {
checkSize(other)
Expand All @@ -474,19 +417,20 @@ open class StridedVector internal constructor(
}
}

fun isEmpty() = size == 0
operator fun div(update: Double) = copy().apply { this /= update }

fun isNotEmpty() = size > 0

open fun toArray(): DoubleArray {
val res = DoubleArray(size)
operator open fun divAssign(update: Double) {
for (pos in 0..size - 1) {
res[pos] = unsafeGet(pos)
unsafeSet(pos, unsafeGet(pos) / update)
}

return res
}

fun isEmpty() = size == 0

fun isNotEmpty() = size > 0

open fun toArray() = DoubleArray(size) { unsafeGet(it) }

/** Creates an iterator over the elements of the array. */
operator fun iterator(): DoubleIterator = object : DoubleIterator() {
var i = 0
Expand Down

0 comments on commit c2fd5b3

Please sign in to comment.