Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return active learning proposed placements with raw coordinate names #99

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepsensor/active_learning/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def __call__(
self._init_acquisition_fn_ds(self.X_s)

# Dataframe for storing proposed context locations
self.X_new_df = pd.DataFrame(columns=["x1", "x2"])
self.X_new_df = pd.DataFrame(columns=[self.x1_name, self.x2_name])
self.X_new_df.index.name = "iteration"

# List to track indexes into original search grid of chosen sensor locations
Expand Down
25 changes: 25 additions & 0 deletions tests/test_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,31 @@ def test_greedy_alg_with_sequential_acquisition_fn(self):
task = self.task_loader("2014-12-31", context_sampling=10)
_ = alg(acquisition_fn, task)

def test_greedy_algorithm_column_names(self):
# Setup
acquisition_fn = Stddev(self.model)
X_s = self.ds_raw
alg = GreedyAlgorithm(
model=self.model,
X_t=X_s,
X_s=X_s,
N_new_context=1,
task_loader=self.task_loader,
)
task = self.task_loader("2014-12-31", context_sampling=10)

# Exercise
X_new_df, acquisition_fn_ds = alg(acquisition_fn, task)

# Assert
expected_columns = ["lat", "lon"] # Replace with actual expected column names
actual_columns = X_new_df.columns.tolist()
self.assertEqual(
expected_columns,
actual_columns,
"Column names do not match the expected names",
)

def test_greedy_alg_with_aux_at_targets_without_task_loader_raises_value_error(
self,
):
Expand Down
Loading