From 858e4900011c159c47597a26ad5d03ee369c3979 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 27 May 2022 15:06:18 +0200 Subject: [PATCH] relative error in loo_error method --- ezyrb/reducedordermodel.py | 11 ++++++----- tests/test_reducedordermodel.py | 6 +++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/ezyrb/reducedordermodel.py b/ezyrb/reducedordermodel.py index 8f654ab5..cdcb228c 100644 --- a/ezyrb/reducedordermodel.py +++ b/ezyrb/reducedordermodel.py @@ -217,15 +217,16 @@ def loo_error(self, *args, norm=np.linalg.norm, **kwargs): db_range = list(range(len(self.database))) for j in db_range: - remaining_index = db_range[:] - remaining_index.remove(j) - new_db = self.database[remaining_index] + indeces = np.array([True] * len(self.database)) + indeces[j] = False + + new_db = self.database[indeces] + test_db = self.database[~indeces] rom = type(self)(new_db, copy.deepcopy(self.reduction), copy.deepcopy(self.approximation)).fit( *args, **kwargs) - error[j] = norm(self.database.snapshots[j] - - rom.predict(self.database.parameters[j])) + error[j] = rom.test_error(test_db) return error diff --git a/tests/test_reducedordermodel.py b/tests/test_reducedordermodel.py index 788bae66..7a94432c 100644 --- a/tests/test_reducedordermodel.py +++ b/tests/test_reducedordermodel.py @@ -177,7 +177,7 @@ def test_loo_error_01(self): err = rom.loo_error() np.testing.assert_allclose( err, - np.array([421.299091, 344.571787, 48.711501, 300.490491]), + np.array([0.540029, 1.211744, 0.271776, 0.919509]), rtol=1e-4) def test_loo_error_02(self): @@ -188,7 +188,7 @@ def test_loo_error_02(self): err = rom.loo_error(normalizer=False) np.testing.assert_allclose( err[0], - np.array(498.703803), + np.array(0.639247), rtol=1e-3) def test_loo_error_singular_values(self): @@ -206,5 +206,5 @@ def test_optimal_mu(self): db = Database(param, snapshots.T) rom = ROM(db, pod, rbf).fit() opt_mu = rom.optimal_mu() - np.testing.assert_allclose(opt_mu, [[-0.17687147, -0.21820951]], + np.testing.assert_allclose(opt_mu, [[-0.046381, -0.15578 ]], rtol=1e-4)