Skip to content

Commit

Permalink
implement basic-types/cuda hdf5 support
Browse files Browse the repository at this point in the history
  • Loading branch information
ShigekiKarita committed Jun 25, 2018
1 parent 38a02c1 commit dcf2201
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 13 deletions.
6 changes: 6 additions & 0 deletions example/mnist.d
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import snck : snck;
import grain.autograd;
import grain.chain; // : Linear, relu;
import grain.optim; // : zeroGrad;
import grain.serializer : save, load;

enum files = [
"train-images-idx3-ubyte",
Expand Down Expand Up @@ -115,6 +116,7 @@ version (grain_cuda) {
}

void main() {
import std.file : exists;
RNG.setSeed(0);
grain.autograd.backprop = true;
auto datasets = prepareDataset();
Expand All @@ -124,6 +126,9 @@ void main() {
auto testBatch = datasets.test.makeBatch(batchSize);
auto model = Model!(float, S)(inSize, 512, 10);
auto optimizer = SGD!(typeof(model))(model, 1e-2);
if ("mnist.h5".exists) {
model.load("mnist.h5");
}

foreach (epoch; 0 .. 10) {
// TODO implement model.train();
Expand Down Expand Up @@ -159,5 +164,6 @@ void main() {
}
writefln!"test loss: %f, acc: %f"(lossSum / niter, accSum / niter);
}
model.save("mnist.h5");
}
}
70 changes: 57 additions & 13 deletions source/grain/serializer.d
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module grain.serializer;

import std.stdio;
import grain.autograd;
import hdf5.hdf5;

enum variableNames(C) = {
string[] ret;
Expand Down Expand Up @@ -71,11 +72,23 @@ version (unittest) {
}
}

/// https://support.hdfgroup.org/HDF5/doc1.8/RM/PredefDTypes.html
auto toH5Type(T)() {
import std.traits;
import std.format;
static assert(isBasicType!T);
mixin("return H5T_%s%dLE;".format(
isFloatingPoint!T
? "IEEE_F"
: (isSigned!T ? "STD_I" : "STD_U"),
T.sizeof * 8
));
}

void save(bool verbose = true, C)(C chain, string path) {
import std.file : exists;
import std.string : replace, endsWith;
import mir.ndslice : slice;
import hdf5.hdf5;
import grain.utility : castArray;
auto file = H5F.create(path,
// path.exists ? H5F_ACC_RDWR :
Expand All @@ -86,19 +99,18 @@ void save(bool verbose = true, C)(C chain, string path) {
// H5P.set_alloc_time(property, H5DAllocTime.Early); // canbe Late
// scope(exit) H5P.close(property);

void register(V)(string k, V v) if (isVariable!V) {
void register(T, size_t dim, alias Storage)(string k, Variable!(T, dim, Storage) v) {
auto h = v.to!HostStorage;
// FIXME support check contiguous
// auto s = h.sliced.slice;
auto data = v.to!HostStorage.data;
auto space = H5S.create_simple(h.shape.castArray!hsize_t);
scope(exit) H5S.close(space);
auto h5key = "/" ~ k.replace(".", "_");
// FIXME support non-float type
auto dataset = H5D.create2(file, h5key, H5T_IEEE_F32LE, space,
auto dataset = H5D.create2(file, "/" ~ k, toH5Type!T, space,
H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT);
scope(exit) H5D.close(dataset);
H5D.write(dataset, H5T_NATIVE_FLOAT, H5S_ALL, H5S_ALL, H5P_DEFAULT,
H5D.write(dataset, toH5Type!T, H5S_ALL, H5S_ALL, H5P_DEFAULT,
cast(ubyte*) data.ptr);
// auto raw = new float[v.data.length];
// H5D.read(dataset, H5T_NATIVE_FLOAT, H5S_ALL, H5S_ALL, H5P_DEFAULT,
Expand All @@ -121,28 +133,60 @@ void load(C)(ref C chain, string path) {

void register(T, size_t dim, alias Storage)(string k, ref Variable!(T, dim, Storage) v) {
// writeln(k, v.sliced);
auto h5key = "/" ~ k.replace(".", "_");
// writeln(h5key);
auto dataset = H5D.open2(file, h5key, H5P_DEFAULT);
auto dataset = H5D.open2(file, "/" ~ k, H5P_DEFAULT);
scope(exit) H5D.close(dataset);
// FIXME support non-float type
auto raw = new float[v.data.length];
H5D.read(dataset, H5T_NATIVE_FLOAT, H5S_ALL, H5S_ALL, H5P_DEFAULT,
auto raw = new T[v.data.length];
H5D.read(dataset, toH5Type!T, H5S_ALL, H5S_ALL, H5P_DEFAULT,
cast(ubyte*) &raw[0]);
auto src = raw.sliced(v.shape.castArray!size_t).variable;
// TODO cuda support
v.data[] = src.to!Storage.data;
v.strides = src.strides;
static if (is(Storage!T == HostStorage!T)) {
v.sliced[] = src.sliced;
} else {
import grain.cudnn : transform;
transform(src.to!Storage, v);
}
}
refIterVariables!( (k, ref v) { register(k, v); })(chain, "");
}

///
unittest {
import numir;
auto model1 = MLP!(float, HostStorage)(3);
model1.save("test_grain.h5");

auto model2 = MLP!(float, HostStorage)(3);
model2.load("test_grain.h5");
assert(model1.fc1.bias.sliced == model2.fc1.bias.sliced);
}

///
version (grain_cuda) unittest {
auto model1 = MLP!(float, DeviceStorage)(3);
model1.save("test_grain.h5");

auto model2 = MLP!(float, DeviceStorage)(3);
model2.load("test_grain.h5");
assert(model1.fc1.bias.to!HostStorage.sliced == model2.fc1.bias.to!HostStorage.sliced);
}

///
version (grain_cuda) unittest {
auto model1 = MLP!(float, HostStorage)(3);
model1.save("test_grain.h5");

auto model2 = MLP!(float, DeviceStorage)(3);
model2.load("test_grain.h5");
assert(model1.fc1.bias.to!HostStorage.sliced == model2.fc1.bias.to!HostStorage.sliced);
}

///
version (grain_cuda) unittest {
auto model1 = MLP!(float, DeviceStorage)(3);
model1.save("test_grain.h5");

auto model2 = MLP!(float, HostStorage)(3);
model2.load("test_grain.h5");
assert(model1.fc1.bias.to!HostStorage.sliced == model2.fc1.bias.to!HostStorage.sliced);
}

0 comments on commit dcf2201

Please sign in to comment.