diff --git a/ezyrb/linear.py b/ezyrb/linear.py index 9050dd6..3ab3e2a 100644 --- a/ezyrb/linear.py +++ b/ezyrb/linear.py @@ -55,4 +55,4 @@ def predict(self, new_point): :return: the interpolated values. :rtype: numpy.ndarray """ - return self.interpolator(new_point) + return self.interpolator(new_point).squeeze() diff --git a/tests/test_linear.py b/tests/test_linear.py index 0c79629..9ff5c31 100644 --- a/tests/test_linear.py +++ b/tests/test_linear.py @@ -4,7 +4,7 @@ from unittest import TestCase from ezyrb import Linear, Database, POD, ReducedOrderModel -class TestKNeighbors(TestCase): +class TestLinear(TestCase): def test_params(self): reg = Linear(fill_value=0) assert reg.fill_value == 0 @@ -52,6 +52,13 @@ def test_with_db_predict(self): assert rom.predict([2]) == 5 assert rom.predict([3]) == 3 + Y = np.random.uniform(size=(3, 3)) + db = Database(np.array([1, 2, 3]), Y) + rom = ReducedOrderModel(db, POD(), Linear()) + rom.fit() + assert rom.predict([1.]).shape == (3,) + + def test_wrong1(self): # wrong number of params with warnings.catch_warnings():