-
Notifications
You must be signed in to change notification settings - Fork 106
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixing FAISS integration and performance test (#102)
* Working FAISS KNN * Handling custom KNN, add test for the handling custom knn and faiss * Add docstring for knn_classifier * Remove trailing whitespace, remove unused variables * - Updating docstring for the knn_classifier parameter; - Setting default value of knn_classifier to 'knn' (standard scikit-learn implementation) ; * Adding documentation to methods * Handle continous array for faiss and add predict_proba test with IH * Fix error in predict and predict_proba, add instalation guide, add faiss test * Add performance comparison between faiss vs sklearn * Add performance comparison for FAISS * Improved code quality * Update travis for faiss * Fix bug and remove mock * Update travis for anaconda * Fix minlengh=2 in faiss wrapper * Better travis by removing if for conda * Fix anaconda travis * Add bash install for anaconda * Change anaconda to miniconda, fix conda to always yes * Fix old anaconda environment handling * Add source bashrc * Fix travis * Fix travis * Fix travis * Handle skipping faiss test * Fix unused variable * Fix unused variable * Remove unused import * Update code quality
- Loading branch information
Showing
11 changed files
with
176 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import numpy as np | ||
import faiss | ||
from sklearn.neighbors import KNeighborsClassifier | ||
import pandas as pd | ||
import time | ||
from sklearn.model_selection import train_test_split | ||
import threading | ||
import os | ||
import urllib.request | ||
import gzip | ||
import shutil | ||
|
||
def sk_knn(Xtrain, Y, k, Xtest): | ||
start = time.clock() | ||
s_knn = KNeighborsClassifier(k, n_jobs=4) #Half of current cores | ||
s_knn.fit(Xtrain, Y) | ||
s_knn.predict(Xtest) | ||
print("sklearn_knn run_time: {}".format(time.clock() - start)) | ||
|
||
def faiss_knn(Xtrain, Y, k, Xtest): | ||
start = time.clock() | ||
index = faiss.IndexFlatL2(Xtrain.shape[1]) | ||
index.add(np.ascontiguousarray(Xtrain).astype(np.float32)) | ||
index.search(Xtest.astype(np.float32), k) | ||
print("faiss_knn run_time: {}".format(time.clock() - start)) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
if not os.path.exists("../../HIGGS.csv"): | ||
print("Downloading HIGGS dataset from https://archive.ics.uci.edu/ml/datasets/HIGGS") | ||
if not os.path.exists("../../HIGGS.gz"): | ||
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00280/HIGGS.csv.gz" | ||
filedata = urllib.request.urlopen(url) | ||
data2write = filedata.read() | ||
with open('../../HIGSS.gz', 'wb') as f: | ||
f.write(data2write) | ||
print("Finished downloading") | ||
print("Extracting HIGGS.gz") | ||
if not os.path.exists("../../HIGGS.csv"): | ||
with gzip.open('../../HIGGS.gz', 'rb') as f: | ||
with open('../../HIGGS.csv', 'wb') as csv_out: | ||
shutil.copyfileobj(f, csv_out) | ||
print("Extracted csv") | ||
|
||
df = pd.read_csv('../../HIGGS.csv', header=None) | ||
data = df.values | ||
X = data[:, 1:] | ||
Y = data[:, 0] | ||
|
||
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.33) | ||
num_samples_list = [1000000] | ||
num_of_k_list = [1, 2, 5, 7, 10] | ||
num_of_test_inputs = [100, 1000] | ||
|
||
for nsamples in num_samples_list: | ||
for n_k in num_of_k_list: | ||
for n_t in num_of_test_inputs: | ||
print("running experiment: num_of_train_samples: {}, num_of_k: {}, num_of_tests: {}".format( | ||
nsamples, | ||
n_k, | ||
n_t)) | ||
faiss_knn(X_train[:nsamples], Y_train[:nsamples], n_k, X_test[:n_t]) | ||
t = threading.Thread(target=sk_knn, args=(X_train[:nsamples], Y_train[:nsamples], n_k, X_test[:n_t])) | ||
t.start() | ||
t.join(timeout=600) | ||
if t.is_alive(): | ||
print("sklearn_knn, num_of_train_samples: {}, num_of_k: {}, num_of_tests: {}, run_time: {}".format( | ||
nsamples, | ||
n_k, | ||
n_t, | ||
"timeout after 60s")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
|
||
import pytest | ||
import sys | ||
from sklearn.neighbors import KNeighborsClassifier | ||
from deslib.tests.examples_test import * | ||
from deslib.tests.test_des_integration import load_dataset | ||
|
||
try: | ||
from deslib.util.faiss_knn_wrapper import FaissKNNClassifier | ||
except ImportError: | ||
pass | ||
|
||
|
||
@pytest.mark.skipif('faiss' not in sys.modules, | ||
reason="requires the faiss library") | ||
def test_faiss_predict(): | ||
rng = np.random.RandomState(123456) | ||
_, X_test, X_train, _, _, y_train = load_dataset(None, rng) | ||
k = 7 | ||
X_train = X_train.astype(np.float32) | ||
X_test = X_test.astype(np.float32) | ||
f_knn_test = FaissKNNClassifier(n_neighbors=k) | ||
f_knn_test.fit(X_train, y_train) | ||
f_knn_preds = f_knn_test.predict(X_test) | ||
|
||
knn_test = KNeighborsClassifier(n_neighbors=k) | ||
knn_test.fit(X_train, y_train) | ||
knn_preds = knn_test.predict(X_test) | ||
|
||
assert ((f_knn_preds - knn_preds).sum() == 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters