Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Iris dataset #2

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/new_dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ Try downloading and converting the data file:
cd $FUEL_DATA_PATH
fuel-download iris
fuel-convert iris
fuel-download --clear iris
fuel-download iris --clear
cd -

You can now use the Iris dataset like you would use any other built-in dataset:
Expand Down Expand Up @@ -300,7 +300,7 @@ You can now use the Iris dataset like you would use any other built-in dataset:
... delimiter=',')
... numpy.random.shuffle(data)
... features = data[:, :-1].astype('float32')
... targets = data[:, -1:].astype('uint8')
... targets = data[:, -1:].astype('uint8').reshape((-1, 1))
... train_features = features[:100]
... train_targets = targets[:100]
... valid_features = features[100:120]
Expand Down
2 changes: 2 additions & 0 deletions fuel/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from fuel.converters import caltech101_silhouettes
from fuel.converters import cifar10
from fuel.converters import cifar100
from fuel.converters import iris
from fuel.converters import mnist
from fuel.converters import svhn

Expand All @@ -23,5 +24,6 @@
('caltech101_silhouettes', caltech101_silhouettes.fill_subparser),
('cifar10', cifar10.fill_subparser),
('cifar100', cifar100.fill_subparser),
('iris', iris.fill_subparser),
('mnist', mnist.fill_subparser),
('svhn', svhn.fill_subparser))
58 changes: 58 additions & 0 deletions fuel/converters/iris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os

import h5py
import numpy

from fuel.converters.base import fill_hdf5_file


def convert_iris(directory, output_directory, output_filename='iris.hdf5'):
"""Convert the Iris dataset to HDF5.

Converts the Iris dataset to an HDF5 dataset compatible with
:class:`fuel.datasets.Iris`. The converted dataset is
saved as 'iris.hdf5'.
This method assumes the existence of the file `iris.data`.

Parameters
----------
directory : str
Directory in which input files reside.
output_directory : str
Directory in which to save the converted dataset.
output_filename : str, optional
Name of the saved dataset. Defaults to `None`, in which case a name
based on `dtype` will be used.

Returns
-------
output_paths : tuple of str
Single-element tuple containing the path to the converted dataset.

"""
classes = {'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2}
data = numpy.loadtxt(
os.path.join(directory, 'iris.data'),
converters={4: lambda x: classes[x]},
delimiter=',')
features = data[:, :-1].astype('float32')
targets = data[:, -1].astype('uint8').reshape((-1, 1))
data = (('all', 'features', features),
('all', 'targets', targets))

output_path = os.path.join(output_directory, output_filename)
h5file = h5py.File(output_path, mode='w')
fill_hdf5_file(h5file, data)
h5file['features'].dims[0].label = 'batch'
h5file['features'].dims[1].label = 'feature'
h5file['targets'].dims[0].label = 'batch'
h5file['targets'].dims[1].label = 'index'

h5file.flush()
h5file.close()

return (output_path,)


def fill_subparser(subparser):
subparser.set_defaults(func=convert_iris)
1 change: 1 addition & 0 deletions fuel/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fuel.datasets.cifar10 import CIFAR10
from fuel.datasets.cifar100 import CIFAR100
from fuel.datasets.caltech101_silhouettes import CalTech101Silhouettes
from fuel.datasets.iris import Iris
from fuel.datasets.mnist import MNIST
from fuel.datasets.svhn import SVHN
from fuel.datasets.text import TextFile
Expand Down
31 changes: 31 additions & 0 deletions fuel/datasets/iris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from fuel.datasets import H5PYDataset
from fuel.utils import find_in_data_path


class Iris(H5PYDataset):
u"""Iris dataset.

Iris [LBBH] is a simple pattern recognition dataset, which consist of
3 classes of 50 examples each having 4 real-valued features each, where
each class refers to a type of iris plant. It is accessible through the
UCI Machine Learning repository [UCI].

.. [IRIS] Ronald A. Fisher, *The use of multiple measurements in
taxonomic problems*, Annual Eugenics, 7, Part II, 179-188,
September 1936.
.. [UCI] https://archive.ics.uci.edu/ml/datasets/Iris

Parameters
----------
which_sets : tuple of str
Which split to load. Valid value is 'all'
corresponding to 150 examples.

"""
filename = 'iris.hdf5'

def __init__(self, which_sets, **kwargs):
kwargs.setdefault('load_in_memory', True)
super(Iris, self).__init__(
file_or_path=find_in_data_path(self.filename),
which_sets=which_sets, **kwargs)
2 changes: 2 additions & 0 deletions fuel/downloaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fuel.downloaders import caltech101_silhouettes
from fuel.downloaders import cifar10
from fuel.downloaders import cifar100
from fuel.downloaders import iris
from fuel.downloaders import mnist
from fuel.downloaders import svhn

Expand All @@ -18,5 +19,6 @@
('caltech101_silhouettes', caltech101_silhouettes.fill_subparser),
('cifar10', cifar10.fill_subparser),
('cifar100', cifar100.fill_subparser),
('iris', iris.fill_subparser),
('mnist', mnist.fill_subparser),
('svhn', svhn.fill_subparser))
22 changes: 22 additions & 0 deletions fuel/downloaders/iris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from fuel.downloaders.base import default_downloader


