forked from mila-iqia/fuel
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ba15181
commit 755f1f2
Showing
11 changed files
with
216 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import os | ||
|
||
import h5py | ||
import numpy | ||
|
||
from fuel.converters.base import fill_hdf5_file | ||
|
||
|
||
def convert_iris(directory, output_directory, output_filename='iris.hdf5'): | ||
"""Convert the Iris dataset to HDF5. | ||
Converts the Iris dataset to an HDF5 dataset compatible with | ||
:class:`fuel.datasets.Iris`. The converted dataset is | ||
saved as 'iris.hdf5'. | ||
This method assumes the existence of the file `iris.data`. | ||
Parameters | ||
---------- | ||
directory : str | ||
Directory in which input files reside. | ||
output_directory : str | ||
Directory in which to save the converted dataset. | ||
output_filename : str, optional | ||
Name of the saved dataset. Defaults to `None`, in which case a name | ||
based on `dtype` will be used. | ||
Returns | ||
------- | ||
output_paths : tuple of str | ||
Single-element tuple containing the path to the converted dataset. | ||
""" | ||
classes = {'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2} | ||
data = numpy.loadtxt( | ||
os.path.join(directory, 'iris.data'), | ||
converters={4: lambda x: classes[x]}, | ||
delimiter=',') | ||
features = data[:, :-1].astype('float32') | ||
targets = data[:, -1].astype('uint8').reshape((-1, 1)) | ||
data = (('all', 'features', features), | ||
('all', 'targets', targets)) | ||
|
||
output_path = os.path.join(output_directory, output_filename) | ||
h5file = h5py.File(output_path, mode='w') | ||
fill_hdf5_file(h5file, data) | ||
h5file['features'].dims[0].label = 'batch' | ||
h5file['features'].dims[1].label = 'feature' | ||
h5file['targets'].dims[0].label = 'batch' | ||
h5file['targets'].dims[1].label = 'index' | ||
|
||
h5file.flush() | ||
h5file.close() | ||
|
||
return (output_path,) | ||
|
||
|
||
def fill_subparser(subparser): | ||
subparser.set_defaults(func=convert_iris) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from fuel.datasets import H5PYDataset | ||
from fuel.utils import find_in_data_path | ||
|
||
|
||
class Iris(H5PYDataset): | ||
u"""Iris dataset. | ||
Iris [LBBH] is a simple pattern recognition dataset, which consist of | ||
3 classes of 50 examples each having 4 real-valued features each, where | ||
each class refers to a type of iris plant. It is accessible through the | ||
UCI Machine Learning repository [UCI]. | ||
.. [IRIS] Ronald A. Fisher, *The use of multiple measurements in | ||
taxonomic problems*, Annual Eugenics, 7, Part II, 179-188, | ||
September 1936. | ||
.. [UCI] https://archive.ics.uci.edu/ml/datasets/Iris | ||
Parameters | ||
---------- | ||
which_sets : tuple of str | ||
Which split to load. Valid value is 'all' | ||
corresponding to 150 examples. | ||
""" | ||
filename = 'iris.hdf5' | ||
|
||
def __init__(self, which_sets, **kwargs): | ||
kwargs.setdefault('load_in_memory', True) | ||
super(Iris, self).__init__( | ||
file_or_path=find_in_data_path(self.filename), | ||
which_sets=which_sets, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from fuel.downloaders.base import default_downloader | ||
|
||
|
||
def fill_subparser(subparser): | ||
"""Set up a subparser to download the Iris dataset file. | ||
The Iris dataset file `iris.data` is downloaded from the UCI | ||
Machine Learning Repository [UCI]. | ||
.. [UCI] https://archive.ics.uci.edu/ml/datasets/Iris | ||
Parameters | ||
---------- | ||
subparser : :class:`argparse.ArgumentParser` | ||
Subparser handling the iris command. | ||
""" | ||
subparser.set_defaults( | ||
func=default_downloader, | ||
urls=['https://archive.ics.uci.edu/ml/machine-learning-databases/' | ||
'iris/iris.data'], | ||
filenames=['iris.data']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import numpy | ||
|
||
from numpy.testing import assert_raises, assert_equal, assert_allclose | ||
|
||
from fuel import config | ||
from fuel.datasets import Iris | ||
from tests import skip_if_not_available | ||
|
||
|
||
def test_iris_all(): | ||
skip_if_not_available(datasets=['iris.hdf5']) | ||
|
||
dataset = Iris(('all',), load_in_memory=False) | ||
handle = dataset.open() | ||
data, labels = dataset.get_data(handle, slice(0, 10)) | ||
assert data.dtype == config.floatX | ||
assert data.shape == (10, 4) | ||
assert labels.shape == (10, 1) | ||
known = numpy.array([5.1, 3.5, 1.4, 0.2]) | ||
assert_allclose(data[0], known) | ||
assert labels[0][0] == 0 | ||
assert dataset.num_examples == 150 | ||
dataset.close(handle) | ||
|
||
|
||
def test_iris_axes(): | ||
skip_if_not_available(datasets=['iris.hdf5']) | ||
|
||
dataset = Iris(('all',), load_in_memory=False) | ||
assert_equal(dataset.axis_labels['features'], | ||
('batch', 'feature')) | ||
|
||
|
||
def test_iris_invalid_split(): | ||
skip_if_not_available(datasets=['iris.hdf5']) | ||
|
||
assert_raises(ValueError, Iris, ('dummy',)) |