diff --git a/nengo_extras/data.py b/nengo_extras/data.py index 8bad345..0178cac 100644 --- a/nengo_extras/data.py +++ b/nengo_extras/data.py @@ -4,7 +4,7 @@ import tarfile import urllib -from nengo.utils.compat import pickle +from nengo.utils.compat import is_integer, is_iterable, pickle import numpy as np @@ -289,6 +289,47 @@ def spasafe_names(label_names): return vocab_names +def one_hot_from_labels(labels, classes=None, dtype=float): + """Turn integer labels into a one-hot encoding. + + Parameters + ========== + labels : (n,) array + Labels to turn into one-hot encoding. + classes : int or (n_classes,) array (optional) + Classes for encoding. If integer and ``labels.dtype`` is integer, this + is the number of classes in the encoding. If iterable, this is the + list of classes to place in the one-hot (must be a superset of the + unique elements in ``labels``). + dtype : dtype (optional) + Data type of returned one-hot encoding (defaults to ``float``). + """ + assert labels.ndim == 1 + n = labels.shape[0] + + if np.issubdtype(labels.dtype, np.integer) and ( + classes is None or is_integer(classes)): + index = labels + index_min, index_max = index.min(), index.max() + n_classes = (index_max + 1) if classes is None else classes + assert index_min >= 0 + assert index_max < n_classes + else: + if classes is not None: + assert is_iterable(classes) + assert set(np.unique(labels)).issubset(classes) + classes = np.unique(labels) if classes is None else classes + n_classes = len(classes) + + c_index = np.argsort(classes) + c_sorted = classes[c_index] + index = c_index[np.searchsorted(c_sorted, labels)] + + y = np.zeros((n, n_classes), dtype=dtype) + y[np.arange(n), index] = 1 + return y + + class ZCAWhiten(object): """ZCA Whitening diff --git a/nengo_extras/tests/test_data.py b/nengo_extras/tests/test_data.py new file mode 100644 index 0000000..d58846d --- /dev/null +++ b/nengo_extras/tests/test_data.py @@ -0,0 +1,38 @@ +import numpy as np + +from nengo_extras.data import one_hot_from_labels + + +def test_one_hot_from_labels_int(rng): + nc = 19 + labels = rng.randint(nc, size=1000) + + yref = np.zeros((len(labels), nc)) + yref[np.arange(len(labels)), labels] = 1 + + y0 = one_hot_from_labels(labels) + y1 = one_hot_from_labels(labels, classes=nc+5) + assert np.array_equal(y0, yref) + assert np.array_equal(y0, y1[:, :nc]) + assert (y1[:, nc:] == 0).all() + + +def test_one_hot_from_labels_skip(rng): + labels = 2*rng.randint(4, size=1000) + + yref = np.zeros((len(labels), labels.max()+1)) + yref[np.arange(len(labels)), labels] = 1 + y = one_hot_from_labels(labels) + assert np.array_equal(y, yref) + + +def test_one_hot_from_labels_float(rng): + classes = rng.uniform(0, 9, size=11) + inds = rng.randint(len(classes), size=1000) + labels = classes[inds] + + yref = np.zeros((len(labels), len(classes))) + yref[np.arange(len(labels)), inds] = 1 + + y = one_hot_from_labels(labels, classes=classes) + assert np.array_equal(y, yref)