-
Notifications
You must be signed in to change notification settings - Fork 0
/
mnist.py
223 lines (170 loc) · 6.43 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# Based on https://github.com/datapythonista/mnist
import os
import functools
import operator
import gzip
import struct
import array
import tempfile
try:
from urllib.request import urlretrieve
except ImportError:
from urllib import urlretrieve # py2
try:
from urllib.parse import urljoin
except ImportError:
from urlparse import urljoin
import numpy
__version__ = '0.2.2'
# `datasets_url` and `temporary_dir` can be set by the user using:
# >>> mnist.datasets_url = 'http://my.mnist.url'
# >>> mnist.temporary_dir = lambda: '/tmp/mnist'
datasets_url = 'http://yann.lecun.com/exdb/mnist/'
temporary_dir = lambda: os.path.join(os.path.dirname(__file__), 'data') # tempfile.gettempdir
class IdxDecodeError(ValueError):
"""Raised when an invalid idx file is parsed."""
pass
def image_shape():
return 28, 28
def transform(matrix):
# Reshape image data so each image is represented by one long array
matrix = matrix.reshape(matrix.shape[0], matrix.shape[1] * matrix.shape[2])
# Normalise input in the range [0, 1]
return matrix / 255
def load_data():
X_train = transform(train_images())
y_train = train_labels()
X_test = transform(test_images())
y_test = test_labels()
return X_train, X_test, y_train, y_test
def download_file(fname, target_dir=None, force=False):
"""Download fname from the datasets_url, and save it to target_dir,
unless the file already exists, and force is False.
Parameters
----------
fname : str
Name of the file to download
target_dir : str
Directory where to store the file
force : bool
Force downloading the file, if it already exists
Returns
-------
fname : str
Full path of the downloaded file
"""
target_dir = target_dir or temporary_dir()
target_fname = os.path.join(target_dir, fname)
if force or not os.path.isfile(target_fname):
url = urljoin(datasets_url, fname)
urlretrieve(url, target_fname)
return target_fname
def parse_idx(fd):
"""Parse an IDX file, and return it as a numpy array.
Parameters
----------
fd : file
File descriptor of the IDX file to parse
endian : str
Byte order of the IDX file. See [1] for available options
Returns
-------
data : numpy.ndarray
Numpy array with the dimensions and the data in the IDX file
1. https://docs.python.org/3/library/struct.html
#byte-order-size-and-alignment
"""
DATA_TYPES = {0x08: 'B', # unsigned byte
0x09: 'b', # signed byte
0x0b: 'h', # short (2 bytes)
0x0c: 'i', # int (4 bytes)
0x0d: 'f', # float (4 bytes)
0x0e: 'd'} # double (8 bytes)
header = fd.read(4)
if len(header) != 4:
raise IdxDecodeError('Invalid IDX file, '
'file empty or does not contain a full header.')
zeros, data_type, num_dimensions = struct.unpack('>HBB', header)
if zeros != 0:
raise IdxDecodeError('Invalid IDX file, '
'file must start with two zero bytes. '
'Found 0x%02x' % zeros)
try:
data_type = DATA_TYPES[data_type]
except KeyError:
raise IdxDecodeError('Unknown data type '
'0x%02x in IDX file' % data_type)
dimension_sizes = struct.unpack('>' + 'I' * num_dimensions,
fd.read(4 * num_dimensions))
data = array.array(data_type, fd.read())
data.byteswap() # looks like array.array reads data as little endian
expected_items = functools.reduce(operator.mul, dimension_sizes)
if len(data) != expected_items:
raise IdxDecodeError('IDX file has wrong number of items. '
'Expected: %d. Found: %d' % (expected_items,
len(data)))
return numpy.array(data).reshape(dimension_sizes)
def download_and_parse_mnist_file(fname, target_dir=None, force=False):
"""Download the IDX file named fname from the URL specified in dataset_url
and return it as a numpy array.
Parameters
----------
fname : str
File name to download and parse
target_dir : str
Directory where to store the file
force : bool
Force downloading the file, if it already exists
Returns
-------
data : numpy.ndarray
Numpy array with the dimensions and the data in the IDX file
"""
fname = download_file(fname, target_dir=target_dir, force=force)
fopen = gzip.open if os.path.splitext(fname)[1] == '.gz' else open
with fopen(fname, 'rb') as fd:
return parse_idx(fd)
def train_images():
"""Return train images from Yann LeCun MNIST database as a numpy array.
Download the file, if not already found in the temporary directory of
the system.
Returns
-------
train_images : numpy.ndarray
Numpy array with the images in the train MNIST database. The first
dimension indexes each sample, while the other two index rows and
columns of the image
"""
return download_and_parse_mnist_file('train-images-idx3-ubyte.gz')
def test_images():
"""Return test images from Yann LeCun MNIST database as a numpy array.
Download the file, if not already found in the temporary directory of
the system.
Returns
-------
test_images : numpy.ndarray
Numpy array with the images in the train MNIST database. The first
dimension indexes each sample, while the other two index rows and
columns of the image
"""
return download_and_parse_mnist_file('t10k-images-idx3-ubyte.gz')
def train_labels():
"""Return train labels from Yann LeCun MNIST database as a numpy array.
Download the file, if not already found in the temporary directory of
the system.
Returns
-------
train_labels : numpy.ndarray
Numpy array with the labels 0 to 9 in the train MNIST database.
"""
return download_and_parse_mnist_file('train-labels-idx1-ubyte.gz')
def test_labels():
"""Return test labels from Yann LeCun MNIST database as a numpy array.
Download the file, if not already found in the temporary directory of
the system.
Returns
-------
test_labels : numpy.ndarray
Numpy array with the labels 0 to 9 in the train MNIST database.
"""
return download_and_parse_mnist_file('t10k-labels-idx1-ubyte.gz')