diff --git a/docs/new_dataset.rst b/docs/new_dataset.rst index 8dd883cdb..633711236 100644 --- a/docs/new_dataset.rst +++ b/docs/new_dataset.rst @@ -235,7 +235,7 @@ Try downloading and converting the data file: cd $FUEL_DATA_PATH fuel-download iris fuel-convert iris - fuel-download --clear iris + fuel-download iris --clear cd - You can now use the Iris dataset like you would use any other built-in dataset: @@ -300,7 +300,7 @@ You can now use the Iris dataset like you would use any other built-in dataset: ... delimiter=',') ... numpy.random.shuffle(data) ... features = data[:, :-1].astype('float32') - ... targets = data[:, -1:].astype('uint8') + ... targets = data[:, -1:].astype('uint8').reshape((-1, 1)) ... train_features = features[:100] ... train_targets = targets[:100] ... valid_features = features[100:120] diff --git a/fuel/converters/__init__.py b/fuel/converters/__init__.py index 629a291c2..ac84aad0f 100644 --- a/fuel/converters/__init__.py +++ b/fuel/converters/__init__.py @@ -14,6 +14,7 @@ from fuel.converters import caltech101_silhouettes from fuel.converters import cifar10 from fuel.converters import cifar100 +from fuel.converters import iris from fuel.converters import mnist from fuel.converters import svhn @@ -23,5 +24,6 @@ ('caltech101_silhouettes', caltech101_silhouettes.fill_subparser), ('cifar10', cifar10.fill_subparser), ('cifar100', cifar100.fill_subparser), + ('iris', iris.fill_subparser), ('mnist', mnist.fill_subparser), ('svhn', svhn.fill_subparser)) diff --git a/fuel/converters/iris.py b/fuel/converters/iris.py new file mode 100644 index 000000000..cbfd52657 --- /dev/null +++ b/fuel/converters/iris.py @@ -0,0 +1,58 @@ +import os + +import h5py +import numpy + +from fuel.converters.base import fill_hdf5_file + + +def convert_iris(directory, output_directory, output_filename='iris.hdf5'): + """Convert the Iris dataset to HDF5. + + Converts the Iris dataset to an HDF5 dataset compatible with + :class:`fuel.datasets.Iris`. The converted dataset is + saved as 'iris.hdf5'. + This method assumes the existence of the file `iris.data`. + + Parameters + ---------- + directory : str + Directory in which input files reside. + output_directory : str + Directory in which to save the converted dataset. + output_filename : str, optional + Name of the saved dataset. Defaults to `None`, in which case a name + based on `dtype` will be used. + + Returns + ------- + output_paths : tuple of str + Single-element tuple containing the path to the converted dataset. + + """ + classes = {'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2} + data = numpy.loadtxt( + os.path.join(directory, 'iris.data'), + converters={4: lambda x: classes[x]}, + delimiter=',') + features = data[:, :-1].astype('float32') + targets = data[:, -1].astype('uint8').reshape((-1, 1)) + data = (('all', 'features', features), + ('all', 'targets', targets)) + + output_path = os.path.join(output_directory, output_filename) + h5file = h5py.File(output_path, mode='w') + fill_hdf5_file(h5file, data) + h5file['features'].dims[0].label = 'batch' + h5file['features'].dims[1].label = 'feature' + h5file['targets'].dims[0].label = 'batch' + h5file['targets'].dims[1].label = 'index' + + h5file.flush() + h5file.close() + + return (output_path,) + + +def fill_subparser(subparser): + subparser.set_defaults(func=convert_iris) diff --git a/fuel/datasets/__init__.py b/fuel/datasets/__init__.py index 8b7d91111..7ca62f108 100644 --- a/fuel/datasets/__init__.py +++ b/fuel/datasets/__init__.py @@ -7,6 +7,7 @@ from fuel.datasets.cifar10 import CIFAR10 from fuel.datasets.cifar100 import CIFAR100 from fuel.datasets.caltech101_silhouettes import CalTech101Silhouettes +from fuel.datasets.iris import Iris from fuel.datasets.mnist import MNIST from fuel.datasets.svhn import SVHN from fuel.datasets.text import TextFile diff --git a/fuel/datasets/iris.py b/fuel/datasets/iris.py new file mode 100644 index 000000000..5d6d50f0f --- /dev/null +++ b/fuel/datasets/iris.py @@ -0,0 +1,32 @@ +from fuel.datasets import H5PYDataset +from fuel.utils import find_in_data_path + + +class Iris(H5PYDataset): + u"""Iris dataset. + + Iris [LBBH] is a simple pattern recognition dataset, which consist of + 3 classes of 50 examples each having 4 real-valued features each, where + each class refers to a type of iris plant. It is accessible through the + UCI Machine Learning repository [UCI]. + + .. [IRIS] Ronald A. Fisher, *The use of multiple measurements in + taxonomic problems*, Annual Eugenics, 7, Part II, 179-188, + September 1936. + .. [UCI] https://archive.ics.uci.edu/ml/datasets/Iris + + Parameters + ---------- + which_sets : tuple of str + Which split to load. Valid values are 'train' and 'test', + corresponding to the training set (50,000 examples) and the test + set (10,000 examples). + + """ + filename = 'iris.hdf5' + + def __init__(self, which_sets, **kwargs): + kwargs.setdefault('load_in_memory', True) + super(Iris, self).__init__( + file_or_path=find_in_data_path(self.filename), + which_sets=which_sets, **kwargs) diff --git a/fuel/downloaders/__init__.py b/fuel/downloaders/__init__.py index 8640a8244..e0b0234d9 100644 --- a/fuel/downloaders/__init__.py +++ b/fuel/downloaders/__init__.py @@ -10,6 +10,7 @@ from fuel.downloaders import caltech101_silhouettes from fuel.downloaders import cifar10 from fuel.downloaders import cifar100 +from fuel.downloaders import iris from fuel.downloaders import mnist from fuel.downloaders import svhn @@ -18,5 +19,6 @@ ('caltech101_silhouettes', caltech101_silhouettes.fill_subparser), ('cifar10', cifar10.fill_subparser), ('cifar100', cifar100.fill_subparser), + ('iris', iris.fill_subparser), ('mnist', mnist.fill_subparser), ('svhn', svhn.fill_subparser)) diff --git a/fuel/downloaders/iris.py b/fuel/downloaders/iris.py new file mode 100644 index 000000000..63cd6e655 --- /dev/null +++ b/fuel/downloaders/iris.py @@ -0,0 +1,22 @@ +from fuel.downloaders.base import default_downloader + + +def fill_subparser(subparser): + """Set up a subparser to download the Iris dataset file. + + The Iris dataset file `iris.data` is downloaded from the UCI + Machine Learning Repository [UCI]. + + .. [UCI] https://archive.ics.uci.edu/ml/datasets/Iris + + Parameters + ---------- + subparser : :class:`argparse.ArgumentParser` + Subparser handling the iris command. + + """ + subparser.set_defaults( + func=default_downloader, + urls=['https://archive.ics.uci.edu/ml/machine-learning-databases/' + 'iris/iris.data'], + filenames=['iris.data']) diff --git a/fuel/downloaders/mnist.py b/fuel/downloaders/mnist.py index 759099213..c3da4d273 100644 --- a/fuel/downloaders/mnist.py +++ b/fuel/downloaders/mnist.py @@ -4,7 +4,7 @@ def fill_subparser(subparser): """Sets up a subparser to download the MNIST dataset files. - The following MNIST dataset files are downoladed from Yann LeCun's + The following MNIST dataset files are downloaded from Yann LeCun's website [LECUN]: `train-images-idx3-ubyte.gz`, `train-labels-idx1-ubyte.gz`, `t10k-images-idx3-ubyte.gz`, `t10k-labels-idx1-ubyte.gz`. diff --git a/tests/test_converters.py b/tests/test_converters.py index 67cc8439e..46b395c02 100644 --- a/tests/test_converters.py +++ b/tests/test_converters.py @@ -19,8 +19,9 @@ from fuel.converters.base import (fill_hdf5_file, check_exists, MissingInputFiles) from fuel.converters import (binarized_mnist, caltech101_silhouettes, - cifar10, cifar100, mnist, svhn) + iris, cifar10, cifar100, mnist, svhn) from fuel.downloaders.caltech101_silhouettes import silhouettes_downloader +from fuel.downloaders.base import default_downloader if six.PY3: getbuffer = memoryview @@ -424,6 +425,50 @@ def test_download_and_convert(self, size=16): assert h5['targets'].shape == (8641, 1) +class TestIris(object): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def test_download_and_convert(self): + tempdir = self.tempdir + + cwd = os.getcwd() + os.chdir(tempdir) + + assert_raises(IOError, + iris.convert_iris, + directory=tempdir, + output_directory=tempdir) + + default_downloader( + directory=tempdir, + urls=['https://archive.ics.uci.edu/ml/machine-learning-databases/' + 'iris/iris.data'], + filenames=['iris.data']) + + classes = {'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2} + data = numpy.loadtxt( + os.path.join(tempdir, 'iris.data'), + converters={4: lambda x: classes[x]}, + delimiter=',') + features = data[:, :-1].astype('float32') + targets = data[:, -1].astype('uint8').reshape((-1, 1)) + + iris.convert_iris(directory=tempdir, + output_directory=tempdir) + + os.chdir(cwd) + + output_file = "iris.hdf5" + output_file = os.path.join(tempdir, output_file) + with h5py.File(output_file, 'r') as h5: + assert numpy.allclose(h5['features'], features) + assert numpy.allclose(h5['targets'], targets) + + class TestSVHN(object): def setUp(self): numpy.random.seed(9 + 5 + 2015) diff --git a/tests/test_downloaders.py b/tests/test_downloaders.py index 0907d82a6..962d80ff0 100644 --- a/tests/test_downloaders.py +++ b/tests/test_downloaders.py @@ -8,7 +8,7 @@ from numpy.testing import assert_equal, assert_raises from fuel.downloaders import (binarized_mnist, caltech101_silhouettes, - cifar10, cifar100, mnist, svhn) + cifar10, cifar100, iris, mnist, svhn) from fuel.downloaders.base import (download, default_downloader, filename_from_url, NeedURLPrefix, ensure_directory_exists) @@ -128,6 +128,19 @@ def test_caltech101_silhouettes(): assert args.func is caltech101_silhouettes.silhouettes_downloader +def test_iris(): + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + iris.fill_subparser(subparsers.add_parser('iris')) + args = parser.parse_args(['iris']) + urls = ['https://archive.ics.uci.edu/ml/machine-learning-databases/' + 'iris/iris.data'] + filenames = ['iris.data'] + assert_equal(args.filenames, filenames) + assert_equal(args.urls, urls) + assert args.func is default_downloader + + def test_cifar10(): parser = argparse.ArgumentParser() subparsers = parser.add_subparsers() diff --git a/tests/test_iris.py b/tests/test_iris.py new file mode 100644 index 000000000..0b08eb1b1 --- /dev/null +++ b/tests/test_iris.py @@ -0,0 +1,37 @@ +import numpy + +from numpy.testing import assert_raises, assert_equal, assert_allclose + +from fuel import config +from fuel.datasets import Iris +from tests import skip_if_not_available + + +def test_iris_all(): + skip_if_not_available(datasets=['iris.hdf5']) + + dataset = Iris(('all',), load_in_memory=False) + handle = dataset.open() + data, labels = dataset.get_data(handle, slice(0, 10)) + assert data.dtype == config.floatX + assert data.shape == (10, 4) + assert labels.shape == (10, 1) + known = numpy.array([5.1, 3.5, 1.4, 0.2]) + assert_allclose(data[0], known) + assert labels[0][0] == 0 + assert dataset.num_examples == 150 + dataset.close(handle) + + +def test_iris_axes(): + skip_if_not_available(datasets=['iris.hdf5']) + + dataset = Iris(('all',), load_in_memory=False) + assert_equal(dataset.axis_labels['features'], + ('batch', 'feature')) + + +def test_iris_invalid_split(): + skip_if_not_available(datasets=['iris.hdf5']) + + assert_raises(ValueError, Iris, ('dummy',))