-
Notifications
You must be signed in to change notification settings - Fork 2
/
mnist.cpp
145 lines (118 loc) · 3.79 KB
/
mnist.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
// Copyright (c) 2013, Manuel Blum
// All rights reserved.
// Define this symbol to enable runtime tests for allocations
//#define EIGEN_RUNTIME_NO_MALLOC
#include <Eigen/Dense>
#include <iostream>
#include <fstream>
#include <cstdio>
#include <string>
#include "nn.h"
inline void swap(int &val)
{
val = (val<<24) | ((val<<8) & 0x00ff0000) | ((val>>8) & 0x0000ff00) | (val>>24);
}
matrix_t read_mnist_images(std::string filename)
{
matrix_t X;
std::ifstream fs(filename.c_str(), std::ios::binary);
if(fs) {
int magic_number, num_images, num_rows, num_columns;
fs.read((char*)&magic_number, sizeof(magic_number));
fs.read((char*)&num_images, sizeof(num_images));
fs.read((char*)&num_rows, sizeof(num_rows));
fs.read((char*)&num_columns, sizeof(num_columns));
if (magic_number != 2051) {
swap(magic_number);
swap(num_images);
swap(num_rows);
swap(num_columns);
}
X = matrix_t::Zero(num_images, num_rows*num_columns);
for (size_t i=0; i<num_images; ++i) {
for (size_t j=0; j<num_rows*num_columns; ++j) {
unsigned char temp=0;
fs.read((char*)&temp,sizeof(temp));
X(i,j) = (double) temp;
}
}
fs.close();
} else {
std::cout << "error reading file: " << filename << std::endl;
exit(1);
}
return X;
}
matrix_t read_mnist_labels(std::string filename)
{
matrix_t Y;
std::ifstream fs(filename.c_str(), std::ios::binary);
if(fs) {
int magic_number, num_images, num_rows, num_columns;
fs.read((char*)&magic_number, sizeof(magic_number));
fs.read((char*)&num_images, sizeof(num_images));
if (magic_number != 2049) {
swap(magic_number);
swap(num_images);
}
Y = matrix_t::Zero(num_images, 10);
for (size_t i=0; i<num_images; ++i) {
unsigned char temp=0;
fs.read((char*)&temp,sizeof(temp));
Y(i,(int) temp) = 1.0;
}
fs.close();
} else {
std::cout << "error reading file: " << filename << std::endl;
exit(1);
}
return Y;
}
int main (int argc, const char* argv[]) {
if (argc != 2) {
std::cout << "please provide path to mnist data ..." << std::endl;
std::cout << "you can download the dataset at http://yann.lecun.com/exdb/mnist/" << std::endl;
std::cout << std::endl << "usage: " << argv[0] << " path_to_data" << std::endl << std::endl;
return 1;
}
std::string path = argv[1];
std::cout << "reading data" << std::endl;
matrix_t X_train = read_mnist_images(path + "/train-images-idx3-ubyte");
matrix_t Y_train = read_mnist_labels(path + "/train-labels-idx1-ubyte");
matrix_t X_test = read_mnist_images(path + "/t10k-images-idx3-ubyte");
matrix_t Y_test = read_mnist_labels(path + "/t10k-labels-idx1-ubyte");
// number of optimization steps
int max_steps = 600;
// regularization parameter
double lambda = 0.0;
// specify network topology
Eigen::VectorXi topo(3);
topo << X_train.cols(), 300, Y_test.cols();
std::cout << "topology: " << topo.transpose() << std::endl;
// initialize a neural network with given topology
std::cout << "initializing network" << std::endl;
NeuralNet nn(topo);
std::cout << "scaling the data" << std::endl;
nn.autoscale(X_train, Y_train);
// train the network
std::cout << "starting training" << std::endl;
std::cout << "iter error" << std::endl;
double err;
for (int i = 0; i < max_steps; ++i) {
err = nn.loss(X_train, Y_train, lambda);
nn.rprop();
printf("%4i %10.7f\n", i, err);
}
// test accuracy
nn.forward_pass(X_test);
matrix_t prediction = nn.get_activation();
int correct = 0;
int k;
for (size_t i=0; i<Y_test.rows(); ++i) {
prediction.row(i).maxCoeff(&k);
correct += Y_test(i, k);
}
std::cout << "test accuracy: " << correct*1.0/Y_test.rows() << std::endl;
nn.write("mnist.nn");
return 0;
}