Skip to content

Commit

Permalink
Return active learning proposed placements with raw coordinate names
Browse files Browse the repository at this point in the history
  • Loading branch information
RohitRathore1 authored and tom-andersson committed Feb 2, 2024
1 parent d4b3090 commit 6a80fc4
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
2 changes: 1 addition & 1 deletion deepsensor/active_learning/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def __call__(
self._init_acquisition_fn_ds(self.X_s)

# Dataframe for storing proposed context locations
self.X_new_df = pd.DataFrame(columns=["x1", "x2"])
self.X_new_df = pd.DataFrame(columns=[self.x1_name, self.x2_name])
self.X_new_df.index.name = "iteration"

# List to track indexes into original search grid of chosen sensor locations
Expand Down
21 changes: 21 additions & 0 deletions tests/test_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,27 @@ def test_greedy_alg_with_sequential_acquisition_fn(self):
task = self.task_loader("2014-12-31", context_sampling=10)
_ = alg(acquisition_fn, task)

def test_greedy_algorithm_column_names(self):
# Setup
acquisition_fn = Stddev(self.model)
X_s = self.ds_raw
alg = GreedyAlgorithm(
model=self.model,
X_t=X_s,
X_s=X_s,
N_new_context=1,
task_loader=self.task_loader,
)
task = self.task_loader("2014-12-31", context_sampling=10)

# Exercise
X_new_df, acquisition_fn_ds = alg(acquisition_fn, task)

# Assert
expected_columns = ['lat', 'lon'] # Replace with actual expected column names
actual_columns = X_new_df.columns.tolist()
self.assertEqual(expected_columns, actual_columns, "Column names do not match the expected names")

def test_greedy_alg_with_aux_at_targets_without_task_loader_raises_value_error(
self,
):
Expand Down

0 comments on commit 6a80fc4

Please sign in to comment.