-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
10 changed files
with
140 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package org.jetbrains.bio.viktor | ||
|
||
import org.jetbrains.bio.npy.NpyArray | ||
import org.jetbrains.bio.npy.NpyFile | ||
import org.jetbrains.bio.npy.NpzFile | ||
import java.nio.file.Path | ||
|
||
/** Returns a view of the [NpyArray] as a strided vector. */ | ||
fun NpyArray.asStridedVector() = asDoubleArray().asStrided() | ||
|
||
/** Returns a view of the [NpyArray] as a 2-D strided matrix. */ | ||
fun NpyArray.asStridedMatrix2(): StridedMatrix2 { | ||
val (numRows, numColumns) = shape | ||
return asStridedVector().reshape(numRows, numColumns) | ||
} | ||
|
||
/** Returns a view of the [NpyArray] as a 3-D strided matrix. */ | ||
fun NpyArray.asStridedMatrix3(): StridedMatrix3 { | ||
val (depth, numRows, numColumns) = shape | ||
return asStridedVector().reshape(depth, numRows, numColumns) | ||
} | ||
|
||
/** Writes a given vector to [path] in NPY format. */ | ||
fun NpyFile.write(path: Path, v: StridedVector) { | ||
write(path, v.toArray(), v.shape) | ||
} | ||
|
||
/** Writes a given 2-D matrix to [path] in NPY format. */ | ||
fun NpyFile.write(path: Path, m: StridedMatrix2) { | ||
write(path, m.flatten().toArray(), shape = m.shape) | ||
} | ||
|
||
/** Writes a given 3-D matrix to [path] in NPY format. */ | ||
fun NpyFile.write(path: Path, m: StridedMatrix3) { | ||
write(path, m.flatten().toArray(), m.shape) | ||
} | ||
|
||
/** Adds a given vector to an NPZ format under the specified [name]. */ | ||
fun NpzFile.Writer.write(name: String, v: StridedVector) { | ||
write(name, v.toArray(), v.shape) | ||
} | ||
|
||
/** Writes a given 2-D matrix into an NPZ file under the specified [name]. */ | ||
fun NpzFile.Writer.write(name: String, m: StridedMatrix2) { | ||
write(name, m.flatten().toArray(), m.shape) | ||
} | ||
|
||
/** Writes a given 3-D matrix into an NPZ file under the specified [name]. */ | ||
fun NpzFile.Writer.write(name: String, m: StridedMatrix3) { | ||
write(name, m.flatten().toArray(), m.shape) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
50 changes: 50 additions & 0 deletions
50
src/test/kotlin/org/jetbrains/bio/viktor/SerializationTests.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
package org.jetbrains.bio.viktor | ||
|
||
import org.jetbrains.bio.npy.NpyFile | ||
import org.jetbrains.bio.npy.NpzFile | ||
import org.junit.Test | ||
import kotlin.test.assertEquals | ||
|
||
class TestReadWriteNpy { | ||
@Test fun vector() = withTempFile("v", ".npy") { path -> | ||
val v = StridedVector.of(1.0, 2.0, 3.0, 4.0) | ||
NpyFile.write(path, v) | ||
assertEquals(v, NpyFile.read(path).asStridedVector()) | ||
} | ||
|
||
@Test fun matrix2() = withTempFile("m2", ".npy") { path -> | ||
val m = StridedVector.of(1.0, 2.0, 3.0, 4.0, 5.0, 6.0).reshape(2, 3) | ||
NpyFile.write(path, m) | ||
assertEquals(m, NpyFile.read(path).asStridedMatrix2()) | ||
} | ||
|
||
@Test fun matrix3() = withTempFile("m3", ".npy") { path -> | ||
val m = StridedVector.of(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0) | ||
.reshape(1, 4, 2) | ||
NpyFile.write(path, m) | ||
assertEquals(m, NpyFile.read(path).asStridedMatrix3()) | ||
} | ||
} | ||
|
||
class TestReadWriteNpz { | ||
@Test fun combined() { | ||
val v = StridedVector.of(1.0, 2.0, 3.0, 4.0) | ||
val m2 = StridedVector.of(1.0, 2.0, 3.0, 4.0, 5.0, 6.0).reshape(2, 3) | ||
val m3 = StridedVector.of(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0) | ||
.reshape(1, 4, 2) | ||
|
||
withTempFile("vm2m3", ".npz") { path -> | ||
NpzFile.write(path).use { | ||
it.write("v", v) | ||
it.write("m2", m2) | ||
it.write("m3", m3) | ||
} | ||
|
||
NpzFile.read(path).use { | ||
assertEquals(v, it["v"].asStridedVector()) | ||
assertEquals(m2, it["m2"].asStridedMatrix2()) | ||
assertEquals(m3, it["m3"].asStridedMatrix3()) | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
package org.jetbrains.bio.viktor | ||
|
||
import java.io.IOException | ||
import java.nio.file.Files | ||
import java.nio.file.Path | ||
|
||
internal inline fun withTempFile(prefix: String, suffix: String, | ||
block: (Path) -> Unit) { | ||
val path = Files.createTempFile(prefix, suffix) | ||
try { | ||
block(path) | ||
} finally { | ||
try { | ||
Files.delete(path) | ||
} catch (e: IOException) { | ||
// Mmaped buffer not yet garbage collected. Leave it to the VM. | ||
path.toFile().deleteOnExit() | ||
} | ||
} | ||
} |