Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make google docstrings #79

Merged
merged 3 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
362 changes: 156 additions & 206 deletions deepsensor/active_learning/acquisition_fns.py

Large diffs are not rendered by default.

165 changes: 88 additions & 77 deletions deepsensor/active_learning/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@ class GreedyAlgorithm:
def __init__(
self,
model: DeepSensorModel,
X_s: Union[xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index],
X_t: Union[xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index],
X_s: Union[
xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index
],
X_t: Union[
xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index
],
X_s_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None,
X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None,
N_new_context: int = 1,
Expand All @@ -53,47 +57,45 @@ def __init__(
"""
...

Parameters
----------
model : :class:`~.model.model.DeepSensorModel`
Trained model to use for proposing new context points.
X_s : :class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index`
Search coordinates.
X_t : :class:`xarray.Dataset` | :class:`xarray.DataArray`
Target coordinates.
X_s_mask : :class:`xarray.Dataset` | :class:`xarray.DataArray`, optional
Mask for search coordinates. If provided, only points where mask
is True will be considered. Defaults to None.
X_t_mask : :class:`xarray.Dataset` | :class:`xarray.DataArray`, optional
..., by default None.
N_new_context : int, optional
..., by default 1.
X_normalised : bool, optional
..., by default False.
model_infill_method : str, optional
..., by default "mean".
query_infill : :class:`xarray.DataArray`, optional
..., by default None.
proposed_infill : :class:`xarray.DataArray`, optional
..., by default None.
context_set_idx : int, optional
..., by default 0.
target_set_idx : int, optional
..., by default 0.
progress_bar : bool, optional
..., by default False.
min_or_max : str, optional
..., by default "min".
task_loader : :class:`~.data.loader.TaskLoader`, optional
..., by default None.
verbose : bool, optional
..., by default False.

Raises
------
ValueError
If the ``model`` passed does not inherit from
:class:`~.model.model.DeepSensorModel`.
Args:
model (:class:`~.model.model.DeepSensorModel`):
Trained model to use for proposing new context points.
X_s (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index`):
Search coordinates.
X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray`):
Target coordinates.
X_s_mask (:class:`xarray.Dataset` | :class:`xarray.DataArray`, optional):
Mask for search coordinates. If provided, only points where mask
is True will be considered. Defaults to None.
X_t_mask (:class:`xarray.Dataset` | :class:`xarray.DataArray`, optional):
[Description of the X_t_mask parameter.], defaults to None.
N_new_context (int, optional):
[Description of the N_new_context parameter.], defaults to 1.
X_normalised (bool, optional):
[Description of the X_normalised parameter.], defaults to False.
model_infill_method (str, optional):
[Description of the model_infill_method parameter.], defaults to "mean".
query_infill (:class:`xarray.DataArray`, optional):
[Description of the query_infill parameter.], defaults to None.
proposed_infill (:class:`xarray.DataArray`, optional):
[Description of the proposed_infill parameter.], defaults to None.
context_set_idx (int, optional):
[Description of the context_set_idx parameter.], defaults to 0.
target_set_idx (int, optional):
[Description of the target_set_idx parameter.], defaults to 0.
progress_bar (bool, optional):
[Description of the progress_bar parameter.], defaults to False.
min_or_max (str, optional):
[Description of the min_or_max parameter.], defaults to "min".
task_loader (:class:`~.data.loader.TaskLoader`, optional):
[Description of the task_loader parameter.], defaults to None.
verbose (bool, optional):
[Description of the verbose parameter.], defaults to False.

Raises:
ValueError:
If the ``model`` passed does not inherit from
:class:`~.model.model.DeepSensorModel`.
"""
if not isinstance(model, DeepSensorModel):
raise ValueError(
Expand Down Expand Up @@ -136,11 +138,15 @@ def __init__(
self.X_t_mask = process_X_mask_for_X(self.X_t_mask, self.X_t)

# Interpolate overridden infill datasets at search points if necessary
if query_infill is not None and not da1_da2_same_grid(query_infill, X_s):
if query_infill is not None and not da1_da2_same_grid(
query_infill, X_s
):
if verbose:
print("query_infill not on search grid, interpolating.")
query_infill = interp_da1_to_da2(query_infill, self.X_s)
if proposed_infill is not None and not da1_da2_same_grid(proposed_infill, X_s):
if proposed_infill is not None and not da1_da2_same_grid(
proposed_infill, X_s
):
if verbose:
print("proposed_infill not on search grid, interpolating.")
proposed_infill = interp_da1_to_da2(proposed_infill, self.X_s)
Expand All @@ -153,7 +159,9 @@ def __init__(
self.X_t_arr = xarray_to_coord_array_normalised(X_t)
if self.X_t_mask is not None:
# Remove points that lie outside the mask
self.X_t_arr = mask_coord_array_normalised(self.X_t_arr, self.X_t_mask)
self.X_t_arr = mask_coord_array_normalised(
self.X_t_arr, self.X_t_mask
)
elif isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index)):
# Targets off-grid
self.X_t_arr = X_t.reset_index()[["x1", "x2"]].values.T
Expand Down Expand Up @@ -200,7 +208,9 @@ def _get_times_from_tasks(self):

def _model_infill_at_search_points(
self,
X_s: Union[xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index],
X_s: Union[
xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index
],
):
"""
Computes and sets the model infill y-values over whole search grid
Expand Down Expand Up @@ -345,7 +355,9 @@ def _search(self, acquisition_fn: AcquisitionFunction):
and self.task_loader.aux_at_contexts
):
# Add auxiliary variable sampled at context set as a new context variable
X_c = task_with_new["X_c"][self.task_loader.aux_at_contexts[0]]
X_c = task_with_new["X_c"][
self.task_loader.aux_at_contexts[0]
]
Y_c_aux = self.task_loader.sample_offgrid_aux(
X_c, self.task_loader.aux_at_contexts[1]
)
Expand Down Expand Up @@ -426,32 +438,25 @@ def __call__(
"""
Iteratively... docstring TODO

Returns a tensor of proposed new sensor locations (in greedy
iteration/priority order) and their corresponding list of indexes in
the search space.

Parameters
----------
acquisition_fn: :class:`~.active_learning.acquisition_fns.AcquisitionFunction`
...
tasks: List[:class:`~.data.task.Task`] | :class:`~.data.task.Task`
...

Returns
-------
X_new_df, acquisition_fn_ds: Tuple[:class:`pandas.DataFrame`, :class:`xarray.Dataset`]
...

Raises
------
ValueError
If ``acquisition_fn`` is an
:class:`~.active_learning.acquisition_fns.AcquisitionFunctionOracle`
and ``task_loader`` is None.
ValueError
If ``min_or_max`` is not ``"min"`` or ``"max"``.
ValueError
If ``Y_t_aux`` is in ``tasks`` but ``task_loader`` is None.
Args:
acquisition_fn (:class:`~.active_learning.acquisition_fns.AcquisitionFunction`):
[Description of the acquisition_fn parameter.]
tasks (List[:class:`~.data.task.Task`] | :class:`~.data.task.Task`):
[Description of the tasks parameter.]

Returns:
Tuple[:class:`pandas.DataFrame`, :class:`xarray.Dataset`]:
X_new_df, acquisition_fn_ds - [Description of the return values.]

Raises:
ValueError:
If ``acquisition_fn`` is an
:class:`~.active_learning.acquisition_fns.AcquisitionFunctionOracle`
and ``task_loader`` is None.
ValueError:
If ``min_or_max`` is not ``"min"`` or ``"max"``.
ValueError:
If ``Y_t_aux`` is in ``tasks`` but ``task_loader`` is None.
"""
if (
isinstance(acquisition_fn, AcquisitionFunctionOracle)
Expand All @@ -465,7 +470,8 @@ def __call__(
self.min_or_max = acquisition_fn.min_or_max
if self.min_or_max not in ["min", "max"]:
raise ValueError(
f"min_or_max must be either 'min' or 'max', got " f"{self.min_or_max}."
f"min_or_max must be either 'min' or 'max', got "
f"{self.min_or_max}."
)

if diff and isinstance(acquisition_fn, AcquisitionFunctionParallel):
Expand Down Expand Up @@ -496,7 +502,10 @@ def __call__(
"Model expects Y_t_aux data but a TaskLoader isn't "
"provided to GreedyAlgorithm."
)
if self.task_loader is not None and self.task_loader.aux_at_target_dims > 0:
if (
self.task_loader is not None
and self.task_loader.aux_at_target_dims > 0
):
tasks[i]["Y_t_aux"] = self.task_loader.sample_offgrid_aux(
self.X_t_arr, self.task_loader.aux_at_targets
)
Expand Down Expand Up @@ -532,7 +541,9 @@ def __call__(
if self.model_infill_method == "sample":
total_iterations *= self.n_samples

with tqdm(total=total_iterations, disable=not self.progress_bar) as self.pbar:
with tqdm(
total=total_iterations, disable=not self.progress_bar
) as self.pbar:
for iteration in range(self.N_new_context):
self.iteration = iteration
x_new = self._single_greedy_iteration(acquisition_fn)
Expand Down
Loading
Loading