From de1d444e3e51e1a5a1034c7f33422eb42117255e Mon Sep 17 00:00:00 2001 From: Le Thanh Date: Tue, 25 Sep 2018 13:21:28 -0400 Subject: [PATCH] Fixing FAISS integration and performance test (#102) * Working FAISS KNN * Handling custom KNN, add test for the handling custom knn and faiss * Add docstring for knn_classifier * Remove trailing whitespace, remove unused variables * - Updating docstring for the knn_classifier parameter; - Setting default value of knn_classifier to 'knn' (standard scikit-learn implementation) ; * Adding documentation to methods * Handle continous array for faiss and add predict_proba test with IH * Fix error in predict and predict_proba, add instalation guide, add faiss test * Add performance comparison between faiss vs sklearn * Add performance comparison for FAISS * Improved code quality * Update travis for faiss * Fix bug and remove mock * Update travis for anaconda * Fix minlengh=2 in faiss wrapper * Better travis by removing if for conda * Fix anaconda travis * Add bash install for anaconda * Change anaconda to miniconda, fix conda to always yes * Fix old anaconda environment handling * Add source bashrc * Fix travis * Fix travis * Fix travis * Handle skipping faiss test * Fix unused variable * Fix unused variable * Remove unused import * Update code quality --- .travis.yml | 11 ++- deslib/base.py | 17 +++-- deslib/des/knop.py | 3 +- deslib/des/meta_des.py | 3 +- .../kne_knn_proba_integration.npy | Bin 0 -> 3136 bytes .../performance/compare_performance_faiss.py | 72 ++++++++++++++++++ deslib/tests/test_base.py | 21 ++++- deslib/tests/test_des_integration.py | 19 ++++- deslib/tests/test_faiss.py | 30 ++++++++ deslib/util/faiss_knn_wrapper.py | 15 ++-- docs/user_guide/installation.rst | 9 ++- 11 files changed, 176 insertions(+), 24 deletions(-) create mode 100644 deslib/tests/expected_values/kne_knn_proba_integration.npy create mode 100644 deslib/tests/performance/compare_performance_faiss.py create mode 100644 deslib/tests/test_faiss.py diff --git a/.travis.yml b/.travis.yml index 14ca568b..5ad7b80b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,13 +2,20 @@ language: python python: - "3.5" - "3.6" -before_install: - - pip install -U pip install: + - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh + - bash miniconda.sh -b -p $HOME/miniconda + - export PATH="$HOME/miniconda/bin:$PATH" + - hash -r + - conda config --set always_yes yes --set changeps1 no + - conda create -n test_env python="$TRAVIS_PYTHON_VERSION" + - echo ". $HOME/miniconda/etc/profile.d/conda.sh" >> "$HOME/.bashrc" + - source activate test_env - travis_wait travis_retry pip install -r requirements-dev.txt - travis_retry pip install codecov - travis_retry python setup.py build - travis_retry python setup.py install + - travis_retry conda install faiss-cpu -c pytorch script: coverage run -m py.test after_success: - codecov diff --git a/deslib/base.py b/deslib/base.py index 46fe39c5..0904d484 100644 --- a/deslib/base.py +++ b/deslib/base.py @@ -51,24 +51,29 @@ def __init__(self, pool_classifiers, k=7, DFP=False, with_IH=False, safe_k=None, self.n_classes = None self.n_samples = None self.n_features = None + self.knn_class = None + if knn_classifier is None: - self.roc_algorithm = functools.partial(KNeighborsClassifier, n_jobs=-1, algorithm="auto") + self.knn_class = functools.partial(KNeighborsClassifier, n_jobs=-1, algorithm="auto") elif isinstance(knn_classifier, str): if knn_classifier == "faiss": - from deslib.util.faiss_knn_wrapper import FaissKNNClassifier - self.roc_algorithm = functools.partial(FaissKNNClassifier, n_jobs=-1, algorithm="auto") + try: + from deslib.util.faiss_knn_wrapper import FaissKNNClassifier + except ImportError: + raise ImportError("FAISS library needs to be manually installed, please check the Installation Guide") + self.knn_class = functools.partial(FaissKNNClassifier, n_jobs=-1, algorithm="auto") elif knn_classifier == "knn": - self.roc_algorithm = functools.partial(KNeighborsClassifier, n_jobs=-1, algorithm="auto") + self.knn_class = functools.partial(KNeighborsClassifier, n_jobs=-1, algorithm="auto") else: raise ValueError('"knn_classifier" should be one of the following ' '["knn", "faiss"] or an estimator class') elif callable(knn_classifier): - self.roc_algorithm = knn_classifier + self.knn_class = knn_classifier else: raise ValueError('"knn_classifier" should be one of the following ' '["knn", "faiss"] or an estimator class') - self.roc_algorithm = self.roc_algorithm(self.k) + self.roc_algorithm = self.knn_class(self.k) # TODO: remove these as class variables self.neighbors = None diff --git a/deslib/des/knop.py b/deslib/des/knop.py index 3e6677ea..1b6a11ff 100644 --- a/deslib/des/knop.py +++ b/deslib/des/knop.py @@ -5,7 +5,6 @@ # License: BSD 3 clause import numpy as np -from sklearn.neighbors import KNeighborsClassifier from deslib.des.base import DES @@ -130,7 +129,7 @@ def _fit_OP(self, X_op, y_op, k): Number of output profiles used in the region of competence estimation. """ - self.op_knn = KNeighborsClassifier(n_neighbors=k, n_jobs=-1, algorithm='auto') + self.op_knn = self.knn_class(k) if self.n_classes == 2: # Get only the scores for one class since they are complementary diff --git a/deslib/des/meta_des.py b/deslib/des/meta_des.py index 6c48de59..50bd32a9 100644 --- a/deslib/des/meta_des.py +++ b/deslib/des/meta_des.py @@ -9,7 +9,6 @@ import numpy as np from sklearn.exceptions import NotFittedError from sklearn.naive_bayes import MultinomialNB -from sklearn.neighbors import KNeighborsClassifier from sklearn.utils.validation import check_is_fitted from deslib.des.base import DES @@ -178,7 +177,7 @@ class labels of each sample in X_op. Number of output profiles used in the estimation. """ - self.op_knn = KNeighborsClassifier(n_neighbors=kp, n_jobs=-1, algorithm='auto') + self.op_knn = self.knn_class(kp) if self.n_classes == 2: # Get only the scores for one class since they are complementary diff --git a/deslib/tests/expected_values/kne_knn_proba_integration.npy b/deslib/tests/expected_values/kne_knn_proba_integration.npy new file mode 100644 index 0000000000000000000000000000000000000000..229717f93b008494cf1665ca58bb0bbe3369cfae GIT binary patch literal 3136 zcmbVOX*kr2+pd!(q@+Vmq>)0F5{I0MTV#nyCM`(WrlMpi**lR+_C!Tv$r{mOvKBHq zWD7~MbS%+~nfd+jwc|JYY_1s&6rHzHviRH_Dmjx=ixcWGIEAFN! z?xpD{s!|kb7rifeJ9?bD=)dErOC&;M_VP09`J|#{Ks-JtuH{;xx;AI19q% zn9|WH#g~lIYrd zWz0yp8Z@cr2BorG$lInyOtQ*Q`Db&{;1Czf{tM%nH*pj8f9?lQ6^Q<05=nt2`~etpGy~h58;A1QZTfDHJ=f z;nrJ4(+qEf{sj}dOEL-0(=A2avb!I_a1@EPvk}IVxdps!-1*tmoTE&`Mxpb zp+W71xD>k<(U++6rJ z{MuBx)>f+*Cs#cwv^mDbMta&EYopg--pJnWDl`heNyZ~B-UlT0dn<;@q(hq&k2vWf6Ncbr3zR7OGmM`a2(;w#pdy#ANmj)m z@Yqg6zHk!Wp$qwkL(1{uPAxBvZj>^3kr+UNm=*`s|Eb6;@bdgoffumMO`7aLNaW@Gi+w@t#?1lpwc?t~*W(%djm|rZ7=!%|1N6iv#u8y_@785)c;MwW?J&2Z_V>iaEDiv7uhc_vlF$ z_(~e|Qte0aJ;5WCY5V~ahZN@Aes;sR`+QjAwP7r5xmfN#P9jIJL@rwC6a-|Pxn)sI zEb4rF+b+X_U#Y-@z!nnPMp;3c2L2#w^Osw*NeGJT+AJ2*aJECQOIEWMG;>cg9Ul&y zX5CbFwg*96`7!TUjau4yZD{6nF z_iW%kLCt-KkM$?`RFal-{EKK=om7E!6axjl!X-caEH8bJ){GlMyZ*8rWnt`bS;1e? zy$~*tS9bnu2m%HsDt6&)jLi;NCyc#AXpzuFwps%m?!FF7nqtFPK3=apqYC$X)w&y0 zxS)G*E{;d1L-N4#Im#>n;#Y!$bRHLx%T9}It*$_apqTlSnKB6cqWdZja$$+o#-QpB z93Oa=yf1we*MCW=$56jvZdq&E@bPvSa)0#jZMhAr-|uoS-XrmclJBjRiP;FD`(1gO zL%?qAwC$Qq5~Ano&TQLz6Su{uhtqmVJgfBOU#;$q)K_KGfmT(xz%soyuEm7|E~Dih0bf=+BA7sG3X^~5=p|Y`ex?d z?a?@ML|(tCmx)p3*&U)n90b(+|7qgY0RtJPt9|__ez=_03O>if&hDQ_BpD3s*?0A= zeRUdU`1}6g|BXa|38hxLgT%iEeWdK9FIp`R&rU?u!s3ZaDAS$;JXRI%PgK1Ui zbxQ)h-;#4}gh@OvR{~dX;F<5>Gqz1mCS`gsiXor#K-d0yb}LOwnEPy|UX} zTIn;4=6P&h$L3*?DcxOe@fE-MH5`g^`GoKt7fyzc4q+yJtI@PZ50b=YGhZJ27mO=> z+qWJWfTh!f9M6ZunvxH1>_X%5k9JaV+fdejDTE)dYqiR@Yya{h!ZhwRK~7ydZJ;;JFcNmoSXeXYvl~#X;P> z%v&^s{6GCc`m3{(!kH+poIw_7g-(Uny-PvOr@B2|%Y&?AkYKZH3(CJy77v!OP+5~`FVyn}9kD%m z!4Vv^NWK=ME53t8MO&p*2^+gNonfa))<8oihyTP64#xhd-rDb84W^8#-F0y;*uu;v zeYIii81p|HlKmM){98V%mVAWr!i7ZkmJyh)`!w)CHXbySIdzdMBr1{^i&M7?kXZ8j zGbr1(o zEMRT`!!1C(!%2|WSqNKu#xYa zeyz_b3_c-q(U+EZwpw?dw!gR<4)-ZO_k_3@{k7L?ZB8#%|Ltfpt2zYdk>V|^!YWY1 K&dnPrbMZfH67voK literal 0 HcmV?d00001 diff --git a/deslib/tests/performance/compare_performance_faiss.py b/deslib/tests/performance/compare_performance_faiss.py new file mode 100644 index 00000000..75576c68 --- /dev/null +++ b/deslib/tests/performance/compare_performance_faiss.py @@ -0,0 +1,72 @@ +import numpy as np +import faiss +from sklearn.neighbors import KNeighborsClassifier +import pandas as pd +import time +from sklearn.model_selection import train_test_split +import threading +import os +import urllib.request +import gzip +import shutil + +def sk_knn(Xtrain, Y, k, Xtest): + start = time.clock() + s_knn = KNeighborsClassifier(k, n_jobs=4) #Half of current cores + s_knn.fit(Xtrain, Y) + s_knn.predict(Xtest) + print("sklearn_knn run_time: {}".format(time.clock() - start)) + +def faiss_knn(Xtrain, Y, k, Xtest): + start = time.clock() + index = faiss.IndexFlatL2(Xtrain.shape[1]) + index.add(np.ascontiguousarray(Xtrain).astype(np.float32)) + index.search(Xtest.astype(np.float32), k) + print("faiss_knn run_time: {}".format(time.clock() - start)) + + +if __name__ == "__main__": + + if not os.path.exists("../../HIGGS.csv"): + print("Downloading HIGGS dataset from https://archive.ics.uci.edu/ml/datasets/HIGGS") + if not os.path.exists("../../HIGGS.gz"): + url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00280/HIGGS.csv.gz" + filedata = urllib.request.urlopen(url) + data2write = filedata.read() + with open('../../HIGSS.gz', 'wb') as f: + f.write(data2write) + print("Finished downloading") + print("Extracting HIGGS.gz") + if not os.path.exists("../../HIGGS.csv"): + with gzip.open('../../HIGGS.gz', 'rb') as f: + with open('../../HIGGS.csv', 'wb') as csv_out: + shutil.copyfileobj(f, csv_out) + print("Extracted csv") + + df = pd.read_csv('../../HIGGS.csv', header=None) + data = df.values + X = data[:, 1:] + Y = data[:, 0] + + X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.33) + num_samples_list = [1000000] + num_of_k_list = [1, 2, 5, 7, 10] + num_of_test_inputs = [100, 1000] + + for nsamples in num_samples_list: + for n_k in num_of_k_list: + for n_t in num_of_test_inputs: + print("running experiment: num_of_train_samples: {}, num_of_k: {}, num_of_tests: {}".format( + nsamples, + n_k, + n_t)) + faiss_knn(X_train[:nsamples], Y_train[:nsamples], n_k, X_test[:n_t]) + t = threading.Thread(target=sk_knn, args=(X_train[:nsamples], Y_train[:nsamples], n_k, X_test[:n_t])) + t.start() + t.join(timeout=600) + if t.is_alive(): + print("sklearn_knn, num_of_train_samples: {}, num_of_k: {}, num_of_tests: {}, run_time: {}".format( + nsamples, + n_k, + n_t, + "timeout after 60s")) diff --git a/deslib/tests/test_base.py b/deslib/tests/test_base.py index 4ffac344..1d4d2795 100644 --- a/deslib/tests/test_base.py +++ b/deslib/tests/test_base.py @@ -2,10 +2,11 @@ import pytest from sklearn.exceptions import NotFittedError +from sklearn.neighbors import KNeighborsClassifier from deslib.base import DS from deslib.tests.examples_test import * - +import unittest.mock def test_all_classifiers_agree(): # 10 classifiers that return 1 @@ -80,6 +81,24 @@ def test_valid_selection_mode(knn_method): with pytest.raises(ValueError): DS(create_pool_classifiers(), knn_classifier=knn_method) +def test_import_faiss_mode(): + try: + import sys + sys.modules.pop('deslib.util.faiss_knn_wrapper') + except Exception: + pass + with unittest.mock.patch.dict('sys.modules', {'faiss': None}): + with pytest.raises(ImportError): + DS(create_pool_classifiers(), knn_classifier="faiss") + +def test_none_selection_mode(): + ds = DS(create_pool_classifiers(), knn_classifier=None) + assert(isinstance(ds.roc_algorithm, KNeighborsClassifier)) + +def test_string_selection_mode(): + ds = DS(create_pool_classifiers(), knn_classifier="knn") + assert(isinstance(ds.roc_algorithm, KNeighborsClassifier)) + # In this test the system was trained for a sample containing 2 features and we are passing a sample with 3 as argument. # So it should raise a value error. def test_different_input_shape(): diff --git a/deslib/tests/test_des_integration.py b/deslib/tests/test_des_integration.py index 1b5a2934..2a74e7e7 100644 --- a/deslib/tests/test_des_integration.py +++ b/deslib/tests/test_des_integration.py @@ -4,7 +4,6 @@ from sklearn.ensemble import BaggingClassifier from sklearn.linear_model import LogisticRegression from sklearn.linear_model import Perceptron -from sklearn.neighbors import KNeighborsClassifier from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC @@ -33,12 +32,12 @@ from deslib.static.static_selection import StaticSelection import pytest import warnings +import sys -knn_methods = [None, "knn", KNeighborsClassifier] +knn_methods = [None] try: from deslib.util.faiss_knn_wrapper import FaissKNNClassifier - knn_methods.append("faiss") knn_methods.append(FaissKNNClassifier) except ImportError: warnings.warn("Not testing FAISS for KNN") @@ -281,10 +280,22 @@ def test_kne_proba(knn_methods): expected = np.load('deslib/tests/expected_values/kne_proba_integration.npy') assert np.allclose(probas, expected) +# ------------------------------------------ Testing predict_proba ----------------------------------- + +@pytest.mark.skipif('faiss' not in sys.modules, + reason="requires the faiss library") +def test_compare_faiss_predict_proba_IH(): + pool_classifiers, X_dsel, y_dsel, X_test, y_test = setup_classifiers() + kne = KNORAE(pool_classifiers, knn_classifier="faiss", with_IH=True, IH_rate=0.1) + kne.fit(X_dsel, y_dsel) + probas = kne.predict_proba(X_test) + expected = np.load('deslib/tests/expected_values/kne_knn_proba_integration.npy') + assert np.allclose(probas, expected) + + @pytest.mark.parametrize('knn_methods', knn_methods) def test_desp_proba(knn_methods): pool_classifiers, X_dsel, y_dsel, X_test, y_test = setup_classifiers() - desp = DESP(pool_classifiers, knn_classifier=knn_methods) desp.fit(X_dsel, y_dsel) probas = desp.predict_proba(X_test) diff --git a/deslib/tests/test_faiss.py b/deslib/tests/test_faiss.py new file mode 100644 index 00000000..81ec6912 --- /dev/null +++ b/deslib/tests/test_faiss.py @@ -0,0 +1,30 @@ + +import pytest +import sys +from sklearn.neighbors import KNeighborsClassifier +from deslib.tests.examples_test import * +from deslib.tests.test_des_integration import load_dataset + +try: + from deslib.util.faiss_knn_wrapper import FaissKNNClassifier +except ImportError: + pass + + +@pytest.mark.skipif('faiss' not in sys.modules, + reason="requires the faiss library") +def test_faiss_predict(): + rng = np.random.RandomState(123456) + _, X_test, X_train, _, _, y_train = load_dataset(None, rng) + k = 7 + X_train = X_train.astype(np.float32) + X_test = X_test.astype(np.float32) + f_knn_test = FaissKNNClassifier(n_neighbors=k) + f_knn_test.fit(X_train, y_train) + f_knn_preds = f_knn_test.predict(X_test) + + knn_test = KNeighborsClassifier(n_neighbors=k) + knn_test.fit(X_train, y_train) + knn_preds = knn_test.predict(X_test) + + assert ((f_knn_preds - knn_preds).sum() == 0) diff --git a/deslib/util/faiss_knn_wrapper.py b/deslib/util/faiss_knn_wrapper.py index 1bd7e595..8f8eb3c5 100644 --- a/deslib/util/faiss_knn_wrapper.py +++ b/deslib/util/faiss_knn_wrapper.py @@ -47,7 +47,8 @@ def predict(self, X): """ _, idx = self.kneighbors(X, self.n_neighbors) class_idx = self.y[idx] - preds = np.amax(class_idx, axis=1) + counts = np.apply_along_axis(lambda x: np.bincount(x, minlength=self.num_of_classes), axis=1, arr=class_idx.astype(np.int64)) + preds = np.argmax(counts, axis=1) return preds def kneighbors(self, X, n_neighbors, return_distance=True): @@ -70,12 +71,13 @@ def predict_proba(self, X): """ _, idx = self.kneighbors(X, self.n_neighbors) class_idx = self.y[idx] - preds = np.amax(class_idx, axis=1) + counts = np.apply_along_axis(lambda x: np.bincount(x, minlength=self.num_of_classes), axis=1, arr=class_idx.astype(np.int64)) + preds = np.argmax(counts, axis=1) - #FIXME: can probably be improved for a vectorized version - preds_proba = np.zeros(X.shape[0], self.num_of_classes) - for i in range(preds): - preds_proba[i] = np.bincount(class_idx[i, :]) / self.n_neighbors + #TODO: can probably be improved for a vectorized version + preds_proba = np.zeros((X.shape[0], self.num_of_classes)) + for i in range(preds.shape[0]): + preds_proba[i] = counts[i] / self.n_neighbors return preds_proba @@ -91,6 +93,7 @@ def fit(self, X, y): class labels of each example in X. """ X = np.atleast_2d(X).astype(np.float32) + X = np.ascontiguousarray(X) self.index = faiss.IndexFlatL2(X.shape[1]) self.index.add(X) self.y = y diff --git a/docs/user_guide/installation.rst b/docs/user_guide/installation.rst index ac794c30..ad6142d4 100644 --- a/docs/user_guide/installation.rst +++ b/docs/user_guide/installation.rst @@ -24,4 +24,11 @@ DESlib is tested to work with Python 3.5, and 3.6. The dependency requirements a * numpy(>=1.10.4) * scikit-learn(>=0.19.0) -These dependencies are automatically installed using the pip commands above. \ No newline at end of file +These dependencies are automatically installed using the pip commands above. + +Optional dependencies +===================== + +To use Faiss (Fair AI Similarity Search), a fast implementation of KNN that can use GPUs, follow the instructions below: +https://github.com/facebookresearch/faiss/blob/master/INSTALL.md +