Skip to content

Commit

Permalink
update serializer
Browse files Browse the repository at this point in the history
  • Loading branch information
ShigekiKarita committed Jun 26, 2018
1 parent 5c38c07 commit 3d7b018
Showing 1 changed file with 15 additions and 26 deletions.
41 changes: 15 additions & 26 deletions source/grain/serializer.d
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import std.stdio;
import grain.autograd;
import hdf5.hdf5;

/// enumerate the parameter names inside chain C
enum variableNames(C) = {
string[] ret;
void register(V)(string k, V v) if (isVariable!V) {
Expand All @@ -22,24 +23,8 @@ unittest {
[".fc1.weight", ".fc1.bias", ".fc2.weight", ".fc2.bias", ".fc3.weight", ".fc3.bias"]);
}

auto variableDict(C)(C chain) {
UntypedVariable[string] ret;
void register(V)(string k, V v) if (isVariable!V) {
ret[k] = UntypedVariable(v);
}
iterVariables!( (k, v) { register(k, v); })(&chain, "");
return ret;
}

unittest {
import std.traits;
auto mlp = MLP!(float, HostStorage)(3);
auto keys = [".fc1.weight", ".fc1.bias",
".fc2.weight", ".fc2.bias",
".fc3.weight", ".fc3.bias"];
}

/// test .slice makes slice contiguous
// test .slice makes slice contiguous
unittest {
import numir;
import mir.ndslice;
Expand Down Expand Up @@ -72,7 +57,7 @@ version (unittest) {
}
}

/// https://support.hdfgroup.org/HDF5/doc1.8/RM/PredefDTypes.html
/// convert D type into HDF5 type-id https://support.hdfgroup.org/HDF5/doc1.8/RM/PredefDTypes.html
auto toH5Type(T)() {
import std.traits;
import std.format;
Expand All @@ -85,19 +70,16 @@ auto toH5Type(T)() {
));
}

/// save chain parameters to HDF5 path
void save(bool verbose = true, C)(C chain, string path) {
import std.file : exists;
import std.string : replace, endsWith;
import mir.ndslice : slice;
import grain.utility : castArray;
auto file = H5F.create(path,
// path.exists ? H5F_ACC_RDWR :
H5F_ACC_TRUNC,
H5P_DEFAULT, H5P_DEFAULT);
scope(exit) H5F.close(file);
// auto property = H5P.create (H5P_DATASET_CREATE);
// H5P.set_alloc_time(property, H5DAllocTime.Early); // canbe Late
// scope(exit) H5P.close(property);

void register(T, size_t dim, alias Storage)(string k, Variable!(T, dim, Storage) v) {
auto h = v.to!HostStorage;
Expand All @@ -106,19 +88,16 @@ void save(bool verbose = true, C)(C chain, string path) {
auto data = v.to!HostStorage.data;
auto space = H5S.create_simple(h.shape.castArray!hsize_t);
scope(exit) H5S.close(space);
// FIXME support non-float type
auto dataset = H5D.create2(file, "/" ~ k, toH5Type!T, space,
H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT);
scope(exit) H5D.close(dataset);
H5D.write(dataset, toH5Type!T, H5S_ALL, H5S_ALL, H5P_DEFAULT,
cast(ubyte*) data.ptr);
// auto raw = new float[v.data.length];
// H5D.read(dataset, H5T_NATIVE_FLOAT, H5S_ALL, H5S_ALL, H5P_DEFAULT,
// cast(ubyte*) raw.ptr);
}
iterVariables!( (k, v) { register(k, v); })(&chain, "");
}

/// load chain parameters from HDF5 path
void load(C)(ref C chain, string path) {
import std.string : replace, endsWith;
import mir.ndslice : slice, sliced;
Expand Down Expand Up @@ -159,6 +138,11 @@ unittest {
auto model2 = MLP!(float, HostStorage)(3);
model2.load("test_grain.h5");
assert(model1.fc1.bias.sliced == model2.fc1.bias.sliced);

import numir;
import mir.ndslice;
auto x = uniform!float(3, 2).slice.variable;
assert(model1(x).sliced == model2(x).sliced);
}

///
Expand All @@ -169,6 +153,11 @@ version (grain_cuda) unittest {
auto model2 = MLP!(float, DeviceStorage)(3);
model2.load("test_grain.h5");
assert(model1.fc1.bias.to!HostStorage.sliced == model2.fc1.bias.to!HostStorage.sliced);

import numir;
import mir.ndslice;
auto x = uniform!float(3, 2).slice.variable.to!DeviceStorage;
assert(model1(x).to!HostStorage.sliced == model2(x).to!HostStorage.sliced);
}

///
Expand Down

0 comments on commit 3d7b018

Please sign in to comment.