def fill_subparser(subparser):
"""Set up a subparser to download the Iris dataset file.

The Iris dataset file `iris.data` is downloaded from the UCI
Machine Learning Repository [UCI].

.. [UCI] https://archive.ics.uci.edu/ml/datasets/Iris

Parameters
----------
subparser : :class:`argparse.ArgumentParser`
Subparser handling the iris command.

"""
subparser.set_defaults(
func=default_downloader,
urls=['https://archive.ics.uci.edu/ml/machine-learning-databases/'
'iris/iris.data'],
filenames=['iris.data'])
2 changes: 1 addition & 1 deletion fuel/downloaders/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def fill_subparser(subparser):
"""Sets up a subparser to download the MNIST dataset files.

The following MNIST dataset files are downoladed from Yann LeCun's
The following MNIST dataset files are downloaded from Yann LeCun's
website [LECUN]:
`train-images-idx3-ubyte.gz`, `train-labels-idx1-ubyte.gz`,
`t10k-images-idx3-ubyte.gz`, `t10k-labels-idx1-ubyte.gz`.
Expand Down
47 changes: 46 additions & 1 deletion tests/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
from fuel.converters.base import (fill_hdf5_file, check_exists,
MissingInputFiles)
from fuel.converters import (binarized_mnist, caltech101_silhouettes,
cifar10, cifar100, mnist, svhn)
iris, cifar10, cifar100, mnist, svhn)
from fuel.downloaders.caltech101_silhouettes import silhouettes_downloader
from fuel.downloaders.base import default_downloader

if six.PY3:
getbuffer = memoryview
Expand Down Expand Up @@ -424,6 +425,50 @@ def test_download_and_convert(self, size=16):
assert h5['targets'].shape == (8641, 1)


class TestIris(object):
def setUp(self):
self.tempdir = tempfile.mkdtemp()

def tearDown(self):
shutil.rmtree(self.tempdir)

def test_download_and_convert(self):
tempdir = self.tempdir

cwd = os.getcwd()
os.chdir(tempdir)

assert_raises(IOError,
iris.convert_iris,
directory=tempdir,
output_directory=tempdir)

default_downloader(
directory=tempdir,
urls=['https://archive.ics.uci.edu/ml/machine-learning-databases/'
'iris/iris.data'],
filenames=['iris.data'])

classes = {'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2}
data = numpy.loadtxt(
os.path.join(tempdir, 'iris.data'),
converters={4: lambda x: classes[x]},
delimiter=',')
features = data[:, :-1].astype('float32')
targets = data[:, -1].astype('uint8').reshape((-1, 1))

iris.convert_iris(directory=tempdir,
output_directory=tempdir)

os.chdir(cwd)

output_file = "iris.hdf5"
output_file = os.path.join(tempdir, output_file)
with h5py.File(output_file, 'r') as h5:
assert numpy.allclose(h5['features'], features)
assert numpy.allclose(h5['targets'], targets)


class TestSVHN(object):
def setUp(self):
numpy.random.seed(9 + 5 + 2015)
Expand Down
15 changes: 14 additions & 1 deletion tests/test_downloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from numpy.testing import assert_equal, assert_raises

from fuel.downloaders import (binarized_mnist, caltech101_silhouettes,
cifar10, cifar100, mnist, svhn)
cifar10, cifar100, iris, mnist, svhn)
from fuel.downloaders.base import (download, default_downloader,
filename_from_url, NeedURLPrefix,
ensure_directory_exists)
Expand Down Expand Up @@ -128,6 +128,19 @@ def test_caltech101_silhouettes():
assert args.func is caltech101_silhouettes.silhouettes_downloader


def test_iris():
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
iris.fill_subparser(subparsers.add_parser('iris'))
args = parser.parse_args(['iris'])
urls = ['https://archive.ics.uci.edu/ml/machine-learning-databases/'
'iris/iris.data']
filenames = ['iris.data']
assert_equal(args.filenames, filenames)
assert_equal(args.urls, urls)
assert args.func is default_downloader


def test_cifar10():
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Expand Down
37 changes: 37 additions & 0 deletions tests/test_iris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy

from numpy.testing import assert_raises, assert_equal, assert_allclose

from fuel import config
from fuel.datasets import Iris
from tests import skip_if_not_available


def test_iris_all():
skip_if_not_available(datasets=['iris.hdf5'])

dataset = Iris(('all',), load_in_memory=False)
handle = dataset.open()
data, labels = dataset.get_data(handle, slice(0, 10))
assert data.dtype == config.floatX
assert data.shape == (10, 4)
assert labels.shape == (10, 1)
known = numpy.array([5.1, 3.5, 1.4, 0.2])
assert_allclose(data[0], known)
assert labels[0][0] == 0
assert dataset.num_examples == 150
dataset.close(handle)


def test_iris_axes():
skip_if_not_available(datasets=['iris.hdf5'])

dataset = Iris(('all',), load_in_memory=False)
assert_equal(dataset.axis_labels['features'],
('batch', 'feature'))


def test_iris_invalid_split():
skip_if_not_available(datasets=['iris.hdf5'])

assert_raises(ValueError, Iris, ('dummy',))