Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuber21 committed Dec 13, 2024
1 parent e304dae commit 720e447
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions onedal/svm/tests/test_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,16 @@
import pytest
import sklearn.utils.estimator_checks
from numpy.testing import assert_array_almost_equal, assert_array_equal
from sklearn import datasets
from sklearn.datasets import make_blobs
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.model_selection import train_test_split

from onedal.svm import SVC
from onedal.tests.utils._device_selection import (
get_queues,
pass_if_not_implemented_for_gpu,
)
from sklearn import datasets
from sklearn.datasets import make_blobs
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.model_selection import train_test_split


def _test_libsvm_parameters(queue, array_constr, dtype):
Expand Down Expand Up @@ -105,7 +106,7 @@ def test_decision_function(queue):
assert_array_almost_equal(dec.ravel(), clf.decision_function(X, queue=queue))


@pass_if_not_implemented_for_gpu(reason="not implemented")
@pass_if_not_implemented_for_gpu(reason="multiclass svm is not implemented")
@pytest.mark.parametrize("queue", get_queues())
def test_iris(queue):
iris = datasets.load_iris()
Expand All @@ -114,7 +115,7 @@ def test_iris(queue):
assert_array_equal(clf.classes_, np.sort(clf.classes_))


@pass_if_not_implemented_for_gpu(reason="not implemented")
@pass_if_not_implemented_for_gpu(reason="multiclass svm is not implemented")
@pytest.mark.parametrize("queue", get_queues())
def test_decision_function_shape(queue):
X, y = make_blobs(n_samples=80, centers=5, random_state=0)
Expand All @@ -131,7 +132,7 @@ def test_decision_function_shape(queue):
SVC(decision_function_shape="bad").fit(X_train, y_train, queue=queue)


@pass_if_not_implemented_for_gpu(reason="not implemented")
@pass_if_not_implemented_for_gpu(reason="multiclass svm is not implemented")
@pytest.mark.parametrize("queue", get_queues())
def test_pickle(queue):
iris = datasets.load_iris()
Expand All @@ -156,7 +157,7 @@ def test_pickle(queue):
pytest.param(
get_queues("gpu"),
marks=pytest.mark.xfail(
reason="raises Unimplemented error with inconsistent error message"
reason="raises Unimplemented error " "with inconsistent error message"
),
)
],
Expand Down

0 comments on commit 720e447

Please sign in to comment.