From 6a80fc431d2782f2561e024d5a3433c95b6e1b42 Mon Sep 17 00:00:00 2001 From: RohitRathore1 Date: Sun, 21 Jan 2024 17:50:02 +0530 Subject: [PATCH] Return active learning proposed placements with raw coordinate names --- deepsensor/active_learning/algorithms.py | 2 +- tests/test_active_learning.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/deepsensor/active_learning/algorithms.py b/deepsensor/active_learning/algorithms.py index 51234704..19bb67c6 100644 --- a/deepsensor/active_learning/algorithms.py +++ b/deepsensor/active_learning/algorithms.py @@ -505,7 +505,7 @@ def __call__( self._init_acquisition_fn_ds(self.X_s) # Dataframe for storing proposed context locations - self.X_new_df = pd.DataFrame(columns=["x1", "x2"]) + self.X_new_df = pd.DataFrame(columns=[self.x1_name, self.x2_name]) self.X_new_df.index.name = "iteration" # List to track indexes into original search grid of chosen sensor locations diff --git a/tests/test_active_learning.py b/tests/test_active_learning.py index e8690e5a..e2b71d63 100644 --- a/tests/test_active_learning.py +++ b/tests/test_active_learning.py @@ -181,6 +181,27 @@ def test_greedy_alg_with_sequential_acquisition_fn(self): task = self.task_loader("2014-12-31", context_sampling=10) _ = alg(acquisition_fn, task) + def test_greedy_algorithm_column_names(self): + # Setup + acquisition_fn = Stddev(self.model) + X_s = self.ds_raw + alg = GreedyAlgorithm( + model=self.model, + X_t=X_s, + X_s=X_s, + N_new_context=1, + task_loader=self.task_loader, + ) + task = self.task_loader("2014-12-31", context_sampling=10) + + # Exercise + X_new_df, acquisition_fn_ds = alg(acquisition_fn, task) + + # Assert + expected_columns = ['lat', 'lon'] # Replace with actual expected column names + actual_columns = X_new_df.columns.tolist() + self.assertEqual(expected_columns, actual_columns, "Column names do not match the expected names") + def test_greedy_alg_with_aux_at_targets_without_task_loader_raises_value_error( self, ):