Skip to content

Commit

Permalink
Added function to turn labels into one-hot
Browse files Browse the repository at this point in the history
  • Loading branch information
hunse committed Nov 9, 2016
1 parent d701090 commit d2114ad
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 1 deletion.
43 changes: 42 additions & 1 deletion nengo_extras/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tarfile
import urllib

from nengo.utils.compat import pickle
from nengo.utils.compat import is_integer, is_iterable, pickle
import numpy as np


Expand Down Expand Up @@ -289,6 +289,47 @@ def spasafe_names(label_names):
return vocab_names


def one_hot_from_labels(labels, classes=None, dtype=float):
"""Turn integer labels into a one-hot encoding.
Parameters
==========
labels : (n,) array
Labels to turn into one-hot encoding.
classes : int or (n_classes,) array (optional)
Classes for encoding. If integer and ``labels.dtype`` is integer, this
is the number of classes in the encoding. If iterable, this is the
list of classes to place in the one-hot (must be a superset of the
unique elements in ``labels``).
dtype : dtype (optional)
Data type of returned one-hot encoding (defaults to ``float``).
"""
assert labels.ndim == 1
n = labels.shape[0]

if np.issubdtype(labels.dtype, np.integer) and (
classes is None or is_integer(classes)):
index = labels
index_min, index_max = index.min(), index.max()
n_classes = (index_max + 1) if classes is None else classes
assert index_min >= 0
assert index_max < n_classes
else:
if classes is not None:
assert is_iterable(classes)
assert set(np.unique(labels)).issubset(classes)
classes = np.unique(labels) if classes is None else classes
n_classes = len(classes)

c_index = np.argsort(classes)
c_sorted = classes[c_index]
index = c_index[np.searchsorted(c_sorted, labels)]

y = np.zeros((n, n_classes), dtype=dtype)
y[np.arange(n), index] = 1
return y


class ZCAWhiten(object):
"""ZCA Whitening
Expand Down
38 changes: 38 additions & 0 deletions nengo_extras/tests/test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np

from nengo_extras.data import one_hot_from_labels


def test_one_hot_from_labels_int(rng):
nc = 19
labels = rng.randint(nc, size=1000)

yref = np.zeros((len(labels), nc))
yref[np.arange(len(labels)), labels] = 1

y0 = one_hot_from_labels(labels)
y1 = one_hot_from_labels(labels, classes=nc+5)
assert np.array_equal(y0, yref)
assert np.array_equal(y0, y1[:, :nc])
assert (y1[:, nc:] == 0).all()


def test_one_hot_from_labels_skip(rng):
labels = 2*rng.randint(4, size=1000)

yref = np.zeros((len(labels), labels.max()+1))
yref[np.arange(len(labels)), labels] = 1
y = one_hot_from_labels(labels)
assert np.array_equal(y, yref)


def test_one_hot_from_labels_float(rng):
classes = rng.uniform(0, 9, size=11)
inds = rng.randint(len(classes), size=1000)
labels = classes[inds]

yref = np.zeros((len(labels), len(classes)))
yref[np.arange(len(labels)), inds] = 1

y = one_hot_from_labels(labels, classes=classes)
assert np.array_equal(y, yref)

0 comments on commit d2114ad

Please sign in to comment.