-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #971 from reyoung/feature/mnist_train_api
[Done] Feature/mnist train api
- Loading branch information
Showing
19 changed files
with
633 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ mnist_vgg_model | |
plot.png | ||
train.log | ||
*pyc | ||
.ipynb_checkpoints |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
""" | ||
A very basic example for how to use current Raw SWIG API to train mnist network. | ||
Current implementation uses Raw SWIG, which means the API call is directly \ | ||
passed to C++ side of Paddle. | ||
The user api could be simpler and carefully designed. | ||
""" | ||
import py_paddle.swig_paddle as api | ||
from py_paddle import DataProviderConverter | ||
import paddle.trainer.PyDataProvider2 as dp | ||
import numpy as np | ||
import random | ||
from mnist_util import read_from_mnist | ||
from paddle.trainer_config_helpers import * | ||
|
||
|
||
def optimizer_config(): | ||
settings( | ||
learning_rate=1e-4, | ||
learning_method=AdamOptimizer(), | ||
batch_size=1000, | ||
model_average=ModelAverage(average_window=0.5), | ||
regularization=L2Regularization(rate=0.5)) | ||
|
||
|
||
def network_config(): | ||
imgs = data_layer(name='pixel', size=784) | ||
hidden1 = fc_layer(input=imgs, size=200) | ||
hidden2 = fc_layer(input=hidden1, size=200) | ||
inference = fc_layer(input=hidden2, size=10, act=SoftmaxActivation()) | ||
cost = classification_cost( | ||
input=inference, label=data_layer( | ||
name='label', size=10)) | ||
outputs(cost) | ||
|
||
|
||
def init_parameter(network): | ||
assert isinstance(network, api.GradientMachine) | ||
for each_param in network.getParameters(): | ||
assert isinstance(each_param, api.Parameter) | ||
array_size = len(each_param) | ||
array = np.random.uniform(-1.0, 1.0, array_size).astype('float32') | ||
each_param.getBuf(api.PARAMETER_VALUE).copyFromNumpyArray(array) | ||
|
||
|
||
def generator_to_batch(generator, batch_size): | ||
ret_val = list() | ||
for each_item in generator: | ||
ret_val.append(each_item) | ||
if len(ret_val) == batch_size: | ||
yield ret_val | ||
ret_val = list() | ||
if len(ret_val) != 0: | ||
yield ret_val | ||
|
||
|
||
class BatchPool(object): | ||
def __init__(self, generator, batch_size): | ||
self.data = list(generator) | ||
self.batch_size = batch_size | ||
|
||
def __call__(self): | ||
random.shuffle(self.data) | ||
for offset in xrange(0, len(self.data), self.batch_size): | ||
limit = min(offset + self.batch_size, len(self.data)) | ||
yield self.data[offset:limit] | ||
|
||
|
||
def input_order_converter(generator): | ||
for each_item in generator: | ||
yield each_item['pixel'], each_item['label'] | ||
|
||
|
||
def main(): | ||
api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores | ||
|
||
# get enable_types for each optimizer. | ||
# enable_types = [value, gradient, momentum, etc] | ||
# For each optimizer(SGD, Adam), GradientMachine should enable different | ||
# buffers. | ||
opt_config_proto = parse_optimizer_config(optimizer_config) | ||
opt_config = api.OptimizationConfig.createFromProto(opt_config_proto) | ||
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config) | ||
enable_types = _temp_optimizer_.getParameterTypes() | ||
|
||
# Create Simple Gradient Machine. | ||
model_config = parse_network_config(network_config) | ||
m = api.GradientMachine.createFromConfigProto( | ||
model_config, api.CREATE_MODE_NORMAL, enable_types) | ||
|
||
# This type check is not useful. Only enable type hint in IDE. | ||
# Such as PyCharm | ||
assert isinstance(m, api.GradientMachine) | ||
|
||
# Initialize Parameter by numpy. | ||
init_parameter(network=m) | ||
|
||
# Create Local Updater. Local means not run in cluster. | ||
# For a cluster training, here we can change to createRemoteUpdater | ||
# in future. | ||
updater = api.ParameterUpdater.createLocalUpdater(opt_config) | ||
assert isinstance(updater, api.ParameterUpdater) | ||
|
||
# Initialize ParameterUpdater. | ||
updater.init(m) | ||
|
||
# DataProvider Converter is a utility convert Python Object to Paddle C++ | ||
# Input. The input format is as same as Paddle's DataProvider. | ||
converter = DataProviderConverter( | ||
input_types=[dp.dense_vector(784), dp.integer_value(10)]) | ||
|
||
train_file = './data/raw_data/train' | ||
test_file = './data/raw_data/t10k' | ||
|
||
# start gradient machine. | ||
# the gradient machine must be started before invoke forward/backward. | ||
# not just for training, but also for inference. | ||
m.start() | ||
|
||
# evaluator can print error rate, etc. It is a C++ class. | ||
batch_evaluator = m.makeEvaluator() | ||
test_evaluator = m.makeEvaluator() | ||
|
||
# Get Train Data. | ||
# TrainData will stored in a data pool. Currently implementation is not care | ||
# about memory, speed. Just a very naive implementation. | ||
train_data_generator = input_order_converter(read_from_mnist(train_file)) | ||
train_data = BatchPool(train_data_generator, 512) | ||
|
||
# outArgs is Neural Network forward result. Here is not useful, just passed | ||
# to gradient_machine.forward | ||
outArgs = api.Arguments.createArguments(0) | ||
|
||
for pass_id in xrange(2): # we train 2 passes. | ||
updater.startPass() | ||
|
||
for batch_id, data_batch in enumerate(train_data()): | ||
# data_batch is input images. | ||
# here, for online learning, we could get data_batch from network. | ||
|
||
# Start update one batch. | ||
pass_type = updater.startBatch(len(data_batch)) | ||
|
||
# Start BatchEvaluator. | ||
# batch_evaluator can be used between start/finish. | ||
batch_evaluator.start() | ||
|
||
# forwardBackward is a shortcut for forward and backward. | ||
# It is sometimes faster than invoke forward/backward separately, | ||
# because in GradientMachine, it may be async. | ||
m.forwardBackward(converter(data_batch), outArgs, pass_type) | ||
|
||
for each_param in m.getParameters(): | ||
updater.update(each_param) | ||
|
||
# Get cost. We use numpy to calculate total cost for this batch. | ||
cost_vec = outArgs.getSlotValue(0) | ||
cost_vec = cost_vec.copyToNumpyMat() | ||
cost = cost_vec.sum() / len(data_batch) | ||
|
||
# Make evaluator works. | ||
m.eval(batch_evaluator) | ||
|
||
# Print logs. | ||
print 'Pass id', pass_id, 'Batch id', batch_id, 'with cost=', \ | ||
cost, batch_evaluator | ||
|
||
batch_evaluator.finish() | ||
# Finish batch. | ||
# * will clear gradient. | ||
# * ensure all values should be updated. | ||
updater.finishBatch(cost) | ||
|
||
# testing stage. use test data set to test current network. | ||
updater.apply() | ||
test_evaluator.start() | ||
test_data_generator = input_order_converter(read_from_mnist(test_file)) | ||
for data_batch in generator_to_batch(test_data_generator, 512): | ||
# in testing stage, only forward is needed. | ||
m.forward(converter(data_batch), outArgs, api.PASS_TEST) | ||
m.eval(test_evaluator) | ||
|
||
# print error rate for test data set | ||
print 'Pass', pass_id, ' test evaluator: ', test_evaluator | ||
test_evaluator.finish() | ||
updater.restore() | ||
|
||
updater.catchUpWith() | ||
params = m.getParameters() | ||
for each_param in params: | ||
assert isinstance(each_param, api.Parameter) | ||
value = each_param.getBuf(api.PARAMETER_VALUE) | ||
value = value.copyToNumpyArray() | ||
|
||
# Here, we could save parameter to every where you want | ||
print each_param.getName(), value | ||
|
||
updater.finishPass() | ||
|
||
m.finish() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import numpy | ||
|
||
__all__ = ['read_from_mnist'] | ||
|
||
|
||
def read_from_mnist(filename): | ||
imgf = filename + "-images-idx3-ubyte" | ||
labelf = filename + "-labels-idx1-ubyte" | ||
f = open(imgf, "rb") | ||
l = open(labelf, "rb") | ||
|
||
f.read(16) | ||
l.read(8) | ||
|
||
# Define number of samples for train/test | ||
if "train" in filename: | ||
n = 60000 | ||
else: | ||
n = 10000 | ||
|
||
images = numpy.fromfile( | ||
f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)).astype('float32') | ||
images = images / 255.0 * 2.0 - 1.0 | ||
labels = numpy.fromfile(l, 'ubyte', count=n).astype("int") | ||
|
||
for i in xrange(n): | ||
yield {"pixel": images[i, :], 'label': labels[i]} | ||
|
||
f.close() | ||
l.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
#include <sstream> | ||
#include "PaddleAPI.h" | ||
#include "PaddleAPIPrivate.h" | ||
|
||
Evaluator::Evaluator() : m(new EvaluatorPrivate()) {} | ||
Evaluator::~Evaluator() { delete m; } | ||
|
||
void Evaluator::start() { m->rawPtr->start(); } | ||
|
||
void Evaluator::finish() { m->rawPtr->finish(); } | ||
|
||
std::string Evaluator::toString() { | ||
std::ostringstream sout; | ||
m->rawPtr->printStats(sout); | ||
return sout.str(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.