Skip to content

Commit

Permalink
Iris
Browse files Browse the repository at this point in the history
  • Loading branch information
laurent-dinh committed Jul 6, 2015
1 parent ba15181 commit 755f1f2
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 5 deletions.
4 changes: 2 additions & 2 deletions docs/new_dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ Try downloading and converting the data file:
cd $FUEL_DATA_PATH
fuel-download iris
fuel-convert iris
fuel-download --clear iris
fuel-download iris --clear
cd -
You can now use the Iris dataset like you would use any other built-in dataset:
Expand Down Expand Up @@ -300,7 +300,7 @@ You can now use the Iris dataset like you would use any other built-in dataset:
... delimiter=',')
... numpy.random.shuffle(data)
... features = data[:, :-1].astype('float32')
... targets = data[:, -1:].astype('uint8')
... targets = data[:, -1:].astype('uint8').reshape((-1, 1))
... train_features = features[:100]
... train_targets = targets[:100]
... valid_features = features[100:120]
Expand Down
2 changes: 2 additions & 0 deletions fuel/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from fuel.converters import caltech101_silhouettes
from fuel.converters import cifar10
from fuel.converters import cifar100
from fuel.converters import iris
from fuel.converters import mnist
from fuel.converters import svhn

Expand All @@ -23,5 +24,6 @@
('caltech101_silhouettes', caltech101_silhouettes.fill_subparser),
('cifar10', cifar10.fill_subparser),
('cifar100', cifar100.fill_subparser),
('iris', iris.fill_subparser),
('mnist', mnist.fill_subparser),
('svhn', svhn.fill_subparser))
58 changes: 58 additions & 0 deletions fuel/converters/iris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os

import h5py
import numpy

from fuel.converters.base import fill_hdf5_file


def convert_iris(directory, output_directory, output_filename='iris.hdf5'):
"""Convert the Iris dataset to HDF5.
Converts the Iris dataset to an HDF5 dataset compatible with
:class:`fuel.datasets.Iris`. The converted dataset is
saved as 'iris.hdf5'.
This method assumes the existence of the file `iris.data`.
Parameters
----------
directory : str
Directory in which input files reside.
output_directory : str
Directory in which to save the converted dataset.
output_filename : str, optional
Name of the saved dataset. Defaults to `None`, in which case a name
based on `dtype` will be used.
Returns
-------
output_paths : tuple of str
Single-element tuple containing the path to the converted dataset.
"""
classes = {'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2}
data = numpy.loadtxt(
os.path.join(directory, 'iris.data'),
converters={4: lambda x: classes[x]},
delimiter=',')
features = data[:, :-1].astype('float32')
targets = data[:, -1].astype('uint8').reshape((-1, 1))
data = (('all', 'features', features),
('all', 'targets', targets))

output_path = os.path.join(output_directory, output_filename)
h5file = h5py.File(output_path, mode='w')
fill_hdf5_file(h5file, data)
h5file['features'].dims[0].label = 'batch'
h5file['features'].dims[1].label = 'feature'
h5file['targets'].dims[0].label = 'batch'
h5file['targets'].dims[1].label = 'index'

h5file.flush()
h5file.close()

return (output_path,)


def fill_subparser(subparser):
subparser.set_defaults(func=convert_iris)
1 change: 1 addition & 0 deletions fuel/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fuel.datasets.cifar10 import CIFAR10
from fuel.datasets.cifar100 import CIFAR100
from fuel.datasets.caltech101_silhouettes import CalTech101Silhouettes
from fuel.datasets.iris import Iris
from fuel.datasets.mnist import MNIST
from fuel.datasets.svhn import SVHN
from fuel.datasets.text import TextFile
Expand Down
31 changes: 31 additions & 0 deletions fuel/datasets/iris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from fuel.datasets import H5PYDataset
from fuel.utils import find_in_data_path


class Iris(H5PYDataset):
u"""Iris dataset.
Iris [LBBH] is a simple pattern recognition dataset, which consist of
3 classes of 50 examples each having 4 real-valued features each, where
each class refers to a type of iris plant. It is accessible through the
UCI Machine Learning repository [UCI].
.. [IRIS] Ronald A. Fisher, *The use of multiple measurements in
taxonomic problems*, Annual Eugenics, 7, Part II, 179-188,
September 1936.
.. [UCI] https://archive.ics.uci.edu/ml/datasets/Iris
Parameters
----------
which_sets : tuple of str
Which split to load. Valid value is 'all'
corresponding to 150 examples.
"""
filename = 'iris.hdf5'

def __init__(self, which_sets, **kwargs):
kwargs.setdefault('load_in_memory', True)
super(Iris, self).__init__(
file_or_path=find_in_data_path(self.filename),
which_sets=which_sets, **kwargs)
2 changes: 2 additions & 0 deletions fuel/downloaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fuel.downloaders import caltech101_silhouettes
from fuel.downloaders import cifar10
from fuel.downloaders import cifar100
from fuel.downloaders import iris
from fuel.downloaders import mnist
from fuel.downloaders import svhn

Expand All @@ -18,5 +19,6 @@
('caltech101_silhouettes', caltech101_silhouettes.fill_subparser),
('cifar10', cifar10.fill_subparser),
('cifar100', cifar100.fill_subparser),
('iris', iris.fill_subparser),
('mnist', mnist.fill_subparser),
('svhn', svhn.fill_subparser))
22 changes: 22 additions & 0 deletions fuel/downloaders/iris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from fuel.downloaders.base import default_downloader


