Skip to content

Commit

Permalink
black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Kalle Westerling committed Oct 2, 2023
1 parent 2f62069 commit 94a0750
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 9 deletions.
2 changes: 1 addition & 1 deletion deepsensor/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
) -> None:
"""
Initialise a TaskLoader object.
The behaviour is the following:
- If all data passed as paths, load the data and overwrite the paths with the loaded data
- Either all data is passed as paths, or all data is passed as loaded data (else ValueError)
Expand Down
4 changes: 2 additions & 2 deletions deepsensor/data/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def check_params_computed(self, var_ID, method) -> bool:
and "params" in self.config[var_ID]
):
return True

return False

def add_to_config(self, var_ID, **kwargs):
Expand All @@ -251,7 +251,7 @@ def get_norm_params(self, var_ID, data, method=None):
"""
Get pre-computed normalisation params or compute them for variable
``var_ID``.
.. note:
TODO do we need to pass var_ID? Can we just use name of data?
Expand Down
2 changes: 1 addition & 1 deletion deepsensor/data/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def op(self, f, op_flag=None):
Returns
-------
task : dict.
task : dict.
Task dictionary with f applied to the array elements and
op_flag set in the ``ops`` key.
"""
Expand Down
7 changes: 3 additions & 4 deletions deepsensor/model/convnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def __init__(self, model_ID: str):
super().__init__()

self.load(model_ID)

@dispatch
def __init__(
self,
Expand All @@ -251,7 +251,7 @@ def __init__(
):
"""Instantiate a model from a folder containing model weights and config."""
super().__init__(data_processor, task_loader)

self.load(model_ID)

def save(self, model_ID: str):
Expand Down Expand Up @@ -290,7 +290,6 @@ def load(self, model_ID: str):

@classmethod
def modify_task(cls, task):

"""
Cast numpy arrays to TensorFlow or PyTorch tensors, add batch dim, and
mask NaNs.
Expand All @@ -305,7 +304,7 @@ def modify_task(cls, task):
...
...
"""

if "target_nans_removed" not in task["ops"]:
task = task.remove_nans_from_task_Y_t_if_present()
if "batch_dim" not in task["ops"]:
Expand Down
2 changes: 1 addition & 1 deletion deepsensor/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def predict(
X_t_arr = mask_coord_array_normalised(X_t_arr, X_t_mask_normalised)
else:
X_t_arr = (X_t_normalised["x1"].values, X_t_normalised["x2"].values)

elif mode == "off-grid":
X_t_arr = X_t_normalised.reset_index()[["x1", "x2"]].values.T

Expand Down

0 comments on commit 94a0750

Please sign in to comment.