Skip to content

Commit

Permalink
refactor optim
Browse files Browse the repository at this point in the history
  • Loading branch information
ShigekiKarita committed Jun 19, 2018
1 parent e20852a commit bb36473
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 26 deletions.
5 changes: 2 additions & 3 deletions example/char_rnn.d
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ void main() {
auto model = RNN!Storage(vocabSize, hiddenSize);

// for optim
// SGD optim = { lr: learningRate };
auto optim = AdaGrad!(typeof(model))(model, learningRate);
auto smoothLoss = -log(1.0 / vocabSize) * seqLength;
size_t beginId = 0;
Expand All @@ -154,9 +153,9 @@ void main() {
// forward seq_length characters through the net and fetch gradient
model.zeroGrad();
auto ret = model.accumGrad(ids.sliced.unsqueeze!0, hprev);
hprev = ret.hprev;
optim.update(model);
optim.update();
smoothLoss = smoothLoss * 0.999 + ret.loss * 0.001;
hprev = ret.hprev;
if (nIter % logIter == 0) {
writefln!"iter %d, loss: %f, iter/sec: %f"(
nIter, smoothLoss,
Expand Down
4 changes: 2 additions & 2 deletions example/mnist.d
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ void main() {
auto trainBatch = datasets.train.makeBatch(batchSize);
auto testBatch = datasets.test.makeBatch(batchSize);
auto model = MLP!(float, S)(inSize, 512, 10);
SGD optimizer = {lr: 1e-2};
auto optimizer = SGD!(typeof(model))(model, 1e-2);

foreach (epoch; 0 .. 10) {
// TODO implement model.train();
Expand All @@ -135,7 +135,7 @@ void main() {
accSum += acc;
model.zeroGrad();
loss.backward();
optimizer.update(model);
optimizer.update();
}
writefln!"train loss: %f, acc: %f"(lossSum / niter, accSum / niter);
}
Expand Down
55 changes: 34 additions & 21 deletions source/grain/optim.d
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ enum bool isOptimizer(T) = is(typeof({


/// kind of std.algorithm.each for iterating variables inside a chain
void iterVariables(alias proc, C)(ref C chain, string prefix="") {
void iterVariables(alias proc, C)(C* chain, string prefix="") {
import std.traits;
import grain.autograd;

Expand All @@ -44,11 +44,12 @@ void iterVariables(alias proc, C)(ref C chain, string prefix="") {
static if (isVariable!V) {
proc(fullName, value);
} else static if (hasMember!(V, "tupleof")) {
iterVariables!proc(value, fullName);
iterVariables!proc(&value, fullName);
}
}
}

/*
enum variableNames(C) = {
string[] ret;
void register(V)(string k, V v) if (isVariable!V) {
Expand All @@ -70,12 +71,13 @@ unittest {
StateDict dict;
iterVariables!( (k, v) { dict[k] = UntypedVariable(v); } )(mlp);
}
*/


alias StateDict = UntypedVariable[string];

void update(O, C)(ref O optimizer, ref C chain, string attr = "") { // if (isOptimizer!O) {
iterVariables!( (k, v) {optimizer.step(k, v);} )(chain);
void update(O)(ref O optimizer) { // if (isOptimizer!O) {
iterVariables!( (k, v) {optimizer.step(k, v);} )(optimizer.target, "");
}

void transform(T, size_t dim)(Variable!(T, dim, HostStorage) src, ref Variable!(T, dim, HostStorage) dst, T alpha=1, T beta=0) {
Expand All @@ -89,10 +91,15 @@ void transform(T, size_t dim)(Variable!(T, dim, HostStorage) src, ref Variable!(


/// stochastic gradient descent optimizer
struct SGD {
struct SGD(Chain) {
Chain* target;
float lr = 1.0;
// float momentum = 0.0;
// float weightDecay = 0.0;
this(ref Chain target, float lr=1.0) {
this.target = ⌖
this.lr = lr;
}

void step(V)(string name, ref V field) if (isVariable!V) {
// transform(field.gradVariable, field, -this.lr, 1.0);
Expand Down Expand Up @@ -143,10 +150,10 @@ unittest {
mlp.zeroGrad();
assert(mlp.fc1.weight.grad[0] == 0.0);

auto sgd = SGD(0.5);
auto sgd = SGD!(typeof(mlp))(mlp, 0.5);
mlp.fc1.weight.data.zero_();
mlp.fc1.weight.grad = [[1.0f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable.data;
sgd.update(mlp);
sgd.update();
assert(mlp.fc1.weight.sliced == [[-0.5, 0.0, 0.0], [0.0, 0.0, 0.0]]);
}
version (grain_cuda) {
Expand All @@ -155,10 +162,10 @@ unittest {
mlp.zeroGrad();
assert(mlp.fc1.weight.to!HostStorage.gradSliced == [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]);

auto sgd = SGD(0.5);
auto sgd = SGD!(typeof(mlp))(mlp, 0.5);
mlp.fc1.weight.data.zero_();
mlp.fc1.weight.grad = [[1.0f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable.to!DeviceStorage.data;
sgd.update(mlp);
sgd.update();
assert(mlp.fc1.weight.to!HostStorage.sliced == [[-0.5, 0.0, 0.0], [0.0, 0.0, 0.0]]);
}
}
Expand All @@ -168,14 +175,16 @@ unittest {
struct AdaGrad(Chain) {
import grain.autograd;

Chain* target;
float lr = 1.0;
float eps = 1e-8;
StateDict memory;

this(ref Chain model, float lr=1e-3, float eps=1e-8) {
this(ref Chain target, float lr=1e-3, float eps=1e-8) {
this.target = ⌖
this.lr = lr;
this.eps = eps;
iterVariables!((k, v) { this.initStates(k, v); })(model);
iterVariables!((k, v) { this.initStates(k, v); })(this.target);
}

void initStates(V)(string name, ref V field) if (isVariable!V) {
Expand Down Expand Up @@ -210,7 +219,7 @@ unittest {
static assert(isOptimizer!(typeof(optim)));
model.fc1.weight.data.zero_();
model.fc1.weight.grad = [[0.2f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable.data;
optim.update(model);
optim.update();
auto w = model.fc1.weight;
assert(approxEqual(w.sliced, [[-lr * 0.2 / (0.2 * 0.2 + eps) ^^ 0.5, 0.0, 0.0], [0.0, 0.0, 0.0]].nparray));
auto m = optim.memory[".fc1.weight"].to!(typeof(w));
Expand All @@ -219,7 +228,7 @@ unittest {
version (grain_cuda) {
auto model = MLP!(float, DeviceStorage)(3);
auto optim = AdaGrad!(typeof(model))(model, 0.1);
optim.update(model);
optim.update();
}
}

Expand All @@ -229,17 +238,19 @@ unittest {
struct Adam(Chain) {
import grain.autograd;

Chain* target;
float lr = 1.0;
float beta1 = 0.9;
float beta2 = 0.999;
float eps = 1e-8;

StateDict moment1, moment2;

this(ref Chain model, float lr, float eps=1e-8) {
this(ref Chain target, float lr, float eps=1e-8) {
this.target = ⌖
this.lr = lr;
this.eps = eps;
iterVariables!((k, v) { this.initStates(k, v); })(model);
iterVariables!((k, v) { this.initStates(k, v); })(this.target);
}

void initStates(V)(string name, ref V field) if (isVariable!V) {
Expand Down Expand Up @@ -280,7 +291,7 @@ unittest {
static assert(isOptimizer!(typeof(optim)));
model.fc1.weight.data.zero_();
model.fc1.weight.grad = [[0.2f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable.data;
optim.update(model);
optim.update();
auto w = model.fc1.weight;
auto m1 = (1.0 - optim.beta1) * (0.2 - 0.0) + 0.0;
auto m2 = (1.0 - optim.beta2) * (0.2 * 0.2 - 0.0) + 0.0;
Expand All @@ -293,7 +304,7 @@ unittest {
version (grain_cuda) {
auto model = MLP!(float, DeviceStorage)(3);
auto optim = Adam!(typeof(model))(model, 0.1);
optim.update(model);
optim.update();
}
}

Expand All @@ -302,17 +313,19 @@ unittest {
struct AdaDelta(Chain) {
import grain.autograd;

Chain* target;
float lr = 1.0;
float rho = 0.95;
float eps = 1e-6;

StateDict den, num;

this(ref Chain model, float lr=1.0, float rho=0.95, float eps=1e-8) {
this(ref Chain target, float lr=1.0, float rho=0.95, float eps=1e-8) {
this.target = ⌖
this.lr = lr;
this.rho = rho;
this.eps = eps;
iterVariables!((k, v) { this.initStates(k, v); })(model);
iterVariables!((k, v) { this.initStates(k, v); })(this.target);
}

void initStates(V)(string name, ref V field) if (isVariable!V) {
Expand Down Expand Up @@ -353,7 +366,7 @@ unittest {
// static assert(isOptimizer!(typeof(optim)));
model.fc1.weight.data.zero_();
model.fc1.weight.grad = [[0.2f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable.data;
optim.update(model);
optim.update();
auto w = model.fc1.weight;
auto d = (1.0 - optim.rho) * 0.2 * 0.2;
auto diff = cast(float) ((0.0 + optim.eps) / (d + optim.eps)) ^^ 0.5;
Expand All @@ -367,6 +380,6 @@ unittest {
version (grain_cuda) {
auto model = MLP!(float, DeviceStorage)(3);
auto optim = AdaDelta!(typeof(model))(model);
optim.update(model);
optim.update();
}
}

0 comments on commit bb36473

Please sign in to comment.