diff --git a/tests/unit/gaussian_naive_bayes/gaussian_nb_test.py b/tests/unit/gaussian_naive_bayes/gaussian_nb_test.py index 3c66c09..73b994b 100644 --- a/tests/unit/gaussian_naive_bayes/gaussian_nb_test.py +++ b/tests/unit/gaussian_naive_bayes/gaussian_nb_test.py @@ -3,6 +3,7 @@ import torch import torchml as ml from sklearn.naive_bayes import GaussianNB +from torch.autograd import gradcheck BSZ = 128 @@ -25,6 +26,10 @@ def test_fit(self): self.assertTrue(np.allclose(ref_preds, model_preds.numpy())) self.assertTrue(np.allclose(ref_preds, model_forward.numpy())) + inputX = torch.from_numpy(X) + inputX.requires_grad = True + self.assertTrue(gradcheck(model.predict, inputX, eps=1e-6, atol=1e-3)) + self.assertTrue(gradcheck(model, inputX, eps=1e-6, atol=1e-3)) if __name__ == "__main__": diff --git a/tests/unit/linear_model/lasso_test.py b/tests/unit/linear_model/lasso_test.py index b02a15e..4ed6094 100644 --- a/tests/unit/linear_model/lasso_test.py +++ b/tests/unit/linear_model/lasso_test.py @@ -3,6 +3,7 @@ import torch import torchml as ml import sklearn.linear_model as linear_model +from torch.autograd import gradcheck BSZ = 128 @@ -11,70 +12,110 @@ class TestLasso(unittest.TestCase): def test_fit(self): - X = np.random.randn(BSZ, DIM) - y = np.random.randn(BSZ, 1) - - ref = linear_model.Lasso(fit_intercept=False) - ref.fit(X, y) - ref_preds = ref.predict(X) - - model = ml.linear_model.Lasso() - model.fit(torch.from_numpy(X), torch.from_numpy(y)) - model_preds = model.predict(torch.from_numpy(X)) - model_forward = model(torch.from_numpy(X)) - - self.assertTrue( - np.allclose(ref_preds, model_preds[0].detach().numpy().flatten(), atol=1e-3) - ) - self.assertTrue( - np.allclose( - ref_preds, model_forward[0].detach().numpy().flatten(), atol=1e-3 + for i in range(2): + device = torch.device("cuda" if torch.cuda.is_available() and i else "cpu") + + X = np.random.randn(BSZ, DIM) + y = np.random.randn(BSZ, 1) + + ref = linear_model.Lasso(fit_intercept=False) + ref.fit(X, y) + ref_preds = ref.predict(X) + + model = ml.linear_model.Lasso() + model.fit(torch.from_numpy(X).to(device), torch.from_numpy(y).to(device)) + model_preds = model.predict(torch.from_numpy(X).to(device)) + model_forward = model(torch.from_numpy(X).to(device)) + + self.assertTrue( + np.allclose( + ref_preds, + model_preds[0].detach().cpu().numpy().flatten(), + atol=1e-3, + ) ) - ) + self.assertTrue( + np.allclose( + ref_preds, + model_forward[0].detach().cpu().numpy().flatten(), + atol=1e-3, + ) + ) + + inputX = torch.from_numpy(X).to(device) + inputX.requires_grad = True + self.assertTrue(gradcheck(model.predict, inputX, eps=1e-6, atol=1e-3)) + self.assertTrue(gradcheck(model, inputX, eps=1e-6, atol=1e-3)) def test_fit_intercept(self): - X = np.random.randn(BSZ, DIM) - y = np.random.randn(BSZ, 1) - - ref = linear_model.Lasso(fit_intercept=True) - ref.fit(X, y) - ref_preds = ref.predict(X) - - model = ml.linear_model.Lasso(fit_intercept=True) - model.fit(torch.from_numpy(X), torch.from_numpy(y)) - model_preds = model.predict(torch.from_numpy(X)) - model_forward = model(torch.from_numpy(X)) - - self.assertTrue( - np.allclose(ref_preds, model_preds[0].detach().numpy().flatten(), atol=1e-3) - ) - self.assertTrue( - np.allclose( - ref_preds, model_forward[0].detach().numpy().flatten(), atol=1e-3 + for i in range(2): + device = torch.device("cuda" if torch.cuda.is_available() and i else "cpu") + X = np.random.randn(BSZ, DIM) + y = np.random.randn(BSZ, 1) + + ref = linear_model.Lasso(fit_intercept=True) + ref.fit(X, y) + ref_preds = ref.predict(X) + + model = ml.linear_model.Lasso(fit_intercept=True) + model.fit(torch.from_numpy(X).to(device), torch.from_numpy(y).to(device)) + model_preds = model.predict(torch.from_numpy(X).to(device)) + model_forward = model(torch.from_numpy(X).to(device)) + + self.assertTrue( + np.allclose( + ref_preds, + model_preds[0].detach().cpu().numpy().flatten(), + atol=1e-3, + ) + ) + self.assertTrue( + np.allclose( + ref_preds, + model_forward[0].detach().cpu().numpy().flatten(), + atol=1e-3, + ) ) - ) + + inputX = torch.from_numpy(X).to(device) + inputX.requires_grad = True + self.assertTrue(gradcheck(model.predict, inputX, eps=1e-6, atol=1e-3)) + self.assertTrue(gradcheck(model, inputX, eps=1e-6, atol=1e-3)) def test_fit_positive(self): - X = np.random.randn(BSZ, DIM) - y = np.random.randn(BSZ, 1) - - ref = linear_model.Lasso(fit_intercept=False, positive=True) - ref.fit(X, y) - ref_preds = ref.predict(X) - - model = ml.linear_model.Lasso(fit_intercept=False, positive=True) - model.fit(torch.from_numpy(X), torch.from_numpy(y)) - model_preds = model.predict(torch.from_numpy(X)) - model_forward = model(torch.from_numpy(X)) - - self.assertTrue( - np.allclose(ref_preds, model_preds[0].detach().numpy().flatten(), atol=1e-3) - ) - self.assertTrue( - np.allclose( - ref_preds, model_forward[0].detach().numpy().flatten(), atol=1e-3 + for i in range(2): + device = torch.device("cuda" if torch.cuda.is_available() and i else "cpu") + X = np.random.randn(BSZ, DIM) + y = np.random.randn(BSZ, 1) + + ref = linear_model.Lasso(fit_intercept=False, positive=True) + ref.fit(X, y) + ref_preds = ref.predict(X) + + model = ml.linear_model.Lasso(fit_intercept=False, positive=True) + model.fit(torch.from_numpy(X).to(device), torch.from_numpy(y).to(device)) + model_preds = model.predict(torch.from_numpy(X).to(device)) + model_forward = model(torch.from_numpy(X).to(device)) + + self.assertTrue( + np.allclose( + ref_preds, + model_preds[0].detach().cpu().numpy().flatten(), + atol=1e-3, + ) + ) + self.assertTrue( + np.allclose( + ref_preds, + model_forward[0].detach().cpu().numpy().flatten(), + atol=1e-3, + ) ) - ) + + inputX = torch.from_numpy(X).to(device) + inputX.requires_grad = True + self.assertTrue(gradcheck(model.predict, inputX, eps=1e-6, atol=1e-3)) + self.assertTrue(gradcheck(model, inputX, eps=1e-6, atol=1e-3)) if __name__ == "__main__": diff --git a/tests/unit/linear_model/linear_regression_test.py b/tests/unit/linear_model/linear_regression_test.py index 0d36710..ed4c201 100644 --- a/tests/unit/linear_model/linear_regression_test.py +++ b/tests/unit/linear_model/linear_regression_test.py @@ -3,6 +3,7 @@ import torch import torchml as ml import sklearn.linear_model as linear_model +from torch.autograd import gradcheck BSZ = 128 @@ -11,20 +12,27 @@ class TestLinearRegression(unittest.TestCase): def test_fit(self): - X = np.random.randn(BSZ, DIM) - y = np.random.randn(BSZ, 1) - - ref = linear_model.LinearRegression(fit_intercept=False) - ref.fit(X, y) - ref_preds = ref.predict(X) - - model = ml.linear_model.LinearRegression(fit_intercept=False) - model.fit(torch.from_numpy(X), torch.from_numpy(y)) - model_preds = model.predict(torch.from_numpy(X)) - model_forward = model(torch.from_numpy(X)) - - self.assertTrue(np.allclose(ref_preds, model_preds.numpy())) - self.assertTrue(np.allclose(ref_preds, model_forward.numpy())) + for i in range(2): + device = torch.device("cuda" if torch.cuda.is_available() and i else "cpu") + X = np.random.randn(BSZ, DIM) + y = np.random.randn(BSZ, 1) + + ref = linear_model.LinearRegression(fit_intercept=False) + ref.fit(X, y) + ref_preds = ref.predict(X) + + model = ml.linear_model.LinearRegression(fit_intercept=False) + model.fit(torch.from_numpy(X).to(device), torch.from_numpy(y).to(device)) + model_preds = model.predict(torch.from_numpy(X).to(device)) + model_forward = model(torch.from_numpy(X).to(device)) + + self.assertTrue(np.allclose(ref_preds, model_preds.cpu().numpy())) + self.assertTrue(np.allclose(ref_preds, model_forward.cpu().numpy())) + + inputX = torch.from_numpy(X).to(device) + inputX.requires_grad = True + self.assertTrue(gradcheck(model.predict, inputX, eps=1e-6, atol=1e-3)) + self.assertTrue(gradcheck(model, inputX, eps=1e-6, atol=1e-3)) if __name__ == "__main__": diff --git a/tests/unit/linear_model/ridge_test.py b/tests/unit/linear_model/ridge_test.py index 315275e..cd78111 100644 --- a/tests/unit/linear_model/ridge_test.py +++ b/tests/unit/linear_model/ridge_test.py @@ -3,6 +3,7 @@ import torch import torchml as ml import sklearn.linear_model as linear_model +from torch.autograd import gradcheck BSZ = 128 @@ -11,36 +12,50 @@ class TestRidge(unittest.TestCase): def test_fit(self): - X = np.random.randn(BSZ, DIM) - y = np.random.randn(BSZ, 1) + for i in range(2): + device = torch.device("cuda" if torch.cuda.is_available() and i else "cpu") + X = np.random.randn(BSZ, DIM) + y = np.random.randn(BSZ, 1) - ref = linear_model.Ridge(fit_intercept=False) - ref.fit(X, y) - ref_preds = ref.predict(X) + ref = linear_model.Ridge(fit_intercept=False) + ref.fit(X, y) + ref_preds = ref.predict(X) - model = ml.linear_model.Ridge(fit_intercept=False) - model.fit(torch.from_numpy(X), torch.from_numpy(y)) - model_preds = model.predict(torch.from_numpy(X)) - model_forward = model(torch.from_numpy(X)) + model = ml.linear_model.Ridge(fit_intercept=False) + model.fit(torch.from_numpy(X).to(device), torch.from_numpy(y).to(device)) + model_preds = model.predict(torch.from_numpy(X).to(device)) + model_forward = model(torch.from_numpy(X).to(device)) - self.assertTrue(np.allclose(ref_preds, model_preds.numpy())) - self.assertTrue(np.allclose(ref_preds, model_forward.numpy())) + self.assertTrue(np.allclose(ref_preds, model_preds.cpu().numpy())) + self.assertTrue(np.allclose(ref_preds, model_forward.cpu().numpy())) - def test_fit_intercept(self): - X = np.random.randn(BSZ, DIM) - y = np.random.randn(BSZ, 1) - - ref = linear_model.Ridge(fit_intercept=True) - ref.fit(X, y) - ref_preds = ref.predict(X) + inputX = torch.from_numpy(X).to(device) + inputX.requires_grad = True + self.assertTrue(gradcheck(model.predict, inputX, eps=1e-6, atol=1e-3)) + self.assertTrue(gradcheck(model, inputX, eps=1e-6, atol=1e-3)) - model = ml.linear_model.Ridge(fit_intercept=True) - model.fit(torch.from_numpy(X), torch.from_numpy(y)) - model_preds = model.predict(torch.from_numpy(X)) - model_forward = model(torch.from_numpy(X)) - - self.assertTrue(np.allclose(ref_preds, model_preds.numpy())) - self.assertTrue(np.allclose(ref_preds, model_forward.numpy())) + def test_fit_intercept(self): + for i in range(2): + device = torch.device("cuda" if torch.cuda.is_available() and i else "cpu") + X = np.random.randn(BSZ, DIM) + y = np.random.randn(BSZ, 1) + + ref = linear_model.Ridge(fit_intercept=True) + ref.fit(X, y) + ref_preds = ref.predict(X) + + model = ml.linear_model.Ridge(fit_intercept=True) + model.fit(torch.from_numpy(X).to(device), torch.from_numpy(y).to(device)) + model_preds = model.predict(torch.from_numpy(X).to(device)) + model_forward = model(torch.from_numpy(X).to(device)) + + self.assertTrue(np.allclose(ref_preds, model_preds.cpu().numpy())) + self.assertTrue(np.allclose(ref_preds, model_forward.cpu().numpy())) + + inputX = torch.from_numpy(X).to(device) + inputX.requires_grad = True + self.assertTrue(gradcheck(model.predict, inputX, eps=1e-6, atol=1e-3)) + self.assertTrue(gradcheck(model, inputX, eps=1e-6, atol=1e-3)) if __name__ == "__main__": diff --git a/tests/unit/neighbors/k_neighbors_classifier_test.py b/tests/unit/neighbors/k_neighbors_classifier_test.py index 1db5479..84f5d13 100644 --- a/tests/unit/neighbors/k_neighbors_classifier_test.py +++ b/tests/unit/neighbors/k_neighbors_classifier_test.py @@ -3,6 +3,8 @@ import torch import torchml as ml import sklearn.neighbors as neighbors +from torch.autograd import gradcheck + BSZ = 1000 DIM = 50 @@ -10,31 +12,38 @@ class TestkneighborsClassifier(unittest.TestCase): def test_knn_classifier(self): - for i in range(1, 20, 1): - X = np.random.randn(BSZ, DIM) - y = np.random.randint(low=-100, high=100, size=BSZ) - p = np.random.randn(5, DIM) - - ref = neighbors.KNeighborsClassifier( - weights="distance" if i % 2 else "uniform", p=i - ) - ref.fit(X, y) - refr = ref.predict(p) - refp = ref.predict_proba(p) - - test = ml.neighbors.KNeighborsClassifier( - weights="distance" if i % 2 else "uniform", p=i - ) - test.fit(torch.from_numpy(X), torch.from_numpy(y)) - testr = test.predict(torch.from_numpy(p)) - testp = test.predict_proba(torch.from_numpy(p)) - self.assertTrue(np.allclose(refr, testr.numpy())) - self.assertTrue(np.allclose(refp, testp.numpy())) - - refr2 = ref.kneighbors(p) - testr2 = test.kneighbors(torch.from_numpy(p)) - self.assertTrue(np.allclose(refr2[0], testr2[0].numpy())) - self.assertTrue(np.allclose(refr2[1], testr2[1].numpy())) + for i in range(2): + device = torch.device("cuda" if torch.cuda.is_available() and i else "cpu") + for i in range(1, 5, 1): + X = np.random.randn(BSZ, DIM) + y = np.random.randint(low=-100, high=100, size=BSZ) + p = np.random.randn(5, DIM) + + ref = neighbors.KNeighborsClassifier( + weights="distance" if i % 2 else "uniform", p=i + ) + ref.fit(X, y) + refr = ref.predict(p) + refp = ref.predict_proba(p) + + test = ml.neighbors.KNeighborsClassifier( + weights="distance" if i % 2 else "uniform", p=i + ) + test.fit(torch.from_numpy(X).to(device), torch.from_numpy(y).to(device)) + inputP = torch.from_numpy(p).to(device).double() + inputP.requires_grad = True + + testr = test.predict(torch.from_numpy(p).to(device)) + testp = test.predict_proba(torch.from_numpy(p).to(device)) + self.assertTrue(gradcheck(test.predict, inputP, eps=1e-6, atol=1e-3)) + self.assertTrue(np.allclose(refr, testr.cpu().numpy())) + self.assertTrue(np.allclose(refp, testp.cpu().numpy())) + + refr2 = ref.kneighbors(p) + testr2 = test.kneighbors(torch.from_numpy(p).to(device)) + self.assertTrue(gradcheck(test.kneighbors, inputP, eps=1e-6, atol=1e-3)) + self.assertTrue(np.allclose(refr2[0], testr2[0].cpu().numpy())) + self.assertTrue(np.allclose(refr2[1], testr2[1].cpu().numpy())) if __name__ == "__main__": diff --git a/tests/unit/neighbors/nearest_centroids_test.py b/tests/unit/neighbors/nearest_centroids_test.py index 4a6dc5e..bcbe872 100644 --- a/tests/unit/neighbors/nearest_centroids_test.py +++ b/tests/unit/neighbors/nearest_centroids_test.py @@ -3,6 +3,7 @@ import torch import torchml as ml import sklearn.neighbors as neighbors +from torch.autograd import gradcheck # define numbers of classes & features SAMPLES = 10 @@ -12,20 +13,24 @@ class Testcentroids(unittest.TestCase): def test_kneighbors(self): - - for i in range(100): - X = np.random.randn(SAMPLES, FEA) - y = np.random.randint(1, CLS, size=SAMPLES) - torchX = torch.from_numpy(X) - torchy = torch.from_numpy(y) - ref = neighbors.NearestCentroid() - cent = ml.neighbors.NearestCentroid() - ref.fit(X, y) - cent.fit(torchX, torchy) - samp = np.random.randn(SAMPLES, FEA) - refres = ref.predict(samp) - centres = cent.predict(torch.from_numpy(samp)).numpy() - self.assertTrue(np.array_equal(refres, centres)) + for i in range(2): + device = torch.device("cuda" if torch.cuda.is_available() and i else "cpu") + for i in range(100): + X = np.random.randn(SAMPLES, FEA) + y = np.random.randint(1, CLS, size=SAMPLES) + torchX = torch.from_numpy(X).to(device) + torchy = torch.from_numpy(y).to(device) + ref = neighbors.NearestCentroid() + cent = ml.neighbors.NearestCentroid() + ref.fit(X, y) + cent.fit(torchX, torchy) + samp = np.random.randn(SAMPLES, FEA) + refres = ref.predict(samp) + centres = cent.predict(torch.from_numpy(samp).to(device)).cpu().numpy() + self.assertTrue(np.array_equal(refres, centres)) + inputSamp = torch.from_numpy(samp).to(device) + inputSamp.requires_grad = True + self.assertTrue(gradcheck(cent.predict, inputSamp, eps=1e-6, atol=1e-3)) if __name__ == "__main__": diff --git a/tests/unit/neighbors/nearest_neighbors_test.py b/tests/unit/neighbors/nearest_neighbors_test.py index 28ceb54..a0cec86 100644 --- a/tests/unit/neighbors/nearest_neighbors_test.py +++ b/tests/unit/neighbors/nearest_neighbors_test.py @@ -3,38 +3,51 @@ import torch import torchml as ml import sklearn.neighbors as neighbors +from torch.autograd import gradcheck BSZ = 128 DIM = 5 class Testkneighbors(unittest.TestCase): - def test_kneighbors(self): - for i in range(1, 200, 1): - X = np.random.randn(BSZ, DIM) - y = np.random.randn(5, DIM) - ref = neighbors.NearestNeighbors(p=i) - ref.fit(X) - test = ref.kneighbors(y) - - model = ml.neighbors.NearestNeighbors(p=i) - model.fit(torch.from_numpy(X)) - res = model.kneighbors(torch.from_numpy(y)) - - # return distance is true - self.assertTrue(np.allclose(test[0], res[0].numpy())) - self.assertTrue(np.allclose(test[1], res[1].numpy())) - - ref = neighbors.NearestNeighbors(p=i) - ref.fit(X) - test = ref.kneighbors(y, return_distance=False) - - model = ml.neighbors.NearestNeighbors(p=i) - model.fit(torch.from_numpy(X)) - res = model.kneighbors(torch.from_numpy(y), return_distance=False) - - # return distance is false - self.assertTrue(np.allclose(test, res.numpy())) + def test_kneighbors_classifier(self): + for i in range(2): + device = torch.device("cuda" if torch.cuda.is_available() and i else "cpu") + for i in range(1, 5, 1): + X = np.random.randn(BSZ, DIM) + y = np.random.randn(5, DIM) + ref = neighbors.NearestNeighbors(p=i) + ref.fit(X) + test = ref.kneighbors(y) + + model = ml.neighbors.NearestNeighbors(p=i) + model.fit(torch.from_numpy(X).to(device)) + res = model.kneighbors(torch.from_numpy(y).to(device)) + + # return distance is true + self.assertTrue(np.allclose(test[0], res[0].cpu().numpy())) + self.assertTrue(np.allclose(test[1], res[1].cpu().numpy())) + inputY = torch.from_numpy(y).to(device) + inputY.requires_grad = True + self.assertTrue( + gradcheck(model.kneighbors, inputY, eps=1e-6, atol=1e-3) + ) + + ref = neighbors.NearestNeighbors(p=i) + ref.fit(X) + test = ref.kneighbors(y, return_distance=False) + + model = ml.neighbors.NearestNeighbors(p=i) + model.fit(torch.from_numpy(X).to(device)) + res = model.kneighbors( + torch.from_numpy(y).to(device), return_distance=False + ) + + # return distance is false + self.assertTrue(np.allclose(test, res.cpu().numpy())) + self.assertTrue( + gradcheck(model.kneighbors, inputY, eps=1e-6, atol=1e-3) + ) if __name__ == "__main__": diff --git a/tests/unit/svm/__init__.py b/tests/unit/svm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/svm/linear_svc_test.py b/tests/unit/svm/linear_svc_test.py new file mode 100644 index 0000000..93d70a6 --- /dev/null +++ b/tests/unit/svm/linear_svc_test.py @@ -0,0 +1,74 @@ +import unittest +import numpy as np +import torch +from sklearn.datasets import make_classification +import sklearn.svm as svm +import time +from torch.autograd import gradcheck + +from torchml.svm import LinearSVC + +n_samples = 5000 +n_features = 4 +n_classes = 2 +n_informative = 4 + + +class TestLinearSVC(unittest.TestCase): + def test_LinearSVC(self): + for i in range(2): + device = torch.device("cuda" if torch.cuda.is_available and i else "cpu") + x, y = make_classification( + n_samples=n_samples, + n_features=n_features, + n_classes=n_classes, + n_informative=n_informative, + n_redundant=n_features - n_informative, + ) + lsvc = LinearSVC(max_iter=1000) + start = time.time() + lsvc.fit(torch.from_numpy(x).to(device), torch.from_numpy(y).to(device)) + end = time.time() + # print(end - start) + start = time.time() + reflsvc = svm.LinearSVC(max_iter=100000) + reflsvc.fit(x, y) + + end = time.time() + # print(end - start) + self.assertTrue( + np.allclose(lsvc.coef_.cpu().numpy(), reflsvc.coef_, atol=1e-2) + ) + self.assertTrue( + np.allclose( + lsvc.intercept_.cpu().numpy(), reflsvc.intercept_, atol=1e-2 + ) + ) + self.assertTrue( + np.allclose( + lsvc.decision_function(torch.from_numpy(x).to(device)) + .cpu() + .numpy(), + reflsvc.decision_function(x), + atol=1e-2, + ) + ) + + inputX = torch.from_numpy(x).to(device) + inputX.requires_grad = True + self.assertTrue( + gradcheck(lsvc.decision_function, inputX, eps=1e-6, atol=1e-3) + ) + + self.assertTrue( + np.allclose( + lsvc.predict(torch.from_numpy(x).to(device)).cpu().numpy(), + reflsvc.predict(x), + atol=1e-2, + ) + ) + self.assertTrue(gradcheck(lsvc.predict, inputX, eps=1e-6, atol=1e-3)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/svm/linear_svr_test.py b/tests/unit/svm/linear_svr_test.py new file mode 100644 index 0000000..c8a5e9a --- /dev/null +++ b/tests/unit/svm/linear_svr_test.py @@ -0,0 +1,58 @@ +import unittest +import numpy as np +import torch +from sklearn.datasets import make_regression +import sklearn.svm as svm +import time +from torch.autograd import gradcheck + +from torchml.svm import LinearSVR + +n_samples = 5000 +n_features = 4 +n_informative = 3 + + +class TestLinearSVR(unittest.TestCase): + def test_LinearSVR(self): + for i in range(2): + device = torch.device("cuda" if torch.cuda.is_available and i else "cpu") + x, y = make_regression( + n_samples=n_samples, + n_features=n_features, + n_informative=n_informative, + ) + lsvr = LinearSVR(max_iter=1000) + start = time.time() + lsvr.fit(torch.from_numpy(x).to(device), torch.from_numpy(y).to(device)) + end = time.time() + # print(end - start) + start = time.time() + reflsvr = svm.LinearSVR(max_iter=100000) + reflsvr.fit(x, y) + + end = time.time() + # print(end - start) + self.assertTrue( + np.allclose(lsvr.coef_.cpu().numpy(), reflsvr.coef_, atol=1e-2) + ) + self.assertTrue( + np.allclose( + lsvr.intercept_.cpu().numpy(), reflsvr.intercept_, atol=1e-2 + ) + ) + self.assertTrue( + np.allclose( + lsvr.predict(torch.from_numpy(x).to(device)).cpu().numpy(), + reflsvr.predict(x), + atol=1e-2, + ) + ) + + inputX = torch.from_numpy(x).to(device) + inputX.requires_grad = True + self.assertTrue(gradcheck(lsvr.predict, inputX, eps=1e-6, atol=1e-3)) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchml/linear_model/lasso.py b/torchml/linear_model/lasso.py index a6ab155..5ebe4eb 100644 --- a/torchml/linear_model/lasso.py +++ b/torchml/linear_model/lasso.py @@ -81,6 +81,8 @@ def fit(self, X: torch.Tensor, y: torch.Tensor): assert X.shape[0] == y.shape[0], "Number of X and y rows don't match" + device = X.device + m, n = X.shape w = cp.Variable((n, 1)) @@ -120,10 +122,12 @@ def fit(self, X: torch.Tensor, y: torch.Tensor): # this object is now callable with pytorch tensors if self.fit_intercept: self.weight, self.intercept = fit_lr( - X, y, torch.tensor(self.alpha, dtype=torch.float64) + X, y, torch.tensor(self.alpha, dtype=torch.float64, device=device) ) else: - self.weight = fit_lr(X, y, torch.tensor(self.alpha, dtype=torch.float64)) + self.weight = fit_lr( + X, y, torch.tensor(self.alpha, dtype=torch.float64, device=device) + ) self.weight = torch.stack(list(self.weight), dim=0) def predict(self, X: torch.Tensor): diff --git a/torchml/linear_model/ridge.py b/torchml/linear_model/ridge.py index 1c43555..cace70b 100644 --- a/torchml/linear_model/ridge.py +++ b/torchml/linear_model/ridge.py @@ -80,14 +80,16 @@ def fit(self, X: torch.Tensor, y: torch.Tensor): """ assert X.shape[0] == y.shape[0], "Number of X and y rows don't match" + device = X.device + if self.fit_intercept: - X = torch.cat([torch.ones(X.shape[0], 1), X], dim=1) + X = torch.cat([torch.ones(X.shape[0], 1, device=device), X], dim=1) # L2 penalty term will not apply when alpha is 0 if self.alpha == 0: self.weight = torch.pinverse(X.T @ X) @ X.T @ y else: - ridge = self.alpha * torch.eye(X.shape[1]) + ridge = self.alpha * torch.eye(X.shape[1], device=device) # intercept term is not penalized when fit_intercept is true if self.fit_intercept: ridge[0][0] = 0 @@ -112,5 +114,5 @@ def predict(self, X: torch.Tensor): ~~~ """ if self.fit_intercept: - X = torch.cat([torch.ones(X.shape[0], 1), X], dim=1) + X = torch.cat([torch.ones(X.shape[0], 1, device=X.device), X], dim=1) return X @ self.weight diff --git a/torchml/neighbors/k_neighbors_classifier.py b/torchml/neighbors/k_neighbors_classifier.py index ed56256..9a97053 100644 --- a/torchml/neighbors/k_neighbors_classifier.py +++ b/torchml/neighbors/k_neighbors_classifier.py @@ -1,7 +1,10 @@ import numbers import warnings +from typing import Tuple, Any import torch +from torch import Tensor + import torchml as ml @@ -107,6 +110,7 @@ def fit(self, X: torch.Tensor, y: torch.Tensor): """ self.KNN.fit(X) self.weights = self._check_weights(weights=self.weights) + device = X.device if y.ndim == 1 or y.ndim == 2 and y.shape[1] == 1: if y.ndim != 1: warnings.warn( @@ -122,7 +126,7 @@ def fit(self, X: torch.Tensor, y: torch.Tensor): self.outputs_2d_ = True self.classes_ = [] - self._y = torch.empty(size=y.shape, dtype=torch.long) + self._y = torch.empty(size=y.shape, dtype=torch.long, device=device) for k in range(self._y.shape[1]): classes, self._y[:, k] = torch.unique(y[:, k], return_inverse=True) self.classes_.append(classes) @@ -141,6 +145,7 @@ def predict(self, X: torch.Tensor) -> torch.Tensor: * `X` (torch.Tensor): the target point """ + device = X.device if self.weights == "uniform": neigh_ind = self.KNN.kneighbors(X, return_distance=False) neigh_dist = None @@ -157,7 +162,9 @@ def predict(self, X: torch.Tensor) -> torch.Tensor: n_queries = len(X) weights = self._get_weights(neigh_dist, self.weights) - y_pred = torch.empty((n_queries, n_outputs), dtype=classes_[0].dtype) + y_pred = torch.empty( + (n_queries, n_outputs), dtype=classes_[0].dtype, device=device + ) for k, classes_k in enumerate(classes_): if weights is None: @@ -182,6 +189,7 @@ def predict_proba(self, X: torch.Tensor) -> torch.Tensor: * `X` (torch.Tensor): the target point """ + device = X.device if self.weights == "uniform": neigh_ind = self.KNN.kneighbors(X, return_distance=False) neigh_dist = None @@ -198,13 +206,13 @@ def predict_proba(self, X: torch.Tensor) -> torch.Tensor: weights = self._get_weights(neigh_dist, self.weights) if weights is None: - weights = torch.ones_like(neigh_ind) + weights = torch.ones_like(neigh_ind, device=device) all_rows = torch.arange(n_queries) probabilities = [] for k, classes_k in enumerate(classes_): pred_labels = _y[:, k][neigh_ind] - proba_k = torch.zeros((n_queries, len(classes_k))) + proba_k = torch.zeros((n_queries, len(classes_k)), device=device) for i, idx in enumerate(pred_labels.T): proba_k[all_rows, idx] += weights[:, i] @@ -264,21 +272,27 @@ def _get_weights(self, dist: torch.Tensor, weights: str) -> torch.Tensor: "'distance', or a callable function" ) - def _weighted_mode(self, a: torch.Tensor, w: torch.Tensor) -> torch.Tensor: - res = torch.empty(0) - resi = torch.empty(0) + def _weighted_mode( + self, a: torch.Tensor, w: torch.Tensor + ) -> tuple[Tensor | Any, Tensor | Any]: + device = a.device + res = torch.empty(0, device=device) + resi = torch.empty(0, device=device) for i, x in enumerate(a): res1 = self._weighted_mode_util(x, w) - res = torch.cat((res, torch.tensor([res1[0]]))) - resi = torch.cat((resi, torch.tensor([res1[1]]))) + res = torch.cat((res, torch.tensor([res1[0]], device=device))) + resi = torch.cat((resi, torch.tensor([res1[1]], device=device))) return res, resi - def _weighted_mode_util(self, a: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + def _weighted_mode_util( + self, a: torch.Tensor, w: torch.Tensor + ) -> tuple[Any, Tensor]: + device = a.device unique_a = torch.unique(a) - res = torch.empty(0) + res = torch.empty(0, device=device) for i, x in enumerate(unique_a): cleared = (a == x).float() cleared_weights = cleared * w sum = torch.sum(cleared_weights) - res = torch.cat((res, torch.tensor([sum]))) + res = torch.cat((res, torch.tensor([sum], device=device))) return unique_a[torch.argmax(res)], torch.max(res) diff --git a/torchml/neighbors/nearest_centroid.py b/torchml/neighbors/nearest_centroid.py index 96dbf7b..7d7ee23 100644 --- a/torchml/neighbors/nearest_centroid.py +++ b/torchml/neighbors/nearest_centroid.py @@ -64,6 +64,8 @@ def fit(self, X: torch.Tensor, y: torch.Tensor): * `y` (torch.Tensor): array-like of shape (n_samples,) Target values """ + device = X.device + n_samples, n_features = X.shape # y_ind: idx, y_classes: unique tensor @@ -79,7 +81,7 @@ def fit(self, X: torch.Tensor, y: torch.Tensor): # Mask mapping each class to its members. self.centroids_ = torch.empty( - (n_classes, n_features), dtype=X.dtype, device=torch.device("cpu") + (n_classes, n_features), dtype=X.dtype, device=device ) # Number of clusters in each class. @@ -109,15 +111,15 @@ def predict(self, X: torch.tensor) -> torch.tensor: * (torch.Tensor): the predicted classes """ + device = X.device if X is None or X.size(dim=0) < 1: print("Warning: check input size") - ret = torch.empty(X.size(dim=0)) + ret = torch.empty(X.size(dim=0), device=device) for i in range(X.size(dim=0)): ret[i] = self.classes_[ - torch.argmin(torch.nn.PairwiseDistance(p=2) - (X[i], self.centroids_)) + torch.argmin(torch.nn.PairwiseDistance(p=2)(X[i], self.centroids_)) ] # return ret.to(self.y_type) diff --git a/torchml/svm/__init__.py b/torchml/svm/__init__.py new file mode 100644 index 0000000..09ebc60 --- /dev/null +++ b/torchml/svm/__init__.py @@ -0,0 +1,2 @@ +from .linear_svc import LinearSVC +from .linear_svr import LinearSVR diff --git a/torchml/svm/linear_svc.py b/torchml/svm/linear_svc.py new file mode 100644 index 0000000..127f4f1 --- /dev/null +++ b/torchml/svm/linear_svc.py @@ -0,0 +1,240 @@ +import torch + +import torchml as ml +import cvxpy as cp +from cvxpylayers.torch import CvxpyLayer + + +class LinearSVC(ml.Model): + """ + ## Description + + Support vector classifier with cvxpy + + ## References + + 1. Bernhard E. Boser, Isabelle M. Guyon, and Vladimir N. Vapnik. 1992. A training algorithm for optimal margin classifiers. In Proceedings of the fifth annual workshop on Computational learning theory (COLT '92). Association for Computing Machinery, New York, NY, USA, 144–152. https://doi.org/10.1145/130385.130401 + 2. MIT 6.034 Artificial Intelligence, Fall 2010, [16. Learning: Support Vector Machines](https://youtu.be/_PwhiWxHK8o) + 3. The scikit-learn [documentation page](https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html) for LinearSVC. + + ## Arguments + + * `penalty` (str {'l1', 'l2'}, default=’l2’): + Specifies the norm used in the penalization. + + * `loss` (str {‘hinge’, ‘squared_hinge’}, default=’squared_hinge’): + Specifies the loss function. ‘hinge’ is the standard SVM loss. + + * `dual` (bool, default=True): + Dummy variable to keep consistency with SKlearn's API, always 'False' for now. + + * `tol` (float, default=1e-4) + Tolerance for stopping criteria. + + * `C` (float, default=1.0): + Regularization parameter. The strength of the regularization is inversely proportional to C. Must be strictly positive. + + * `multi_class` (str {‘ovr’, ‘crammer_singer’}, default=’ovr’): + Dummy variable, always 'ovr' (one class over all the other as a single class) + + * `fit_intercept` (bool, default=True): + Whether to calculate the intercept for this model. + + * `intercept_scaling` (float, default=1): + Dummy variable to mimic the sklearn API, always 1 for now + + * `class_weight` (dict or str ‘balanced’, default=None): + Dummy variable to mimic the sklearn API, always None for now + + * `verbose` (int, default=0): + Dummy variable to mimic the sklearn API, always 0 for now + + * `random_state` (int, RandomState instance or None, default=None): + Dummy variable to mimic the sklearn API, always None for now + + * `max_iter` (int, default=1000): + The maximum number of iterations to be run for the underneath convex solver. + + + ## Example + + ~~~python + import numpy as np + from torchml.svm import LinearSVC + from sklearn.datasets import make_classification + + x, y = make_classification( + n_samples=n_samples, + n_features=n_features, + n_classes=n_classes, + n_informative=n_informative, + n_redundant=n_features - n_informative, + ) + svc = LinearSVC(max_iter=1000) + svc.fit(torch.from_numpy(x), torch.from_numpy(y)) + svc.decision_function(torch.from_numpy(x) + svc.predict(torch.from_numpy(x)) + ~~~ + """ + + def __init__( + self, + penalty="l2", + loss="squared_hinge", + *, + dual=True, + tol=1e-4, + C=1.0, + multi_class="ovr", + fit_intercept=True, + intercept_scaling=1, + class_weight=None, + verbose=0, + random_state=None, + max_iter=1000, + ): + super(LinearSVC, self).__init__() + self.coef_ = None + self.intercept_ = None + self.classes_ = None + self.dual = dual + self.tol = tol + self.C = C + self.multi_class = multi_class + self.fit_intercept = fit_intercept + self.intercept_scaling = intercept_scaling + self.class_weight = class_weight + self.verbose = verbose + self.random_state = random_state + self.max_iter = max_iter + self.penalty = penalty + self.loss = loss + + def fit(self, X: torch.Tensor, y: torch.Tensor, sample_weight=None): + """ + ## Description + + Initialize the class with training sets + + ## Arguments + * `X` (torch.Tensor): the training set + * `y` (torch.Tensor, default=None): the class labels for each sample + + """ + if self.C < 0: + raise ValueError("Penalty term must be positive; got (C=%r)" % self.C) + device = X.device + self.classes_ = torch.unique(y) + assert X.shape[0] == y.shape[0], "Number of X and y rows don't match" + m, n = X.shape + self.coef_ = torch.empty((0, n), device=device) + self.intercept_ = torch.empty((0), device=device) + if self.classes_.shape[0] == 2: + self._fit_with_one_class( + X, y, self.classes_[1], sample_weight=sample_weight + ) + else: + for i, x in enumerate(self.classes_): + self._fit_with_one_class(X, y, x, sample_weight=sample_weight) + + def decision_function(self, X: torch.Tensor) -> torch.Tensor: + """ + ## Description + + Predict confidence scores for samples. + + ## Arguments + * `X` (torch.Tensor): the data set for which we want to get the confidence scores. + + """ + scores = X @ self.coef_.T + self.intercept_ + return scores.ravel() if scores.shape[1] == 1 else scores + + def predict(self, X: torch.Tensor) -> torch.Tensor: + """ + ## Description + + Predict the class labels for the provided data. + + ## Arguments + + * `X` (torch.Tensor): the target point + """ + scores = self.decision_function(X) + if len(scores.shape) == 1: + indices = (scores > 0).long() + else: + indices = scores.argmax(dim=1) + return self.classes_[indices] + + def _fit_with_one_class( + self, X: torch.Tensor, y: torch.Tensor, fitting_class: any, sample_weight=None + ): + device = X.device + m, n = X.shape + + y = torch.unsqueeze(y, 1) + + y = (y == fitting_class).float() + y *= 2 + y -= 1 + + w = cp.Variable((n, 1)) + if self.fit_intercept: + b = cp.Variable() + X_param = cp.Parameter((m, n)) + ones = torch.ones((m, 1)) + + loss = cp.multiply((1 / 2.0), cp.norm(w, 2)) + + if self.fit_intercept: + hinge = cp.pos(ones - cp.multiply(y.cpu(), X_param @ w + b)) + else: + hinge = cp.pos(ones - cp.multiply(y.cpu(), X_param @ w)) + + if self.loss == "squared_hinge": + loss += cp.multiply(self.C, cp.sum(cp.square(hinge))) + elif self.loss == "hinge": + loss += cp.multiply(self.C, cp.sum(hinge)) + + objective = loss + + # set up constraints + constraints = [] + + prob = cp.Problem(cp.Minimize(objective), constraints) + assert prob.is_dpp() + + # convert into pytorch layer + if self.fit_intercept: + fit_lr = CvxpyLayer(prob, [X_param], [w, b]) + else: + fit_lr = CvxpyLayer(prob, [X_param], [w]) + + # prob.solve(solver="ECOS", abstol=self.tol, max_iters=self.max_iter) + if self.fit_intercept: + weight, intercept = fit_lr( + X, + solver_args={ + "solve_method": "ECOS", + "abstol": self.tol, + "max_iters": self.max_iter, + }, + ) + else: + weight = fit_lr( + X, + solver_args={ + "solve_method": "ECOS", + "abstol": self.tol, + "max_iters": self.max_iter, + }, + ) + + self.coef_ = torch.cat((self.coef_, torch.t(weight))) + + if self.fit_intercept: + self.intercept_ = torch.cat( + (self.intercept_, torch.unsqueeze(intercept, 0)) + ) + return self diff --git a/torchml/svm/linear_svr.py b/torchml/svm/linear_svr.py new file mode 100644 index 0000000..0e2c67b --- /dev/null +++ b/torchml/svm/linear_svr.py @@ -0,0 +1,185 @@ +import torch + +import torchml as ml +import cvxpy as cp +from cvxpylayers.torch import CvxpyLayer + + +class LinearSVR(ml.Model): + """ + ## Description + + Support vector regressor with cvxpy + + ## References + + 1. Bernhard E. Boser, Isabelle M. Guyon, and Vladimir N. Vapnik. 1992. A training algorithm for optimal margin classifiers. In Proceedings of the fifth annual workshop on Computational learning theory (COLT '92). Association for Computing Machinery, New York, NY, USA, 144–152. https://doi.org/10.1145/130385.130401 + 2. MIT 6.034 Artificial Intelligence, Fall 2010, [16. Learning: Support Vector Machines](https://youtu.be/_PwhiWxHK8o) + 3. The scikit-learn [documentation page](https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html) for LinearSVC. + + ## Arguments + + * `loss` (str {‘epsilon_insensitive’, ‘squared_epsilon_insensitive’}, default=’epsilon_insensitive’): + Specifies the loss function. + + * `epsilon` (float, default=0.0): + Epsilon parameter in the epsilon-insensitive loss function. + + * `tol` (float, default=1e-4) + Tolerance for stopping criteria. + + * `C` (float, default=1.0): + Regularization parameter. The strength of the regularization is inversely proportional to C. Must be strictly positive. + + * `fit_intercept` (bool, default=True): + Whether to calculate the intercept for this model. + + * `intercept_scaling` (float, default=1): + Dummy variable to mimic the sklearn API, always 1 for now + + * `dual` (bool, default=True): + Dummy variable to keep consistency with SKlearn's API, always 'False' for now. + + * `verbose` (int, default=0): + Dummy variable to mimic the sklearn API, always 0 for now + + * `random_state` (int, RandomState instance or None, default=None): + Dummy variable to mimic the sklearn API, always None for now + + * `max_iter` (int, default=1000): + The maximum number of iterations to be run for the underneath convex solver. + + + ## Example + + ~~~python + import numpy as np + from torchml.svm import LinearSVR + from sklearn.datasets import make_regression + + x, y = make_regression( + n_samples=n_samples, + n_features=n_features, + n_informative=n_informative, + ) + svr = LinearSVR(max_iter=1000) + svr.fit(torch.from_numpy(x), torch.from_numpy(y)) + svr.predict(torch.from_numpy(x)) + ~~~ + """ + + def __init__( + self, + *, + epsilon=0.0, + tol=1e-4, + C=1.0, + loss="epsilon_insensitive", + fit_intercept=True, + intercept_scaling=1.0, + dual=True, + verbose=0, + random_state=None, + max_iter=1000, + ): + super(LinearSVR, self).__init__() + self.intercept_ = None + self.coef_ = None + self.classes_ = None + self.tol = tol + self.C = C + self.epsilon = epsilon + self.fit_intercept = fit_intercept + self.intercept_scaling = intercept_scaling + self.verbose = verbose + self.random_state = random_state + self.max_iter = max_iter + self.dual = dual + self.loss = loss + + def fit(self, X: torch.Tensor, y: torch.Tensor, sample_weight=None): + """ + ## Description + + Initialize the class with training sets + + ## Arguments + * `X` (torch.Tensor): the training set + * `y` (torch.Tensor): Target vector relative to X. + * `sample_weight` (default=None): Dummy variable for feature not supported yet. + """ + + if self.C < 0: + raise ValueError("Penalty term must be positive; got (C=%r)" % self.C) + assert X.shape[0] == y.shape[0], "Number of X and y rows don't match" + m, n = X.shape + m, n = X.shape + + y = torch.unsqueeze(y, 1) + + w = cp.Variable((n, 1)) + if self.fit_intercept: + b = cp.Variable() + X_param = cp.Parameter((m, n)) + + loss = cp.multiply((1 / 2.0), cp.norm(w, 2)) + + if self.fit_intercept: + hinge = cp.pos(cp.abs(y.cpu() - (X_param @ w + b)) - self.epsilon) + else: + hinge = cp.pos(cp.abs(y.cpu() - (X_param @ w + b)) - self.epsilon) + + if self.loss == "epsilon_insensitive": + loss += self.C * cp.sum(cp.square(hinge)) + elif self.loss == "squared_epsilon_insensitive": + loss += self.C * cp.sum(hinge) + + objective = loss + + # set up constraints + constraints = [] + + prob = cp.Problem(cp.Minimize(objective), constraints) + assert prob.is_dpp() + + if self.fit_intercept: + fit_lr = CvxpyLayer(prob, [X_param], [w, b]) + else: + fit_lr = CvxpyLayer(prob, [X_param], [w]) + + if self.fit_intercept: + self.coef_, self.intercept_ = fit_lr( + X, + solver_args={ + "solve_method": "ECOS", + "abstol": self.tol, + "max_iters": self.max_iter, + }, + ) + else: + (self.coef_,) = fit_lr( + X, + solver_args={ + "solve_method": "ECOS", + "abstol": self.tol, + "max_iters": self.max_iter, + }, + ) + + self.coef_ = torch.flatten(self.coef_) + if self.fit_intercept: + self.intercept_ = torch.flatten(self.intercept_) + + return self + + def predict(self, X: torch.Tensor) -> torch.Tensor: + """ + ## Description + + Predict using the linear model + + ## Arguments + + * `X` (torch.Tensor): Samples. + """ + return X @ self.coef_ + self.intercept_