def fill_subparser(subparser):
"""Set up a subparser to download the Iris dataset file.
The Iris dataset file `iris.data` is downloaded from the UCI
Machine Learning Repository [UCI].
.. [UCI] https://archive.ics.uci.edu/ml/datasets/Iris
Parameters
----------
subparser : :class:`argparse.ArgumentParser`
Subparser handling the iris command.
"""
subparser.set_defaults(
func=default_downloader,
urls=['https://archive.ics.uci.edu/ml/machine-learning-databases/'
'iris/iris.data'],
filenames=['iris.data'])
2 changes: 1 addition & 1 deletion fuel/downloaders/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def fill_subparser(subparser):
"""Sets up a subparser to download the MNIST dataset files.
The following MNIST dataset files are downoladed from Yann LeCun's
The following MNIST dataset files are downloaded from Yann LeCun's
website [LECUN]:
`train-images-idx3-ubyte.gz`, `train-labels-idx1-ubyte.gz`,
`t10k-images-idx3-ubyte.gz`, `t10k-labels-idx1-ubyte.gz`.
Expand Down
47 changes: 46 additions & 1 deletion tests/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
from fuel.converters.base import (fill_hdf5_file, check_exists,
MissingInputFiles)
from fuel.converters import (binarized_mnist, caltech101_silhouettes,
cifar10, cifar100, mnist, svhn)
iris, cifar10, cifar100, mnist, svhn)
from fuel.downloaders.caltech101_silhouettes import silhouettes_downloader
from fuel.downloaders.base import default_downloader

if six.PY3:
getbuffer = memoryview
Expand Down Expand Up @@ -424,6 +425,50 @@ def test_download_and_convert(self, size=16):
assert h5['targets'].shape == (8641, 1)


class TestIris(object):
def setUp(self):
self.tempdir = tempfile.mkdtemp()

def tearDown(self):
shutil.rmtree(self.tempdir)

def test_download_and_convert(self):
tempdir = self.tempdir

cwd = os.getcwd()
os.chdir(tempdir)

assert_raises(IOError,
iris.convert_iris,
directory=tempdir,
output_directory=tempdir)

default_downloader(
directory=tempdir,
urls=['https://archive.ics.uci.edu/ml/machine-learning-databases/'
'iris/iris.data'],
filenames=['iris.data'])

classes = {'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2}
data = numpy.loadtxt(
os.path.join(tempdir, 'iris.data'),
converters={4: lambda x: classes[x]},
delimiter=',')
features = data[:, :-1].astype('float32')
targets = data[:, -1].astype('uint8').reshape((-1, 1))

iris.convert_iris(directory=tempdir,
output_directory=tempdir)

os.chdir(cwd)

output_file = "iris.hdf5"
output_file = os.path.join(tempdir, output_file)
with h5py.File(output_file, 'r') as h5:
assert numpy.allclose(h5['features'], features)
assert numpy.allclose(h5['targets'], targets)


class TestSVHN(object):
def setUp(self):
numpy.random.seed(9 + 5 + 2015)
Expand Down
15 changes: 14 additions & 1 deletion tests/test_downloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from numpy.testing import assert_equal, assert_raises

from fuel.downloaders import (binarized_mnist, caltech101_silhouettes,
cifar10, cifar100, mnist, svhn)
cifar10, cifar100, iris, mnist, svhn)
from fuel.downloaders.base import (download, default_downloader,
filename_from_url, NeedURLPrefix,
ensure_directory_exists)
Expand Down Expand Up @@ -128,6 +128,19 @@ def test_caltech101_silhouettes():
assert args.func is caltech101_silhouettes.silhouettes_downloader


def test_iris():
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
iris.fill_subparser(subparsers.add_parser('iris'))
args = parser.parse_args(['iris'])
urls = ['https://archive.ics.uci.edu/ml/machine-learning-databases/'
'iris/iris.data']
filenames = ['iris.data']
assert_equal(args.filenames, filenames)
assert_equal(args.urls, urls)
assert args.func is default_downloader


def test_cifar10():
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Expand Down
37 changes: 37 additions & 0 deletions tests/test_iris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy

from numpy.testing import assert_raises, assert_equal, assert_allclose

from fuel import config
from fuel.datasets import Iris
from tests import skip_if_not_available


def test_iris_all():
skip_if_not_available(datasets=['iris.hdf5'])

dataset = Iris(('all',), load_in_memory=False)
handle = dataset.open()
data, labels = dataset.get_data(handle, slice(0, 10))
assert data.dtype == config.floatX
assert data.shape == (10, 4)
assert labels.shape == (10, 1)
known = numpy.array([5.1, 3.5, 1.4, 0.2])
assert_allclose(data[0], known)
assert labels[0][0] == 0
assert dataset.num_examples == 150
dataset.close(handle)


def test_iris_axes():
skip_if_not_available(datasets=['iris.hdf5'])

dataset = Iris(('all',), load_in_memory=False)
assert_equal(dataset.axis_labels['features'],
('batch', 'feature'))


def test_iris_invalid_split():
skip_if_not_available(datasets=['iris.hdf5'])

assert_raises(ValueError, Iris, ('dummy',))

0 comments on commit 755f1f2

Please sign in to comment.