From 131c43445fb97b64aa28aa8b4855d2548e1e7e4e Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Tue, 10 Oct 2023 19:01:51 +0200 Subject: [PATCH 01/69] stach changes --- deepsensor/data/loader.py | 61 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index bc5d39b0..2474e613 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -3,12 +3,13 @@ import os import json import copy +import random import numpy as np import xarray as xr import pandas as pd -from typing import List, Tuple, Union, Optional +from typing import List, Tuple, Union, Optional, Sequence from deepsensor.errors import InvalidSamplingStrategyError @@ -810,6 +811,58 @@ def sample_offgrid_aux( # Reshape to (variable, *spatial_dims) Y_t_aux = Y_t_aux.reshape(1, *Y_t_aux.shape) return Y_t_aux + + def sample_patch_size_extent(self) -> Sequence[float]: + """Sample patch size. + + :return sequence of patch spatial extent as [lat_min, lat_max, lon_min, lon_max] + """ + # assumption of normalized spatial coordinates between 0 and 1 + + lat_extend, lon_extend = self.patch_size + + lat_side = lat_extend / 2 + lon_side = lon_extend / 2 + + # sample a point that satisfies the boundary and target conditions + continue_looking = True + while continue_looking: + lat_point = random.uniform(lat_side, 1 - lat_side) + lon_point = random.uniform(lon_side, 1 - lon_side) + + # bbox of lat_min, lat_max, lon_min, lon_max + bbox = [lat_point - lat_side, lat_point + lat_side, lon_point - lon_side, lon_point + lon_side] + + x1_slice = slice(bbox[0], bbox[1]) + x2_slice = slice(bbox[2], bbox[3]) + # check whether target is non-empty given this box + target_check: list[bool] = [] + for target_var in self.target: + if isinstance(target_var, (pd.DataFrame, pd.Series)): + data = target_var.loc[(slice(None), x1_slice, x2_slice)] + else: + data = target_var.sel(x1=x1_slice, x2=x2_slice) + + target_check.append(True if len(data)>0 else False) + + # check whether context is non-empty given this box + context_check: list[bool] = [] + for context_var in self.context: + if isinstance(context_var, (pd.DataFrame, pd.Series)): + data = context_var[(context_var.index.get_level_values('x1') >= bbox[0]) & (context_var.index.get_level_values('x1') <= bbox[1]) & + (context_var.index.get_level_values('x2') >= bbox[2]) & (context_var.index.get_level_values('x2') <= bbox[3])] + + # data = context_var.loc[(slice(None), x1_slice, x2_slice)] + else: + data = context_var.sel(x1=x1_slice, x2=x2_slice) + + context_check.append(True if len(data)>0 else False) + + + if all(target_check) and all(context_check): + continue_looking = False + + return bbox def task_generation( self, @@ -1096,6 +1149,12 @@ def sample_variable(var, sampling_strat, seed): context_slices[link[0]].index ) + # sample common patch size for context and target set + if self.patch_size is not None: + sample_patch_size = self.sample_patch_size_extent() + else: + sample_patch_size = None + for i, (var, sampling_strat) in enumerate( zip(context_slices, context_sampling) ): From 3342b96ccbf713c8098ec71546124aecbd390363 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Thu, 12 Oct 2023 12:46:19 +0000 Subject: [PATCH 02/69] draft --- deepsensor/data/loader.py | 51 +++++++++++++++++++++++++++++++++------ tests/test_task_loader.py | 21 ++++++++++++++++ 2 files changed, 65 insertions(+), 7 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 2474e613..2f558104 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -602,6 +602,7 @@ def sample_da( self, da: Union[xr.DataArray, xr.Dataset], sampling_strat: Union[str, int, float, np.ndarray], + sample_patch_size: Optional[list[float]] = None, seed: Optional[int] = None, ) -> (np.ndarray, np.ndarray): """ @@ -614,6 +615,8 @@ def sample_da( sampling_strat : str | int | float | :class:`numpy:numpy.ndarray` Sampling strategy, either "all" or an integer for random grid cell sampling. + sample_patch_size: list + desired patch size extent to sample [lat_min, lat_max, lon_min, lon_max] seed : int, optional Seed for random sampling. Default: None. @@ -634,6 +637,12 @@ def sample_da( if isinstance(da, xr.Dataset): da = da.to_array() + # restric to a certain spatial patch + if sample_patch_size is not None: + x1_slice = slice(sample_patch_size[0], sample_patch_size[1]) + x2_slice = slice(sample_patch_size[2], sample_patch_size[3]) + da = da.sel(x1=x1_slice, x2=x2_slice) + if isinstance(sampling_strat, float): sampling_strat = int(sampling_strat * da.size) @@ -707,6 +716,7 @@ def sample_df( self, df: Union[pd.DataFrame, pd.Series], sampling_strat: Union[str, int, float, np.ndarray], + sample_patch_size: Optional[list[float]] = None, seed: Optional[int] = None, ) -> (np.ndarray, np.ndarray): """ @@ -720,6 +730,8 @@ def sample_df( sampling_strat : str | int | float | :class:`numpy:numpy.ndarray` Sampling strategy, either "all" or an integer for random grid cell sampling. + sample_patch_size: list[float], optional + desired patch size extent to sample [lat_min, lat_max, lon_min, lon_max] seed : int, optional Seed for random sampling. Default: None. @@ -738,6 +750,12 @@ def sample_df( """ df = df.dropna(how="any") # If any obs are NaN, drop them + if sample_patch_size is not None: + # retrieve desired patch size + lat_min, lat_max, lon_min, lon_max = sample_patch_size + df = df[(df.index.get_level_values('x1') >= lat_min) & (df.index.get_level_values('x1') <= lat_max) & + (df.index.get_level_values('x2') >= lon_min) & (df.index.get_level_values('x2') <= lon_max)] + if isinstance(sampling_strat, float): sampling_strat = int(sampling_strat * df.shape[0]) @@ -747,7 +765,7 @@ def sample_df( idx = rng.choice(df.index, N) X_c = df.loc[idx].reset_index()[["x1", "x2"]].values.T.astype(self.dtype) Y_c = df.loc[idx].values.T - elif sampling_strat in ["all", "split"]: + elif isinstance(sampling_strat, str) and sampling_strat in ["all", "split"]: # NOTE if "split", we assume that the context-target split has already been applied to the df # in an earlier scope with access to both the context and target data. This is maybe risky! X_c = df.reset_index()[["x1", "x2"]].values.T.astype(self.dtype) @@ -778,6 +796,7 @@ def sample_offgrid_aux( self, X_t: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]], offgrid_aux: Union[xr.DataArray, xr.Dataset], + sample_patch_size: Optional[list[float]] = None ) -> np.ndarray: """ Sample auxiliary data at off-grid locations. @@ -789,6 +808,8 @@ def sample_offgrid_aux( tuple of two numpy arrays, or a single numpy array. offgrid_aux : :class:`xarray.DataArray` | :class:`xarray.Dataset` Auxiliary data at off-grid locations. + sample_patch_size: list[float], optional + desired patch size extent to sample [lat_min, lat_max, lon_min, lon_max] Returns ------- @@ -801,6 +822,12 @@ def sample_offgrid_aux( xt2 = xt2.ravel() else: xt1, xt2 = xr.DataArray(X_t[0]), xr.DataArray(X_t[1]) + + if sample_patch_size is not None: + x1_slice = slice(sample_patch_size[0], sample_patch_size[1]) + x2_slice = slice(sample_patch_size[2], sample_patch_size[3]) + offgrid_aux = offgrid_aux.sel(x1=x1_slice, x2=x2_slice) + Y_t_aux = offgrid_aux.sel(x1=xt1, x2=xt2, method="nearest") if isinstance(Y_t_aux, xr.Dataset): Y_t_aux = Y_t_aux.to_array() @@ -882,6 +909,7 @@ def task_generation( List[Union[str, int, float, np.ndarray]], ] = "all", split_frac: float = 0.5, + patch_size: Sequence[float] = None, datewise_deterministic: bool = False, seed_override: Optional[int] = None, ) -> Task: @@ -915,6 +943,9 @@ def task_generation( "split" sampling strategy for linked context and target set pairs. The remaining observations are used for the target set. Default is 0.5. + patch_size: Sequence[float], optional + Desired patch size in lat/lon used for patchwise task generation. Usefule when considering + the entire available region is computationally prohibitive for model forward pass datewise_deterministic : bool Whether random sampling is datewise_deterministic based on the date. Default is ``False``. @@ -1034,7 +1065,7 @@ def time_slice_variable(var, delta_t): raise ValueError(f"Unknown variable type {type(var)}") return var - def sample_variable(var, sampling_strat, seed): + def sample_variable(var, sampling_strat, sample_patch_size, seed): """ Sample a variable by a given sampling strategy to get input and output data. @@ -1059,9 +1090,9 @@ def sample_variable(var, sampling_strat, seed): If the variable is of an unknown type. """ if isinstance(var, (xr.Dataset, xr.DataArray)): - X, Y = self.sample_da(var, sampling_strat, seed) + X, Y = self.sample_da(var, sampling_strat, sample_patch_size, seed) elif isinstance(var, (pd.DataFrame, pd.Series)): - X, Y = self.sample_df(var, sampling_strat, seed) + X, Y = self.sample_df(var, sampling_strat, sample_patch_size, seed) else: raise ValueError(f"Unknown type {type(var)} for context set " f"{var}") return X, Y @@ -1104,6 +1135,12 @@ def sample_variable(var, sampling_strat, seed): # 'Truly' random sampling seed = None + # check patch size + if patch_size is not None: + assert len(patch_size) == 2, "Patch size must be a Sequence of two values for lat/lon extent." + assert all(0 < x <= 1 for x in patch_size), "Values specified for patch size must satisfy 0 < x <= 1." + self.patch_size = patch_size + task = {} task["time"] = date @@ -1159,12 +1196,12 @@ def sample_variable(var, sampling_strat, seed): zip(context_slices, context_sampling) ): context_seed = seed + i if seed is not None else None - X_c, Y_c = sample_variable(var, sampling_strat, context_seed) + X_c, Y_c = sample_variable(var, sampling_strat, sample_patch_size, context_seed) task[f"X_c"].append(X_c) task[f"Y_c"].append(Y_c) for j, (var, sampling_strat) in enumerate(zip(target_slices, target_sampling)): target_seed = seed + i + j if seed is not None else None - X_t, Y_t = sample_variable(var, sampling_strat, target_seed) + X_t, Y_t = sample_variable(var, sampling_strat, sample_patch_size, target_seed) task[f"X_t"].append(X_t) task[f"Y_t"].append(Y_t) @@ -1176,7 +1213,7 @@ def sample_variable(var, sampling_strat, seed): X_c_offrid_all = np.empty((2, 0), dtype=self.dtype) else: X_c_offrid_all = np.concatenate(X_c_offgrid, axis=1) - Y_c_aux = self.sample_offgrid_aux(X_c_offrid_all, self.aux_at_contexts) + Y_c_aux = self.sample_offgrid_aux(X_c_offrid_all, self.aux_at_contexts, sample_patch_size) task["X_c"].append(X_c_offrid_all) task["Y_c"].append(Y_c_aux) diff --git a/tests/test_task_loader.py b/tests/test_task_loader.py index 0c04f93b..135a1861 100644 --- a/tests/test_task_loader.py +++ b/tests/test_task_loader.py @@ -1,5 +1,7 @@ import itertools +from typing import Sequence + from parameterized import parameterized import xarray as xr @@ -366,6 +368,25 @@ def test_saving_and_loading(self): tl_loaded.target_delta_t, "target_delta_t not saved and loaded correctly", ) + @parameterized.expand([ + [(0.3, 0.3)], + [(0.6, 0.4)] + ]) + def test_patch_size(self, patch_size: Sequence[float]) -> None: + """Test patch size sampling.""" + context = [self.da, self.df] + + tl = TaskLoader( + context=context, # gridded xarray and off-grid pandas contexts + target=self.df, # off-grid pandas targets + ) + + for context_sampling, target_sampling in self._gen_task_loader_call_args( + len(context), 1 + ): + if isinstance(context_sampling[0], np.ndarray): + continue + task = tl("2020-01-01", context_sampling, target_sampling, patch_size=patch_size) if __name__ == "__main__": From b7cf3fa7c338c406113dd6fd1c1a800743e89c3c Mon Sep 17 00:00:00 2001 From: nilsleh Date: Thu, 12 Oct 2023 12:46:52 +0000 Subject: [PATCH 03/69] draft --- deepsensor/data/loader.py | 60 ++++++++++++++++++++++++++------------- tests/test_task_loader.py | 12 ++++---- 2 files changed, 47 insertions(+), 25 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 2f558104..eed92a16 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -615,7 +615,7 @@ def sample_da( sampling_strat : str | int | float | :class:`numpy:numpy.ndarray` Sampling strategy, either "all" or an integer for random grid cell sampling. - sample_patch_size: list + sample_patch_size: list desired patch size extent to sample [lat_min, lat_max, lon_min, lon_max] seed : int, optional Seed for random sampling. Default: None. @@ -753,9 +753,13 @@ def sample_df( if sample_patch_size is not None: # retrieve desired patch size lat_min, lat_max, lon_min, lon_max = sample_patch_size - df = df[(df.index.get_level_values('x1') >= lat_min) & (df.index.get_level_values('x1') <= lat_max) & - (df.index.get_level_values('x2') >= lon_min) & (df.index.get_level_values('x2') <= lon_max)] - + df = df[ + (df.index.get_level_values("x1") >= lat_min) + & (df.index.get_level_values("x1") <= lat_max) + & (df.index.get_level_values("x2") >= lon_min) + & (df.index.get_level_values("x2") <= lon_max) + ] + if isinstance(sampling_strat, float): sampling_strat = int(sampling_strat * df.shape[0]) @@ -796,7 +800,7 @@ def sample_offgrid_aux( self, X_t: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]], offgrid_aux: Union[xr.DataArray, xr.Dataset], - sample_patch_size: Optional[list[float]] = None + sample_patch_size: Optional[list[float]] = None, ) -> np.ndarray: """ Sample auxiliary data at off-grid locations. @@ -808,7 +812,7 @@ def sample_offgrid_aux( tuple of two numpy arrays, or a single numpy array. offgrid_aux : :class:`xarray.DataArray` | :class:`xarray.Dataset` Auxiliary data at off-grid locations. - sample_patch_size: list[float], optional + sample_patch_size: list[float], optional desired patch size extent to sample [lat_min, lat_max, lon_min, lon_max] Returns @@ -838,7 +842,7 @@ def sample_offgrid_aux( # Reshape to (variable, *spatial_dims) Y_t_aux = Y_t_aux.reshape(1, *Y_t_aux.shape) return Y_t_aux - + def sample_patch_size_extent(self) -> Sequence[float]: """Sample patch size. @@ -858,7 +862,12 @@ def sample_patch_size_extent(self) -> Sequence[float]: lon_point = random.uniform(lon_side, 1 - lon_side) # bbox of lat_min, lat_max, lon_min, lon_max - bbox = [lat_point - lat_side, lat_point + lat_side, lon_point - lon_side, lon_point + lon_side] + bbox = [ + lat_point - lat_side, + lat_point + lat_side, + lon_point - lon_side, + lon_point + lon_side, + ] x1_slice = slice(bbox[0], bbox[1]) x2_slice = slice(bbox[2], bbox[3]) @@ -870,25 +879,28 @@ def sample_patch_size_extent(self) -> Sequence[float]: else: data = target_var.sel(x1=x1_slice, x2=x2_slice) - target_check.append(True if len(data)>0 else False) + target_check.append(True if len(data) > 0 else False) # check whether context is non-empty given this box context_check: list[bool] = [] for context_var in self.context: if isinstance(context_var, (pd.DataFrame, pd.Series)): - data = context_var[(context_var.index.get_level_values('x1') >= bbox[0]) & (context_var.index.get_level_values('x1') <= bbox[1]) & - (context_var.index.get_level_values('x2') >= bbox[2]) & (context_var.index.get_level_values('x2') <= bbox[3])] + data = context_var[ + (context_var.index.get_level_values("x1") >= bbox[0]) + & (context_var.index.get_level_values("x1") <= bbox[1]) + & (context_var.index.get_level_values("x2") >= bbox[2]) + & (context_var.index.get_level_values("x2") <= bbox[3]) + ] # data = context_var.loc[(slice(None), x1_slice, x2_slice)] else: data = context_var.sel(x1=x1_slice, x2=x2_slice) - context_check.append(True if len(data)>0 else False) - + context_check.append(True if len(data) > 0 else False) if all(target_check) and all(context_check): continue_looking = False - + return bbox def task_generation( @@ -1137,8 +1149,12 @@ def sample_variable(var, sampling_strat, sample_patch_size, seed): # check patch size if patch_size is not None: - assert len(patch_size) == 2, "Patch size must be a Sequence of two values for lat/lon extent." - assert all(0 < x <= 1 for x in patch_size), "Values specified for patch size must satisfy 0 < x <= 1." + assert ( + len(patch_size) == 2 + ), "Patch size must be a Sequence of two values for lat/lon extent." + assert all( + 0 < x <= 1 for x in patch_size + ), "Values specified for patch size must satisfy 0 < x <= 1." self.patch_size = patch_size task = {} @@ -1196,12 +1212,16 @@ def sample_variable(var, sampling_strat, sample_patch_size, seed): zip(context_slices, context_sampling) ): context_seed = seed + i if seed is not None else None - X_c, Y_c = sample_variable(var, sampling_strat, sample_patch_size, context_seed) + X_c, Y_c = sample_variable( + var, sampling_strat, sample_patch_size, context_seed + ) task[f"X_c"].append(X_c) task[f"Y_c"].append(Y_c) for j, (var, sampling_strat) in enumerate(zip(target_slices, target_sampling)): target_seed = seed + i + j if seed is not None else None - X_t, Y_t = sample_variable(var, sampling_strat, sample_patch_size, target_seed) + X_t, Y_t = sample_variable( + var, sampling_strat, sample_patch_size, target_seed + ) task[f"X_t"].append(X_t) task[f"Y_t"].append(Y_t) @@ -1213,7 +1233,9 @@ def sample_variable(var, sampling_strat, sample_patch_size, seed): X_c_offrid_all = np.empty((2, 0), dtype=self.dtype) else: X_c_offrid_all = np.concatenate(X_c_offgrid, axis=1) - Y_c_aux = self.sample_offgrid_aux(X_c_offrid_all, self.aux_at_contexts, sample_patch_size) + Y_c_aux = self.sample_offgrid_aux( + X_c_offrid_all, self.aux_at_contexts, sample_patch_size + ) task["X_c"].append(X_c_offrid_all) task["Y_c"].append(Y_c_aux) diff --git a/tests/test_task_loader.py b/tests/test_task_loader.py index 135a1861..296263da 100644 --- a/tests/test_task_loader.py +++ b/tests/test_task_loader.py @@ -368,10 +368,8 @@ def test_saving_and_loading(self): tl_loaded.target_delta_t, "target_delta_t not saved and loaded correctly", ) - @parameterized.expand([ - [(0.3, 0.3)], - [(0.6, 0.4)] - ]) + + @parameterized.expand([[(0.3, 0.3)], [(0.6, 0.4)]]) def test_patch_size(self, patch_size: Sequence[float]) -> None: """Test patch size sampling.""" context = [self.da, self.df] @@ -383,10 +381,12 @@ def test_patch_size(self, patch_size: Sequence[float]) -> None: for context_sampling, target_sampling in self._gen_task_loader_call_args( len(context), 1 - ): + ): if isinstance(context_sampling[0], np.ndarray): continue - task = tl("2020-01-01", context_sampling, target_sampling, patch_size=patch_size) + task = tl( + "2020-01-01", context_sampling, target_sampling, patch_size=patch_size + ) if __name__ == "__main__": From 379e3b2de7692365fc05aa99a48f758905015115 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Thu, 12 Oct 2023 14:25:59 +0000 Subject: [PATCH 04/69] wrong merge --- deepsensor/data/loader.py | 42 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 7eb98af2..5681e610 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1269,6 +1269,48 @@ def sample_variable(var, sampling_strat, sample_patch_size, seed): f"with the `links` attribute if using the 'gapfill' sampling strategy" ) + context_var = context_slices[context_idx] + target_var = target_slices[target_idx] + + for var in [context_var, target_var]: + assert isinstance(var, (xr.DataArray, xr.Dataset)), ( + f"If using 'gapfill' sampling strategy for linked context and target sets, " + f"the context and target sets must be xarray DataArrays or Datasets, " + f"but got {type(var)}." + ) + + split_seed = seed + gapfill_i if seed is not None else None + rng = np.random.default_rng(split_seed) + + # Keep trying until we get a target set with at least one target point + keep_searching = True + while keep_searching: + added_mask_date = rng.choice(self.context[context_idx].time) + added_mask = ( + self.context[context_idx].sel(time=added_mask_date).isnull() + ) + curr_mask = context_var.isnull() + + # Mask out added missing values + context_var = context_var.where(~added_mask) + + # TEMP: Inefficient to convert all non-targets to NaNs and then remove NaNs + # when we could just slice the target values here + target_mask = added_mask & ~curr_mask + if isinstance(target_var, xr.Dataset): + keep_searching = np.all(target_mask.to_array().data == False) + else: + keep_searching = np.all(target_mask.data == False) + if keep_searching: + continue # No target points -- use a different `added_mask` + + target_var = target_var.where( + target_mask + ) # Only keep target locations + + context_slices[context_idx] = context_var + target_slices[target_idx] = target_var + # sample common patch size for context and target set if self.patch_size is not None: sample_patch_size = self.sample_patch_size_extent() From 85cd34b543dda9ac9966abd4ac032bb764fafc34 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Fri, 13 Oct 2023 09:49:22 +0000 Subject: [PATCH 05/69] incorporate some of the feedback --- deepsensor/data/loader.py | 240 +++++++++++++++++++++----------------- tests/test_task_loader.py | 4 +- 2 files changed, 136 insertions(+), 108 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 5681e610..7f2c1374 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -4,6 +4,7 @@ import json import copy import random +import itertools import numpy as np import xarray as xr @@ -186,6 +187,8 @@ def __init__( self.aux_at_target_var_IDs, ) = self.infer_context_and_target_var_IDs() + self.coord_bounds = self._compute_global_coordinate_bounds() + def _set_config(self): """Instantiate a config dictionary for the TaskLoader object""" # Take deepcopy to avoid modifying the original config @@ -588,7 +591,6 @@ def sample_da( self, da: Union[xr.DataArray, xr.Dataset], sampling_strat: Union[str, int, float, np.ndarray], - sample_patch_size: Optional[list[float]] = None, seed: Optional[int] = None, ) -> (np.ndarray, np.ndarray): """ @@ -601,8 +603,6 @@ def sample_da( sampling_strat : str | int | float | :class:`numpy:numpy.ndarray` Sampling strategy, either "all" or an integer for random grid cell sampling. - sample_patch_size: list - desired patch size extent to sample [lat_min, lat_max, lon_min, lon_max] seed : int, optional Seed for random sampling. Default: None. @@ -623,12 +623,6 @@ def sample_da( if isinstance(da, xr.Dataset): da = da.to_array() - # restric to a certain spatial patch - if sample_patch_size is not None: - x1_slice = slice(sample_patch_size[0], sample_patch_size[1]) - x2_slice = slice(sample_patch_size[2], sample_patch_size[3]) - da = da.sel(x1=x1_slice, x2=x2_slice) - if isinstance(sampling_strat, float): sampling_strat = int(sampling_strat * da.size) @@ -697,7 +691,6 @@ def sample_df( self, df: Union[pd.DataFrame, pd.Series], sampling_strat: Union[str, int, float, np.ndarray], - sample_patch_size: Optional[list[float]] = None, seed: Optional[int] = None, ) -> (np.ndarray, np.ndarray): """ @@ -711,8 +704,6 @@ def sample_df( sampling_strat : str | int | float | :class:`numpy:numpy.ndarray` Sampling strategy, either "all" or an integer for random grid cell sampling. - sample_patch_size: list[float], optional - desired patch size extent to sample [lat_min, lat_max, lon_min, lon_max] seed : int, optional Seed for random sampling. Default: None. @@ -731,16 +722,6 @@ def sample_df( """ df = df.dropna(how="any") # If any obs are NaN, drop them - if sample_patch_size is not None: - # retrieve desired patch size - lat_min, lat_max, lon_min, lon_max = sample_patch_size - df = df[ - (df.index.get_level_values("x1") >= lat_min) - & (df.index.get_level_values("x1") <= lat_max) - & (df.index.get_level_values("x2") >= lon_min) - & (df.index.get_level_values("x2") <= lon_max) - ] - if isinstance(sampling_strat, float): sampling_strat = int(sampling_strat * df.shape[0]) @@ -781,7 +762,6 @@ def sample_offgrid_aux( self, X_t: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]], offgrid_aux: Union[xr.DataArray, xr.Dataset], - sample_patch_size: Optional[list[float]] = None, ) -> np.ndarray: """ Sample auxiliary data at off-grid locations. @@ -793,8 +773,6 @@ def sample_offgrid_aux( tuple of two numpy arrays, or a single numpy array. offgrid_aux : :class:`xarray.DataArray` | :class:`xarray.Dataset` Auxiliary data at off-grid locations. - sample_patch_size: list[float], optional - desired patch size extent to sample [lat_min, lat_max, lon_min, lon_max] Returns ------- @@ -813,11 +791,6 @@ def sample_offgrid_aux( else: xt1, xt2 = xr.DataArray(X_t[0]), xr.DataArray(X_t[1]) - if sample_patch_size is not None: - x1_slice = slice(sample_patch_size[0], sample_patch_size[1]) - x2_slice = slice(sample_patch_size[2], sample_patch_size[3]) - offgrid_aux = offgrid_aux.sel(x1=x1_slice, x2=x2_slice) - Y_t_aux = offgrid_aux.sel(x1=xt1, x2=xt2, method="nearest") if isinstance(Y_t_aux, xr.Dataset): Y_t_aux = Y_t_aux.to_array() @@ -828,64 +801,75 @@ def sample_offgrid_aux( # Reshape to (variable, *spatial_dims) Y_t_aux = Y_t_aux.reshape(1, *Y_t_aux.shape) return Y_t_aux - - def sample_patch_size_extent(self) -> Sequence[float]: - """Sample patch size. - - :return sequence of patch spatial extent as [lat_min, lat_max, lon_min, lon_max] + + def _compute_global_coordinate_bounds(self) -> list[float]: """ - # assumption of normalized spatial coordinates between 0 and 1 + Compute global coordinate bounds in order to sample spatial bounds if desired. - lat_extend, lon_extend = self.patch_size + Returns + ------- + bbox: List[float] + sequence of global spatial extent as [x1_min, x1_max, x2_min, x2_max] + """ + x1_min, x1_max, x2_min, x2_max = np.PINF, np.NINF, np.PINF, np.NINF + + for var in itertools.chain(self.context, self.target): + if isinstance(var, (xr.Dataset, xr.DataArray)): + var_x1_min = var.x1.min().item() + var_x1_max = var.x1.max().item() + var_x2_min = var.x2.min().item() + var_x2_max = var.x2.max().item() + elif isinstance(var, (pd.DataFrame, pd.Series)): + var_x1_min = var.index.get_level_values('x1').min() + var_x1_max = var.index.get_level_values('x1').max() + var_x2_min = var.index.get_level_values('x2').min() + var_x2_max = var.index.get_level_values('x2').max() - lat_side = lat_extend / 2 - lon_side = lon_extend / 2 + if var_x1_min < x1_min: + x1_min = var_x1_min - # sample a point that satisfies the boundary and target conditions - continue_looking = True - while continue_looking: - lat_point = random.uniform(lat_side, 1 - lat_side) - lon_point = random.uniform(lon_side, 1 - lon_side) + if var_x1_max > x1_max: + x1_max = var_x1_max - # bbox of lat_min, lat_max, lon_min, lon_max - bbox = [ - lat_point - lat_side, - lat_point + lat_side, - lon_point - lon_side, - lon_point + lon_side, - ] + if var_x2_min < x2_min: + x2_min = var_x2_min - x1_slice = slice(bbox[0], bbox[1]) - x2_slice = slice(bbox[2], bbox[3]) - # check whether target is non-empty given this box - target_check: list[bool] = [] - for target_var in self.target: - if isinstance(target_var, (pd.DataFrame, pd.Series)): - data = target_var.loc[(slice(None), x1_slice, x2_slice)] - else: - data = target_var.sel(x1=x1_slice, x2=x2_slice) - - target_check.append(True if len(data) > 0 else False) - - # check whether context is non-empty given this box - context_check: list[bool] = [] - for context_var in self.context: - if isinstance(context_var, (pd.DataFrame, pd.Series)): - data = context_var[ - (context_var.index.get_level_values("x1") >= bbox[0]) - & (context_var.index.get_level_values("x1") <= bbox[1]) - & (context_var.index.get_level_values("x2") >= bbox[2]) - & (context_var.index.get_level_values("x2") <= bbox[3]) - ] + if var_x2_max > x2_max: + x2_max = var_x2_max + + return [x1_min, x1_max, x2_min, x2_max] - # data = context_var.loc[(slice(None), x1_slice, x2_slice)] - else: - data = context_var.sel(x1=x1_slice, x2=x2_slice) + def sample_random_window(self, window_size: tuple[float]) -> Sequence[float]: + """ + Sample random window uniformly from global coordinats to slice data. - context_check.append(True if len(data) > 0 else False) + Parameters + ---------- + window_size : Tuple[float] + Tuple of window extent - if all(target_check) and all(context_check): - continue_looking = False + Returns + ------- + bbox: List[float] + sequence of patch spatial extent as [x1_min, x1_max, x2_min, x2_max] + """ + x1_extend, x2_extend = window_size + + x1_side = x1_extend / 2 + x2_side = x2_extend / 2 + + # sample a point that satisfies the context and target global bounds + x1_min, x1_max, x2_min, x2_max = self.coord_bounds + x1_point = random.uniform(x1_min + x1_side, x1_max - x1_side) + x2_point = random.uniform(x2_min + x2_side, x2_max - x2_side) + + # bbox of x1_min, x1_max, x2_min, x2_max + bbox = [ + x1_point - x1_side, + x1_point + x1_side, + x2_point - x2_side, + x2_point + x2_side, + ] return bbox @@ -921,7 +905,46 @@ def time_slice_variable(self, var, date, delta_t=0): else: raise ValueError(f"Unknown variable type {type(var)}") return var + + def spatial_slice_variable(self, var, window: list[float]): + """ + Slice a variabel by a given window size. + + Parameters + ---------- + var : ... + Variable to slice + window : ... + list of coordinates specifying the window [x1_min, x1_max, x2_min, x2_max] + + Returns + ------- + var : ... + Sliced variable. + Raises + ------ + ValueError + If the variable is of an unknown type. + """ + x1_min, x1_max, x2_min, x2_max = window + if isinstance(var, (xr.Dataset, xr.DataArray)): + x1_slice = slice(x1_min, x1_max) + x2_slice = slice(x2_min, x2_max) + var = var.sel(x1=x1_slice, x2=x2_slice) + elif isinstance(var, (pd.DataFrame, pd.Series)): + # retrieve desired patch size + var = var[ + (var.index.get_level_values("x1") >= x1_min) + & (var.index.get_level_values("x1") <= x1_max) + & (var.index.get_level_values("x2") >= x2_min) + & (var.index.get_level_values("x2") <= x2_max) + ] + else: + raise ValueError(f"Unknown variable type {type(var)}") + + return var + def task_generation( self, date: pd.Timestamp, @@ -940,7 +963,7 @@ def task_generation( List[Union[str, int, float, np.ndarray]], ] = "all", split_frac: float = 0.5, - patch_size: Sequence[float] = None, + window_size: Sequence[float] = None, datewise_deterministic: bool = False, seed_override: Optional[int] = None, ) -> Task: @@ -974,8 +997,8 @@ def task_generation( "split" sampling strategy for linked context and target set pairs. The remaining observations are used for the target set. Default is 0.5. - patch_size: Sequence[float], optional - Desired patch size in lat/lon used for patchwise task generation. Usefule when considering + window_size : Sequence[float], optional + Desired patch size in x1/x2 used for patchwise task generation. Usefule when considering the entire available region is computationally prohibitive for model forward pass datewise_deterministic : bool Whether random sampling is datewise_deterministic based on the @@ -1067,7 +1090,7 @@ def check_sampling_strat(sampling_strat, set): return sampling_strat - def sample_variable(var, sampling_strat, sample_patch_size, seed): + def sample_variable(var, sampling_strat, seed): """ Sample a variable by a given sampling strategy to get input and output data. @@ -1078,8 +1101,6 @@ def sample_variable(var, sampling_strat, sample_patch_size, seed): Variable to sample. sampling_strat : ... Sampling strategy to use. - sample_patch_size: ... - Desired sample patch size seed : ... Seed for random sampling. @@ -1094,9 +1115,9 @@ def sample_variable(var, sampling_strat, sample_patch_size, seed): If the variable is of an unknown type. """ if isinstance(var, (xr.Dataset, xr.DataArray)): - X, Y = self.sample_da(var, sampling_strat, sample_patch_size, seed) + X, Y = self.sample_da(var, sampling_strat, seed) elif isinstance(var, (pd.DataFrame, pd.Series)): - X, Y = self.sample_df(var, sampling_strat, sample_patch_size, seed) + X, Y = self.sample_df(var, sampling_strat, seed) else: raise ValueError(f"Unknown type {type(var)} for context set " f"{var}") return X, Y @@ -1163,16 +1184,6 @@ def sample_variable(var, sampling_strat, sample_patch_size, seed): # 'Truly' random sampling seed = None - # check patch size - if patch_size is not None: - assert ( - len(patch_size) == 2 - ), "Patch size must be a Sequence of two values for lat/lon extent." - assert all( - 0 < x <= 1 for x in patch_size - ), "Values specified for patch size must satisfy 0 < x <= 1." - self.patch_size = patch_size - task = {} task["time"] = date @@ -1182,6 +1193,7 @@ def sample_variable(var, sampling_strat, sample_patch_size, seed): task["X_t"] = [] task["Y_t"] = [] + # temporal slices context_slices = [ self.time_slice_variable(var, date, delta_t) for var, delta_t in zip(self.context, self.context_delta_t) @@ -1191,6 +1203,27 @@ def sample_variable(var, sampling_strat, sample_patch_size, seed): for var, delta_t in zip(self.target, self.target_delta_t) ] + # check patch size + if window_size is not None: + assert ( + len(window_size) == 2 + ), "Patch size must be a Sequence of two values for x1/x2 extent." + assert all( + 0 < x <= 1 for x in window_size + ), "Values specified for patch size must satisfy 0 < x <= 1." + + window = self.sample_random_window(window_size) + + # spatial slices + context_slices = [ + self.spatial_slice_variable(var, window) + for var in context_slices + ] + target_slices = [ + self.spatial_slice_variable(var, window) + for var in target_slices + ] + # TODO move to method if ( self.links is not None @@ -1311,25 +1344,20 @@ def sample_variable(var, sampling_strat, sample_patch_size, seed): context_slices[context_idx] = context_var target_slices[target_idx] = target_var - # sample common patch size for context and target set - if self.patch_size is not None: - sample_patch_size = self.sample_patch_size_extent() - else: - sample_patch_size = None - + for i, (var, sampling_strat) in enumerate( zip(context_slices, context_sampling) ): context_seed = seed + i if seed is not None else None X_c, Y_c = sample_variable( - var, sampling_strat, sample_patch_size, context_seed + var, sampling_strat, context_seed ) task[f"X_c"].append(X_c) task[f"Y_c"].append(Y_c) for j, (var, sampling_strat) in enumerate(zip(target_slices, target_sampling)): target_seed = seed + i + j if seed is not None else None X_t, Y_t = sample_variable( - var, sampling_strat, sample_patch_size, target_seed + var, sampling_strat, target_seed ) task[f"X_t"].append(X_t) task[f"Y_t"].append(Y_t) @@ -1344,7 +1372,7 @@ def sample_variable(var, sampling_strat, sample_patch_size, seed): X_c_offrid_all = np.concatenate(X_c_offgrid, axis=1) Y_c_aux = ( self.sample_offgrid_aux( - X_c_offrid_all, self.time_slice_variable(self.aux_at_contexts, date), sample_patch_size + X_c_offrid_all, self.time_slice_variable(self.aux_at_contexts, date) ), ) task["X_c"].append(X_c_offrid_all) diff --git a/tests/test_task_loader.py b/tests/test_task_loader.py index 5f807d27..d40bb2d1 100644 --- a/tests/test_task_loader.py +++ b/tests/test_task_loader.py @@ -409,7 +409,7 @@ def test_saving_and_loading(self): ) @parameterized.expand([[(0.3, 0.3)], [(0.6, 0.4)]]) - def test_patch_size(self, patch_size: Sequence[float]) -> None: + def test_window_size(self, window_size: Sequence[float]) -> None: """Test patch size sampling.""" context = [self.da, self.df] @@ -424,7 +424,7 @@ def test_patch_size(self, patch_size: Sequence[float]) -> None: if isinstance(context_sampling[0], np.ndarray): continue task = tl( - "2020-01-01", context_sampling, target_sampling, patch_size=patch_size + "2020-01-01", context_sampling, target_sampling, window_size=window_size ) From be8fffd89ab05b95d7dee621a6a0c8be5e5dc957 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Fri, 13 Oct 2023 09:54:59 +0000 Subject: [PATCH 06/69] run black --- deepsensor/data/loader.py | 37 +++++++++++++++---------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 7f2c1374..72fb0fbc 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -801,7 +801,7 @@ def sample_offgrid_aux( # Reshape to (variable, *spatial_dims) Y_t_aux = Y_t_aux.reshape(1, *Y_t_aux.shape) return Y_t_aux - + def _compute_global_coordinate_bounds(self) -> list[float]: """ Compute global coordinate bounds in order to sample spatial bounds if desired. @@ -812,7 +812,7 @@ def _compute_global_coordinate_bounds(self) -> list[float]: sequence of global spatial extent as [x1_min, x1_max, x2_min, x2_max] """ x1_min, x1_max, x2_min, x2_max = np.PINF, np.NINF, np.PINF, np.NINF - + for var in itertools.chain(self.context, self.target): if isinstance(var, (xr.Dataset, xr.DataArray)): var_x1_min = var.x1.min().item() @@ -820,13 +820,13 @@ def _compute_global_coordinate_bounds(self) -> list[float]: var_x2_min = var.x2.min().item() var_x2_max = var.x2.max().item() elif isinstance(var, (pd.DataFrame, pd.Series)): - var_x1_min = var.index.get_level_values('x1').min() - var_x1_max = var.index.get_level_values('x1').max() - var_x2_min = var.index.get_level_values('x2').min() - var_x2_max = var.index.get_level_values('x2').max() + var_x1_min = var.index.get_level_values("x1").min() + var_x1_max = var.index.get_level_values("x1").max() + var_x2_min = var.index.get_level_values("x2").min() + var_x2_max = var.index.get_level_values("x2").max() if var_x1_min < x1_min: - x1_min = var_x1_min + x1_min = var_x1_min if var_x1_max > x1_max: x1_max = var_x1_max @@ -836,7 +836,7 @@ def _compute_global_coordinate_bounds(self) -> list[float]: if var_x2_max > x2_max: x2_max = var_x2_max - + return [x1_min, x1_max, x2_min, x2_max] def sample_random_window(self, window_size: tuple[float]) -> Sequence[float]: @@ -872,7 +872,7 @@ def sample_random_window(self, window_size: tuple[float]) -> Sequence[float]: ] return bbox - + def time_slice_variable(self, var, date, delta_t=0): """ Slice a variable by a given time delta. @@ -905,7 +905,7 @@ def time_slice_variable(self, var, date, delta_t=0): else: raise ValueError(f"Unknown variable type {type(var)}") return var - + def spatial_slice_variable(self, var, window: list[float]): """ Slice a variabel by a given window size. @@ -944,7 +944,7 @@ def spatial_slice_variable(self, var, window: list[float]): raise ValueError(f"Unknown variable type {type(var)}") return var - + def task_generation( self, date: pd.Timestamp, @@ -1216,12 +1216,10 @@ def sample_variable(var, sampling_strat, seed): # spatial slices context_slices = [ - self.spatial_slice_variable(var, window) - for var in context_slices + self.spatial_slice_variable(var, window) for var in context_slices ] target_slices = [ - self.spatial_slice_variable(var, window) - for var in target_slices + self.spatial_slice_variable(var, window) for var in target_slices ] # TODO move to method @@ -1344,21 +1342,16 @@ def sample_variable(var, sampling_strat, seed): context_slices[context_idx] = context_var target_slices[target_idx] = target_var - for i, (var, sampling_strat) in enumerate( zip(context_slices, context_sampling) ): context_seed = seed + i if seed is not None else None - X_c, Y_c = sample_variable( - var, sampling_strat, context_seed - ) + X_c, Y_c = sample_variable(var, sampling_strat, context_seed) task[f"X_c"].append(X_c) task[f"Y_c"].append(Y_c) for j, (var, sampling_strat) in enumerate(zip(target_slices, target_sampling)): target_seed = seed + i + j if seed is not None else None - X_t, Y_t = sample_variable( - var, sampling_strat, target_seed - ) + X_t, Y_t = sample_variable(var, sampling_strat, target_seed) task[f"X_t"].append(X_t) task[f"Y_t"].append(Y_t) From 876970ebf59066e68706f6530278944c91ac82df Mon Sep 17 00:00:00 2001 From: nilsleh Date: Fri, 12 Apr 2024 13:45:50 +0000 Subject: [PATCH 07/69] layout code --- deepsensor/data/loader.py | 89 ++++++++++++++++++++++++++++----------- 1 file changed, 65 insertions(+), 24 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 7abe5067..57f38875 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -832,13 +832,13 @@ def _compute_global_coordinate_bounds(self) -> List[float]: return [x1_min, x1_max, x2_min, x2_max] - def sample_random_window(self, window_size: Tuple[float]) -> Sequence[float]: + def sample_random_window(self, patch_size: Tuple[float]) -> Sequence[float]: """ Sample random window uniformly from global coordinats to slice data. Parameters ---------- - window_size : Tuple[float] + patch_size : Tuple[float] Tuple of window extent Returns @@ -846,7 +846,7 @@ def sample_random_window(self, window_size: Tuple[float]) -> Sequence[float]: bbox: List[float] sequence of patch spatial extent as [x1_min, x1_max, x2_min, x2_max] """ - x1_extend, x2_extend = window_size + x1_extend, x2_extend = patch_size x1_side = x1_extend / 2 x2_side = x2_extend / 2 @@ -895,7 +895,6 @@ def time_slice_variable(self, var, date, delta_t=0): else: raise ValueError(f"Unknown variable type {type(var)}") return var - def spatial_slice_variable(self, var, window: List[float]): """ @@ -992,7 +991,7 @@ def task_generation( ] ] = None, split_frac: float = 0.5, - window_size: Sequence[float] = None, + patch_size: Sequence[float] = None, datewise_deterministic: bool = False, seed_override: Optional[int] = None, ) -> Task: @@ -1026,7 +1025,7 @@ def task_generation( "split" sampling strategy for linked context and target set pairs. The remaining observations are used for the target set. Default is 0.5. - window_size : Sequence[float], optional + patch_size : Sequence[float], optional Desired patch size in x1/x2 used for patchwise task generation. Usefule when considering the entire available region is computationally prohibitive for model forward pass datewise_deterministic : bool @@ -1232,22 +1231,22 @@ def sample_variable(var, sampling_strat, seed): ] # check patch size - if window_size is not None: + if patch_size is not None: assert ( - len(window_size) == 2 + len(patch_size) == 2 ), "Patch size must be a Sequence of two values for x1/x2 extent." assert all( - 0 < x <= 1 for x in window_size + 0 < x <= 1 for x in patch_size ), "Values specified for patch size must satisfy 0 < x <= 1." - window = self.sample_random_window(window_size) + patch = self.sample_random_window(patch_size) # spatial slices context_slices = [ - self.spatial_slice_variable(var, window) for var in context_slices + self.spatial_slice_variable(var, patch) for var in context_slices ] target_slices = [ - self.spatial_slice_variable(var, window) for var in target_slices + self.spatial_slice_variable(var, patch) for var in target_slices ] # TODO move to method @@ -1377,6 +1376,40 @@ def sample_variable(var, sampling_strat, seed): return Task(task) + def generate_tasks( + self, + dates: Union[pd.Timestamp, List[pd.Timestamp]], + patch_strategy: Optional[str], + **kwargs, + ) -> List[Task]: + """ + Generate a list of Tasks for Training or Inference. + + Args: + dates: Union[pd.Timestamp, List[pd.Timestamp]] + List of dates for which to generate the task. + patch_strategy: Optional[str] + Patch strategy to use for patchwise task generation. Default is None. + Possible options are 'random' or 'sliding'. + **kwargs: + Additional keyword arguments to pass to the task generation method. + """ + if patch_strategy is None: + tasks = [self.task_generation(date, **kwargs) for date in dates] + elif patch_strategy == "random": + # uniform random sampling of patch + pass + elif patch_strategy == "sliding": + # sliding window sampling of patch + pass + else: + raise ValueError( + f"Invalid patch strategy {patch_strategy}. " + f"Must be one of [None, 'random', 'sliding']." + ) + + return tasks + def __call__( self, date: pd.Timestamp, @@ -1397,6 +1430,8 @@ def __call__( ] ] = None, split_frac: float = 0.5, + patch_size: Sequence[float] = None, + patch_strategy: Optional[str] = None, datewise_deterministic: bool = False, seed_override: Optional[int] = None, ) -> Union[Task, List[Task]]: @@ -1443,9 +1478,12 @@ def __call__( the "split" sampling strategy for linked context and target set pairs. The remaining observations are used for the target set. Default is 0.5. - window_size : Sequence[float], optional + patch_size : Sequence[float], optional Desired patch size in x1/x2 used for patchwise task generation. Usefule when considering the entire available region is computationally prohibitive for model forward pass + patch_strategy: + Patch strategy to use for patchwise task generation. Default is None. + Possible options are 'random' or 'sliding'. datewise_deterministic (bool, optional): Whether random sampling is datewise deterministic based on the date. Default is ``False``. @@ -1459,18 +1497,21 @@ def __call__( Task object or list of task objects for each date containing the context and target data. """ + assert patch_strategy in [None, "random", "sliding"], ( + f"Invalid patch strategy {patch_strategy}. " + f"Must be one of [None, 'random', 'sliding']." + ) if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)): - return [ - self.task_generation( - d, - context_sampling, - target_sampling, - split_frac, - datewise_deterministic, - seed_override, - ) - for d in date - ] + return self.generate_tasks( + dates=date, + patch_strategy=patch_strategy, + context_sampling=context_sampling, + target_sampling=target_sampling, + split_frac=split_frac, + patch_size=patch_size, + datewise_deterministic=datewise_deterministic, + seed_override=seed_override, + ) else: return self.task_generation( date, From d1cb338985e74c4162b6e3638a09e85c16d7222c Mon Sep 17 00:00:00 2001 From: nilsleh Date: Fri, 12 Apr 2024 13:50:35 +0000 Subject: [PATCH 08/69] change __call__ --- deepsensor/data/loader.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 57f38875..2f55fed8 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1431,7 +1431,6 @@ def __call__( ] = None, split_frac: float = 0.5, patch_size: Sequence[float] = None, - patch_strategy: Optional[str] = None, datewise_deterministic: bool = False, seed_override: Optional[int] = None, ) -> Union[Task, List[Task]]: @@ -1497,14 +1496,10 @@ def __call__( Task object or list of task objects for each date containing the context and target data. """ - assert patch_strategy in [None, "random", "sliding"], ( - f"Invalid patch strategy {patch_strategy}. " - f"Must be one of [None, 'random', 'sliding']." - ) if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)): return self.generate_tasks( dates=date, - patch_strategy=patch_strategy, + patch_strategy="random" if patch_size is not None else None, context_sampling=context_sampling, target_sampling=target_sampling, split_frac=split_frac, From 218f791efb1f55debd0218fbb3bd39a632e171a0 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Fri, 12 Apr 2024 13:53:04 +0000 Subject: [PATCH 09/69] revert --- deepsensor/data/loader.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 2f55fed8..57f38875 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1431,6 +1431,7 @@ def __call__( ] = None, split_frac: float = 0.5, patch_size: Sequence[float] = None, + patch_strategy: Optional[str] = None, datewise_deterministic: bool = False, seed_override: Optional[int] = None, ) -> Union[Task, List[Task]]: @@ -1496,10 +1497,14 @@ def __call__( Task object or list of task objects for each date containing the context and target data. """ + assert patch_strategy in [None, "random", "sliding"], ( + f"Invalid patch strategy {patch_strategy}. " + f"Must be one of [None, 'random', 'sliding']." + ) if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)): return self.generate_tasks( dates=date, - patch_strategy="random" if patch_size is not None else None, + patch_strategy=patch_strategy, context_sampling=context_sampling, target_sampling=target_sampling, split_frac=split_frac, From 37fe771843e871a8db040c9eaac89c7cb4b60c38 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Fri, 12 Apr 2024 13:54:05 +0000 Subject: [PATCH 10/69] type annotation --- deepsensor/data/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 57f38875..97c4b7fe 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1378,7 +1378,7 @@ def sample_variable(var, sampling_strat, seed): def generate_tasks( self, - dates: Union[pd.Timestamp, List[pd.Timestamp]], + dates: Union[pd.Timestamp, Sequence[pd.Timestamp]], patch_strategy: Optional[str], **kwargs, ) -> List[Task]: @@ -1412,7 +1412,7 @@ def generate_tasks( def __call__( self, - date: pd.Timestamp, + date: Union[pd.Timestamp, Sequence[pd.Timestamp]], context_sampling: Union[ str, int, From fb20ccc0324cc77982d69a88b9d2f3a4f03d9e59 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Mon, 15 Apr 2024 10:39:47 +0200 Subject: [PATCH 11/69] patch_size sampling test --- deepsensor/data/loader.py | 42 +++++++++++++++++++++++++++++++-------- tests/test_task_loader.py | 4 ++-- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 97c4b7fe..5641efc4 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1086,6 +1086,11 @@ def check_sampling_strat(sampling_strat, set): raise InvalidSamplingStrategyError( f"Unknown sampling strategy {strat} of type {type(strat)}" ) + elif isinstance(strat, str) and strat == "gapfill": + assert all(isinstance(item, (xr.Dataset, xr.DataArray)) for item in set), ( + "Gapfill sampling strategy can only be used with xarray " + "datasets or data arrays" + ) elif isinstance(strat, str) and strat not in [ "all", "split", @@ -1397,11 +1402,18 @@ def generate_tasks( if patch_strategy is None: tasks = [self.task_generation(date, **kwargs) for date in dates] elif patch_strategy == "random": + assert "patch_size" in kwargs, "Patch size must be specified for random patch sampling." # uniform random sampling of patch - pass + tasks : list[Task] = [] + num_samples_per_date = kwargs.get("num_samples_per_date", 1) + new_kwargs = kwargs.copy() + new_kwargs.pop("num_samples_per_date", None) + for date in dates: + tasks.extend([self.task_generation(date, **new_kwargs) for _ in range(num_samples_per_date)]) + elif patch_strategy == "sliding": # sliding window sampling of patch - pass + tasks : list[Task] = [] else: raise ValueError( f"Invalid patch strategy {patch_strategy}. " @@ -1409,6 +1421,20 @@ def generate_tasks( ) return tasks + + def check_tasks(self, tasks: List[Task]): + """ + Check tasks for consistency, such as target nans etc. + + Args: + tasks List[:class:`~.data.task.Task`]: + List of tasks to check. + + Returns: + List[:class:`~.data.task.Task`]: + updated list of tasks + """ + pass def __call__( self, @@ -1514,10 +1540,10 @@ def __call__( ) else: return self.task_generation( - date, - context_sampling, - target_sampling, - split_frac, - datewise_deterministic, - seed_override, + date=date, + context_sampling=context_sampling, + target_sampling=target_sampling, + split_frac=split_frac, + datewise_deterministic=datewise_deterministic, + seed_override=seed_override, ) diff --git a/tests/test_task_loader.py b/tests/test_task_loader.py index b63cd566..4c176ecd 100644 --- a/tests/test_task_loader.py +++ b/tests/test_task_loader.py @@ -273,7 +273,7 @@ def test_links(self) -> None: task = tl("2020-01-01", "gapfill", "gapfill") @parameterized.expand([[(0.3, 0.3)], [(0.6, 0.4)]]) - def test_window_size(self, window_size) -> None: + def test_patch_size(self, patch_size) -> None: """Test patch size sampling.""" context = [self.da, self.df] @@ -288,7 +288,7 @@ def test_window_size(self, window_size) -> None: if isinstance(context_sampling[0], np.ndarray): continue task = tl( - "2020-01-01", context_sampling, target_sampling, window_size=window_size + "2020-01-01", context_sampling, target_sampling, patch_size=patch_size ) def test_saving_and_loading(self): From 5bda80b6abc4acb8efa7faa19bd8d191ec70ac84 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Mon, 15 Apr 2024 12:00:13 +0200 Subject: [PATCH 12/69] patchwise test trainer --- deepsensor/data/loader.py | 92 +++++++++++++++------------------------ tests/test_task_loader.py | 82 ++++++++++++++++++++++++---------- tests/test_training.py | 37 +++++++++++++++- 3 files changed, 129 insertions(+), 82 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 5641efc4..63627468 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1,17 +1,15 @@ -from deepsensor.data.task import Task, flatten_X - -import os -import json import copy -import random import itertools +import json +import os +import random +from typing import List, Optional, Sequence, Tuple, Union import numpy as np -import xarray as xr import pandas as pd +import xarray as xr -from typing import List, Tuple, Union, Optional, Sequence - +from deepsensor.data.task import Task, flatten_X from deepsensor.errors import InvalidSamplingStrategyError @@ -853,6 +851,7 @@ def sample_random_window(self, patch_size: Tuple[float]) -> Sequence[float]: # sample a point that satisfies the context and target global bounds x1_min, x1_max, x2_min, x2_max = self.coord_bounds + x1_point = random.uniform(x1_min + x1_side, x1_max - x1_side) x2_point = random.uniform(x2_min + x2_side, x2_max - x2_side) @@ -916,47 +915,15 @@ def spatial_slice_variable(self, var, window: List[float]): """ x1_min, x1_max, x2_min, x2_max = window if isinstance(var, (xr.Dataset, xr.DataArray)): - x1_slice = slice(x1_min, x1_max) - x2_slice = slice(x2_min, x2_max) - var = var.sel(x1=x1_slice, x2=x2_slice) - elif isinstance(var, (pd.DataFrame, pd.Series)): - # retrieve desired patch size - var = var[ - (var.index.get_level_values("x1") >= x1_min) - & (var.index.get_level_values("x1") <= x1_max) - & (var.index.get_level_values("x2") >= x2_min) - & (var.index.get_level_values("x2") <= x2_max) - ] - else: - raise ValueError(f"Unknown variable type {type(var)}") - - return var - - def spatial_slice_variable(self, var, window: List[float]): - """ - Slice a variabel by a given window size. - - Parameters - ---------- - var : ... - Variable to slice - window : ... - list of coordinates specifying the window [x1_min, x1_max, x2_min, x2_max] - - Returns - ------- - var : ... - Sliced variable. - - Raises - ------ - ValueError - If the variable is of an unknown type. - """ - x1_min, x1_max, x2_min, x2_max = window - if isinstance(var, (xr.Dataset, xr.DataArray)): - x1_slice = slice(x1_min, x1_max) - x2_slice = slice(x2_min, x2_max) + # we cannot assume that the coordinates are sorted from small to large + if var.x1[0] > var.x1[-1]: + x1_slice = slice(x1_max, x1_min) + else: + x1_slice = slice(x1_min, x1_max) + if var.x2[0] > var.x2[-1]: + x2_slice = slice(x2_max, x2_min) + else: + x2_slice = slice(x2_min, x2_max) var = var.sel(x1=x1_slice, x2=x2_slice) elif isinstance(var, (pd.DataFrame, pd.Series)): # retrieve desired patch size @@ -1087,7 +1054,9 @@ def check_sampling_strat(sampling_strat, set): f"Unknown sampling strategy {strat} of type {type(strat)}" ) elif isinstance(strat, str) and strat == "gapfill": - assert all(isinstance(item, (xr.Dataset, xr.DataArray)) for item in set), ( + assert all( + isinstance(item, (xr.Dataset, xr.DataArray)) for item in set + ), ( "Gapfill sampling strategy can only be used with xarray " "datasets or data arrays" ) @@ -1243,7 +1212,6 @@ def sample_variable(var, sampling_strat, seed): assert all( 0 < x <= 1 for x in patch_size ), "Values specified for patch size must satisfy 0 < x <= 1." - patch = self.sample_random_window(patch_size) # spatial slices @@ -1402,18 +1370,25 @@ def generate_tasks( if patch_strategy is None: tasks = [self.task_generation(date, **kwargs) for date in dates] elif patch_strategy == "random": - assert "patch_size" in kwargs, "Patch size must be specified for random patch sampling." + assert ( + "patch_size" in kwargs + ), "Patch size must be specified for random patch sampling." # uniform random sampling of patch - tasks : list[Task] = [] + tasks: list[Task] = [] num_samples_per_date = kwargs.get("num_samples_per_date", 1) new_kwargs = kwargs.copy() new_kwargs.pop("num_samples_per_date", None) for date in dates: - tasks.extend([self.task_generation(date, **new_kwargs) for _ in range(num_samples_per_date)]) - + tasks.extend( + [ + self.task_generation(date, **new_kwargs) + for _ in range(num_samples_per_date) + ] + ) + elif patch_strategy == "sliding": # sliding window sampling of patch - tasks : list[Task] = [] + tasks: list[Task] = [] else: raise ValueError( f"Invalid patch strategy {patch_strategy}. " @@ -1421,11 +1396,11 @@ def generate_tasks( ) return tasks - + def check_tasks(self, tasks: List[Task]): """ Check tasks for consistency, such as target nans etc. - + Args: tasks List[:class:`~.data.task.Task`]: List of tasks to check. @@ -1544,6 +1519,7 @@ def __call__( context_sampling=context_sampling, target_sampling=target_sampling, split_frac=split_frac, + patch_size=patch_size, datewise_deterministic=datewise_deterministic, seed_override=seed_override, ) diff --git a/tests/test_task_loader.py b/tests/test_task_loader.py index 4c176ecd..c1249887 100644 --- a/tests/test_task_loader.py +++ b/tests/test_task_loader.py @@ -1,30 +1,28 @@ +import copy import itertools - +import os +import shutil +import tempfile +import unittest from typing import Sequence -from parameterized import parameterized - -import xarray as xr import dask.array import numpy as np import pandas as pd -import unittest - -import os -import shutil -import tempfile -import copy +import pytest +import xarray as xr +from _pytest.fixtures import SubRequest +from parameterized import parameterized +from deepsensor.data.loader import TaskLoader from deepsensor.errors import InvalidSamplingStrategyError from tests.utils import ( - gen_random_data_xr, - gen_random_data_pandas, assert_allclose_pd, assert_allclose_xr, + gen_random_data_pandas, + gen_random_data_xr, ) -from deepsensor.data.loader import TaskLoader - def _gen_data_xr(coords=None, dims=None, data_vars=None, use_dask=False): """Gen random normalised data""" @@ -275,21 +273,59 @@ def test_links(self) -> None: @parameterized.expand([[(0.3, 0.3)], [(0.6, 0.4)]]) def test_patch_size(self, patch_size) -> None: """Test patch size sampling.""" - context = [self.da, self.df] + # need to redefine the data generators because the patch size samplin + # where we want to test that context and or target have different + # spatial extents + da_data_0_1 = self.da + + # smaller normalized coord + da_data_smaller = _gen_data_xr( + coords=dict( + time=pd.date_range("2020-01-01", "2020-01-31", freq="D"), + x1=np.linspace(0.1, 0.9, 25), + x2=np.linspace(0.1, 0.9, 10), + ) + ) + # larger normalized coord + da_data_larger = _gen_data_xr( + coords=dict( + time=pd.date_range("2020-01-01", "2020-01-31", freq="D"), + x1=np.linspace(-0.1, 1.1, 50), + x2=np.linspace(-0.1, 1.1, 50), + ) + ) + context = [da_data_0_1, da_data_smaller, da_data_larger] tl = TaskLoader( context=context, # gridded xarray and off-grid pandas contexts target=self.df, # off-grid pandas targets ) - for context_sampling, target_sampling in self._gen_task_loader_call_args( - len(context), 1 - ): - if isinstance(context_sampling[0], np.ndarray): - continue - task = tl( - "2020-01-01", context_sampling, target_sampling, patch_size=patch_size - ) + # TODO it would be better to do this with pytest.fixtures + # but could not get to work so far + task = tl( + "2020-01-01", "all", "all", patch_size=patch_size, patch_strategy="random" + ) + + # test date range + tasks = tl( + ["2020-01-01", "2020-01-02"], + "all", + "all", + patch_size=patch_size, + patch_strategy="random", + ) + assert len(tasks) == 2 + # test date range with num_samples per date + tasks = tl.generate_tasks( + ["2020-01-01", "2020-01-02"], + context_sampling="all", + target_sampling="all", + patch_size=patch_size, + patch_strategy="random", + num_samples_per_date=2, + ) + assert len(tasks) == 4 def test_saving_and_loading(self): """Test saving and loading TaskLoader""" diff --git a/tests/test_training.py b/tests/test_training.py index 084b82ed..2c4328e1 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -6,7 +6,8 @@ from tqdm import tqdm -import deepsensor.tensorflow as deepsensor +# import deepsensor.tensorflow as deepsensor +import deepsensor.torch from deepsensor.train.train import Trainer from deepsensor.data.processor import DataProcessor @@ -113,3 +114,37 @@ def test_training(self): # Check for NaNs in the loss loss = np.mean(epoch_losses) self.assertFalse(np.isnan(loss)) + + def test_patch_wise_training(self): + """ + Test model training with patch-wise tasks. + """ + tl = TaskLoader(context=self.da, target=self.da) + model = ConvNP(self.data_processor, tl, unet_channels=(5, 5, 5), verbose=False) + + # generate training tasks + n_train_tasks = 10 + dates = [np.random.choice(self.da.time.values) for i in range(n_train_tasks)] + train_tasks = tl.generate_tasks( + dates, + context_sampling="all", + target_sampling="all", + patch_strategy="random", + patch_size=(0.8, 0.8), + ) + + # TODO pytest can also be more succinct with pytest.fixtures + # Train + trainer = Trainer(model, lr=5e-5) + batch_size = None + # TODO check with batch_size > 1 + # batch_size = 5 + n_epochs = 10 + epoch_losses = [] + for epoch in tqdm(range(n_epochs)): + batch_losses = trainer(train_tasks, batch_size=batch_size) + epoch_losses.append(np.mean(batch_losses)) + + # Check for NaNs in the loss + loss = np.mean(epoch_losses) + self.assertFalse(np.isnan(loss)) From c276844de5d63ef52ea8b52bb98747da8feb340d Mon Sep 17 00:00:00 2001 From: Martin Rogers Date: Fri, 19 Apr 2024 12:52:50 +0100 Subject: [PATCH 13/69] gridded window patching --- deepsensor/data/loader.py | 226 ++++++++++++++++++++++++++------------ 1 file changed, 158 insertions(+), 68 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 97c4b7fe..ae39a469 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1,17 +1,15 @@ -from deepsensor.data.task import Task, flatten_X - -import os -import json import copy -import random import itertools +import json +import os +import random +from typing import List, Optional, Sequence, Tuple, Union import numpy as np -import xarray as xr import pandas as pd +import xarray as xr -from typing import List, Tuple, Union, Optional, Sequence - +from deepsensor.data.task import Task, flatten_X from deepsensor.errors import InvalidSamplingStrategyError @@ -834,7 +832,7 @@ def _compute_global_coordinate_bounds(self) -> List[float]: def sample_random_window(self, patch_size: Tuple[float]) -> Sequence[float]: """ - Sample random window uniformly from global coordinats to slice data. + Sample random window uniformly from global coordinates to slice data. Parameters ---------- @@ -853,6 +851,7 @@ def sample_random_window(self, patch_size: Tuple[float]) -> Sequence[float]: # sample a point that satisfies the context and target global bounds x1_min, x1_max, x2_min, x2_max = self.coord_bounds + x1_point = random.uniform(x1_min + x1_side, x1_max - x1_side) x2_point = random.uniform(x2_min + x2_side, x2_max - x2_side) @@ -866,6 +865,57 @@ def sample_random_window(self, patch_size: Tuple[float]) -> Sequence[float]: return bbox + def sample_sliding_window(self, patch_size: Tuple[float], stride: Tuple[float]) -> Sequence[float]: + """ + Sample data using sliding window from global coordinates to slice data. + + Parameters + ---------- + patch_size : Tuple[float] + Tuple of window extent + + Stride : Tuple[float] + Tuple of step size between each patch along x1 and x2 axis. + Returns + ------- + bbox: List[float] ## check type of return. + sequence of patch spatial extent as [x1_min, x1_max, x2_min, x2_max] + """ + # define patch size in x1/x2 + x1_extend, x2_extend = patch_size + + # define stride length in x1/x2 + dy, dx = stride + + # Calculate the global bounds of context and target set. + x1_min, x1_max, x2_min, x2_max = self.coord_bounds + + ## start with first patch top left hand corner at x1_min, x2_min + n_patches = 0 + patch_list = [] + for y in range(x1_min, x1_max, dy): + for x in range(x2_min, x2_max, dx): + n_patches += 1 + if y + x1_extend > x1_max: + y0 = x1_max - x1_extend + else: + y0 = y + if x + x2_extend > x2_max: + x0 = x2_max - x2_extend + else: + x0 = x + + # bbox of x1_min, x1_max, x2_min, x2_max per patch + bbox = [y0, y0 + x1_extend, x0, x0 + x2_extend] + + patch_list.append(bbox) + + ## I don't think we should actually print this here, but somehow we should + ## provide this information back, so users know the number of patches per date. + print("Number of patches per date using sliding window method", n_patches) + + return patch_list + def time_slice_variable(self, var, date, delta_t=0): """ Slice a variable by a given time delta. @@ -898,56 +948,17 @@ def time_slice_variable(self, var, date, delta_t=0): def spatial_slice_variable(self, var, window: List[float]): """ - Slice a variabel by a given window size. - Parameters - ---------- - var : ... - Variable to slice - window : ... - list of coordinates specifying the window [x1_min, x1_max, x2_min, x2_max] - Returns - ------- - var : ... - Sliced variable. - Raises - ------ - ValueError - If the variable is of an unknown type. - """ - x1_min, x1_max, x2_min, x2_max = window - if isinstance(var, (xr.Dataset, xr.DataArray)): - x1_slice = slice(x1_min, x1_max) - x2_slice = slice(x2_min, x2_max) - var = var.sel(x1=x1_slice, x2=x2_slice) - elif isinstance(var, (pd.DataFrame, pd.Series)): - # retrieve desired patch size - var = var[ - (var.index.get_level_values("x1") >= x1_min) - & (var.index.get_level_values("x1") <= x1_max) - & (var.index.get_level_values("x2") >= x2_min) - & (var.index.get_level_values("x2") <= x2_max) - ] - else: - raise ValueError(f"Unknown variable type {type(var)}") - - return var - - def spatial_slice_variable(self, var, window: List[float]): - """ - Slice a variabel by a given window size. - + Slice a variable by a given window size. Parameters ---------- var : ... Variable to slice window : ... list of coordinates specifying the window [x1_min, x1_max, x2_min, x2_max] - Returns ------- var : ... Sliced variable. - Raises ------ ValueError @@ -955,8 +966,15 @@ def spatial_slice_variable(self, var, window: List[float]): """ x1_min, x1_max, x2_min, x2_max = window if isinstance(var, (xr.Dataset, xr.DataArray)): - x1_slice = slice(x1_min, x1_max) - x2_slice = slice(x2_min, x2_max) + # we cannot assume that the coordinates are sorted from small to large + if var.x1[0] > var.x1[-1]: + x1_slice = slice(x1_max, x1_min) + else: + x1_slice = slice(x1_min, x1_max) + if var.x2[0] > var.x2[-1]: + x2_slice = slice(x2_max, x2_min) + else: + x2_slice = slice(x2_min, x2_max) var = var.sel(x1=x1_slice, x2=x2_slice) elif isinstance(var, (pd.DataFrame, pd.Series)): # retrieve desired patch size @@ -974,6 +992,7 @@ def spatial_slice_variable(self, var, window: List[float]): def task_generation( self, date: pd.Timestamp, + patch_strategy: Optional[str], context_sampling: Union[ str, int, @@ -992,6 +1011,7 @@ def task_generation( ] = None, split_frac: float = 0.5, patch_size: Sequence[float] = None, + bbox: Sequence[float] = None, datewise_deterministic: bool = False, seed_override: Optional[int] = None, ) -> Task: @@ -1026,8 +1046,10 @@ def task_generation( The remaining observations are used for the target set. Default is 0.5. patch_size : Sequence[float], optional - Desired patch size in x1/x2 used for patchwise task generation. Usefule when considering + Desired patch size in x1/x2 used for patchwise task generation. Useful when considering the entire available region is computationally prohibitive for model forward pass + bbox : Sequence[float], optional + Bounding box in x1/x2 for patch. Only passed when using sliding window patching function. datewise_deterministic : bool Whether random sampling is datewise_deterministic based on the date. Default is ``False``. @@ -1086,6 +1108,13 @@ def check_sampling_strat(sampling_strat, set): raise InvalidSamplingStrategyError( f"Unknown sampling strategy {strat} of type {type(strat)}" ) + elif isinstance(strat, str) and strat == "gapfill": + assert all( + isinstance(item, (xr.Dataset, xr.DataArray)) for item in set + ), ( + "Gapfill sampling strategy can only be used with xarray " + "datasets or data arrays" + ) elif isinstance(strat, str) and strat not in [ "all", "split", @@ -1235,11 +1264,15 @@ def sample_variable(var, sampling_strat, seed): assert ( len(patch_size) == 2 ), "Patch size must be a Sequence of two values for x1/x2 extent." - assert all( - 0 < x <= 1 for x in patch_size + assert all( ## Will it confuse users to provide a patch size 0-1? Should we add method to convert patch size to 0-1? + 0 < x <= 1 for x in patch_size ), "Values specified for patch size must satisfy 0 < x <= 1." - - patch = self.sample_random_window(patch_size) + + #patch_strategy = kwargs.get("patch_strategy") + if patch_strategy == "random": + patch = self.sample_random_window(patch_size) + elif patch_strategy == "sliding": + patch = bbox # spatial slices context_slices = [ @@ -1248,6 +1281,8 @@ def sample_variable(var, sampling_strat, seed): target_slices = [ self.spatial_slice_variable(var, patch) for var in target_slices ] + ## Do we want patching before "gapfill" and "split" sampling plus adding + ## Auxilary data? # TODO move to method if ( @@ -1376,7 +1411,7 @@ def sample_variable(var, sampling_strat, seed): return Task(task) - def generate_tasks( + def nils( self, dates: Union[pd.Timestamp, Sequence[pd.Timestamp]], patch_strategy: Optional[str], @@ -1394,14 +1429,50 @@ def generate_tasks( **kwargs: Additional keyword arguments to pass to the task generation method. """ + if patch_strategy is None: tasks = [self.task_generation(date, **kwargs) for date in dates] + elif patch_strategy == "random": + assert ( + "patch_size" in kwargs + ), "Patch size must be specified for random patch sampling." # uniform random sampling of patch - pass + tasks: list[Task] = [] + num_samples_per_date = kwargs.get("num_samples_per_date", 1) + ## Run sample_random_window() here once? + new_kwargs = kwargs.copy() + new_kwargs.pop("num_samples_per_date", None) + for date in dates: + tasks.extend( + [ + self.task_generation(date, **new_kwargs) + for _ in range(num_samples_per_date)## Is it risky to run the entire task_generation call each time? + ## e.g. if using the "split" or "gapfill" strategy? + ## Should we run task_generation() once and then patch? + ] + ) + elif patch_strategy == "sliding": + assert ( + "patch_size" in kwargs + ), "Patch size must be specified for sliding window patch sampling." + # sliding window sampling of patch - pass + tasks: list[Task] = [] + + # Extract the x1/x2 length values of the patch defined by user. + patch_size = kwargs.get("patch_size") + # Extract stride size in x1/x2 or default to patch size. + stride = kwargs.get("stride", patch_size) + + patch_extents = self.sample_sliding_window(patch_size, stride) + + + for date in dates: + for bbox in patch_extents: + tasks.extend([self.task_generation(date, bbox, **kwargs)]) + else: raise ValueError( f"Invalid patch strategy {patch_strategy}. " @@ -1410,6 +1481,20 @@ def generate_tasks( return tasks + def check_tasks(self, tasks: List[Task]): + """ + Check tasks for consistency, such as target nans etc. + + Args: + tasks List[:class:`~.data.task.Task`]: + List of tasks to check. + + Returns: + List[:class:`~.data.task.Task`]: + updated list of tasks + """ + pass + def __call__( self, date: Union[pd.Timestamp, Sequence[pd.Timestamp]], @@ -1497,11 +1582,12 @@ def __call__( Task object or list of task objects for each date containing the context and target data. """ + assert patch_strategy in [None, "random", "sliding"], ( f"Invalid patch strategy {patch_strategy}. " f"Must be one of [None, 'random', 'sliding']." ) - if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)): + if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex, pd._libs.tslibs.timestamps.Timestamp)): return self.generate_tasks( dates=date, patch_strategy=patch_strategy, @@ -1511,13 +1597,17 @@ def __call__( patch_size=patch_size, datewise_deterministic=datewise_deterministic, seed_override=seed_override, - ) + ) else: return self.task_generation( - date, - context_sampling, - target_sampling, - split_frac, - datewise_deterministic, - seed_override, - ) + date=date, + context_sampling=context_sampling, + target_sampling=target_sampling, + split_frac=split_frac, + patch_strategy=patch_strategy, + patch_size=patch_size, + datewise_deterministic=datewise_deterministic, + seed_override=seed_override, + )## This set up currently doesn't work for sliding window because the function is not called when an individual date is supplied. + ## I also don't think it could patch using uniform function? + ## Currently I can only run when incluiding pd._libs.tslibs.timestamps.Timestamp \ No newline at end of file From fde7e0265ef3f26e445407d4912596a360c044d8 Mon Sep 17 00:00:00 2001 From: Martin Rogers Date: Fri, 19 Apr 2024 14:42:35 +0100 Subject: [PATCH 14/69] adding sliding window patching function --- deepsensor/data/loader.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index ae39a469..0f205241 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1264,7 +1264,7 @@ def sample_variable(var, sampling_strat, seed): assert ( len(patch_size) == 2 ), "Patch size must be a Sequence of two values for x1/x2 extent." - assert all( ## Will it confuse users to provide a patch size 0-1? Should we add method to convert patch size to 0-1? + assert all( ## Will it confuse users to provide a patch with size 0-1? Should we add method to convert patch size to 0-1? 0 < x <= 1 for x in patch_size ), "Values specified for patch size must satisfy 0 < x <= 1." @@ -1281,7 +1281,7 @@ def sample_variable(var, sampling_strat, seed): target_slices = [ self.spatial_slice_variable(var, patch) for var in target_slices ] - ## Do we want patching before "gapfill" and "split" sampling plus adding + ## Do we want to patch before "gapfill" and "split" sampling plus adding ## Auxilary data? # TODO move to method @@ -1411,7 +1411,7 @@ def sample_variable(var, sampling_strat, seed): return Task(task) - def nils( + def generate_tasks( self, dates: Union[pd.Timestamp, Sequence[pd.Timestamp]], patch_strategy: Optional[str], @@ -1446,8 +1446,8 @@ def nils( for date in dates: tasks.extend( [ - self.task_generation(date, **new_kwargs) - for _ in range(num_samples_per_date)## Is it risky to run the entire task_generation call each time? + self.task_generation(date, patch_strategy, **new_kwargs) + for _ in range(num_samples_per_date)## Could we produce different context/target sets if we call task_generation in a loop? ## e.g. if using the "split" or "gapfill" strategy? ## Should we run task_generation() once and then patch? ] @@ -1470,8 +1470,11 @@ def nils( for date in dates: - for bbox in patch_extents: - tasks.extend([self.task_generation(date, bbox, **kwargs)]) + tasks.extend( + [self.task_generation(date, patch_strategy, bbox, **kwargs) + for bbox in patch_extents + ] + ) else: raise ValueError( @@ -1587,7 +1590,7 @@ def __call__( f"Invalid patch strategy {patch_strategy}. " f"Must be one of [None, 'random', 'sliding']." ) - if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex, pd._libs.tslibs.timestamps.Timestamp)): + if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)): return self.generate_tasks( dates=date, patch_strategy=patch_strategy, @@ -1609,5 +1612,5 @@ def __call__( datewise_deterministic=datewise_deterministic, seed_override=seed_override, )## This set up currently doesn't work for sliding window because the function is not called when an individual date is supplied. - ## I also don't think it could patch using uniform function? - ## Currently I can only run when incluiding pd._libs.tslibs.timestamps.Timestamp \ No newline at end of file + ## I also don't think it could patch using uniform function because it can't run through for _ in range(num_samples_per_date)? + ## Currently I can only run when including pd._libs.tslibs.timestamps.Timestamp \ No newline at end of file From 195a9239cfbbd61a69ef021586bdcb1423201684 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Mon, 22 Apr 2024 08:42:00 +0000 Subject: [PATCH 15/69] loader with bboxes --- deepsensor/data/loader.py | 51 ++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 63627468..cf13c195 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -958,7 +958,7 @@ def task_generation( ] ] = None, split_frac: float = 0.5, - patch_size: Sequence[float] = None, + bbox: Sequence[float] = None, datewise_deterministic: bool = False, seed_override: Optional[int] = None, ) -> Task: @@ -992,8 +992,9 @@ def task_generation( "split" sampling strategy for linked context and target set pairs. The remaining observations are used for the target set. Default is 0.5. - patch_size : Sequence[float], optional - Desired patch size in x1/x2 used for patchwise task generation. Usefule when considering + bbox : Sequence[float], optional + A bounding box to sample the context and target data from. Specified as a list + of coordinates [x1_min, x1_max, x2_min, x2_max]. Useful when considering the entire available region is computationally prohibitive for model forward pass datewise_deterministic : bool Whether random sampling is datewise_deterministic based on the @@ -1205,21 +1206,14 @@ def sample_variable(var, sampling_strat, seed): ] # check patch size - if patch_size is not None: - assert ( - len(patch_size) == 2 - ), "Patch size must be a Sequence of two values for x1/x2 extent." - assert all( - 0 < x <= 1 for x in patch_size - ), "Values specified for patch size must satisfy 0 < x <= 1." - patch = self.sample_random_window(patch_size) - + if bbox is not None: + assert len(bbox) == 4, "Bounding box must be of length 4" # spatial slices context_slices = [ - self.spatial_slice_variable(var, patch) for var in context_slices + self.spatial_slice_variable(var, bbox) for var in context_slices ] target_slices = [ - self.spatial_slice_variable(var, patch) for var in target_slices + self.spatial_slice_variable(var, bbox) for var in target_slices ] # TODO move to method @@ -1353,6 +1347,7 @@ def generate_tasks( self, dates: Union[pd.Timestamp, Sequence[pd.Timestamp]], patch_strategy: Optional[str], + patch_size: Optional[Sequence[float]] = None, **kwargs, ) -> List[Task]: """ @@ -1364,31 +1359,44 @@ def generate_tasks( patch_strategy: Optional[str] Patch strategy to use for patchwise task generation. Default is None. Possible options are 'random' or 'sliding'. + patch_size: Optional[Sequence[float]] + Patch size for random patch sampling or sliding window sampling **kwargs: Additional keyword arguments to pass to the task generation method. """ if patch_strategy is None: tasks = [self.task_generation(date, **kwargs) for date in dates] elif patch_strategy == "random": - assert ( - "patch_size" in kwargs - ), "Patch size must be specified for random patch sampling." - # uniform random sampling of patch - tasks: list[Task] = [] + assert patch_size is not None, "Patch size must be specified for random patch sampling" + num_samples_per_date = kwargs.get("num_samples_per_date", 1) new_kwargs = kwargs.copy() new_kwargs.pop("num_samples_per_date", None) + tasks: list[Task] = [] for date in dates: + bboxes : list[float] = [] + for _ in range(num_samples_per_date): + bboxes.append(self.sample_random_window(patch_size)) tasks.extend( [ - self.task_generation(date, **new_kwargs) - for _ in range(num_samples_per_date) + self.task_generation(date, bbox=bbox, **new_kwargs) + for bbox in bboxes ] ) elif patch_strategy == "sliding": # sliding window sampling of patch + assert patch_size is not None, "Patch size must be specified for sliding window sampling" tasks: list[Task] = [] + + for date in dates: + bboxes = self.sliding_window_sampling(patch_size) + tasks.extend( + [ + self.task_generation(date, bbox=bbox, **kwargs) + for bbox in bboxes + ] + ) else: raise ValueError( f"Invalid patch strategy {patch_strategy}. " @@ -1519,7 +1527,6 @@ def __call__( context_sampling=context_sampling, target_sampling=target_sampling, split_frac=split_frac, - patch_size=patch_size, datewise_deterministic=datewise_deterministic, seed_override=seed_override, ) From 824df24e609e7dd74c7f76bdd63f3459925c2b95 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Mon, 22 Apr 2024 08:56:12 +0000 Subject: [PATCH 16/69] loader with boxes --- deepsensor/data/loader.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index cf13c195..33dd00f7 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -993,9 +993,8 @@ def task_generation( The remaining observations are used for the target set. Default is 0.5. bbox : Sequence[float], optional - A bounding box to sample the context and target data from. Specified as a list - of coordinates [x1_min, x1_max, x2_min, x2_max]. Useful when considering - the entire available region is computationally prohibitive for model forward pass + Bounding box to spatially slice the data, should be of the form [x1_min, x1_max, x2_min, x2_max]. + Usefule when considering the entire available region is computationally prohibitive for model forward pass datewise_deterministic : bool Whether random sampling is datewise_deterministic based on the date. Default is ``False``. @@ -1207,7 +1206,8 @@ def sample_variable(var, sampling_strat, seed): # check patch size if bbox is not None: - assert len(bbox) == 4, "Bounding box must be of length 4" + assert len(bbox) == 4, "bbox must be a list of length 4 with [x1_min, x1_max, x2_min, x2_max]" + # spatial slices context_slices = [ self.spatial_slice_variable(var, bbox) for var in context_slices From e6e1ae8aad0210e47fc2b66188fbef50dabe88d1 Mon Sep 17 00:00:00 2001 From: Martin Rogers Date: Mon, 22 Apr 2024 14:38:42 +0100 Subject: [PATCH 17/69] Altering kwargs to enable for-loop and change sliding function --- deepsensor/data/loader.py | 58 +++++++++++++++++++++++++++++++-------- 1 file changed, 47 insertions(+), 11 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 0f205241..0de9b9cf 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -890,9 +890,33 @@ def sample_sliding_window(self, patch_size: Tuple[float], stride: Tuple[float]) # Calculate the global bounds of context and target set. x1_min, x1_max, x2_min, x2_max = self.coord_bounds + print("all the key variables in sliding window", x1_min, x1_max, dy, x2_min, x2_max, dx) ## start with first patch top left hand corner at x1_min, x2_min n_patches = 0 patch_list = [] + + y = x1_min + while y < x1_max: + x = x2_min + while x < x2_max: + n_patches += 1 + if y + x1_extend > x1_max: + y0 = x1_max - x1_extend + else: + y0 = y + if x + x2_extend > x2_max: + x0 = x2_max - x2_extend + else: + x0 = x + + # bbox of x1_min, x1_max, x2_min, x2_max per patch + bbox = [y0, y0 + x1_extend, x0, x0 + x2_extend] + print('bbox', bbox) + patch_list.append(bbox) + x += dx # Increment x by dx + y += dy # Increment y by dy + + """ for y in range(x1_min, x1_max, dy): for x in range(x2_min, x2_max, dx): n_patches += 1 @@ -909,7 +933,7 @@ def sample_sliding_window(self, patch_size: Tuple[float], stride: Tuple[float]) bbox = [y0, y0 + x1_extend, x0, x0 + x2_extend] patch_list.append(bbox) - + """ ## I don't think we should actually print this here, but somehow we should ## provide this information back, so users know the number of patches per date. print("Number of patches per date using sliding window method", n_patches) @@ -1443,10 +1467,13 @@ def generate_tasks( ## Run sample_random_window() here once? new_kwargs = kwargs.copy() new_kwargs.pop("num_samples_per_date", None) + new_kwargs.pop('stride', None) + #context_sampling = new_kwargs.pop("context_sampling") + print('kwargs', new_kwargs) for date in dates: tasks.extend( [ - self.task_generation(date, patch_strategy, **new_kwargs) + self.task_generation(date, patch_strategy, **new_kwargs) for _ in range(num_samples_per_date)## Could we produce different context/target sets if we call task_generation in a loop? ## e.g. if using the "split" or "gapfill" strategy? ## Should we run task_generation() once and then patch? @@ -1457,24 +1484,28 @@ def generate_tasks( assert ( "patch_size" in kwargs ), "Patch size must be specified for sliding window patch sampling." - + # sliding window sampling of patch tasks: list[Task] = [] # Extract the x1/x2 length values of the patch defined by user. patch_size = kwargs.get("patch_size") - # Extract stride size in x1/x2 or default to patch size. - stride = kwargs.get("stride", patch_size) + # Extract stride size in x1/x2 or set to patch size if undefined. + stride = kwargs.pop("stride", None) + kwargs.pop("num_samples_per_date") + if stride is None: + stride = patch_size patch_extents = self.sample_sliding_window(patch_size, stride) - + #context_sampling = kwargs.pop("context_sampling") for date in dates: - tasks.extend( - [self.task_generation(date, patch_strategy, bbox, **kwargs) - for bbox in patch_extents - ] - ) + for bbox in patch_extents: + kwargs['bbox'] = bbox + tasks.extend( + [self.task_generation(date, patch_strategy, **kwargs) + ] + ) else: raise ValueError( @@ -1519,7 +1550,9 @@ def __call__( ] = None, split_frac: float = 0.5, patch_size: Sequence[float] = None, + stride: Sequence[float] = None, patch_strategy: Optional[str] = None, + num_samples_per_date: Optional[int] = 1, datewise_deterministic: bool = False, seed_override: Optional[int] = None, ) -> Union[Task, List[Task]]: @@ -1598,6 +1631,8 @@ def __call__( target_sampling=target_sampling, split_frac=split_frac, patch_size=patch_size, + stride=stride, + num_samples_per_date=num_samples_per_date, datewise_deterministic=datewise_deterministic, seed_override=seed_override, ) @@ -1609,6 +1644,7 @@ def __call__( split_frac=split_frac, patch_strategy=patch_strategy, patch_size=patch_size, + num_samples_per_date=num_samples_per_date, datewise_deterministic=datewise_deterministic, seed_override=seed_override, )## This set up currently doesn't work for sliding window because the function is not called when an individual date is supplied. From bae0855c674c9e8a3df529a5bad6182a22016f0a Mon Sep 17 00:00:00 2001 From: nilsleh Date: Mon, 22 Apr 2024 16:30:33 +0000 Subject: [PATCH 18/69] move logic to call --- deepsensor/data/loader.py | 237 ++++++++++++++++++++++++-------------- tests/test_task_loader.py | 46 +++++++- tests/test_training.py | 38 +++++- 3 files changed, 230 insertions(+), 91 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 33dd00f7..b0dda732 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1206,7 +1206,9 @@ def sample_variable(var, sampling_strat, seed): # check patch size if bbox is not None: - assert len(bbox) == 4, "bbox must be a list of length 4 with [x1_min, x1_max, x2_min, x2_max]" + assert ( + len(bbox) == 4 + ), "bbox must be a list of length 4 with [x1_min, x1_max, x2_min, x2_max]" # spatial slices context_slices = [ @@ -1343,81 +1345,56 @@ def sample_variable(var, sampling_strat, seed): return Task(task) - def generate_tasks( - self, - dates: Union[pd.Timestamp, Sequence[pd.Timestamp]], - patch_strategy: Optional[str], - patch_size: Optional[Sequence[float]] = None, - **kwargs, - ) -> List[Task]: + def sample_sliding_window( + self, patch_size: Tuple[float], stride: Tuple[int] + ) -> Sequence[float]: """ - Generate a list of Tasks for Training or Inference. + Sample data using sliding window from global coordinates to slice data. + Parameters + ---------- + patch_size : Tuple[float] + Tuple of window extent - Args: - dates: Union[pd.Timestamp, List[pd.Timestamp]] - List of dates for which to generate the task. - patch_strategy: Optional[str] - Patch strategy to use for patchwise task generation. Default is None. - Possible options are 'random' or 'sliding'. - patch_size: Optional[Sequence[float]] - Patch size for random patch sampling or sliding window sampling - **kwargs: - Additional keyword arguments to pass to the task generation method. + Stride : Tuple[float] + Tuple of step size between each patch along x1 and x2 axis. + Returns + ------- + bbox: List[float] ## check type of return. + sequence of patch spatial extent as [x1_min, x1_max, x2_min, x2_max] """ - if patch_strategy is None: - tasks = [self.task_generation(date, **kwargs) for date in dates] - elif patch_strategy == "random": - assert patch_size is not None, "Patch size must be specified for random patch sampling" - - num_samples_per_date = kwargs.get("num_samples_per_date", 1) - new_kwargs = kwargs.copy() - new_kwargs.pop("num_samples_per_date", None) - tasks: list[Task] = [] - for date in dates: - bboxes : list[float] = [] - for _ in range(num_samples_per_date): - bboxes.append(self.sample_random_window(patch_size)) - tasks.extend( - [ - self.task_generation(date, bbox=bbox, **new_kwargs) - for bbox in bboxes - ] - ) + # define patch size in x1/x2 + x1_extend, x2_extend = patch_size - elif patch_strategy == "sliding": - # sliding window sampling of patch - assert patch_size is not None, "Patch size must be specified for sliding window sampling" - tasks: list[Task] = [] + # define stride length in x1/x2 + dy, dx = stride[0] * x1_extend, stride[1] * x2_extend - for date in dates: - bboxes = self.sliding_window_sampling(patch_size) - tasks.extend( - [ - self.task_generation(date, bbox=bbox, **kwargs) - for bbox in bboxes - ] - ) - else: - raise ValueError( - f"Invalid patch strategy {patch_strategy}. " - f"Must be one of [None, 'random', 'sliding']." - ) + # Calculate the global bounds of context and target set. + x1_min, x1_max, x2_min, x2_max = self.coord_bounds - return tasks + ## start with first patch top left hand corner at x1_min, x2_min + patch_list = [] - def check_tasks(self, tasks: List[Task]): - """ - Check tasks for consistency, such as target nans etc. + for y in np.arange(x1_min, x1_max, dy): + for x in np.arange(x2_min, x2_max, dx): + if y + x1_extend > x1_max: + y0 = x1_max - x1_extend + else: + y0 = y + if x + x2_extend > x2_max: + x0 = x2_max - x2_extend + else: + x0 = x - Args: - tasks List[:class:`~.data.task.Task`]: - List of tasks to check. + # bbox of x1_min, x1_max, x2_min, x2_max per patch + bbox = [y0, y0 + x1_extend, x0, x0 + x2_extend] - Returns: - List[:class:`~.data.task.Task`]: - updated list of tasks - """ - pass + patch_list.append(bbox) + + ## I don't think we should actually print this here, but somehow we should + ## provide this information back, so users know the number of patches per date. + print("Number of patches per date using sliding window method", len(patch_list)) + + return patch_list def __call__( self, @@ -1441,6 +1418,8 @@ def __call__( split_frac: float = 0.5, patch_size: Sequence[float] = None, patch_strategy: Optional[str] = None, + stride: Optional[Sequence[int]] = None, + num_samples_per_date: int = 1, datewise_deterministic: bool = False, seed_override: Optional[int] = None, ) -> Union[Task, List[Task]]: @@ -1493,6 +1472,8 @@ def __call__( patch_strategy: Patch strategy to use for patchwise task generation. Default is None. Possible options are 'random' or 'sliding'. + stride: Sequence[int], optional + Step size between each sliding window patch along x1 and x2 axis. Default is None. datewise_deterministic (bool, optional): Whether random sampling is datewise deterministic based on the date. Default is ``False``. @@ -1510,23 +1491,109 @@ def __call__( f"Invalid patch strategy {patch_strategy}. " f"Must be one of [None, 'random', 'sliding']." ) - if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)): - return self.generate_tasks( - dates=date, - patch_strategy=patch_strategy, - context_sampling=context_sampling, - target_sampling=target_sampling, - split_frac=split_frac, - patch_size=patch_size, - datewise_deterministic=datewise_deterministic, - seed_override=seed_override, - ) + + if patch_strategy is None: + if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)): + tasks = [ + self.task_generation( + d, + context_sampling=context_sampling, + target_sampling=target_sampling, + split_frac=split_frac, + datewise_deterministic=datewise_deterministic, + seed_override=seed_override, + ) + for d in date + ] + else: + tasks = self.task_generation( + date=date, + context_sampling=context_sampling, + target_sampling=target_sampling, + split_frac=split_frac, + datewise_deterministic=datewise_deterministic, + seed_override=seed_override, + ) + + elif patch_strategy == "random": + assert ( + patch_size is not None + ), "Patch size must be specified for random patch sampling" + + if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)): + for d in date: + bboxes = [ + self.sample_random_window(patch_size) + for _ in range(num_samples_per_date) + ] + tasks = [ + self.task_generation( + d, + bbox=bbox, + context_sampling=context_sampling, + target_sampling=target_sampling, + split_frac=split_frac, + datewise_deterministic=datewise_deterministic, + seed_override=seed_override, + ) + for bbox in bboxes + ] + + else: + bbox = self.sample_random_window(patch_size) + tasks = self.task_generation( + date=date, + bbox=bbox, + context_sampling=context_sampling, + target_sampling=target_sampling, + split_frac=split_frac, + datewise_deterministic=datewise_deterministic, + seed_override=seed_override, + ) + + elif patch_strategy == "sliding": + # sliding window sampling of patch + assert ( + patch_size is not None + ), "Patch size must be specified for sliding window sampling" + + if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)): + tasks = [] + for d in date: + bboxes = self.sample_sliding_window(patch_size, stride) + tasks.extend( + [ + self.task_generation( + d, + bbox=bbox, + context_sampling=context_sampling, + target_sampling=target_sampling, + split_frac=split_frac, + datewise_deterministic=datewise_deterministic, + seed_override=seed_override, + ) + for bbox in bboxes + ] + ) + else: + bboxes = self.sample_sliding_window(patch_size, stride) + tasks = [ + self.task_generation( + date, + bbox=bbox, + context_sampling=context_sampling, + target_sampling=target_sampling, + split_frac=split_frac, + datewise_deterministic=datewise_deterministic, + seed_override=seed_override, + ) + for bbox in bboxes + ] + else: - return self.task_generation( - date=date, - context_sampling=context_sampling, - target_sampling=target_sampling, - split_frac=split_frac, - datewise_deterministic=datewise_deterministic, - seed_override=seed_override, + raise ValueError( + f"Invalid patch strategy {patch_strategy}. " + f"Must be one of [None, 'random', 'sliding']." ) + + return tasks diff --git a/tests/test_task_loader.py b/tests/test_task_loader.py index c1249887..d8a3d739 100644 --- a/tests/test_task_loader.py +++ b/tests/test_task_loader.py @@ -315,9 +315,9 @@ def test_patch_size(self, patch_size) -> None: patch_size=patch_size, patch_strategy="random", ) - assert len(tasks) == 2 + # test date range with num_samples per date - tasks = tl.generate_tasks( + tasks = tl( ["2020-01-01", "2020-01-02"], context_sampling="all", target_sampling="all", @@ -325,7 +325,47 @@ def test_patch_size(self, patch_size) -> None: patch_strategy="random", num_samples_per_date=2, ) - assert len(tasks) == 4 + + @parameterized.expand([[(0.2, 0.2), (1, 1)], [(0.3, 0.4), (1, 1)]]) + def test_sliding_window(self, patch_size, stride) -> None: + """Test sliding window sampling.""" + # need to redefine the data generators because the patch size samplin + # where we want to test that context and or target have different + # spatial extents + da_data_0_1 = self.da + + # smaller normalized coord + da_data_smaller = _gen_data_xr( + coords=dict( + time=pd.date_range("2020-01-01", "2020-01-31", freq="D"), + x1=np.linspace(0.1, 0.9, 25), + x2=np.linspace(0.1, 0.9, 10), + ) + ) + # larger normalized coord + da_data_larger = _gen_data_xr( + coords=dict( + time=pd.date_range("2020-01-01", "2020-01-31", freq="D"), + x1=np.linspace(-0.1, 1.1, 50), + x2=np.linspace(-0.1, 1.1, 50), + ) + ) + + context = [da_data_0_1, da_data_smaller, da_data_larger] + tl = TaskLoader( + context=context, # gridded xarray and off-grid pandas contexts + target=self.df, # off-grid pandas targets + ) + + # test date range + tasks = tl( + ["2020-01-01", "2020-01-02"], + "all", + "all", + patch_size=patch_size, + patch_strategy="sliding", + stride=stride, + ) def test_saving_and_loading(self): """Test saving and loading TaskLoader""" diff --git a/tests/test_training.py b/tests/test_training.py index 2c4328e1..2f456418 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -123,8 +123,8 @@ def test_patch_wise_training(self): model = ConvNP(self.data_processor, tl, unet_channels=(5, 5, 5), verbose=False) # generate training tasks - n_train_tasks = 10 - dates = [np.random.choice(self.da.time.values) for i in range(n_train_tasks)] + n_train_dates = 10 + dates = [np.random.choice(self.da.time.values) for i in range(n_train_dates)] train_tasks = tl.generate_tasks( dates, context_sampling="all", @@ -139,7 +139,39 @@ def test_patch_wise_training(self): batch_size = None # TODO check with batch_size > 1 # batch_size = 5 - n_epochs = 10 + n_epochs = 5 + epoch_losses = [] + for epoch in tqdm(range(n_epochs)): + batch_losses = trainer(train_tasks, batch_size=batch_size) + epoch_losses.append(np.mean(batch_losses)) + + # Check for NaNs in the loss + loss = np.mean(epoch_losses) + self.assertFalse(np.isnan(loss)) + + def test_sliding_window_training(self): + """ + Test model training with sliding window tasks. + """ + tl = TaskLoader(context=self.da, target=self.da) + model = ConvNP(self.data_processor, tl, unet_channels=(5, 5, 5), verbose=False) + + # generate training tasks + n_train_dates = 3 + dates = [np.random.choice(self.da.time.values) for i in range(n_train_dates)] + train_tasks = tl.generate_tasks( + dates, + context_sampling="all", + target_sampling="all", + patch_strategy="sliding", + patch_size=(0.5, 0.5), + stride=(1, 1), + ) + + # Train + trainer = Trainer(model, lr=5e-5) + batch_size = None + n_epochs = 2 epoch_losses = [] for epoch in tqdm(range(n_epochs)): batch_losses = trainer(train_tasks, batch_size=batch_size) From 7b09119ebc8433ea127e1690bdae577fc5431605 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Tue, 23 Apr 2024 08:15:17 +0000 Subject: [PATCH 19/69] typo --- deepsensor/data/loader.py | 114 ++++++-------------------------------- 1 file changed, 16 insertions(+), 98 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 3065136a..96cf4a0c 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -865,81 +865,6 @@ def sample_random_window(self, patch_size: Tuple[float]) -> Sequence[float]: return bbox - def sample_sliding_window(self, patch_size: Tuple[float], stride: Tuple[float]) -> Sequence[float]: - """ - Sample data using sliding window from global coordinates to slice data. - - Parameters - ---------- - patch_size : Tuple[float] - Tuple of window extent - - Stride : Tuple[float] - Tuple of step size between each patch along x1 and x2 axis. - Returns - ------- - bbox: List[float] ## check type of return. - sequence of patch spatial extent as [x1_min, x1_max, x2_min, x2_max] - """ - # define patch size in x1/x2 - x1_extend, x2_extend = patch_size - - # define stride length in x1/x2 - dy, dx = stride - - # Calculate the global bounds of context and target set. - x1_min, x1_max, x2_min, x2_max = self.coord_bounds - - print("all the key variables in sliding window", x1_min, x1_max, dy, x2_min, x2_max, dx) - ## start with first patch top left hand corner at x1_min, x2_min - n_patches = 0 - patch_list = [] - - y = x1_min - while y < x1_max: - x = x2_min - while x < x2_max: - n_patches += 1 - if y + x1_extend > x1_max: - y0 = x1_max - x1_extend - else: - y0 = y - if x + x2_extend > x2_max: - x0 = x2_max - x2_extend - else: - x0 = x - - # bbox of x1_min, x1_max, x2_min, x2_max per patch - bbox = [y0, y0 + x1_extend, x0, x0 + x2_extend] - print('bbox', bbox) - patch_list.append(bbox) - x += dx # Increment x by dx - y += dy # Increment y by dy - - """ - for y in range(x1_min, x1_max, dy): - for x in range(x2_min, x2_max, dx): - n_patches += 1 - if y + x1_extend > x1_max: - y0 = x1_max - x1_extend - else: - y0 = y - if x + x2_extend > x2_max: - x0 = x2_max - x2_extend - else: - x0 = x - - # bbox of x1_min, x1_max, x2_min, x2_max per patch - bbox = [y0, y0 + x1_extend, x0, x0 + x2_extend] - - patch_list.append(bbox) - """ - ## I don't think we should actually print this here, but somehow we should - ## provide this information back, so users know the number of patches per date. - print("Number of patches per date using sliding window method", n_patches) - - return patch_list - def time_slice_variable(self, var, date, delta_t=0): """ Slice a variable by a given time delta. @@ -973,20 +898,20 @@ def time_slice_variable(self, var, date, delta_t=0): def spatial_slice_variable(self, var, window: List[float]): """ Slice a variable by a given window size. - Parameters - ---------- - var : ... - Variable to slice - window : ... - list of coordinates specifying the window [x1_min, x1_max, x2_min, x2_max] - Returns - ------- - var : ... - Sliced variable. - Raises - ------ - ValueError - If the variable is of an unknown type. + + Args: + var (...): + Variable to slice. + window (List[float]): + List of coordinates specifying the window [x1_min, x1_max, x2_min, x2_max]. + + Returns: + var (...) + Sliced variable. + + Raises: + ValueError + If the variable is of an unknown type. """ x1_min, x1_max, x2_min, x2_max = window if isinstance(var, (xr.Dataset, xr.DataArray)): @@ -1016,7 +941,6 @@ def spatial_slice_variable(self, var, window: List[float]): def task_generation( self, date: pd.Timestamp, - patch_strategy: Optional[str], context_sampling: Union[ str, int, @@ -1070,8 +994,7 @@ def task_generation( 0.5. bbox : Sequence[float], optional Bounding box to spatially slice the data, should be of the form [x1_min, x1_max, x2_min, x2_max]. - Useful when considering the entire available region is computationally prohibitive for model forward pass - and one resorts to patching strategies + Useful when considering the entire available region is computationally prohibitive for model forward pass. datewise_deterministic : bool Whether random sampling is datewise_deterministic based on the date. Default is ``False``. @@ -1281,7 +1204,7 @@ def sample_variable(var, sampling_strat, seed): for var, delta_t in zip(self.target, self.target_delta_t) ] - # check bbox + # check bbox size if bbox is not None: assert ( len(bbox) == 4 @@ -1294,8 +1217,6 @@ def sample_variable(var, sampling_strat, seed): target_slices = [ self.spatial_slice_variable(var, bbox) for var in target_slices ] - ## Do we want to patch before "gapfill" and "split" sampling plus adding - ## Auxilary data? # TODO move to method if ( @@ -1496,7 +1417,6 @@ def __call__( ] = None, split_frac: float = 0.5, patch_size: Sequence[float] = None, - stride: Sequence[float] = None, patch_strategy: Optional[str] = None, stride: Optional[Sequence[int]] = None, num_samples_per_date: int = 1, @@ -1567,7 +1487,6 @@ def __call__( Task object or list of task objects for each date containing the context and target data. """ - assert patch_strategy in [None, "random", "sliding"], ( f"Invalid patch strategy {patch_strategy}. " f"Must be one of [None, 'random', 'sliding']." @@ -1670,7 +1589,6 @@ def __call__( ) for bbox in bboxes ] - else: raise ValueError( f"Invalid patch strategy {patch_strategy}. " From 282c2bea234fc44b0bd35da92489dc1ac83edc33 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Wed, 24 Apr 2024 07:47:42 +0000 Subject: [PATCH 20/69] notebook with patchwise train --- deepsensor/data/loader.py | 27 +- docs/user-guide/patchwise_training.ipynb | 584 +++++++++++++++++++++++ 2 files changed, 601 insertions(+), 10 deletions(-) create mode 100644 docs/user-guide/patchwise_training.ipynb diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 96cf4a0c..4d159ce9 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1516,6 +1516,7 @@ def __call__( ) elif patch_strategy == "random": + assert ( patch_size is not None ), "Patch size must be specified for random patch sampling" @@ -1540,16 +1541,22 @@ def __call__( ] else: - bbox = self.sample_random_window(patch_size) - tasks = self.task_generation( - date=date, - bbox=bbox, - context_sampling=context_sampling, - target_sampling=target_sampling, - split_frac=split_frac, - datewise_deterministic=datewise_deterministic, - seed_override=seed_override, - ) + bboxes = [ + self.sample_random_window(patch_size) + for _ in range(num_samples_per_date) + ] + tasks = [ + self.task_generation( + date, + bbox=bbox, + context_sampling=context_sampling, + target_sampling=target_sampling, + split_frac=split_frac, + datewise_deterministic=datewise_deterministic, + seed_override=seed_override, + ) + for bbox in bboxes + ] elif patch_strategy == "sliding": # sliding window sampling of patch diff --git a/docs/user-guide/patchwise_training.ipynb b/docs/user-guide/patchwise_training.ipynb new file mode 100644 index 00000000..6e45a5d3 --- /dev/null +++ b/docs/user-guide/patchwise_training.ipynb @@ -0,0 +1,584 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Patchwise Training\n", + "\n", + "Environmental data can sometimes span large spatial areas. For example:\n", + "\n", + "- Modelling tasks based on data that span the entire globe\n", + "- Modelling tasks with high-resolution data\n", + "\n", + "In such cases, training and inference with a ConvNP over the entire region of data may be computationally prohibitive. However, we can resort to patchwise training, where the `TaskLoader` does not provide data of the entire region but instead creates smaller patches that are computationally feasible.\n", + "\n", + "The goal of the notebook is to demonstrate patchwise training and inference." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "logging.captureWarnings(True)\n", + "\n", + "import deepsensor.torch\n", + "from deepsensor.model import ConvNP\n", + "from deepsensor.train import Trainer, set_gpu_default_device\n", + "from deepsensor.data import DataProcessor, TaskLoader, construct_circ_time_ds\n", + "from deepsensor.data.sources import (\n", + " get_era5_reanalysis_data,\n", + " get_earthenv_auxiliary_data,\n", + " get_gldas_land_mask,\n", + ")\n", + "\n", + "import xarray as xr\n", + "import cartopy.crs as ccrs\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import numpy as np\n", + "from tqdm import tqdm_notebook" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Training/data config\n", + "data_range = (\"2010-01-01\", \"2019-12-31\")\n", + "train_range = (\"2010-01-01\", \"2018-12-31\")\n", + "val_range = (\"2019-01-01\", \"2019-12-31\")\n", + "date_subsample_factor = 2\n", + "extent = \"north_america\"\n", + "era5_var_IDs = [\"2m_temperature\"]\n", + "lowres_auxiliary_var_IDs = [\"elevation\"]\n", + "cache_dir = \"../../.datacache\"\n", + "deepsensor_folder = \"../deepsensor_config/\"\n", + "verbose_download = True" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading ERA5 data from Google Cloud Storage... " + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 120/120 [00:02<00:00, 50.27it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.41 GB loaded in 2.78 s\n" + ] + } + ], + "source": [ + "era5_raw_ds = get_era5_reanalysis_data(\n", + " era5_var_IDs,\n", + " extent,\n", + " date_range=data_range,\n", + " cache=True,\n", + " cache_dir=cache_dir,\n", + " verbose=verbose_download,\n", + " num_processes=8,\n", + ")\n", + "lowres_aux_raw_ds = get_earthenv_auxiliary_data(\n", + " lowres_auxiliary_var_IDs,\n", + " extent,\n", + " \"100KM\",\n", + " cache=True,\n", + " cache_dir=cache_dir,\n", + " verbose=verbose_download,\n", + ")\n", + "land_mask_raw_ds = get_gldas_land_mask(\n", + " extent, cache=True, cache_dir=cache_dir, verbose=verbose_download\n", + ")\n", + "\n", + "data_processor = DataProcessor(x1_name=\"lat\", x2_name=\"lon\")\n", + "era5_ds = data_processor(era5_raw_ds)\n", + "lowres_aux_ds, land_mask_ds = data_processor(\n", + " [lowres_aux_raw_ds, land_mask_raw_ds], method=\"min_max\"\n", + ")\n", + "\n", + "dates = pd.date_range(era5_ds.time.values.min(), era5_ds.time.values.max(), freq=\"D\")\n", + "doy_ds = construct_circ_time_ds(dates, freq=\"D\")\n", + "lowres_aux_ds[\"cos_D\"] = doy_ds[\"cos_D\"]\n", + "lowres_aux_ds[\"sin_D\"] = doy_ds[\"sin_D\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "set_gpu_default_device()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialise TaskLoader and ConvNP model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TaskLoader(3 context sets, 1 target sets)\n", + "Context variable IDs: (('2m_temperature',), ('GLDAS_mask',), ('elevation', 'cos_D', 'sin_D'))\n", + "Target variable IDs: (('2m_temperature',),)\n" + ] + } + ], + "source": [ + "task_loader = TaskLoader(\n", + " context=[era5_ds, land_mask_ds, lowres_aux_ds],\n", + " target=era5_ds,\n", + ")\n", + "task_loader.load_dask()\n", + "print(task_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dim_yc inferred from TaskLoader: (1, 1, 3)\n", + "dim_yt inferred from TaskLoader: 1\n", + "dim_aux_t inferred from TaskLoader: 0\n", + "internal_density inferred from TaskLoader: 400\n", + "encoder_scales inferred from TaskLoader: [0.0012499999720603228, 0.0012499999720603228, 0.00416666641831398]\n", + "decoder_scale inferred from TaskLoader: 0.0025\n" + ] + } + ], + "source": [ + "# Set up model\n", + "model = ConvNP(data_processor, task_loader, unet_channels=(32, 32, 32, 32, 32))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define how Tasks are generated\n", + "\n", + "For the purpose of this notebook, we will use a random patchwise training strategy for our training tasks and a sliding window patch strategy for validation and testing to make sure we cover the entire region of interest." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "def gen_training_tasks(dates, progress=True):\n", + " tasks = []\n", + " for date in tqdm_notebook(dates, disable=not progress):\n", + " tasks_per_date = task_loader(\n", + " date,\n", + " context_sampling=[\"all\", \"all\", \"all\"],\n", + " target_sampling=\"all\",\n", + " patch_strategy=\"random\",\n", + " patch_size=(0.4, 0.4),\n", + " num_samples_per_date=2,\n", + " )\n", + " tasks.extend(tasks_per_date)\n", + " return tasks\n", + "\n", + "\n", + "def gen_validation_tasks(dates, progress=True):\n", + " tasks = []\n", + " for date in tqdm_notebook(dates, disable=not progress):\n", + " tasks_per_date = task_loader(\n", + " date,\n", + " context_sampling=[\"all\", \"all\", \"all\"],\n", + " target_sampling=\"all\",\n", + " patch_strategy=\"sliding\",\n", + " patch_size=(0.5, 0.5),\n", + " stride=(1,1)\n", + " )\n", + " tasks.extend(tasks_per_date)\n", + " return tasks" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate validation tasks for testing generalisation" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "60a4044f573a45578ae505a11d3a7bc6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/183 [00:00 10\u001b[0m batch_losses \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_tasks\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 11\u001b[0m losses\u001b[38;5;241m.\u001b[39mappend(np\u001b[38;5;241m.\u001b[39mmean(batch_losses))\n\u001b[1;32m 12\u001b[0m val_rmses\u001b[38;5;241m.\u001b[39mappend(compute_val_rmse(model, val_tasks))\n", + "File \u001b[0;32m/mnt/SSD2/nils/deepsensor/deepsensor/train/train.py:177\u001b[0m, in \u001b[0;36mTrainer.__call__\u001b[0;34m(self, tasks, batch_size, progress_bar, tqdm_notebook)\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\n\u001b[1;32m 171\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 172\u001b[0m tasks: List[Task],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 175\u001b[0m tqdm_notebook\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 176\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m List[\u001b[38;5;28mfloat\u001b[39m]:\n\u001b[0;32m--> 177\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrain_epoch\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 178\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 179\u001b[0m \u001b[43m \u001b[49m\u001b[43mtasks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtasks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 180\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 181\u001b[0m \u001b[43m \u001b[49m\u001b[43mopt\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mopt\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 182\u001b[0m \u001b[43m \u001b[49m\u001b[43mprogress_bar\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprogress_bar\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 183\u001b[0m \u001b[43m \u001b[49m\u001b[43mtqdm_notebook\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtqdm_notebook\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 184\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/mnt/SSD2/nils/deepsensor/deepsensor/train/train.py:145\u001b[0m, in \u001b[0;36mtrain_epoch\u001b[0;34m(model, tasks, lr, batch_size, opt, progress_bar, tqdm_notebook)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 144\u001b[0m task \u001b[38;5;241m=\u001b[39m tasks[batch_i]\n\u001b[0;32m--> 145\u001b[0m batch_loss \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtask\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 146\u001b[0m batch_losses\u001b[38;5;241m.\u001b[39mappend(batch_loss)\n\u001b[1;32m 148\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m batch_losses\n", + "File \u001b[0;32m/mnt/SSD2/nils/deepsensor/deepsensor/train/train.py:116\u001b[0m, in \u001b[0;36mtrain_epoch..train_step\u001b[0;34m(tasks)\u001b[0m\n\u001b[1;32m 114\u001b[0m task_losses \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m task \u001b[38;5;129;01min\u001b[39;00m tasks:\n\u001b[0;32m--> 116\u001b[0m task_losses\u001b[38;5;241m.\u001b[39mappend(\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloss_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnormalise\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m)\n\u001b[1;32m 117\u001b[0m mean_batch_loss \u001b[38;5;241m=\u001b[39m B\u001b[38;5;241m.\u001b[39mmean(B\u001b[38;5;241m.\u001b[39mstack(\u001b[38;5;241m*\u001b[39mtask_losses))\n\u001b[1;32m 118\u001b[0m mean_batch_loss\u001b[38;5;241m.\u001b[39mbackward()\n", + "File \u001b[0;32m/mnt/SSD2/nils/deepsensor/deepsensor/model/convnp.py:865\u001b[0m, in \u001b[0;36mConvNP.loss_fn\u001b[0;34m(self, task, fix_noise, num_lv_samples, normalise)\u001b[0m\n\u001b[1;32m 839\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mloss_fn\u001b[39m(\n\u001b[1;32m 840\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 841\u001b[0m task: Task,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 844\u001b[0m normalise: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 845\u001b[0m ):\n\u001b[1;32m 846\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 847\u001b[0m \u001b[38;5;124;03m Compute the loss of a task.\u001b[39;00m\n\u001b[1;32m 848\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 863\u001b[0m \u001b[38;5;124;03m float: The loss.\u001b[39;00m\n\u001b[1;32m 864\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 865\u001b[0m task \u001b[38;5;241m=\u001b[39m \u001b[43mConvNP\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodify_task\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtask\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 867\u001b[0m context_data, xt, yt, model_kwargs \u001b[38;5;241m=\u001b[39m convert_task_to_nps_args(task)\n\u001b[1;32m 869\u001b[0m logpdfs \u001b[38;5;241m=\u001b[39m backend\u001b[38;5;241m.\u001b[39mnps\u001b[38;5;241m.\u001b[39mloglik(\n\u001b[1;32m 870\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel,\n\u001b[1;32m 871\u001b[0m context_data,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 877\u001b[0m normalise\u001b[38;5;241m=\u001b[39mnormalise,\n\u001b[1;32m 878\u001b[0m )\n", + "File \u001b[0;32m/mnt/SSD2/nils/deepsensor/deepsensor/model/convnp.py:379\u001b[0m, in \u001b[0;36mConvNP.modify_task\u001b[0;34m(cls, task)\u001b[0m\n\u001b[1;32m 365\u001b[0m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[1;32m 366\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmodify_task\u001b[39m(\u001b[38;5;28mcls\u001b[39m, task: Task):\n\u001b[1;32m 367\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 368\u001b[0m \u001b[38;5;124;03m Cast numpy arrays to TensorFlow or PyTorch tensors, add batch dim, and\u001b[39;00m\n\u001b[1;32m 369\u001b[0m \u001b[38;5;124;03m mask NaNs.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 376\u001b[0m \u001b[38;5;124;03m ...: ...\u001b[39;00m\n\u001b[1;32m 377\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 379\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch_dim\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[43mtask\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mops\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m:\n\u001b[1;32m 380\u001b[0m task \u001b[38;5;241m=\u001b[39m task\u001b[38;5;241m.\u001b[39madd_batch_dim()\n\u001b[1;32m 381\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m task[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mops\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n", + "\u001b[0;31mTypeError\u001b[0m: string indices must be integers" + ] + } + ], + "source": [ + "num_epochs = 10\n", + "losses = []\n", + "val_rmses = []\n", + "\n", + "# Train model\n", + "val_rmse_best = np.inf\n", + "trainer = Trainer(model, lr=5e-5)\n", + "for epoch in tqdm_notebook(range(num_epochs)):\n", + " train_tasks = gen_training_tasks(pd.date_range(train_range[0], train_range[1])[::date_subsample_factor], progress=True)\n", + " batch_losses = trainer(train_tasks)\n", + " losses.append(np.mean(batch_losses))\n", + " val_rmses.append(compute_val_rmse(model, val_tasks))\n", + " if val_rmses[-1] < val_rmse_best:\n", + " val_rmse_best = val_rmses[-1]\n", + " model.save(deepsensor_folder)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n", + "axes[0].plot(losses)\n", + "axes[1].plot(val_rmses)\n", + "_ = axes[0].set_xlabel(\"Epoch\")\n", + "_ = axes[1].set_xlabel(\"Epoch\")\n", + "_ = axes[0].set_title(\"Training loss\")\n", + "_ = axes[1].set_title(\"Validation RMSE\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sensorEnv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From dfa386d12349662612f54b78a14d95a79d559824 Mon Sep 17 00:00:00 2001 From: Martin Rogers Date: Wed, 24 Apr 2024 17:10:19 +0100 Subject: [PATCH 21/69] refining stride to avoid error --- deepsensor/data/loader.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 4d159ce9..540f0a2a 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1365,8 +1365,11 @@ def sample_sliding_window( # define patch size in x1/x2 x1_extend, x2_extend = patch_size - # define stride length in x1/x2 - dy, dx = stride[0] * x1_extend, stride[1] * x2_extend + # define stride length in x1/x2 or set to patch_size if undefined + if stride is None: + stride = patch_size + + dy, dx = stride # Calculate the global bounds of context and target set. x1_min, x1_max, x2_min, x2_max = self.coord_bounds @@ -1390,10 +1393,6 @@ def sample_sliding_window( patch_list.append(bbox) - ## I don't think we should actually print this here, but somehow we should - ## provide this information back, so users know the number of patches per date. - print("Number of patches per date using sliding window method", len(patch_list)) - return patch_list def __call__( From 8d466539c6b338ded20f219c804f8075e715be68 Mon Sep 17 00:00:00 2001 From: Martin Rogers Date: Sat, 27 Apr 2024 21:10:25 +0100 Subject: [PATCH 22/69] inference patching --- deepsensor/data/loader.py | 1 + deepsensor/model/model.py | 41 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 540f0a2a..be723e21 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1185,6 +1185,7 @@ def sample_variable(var, sampling_strat, seed): task["time"] = date task["ops"] = [] + task["bbox"] = bbox task["X_c"] = [] task["Y_c"] = [] if target_sampling is not None: diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index d32d0f07..6d726416 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -621,6 +621,47 @@ def unnormalise_pred_array(arr, **kwargs): return pred + def predict_patch( + self, + tasks: Union[List[Task], Task], + X_t: Union[ + xr.Dataset, + xr.DataArray, + pd.DataFrame, + pd.Series, + pd.Index, + np.ndarray, + ],)-> Prediction: + + """ + Predict patches and subsequently stiching patches to produce prediction at original extent. + Predict on a regular grid or at off-grid locations. + + Args: + tasks (List[Task] | Task): + List of tasks containing context data. + X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`): + Target locations to predict at. Can be an xarray object + containingon-grid locations or a pandas object containing off-grid locations. + Returns: + :class:`~.model.pred.Prediction`): + A `dict`-like object mapping from target variable IDs to xarray or pandas objects + containing model predictions. + - If ``X_t`` is a pandas object, returns pandas objects + containing off-grid predictions. + - If ``X_t`` is an xarray object, returns xarray object + containing on-grid predictions. + - If ``n_samples`` == 0, returns only mean and std predictions. + - If ``n_samples`` > 0, returns mean, std and samples + predictions. + """ + + # Identify extent of original dataframe + for task in tasks: + pred = predict(task, X_t) + + return pred + def main(): # pragma: no cover import deepsensor.tensorflow from deepsensor.data.loader import TaskLoader From acbad8b2842a3439c392c7ffb475a7e24a553156 Mon Sep 17 00:00:00 2001 From: Martin Rogers Date: Sun, 28 Apr 2024 20:49:12 +0100 Subject: [PATCH 23/69] predict_patches --- deepsensor/model/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 6d726416..21d17b54 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -658,7 +658,7 @@ def predict_patch( # Identify extent of original dataframe for task in tasks: - pred = predict(task, X_t) + pred = self.predict(task, X_t) return pred From 3e2994e2f171e6300ae67e57bc5622866f975d30 Mon Sep 17 00:00:00 2001 From: Martin Rogers Date: Fri, 3 May 2024 17:13:15 +0100 Subject: [PATCH 24/69] patchwise predictions during inference and stitching --- deepsensor/model/model.py | 139 ++++++++++++++++++++++++++++++++++---- 1 file changed, 124 insertions(+), 15 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 21d17b54..7e096ae5 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -621,20 +621,39 @@ def unnormalise_pred_array(arr, **kwargs): return pred - def predict_patch( - self, - tasks: Union[List[Task], Task], - X_t: Union[ - xr.Dataset, - xr.DataArray, - pd.DataFrame, - pd.Series, - pd.Index, - np.ndarray, - ],)-> Prediction: - + def predict_patch( + self, + tasks: Union[List[Task], Task], + X_t: Union[ + xr.Dataset, + xr.DataArray, + pd.DataFrame, + pd.Series, + pd.Index, + np.ndarray, + ], + X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None, + X_t_is_normalised: bool = False, + aux_at_targets_override: Union[xr.Dataset, xr.DataArray] = None, + aux_at_targets_override_is_normalised: bool = False, + resolution_factor: int = 1, + pred_params: tuple[str] = ("mean", "std"), + n_samples: int = 0, + ar_sample: bool = False, + ar_subsample_factor: int = 1, + unnormalise: bool = False, + seed: int = 0, + append_indexes: dict = None, + progress_bar: int = 0, + verbose: bool = False, + data_processor: Union[ + xr.DataArray, + xr.Dataset, + pd.DataFrame, + List[Union[xr.DataArray, xr.Dataset, pd.DataFrame]], + ] = None, + ) -> Prediction: """ - Predict patches and subsequently stiching patches to produce prediction at original extent. Predict on a regular grid or at off-grid locations. Args: @@ -643,6 +662,45 @@ def predict_patch( X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`): Target locations to predict at. Can be an xarray object containingon-grid locations or a pandas object containing off-grid locations. + X_t_mask: :class:`xarray.Dataset` | :class:`xarray.DataArray`, optional + 2D mask to apply to gridded ``X_t`` (zero/False will be NaNs). Will be interpolated + to the same grid as ``X_t``. Default None (no mask). + X_t_is_normalised (bool): + Whether the ``X_t`` coords are normalised. If False, will normalise + the coords before passing to model. Default ``False``. + aux_at_targets_override (:class:`xarray.Dataset` | :class:`xarray.DataArray`): + Optional auxiliary xarray data to override from the task_loader. + aux_at_targets_override_is_normalised (bool): + Whether the `aux_at_targets_override` coords are normalised. + If False, the DataProcessor will normalise the coords before passing to model. + Default False. + pred_params (tuple[str]): + Tuple of prediction parameters to return. The strings refer to methods + of the model class which will be called and stored in the Prediction object. + Default ("mean", "std"). + resolution_factor (float): + Optional factor to increase the resolution of the target grid + by. E.g. 2 will double the target resolution, 0.5 will halve + it.Applies to on-grid predictions only. Default 1. + n_samples (int): + Number of joint samples to draw from the model. If 0, will not + draw samples. Default 0. + ar_sample (bool): + Whether to use autoregressive sampling. Default ``False``. + unnormalise (bool): + Whether to unnormalise the predictions. Only works if ``self`` + hasa ``data_processor`` and ``task_loader`` attribute. Default + ``True``. + seed (int): + Random seed for deterministic sampling. Default 0. + append_indexes (dict): + Dictionary of index metadata to append to pandas indexes in the + off-grid case. Default ``None``. + progress_bar (int): + Whether to display a progress bar over tasks. Default 0. + verbose (bool): + Whether to print time taken for prediction. Default ``False``. + Returns: :class:`~.model.pred.Prediction`): A `dict`-like object mapping from target variable IDs to xarray or pandas objects @@ -654,13 +712,64 @@ def predict_patch( - If ``n_samples`` == 0, returns only mean and std predictions. - If ``n_samples`` > 0, returns mean, std and samples predictions. + + Raises: + ValueError + If ``X_t`` is not an xarray object and + ``resolution_factor`` is not 1 or ``ar_subsample_factor`` is + not 1. + ValueError + If ``X_t`` is not a pandas object and ``append_indexes`` is not + ``None``. + ValueError + If ``X_t`` is not an xarray, pandas or numpy object. + ValueError + If ``append_indexes`` are not all the same length as ``X_t``. """ # Identify extent of original dataframe + preds = [] for task in tasks: - pred = self.predict(task, X_t) + bbox = task['bbox'] - return pred + # Determine X_t for the patched task in original coordinates. + x1 = xr.DataArray([bbox[0], bbox[1]], dims='x1', name='x1') + x2 = xr.DataArray([bbox[2], bbox[3]], dims='x2', name='x2') + bbox_norm = xr.Dataset(coords={'x1': x1, 'x2': x2}) + + bbox_unnorm = data_processor.unnormalise(bbox_norm) + unnorm_bbox_x1 = bbox_unnorm['x'].values.min(), bbox_unnorm['x'].values.max() + unnorm_bbox_x2 = bbox_unnorm['y'].values.min(), bbox_unnorm['y'].values.max() + + task_X_t = X_t.sel(x = slice(unnorm_bbox_x1[0], unnorm_bbox_x1[1]), + y = slice(unnorm_bbox_x2[0], unnorm_bbox_x2[1])) + + pred = self.predict(task, task_X_t) + preds.append(pred) + + pred_copy = copy.deepcopy(preds[0]) + + for var_name_copy, data_array_copy in pred_copy.items(): + + # set x and y coords + stitched_preds = xr.Dataset(coords={'x': X_t['x'], 'y': X_t['y']}) + + # Set time to same as patched prediction + stitched_preds['time'] = data_array_copy['time'] + + # set variable names to those in patched prediction, make values blank + for var_name_i in data_array_copy.data_vars: + stitched_preds[var_name_i] = data_array_copy[var_name_i] + stitched_preds.attrs.clear() + pred_copy[var_name_copy]= stitched_preds + + for pred in preds: + for var_name, data_array in pred.items(): + if var_name in pred_copy: + unnorm_patch_x1 = data_array['x'].min().values, data_array['x'].max().values + unnorm_patch_x2 = data_array['y'].min().values, data_array['y'].max().values + pred_copy[var_name].loc[{'x': slice(unnorm_patch_x1[0], unnorm_patch_x1[1]), 'y': slice(unnorm_patch_x2[0], unnorm_patch_x2[1])}] = data_array + return preds def main(): # pragma: no cover import deepsensor.tensorflow From 765849dcc225035327271d25b1b27da16f38f1e2 Mon Sep 17 00:00:00 2001 From: Martin Rogers Date: Fri, 3 May 2024 17:27:59 +0100 Subject: [PATCH 25/69] fix typo --- deepsensor/model/model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 7e096ae5..020ba658 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -744,9 +744,11 @@ def predict_patch( task_X_t = X_t.sel(x = slice(unnorm_bbox_x1[0], unnorm_bbox_x1[1]), y = slice(unnorm_bbox_x2[0], unnorm_bbox_x2[1])) + # Patchwise prediction pred = self.predict(task, task_X_t) preds.append(pred) + # Produce a blank xarray to stitch patched predictions to. pred_copy = copy.deepcopy(preds[0]) for var_name_copy, data_array_copy in pred_copy.items(): @@ -763,13 +765,15 @@ def predict_patch( stitched_preds.attrs.clear() pred_copy[var_name_copy]= stitched_preds + # Stitch patchwise predictions for pred in preds: for var_name, data_array in pred.items(): if var_name in pred_copy: unnorm_patch_x1 = data_array['x'].min().values, data_array['x'].max().values unnorm_patch_x2 = data_array['y'].min().values, data_array['y'].max().values - pred_copy[var_name].loc[{'x': slice(unnorm_patch_x1[0], unnorm_patch_x1[1]), 'y': slice(unnorm_patch_x2[0], unnorm_patch_x2[1])}] = data_array - return preds + pred_copy[var_name].loc[{'x': slice(unnorm_patch_x1[0], unnorm_patch_x1[1]), + 'y': slice(unnorm_patch_x2[0], unnorm_patch_x2[1])}] = data_array + return pred_copy def main(): # pragma: no cover import deepsensor.tensorflow From 7f8ef93d71dabdb0cc5316f7054298df79eb1f32 Mon Sep 17 00:00:00 2001 From: Martin Rogers Date: Mon, 24 Jun 2024 16:23:23 +0100 Subject: [PATCH 26/69] new cropped stitching --- deepsensor/model/model.py | 185 +++++++++++++++++++++++++++++++++----- 1 file changed, 164 insertions(+), 21 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 020ba658..5da78986 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -632,6 +632,12 @@ def predict_patch( pd.Index, np.ndarray, ], + data_processor: Union[ + xr.DataArray, + xr.Dataset, + pd.DataFrame, + List[Union[xr.DataArray, xr.Dataset, pd.DataFrame]], + ], X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None, X_t_is_normalised: bool = False, aux_at_targets_override: Union[xr.Dataset, xr.DataArray] = None, @@ -646,12 +652,7 @@ def predict_patch( append_indexes: dict = None, progress_bar: int = 0, verbose: bool = False, - data_processor: Union[ - xr.DataArray, - xr.Dataset, - pd.DataFrame, - List[Union[xr.DataArray, xr.Dataset, pd.DataFrame]], - ] = None, + ) -> Prediction: """ Predict on a regular grid or at off-grid locations. @@ -659,6 +660,8 @@ def predict_patch( Args: tasks (List[Task] | Task): List of tasks containing context data. + data_processor (:class:`~.data.processor.DataProcessor`): + Used for unnormalising the coordinates of the bounding boxes of patches. X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`): Target locations to predict at. Can be an xarray object containingon-grid locations or a pandas object containing off-grid locations. @@ -726,6 +729,143 @@ def predict_patch( ValueError If ``append_indexes`` are not all the same length as ``X_t``. """ + ## To do, do we need to add patch and stride as an additional argument? + def get_patches_per_row(preds, X_t): + """ + Calculate number of patches per row. + Required to stitch patches back together. + Args: + preds (List[class:`~.model.pred.Prediction`]): + A list of `dict`-like objects containing patchwise predictions. + + Returns: + patches_per_row (int) + Number of patches per row. + """ + patches_per_row = 0 + vars = list(preds[0][0].data_vars) + var = vars[0] + + for p in preds: + if p[0][var].coords['y'].min() == X_t.coords['y'].min(): + patches_per_row = patches_per_row + 1 + return patches_per_row + + + # Calculate overlap between adjacent patches in pixels + def get_patch_overlap(overlap_norm, data_processor, amsr_raw_ds): + overlap_list = [0, overlap_norm[0], 0, overlap_norm[1]] + x1 = xr.DataArray([overlap_list[0], overlap_list[1]], dims='x1', name='x1') + x2 = xr.DataArray([overlap_list[2], overlap_list[3]], dims='x2', name='x2') + overlap_norm_xr = xr.Dataset(coords={'x1': x1, 'x2': x2}) + + # Unnormalise coordinates of bounding boxes + overlap_unnorm_xr = data_processor.unnormalise(overlap_norm_xr) + unnorm_overlap_x1 = overlap_unnorm_xr.coords['x'].values[1] + unnorm_overlap_x2 = overlap_unnorm_xr.coords['y'].values[1] + + # Find the position of these indices within the DataArray + x_overlap_index = int(np.ceil((np.argmin(np.abs(amsr_raw_ds.coords['x'].values - unnorm_overlap_x1))/2))) + y_overlap_index = int(np.ceil((np.argmin(np.abs(amsr_raw_ds.coords['y'].values - unnorm_overlap_x2))/2))) + + return x_overlap_index, y_overlap_index + + + ## To do- change amsr_raw_ds to what? + def get_index(*args, x1 = True) -> Union[int, Tuple[List[int], List[int]]]: + """ + Convert coordinates into pixel row/column (index). + + Parameters + ---------- + args : tuple + If one argument (numeric), it represents the coordinate value. + If two arguments (lists), they represent lists of coordinate values. + + x1 : bool, optional + If True, compute index for x1 (default is True). + + Returns + ------- + Union[int, Tuple[List[int], List[int]]] + If one argument is provided and x1 is True or False, returns the index position. + If two arguments are provided, returns a tuple containing two lists: + - First list: indices corresponding to x1 coordinates. + - Second list: indices corresponding to x2 coordinates. + + """ + if len(args) == 1: + patch_coord = args + if x1: + coord_index = np.argmin(np.abs(amsr_raw_ds.coords['y'].values - patch_coord)) + else: + coord_index = np.argmin(np.abs(amsr_raw_ds.coords['x'].values - patch_coord)) + return coord_index + + elif len(args) == 2: + patch_x1, patch_x2 = args + x1_index = [np.argmin(np.abs(amsr_raw_ds.coords['y'].values - target_x1)) for target_x1 in patch_x1] + x2_index = [np.argmin(np.abs(amsr_raw_ds.coords['x'].values - target_x2)) for target_x2 in patch_x2] + return (x1_index, x2_index) + + + def stitch_clipped_predictions(patches, pred_copy, border): + + data_x1 = amsr_raw_ds.coords['y'].min().values, amsr_raw_ds.coords['y'].max().values + data_x2 = amsr_raw_ds.coords['x'].min().values, amsr_raw_ds.coords['x'].max().values + data_x1_index, data_x2_index = get_index(data_x1, data_x2) + patches_clipped = [] + + for i, patch_pred in enumerate(patch_preds): + for var_name, data_array in patch_pred.items(): #previously patch + if var_name in patch_pred: + # Get row/col index values of each patch + patch_x1 = data_array.coords['y'].min().values, data_array.coords['y'].max().values + patch_x2 = data_array.coords['x'].min().values, data_array.coords['x'].max().values + patch_x1_index, patch_x2_index = get_index(patch_x1, patch_x2) + + b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0] + b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1] + # Do not remove border for the patches along top and left of dataset + # and change overlap size for last patch in rows and columns. + if patch_x2_index[0] == data_x2_index[0]: + b_x2_min = 0 + elif patch_x2_index[1] == data_x2_index[1]: + b_x2_max = 0 + patch_row_prev = preds[i-1] + prev_patch_x2_max = get_index(int(patch_row_prev[var_name].coords['x'].max()), x1 = False) + b_x2_min = (prev_patch_x2_max - patch_x2_index[0])-patch_overlap[1] + + if patch_x1_index[0] == data_x1_index[0]: + b_x1_min = 0 + elif abs(patch_x1_index[1] - data_x1_index[1])<2: + b_x1_max = 0 + patch_prev = preds[i-patches_per_row] + prev_patch_x1_max = get_index(int(patch_prev[var_name].coords['y'].max()), x1 = True) + b_x1_min = (prev_patch_x1_max - patch_x1_index[0])- patch_overlap[0] + + patch_clip_x1_min = int(b_x1_min) + patch_clip_x1_max = int(data_array.sizes['y'] - b_x1_max) + patch_clip_x2_min = int(b_x2_min) + patch_clip_x2_max = int(data_array.sizes['x'] - b_x2_max) + + patch_clip = data_array.isel(y=slice(patch_clip_x1_min, patch_clip_x1_max), + x=slice(patch_clip_x2_min, patch_clip_x2_max)) + + patches_clipped.append(patch_clip) + + combined = xr.combine_by_coords(patches_clipped, compat='no_conflicts') + return combined + + def stitch_predictions(preds, pred_copy): + for pred in preds: + for var_name, data_array in pred.items(): + if var_name in pred_copy: + unnorm_patch_x1 = data_array['x'].min().values, data_array['x'].max().values + unnorm_patch_x2 = data_array['y'].min().values, data_array['y'].max().values + pred_copy[var_name].loc[{'x': slice(unnorm_patch_x1[0], unnorm_patch_x1[1]), 'y': slice(unnorm_patch_x2[0], unnorm_patch_x2[1])}] = data_array + return pred_copy + # Identify extent of original dataframe preds = [] @@ -736,21 +876,29 @@ def predict_patch( x1 = xr.DataArray([bbox[0], bbox[1]], dims='x1', name='x1') x2 = xr.DataArray([bbox[2], bbox[3]], dims='x2', name='x2') bbox_norm = xr.Dataset(coords={'x1': x1, 'x2': x2}) - + # Unnormalise coordinates of bounding boxes bbox_unnorm = data_processor.unnormalise(bbox_norm) unnorm_bbox_x1 = bbox_unnorm['x'].values.min(), bbox_unnorm['x'].values.max() unnorm_bbox_x2 = bbox_unnorm['y'].values.min(), bbox_unnorm['y'].values.max() - + # Determine X_t for patch task_X_t = X_t.sel(x = slice(unnorm_bbox_x1[0], unnorm_bbox_x1[1]), y = slice(unnorm_bbox_x2[0], unnorm_bbox_x2[1])) - # Patchwise prediction pred = self.predict(task, task_X_t) + # Append patchwise DeepSensor prediction object to list preds.append(pred) - - # Produce a blank xarray to stitch patched predictions to. + + overlap_norm = tuple(patch - stride for patch, stride in zip(patch_size, stride_size)) + x_overlap_index, y_overlap_index = get_patch_overlap(overlap_norm, data_processor, amsr_raw_ds) + patch_overlap = (x_overlap_index, y_overlap_index) + patches_per_row = get_patches_per_row(preds, X_t) + + + + pred_copy = copy.deepcopy(preds[0]) + # Generate new blank DeepSensor.prediction object in original coordinate system. for var_name_copy, data_array_copy in pred_copy.items(): # set x and y coords @@ -762,18 +910,13 @@ def predict_patch( # set variable names to those in patched prediction, make values blank for var_name_i in data_array_copy.data_vars: stitched_preds[var_name_i] = data_array_copy[var_name_i] - stitched_preds.attrs.clear() + stitched_preds[var_name_i][:] = np.nan pred_copy[var_name_copy]= stitched_preds - # Stitch patchwise predictions - for pred in preds: - for var_name, data_array in pred.items(): - if var_name in pred_copy: - unnorm_patch_x1 = data_array['x'].min().values, data_array['x'].max().values - unnorm_patch_x2 = data_array['y'].min().values, data_array['y'].max().values - pred_copy[var_name].loc[{'x': slice(unnorm_patch_x1[0], unnorm_patch_x1[1]), - 'y': slice(unnorm_patch_x2[0], unnorm_patch_x2[1])}] = data_array - return pred_copy + + prediction = stitch_predictions(preds, pred_copy) + + return preds def main(): # pragma: no cover import deepsensor.tensorflow From 847a47ca044b0f8dc259ff6ce8db5962b25e7f03 Mon Sep 17 00:00:00 2001 From: Martin Rogers Date: Wed, 26 Jun 2024 16:18:20 +0100 Subject: [PATCH 27/69] clipped patchwise predictions, single date --- deepsensor/model/model.py | 103 ++++++++++++++++++++++++++------------ 1 file changed, 71 insertions(+), 32 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 5da78986..4b0f1b8c 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -638,6 +638,8 @@ def predict_patch( pd.DataFrame, List[Union[xr.DataArray, xr.Dataset, pd.DataFrame]], ], + stride_size: Union[float, tuple[float]], + patch_size: Union[float, tuple[float]], X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None, X_t_is_normalised: bool = False, aux_at_targets_override: Union[xr.Dataset, xr.DataArray] = None, @@ -730,7 +732,7 @@ def predict_patch( If ``append_indexes`` are not all the same length as ``X_t``. """ ## To do, do we need to add patch and stride as an additional argument? - def get_patches_per_row(preds, X_t): + def get_patches_per_row(preds, X_t) -> int: """ Calculate number of patches per row. Required to stitch patches back together. @@ -739,7 +741,7 @@ def get_patches_per_row(preds, X_t): A list of `dict`-like objects containing patchwise predictions. Returns: - patches_per_row (int) + patches_per_row: int Number of patches per row. """ patches_per_row = 0 @@ -752,8 +754,28 @@ def get_patches_per_row(preds, X_t): return patches_per_row - # Calculate overlap between adjacent patches in pixels - def get_patch_overlap(overlap_norm, data_processor, amsr_raw_ds): + + def get_patch_overlap(overlap_norm, data_processor, X_t_ds): + """ + Calculate overlap between adjacent patches in pixels. + + Parameters + ---------- + overlap_norm : tuple[float]. + Normalised size of overlap in x1/x2. + + data_processor (:class:`~.data.processor.DataProcessor`): + Used for unnormalising the coordinates of the bounding boxes of patches. + + X_t_ds (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`): + Data array containing target locations to predict at. + + Returns + ------- + patch_overlap : tuple (int) + Unnormalised size of overlap between adjacent patches. + """ + # Place stride and patch size values in Xarray to pass into unnormalise() overlap_list = [0, overlap_norm[0], 0, overlap_norm[1]] x1 = xr.DataArray([overlap_list[0], overlap_list[1]], dims='x1', name='x1') x2 = xr.DataArray([overlap_list[2], overlap_list[3]], dims='x2', name='x2') @@ -765,10 +787,11 @@ def get_patch_overlap(overlap_norm, data_processor, amsr_raw_ds): unnorm_overlap_x2 = overlap_unnorm_xr.coords['y'].values[1] # Find the position of these indices within the DataArray - x_overlap_index = int(np.ceil((np.argmin(np.abs(amsr_raw_ds.coords['x'].values - unnorm_overlap_x1))/2))) - y_overlap_index = int(np.ceil((np.argmin(np.abs(amsr_raw_ds.coords['y'].values - unnorm_overlap_x2))/2))) + x_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords['x'].values - unnorm_overlap_x1))/2))) + y_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords['y'].values - unnorm_overlap_x2))/2))) + xy_overlap = (x_overlap_index, y_overlap_index) - return x_overlap_index, y_overlap_index + return xy_overlap ## To do- change amsr_raw_ds to what? @@ -797,24 +820,40 @@ def get_index(*args, x1 = True) -> Union[int, Tuple[List[int], List[int]]]: if len(args) == 1: patch_coord = args if x1: - coord_index = np.argmin(np.abs(amsr_raw_ds.coords['y'].values - patch_coord)) + coord_index = np.argmin(np.abs(X_t.coords['y'].values - patch_coord)) else: - coord_index = np.argmin(np.abs(amsr_raw_ds.coords['x'].values - patch_coord)) + coord_index = np.argmin(np.abs(X_t.coords['x'].values - patch_coord)) return coord_index elif len(args) == 2: patch_x1, patch_x2 = args - x1_index = [np.argmin(np.abs(amsr_raw_ds.coords['y'].values - target_x1)) for target_x1 in patch_x1] - x2_index = [np.argmin(np.abs(amsr_raw_ds.coords['x'].values - target_x2)) for target_x2 in patch_x2] + x1_index = [np.argmin(np.abs(X_t.coords['y'].values - target_x1)) for target_x1 in patch_x1] + x2_index = [np.argmin(np.abs(X_t.coords['x'].values - target_x2)) for target_x2 in patch_x2] return (x1_index, x2_index) - def stitch_clipped_predictions(patches, pred_copy, border): + def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row): + """ + Stitch patchwise predictions to form prediction at original extent. + + Parameters + ---------- + args : tuple + If one argument (numeric), it represents the coordinate value. + If two arguments (lists), they represent lists of coordinate values. + + x1 : bool, optional + If True, compute index for x1 (default is True). - data_x1 = amsr_raw_ds.coords['y'].min().values, amsr_raw_ds.coords['y'].max().values - data_x2 = amsr_raw_ds.coords['x'].min().values, amsr_raw_ds.coords['x'].max().values + Returns + ------- + """ + + data_x1 = X_t.coords['y'].min().values, X_t.coords['y'].max().values + data_x2 = X_t.coords['x'].min().values, X_t.coords['x'].max().values data_x1_index, data_x2_index = get_index(data_x1, data_x2) - patches_clipped = [] + patches_clipped = {var_name: [] for var_name in patch_preds[0].keys()} + for i, patch_pred in enumerate(patch_preds): for var_name, data_array in patch_pred.items(): #previously patch @@ -852,9 +891,11 @@ def stitch_clipped_predictions(patches, pred_copy, border): patch_clip = data_array.isel(y=slice(patch_clip_x1_min, patch_clip_x1_max), x=slice(patch_clip_x2_min, patch_clip_x2_max)) - patches_clipped.append(patch_clip) + patches_clipped[var_name].append(patch_clip) - combined = xr.combine_by_coords(patches_clipped, compat='no_conflicts') + combined = {var_name: xr.combine_by_coords(patches, compat='no_conflicts') for var_name, patches in patches_clipped.items()} + + #combined = xr.combine_by_coords(patches_clipped, compat='no_conflicts') return combined def stitch_predictions(preds, pred_copy): @@ -867,16 +908,15 @@ def stitch_predictions(preds, pred_copy): return pred_copy - # Identify extent of original dataframe + # Perform patchwise predictions preds = [] for task in tasks: bbox = task['bbox'] - # Determine X_t for the patched task in original coordinates. + # Unnormalise coordinates of bounding box of patch x1 = xr.DataArray([bbox[0], bbox[1]], dims='x1', name='x1') x2 = xr.DataArray([bbox[2], bbox[3]], dims='x2', name='x2') bbox_norm = xr.Dataset(coords={'x1': x1, 'x2': x2}) - # Unnormalise coordinates of bounding boxes bbox_unnorm = data_processor.unnormalise(bbox_norm) unnorm_bbox_x1 = bbox_unnorm['x'].values.min(), bbox_unnorm['x'].values.max() unnorm_bbox_x2 = bbox_unnorm['y'].values.min(), bbox_unnorm['y'].values.max() @@ -887,19 +927,18 @@ def stitch_predictions(preds, pred_copy): pred = self.predict(task, task_X_t) # Append patchwise DeepSensor prediction object to list preds.append(pred) - + + overlap_norm = tuple(patch - stride for patch, stride in zip(patch_size, stride_size)) - x_overlap_index, y_overlap_index = get_patch_overlap(overlap_norm, data_processor, amsr_raw_ds) - patch_overlap = (x_overlap_index, y_overlap_index) + patch_overlap_unnorm = get_patch_overlap(overlap_norm, data_processor, X_t) patches_per_row = get_patches_per_row(preds, X_t) + stitched_prediction = stitch_clipped_predictions(preds, patch_overlap_unnorm, patches_per_row) - - - - pred_copy = copy.deepcopy(preds[0]) + ## Change prediction into DeepSensor.Prediction object. + prediction= copy.deepcopy(preds[0]) # Generate new blank DeepSensor.prediction object in original coordinate system. - for var_name_copy, data_array_copy in pred_copy.items(): + for var_name_copy, data_array_copy in prediction.items(): # set x and y coords stitched_preds = xr.Dataset(coords={'x': X_t['x'], 'y': X_t['y']}) @@ -911,12 +950,12 @@ def stitch_predictions(preds, pred_copy): for var_name_i in data_array_copy.data_vars: stitched_preds[var_name_i] = data_array_copy[var_name_i] stitched_preds[var_name_i][:] = np.nan - pred_copy[var_name_copy]= stitched_preds + prediction[var_name_copy]= stitched_preds + prediction[var_name_copy] = stitched_prediction[var_name_copy] + #prediction = stitch_predictions(preds, pred_copy) - prediction = stitch_predictions(preds, pred_copy) - - return preds + return prediction def main(): # pragma: no cover import deepsensor.tensorflow From f93fc39588492fdf5721fa0dc35895eb37956fa5 Mon Sep 17 00:00:00 2001 From: Martin Rogers Date: Thu, 27 Jun 2024 16:32:21 +0100 Subject: [PATCH 28/69] correct minor errors/typos --- deepsensor/model/model.py | 53 ++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 4b0f1b8c..f6df23ac 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -664,6 +664,10 @@ def predict_patch( List of tasks containing context data. data_processor (:class:`~.data.processor.DataProcessor`): Used for unnormalising the coordinates of the bounding boxes of patches. + stride_size (Union[float, tuple[float]]): + Length of stride between adjacent patches in x1/x2 normalised coordinates. + patch_size (Union[float, tuple[float]]): + Height and width of patch in x1/x2 normalised coordinates. X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`): Target locations to predict at. Can be an xarray object containingon-grid locations or a pandas object containing off-grid locations. @@ -731,7 +735,7 @@ def predict_patch( ValueError If ``append_indexes`` are not all the same length as ``X_t``. """ - ## To do, do we need to add patch and stride as an additional argument? + def get_patches_per_row(preds, X_t) -> int: """ Calculate number of patches per row. @@ -746,16 +750,18 @@ def get_patches_per_row(preds, X_t) -> int: """ patches_per_row = 0 vars = list(preds[0][0].data_vars) - var = vars[0] - + var = vars[0] + y_val = preds[0][0][var].coords['y'].min() + for p in preds: - if p[0][var].coords['y'].min() == X_t.coords['y'].min(): + if p[0][var].coords['y'].min() == y_val: patches_per_row = patches_per_row + 1 + return patches_per_row - def get_patch_overlap(overlap_norm, data_processor, X_t_ds): + def get_patch_overlap(overlap_norm, data_processor, X_t_ds) -> int: """ Calculate overlap between adjacent patches in pixels. @@ -793,8 +799,6 @@ def get_patch_overlap(overlap_norm, data_processor, X_t_ds): return xy_overlap - - ## To do- change amsr_raw_ds to what? def get_index(*args, x1 = True) -> Union[int, Tuple[List[int], List[int]]]: """ Convert coordinates into pixel row/column (index). @@ -832,21 +836,25 @@ def get_index(*args, x1 = True) -> Union[int, Tuple[List[int], List[int]]]: return (x1_index, x2_index) - def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row): + def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> dict: """ Stitch patchwise predictions to form prediction at original extent. Parameters ---------- - args : tuple - If one argument (numeric), it represents the coordinate value. - If two arguments (lists), they represent lists of coordinate values. - - x1 : bool, optional - If True, compute index for x1 (default is True). + patch_preds : list (class:`~.model.pred.Prediction`) + List of patchwise predictions + patch_overlap: int + Overlap between adjacent patches in pixels. + + patches_per_row: int + Number of patchwise predictions in each row. + Returns ------- + combined: dict + Dictionary object containing the stitched model predictions. """ data_x1 = X_t.coords['y'].min().values, X_t.coords['y'].max().values @@ -895,18 +903,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row): combined = {var_name: xr.combine_by_coords(patches, compat='no_conflicts') for var_name, patches in patches_clipped.items()} - #combined = xr.combine_by_coords(patches_clipped, compat='no_conflicts') return combined - - def stitch_predictions(preds, pred_copy): - for pred in preds: - for var_name, data_array in pred.items(): - if var_name in pred_copy: - unnorm_patch_x1 = data_array['x'].min().values, data_array['x'].max().values - unnorm_patch_x2 = data_array['y'].min().values, data_array['y'].max().values - pred_copy[var_name].loc[{'x': slice(unnorm_patch_x1[0], unnorm_patch_x1[1]), 'y': slice(unnorm_patch_x2[0], unnorm_patch_x2[1])}] = data_array - return pred_copy - # Perform patchwise predictions preds = [] @@ -928,13 +925,13 @@ def stitch_predictions(preds, pred_copy): # Append patchwise DeepSensor prediction object to list preds.append(pred) - overlap_norm = tuple(patch - stride for patch, stride in zip(patch_size, stride_size)) patch_overlap_unnorm = get_patch_overlap(overlap_norm, data_processor, X_t) patches_per_row = get_patches_per_row(preds, X_t) stitched_prediction = stitch_clipped_predictions(preds, patch_overlap_unnorm, patches_per_row) - ## Change prediction into DeepSensor.Prediction object. + ## Cast prediction into DeepSensor.Prediction object. + # Todo: make this into seperate method. prediction= copy.deepcopy(preds[0]) # Generate new blank DeepSensor.prediction object in original coordinate system. @@ -953,8 +950,6 @@ def stitch_predictions(preds, pred_copy): prediction[var_name_copy]= stitched_preds prediction[var_name_copy] = stitched_prediction[var_name_copy] - #prediction = stitch_predictions(preds, pred_copy) - return prediction def main(): # pragma: no cover From d8af31491c864d1c31f58f7a2c9ed4e6559c2c53 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Thu, 27 Jun 2024 16:04:46 +0100 Subject: [PATCH 29/69] use TODO to be uniform --- deepsensor/model/model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index f6df23ac..3cb28a63 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -735,7 +735,7 @@ def predict_patch( ValueError If ``append_indexes`` are not all the same length as ``X_t``. """ - + # TODO, do we need to add patch and stride as an additional argument? def get_patches_per_row(preds, X_t) -> int: """ Calculate number of patches per row. @@ -799,6 +799,8 @@ def get_patch_overlap(overlap_norm, data_processor, X_t_ds) -> int: return xy_overlap + + # TODO - change amsr_raw_ds to what? def get_index(*args, x1 = True) -> Union[int, Tuple[List[int], List[int]]]: """ Convert coordinates into pixel row/column (index). @@ -931,7 +933,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d stitched_prediction = stitch_clipped_predictions(preds, patch_overlap_unnorm, patches_per_row) ## Cast prediction into DeepSensor.Prediction object. - # Todo: make this into seperate method. + # TODO make this into seperate method. prediction= copy.deepcopy(preds[0]) # Generate new blank DeepSensor.prediction object in original coordinate system. From f3b7f1283bc9419f2e7bef225abfb2bcc9bae980 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Fri, 28 Jun 2024 13:39:32 +0100 Subject: [PATCH 30/69] use "stride" as in taskloader --- deepsensor/model/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 3cb28a63..d79d9690 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -638,7 +638,7 @@ def predict_patch( pd.DataFrame, List[Union[xr.DataArray, xr.Dataset, pd.DataFrame]], ], - stride_size: Union[float, tuple[float]], + stride: Union[float, tuple[float]], patch_size: Union[float, tuple[float]], X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None, X_t_is_normalised: bool = False, @@ -664,7 +664,7 @@ def predict_patch( List of tasks containing context data. data_processor (:class:`~.data.processor.DataProcessor`): Used for unnormalising the coordinates of the bounding boxes of patches. - stride_size (Union[float, tuple[float]]): + stride (Union[float, tuple[float]]): Length of stride between adjacent patches in x1/x2 normalised coordinates. patch_size (Union[float, tuple[float]]): Height and width of patch in x1/x2 normalised coordinates. @@ -927,7 +927,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d # Append patchwise DeepSensor prediction object to list preds.append(pred) - overlap_norm = tuple(patch - stride for patch, stride in zip(patch_size, stride_size)) + overlap_norm = tuple(patch - stride for patch, stride in zip(patch_size, stride)) patch_overlap_unnorm = get_patch_overlap(overlap_norm, data_processor, X_t) patches_per_row = get_patches_per_row(preds, X_t) stitched_prediction = stitch_clipped_predictions(preds, patch_overlap_unnorm, patches_per_row) From 5a1766be99a6b498eda90c3df32a7a45aee741ed Mon Sep 17 00:00:00 2001 From: Martin Rogers Date: Wed, 10 Jul 2024 16:44:44 +0100 Subject: [PATCH 31/69] resolve unnormalised coordinate names --- deepsensor/model/model.py | 66 ++++++++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index f6df23ac..6d79fdd0 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -736,7 +736,14 @@ def predict_patch( If ``append_indexes`` are not all the same length as ``X_t``. """ - def get_patches_per_row(preds, X_t) -> int: + # Get coordinate names of original unnormalised dataset. + unnorm_coord_names = { + "x1": self.data_processor.raw_spatial_coord_names[0], + "x2": self.data_processor.raw_spatial_coord_names[1], + } + + + def get_patches_per_row(preds) -> int: """ Calculate number of patches per row. Required to stitch patches back together. @@ -751,10 +758,10 @@ def get_patches_per_row(preds, X_t) -> int: patches_per_row = 0 vars = list(preds[0][0].data_vars) var = vars[0] - y_val = preds[0][0][var].coords['y'].min() + y_val = preds[0][0][var].coords[unnorm_coord_names['x1']].min() for p in preds: - if p[0][var].coords['y'].min() == y_val: + if p[0][var].coords[unnorm_coord_names['x1']].min() == y_val: patches_per_row = patches_per_row + 1 return patches_per_row @@ -789,12 +796,12 @@ def get_patch_overlap(overlap_norm, data_processor, X_t_ds) -> int: # Unnormalise coordinates of bounding boxes overlap_unnorm_xr = data_processor.unnormalise(overlap_norm_xr) - unnorm_overlap_x1 = overlap_unnorm_xr.coords['x'].values[1] - unnorm_overlap_x2 = overlap_unnorm_xr.coords['y'].values[1] + unnorm_overlap_x1 = overlap_unnorm_xr.coords[unnorm_coord_names['x1']].values[1] + unnorm_overlap_x2 = overlap_unnorm_xr.coords[unnorm_coord_names['x2']].values[1] # Find the position of these indices within the DataArray - x_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords['x'].values - unnorm_overlap_x1))/2))) - y_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords['y'].values - unnorm_overlap_x2))/2))) + x_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[unnorm_coord_names['x1']].values - unnorm_overlap_x1))/2))) + y_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[unnorm_coord_names['x2']].values - unnorm_overlap_x2))/2))) xy_overlap = (x_overlap_index, y_overlap_index) return xy_overlap @@ -824,15 +831,15 @@ def get_index(*args, x1 = True) -> Union[int, Tuple[List[int], List[int]]]: if len(args) == 1: patch_coord = args if x1: - coord_index = np.argmin(np.abs(X_t.coords['y'].values - patch_coord)) + coord_index = np.argmin(np.abs(X_t.coords[unnorm_coord_names['x1']].values - patch_coord)) else: - coord_index = np.argmin(np.abs(X_t.coords['x'].values - patch_coord)) + coord_index = np.argmin(np.abs(X_t.coords[unnorm_coord_names['x2']].values - patch_coord)) return coord_index elif len(args) == 2: patch_x1, patch_x2 = args - x1_index = [np.argmin(np.abs(X_t.coords['y'].values - target_x1)) for target_x1 in patch_x1] - x2_index = [np.argmin(np.abs(X_t.coords['x'].values - target_x2)) for target_x2 in patch_x2] + x1_index = [np.argmin(np.abs(X_t.coords[unnorm_coord_names['x1']].values - target_x1)) for target_x1 in patch_x1] + x2_index = [np.argmin(np.abs(X_t.coords[unnorm_coord_names['x2']].values - target_x2)) for target_x2 in patch_x2] return (x1_index, x2_index) @@ -856,9 +863,11 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d combined: dict Dictionary object containing the stitched model predictions. """ + + - data_x1 = X_t.coords['y'].min().values, X_t.coords['y'].max().values - data_x2 = X_t.coords['x'].min().values, X_t.coords['x'].max().values + data_x1 = X_t.coords[unnorm_coord_names['x1']].min().values, X_t.coords[unnorm_coord_names['x1']].max().values + data_x2 = X_t.coords[unnorm_coord_names['x2']].min().values, X_t.coords[unnorm_coord_names['x2']].max().values data_x1_index, data_x2_index = get_index(data_x1, data_x2) patches_clipped = {var_name: [] for var_name in patch_preds[0].keys()} @@ -867,20 +876,20 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d for var_name, data_array in patch_pred.items(): #previously patch if var_name in patch_pred: # Get row/col index values of each patch - patch_x1 = data_array.coords['y'].min().values, data_array.coords['y'].max().values - patch_x2 = data_array.coords['x'].min().values, data_array.coords['x'].max().values + patch_x1 = data_array.coords[unnorm_coord_names['x1']].min().values, data_array.coords[unnorm_coord_names['x1']].max().values + patch_x2 = data_array.coords[unnorm_coord_names['x2']].min().values, data_array.coords[unnorm_coord_names['x2']].max().values patch_x1_index, patch_x2_index = get_index(patch_x1, patch_x2) b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0] b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1] # Do not remove border for the patches along top and left of dataset - # and change overlap size for last patch in rows and columns. + # and change overlap size for last patch in each row and column. if patch_x2_index[0] == data_x2_index[0]: b_x2_min = 0 elif patch_x2_index[1] == data_x2_index[1]: b_x2_max = 0 patch_row_prev = preds[i-1] - prev_patch_x2_max = get_index(int(patch_row_prev[var_name].coords['x'].max()), x1 = False) + prev_patch_x2_max = get_index(int(patch_row_prev[var_name].coords[unnorm_coord_names['x2']].max()), x1 = False) b_x2_min = (prev_patch_x2_max - patch_x2_index[0])-patch_overlap[1] if patch_x1_index[0] == data_x1_index[0]: @@ -888,13 +897,13 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d elif abs(patch_x1_index[1] - data_x1_index[1])<2: b_x1_max = 0 patch_prev = preds[i-patches_per_row] - prev_patch_x1_max = get_index(int(patch_prev[var_name].coords['y'].max()), x1 = True) + prev_patch_x1_max = get_index(int(patch_prev[var_name].coords[unnorm_coord_names['x1']].max()), x1 = True) b_x1_min = (prev_patch_x1_max - patch_x1_index[0])- patch_overlap[0] patch_clip_x1_min = int(b_x1_min) - patch_clip_x1_max = int(data_array.sizes['y'] - b_x1_max) + patch_clip_x1_max = int(data_array.sizes[unnorm_coord_names['x1']] - b_x1_max) patch_clip_x2_min = int(b_x2_min) - patch_clip_x2_max = int(data_array.sizes['x'] - b_x2_max) + patch_clip_x2_max = int(data_array.sizes[unnorm_coord_names['x2']] - b_x2_max) patch_clip = data_array.isel(y=slice(patch_clip_x1_min, patch_clip_x1_max), x=slice(patch_clip_x2_min, patch_clip_x2_max)) @@ -915,11 +924,16 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d x2 = xr.DataArray([bbox[2], bbox[3]], dims='x2', name='x2') bbox_norm = xr.Dataset(coords={'x1': x1, 'x2': x2}) bbox_unnorm = data_processor.unnormalise(bbox_norm) - unnorm_bbox_x1 = bbox_unnorm['x'].values.min(), bbox_unnorm['x'].values.max() - unnorm_bbox_x2 = bbox_unnorm['y'].values.min(), bbox_unnorm['y'].values.max() + unnorm_bbox_x1 = bbox_unnorm[unnorm_coord_names['x1']].values.min(), bbox_unnorm[unnorm_coord_names['x1']].values.max() + unnorm_bbox_x2 = bbox_unnorm[unnorm_coord_names['x2']].values.min(), bbox_unnorm[unnorm_coord_names['x2']].values.max() + # Determine X_t for patch - task_X_t = X_t.sel(x = slice(unnorm_bbox_x1[0], unnorm_bbox_x1[1]), - y = slice(unnorm_bbox_x2[0], unnorm_bbox_x2[1])) + task_extent_dict = { + unnorm_coord_names['x1']: slice(unnorm_bbox_x1[0], unnorm_bbox_x1[1]), + unnorm_coord_names['x2']: slice(unnorm_bbox_x2[0], unnorm_bbox_x2[1]) + } + task_X_t = X_t.sel(**task_extent_dict) + # Patchwise prediction pred = self.predict(task, task_X_t) # Append patchwise DeepSensor prediction object to list @@ -927,7 +941,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d overlap_norm = tuple(patch - stride for patch, stride in zip(patch_size, stride_size)) patch_overlap_unnorm = get_patch_overlap(overlap_norm, data_processor, X_t) - patches_per_row = get_patches_per_row(preds, X_t) + patches_per_row = get_patches_per_row(preds) stitched_prediction = stitch_clipped_predictions(preds, patch_overlap_unnorm, patches_per_row) ## Cast prediction into DeepSensor.Prediction object. @@ -938,7 +952,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d for var_name_copy, data_array_copy in prediction.items(): # set x and y coords - stitched_preds = xr.Dataset(coords={'x': X_t['x'], 'y': X_t['y']}) + stitched_preds = xr.Dataset(coords={'x1': X_t[unnorm_coord_names['x1']], 'x2': X_t[unnorm_coord_names['x2']]}) # Set time to same as patched prediction stitched_preds['time'] = data_array_copy['time'] From 84d99441ce739fe396753d016ac811b08c4279f9 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Thu, 11 Jul 2024 09:18:30 +0100 Subject: [PATCH 32/69] Handle absent bbox and task as non-iterable --- deepsensor/model/model.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index d79d9690..746af7dd 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -907,10 +907,17 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d return combined + # tasks should be iterable, if only one is provided, make it a list + if type(tasks) is Task: + tasks = [tasks] + # Perform patchwise predictions preds = [] for task in tasks: bbox = task['bbox'] + + if bbox is None: + raise AttributeError("Tasks require non-None ``bbox`` for patchwise inference.") # Unnormalise coordinates of bounding box of patch x1 = xr.DataArray([bbox[0], bbox[1]], dims='x1', name='x1') From aab6f1e5ca520063202e279abbf0152072babb52 Mon Sep 17 00:00:00 2001 From: Martin Rogers Date: Wed, 10 Jul 2024 16:44:44 +0100 Subject: [PATCH 33/69] resolve unnormalised coordinate names --- deepsensor/model/model.py | 68 +++++++++++++++++++++++---------------- 1 file changed, 41 insertions(+), 27 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 746af7dd..19857285 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -735,8 +735,15 @@ def predict_patch( ValueError If ``append_indexes`` are not all the same length as ``X_t``. """ - # TODO, do we need to add patch and stride as an additional argument? - def get_patches_per_row(preds, X_t) -> int: + + # Get coordinate names of original unnormalised dataset. + unnorm_coord_names = { + "x1": self.data_processor.raw_spatial_coord_names[0], + "x2": self.data_processor.raw_spatial_coord_names[1], + } + + + def get_patches_per_row(preds) -> int: """ Calculate number of patches per row. Required to stitch patches back together. @@ -751,10 +758,10 @@ def get_patches_per_row(preds, X_t) -> int: patches_per_row = 0 vars = list(preds[0][0].data_vars) var = vars[0] - y_val = preds[0][0][var].coords['y'].min() + y_val = preds[0][0][var].coords[unnorm_coord_names['x1']].min() for p in preds: - if p[0][var].coords['y'].min() == y_val: + if p[0][var].coords[unnorm_coord_names['x1']].min() == y_val: patches_per_row = patches_per_row + 1 return patches_per_row @@ -789,12 +796,12 @@ def get_patch_overlap(overlap_norm, data_processor, X_t_ds) -> int: # Unnormalise coordinates of bounding boxes overlap_unnorm_xr = data_processor.unnormalise(overlap_norm_xr) - unnorm_overlap_x1 = overlap_unnorm_xr.coords['x'].values[1] - unnorm_overlap_x2 = overlap_unnorm_xr.coords['y'].values[1] + unnorm_overlap_x1 = overlap_unnorm_xr.coords[unnorm_coord_names['x1']].values[1] + unnorm_overlap_x2 = overlap_unnorm_xr.coords[unnorm_coord_names['x2']].values[1] # Find the position of these indices within the DataArray - x_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords['x'].values - unnorm_overlap_x1))/2))) - y_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords['y'].values - unnorm_overlap_x2))/2))) + x_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[unnorm_coord_names['x1']].values - unnorm_overlap_x1))/2))) + y_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[unnorm_coord_names['x2']].values - unnorm_overlap_x2))/2))) xy_overlap = (x_overlap_index, y_overlap_index) return xy_overlap @@ -826,15 +833,15 @@ def get_index(*args, x1 = True) -> Union[int, Tuple[List[int], List[int]]]: if len(args) == 1: patch_coord = args if x1: - coord_index = np.argmin(np.abs(X_t.coords['y'].values - patch_coord)) + coord_index = np.argmin(np.abs(X_t.coords[unnorm_coord_names['x1']].values - patch_coord)) else: - coord_index = np.argmin(np.abs(X_t.coords['x'].values - patch_coord)) + coord_index = np.argmin(np.abs(X_t.coords[unnorm_coord_names['x2']].values - patch_coord)) return coord_index elif len(args) == 2: patch_x1, patch_x2 = args - x1_index = [np.argmin(np.abs(X_t.coords['y'].values - target_x1)) for target_x1 in patch_x1] - x2_index = [np.argmin(np.abs(X_t.coords['x'].values - target_x2)) for target_x2 in patch_x2] + x1_index = [np.argmin(np.abs(X_t.coords[unnorm_coord_names['x1']].values - target_x1)) for target_x1 in patch_x1] + x2_index = [np.argmin(np.abs(X_t.coords[unnorm_coord_names['x2']].values - target_x2)) for target_x2 in patch_x2] return (x1_index, x2_index) @@ -858,9 +865,11 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d combined: dict Dictionary object containing the stitched model predictions. """ + + - data_x1 = X_t.coords['y'].min().values, X_t.coords['y'].max().values - data_x2 = X_t.coords['x'].min().values, X_t.coords['x'].max().values + data_x1 = X_t.coords[unnorm_coord_names['x1']].min().values, X_t.coords[unnorm_coord_names['x1']].max().values + data_x2 = X_t.coords[unnorm_coord_names['x2']].min().values, X_t.coords[unnorm_coord_names['x2']].max().values data_x1_index, data_x2_index = get_index(data_x1, data_x2) patches_clipped = {var_name: [] for var_name in patch_preds[0].keys()} @@ -869,20 +878,20 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d for var_name, data_array in patch_pred.items(): #previously patch if var_name in patch_pred: # Get row/col index values of each patch - patch_x1 = data_array.coords['y'].min().values, data_array.coords['y'].max().values - patch_x2 = data_array.coords['x'].min().values, data_array.coords['x'].max().values + patch_x1 = data_array.coords[unnorm_coord_names['x1']].min().values, data_array.coords[unnorm_coord_names['x1']].max().values + patch_x2 = data_array.coords[unnorm_coord_names['x2']].min().values, data_array.coords[unnorm_coord_names['x2']].max().values patch_x1_index, patch_x2_index = get_index(patch_x1, patch_x2) b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0] b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1] # Do not remove border for the patches along top and left of dataset - # and change overlap size for last patch in rows and columns. + # and change overlap size for last patch in each row and column. if patch_x2_index[0] == data_x2_index[0]: b_x2_min = 0 elif patch_x2_index[1] == data_x2_index[1]: b_x2_max = 0 patch_row_prev = preds[i-1] - prev_patch_x2_max = get_index(int(patch_row_prev[var_name].coords['x'].max()), x1 = False) + prev_patch_x2_max = get_index(int(patch_row_prev[var_name].coords[unnorm_coord_names['x2']].max()), x1 = False) b_x2_min = (prev_patch_x2_max - patch_x2_index[0])-patch_overlap[1] if patch_x1_index[0] == data_x1_index[0]: @@ -890,13 +899,13 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d elif abs(patch_x1_index[1] - data_x1_index[1])<2: b_x1_max = 0 patch_prev = preds[i-patches_per_row] - prev_patch_x1_max = get_index(int(patch_prev[var_name].coords['y'].max()), x1 = True) + prev_patch_x1_max = get_index(int(patch_prev[var_name].coords[unnorm_coord_names['x1']].max()), x1 = True) b_x1_min = (prev_patch_x1_max - patch_x1_index[0])- patch_overlap[0] patch_clip_x1_min = int(b_x1_min) - patch_clip_x1_max = int(data_array.sizes['y'] - b_x1_max) + patch_clip_x1_max = int(data_array.sizes[unnorm_coord_names['x1']] - b_x1_max) patch_clip_x2_min = int(b_x2_min) - patch_clip_x2_max = int(data_array.sizes['x'] - b_x2_max) + patch_clip_x2_max = int(data_array.sizes[unnorm_coord_names['x2']] - b_x2_max) patch_clip = data_array.isel(y=slice(patch_clip_x1_min, patch_clip_x1_max), x=slice(patch_clip_x2_min, patch_clip_x2_max)) @@ -924,11 +933,16 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d x2 = xr.DataArray([bbox[2], bbox[3]], dims='x2', name='x2') bbox_norm = xr.Dataset(coords={'x1': x1, 'x2': x2}) bbox_unnorm = data_processor.unnormalise(bbox_norm) - unnorm_bbox_x1 = bbox_unnorm['x'].values.min(), bbox_unnorm['x'].values.max() - unnorm_bbox_x2 = bbox_unnorm['y'].values.min(), bbox_unnorm['y'].values.max() + unnorm_bbox_x1 = bbox_unnorm[unnorm_coord_names['x1']].values.min(), bbox_unnorm[unnorm_coord_names['x1']].values.max() + unnorm_bbox_x2 = bbox_unnorm[unnorm_coord_names['x2']].values.min(), bbox_unnorm[unnorm_coord_names['x2']].values.max() + # Determine X_t for patch - task_X_t = X_t.sel(x = slice(unnorm_bbox_x1[0], unnorm_bbox_x1[1]), - y = slice(unnorm_bbox_x2[0], unnorm_bbox_x2[1])) + task_extent_dict = { + unnorm_coord_names['x1']: slice(unnorm_bbox_x1[0], unnorm_bbox_x1[1]), + unnorm_coord_names['x2']: slice(unnorm_bbox_x2[0], unnorm_bbox_x2[1]) + } + task_X_t = X_t.sel(**task_extent_dict) + # Patchwise prediction pred = self.predict(task, task_X_t) # Append patchwise DeepSensor prediction object to list @@ -936,7 +950,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d overlap_norm = tuple(patch - stride for patch, stride in zip(patch_size, stride)) patch_overlap_unnorm = get_patch_overlap(overlap_norm, data_processor, X_t) - patches_per_row = get_patches_per_row(preds, X_t) + patches_per_row = get_patches_per_row(preds) stitched_prediction = stitch_clipped_predictions(preds, patch_overlap_unnorm, patches_per_row) ## Cast prediction into DeepSensor.Prediction object. @@ -947,7 +961,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d for var_name_copy, data_array_copy in prediction.items(): # set x and y coords - stitched_preds = xr.Dataset(coords={'x': X_t['x'], 'y': X_t['y']}) + stitched_preds = xr.Dataset(coords={'x1': X_t[unnorm_coord_names['x1']], 'x2': X_t[unnorm_coord_names['x2']]}) # Set time to same as patched prediction stitched_preds['time'] = data_array_copy['time'] From bda71766500ca31da8cfb198fcbe4c6af643e561 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Thu, 11 Jul 2024 16:55:06 +0100 Subject: [PATCH 34/69] use dict format for isel for variable coordinate names --- deepsensor/model/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 19857285..2649286c 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -907,8 +907,8 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d patch_clip_x2_min = int(b_x2_min) patch_clip_x2_max = int(data_array.sizes[unnorm_coord_names['x2']] - b_x2_max) - patch_clip = data_array.isel(y=slice(patch_clip_x1_min, patch_clip_x1_max), - x=slice(patch_clip_x2_min, patch_clip_x2_max)) + patch_clip = data_array[{unnorm_coord_names['x1']: slice(patch_clip_x1_min, patch_clip_x1_max), + unnorm_coord_names['x2']: slice(patch_clip_x2_min, patch_clip_x2_max)}] patches_clipped[var_name].append(patch_clip) From 55bf86fcdb9e6bc2626ebc66a4c0902a9a446597 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Tue, 16 Jul 2024 17:24:15 +0100 Subject: [PATCH 35/69] add basic test for patchwise prediction --- tests/test_model.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/test_model.py b/tests/test_model.py index 80519269..87176c6d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -522,6 +522,33 @@ def test_highlevel_predict_with_invalid_pred_params(self): with self.assertRaises(AttributeError): model.predict(task, X_t=self.da, pred_params=["invalid_param"]) + def test_patchwise_prediction(self): + """Test that ``.predict_patch`` runs correctly.""" + + patch_size = (0.6, 0.6) + stride_size = (0.5, 0.5) + + tl = TaskLoader(context=self.da, target=self.da) + + task = tl( + "2020-01-01", + context_sampling="all", + target_sampling="all", + patch_strategy="sliding", + patch_size=patch_size, + stride=stride_size, + ) + + model = ConvNP(self.dp, tl) + + model.predict_patch( + tasks=task, + X_t=self.da, + data_processor=self.dp, + stride=stride_size, + patch_size=patch_size, + ) + def test_saving_and_loading(self): """Test saving and loading of model""" with tempfile.TemporaryDirectory() as folder: From 323ab46024dddd15701888fc8d4b8a9c91eaadb5 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Thu, 18 Jul 2024 17:38:08 +0100 Subject: [PATCH 36/69] handle patch_size and stride as floats or tuples in task loader and predict_patch --- deepsensor/data/loader.py | 20 ++++++++++++++------ deepsensor/model/model.py | 28 +++++++++++++++++++--------- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 6a59c406..8f06279a 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1416,9 +1416,9 @@ def __call__( ] ] = None, split_frac: float = 0.5, - patch_size: Sequence[float] = None, + patch_size: Union[float, tuple[float]] = None, patch_strategy: Optional[str] = None, - stride: Optional[Sequence[int]] = None, + stride: Union[float, tuple[float]] = None, num_samples_per_date: int = 1, datewise_deterministic: bool = False, seed_override: Optional[int] = None, @@ -1466,14 +1466,16 @@ def __call__( the "split" sampling strategy for linked context and target set pairs. The remaining observations are used for the target set. Default is 0.5. - patch_size : Sequence[float], optional - Desired patch size in x1/x2 used for patchwise task generation. Usefule when considering - the entire available region is computationally prohibitive for model forward pass + patch_size : Union[float, tuple[float]], optional + Desired patch size in x1/x2 used for patchwise task generation. Useful when considering + the entire available region is computationally prohibitive for model forward pass. + If passed a single float, will use value for both x1 & x2. patch_strategy: Patch strategy to use for patchwise task generation. Default is None. Possible options are 'random' or 'sliding'. - stride: Sequence[int], optional + stride: Union[float, tuple[float]], optional Step size between each sliding window patch along x1 and x2 axis. Default is None. + If passed a single float, will use value for both x1 & x2. datewise_deterministic (bool, optional): Whether random sampling is datewise deterministic based on the date. Default is ``False``. @@ -1492,6 +1494,12 @@ def __call__( f"Must be one of [None, 'random', 'sliding']." ) + if isinstance(patch_size, float) and patch_size is not None: + patch_size = (patch_size, patch_size) + + if isinstance(stride, float) and stride is not None: + stride = (stride, stride) + if patch_strategy is None: if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)): tasks = [ diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 2649286c..afe92bcc 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -665,9 +665,9 @@ def predict_patch( data_processor (:class:`~.data.processor.DataProcessor`): Used for unnormalising the coordinates of the bounding boxes of patches. stride (Union[float, tuple[float]]): - Length of stride between adjacent patches in x1/x2 normalised coordinates. + Length of stride between adjacent patches in x1/x2 normalised coordinates. If passed a single float, will use value for both x1 & x2. patch_size (Union[float, tuple[float]]): - Height and width of patch in x1/x2 normalised coordinates. + Height and width of patch in x1/x2 normalised coordinates. If passed a single float, will use value for both x1 & x2. X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`): Target locations to predict at. Can be an xarray object containingon-grid locations or a pandas object containing off-grid locations. @@ -736,13 +736,6 @@ def predict_patch( If ``append_indexes`` are not all the same length as ``X_t``. """ - # Get coordinate names of original unnormalised dataset. - unnorm_coord_names = { - "x1": self.data_processor.raw_spatial_coord_names[0], - "x2": self.data_processor.raw_spatial_coord_names[1], - } - - def get_patches_per_row(preds) -> int: """ Calculate number of patches per row. @@ -916,6 +909,23 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d return combined + if isinstance(patch_size, float) and patch_size is not None: + patch_size = (patch_size, patch_size) + + if isinstance(stride, float) and stride is not None: + stride = (stride, stride) + + if stride[0] > patch_size[0] or stride[1] > patch_size[1]: + raise ValueError( + f"stride must be smaller than patch_size in the corresponding dimensions. Got: patch_size: {patch_size}, stride: {stride}" + ) + + # Get coordinate names of original unnormalised dataset. + unnorm_coord_names = { + "x1": self.data_processor.raw_spatial_coord_names[0], + "x2": self.data_processor.raw_spatial_coord_names[1], + } + # tasks should be iterable, if only one is provided, make it a list if type(tasks) is Task: tasks = [tasks] From 09befb3175d9db6660a52a22010cd29befc063e1 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Fri, 19 Jul 2024 16:48:39 +0100 Subject: [PATCH 37/69] test parameter handling and sizes in patchwise prediction --- tests/test_model.py | 55 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 5 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 87176c6d..017f57f9 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -541,14 +541,59 @@ def test_patchwise_prediction(self): model = ConvNP(self.dp, tl) - model.predict_patch( - tasks=task, - X_t=self.da, - data_processor=self.dp, - stride=stride_size, + pred = model.predict_patch( + tasks=task, + X_t=self.da, + data_processor=self.dp, + stride=stride_size, + patch_size=patch_size, + ) + + # gridded predictions + assert [isinstance(ds, xr.Dataset) for ds in pred.values()] + for var_ID in pred: + assert_shape( + pred[var_ID]["mean"], + (1, self.da.x1.size, self.da.x2.size), + ) + assert_shape( + pred[var_ID]["std"], + (1, self.da.x1.size, self.da.x2.size), + ) + assert( + self.da.x1.size == pred[var_ID].x1.size + ) + assert( + self.da.x2.size == pred[var_ID].x2.size + ) + + + @parameterized.expand([(0.5, 0.6)]) + def test_patchwise_prediction_parameter_handling(self, patch_size, stride_size): + """Test that correct errors and warnings are raised by ``.predict_patch``.""" + + tl = TaskLoader(context=self.da, target=self.da) + + task = tl( + "2020-01-01", + context_sampling="all", + target_sampling="all", + patch_strategy="sliding", patch_size=patch_size, + stride=stride_size, ) + model = ConvNP(self.dp, tl) + + with self.assertRaises(ValueError): + model.predict_patch( + tasks=task, + X_t=self.da, + data_processor=self.dp, + stride=stride_size, + patch_size=patch_size, + ) + def test_saving_and_loading(self): """Test saving and loading of model""" with tempfile.TemporaryDirectory() as folder: From c36455b719324b9945079f8e2460fab0c699b89a Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Fri, 19 Jul 2024 17:55:32 +0100 Subject: [PATCH 38/69] remove resolved TODO --- deepsensor/model/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index afe92bcc..f2a88fa2 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -800,7 +800,6 @@ def get_patch_overlap(overlap_norm, data_processor, X_t_ds) -> int: return xy_overlap - # TODO - change amsr_raw_ds to what? def get_index(*args, x1 = True) -> Union[int, Tuple[List[int], List[int]]]: """ Convert coordinates into pixel row/column (index). From 0cf143db1aa910b1771949683b99d3a34e9ef8ac Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Fri, 19 Jul 2024 18:15:01 +0100 Subject: [PATCH 39/69] check patch_size and stride values in predict_patch and test --- deepsensor/model/model.py | 9 +++++++++ tests/test_model.py | 7 ++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index f2a88fa2..cbf6d423 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -908,6 +908,8 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d return combined + # sanitise patch_size and stride arguments + if isinstance(patch_size, float) and patch_size is not None: patch_size = (patch_size, patch_size) @@ -918,6 +920,13 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d raise ValueError( f"stride must be smaller than patch_size in the corresponding dimensions. Got: patch_size: {patch_size}, stride: {stride}" ) + + for val in zip(stride, patch_size): + if val>1.0 or val<0.0: + raise ValueError( + f"Values of stride and patch_size must be between 0 & 1. Got: patch_size: {patch_size}, stride: {stride}" + ) + # Get coordinate names of original unnormalised dataset. unnorm_coord_names = { diff --git a/tests/test_model.py b/tests/test_model.py index 017f57f9..bac07406 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -568,7 +568,12 @@ def test_patchwise_prediction(self): ) - @parameterized.expand([(0.5, 0.6)]) + @parameterized.expand([ + ((0.5, 0.5), (0.6, 0.6)), # patch_size and stride as tuples + (0.5, 0.6), # as floats + (1.0, 1.2), # one argument above allowed range + (-0.1, 0.6) # and below allowed range + ]) def test_patchwise_prediction_parameter_handling(self, patch_size, stride_size): """Test that correct errors and warnings are raised by ``.predict_patch``.""" From 8da48c106cf019229a6c31327b68d87262987380 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Wed, 7 Aug 2024 10:14:38 +0000 Subject: [PATCH 40/69] test inference --- deepsensor/model/model.py | 77 ++++++---- docs/user-guide/patchwise_training.py | 203 ++++++++++++++++++++++++++ 2 files changed, 253 insertions(+), 27 deletions(-) create mode 100644 docs/user-guide/patchwise_training.py diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index f6df23ac..3831c4ec 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -735,6 +735,9 @@ def predict_patch( ValueError If ``append_indexes`` are not all the same length as ``X_t``. """ + + orig_x1_name = data_processor.x1_name + orig_x2_name = data_processor.x2_name def get_patches_per_row(preds, X_t) -> int: """ @@ -751,10 +754,10 @@ def get_patches_per_row(preds, X_t) -> int: patches_per_row = 0 vars = list(preds[0][0].data_vars) var = vars[0] - y_val = preds[0][0][var].coords['y'].min() + y_val = preds[0][0][var].coords[orig_x2_name].min() for p in preds: - if p[0][var].coords['y'].min() == y_val: + if p[0][var].coords[orig_x2_name].min() == y_val: patches_per_row = patches_per_row + 1 return patches_per_row @@ -789,12 +792,12 @@ def get_patch_overlap(overlap_norm, data_processor, X_t_ds) -> int: # Unnormalise coordinates of bounding boxes overlap_unnorm_xr = data_processor.unnormalise(overlap_norm_xr) - unnorm_overlap_x1 = overlap_unnorm_xr.coords['x'].values[1] - unnorm_overlap_x2 = overlap_unnorm_xr.coords['y'].values[1] + unnorm_overlap_x1 = overlap_unnorm_xr.coords[orig_x1_name].values[1] + unnorm_overlap_x2 = overlap_unnorm_xr.coords[orig_x2_name].values[1] # Find the position of these indices within the DataArray - x_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords['x'].values - unnorm_overlap_x1))/2))) - y_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords['y'].values - unnorm_overlap_x2))/2))) + x_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[orig_x1_name].values - unnorm_overlap_x1))/2))) + y_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[orig_x2_name].values - unnorm_overlap_x2))/2))) xy_overlap = (x_overlap_index, y_overlap_index) return xy_overlap @@ -824,15 +827,15 @@ def get_index(*args, x1 = True) -> Union[int, Tuple[List[int], List[int]]]: if len(args) == 1: patch_coord = args if x1: - coord_index = np.argmin(np.abs(X_t.coords['y'].values - patch_coord)) + coord_index = np.argmin(np.abs(X_t.coords[orig_x2_name].values - patch_coord)) else: - coord_index = np.argmin(np.abs(X_t.coords['x'].values - patch_coord)) + coord_index = np.argmin(np.abs(X_t.coords[orig_x1_name].values - patch_coord)) return coord_index elif len(args) == 2: patch_x1, patch_x2 = args - x1_index = [np.argmin(np.abs(X_t.coords['y'].values - target_x1)) for target_x1 in patch_x1] - x2_index = [np.argmin(np.abs(X_t.coords['x'].values - target_x2)) for target_x2 in patch_x2] + x1_index = [np.argmin(np.abs(X_t.coords[orig_x1_name].values - target_x1)) for target_x1 in patch_x1] + x2_index = [np.argmin(np.abs(X_t.coords[orig_x2_name].values - target_x2)) for target_x2 in patch_x2] return (x1_index, x2_index) @@ -857,8 +860,8 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d Dictionary object containing the stitched model predictions. """ - data_x1 = X_t.coords['y'].min().values, X_t.coords['y'].max().values - data_x2 = X_t.coords['x'].min().values, X_t.coords['x'].max().values + data_x1 = X_t.coords[orig_x2_name].min().values, X_t.coords[orig_x2_name].max().values + data_x2 = X_t.coords[orig_x1_name].min().values, X_t.coords[orig_x1_name].max().values data_x1_index, data_x2_index = get_index(data_x1, data_x2) patches_clipped = {var_name: [] for var_name in patch_preds[0].keys()} @@ -867,8 +870,8 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d for var_name, data_array in patch_pred.items(): #previously patch if var_name in patch_pred: # Get row/col index values of each patch - patch_x1 = data_array.coords['y'].min().values, data_array.coords['y'].max().values - patch_x2 = data_array.coords['x'].min().values, data_array.coords['x'].max().values + patch_x1 = data_array.coords[orig_x2_name].min().values, data_array.coords[orig_x2_name].max().values + patch_x2 = data_array.coords[orig_x1_name].min().values, data_array.coords[orig_x1_name].max().values patch_x1_index, patch_x2_index = get_index(patch_x1, patch_x2) b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0] @@ -880,7 +883,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d elif patch_x2_index[1] == data_x2_index[1]: b_x2_max = 0 patch_row_prev = preds[i-1] - prev_patch_x2_max = get_index(int(patch_row_prev[var_name].coords['x'].max()), x1 = False) + prev_patch_x2_max = get_index(int(patch_row_prev[var_name].coords[orig_x1_name].max()), x1 = False) b_x2_min = (prev_patch_x2_max - patch_x2_index[0])-patch_overlap[1] if patch_x1_index[0] == data_x1_index[0]: @@ -888,16 +891,18 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d elif abs(patch_x1_index[1] - data_x1_index[1])<2: b_x1_max = 0 patch_prev = preds[i-patches_per_row] - prev_patch_x1_max = get_index(int(patch_prev[var_name].coords['y'].max()), x1 = True) + prev_patch_x1_max = get_index(int(patch_prev[var_name].coords[orig_x2_name].max()), x1 = True) b_x1_min = (prev_patch_x1_max - patch_x1_index[0])- patch_overlap[0] patch_clip_x1_min = int(b_x1_min) - patch_clip_x1_max = int(data_array.sizes['y'] - b_x1_max) + patch_clip_x1_max = int(data_array.sizes[orig_x2_name] - b_x1_max) patch_clip_x2_min = int(b_x2_min) - patch_clip_x2_max = int(data_array.sizes['x'] - b_x2_max) + patch_clip_x2_max = int(data_array.sizes[orig_x1_name] - b_x2_max) + + # patch_clip = data_array.isel(y=slice(patch_clip_x1_min, patch_clip_x1_max), + # x=slice(patch_clip_x2_min, patch_clip_x2_max)) - patch_clip = data_array.isel(y=slice(patch_clip_x1_min, patch_clip_x1_max), - x=slice(patch_clip_x2_min, patch_clip_x2_max)) + patch_clip = data_array.isel(**{orig_x1_name: slice(patch_clip_x1_min, patch_clip_x1_max), orig_x2_name: slice(patch_clip_x2_min, patch_clip_x2_max)}) patches_clipped[var_name].append(patch_clip) @@ -905,6 +910,8 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d return combined + + # Perform patchwise predictions preds = [] for task in tasks: @@ -915,11 +922,27 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d x2 = xr.DataArray([bbox[2], bbox[3]], dims='x2', name='x2') bbox_norm = xr.Dataset(coords={'x1': x1, 'x2': x2}) bbox_unnorm = data_processor.unnormalise(bbox_norm) - unnorm_bbox_x1 = bbox_unnorm['x'].values.min(), bbox_unnorm['x'].values.max() - unnorm_bbox_x2 = bbox_unnorm['y'].values.min(), bbox_unnorm['y'].values.max() - # Determine X_t for patch - task_X_t = X_t.sel(x = slice(unnorm_bbox_x1[0], unnorm_bbox_x1[1]), - y = slice(unnorm_bbox_x2[0], unnorm_bbox_x2[1])) + unnorm_bbox_x1 = bbox_unnorm[orig_x1_name].values.min(), bbox_unnorm[orig_x1_name].values.max() + unnorm_bbox_x2 = bbox_unnorm[orig_x2_name].values.min(), bbox_unnorm[orig_x2_name].values.max() + + # Determine X_t for patch, however, cannot assume min/max ordering of slice coordinates + # Check the order of coordinates in X_t, sometimes they are in increasing or decreasing order + x1_coords = X_t.coords[orig_x1_name].values + x2_coords = X_t.coords[orig_x2_name].values + + if x1_coords[0] < x1_coords[-1]: + x1_slice = slice(unnorm_bbox_x1[0], unnorm_bbox_x1[1]) + else: + x1_slice = slice(unnorm_bbox_x1[1], unnorm_bbox_x1[0]) + + if x2_coords[0] < x2_coords[-1]: + x2_slice = slice(unnorm_bbox_x2[0], unnorm_bbox_x2[1]) + else: + x2_slice = slice(unnorm_bbox_x2[1], unnorm_bbox_x2[0]) + + # Determine X_t for patch with correct slice direction + task_X_t = X_t.sel(**{orig_x1_name: x1_slice, orig_x2_name: x2_slice}) + # Patchwise prediction pred = self.predict(task, task_X_t) # Append patchwise DeepSensor prediction object to list @@ -930,7 +953,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d patches_per_row = get_patches_per_row(preds, X_t) stitched_prediction = stitch_clipped_predictions(preds, patch_overlap_unnorm, patches_per_row) - ## Cast prediction into DeepSensor.Prediction object. + ## Cast prediction into DeepSensor.Prediction object.orig_x2_name # Todo: make this into seperate method. prediction= copy.deepcopy(preds[0]) @@ -938,7 +961,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d for var_name_copy, data_array_copy in prediction.items(): # set x and y coords - stitched_preds = xr.Dataset(coords={'x': X_t['x'], 'y': X_t['y']}) + stitched_preds = xr.Dataset(coords={orig_x1_name: X_t[orig_x1_name], orig_x2_name: X_t[orig_x2_name]}) # Set time to same as patched prediction stitched_preds['time'] = data_array_copy['time'] diff --git a/docs/user-guide/patchwise_training.py b/docs/user-guide/patchwise_training.py new file mode 100644 index 00000000..be855396 --- /dev/null +++ b/docs/user-guide/patchwise_training.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python + +import logging +import os + +logging.captureWarnings(True) + +import deepsensor.torch +from deepsensor.model import ConvNP +from deepsensor.train import Trainer, set_gpu_default_device +from deepsensor.data import DataProcessor, TaskLoader, construct_circ_time_ds +from deepsensor.data.sources import ( + get_era5_reanalysis_data, + get_earthenv_auxiliary_data, + get_gldas_land_mask, +) + +import xarray as xr +import cartopy.crs as ccrs +import matplotlib.pyplot as plt +import pandas as pd +import numpy as np +from tqdm import tqdm + + + + +# Training/data config +data_range = ("2010-01-01", "2019-12-31") +train_range = ("2010-01-01", "2018-12-31") +val_range = ("2019-01-01", "2019-12-31") +date_subsample_factor = 2 +extent = "north_america" +era5_var_IDs = ["2m_temperature"] +lowres_auxiliary_var_IDs = ["elevation"] +cache_dir = "../../.datacache" +deepsensor_folder = "../deepsensor_config/" +verbose_download = True + + + + +era5_raw_ds = get_era5_reanalysis_data( + era5_var_IDs, + extent, + date_range=data_range, + cache=True, + cache_dir=cache_dir, + verbose=verbose_download, + num_processes=8, +) +lowres_aux_raw_ds = get_earthenv_auxiliary_data( + lowres_auxiliary_var_IDs, + extent, + "100KM", + cache=True, + cache_dir=cache_dir, + verbose=verbose_download, +) +land_mask_raw_ds = get_gldas_land_mask( + extent, cache=True, cache_dir=cache_dir, verbose=verbose_download +) + +data_processor = DataProcessor(x1_name="lat", x2_name="lon") +era5_ds = data_processor(era5_raw_ds) +lowres_aux_ds, land_mask_ds = data_processor( + [lowres_aux_raw_ds, land_mask_raw_ds], method="min_max" +) + +dates = pd.date_range(era5_ds.time.values.min(), era5_ds.time.values.max(), freq="D") +doy_ds = construct_circ_time_ds(dates, freq="D") +lowres_aux_ds["cos_D"] = doy_ds["cos_D"] +lowres_aux_ds["sin_D"] = doy_ds["sin_D"] + + + + +set_gpu_default_device() + + +# ## Initialise TaskLoader and ConvNP model + + + +task_loader = TaskLoader( + context=[era5_ds, land_mask_ds, lowres_aux_ds], + target=era5_ds, +) +task_loader.load_dask() +print(task_loader) + + + + +# Set up model +model = ConvNP(data_processor, task_loader, unet_channels=(32, 32, 32, 32, 32)) + + +# ## Define how Tasks are generated +# + +def gen_training_tasks(dates, progress=True): + tasks = [] + for date in tqdm(dates, disable=not progress): + tasks_per_date = task_loader( + date, + context_sampling=["all", "all", "all"], + target_sampling="all", + patch_strategy="random", + patch_size=(0.4, 0.4), + num_samples_per_date=2, + ) + tasks.extend(tasks_per_date) + return tasks + + +def gen_validation_tasks(dates, progress=True): + tasks = [] + for date in tqdm(dates, disable=not progress): + tasks_per_date = task_loader( + date, + context_sampling=["all", "all", "all"], + target_sampling="all", + patch_strategy="sliding", + patch_size=(0.5, 0.5), + stride=(1,1) + ) + tasks.extend(tasks_per_date) + return tasks + + +# ## Generate validation tasks for testing generalisation + + + +val_dates = pd.date_range(val_range[0], val_range[1])[::date_subsample_factor] +val_tasks = gen_validation_tasks(val_dates) + + +# ## Training with the Trainer class + + + + +def compute_val_rmse(model, val_tasks): + errors = [] + target_var_ID = task_loader.target_var_IDs[0][0] # assume 1st target set and 1D + for task in val_tasks: + mean = data_processor.map_array(model.mean(task), target_var_ID, unnorm=True) + true = data_processor.map_array(task["Y_t"][0], target_var_ID, unnorm=True) + errors.extend(np.abs(mean - true)) + return np.sqrt(np.mean(np.concatenate(errors) ** 2)) + + + + +num_epochs = 50 +losses = [] +val_rmses = [] + +# # Train model +val_rmse_best = np.inf +trainer = Trainer(model, lr=5e-5) +for epoch in tqdm(range(num_epochs)): + train_tasks = gen_training_tasks(pd.date_range(train_range[0], train_range[1])[::date_subsample_factor], progress=False) + batch_losses = trainer(train_tasks) + losses.append(np.mean(batch_losses)) + val_rmses.append(compute_val_rmse(model, val_tasks)) + if val_rmses[-1] < val_rmse_best: + val_rmse_best = val_rmses[-1] + model.save(deepsensor_folder) + + + + +fig, axes = plt.subplots(1, 2, figsize=(12, 4)) +axes[0].plot(losses) +axes[1].plot(val_rmses) +_ = axes[0].set_xlabel("Epoch") +_ = axes[1].set_xlabel("Epoch") +_ = axes[0].set_title("Training loss") +_ = axes[1].set_title("Validation RMSE") + +fig.savefig(os.path.join(deepsensor_folder, "patchwise_training_loss.png")) + + +# prediction with patches ON-GRID, select one data from the validation tasks +# generate patchwise tasks for a specific date +# pick a random date as datetime64[ns] + +dates = [np.datetime64("2019-06-25")] +eval_task = gen_validation_tasks(dates, progress=False) +# test_task = task_loader(date, [100, "all", "all"], seed_override=42) +pred = model.predict_patch(eval_task, data_processor=data_processor, stride_size=(1, 1), patch_size=(0.5, 0.5), X_t=era5_raw_ds, resolution_factor=2) + +import pdb +pdb.set_trace() + +fig = deepsensor.plot.prediction(pred, dates[0], data_processor, task_loader, eval_task[0], crs=ccrs.PlateCarree()) +fig.savefig(os.path.join(deepsensor_folder, "patchwise_prediction.png")) + +print(0) + From 7cf556efaba63c0994ccabc3d3b9ed3cdd6d45fe Mon Sep 17 00:00:00 2001 From: Martin Rogers Date: Thu, 8 Aug 2024 12:04:05 +0100 Subject: [PATCH 41/69] correct typo --- deepsensor/model/model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 6d79fdd0..1ce5fa92 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -757,7 +757,8 @@ def get_patches_per_row(preds) -> int: """ patches_per_row = 0 vars = list(preds[0][0].data_vars) - var = vars[0] + + var = vars[0] y_val = preds[0][0][var].coords[unnorm_coord_names['x1']].min() for p in preds: @@ -803,7 +804,6 @@ def get_patch_overlap(overlap_norm, data_processor, X_t_ds) -> int: x_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[unnorm_coord_names['x1']].values - unnorm_overlap_x1))/2))) y_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[unnorm_coord_names['x2']].values - unnorm_overlap_x2))/2))) xy_overlap = (x_overlap_index, y_overlap_index) - return xy_overlap def get_index(*args, x1 = True) -> Union[int, Tuple[List[int], List[int]]]: @@ -918,7 +918,6 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d preds = [] for task in tasks: bbox = task['bbox'] - # Unnormalise coordinates of bounding box of patch x1 = xr.DataArray([bbox[0], bbox[1]], dims='x1', name='x1') x2 = xr.DataArray([bbox[2], bbox[3]], dims='x2', name='x2') @@ -938,9 +937,11 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d pred = self.predict(task, task_X_t) # Append patchwise DeepSensor prediction object to list preds.append(pred) - + + overlap_norm = tuple(patch - stride for patch, stride in zip(patch_size, stride_size)) patch_overlap_unnorm = get_patch_overlap(overlap_norm, data_processor, X_t) + patches_per_row = get_patches_per_row(preds) stitched_prediction = stitch_clipped_predictions(preds, patch_overlap_unnorm, patches_per_row) From bc862df0126a10b952c477b5c2a01dd247f58146 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Fri, 9 Aug 2024 10:11:39 +0100 Subject: [PATCH 42/69] fix stride & patch checking --- deepsensor/model/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index cbf6d423..2d39cc91 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -921,7 +921,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d f"stride must be smaller than patch_size in the corresponding dimensions. Got: patch_size: {patch_size}, stride: {stride}" ) - for val in zip(stride, patch_size): + for val in list(stride + patch_size): if val>1.0 or val<0.0: raise ValueError( f"Values of stride and patch_size must be between 0 & 1. Got: patch_size: {patch_size}, stride: {stride}" From 61dc88ea5baed1d8791debe9e1c9fb761a27d1b9 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Fri, 9 Aug 2024 10:22:09 +0000 Subject: [PATCH 43/69] revert previous commit --- deepsensor/model/model.py | 77 ++++------ docs/user-guide/patchwise_training.py | 203 -------------------------- 2 files changed, 27 insertions(+), 253 deletions(-) delete mode 100644 docs/user-guide/patchwise_training.py diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 3831c4ec..f6df23ac 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -735,9 +735,6 @@ def predict_patch( ValueError If ``append_indexes`` are not all the same length as ``X_t``. """ - - orig_x1_name = data_processor.x1_name - orig_x2_name = data_processor.x2_name def get_patches_per_row(preds, X_t) -> int: """ @@ -754,10 +751,10 @@ def get_patches_per_row(preds, X_t) -> int: patches_per_row = 0 vars = list(preds[0][0].data_vars) var = vars[0] - y_val = preds[0][0][var].coords[orig_x2_name].min() + y_val = preds[0][0][var].coords['y'].min() for p in preds: - if p[0][var].coords[orig_x2_name].min() == y_val: + if p[0][var].coords['y'].min() == y_val: patches_per_row = patches_per_row + 1 return patches_per_row @@ -792,12 +789,12 @@ def get_patch_overlap(overlap_norm, data_processor, X_t_ds) -> int: # Unnormalise coordinates of bounding boxes overlap_unnorm_xr = data_processor.unnormalise(overlap_norm_xr) - unnorm_overlap_x1 = overlap_unnorm_xr.coords[orig_x1_name].values[1] - unnorm_overlap_x2 = overlap_unnorm_xr.coords[orig_x2_name].values[1] + unnorm_overlap_x1 = overlap_unnorm_xr.coords['x'].values[1] + unnorm_overlap_x2 = overlap_unnorm_xr.coords['y'].values[1] # Find the position of these indices within the DataArray - x_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[orig_x1_name].values - unnorm_overlap_x1))/2))) - y_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[orig_x2_name].values - unnorm_overlap_x2))/2))) + x_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords['x'].values - unnorm_overlap_x1))/2))) + y_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords['y'].values - unnorm_overlap_x2))/2))) xy_overlap = (x_overlap_index, y_overlap_index) return xy_overlap @@ -827,15 +824,15 @@ def get_index(*args, x1 = True) -> Union[int, Tuple[List[int], List[int]]]: if len(args) == 1: patch_coord = args if x1: - coord_index = np.argmin(np.abs(X_t.coords[orig_x2_name].values - patch_coord)) + coord_index = np.argmin(np.abs(X_t.coords['y'].values - patch_coord)) else: - coord_index = np.argmin(np.abs(X_t.coords[orig_x1_name].values - patch_coord)) + coord_index = np.argmin(np.abs(X_t.coords['x'].values - patch_coord)) return coord_index elif len(args) == 2: patch_x1, patch_x2 = args - x1_index = [np.argmin(np.abs(X_t.coords[orig_x1_name].values - target_x1)) for target_x1 in patch_x1] - x2_index = [np.argmin(np.abs(X_t.coords[orig_x2_name].values - target_x2)) for target_x2 in patch_x2] + x1_index = [np.argmin(np.abs(X_t.coords['y'].values - target_x1)) for target_x1 in patch_x1] + x2_index = [np.argmin(np.abs(X_t.coords['x'].values - target_x2)) for target_x2 in patch_x2] return (x1_index, x2_index) @@ -860,8 +857,8 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d Dictionary object containing the stitched model predictions. """ - data_x1 = X_t.coords[orig_x2_name].min().values, X_t.coords[orig_x2_name].max().values - data_x2 = X_t.coords[orig_x1_name].min().values, X_t.coords[orig_x1_name].max().values + data_x1 = X_t.coords['y'].min().values, X_t.coords['y'].max().values + data_x2 = X_t.coords['x'].min().values, X_t.coords['x'].max().values data_x1_index, data_x2_index = get_index(data_x1, data_x2) patches_clipped = {var_name: [] for var_name in patch_preds[0].keys()} @@ -870,8 +867,8 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d for var_name, data_array in patch_pred.items(): #previously patch if var_name in patch_pred: # Get row/col index values of each patch - patch_x1 = data_array.coords[orig_x2_name].min().values, data_array.coords[orig_x2_name].max().values - patch_x2 = data_array.coords[orig_x1_name].min().values, data_array.coords[orig_x1_name].max().values + patch_x1 = data_array.coords['y'].min().values, data_array.coords['y'].max().values + patch_x2 = data_array.coords['x'].min().values, data_array.coords['x'].max().values patch_x1_index, patch_x2_index = get_index(patch_x1, patch_x2) b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0] @@ -883,7 +880,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d elif patch_x2_index[1] == data_x2_index[1]: b_x2_max = 0 patch_row_prev = preds[i-1] - prev_patch_x2_max = get_index(int(patch_row_prev[var_name].coords[orig_x1_name].max()), x1 = False) + prev_patch_x2_max = get_index(int(patch_row_prev[var_name].coords['x'].max()), x1 = False) b_x2_min = (prev_patch_x2_max - patch_x2_index[0])-patch_overlap[1] if patch_x1_index[0] == data_x1_index[0]: @@ -891,18 +888,16 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d elif abs(patch_x1_index[1] - data_x1_index[1])<2: b_x1_max = 0 patch_prev = preds[i-patches_per_row] - prev_patch_x1_max = get_index(int(patch_prev[var_name].coords[orig_x2_name].max()), x1 = True) + prev_patch_x1_max = get_index(int(patch_prev[var_name].coords['y'].max()), x1 = True) b_x1_min = (prev_patch_x1_max - patch_x1_index[0])- patch_overlap[0] patch_clip_x1_min = int(b_x1_min) - patch_clip_x1_max = int(data_array.sizes[orig_x2_name] - b_x1_max) + patch_clip_x1_max = int(data_array.sizes['y'] - b_x1_max) patch_clip_x2_min = int(b_x2_min) - patch_clip_x2_max = int(data_array.sizes[orig_x1_name] - b_x2_max) - - # patch_clip = data_array.isel(y=slice(patch_clip_x1_min, patch_clip_x1_max), - # x=slice(patch_clip_x2_min, patch_clip_x2_max)) + patch_clip_x2_max = int(data_array.sizes['x'] - b_x2_max) - patch_clip = data_array.isel(**{orig_x1_name: slice(patch_clip_x1_min, patch_clip_x1_max), orig_x2_name: slice(patch_clip_x2_min, patch_clip_x2_max)}) + patch_clip = data_array.isel(y=slice(patch_clip_x1_min, patch_clip_x1_max), + x=slice(patch_clip_x2_min, patch_clip_x2_max)) patches_clipped[var_name].append(patch_clip) @@ -910,8 +905,6 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d return combined - - # Perform patchwise predictions preds = [] for task in tasks: @@ -922,27 +915,11 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d x2 = xr.DataArray([bbox[2], bbox[3]], dims='x2', name='x2') bbox_norm = xr.Dataset(coords={'x1': x1, 'x2': x2}) bbox_unnorm = data_processor.unnormalise(bbox_norm) - unnorm_bbox_x1 = bbox_unnorm[orig_x1_name].values.min(), bbox_unnorm[orig_x1_name].values.max() - unnorm_bbox_x2 = bbox_unnorm[orig_x2_name].values.min(), bbox_unnorm[orig_x2_name].values.max() - - # Determine X_t for patch, however, cannot assume min/max ordering of slice coordinates - # Check the order of coordinates in X_t, sometimes they are in increasing or decreasing order - x1_coords = X_t.coords[orig_x1_name].values - x2_coords = X_t.coords[orig_x2_name].values - - if x1_coords[0] < x1_coords[-1]: - x1_slice = slice(unnorm_bbox_x1[0], unnorm_bbox_x1[1]) - else: - x1_slice = slice(unnorm_bbox_x1[1], unnorm_bbox_x1[0]) - - if x2_coords[0] < x2_coords[-1]: - x2_slice = slice(unnorm_bbox_x2[0], unnorm_bbox_x2[1]) - else: - x2_slice = slice(unnorm_bbox_x2[1], unnorm_bbox_x2[0]) - - # Determine X_t for patch with correct slice direction - task_X_t = X_t.sel(**{orig_x1_name: x1_slice, orig_x2_name: x2_slice}) - + unnorm_bbox_x1 = bbox_unnorm['x'].values.min(), bbox_unnorm['x'].values.max() + unnorm_bbox_x2 = bbox_unnorm['y'].values.min(), bbox_unnorm['y'].values.max() + # Determine X_t for patch + task_X_t = X_t.sel(x = slice(unnorm_bbox_x1[0], unnorm_bbox_x1[1]), + y = slice(unnorm_bbox_x2[0], unnorm_bbox_x2[1])) # Patchwise prediction pred = self.predict(task, task_X_t) # Append patchwise DeepSensor prediction object to list @@ -953,7 +930,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d patches_per_row = get_patches_per_row(preds, X_t) stitched_prediction = stitch_clipped_predictions(preds, patch_overlap_unnorm, patches_per_row) - ## Cast prediction into DeepSensor.Prediction object.orig_x2_name + ## Cast prediction into DeepSensor.Prediction object. # Todo: make this into seperate method. prediction= copy.deepcopy(preds[0]) @@ -961,7 +938,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d for var_name_copy, data_array_copy in prediction.items(): # set x and y coords - stitched_preds = xr.Dataset(coords={orig_x1_name: X_t[orig_x1_name], orig_x2_name: X_t[orig_x2_name]}) + stitched_preds = xr.Dataset(coords={'x': X_t['x'], 'y': X_t['y']}) # Set time to same as patched prediction stitched_preds['time'] = data_array_copy['time'] diff --git a/docs/user-guide/patchwise_training.py b/docs/user-guide/patchwise_training.py deleted file mode 100644 index be855396..00000000 --- a/docs/user-guide/patchwise_training.py +++ /dev/null @@ -1,203 +0,0 @@ -#!/usr/bin/env python - -import logging -import os - -logging.captureWarnings(True) - -import deepsensor.torch -from deepsensor.model import ConvNP -from deepsensor.train import Trainer, set_gpu_default_device -from deepsensor.data import DataProcessor, TaskLoader, construct_circ_time_ds -from deepsensor.data.sources import ( - get_era5_reanalysis_data, - get_earthenv_auxiliary_data, - get_gldas_land_mask, -) - -import xarray as xr -import cartopy.crs as ccrs -import matplotlib.pyplot as plt -import pandas as pd -import numpy as np -from tqdm import tqdm - - - - -# Training/data config -data_range = ("2010-01-01", "2019-12-31") -train_range = ("2010-01-01", "2018-12-31") -val_range = ("2019-01-01", "2019-12-31") -date_subsample_factor = 2 -extent = "north_america" -era5_var_IDs = ["2m_temperature"] -lowres_auxiliary_var_IDs = ["elevation"] -cache_dir = "../../.datacache" -deepsensor_folder = "../deepsensor_config/" -verbose_download = True - - - - -era5_raw_ds = get_era5_reanalysis_data( - era5_var_IDs, - extent, - date_range=data_range, - cache=True, - cache_dir=cache_dir, - verbose=verbose_download, - num_processes=8, -) -lowres_aux_raw_ds = get_earthenv_auxiliary_data( - lowres_auxiliary_var_IDs, - extent, - "100KM", - cache=True, - cache_dir=cache_dir, - verbose=verbose_download, -) -land_mask_raw_ds = get_gldas_land_mask( - extent, cache=True, cache_dir=cache_dir, verbose=verbose_download -) - -data_processor = DataProcessor(x1_name="lat", x2_name="lon") -era5_ds = data_processor(era5_raw_ds) -lowres_aux_ds, land_mask_ds = data_processor( - [lowres_aux_raw_ds, land_mask_raw_ds], method="min_max" -) - -dates = pd.date_range(era5_ds.time.values.min(), era5_ds.time.values.max(), freq="D") -doy_ds = construct_circ_time_ds(dates, freq="D") -lowres_aux_ds["cos_D"] = doy_ds["cos_D"] -lowres_aux_ds["sin_D"] = doy_ds["sin_D"] - - - - -set_gpu_default_device() - - -# ## Initialise TaskLoader and ConvNP model - - - -task_loader = TaskLoader( - context=[era5_ds, land_mask_ds, lowres_aux_ds], - target=era5_ds, -) -task_loader.load_dask() -print(task_loader) - - - - -# Set up model -model = ConvNP(data_processor, task_loader, unet_channels=(32, 32, 32, 32, 32)) - - -# ## Define how Tasks are generated -# - -def gen_training_tasks(dates, progress=True): - tasks = [] - for date in tqdm(dates, disable=not progress): - tasks_per_date = task_loader( - date, - context_sampling=["all", "all", "all"], - target_sampling="all", - patch_strategy="random", - patch_size=(0.4, 0.4), - num_samples_per_date=2, - ) - tasks.extend(tasks_per_date) - return tasks - - -def gen_validation_tasks(dates, progress=True): - tasks = [] - for date in tqdm(dates, disable=not progress): - tasks_per_date = task_loader( - date, - context_sampling=["all", "all", "all"], - target_sampling="all", - patch_strategy="sliding", - patch_size=(0.5, 0.5), - stride=(1,1) - ) - tasks.extend(tasks_per_date) - return tasks - - -# ## Generate validation tasks for testing generalisation - - - -val_dates = pd.date_range(val_range[0], val_range[1])[::date_subsample_factor] -val_tasks = gen_validation_tasks(val_dates) - - -# ## Training with the Trainer class - - - - -def compute_val_rmse(model, val_tasks): - errors = [] - target_var_ID = task_loader.target_var_IDs[0][0] # assume 1st target set and 1D - for task in val_tasks: - mean = data_processor.map_array(model.mean(task), target_var_ID, unnorm=True) - true = data_processor.map_array(task["Y_t"][0], target_var_ID, unnorm=True) - errors.extend(np.abs(mean - true)) - return np.sqrt(np.mean(np.concatenate(errors) ** 2)) - - - - -num_epochs = 50 -losses = [] -val_rmses = [] - -# # Train model -val_rmse_best = np.inf -trainer = Trainer(model, lr=5e-5) -for epoch in tqdm(range(num_epochs)): - train_tasks = gen_training_tasks(pd.date_range(train_range[0], train_range[1])[::date_subsample_factor], progress=False) - batch_losses = trainer(train_tasks) - losses.append(np.mean(batch_losses)) - val_rmses.append(compute_val_rmse(model, val_tasks)) - if val_rmses[-1] < val_rmse_best: - val_rmse_best = val_rmses[-1] - model.save(deepsensor_folder) - - - - -fig, axes = plt.subplots(1, 2, figsize=(12, 4)) -axes[0].plot(losses) -axes[1].plot(val_rmses) -_ = axes[0].set_xlabel("Epoch") -_ = axes[1].set_xlabel("Epoch") -_ = axes[0].set_title("Training loss") -_ = axes[1].set_title("Validation RMSE") - -fig.savefig(os.path.join(deepsensor_folder, "patchwise_training_loss.png")) - - -# prediction with patches ON-GRID, select one data from the validation tasks -# generate patchwise tasks for a specific date -# pick a random date as datetime64[ns] - -dates = [np.datetime64("2019-06-25")] -eval_task = gen_validation_tasks(dates, progress=False) -# test_task = task_loader(date, [100, "all", "all"], seed_override=42) -pred = model.predict_patch(eval_task, data_processor=data_processor, stride_size=(1, 1), patch_size=(0.5, 0.5), X_t=era5_raw_ds, resolution_factor=2) - -import pdb -pdb.set_trace() - -fig = deepsensor.plot.prediction(pred, dates[0], data_processor, task_loader, eval_task[0], crs=ccrs.PlateCarree()) -fig.savefig(os.path.join(deepsensor_folder, "patchwise_prediction.png")) - -print(0) - From f2bd5bb558691502634135ddbb8bac7af6845c36 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Fri, 9 Aug 2024 15:29:47 +0100 Subject: [PATCH 44/69] fix patchwise training tests --- tests/test_training.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_training.py b/tests/test_training.py index 2ab63716..72e6be5a 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -115,9 +115,9 @@ def test_training(self): loss = np.mean(epoch_losses) self.assertFalse(np.isnan(loss)) - def test_patch_wise_training(self): + def test_patchwise_training(self): """ - Test model training with patch-wise tasks. + Test model training with patchwise tasks. """ tl = TaskLoader(context=self.da, target=self.da) model = ConvNP(self.data_processor, tl, unet_channels=(5, 5, 5), verbose=False) @@ -125,7 +125,7 @@ def test_patch_wise_training(self): # generate training tasks n_train_dates = 10 dates = [np.random.choice(self.da.time.values) for i in range(n_train_dates)] - train_tasks = tl.generate_tasks( + train_tasks = tl( dates, context_sampling="all", target_sampling="all", @@ -159,7 +159,7 @@ def test_sliding_window_training(self): # generate training tasks n_train_dates = 3 dates = [np.random.choice(self.da.time.values) for i in range(n_train_dates)] - train_tasks = tl.generate_tasks( + train_tasks = tl( dates, context_sampling="all", target_sampling="all", From 601102e113be476e78517403340d9bdcecad3ce4 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Fri, 9 Aug 2024 15:36:31 +0100 Subject: [PATCH 45/69] add actual training step to test_sliding_window_training --- tests/test_training.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_training.py b/tests/test_training.py index 72e6be5a..3d4c0388 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -172,6 +172,14 @@ def test_sliding_window_training(self): trainer = Trainer(model, lr=5e-5) batch_size = None n_epochs = 2 + epoch_losses = [] + for epoch in tqdm(range(n_epochs)): + batch_losses = trainer(train_tasks, batch_size=batch_size) + epoch_losses.append(np.mean(batch_losses)) + + # Check for NaNs in the loss + loss = np.mean(epoch_losses) + self.assertFalse(np.isnan(loss)) def test_training_multidim(self): """A basic test of the training loop with multidimensional context sets""" From 64773fc1dc24d5759ad1e79949ac0e446ebd2ea7 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Fri, 9 Aug 2024 17:14:07 +0100 Subject: [PATCH 46/69] try to make printing work for task objects with bbox attribute --- deepsensor/data/task.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/deepsensor/data/task.py b/deepsensor/data/task.py index 2725b2cb..aa7badcc 100644 --- a/deepsensor/data/task.py +++ b/deepsensor/data/task.py @@ -31,7 +31,9 @@ def __init__(self, task_dict: dict) -> None: @classmethod def summarise_str(cls, k, v): - if plum.isinstance(v, B.Numeric): + if isinstance(v, float): + return v + elif plum.isinstance(v, B.Numeric): return v.shape elif plum.isinstance(v, tuple): return tuple(vi.shape for vi in v) @@ -58,6 +60,8 @@ def summarise_repr(cls, k, v) -> str: """ if v is None: return "None" + elif isinstance(v, float): + return f"{type(v).__name__}" elif plum.isinstance(v, B.Numeric): return f"{type(v).__name__}/{v.dtype}/{v.shape}" if plum.isinstance(v, deepsensor.backend.nps.mask.Masked): From 69f0ac605c6d999f787f6a39f9e6b59f297c4f9a Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Fri, 9 Aug 2024 17:18:09 +0100 Subject: [PATCH 47/69] run black --- deepsensor/data/loader.py | 36 ++--- deepsensor/model/model.py | 299 +++++++++++++++++++++++++------------- tests/test_model.py | 37 +++-- 3 files changed, 235 insertions(+), 137 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 8f06279a..6db69a85 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -908,7 +908,7 @@ def spatial_slice_variable(self, var, window: List[float]): Returns: var (...) Sliced variable. - + Raises: ValueError If the variable is of an unknown type. @@ -1369,7 +1369,7 @@ def sample_sliding_window( # define stride length in x1/x2 or set to patch_size if undefined if stride is None: stride = patch_size - + dy, dx = stride # Calculate the global bounds of context and target set. @@ -1496,7 +1496,7 @@ def __call__( if isinstance(patch_size, float) and patch_size is not None: patch_size = (patch_size, patch_size) - + if isinstance(stride, float) and stride is not None: stride = (stride, stride) @@ -1524,7 +1524,7 @@ def __call__( ) elif patch_strategy == "random": - + assert ( patch_size is not None ), "Patch size must be specified for random patch sampling" @@ -1550,21 +1550,21 @@ def __call__( else: bboxes = [ - self.sample_random_window(patch_size) - for _ in range(num_samples_per_date) - ] + self.sample_random_window(patch_size) + for _ in range(num_samples_per_date) + ] tasks = [ - self.task_generation( - date, - bbox=bbox, - context_sampling=context_sampling, - target_sampling=target_sampling, - split_frac=split_frac, - datewise_deterministic=datewise_deterministic, - seed_override=seed_override, - ) - for bbox in bboxes - ] + self.task_generation( + date, + bbox=bbox, + context_sampling=context_sampling, + target_sampling=target_sampling, + split_frac=split_frac, + datewise_deterministic=datewise_deterministic, + seed_override=seed_override, + ) + for bbox in bboxes + ] elif patch_strategy == "sliding": # sliding window sampling of patch diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 2d39cc91..91b04966 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -620,7 +620,6 @@ def unnormalise_pred_array(arr, **kwargs): return pred - def predict_patch( self, tasks: Union[List[Task], Task], @@ -654,7 +653,6 @@ def predict_patch( append_indexes: dict = None, progress_bar: int = 0, verbose: bool = False, - ) -> Prediction: """ Predict on a regular grid or at off-grid locations. @@ -735,75 +733,100 @@ def predict_patch( ValueError If ``append_indexes`` are not all the same length as ``X_t``. """ - + def get_patches_per_row(preds) -> int: """ - Calculate number of patches per row. - Required to stitch patches back together. + Calculate number of patches per row. + Required to stitch patches back together. Args: preds (List[class:`~.model.pred.Prediction`]): A list of `dict`-like objects containing patchwise predictions. - + Returns: patches_per_row: int Number of patches per row. - """ + """ patches_per_row = 0 vars = list(preds[0][0].data_vars) - var = vars[0] - y_val = preds[0][0][var].coords[unnorm_coord_names['x1']].min() - + var = vars[0] + y_val = preds[0][0][var].coords[unnorm_coord_names["x1"]].min() + for p in preds: - if p[0][var].coords[unnorm_coord_names['x1']].min() == y_val: - patches_per_row = patches_per_row + 1 + if p[0][var].coords[unnorm_coord_names["x1"]].min() == y_val: + patches_per_row = patches_per_row + 1 return patches_per_row - - def get_patch_overlap(overlap_norm, data_processor, X_t_ds) -> int: """ - Calculate overlap between adjacent patches in pixels. - + Calculate overlap between adjacent patches in pixels. + Parameters ---------- - overlap_norm : tuple[float]. + overlap_norm : tuple[float]. Normalised size of overlap in x1/x2. - + data_processor (:class:`~.data.processor.DataProcessor`): - Used for unnormalising the coordinates of the bounding boxes of patches. + Used for unnormalising the coordinates of the bounding boxes of patches. X_t_ds (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`): - Data array containing target locations to predict at. - + Data array containing target locations to predict at. + Returns ------- patch_overlap : tuple (int) - Unnormalised size of overlap between adjacent patches. + Unnormalised size of overlap between adjacent patches. """ # Place stride and patch size values in Xarray to pass into unnormalise() overlap_list = [0, overlap_norm[0], 0, overlap_norm[1]] - x1 = xr.DataArray([overlap_list[0], overlap_list[1]], dims='x1', name='x1') - x2 = xr.DataArray([overlap_list[2], overlap_list[3]], dims='x2', name='x2') - overlap_norm_xr = xr.Dataset(coords={'x1': x1, 'x2': x2}) - + x1 = xr.DataArray([overlap_list[0], overlap_list[1]], dims="x1", name="x1") + x2 = xr.DataArray([overlap_list[2], overlap_list[3]], dims="x2", name="x2") + overlap_norm_xr = xr.Dataset(coords={"x1": x1, "x2": x2}) + # Unnormalise coordinates of bounding boxes overlap_unnorm_xr = data_processor.unnormalise(overlap_norm_xr) - unnorm_overlap_x1 = overlap_unnorm_xr.coords[unnorm_coord_names['x1']].values[1] - unnorm_overlap_x2 = overlap_unnorm_xr.coords[unnorm_coord_names['x2']].values[1] + unnorm_overlap_x1 = overlap_unnorm_xr.coords[ + unnorm_coord_names["x1"] + ].values[1] + unnorm_overlap_x2 = overlap_unnorm_xr.coords[ + unnorm_coord_names["x2"] + ].values[1] # Find the position of these indices within the DataArray - x_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[unnorm_coord_names['x1']].values - unnorm_overlap_x1))/2))) - y_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[unnorm_coord_names['x2']].values - unnorm_overlap_x2))/2))) + x_overlap_index = int( + np.ceil( + ( + np.argmin( + np.abs( + X_t_ds.coords[unnorm_coord_names["x1"]].values + - unnorm_overlap_x1 + ) + ) + / 2 + ) + ) + ) + y_overlap_index = int( + np.ceil( + ( + np.argmin( + np.abs( + X_t_ds.coords[unnorm_coord_names["x2"]].values + - unnorm_overlap_x2 + ) + ) + / 2 + ) + ) + ) xy_overlap = (x_overlap_index, y_overlap_index) return xy_overlap - - def get_index(*args, x1 = True) -> Union[int, Tuple[List[int], List[int]]]: + def get_index(*args, x1=True) -> Union[int, Tuple[List[int], List[int]]]: """ Convert coordinates into pixel row/column (index). - + Parameters ---------- args : tuple @@ -812,68 +835,96 @@ def get_index(*args, x1 = True) -> Union[int, Tuple[List[int], List[int]]]: x1 : bool, optional If True, compute index for x1 (default is True). - + Returns ------- Union[int, Tuple[List[int], List[int]]] If one argument is provided and x1 is True or False, returns the index position. If two arguments are provided, returns a tuple containing two lists: - First list: indices corresponding to x1 coordinates. - - Second list: indices corresponding to x2 coordinates. + - Second list: indices corresponding to x2 coordinates. """ if len(args) == 1: patch_coord = args if x1: - coord_index = np.argmin(np.abs(X_t.coords[unnorm_coord_names['x1']].values - patch_coord)) + coord_index = np.argmin( + np.abs( + X_t.coords[unnorm_coord_names["x1"]].values - patch_coord + ) + ) else: - coord_index = np.argmin(np.abs(X_t.coords[unnorm_coord_names['x2']].values - patch_coord)) + coord_index = np.argmin( + np.abs( + X_t.coords[unnorm_coord_names["x2"]].values - patch_coord + ) + ) return coord_index elif len(args) == 2: - patch_x1, patch_x2 = args - x1_index = [np.argmin(np.abs(X_t.coords[unnorm_coord_names['x1']].values - target_x1)) for target_x1 in patch_x1] - x2_index = [np.argmin(np.abs(X_t.coords[unnorm_coord_names['x2']].values - target_x2)) for target_x2 in patch_x2] + patch_x1, patch_x2 = args + x1_index = [ + np.argmin( + np.abs(X_t.coords[unnorm_coord_names["x1"]].values - target_x1) + ) + for target_x1 in patch_x1 + ] + x2_index = [ + np.argmin( + np.abs(X_t.coords[unnorm_coord_names["x2"]].values - target_x2) + ) + for target_x2 in patch_x2 + ] return (x1_index, x2_index) - - - def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> dict: + + def stitch_clipped_predictions( + patch_preds, patch_overlap, patches_per_row + ) -> dict: """ - Stitch patchwise predictions to form prediction at original extent. + Stitch patchwise predictions to form prediction at original extent. Parameters ---------- patch_preds : list (class:`~.model.pred.Prediction`) List of patchwise predictions - + patch_overlap: int Overlap between adjacent patches in pixels. - + patches_per_row: int Number of patchwise predictions in each row. - + Returns ------- combined: dict Dictionary object containing the stitched model predictions. """ - - - data_x1 = X_t.coords[unnorm_coord_names['x1']].min().values, X_t.coords[unnorm_coord_names['x1']].max().values - data_x2 = X_t.coords[unnorm_coord_names['x2']].min().values, X_t.coords[unnorm_coord_names['x2']].max().values + data_x1 = ( + X_t.coords[unnorm_coord_names["x1"]].min().values, + X_t.coords[unnorm_coord_names["x1"]].max().values, + ) + data_x2 = ( + X_t.coords[unnorm_coord_names["x2"]].min().values, + X_t.coords[unnorm_coord_names["x2"]].max().values, + ) data_x1_index, data_x2_index = get_index(data_x1, data_x2) patches_clipped = {var_name: [] for var_name in patch_preds[0].keys()} - for i, patch_pred in enumerate(patch_preds): - for var_name, data_array in patch_pred.items(): #previously patch + for var_name, data_array in patch_pred.items(): # previously patch if var_name in patch_pred: # Get row/col index values of each patch - patch_x1 = data_array.coords[unnorm_coord_names['x1']].min().values, data_array.coords[unnorm_coord_names['x1']].max().values - patch_x2 = data_array.coords[unnorm_coord_names['x2']].min().values, data_array.coords[unnorm_coord_names['x2']].max().values - patch_x1_index, patch_x2_index = get_index(patch_x1, patch_x2) - + patch_x1 = ( + data_array.coords[unnorm_coord_names["x1"]].min().values, + data_array.coords[unnorm_coord_names["x1"]].max().values, + ) + patch_x2 = ( + data_array.coords[unnorm_coord_names["x2"]].min().values, + data_array.coords[unnorm_coord_names["x2"]].max().values, + ) + patch_x1_index, patch_x2_index = get_index(patch_x1, patch_x2) + b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0] b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1] # Do not remove border for the patches along top and left of dataset @@ -882,58 +933,90 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d b_x2_min = 0 elif patch_x2_index[1] == data_x2_index[1]: b_x2_max = 0 - patch_row_prev = preds[i-1] - prev_patch_x2_max = get_index(int(patch_row_prev[var_name].coords[unnorm_coord_names['x2']].max()), x1 = False) - b_x2_min = (prev_patch_x2_max - patch_x2_index[0])-patch_overlap[1] + patch_row_prev = preds[i - 1] + prev_patch_x2_max = get_index( + int( + patch_row_prev[var_name] + .coords[unnorm_coord_names["x2"]] + .max() + ), + x1=False, + ) + b_x2_min = ( + prev_patch_x2_max - patch_x2_index[0] + ) - patch_overlap[1] if patch_x1_index[0] == data_x1_index[0]: b_x1_min = 0 - elif abs(patch_x1_index[1] - data_x1_index[1])<2: + elif abs(patch_x1_index[1] - data_x1_index[1]) < 2: b_x1_max = 0 - patch_prev = preds[i-patches_per_row] - prev_patch_x1_max = get_index(int(patch_prev[var_name].coords[unnorm_coord_names['x1']].max()), x1 = True) - b_x1_min = (prev_patch_x1_max - patch_x1_index[0])- patch_overlap[0] + patch_prev = preds[i - patches_per_row] + prev_patch_x1_max = get_index( + int( + patch_prev[var_name] + .coords[unnorm_coord_names["x1"]] + .max() + ), + x1=True, + ) + b_x1_min = ( + prev_patch_x1_max - patch_x1_index[0] + ) - patch_overlap[0] patch_clip_x1_min = int(b_x1_min) - patch_clip_x1_max = int(data_array.sizes[unnorm_coord_names['x1']] - b_x1_max) + patch_clip_x1_max = int( + data_array.sizes[unnorm_coord_names["x1"]] - b_x1_max + ) patch_clip_x2_min = int(b_x2_min) - patch_clip_x2_max = int(data_array.sizes[unnorm_coord_names['x2']] - b_x2_max) + patch_clip_x2_max = int( + data_array.sizes[unnorm_coord_names["x2"]] - b_x2_max + ) - patch_clip = data_array[{unnorm_coord_names['x1']: slice(patch_clip_x1_min, patch_clip_x1_max), - unnorm_coord_names['x2']: slice(patch_clip_x2_min, patch_clip_x2_max)}] + patch_clip = data_array[ + { + unnorm_coord_names["x1"]: slice( + patch_clip_x1_min, patch_clip_x1_max + ), + unnorm_coord_names["x2"]: slice( + patch_clip_x2_min, patch_clip_x2_max + ), + } + ] patches_clipped[var_name].append(patch_clip) - combined = {var_name: xr.combine_by_coords(patches, compat='no_conflicts') for var_name, patches in patches_clipped.items()} + combined = { + var_name: xr.combine_by_coords(patches, compat="no_conflicts") + for var_name, patches in patches_clipped.items() + } return combined # sanitise patch_size and stride arguments - + if isinstance(patch_size, float) and patch_size is not None: patch_size = (patch_size, patch_size) - + if isinstance(stride, float) and stride is not None: stride = (stride, stride) if stride[0] > patch_size[0] or stride[1] > patch_size[1]: raise ValueError( f"stride must be smaller than patch_size in the corresponding dimensions. Got: patch_size: {patch_size}, stride: {stride}" - ) + ) for val in list(stride + patch_size): - if val>1.0 or val<0.0: + if val > 1.0 or val < 0.0: raise ValueError( f"Values of stride and patch_size must be between 0 & 1. Got: patch_size: {patch_size}, stride: {stride}" ) - - + # Get coordinate names of original unnormalised dataset. unnorm_coord_names = { - "x1": self.data_processor.raw_spatial_coord_names[0], - "x2": self.data_processor.raw_spatial_coord_names[1], - } - + "x1": self.data_processor.raw_spatial_coord_names[0], + "x2": self.data_processor.raw_spatial_coord_names[1], + } + # tasks should be iterable, if only one is provided, make it a list if type(tasks) is Task: tasks = [tasks] @@ -941,23 +1024,31 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d # Perform patchwise predictions preds = [] for task in tasks: - bbox = task['bbox'] - + bbox = task["bbox"] + if bbox is None: - raise AttributeError("Tasks require non-None ``bbox`` for patchwise inference.") + raise AttributeError( + "Tasks require non-None ``bbox`` for patchwise inference." + ) # Unnormalise coordinates of bounding box of patch - x1 = xr.DataArray([bbox[0], bbox[1]], dims='x1', name='x1') - x2 = xr.DataArray([bbox[2], bbox[3]], dims='x2', name='x2') - bbox_norm = xr.Dataset(coords={'x1': x1, 'x2': x2}) + x1 = xr.DataArray([bbox[0], bbox[1]], dims="x1", name="x1") + x2 = xr.DataArray([bbox[2], bbox[3]], dims="x2", name="x2") + bbox_norm = xr.Dataset(coords={"x1": x1, "x2": x2}) bbox_unnorm = data_processor.unnormalise(bbox_norm) - unnorm_bbox_x1 = bbox_unnorm[unnorm_coord_names['x1']].values.min(), bbox_unnorm[unnorm_coord_names['x1']].values.max() - unnorm_bbox_x2 = bbox_unnorm[unnorm_coord_names['x2']].values.min(), bbox_unnorm[unnorm_coord_names['x2']].values.max() - + unnorm_bbox_x1 = ( + bbox_unnorm[unnorm_coord_names["x1"]].values.min(), + bbox_unnorm[unnorm_coord_names["x1"]].values.max(), + ) + unnorm_bbox_x2 = ( + bbox_unnorm[unnorm_coord_names["x2"]].values.min(), + bbox_unnorm[unnorm_coord_names["x2"]].values.max(), + ) + # Determine X_t for patch task_extent_dict = { - unnorm_coord_names['x1']: slice(unnorm_bbox_x1[0], unnorm_bbox_x1[1]), - unnorm_coord_names['x2']: slice(unnorm_bbox_x2[0], unnorm_bbox_x2[1]) + unnorm_coord_names["x1"]: slice(unnorm_bbox_x1[0], unnorm_bbox_x1[1]), + unnorm_coord_names["x2"]: slice(unnorm_bbox_x2[0], unnorm_bbox_x2[1]), } task_X_t = X_t.sel(**task_extent_dict) @@ -965,34 +1056,44 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d pred = self.predict(task, task_X_t) # Append patchwise DeepSensor prediction object to list preds.append(pred) - - overlap_norm = tuple(patch - stride for patch, stride in zip(patch_size, stride)) + + overlap_norm = tuple( + patch - stride for patch, stride in zip(patch_size, stride) + ) patch_overlap_unnorm = get_patch_overlap(overlap_norm, data_processor, X_t) patches_per_row = get_patches_per_row(preds) - stitched_prediction = stitch_clipped_predictions(preds, patch_overlap_unnorm, patches_per_row) - + stitched_prediction = stitch_clipped_predictions( + preds, patch_overlap_unnorm, patches_per_row + ) + ## Cast prediction into DeepSensor.Prediction object. - # TODO make this into seperate method. - prediction= copy.deepcopy(preds[0]) + # TODO make this into seperate method. + prediction = copy.deepcopy(preds[0]) # Generate new blank DeepSensor.prediction object in original coordinate system. for var_name_copy, data_array_copy in prediction.items(): # set x and y coords - stitched_preds = xr.Dataset(coords={'x1': X_t[unnorm_coord_names['x1']], 'x2': X_t[unnorm_coord_names['x2']]}) + stitched_preds = xr.Dataset( + coords={ + "x1": X_t[unnorm_coord_names["x1"]], + "x2": X_t[unnorm_coord_names["x2"]], + } + ) # Set time to same as patched prediction - stitched_preds['time'] = data_array_copy['time'] + stitched_preds["time"] = data_array_copy["time"] # set variable names to those in patched prediction, make values blank for var_name_i in data_array_copy.data_vars: stitched_preds[var_name_i] = data_array_copy[var_name_i] stitched_preds[var_name_i][:] = np.nan - prediction[var_name_copy]= stitched_preds + prediction[var_name_copy] = stitched_preds prediction[var_name_copy] = stitched_prediction[var_name_copy] return prediction + def main(): # pragma: no cover import deepsensor.tensorflow from deepsensor.data.loader import TaskLoader diff --git a/tests/test_model.py b/tests/test_model.py index bac07406..5e6c7f6a 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -542,12 +542,12 @@ def test_patchwise_prediction(self): model = ConvNP(self.dp, tl) pred = model.predict_patch( - tasks=task, - X_t=self.da, - data_processor=self.dp, - stride=stride_size, - patch_size=patch_size, - ) + tasks=task, + X_t=self.da, + data_processor=self.dp, + stride=stride_size, + patch_size=patch_size, + ) # gridded predictions assert [isinstance(ds, xr.Dataset) for ds in pred.values()] @@ -560,20 +560,17 @@ def test_patchwise_prediction(self): pred[var_ID]["std"], (1, self.da.x1.size, self.da.x2.size), ) - assert( - self.da.x1.size == pred[var_ID].x1.size - ) - assert( - self.da.x2.size == pred[var_ID].x2.size - ) - - - @parameterized.expand([ - ((0.5, 0.5), (0.6, 0.6)), # patch_size and stride as tuples - (0.5, 0.6), # as floats - (1.0, 1.2), # one argument above allowed range - (-0.1, 0.6) # and below allowed range - ]) + assert self.da.x1.size == pred[var_ID].x1.size + assert self.da.x2.size == pred[var_ID].x2.size + + @parameterized.expand( + [ + ((0.5, 0.5), (0.6, 0.6)), # patch_size and stride as tuples + (0.5, 0.6), # as floats + (1.0, 1.2), # one argument above allowed range + (-0.1, 0.6), # and below allowed range + ] + ) def test_patchwise_prediction_parameter_handling(self, patch_size, stride_size): """Test that correct errors and warnings are raised by ``.predict_patch``.""" From f5e4a8aa4cf36ff7cdd8291c558d8ca70c448178 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Mon, 12 Aug 2024 10:19:26 +0100 Subject: [PATCH 48/69] re-add missing code from task loader --- deepsensor/data/loader.py | 42 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 6a59c406..2bfcdf2a 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1297,6 +1297,48 @@ def sample_variable(var, sampling_strat, seed): f"with the `links` attribute if using the 'gapfill' sampling strategy" ) + context_var = context_slices[context_idx] + target_var = target_slices[target_idx] + + for var in [context_var, target_var]: + assert isinstance(var, (xr.DataArray, xr.Dataset)), ( + f"If using 'gapfill' sampling strategy for linked context and target sets, " + f"the context and target sets must be xarray DataArrays or Datasets, " + f"but got {type(var)}." + ) + + split_seed = seed + gapfill_i if seed is not None else None + rng = np.random.default_rng(split_seed) + + # Keep trying until we get a target set with at least one target point + keep_searching = True + while keep_searching: + added_mask_date = rng.choice(self.context[context_idx].time) + added_mask = ( + self.context[context_idx].sel(time=added_mask_date).isnull() + ) + curr_mask = context_var.isnull() + + # Mask out added missing values + context_var = context_var.where(~added_mask) + + # TEMP: Inefficient to convert all non-targets to NaNs and then remove NaNs + # when we could just slice the target values here + target_mask = added_mask & ~curr_mask + if isinstance(target_var, xr.Dataset): + keep_searching = np.all(target_mask.to_array().data == False) + else: + keep_searching = np.all(target_mask.data == False) + if keep_searching: + continue # No target points -- use a different `added_mask` + + target_var = target_var.where( + target_mask + ) # Only keep target locations + + context_slices[context_idx] = context_var + target_slices[target_idx] = target_var + for i, (var, sampling_strat) in enumerate( zip(context_slices, context_sampling) ): From 5e29031e9dd2be185a7a4688e2350435c1265842 Mon Sep 17 00:00:00 2001 From: Martin Rogers Date: Mon, 12 Aug 2024 13:49:38 +0100 Subject: [PATCH 49/69] Commit to allow patching irrespective of whether x1 and x2 are ascending/descending --- deepsensor/model/model.py | 115 ++++++++++++++++++++++++++++---------- 1 file changed, 86 insertions(+), 29 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 3aee64cb..fc08ed68 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -761,12 +761,12 @@ def get_patches_per_row(preds) -> int: patches_per_row = 0 vars = list(preds[0][0].data_vars) var = vars[0] - y_val = preds[0][0][var].coords[orig_x2_name].min() + x1_val = preds[0][0][var].coords[orig_x1_name].min() for p in preds: - if p[0][var].coords[orig_x2_name].min() == y_val: + if p[0][var].coords[orig_x1_name].min() == x1_val: patches_per_row = patches_per_row + 1 - + print("patches_per_row", patches_per_row) return patches_per_row @@ -791,7 +791,7 @@ def get_patch_overlap(overlap_norm, data_processor, X_t_ds) -> int: patch_overlap : tuple (int) Unnormalised size of overlap between adjacent patches. """ - # Place stride and patch size values in Xarray to pass into unnormalise() + # Place x1/x2 overlap values in Xarray to pass into unnormalise() overlap_list = [0, overlap_norm[0], 0, overlap_norm[1]] x1 = xr.DataArray([overlap_list[0], overlap_list[1]], dims='x1', name='x1') x2 = xr.DataArray([overlap_list[2], overlap_list[3]], dims='x2', name='x2') @@ -799,14 +799,16 @@ def get_patch_overlap(overlap_norm, data_processor, X_t_ds) -> int: # Unnormalise coordinates of bounding boxes overlap_unnorm_xr = data_processor.unnormalise(overlap_norm_xr) + #print('intermediary unnorm overlap value', overlap_unnorm_xr) unnorm_overlap_x1 = overlap_unnorm_xr.coords[orig_x1_name].values[1] unnorm_overlap_x2 = overlap_unnorm_xr.coords[orig_x2_name].values[1] - + #print('intermediary unnorm overlap value2:', unnorm_overlap_x1, unnorm_overlap_x2) # Find the position of these indices within the DataArray - x_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[orig_x1_name].values - unnorm_overlap_x1))/2))) - y_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[orig_x2_name].values - unnorm_overlap_x2))/2))) - xy_overlap = (x_overlap_index, y_overlap_index) - return xy_overlap + x1_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[orig_x1_name].values - unnorm_overlap_x1))/2))) + x2_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[orig_x2_name].values - unnorm_overlap_x2))/2))) + #print('intermediary unnorm overlap value2:', x1_overlap_index, x2_overlap_index, X_t_ds.coords[orig_x1_name].values, X_t_ds.coords[orig_x2_name].values) + x1x2_overlap = (x1_overlap_index, x2_overlap_index) + return x1x2_overlap def get_index(*args, x1 = True) -> Union[int, Tuple[List[int], List[int]]]: """ @@ -833,9 +835,9 @@ def get_index(*args, x1 = True) -> Union[int, Tuple[List[int], List[int]]]: if len(args) == 1: patch_coord = args if x1: - coord_index = np.argmin(np.abs(X_t.coords[orig_x2_name].values - patch_coord)) + coord_index = np.argmin(np.abs(X_t.coords[orig_x1_name].values - patch_coord)) else: - coord_index = np.argmin(np.abs(X_t.coords[orig_x1_name].values - patch_coord)) + coord_index = np.argmin(np.abs(X_t.coords[orig_x2_name].values - patch_coord)) return coord_index elif len(args) == 2: @@ -845,7 +847,7 @@ def get_index(*args, x1 = True) -> Union[int, Tuple[List[int], List[int]]]: return (x1_index, x2_index) - def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> dict: + def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_ascend=True, x2_ascend=True) -> dict: """ Stitch patchwise predictions to form prediction at original extent. @@ -866,49 +868,98 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d Dictionary object containing the stitched model predictions. """ + # Get row/col index values of X_t. Order depends on whether coordinate is ascending or descending. + if x1_ascend: + data_x1 = X_t.coords[orig_x1_name].min().values, X_t.coords[orig_x1_name].max().values + else: + data_x1 = X_t.coords[orig_x1_name].max().values, X_t.coords[orig_x1_name].min().values + if x2_ascend: + data_x2 = X_t.coords[orig_x2_name].min().values, X_t.coords[orig_x2_name].max().values + else: + data_x2 = X_t.coords[orig_x2_name].max().values, X_t.coords[orig_x2_name].min().values - - data_x1 = X_t.coords[orig_x2_name].min().values, X_t.coords[orig_x2_name].max().values - data_x2 = X_t.coords[orig_x1_name].min().values, X_t.coords[orig_x1_name].max().values data_x1_index, data_x2_index = get_index(data_x1, data_x2) + print('coords of X_t', data_x1[0].item(), data_x1[1].item(), data_x2[0].item(), data_x2[1].item()) + print('row and column values of X_t', data_x1_index, data_x2_index ) patches_clipped = {var_name: [] for var_name in patch_preds[0].keys()} - + print('coords ascending', x1_ascend, x2_ascend) for i, patch_pred in enumerate(patch_preds): for var_name, data_array in patch_pred.items(): #previously patch if var_name in patch_pred: - # Get row/col index values of each patch - patch_x1 = data_array.coords[orig_x2_name].min().values, data_array.coords[orig_x2_name].max().values - patch_x2 = data_array.coords[orig_x1_name].min().values, data_array.coords[orig_x1_name].max().values + # Get row/col index values of each patch. Order depends on whether coordinate is ascending or descending. + if x1_ascend: + patch_x1 = data_array.coords[orig_x1_name].min().values, data_array.coords[orig_x1_name].max().values + else: + patch_x1 = data_array.coords[orig_x1_name].max().values, data_array.coords[orig_x1_name].min().values + if x2_ascend: + patch_x2 = data_array.coords[orig_x2_name].min().values, data_array.coords[orig_x2_name].max().values + else: + patch_x2 = data_array.coords[orig_x2_name].max().values, data_array.coords[orig_x2_name].min().values patch_x1_index, patch_x2_index = get_index(patch_x1, patch_x2) + print('coords of patch', patch_x1[0].item(), patch_x1[1].item(), patch_x2[0].item(), patch_x2[1].item()) + print('row and column values of patch', patch_x1_index, patch_x2_index) b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0] b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1] # Do not remove border for the patches along top and left of dataset # and change overlap size for last patch in each row and column. + """ + At end of row (when patch_x2_index = data_x2_index), to calculate the number of pixels to remove from left hand side of patch: + If x2 is ascending, subtract previous patch x2 max value from current patch x2 min value to get bespoke overlap in column pixels. + To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels + to get the number of pixels to remove from left hand side of patch. + + If x2 is descending. Subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels. + To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels + to get the number of pixels to remove from left hand side of patch. + + """ if patch_x2_index[0] == data_x2_index[0]: b_x2_min = 0 elif patch_x2_index[1] == data_x2_index[1]: b_x2_max = 0 patch_row_prev = preds[i-1] - prev_patch_x2_max = get_index(int(patch_row_prev[var_name].coords[orig_x1_name].max()), x1 = False) - b_x2_min = (prev_patch_x2_max - patch_x2_index[0])-patch_overlap[1] + if x2_ascend: + prev_patch_x2_max = get_index(int(patch_row_prev[var_name].coords[orig_x2_name].max()), x1 = False) + b_x2_min = (prev_patch_x2_max - patch_x2_index[0])-patch_overlap[1] + else: + prev_patch_x2_min = get_index(int(patch_row_prev[var_name].coords[orig_x2_name].min()), x1 = False) + b_x2_min = (patch_x2_index[0] -prev_patch_x2_min)-patch_overlap[1] if patch_x1_index[0] == data_x1_index[0]: b_x1_min = 0 elif abs(patch_x1_index[1] - data_x1_index[1])<2: b_x1_max = 0 patch_prev = preds[i-patches_per_row] - prev_patch_x1_max = get_index(int(patch_prev[var_name].coords[orig_x2_name].max()), x1 = True) - b_x1_min = (prev_patch_x1_max - patch_x1_index[0])- patch_overlap[0] + if x1_ascend: + prev_patch_x1_max = get_index(int(patch_prev[var_name].coords[orig_x1_name].max()), x1 = True) + b_x1_min = (prev_patch_x1_max - patch_x1_index[0])- patch_overlap[0] + else: + prev_patch_x1_min = get_index(int(patch_prev[var_name].coords[orig_x1_name].min()), x1 = True) + b_x1_min = (patch_x1_index[0] - prev_patch_x1_min)- patch_overlap[0] patch_clip_x1_min = int(b_x1_min) - patch_clip_x1_max = int(data_array.sizes[orig_x2_name] - b_x1_max) + patch_clip_x1_max = int(data_array.sizes[orig_x1_name] - b_x1_max) patch_clip_x2_min = int(b_x2_min) - patch_clip_x2_max = int(data_array.sizes[orig_x1_name] - b_x2_max) + patch_clip_x2_max = int(data_array.sizes[orig_x2_name] - b_x2_max) + """ + if x1_ascend: + patch_clip_x1_max = int(data_array.sizes[orig_x1_name] - b_x1_max) + else: + patch_clip_x1_max = int(data_array.sizes[orig_x1_name] - b_x1_max) + patch_clip_x2_min = int(b_x2_min) + if x2_ascend: + patch_clip_x2_max = int(data_array.sizes[orig_x2_name] - b_x2_max) + else: + patch_clip_x2_max = int(b_x2_max - data_array.sizes[orig_x2_name]) + """ + print('x1 and x2 sizes', data_array.sizes[orig_x1_name], data_array.sizes[orig_x2_name]) # patch_clip = data_array.isel(y=slice(patch_clip_x1_min, patch_clip_x1_max), # x=slice(patch_clip_x2_min, patch_clip_x2_max)) + patch_clip_x1_min = int(b_x1_min) + print('final clip coord values', patch_clip_x1_min, patch_clip_x1_max, patch_clip_x2_min, patch_clip_x2_max) patch_clip = data_array.isel(**{orig_x1_name: slice(patch_clip_x1_min, patch_clip_x1_max), orig_x2_name: slice(patch_clip_x2_min, patch_clip_x2_max)}) @@ -939,28 +990,34 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d if x1_coords[0] < x1_coords[-1]: x1_slice = slice(unnorm_bbox_x1[0], unnorm_bbox_x1[1]) + x1_ascending = True else: x1_slice = slice(unnorm_bbox_x1[1], unnorm_bbox_x1[0]) + x1_ascending = False if x2_coords[0] < x2_coords[-1]: x2_slice = slice(unnorm_bbox_x2[0], unnorm_bbox_x2[1]) + x2_ascending = True else: x2_slice = slice(unnorm_bbox_x2[1], unnorm_bbox_x2[0]) + x2_ascending = False # Determine X_t for patch with correct slice direction task_X_t = X_t.sel(**{orig_x1_name: x1_slice, orig_x2_name: x2_slice}) - + # Patchwise prediction pred = self.predict(task, task_X_t) # Append patchwise DeepSensor prediction object to list preds.append(pred) - + print('first pred', preds[0]) overlap_norm = tuple(patch - stride for patch, stride in zip(patch_size, stride_size)) patch_overlap_unnorm = get_patch_overlap(overlap_norm, data_processor, X_t) - + print('pred bbox coords', overlap_norm, patch_overlap_unnorm) + patch_overlap_unnorm = (5,5) + print('pred bbox coords', overlap_norm, patch_overlap_unnorm) patches_per_row = get_patches_per_row(preds) - stitched_prediction = stitch_clipped_predictions(preds, patch_overlap_unnorm, patches_per_row) + stitched_prediction = stitch_clipped_predictions(preds, patch_overlap_unnorm, patches_per_row, x1_ascending, x2_ascending) ## Cast prediction into DeepSensor.Prediction object.orig_x2_name # Todo: make this into seperate method. From 294cc47c694c16dcd71c51cea771c206b9abe0b7 Mon Sep 17 00:00:00 2001 From: Martin Rogers Date: Mon, 12 Aug 2024 17:12:44 +0100 Subject: [PATCH 50/69] changes to loader.py to ensure all patched tasks run left to right and top to bottom --- deepsensor/data/loader.py | 104 +++++++++++++++++++++++++++++++------- deepsensor/model/model.py | 15 +++--- 2 files changed, 95 insertions(+), 24 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 6a59c406..7b809bd6 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -191,6 +191,7 @@ def __init__( ) = self.infer_context_and_target_var_IDs() self.coord_bounds = self._compute_global_coordinate_bounds() + self.coord_directions = self._compute_x1x2_direction() def _set_config(self): """Instantiate a config dictionary for the TaskLoader object""" @@ -829,6 +830,55 @@ def _compute_global_coordinate_bounds(self) -> List[float]: x2_max = var_x2_max return [x1_min, x1_max, x2_min, x2_max] + + def _compute_x1x2_direction(self) -> str: + """ + Compute whether the x1 and x2 coords are ascending or descending. + + Returns + ------- + x1_ascend: str + Boolean: If x1 coords ascend from left to right = True, if descend = False + x1_ascend: str + Boolean: If x2 coords ascend from top to bottom = True, if descend = False + """ + + for var in itertools.chain(self.context, self.target): + if isinstance(var, (xr.Dataset, xr.DataArray)): + coord_x1_left= var.x1[0] + coord_x1_right= var.x1[-1] + coord_x2_top= var.x2[0] + coord_x2_bottom= var.x2[-1] + #Todo- what to input for pd.dataframe + elif isinstance(var, (pd.DataFrame, pd.Series)): + var_x1_min = var.index.get_level_values("x1").min() + var_x1_max = var.index.get_level_values("x1").max() + var_x2_min = var.index.get_level_values("x2").min() + var_x2_max = var.index.get_level_values("x2").max() + + x1_ascend = True + x2_ascend = True + if coord_x1_left < coord_x1_right: + x1_ascend = True + print('x1 ascending') + if coord_x1_left > coord_x1_right: + x1_ascend = False + print("x1 descending") + + if coord_x2_top < coord_x2_bottom: + x2_ascend = True + print('x2 ascending') + if coord_x2_top > coord_x2_bottom: + x2_ascend = False + print("x2 descending") + + + coord_directions = { + "x1": x1_ascend, + "x2": x2_ascend, + } + + return coord_directions def sample_random_window(self, patch_size: Tuple[float]) -> Sequence[float]: """ @@ -1371,29 +1421,47 @@ def sample_sliding_window( stride = patch_size dy, dx = stride - + print('stride size', dy, dx) # Calculate the global bounds of context and target set. x1_min, x1_max, x2_min, x2_max = self.coord_bounds - + print('in sample_sliding_window', self.coord_directions) ## start with first patch top left hand corner at x1_min, x2_min patch_list = [] - for y in np.arange(x1_min, x1_max, dy): - for x in np.arange(x2_min, x2_max, dx): - if y + x1_extend > x1_max: - y0 = x1_max - x1_extend - else: - y0 = y - if x + x2_extend > x2_max: - x0 = x2_max - x2_extend - else: - x0 = x - - # bbox of x1_min, x1_max, x2_min, x2_max per patch - bbox = [y0, y0 + x1_extend, x0, x0 + x2_extend] - - patch_list.append(bbox) - + if self.coord_directions['x1'] == False and self.coord_directions['x2'] == True: + print('rocking the scenario') + for y in np.arange(x1_max, x1_min, -dy): + for x in np.arange(x2_min, x2_max, dx): + if y - x1_extend < x1_min: + y0 = x1_min + x1_extend + else: + y0 = y + if x + x2_extend > x2_max: + x0 = x2_max - x2_extend + else: + x0 = x + + # bbox of x1_min, x1_max, x2_min, x2_max per patch + bbox = [y0 - x1_extend, y0, x0, x0 + x2_extend] + patch_list.append(bbox) + else: + for y in np.arange(x1_min, x1_max, dy): + for x in np.arange(x2_min, x2_max, dx): + if y + x1_extend > x1_max: + y0 = x1_max - x1_extend + else: + y0 = y + if x + x2_extend > x2_max: + x0 = x2_max - x2_extend + else: + x0 = x + + # bbox of x1_min, x1_max, x2_min, x2_max per patch + bbox = [y0, y0 + x1_extend, x0, x0 + x2_extend] + + patch_list.append(bbox) + + print('patch list', patch_list) return patch_list def __call__( diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index fc08ed68..293e3a51 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -898,7 +898,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a patch_x2 = data_array.coords[orig_x2_name].max().values, data_array.coords[orig_x2_name].min().values patch_x1_index, patch_x2_index = get_index(patch_x1, patch_x2) - print('coords of patch', patch_x1[0].item(), patch_x1[1].item(), patch_x2[0].item(), patch_x2[1].item()) + #print('coords of patch', patch_x1[0].item(), patch_x1[1].item(), patch_x2[0].item(), patch_x2[1].item()) print('row and column values of patch', patch_x1_index, patch_x2_index) b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0] b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1] @@ -937,7 +937,9 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a b_x1_min = (prev_patch_x1_max - patch_x1_index[0])- patch_overlap[0] else: prev_patch_x1_min = get_index(int(patch_prev[var_name].coords[orig_x1_name].min()), x1 = True) - b_x1_min = (patch_x1_index[0] - prev_patch_x1_min)- patch_overlap[0] + + b_x1_min = (prev_patch_x1_min- patch_x1_index[0])- patch_overlap[0] + patch_clip_x1_min = int(b_x1_min) patch_clip_x1_max = int(data_array.sizes[orig_x1_name] - b_x1_max) @@ -955,12 +957,13 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a else: patch_clip_x2_max = int(b_x2_max - data_array.sizes[orig_x2_name]) """ - print('x1 and x2 sizes', data_array.sizes[orig_x1_name], data_array.sizes[orig_x2_name]) + #print('x1 and x2 sizes', data_array.sizes[orig_x1_name], data_array.sizes[orig_x2_name]) # patch_clip = data_array.isel(y=slice(patch_clip_x1_min, patch_clip_x1_max), # x=slice(patch_clip_x2_min, patch_clip_x2_max)) patch_clip_x1_min = int(b_x1_min) - print('final clip coord values', patch_clip_x1_min, patch_clip_x1_max, patch_clip_x2_min, patch_clip_x2_max) - + #print('final clip coord values', patch_clip_x1_min, patch_clip_x1_max, patch_clip_x2_min, patch_clip_x2_max) + print('row and column values of clipped patch', patch_x1_index[0]-int(b_x1_min) , patch_x1_index[1]+int(b_x1_max), + patch_x2_index[0]+int(b_x2_min) , patch_x2_index[1]-int(b_x2_max)) patch_clip = data_array.isel(**{orig_x1_name: slice(patch_clip_x1_min, patch_clip_x1_max), orig_x2_name: slice(patch_clip_x2_min, patch_clip_x2_max)}) patches_clipped[var_name].append(patch_clip) @@ -1014,7 +1017,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a overlap_norm = tuple(patch - stride for patch, stride in zip(patch_size, stride_size)) patch_overlap_unnorm = get_patch_overlap(overlap_norm, data_processor, X_t) print('pred bbox coords', overlap_norm, patch_overlap_unnorm) - patch_overlap_unnorm = (5,5) + patch_overlap_unnorm = (10,10) print('pred bbox coords', overlap_norm, patch_overlap_unnorm) patches_per_row = get_patches_per_row(preds) stitched_prediction = stitch_clipped_predictions(preds, patch_overlap_unnorm, patches_per_row, x1_ascending, x2_ascending) From 4e136e39bf4d77440d25e3ca342842e6c2927b01 Mon Sep 17 00:00:00 2001 From: Martin Rogers Date: Tue, 13 Aug 2024 10:27:23 +0100 Subject: [PATCH 51/69] Commit to make model agnostic to coord direction --- deepsensor/data/loader.py | 64 ++++++++++++++++++++++------ deepsensor/model/model.py | 89 +++++++++++++++++---------------------- 2 files changed, 88 insertions(+), 65 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 7b809bd6..cda26f0e 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -837,10 +837,10 @@ def _compute_x1x2_direction(self) -> str: Returns ------- - x1_ascend: str - Boolean: If x1 coords ascend from left to right = True, if descend = False - x1_ascend: str - Boolean: If x2 coords ascend from top to bottom = True, if descend = False + coord_directions: dict(str) + String containing two booleans: x1_ascend and x2_ascend, + defining if these coordings increase or decrease from top left corner. + """ for var in itertools.chain(self.context, self.target): @@ -849,7 +849,7 @@ def _compute_x1x2_direction(self) -> str: coord_x1_right= var.x1[-1] coord_x2_top= var.x2[0] coord_x2_bottom= var.x2[-1] - #Todo- what to input for pd.dataframe + #Todo- what to input for pd.dataframe elif isinstance(var, (pd.DataFrame, pd.Series)): var_x1_min = var.index.get_level_values("x1").min() var_x1_max = var.index.get_level_values("x1").max() @@ -860,17 +860,14 @@ def _compute_x1x2_direction(self) -> str: x2_ascend = True if coord_x1_left < coord_x1_right: x1_ascend = True - print('x1 ascending') if coord_x1_left > coord_x1_right: x1_ascend = False - print("x1 descending") if coord_x2_top < coord_x2_bottom: x2_ascend = True - print('x2 ascending') if coord_x2_top > coord_x2_bottom: x2_ascend = False - print("x2 descending") + coord_directions = { @@ -1421,15 +1418,13 @@ def sample_sliding_window( stride = patch_size dy, dx = stride - print('stride size', dy, dx) # Calculate the global bounds of context and target set. x1_min, x1_max, x2_min, x2_max = self.coord_bounds - print('in sample_sliding_window', self.coord_directions) ## start with first patch top left hand corner at x1_min, x2_min patch_list = [] + # Todo: simplify these elif statements if self.coord_directions['x1'] == False and self.coord_directions['x2'] == True: - print('rocking the scenario') for y in np.arange(x1_max, x1_min, -dy): for x in np.arange(x2_min, x2_max, dx): if y - x1_extend < x1_min: @@ -1444,6 +1439,38 @@ def sample_sliding_window( # bbox of x1_min, x1_max, x2_min, x2_max per patch bbox = [y0 - x1_extend, y0, x0, x0 + x2_extend] patch_list.append(bbox) + + elif self.coord_directions['x1'] == False and self.coord_directions['x2'] == False: + for y in np.arange(x1_max, x1_min, -dy): + for x in np.arange(x2_max, x2_min, -dx): + if y - x1_extend < x1_min: + y0 = x1_min + x1_extend + else: + y0 = y + if x - x2_extend < x2_min: + x0 = x2_min + x2_extend + else: + x0 = x + + # bbox of x1_min, x1_max, x2_min, x2_max per patch + bbox = [y0 - x1_extend, y0, x0 - x2_extend, x0] + patch_list.append(bbox) + + elif self.coord_directions['x1'] == True and self.coord_directions['x2'] == False: + for y in np.arange(x1_min, x1_max, dy): + for x in np.arange(x2_max, x2_min, -dx): + if y + x1_extend > x1_max: + y0 = x1_max - x1_extend + else: + y0 = y + if x - x2_extend < x2_min: + x0 = x2_min + x2_extend + else: + x0 = x + + # bbox of x1_min, x1_max, x2_min, x2_max per patch + bbox = [y0, y0 + x1_extend, x0 - x2_extend, x0] + patch_list.append(bbox) else: for y in np.arange(x1_min, x1_max, dy): for x in np.arange(x2_min, x2_max, dx): @@ -1461,8 +1488,17 @@ def sample_sliding_window( patch_list.append(bbox) - print('patch list', patch_list) - return patch_list + # Remove duplicate patches while preserving order + seen = set() + unique_patch_list = [] + for lst in patch_list: + # Convert list to tuple for immutability + tuple_lst = tuple(lst) + if tuple_lst not in seen: + seen.add(tuple_lst) + unique_patch_list.append(lst) + + return unique_patch_list def __call__( self, diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 293e3a51..327c17a5 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -735,17 +735,10 @@ def predict_patch( ValueError If ``append_indexes`` are not all the same length as ``X_t``. """ - + # Get coordinate names of original unnormalised dataset. orig_x1_name = data_processor.x1_name orig_x2_name = data_processor.x2_name - # Get coordinate names of original unnormalised dataset. - unnorm_coord_names = { - "x1": self.data_processor.raw_spatial_coord_names[0], - "x2": self.data_processor.raw_spatial_coord_names[1], - } - - def get_patches_per_row(preds) -> int: """ Calculate number of patches per row. @@ -763,15 +756,15 @@ def get_patches_per_row(preds) -> int: var = vars[0] x1_val = preds[0][0][var].coords[orig_x1_name].min() - for p in preds: - if p[0][var].coords[orig_x1_name].min() == x1_val: + for pred in preds: + if pred[0][var].coords[orig_x1_name].min() == x1_val: patches_per_row = patches_per_row + 1 - print("patches_per_row", patches_per_row) + return patches_per_row - def get_patch_overlap(overlap_norm, data_processor, X_t_ds) -> int: + def get_patch_overlap(overlap_norm, data_processor, X_t_ds, x1_ascend, x2_ascend) -> int: """ Calculate overlap between adjacent patches in pixels. @@ -786,11 +779,18 @@ def get_patch_overlap(overlap_norm, data_processor, X_t_ds) -> int: X_t_ds (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`): Data array containing target locations to predict at. + x1_ascend : str: + Boolean defining whether the x1 coords ascend (increase) from top to bottom, default = True. + + x2_ascend : str: + Boolean defining whether the x2 coords ascend (increase) from left to right, default = True. + Returns ------- patch_overlap : tuple (int) Unnormalised size of overlap between adjacent patches. """ + # Todo- check if there is simplier and more robust way to convert overlap into pixels. # Place x1/x2 overlap values in Xarray to pass into unnormalise() overlap_list = [0, overlap_norm[0], 0, overlap_norm[1]] x1 = xr.DataArray([overlap_list[0], overlap_list[1]], dims='x1', name='x1') @@ -799,16 +799,23 @@ def get_patch_overlap(overlap_norm, data_processor, X_t_ds) -> int: # Unnormalise coordinates of bounding boxes overlap_unnorm_xr = data_processor.unnormalise(overlap_norm_xr) - #print('intermediary unnorm overlap value', overlap_unnorm_xr) + unnorm_overlap_x1 = overlap_unnorm_xr.coords[orig_x1_name].values[1] unnorm_overlap_x2 = overlap_unnorm_xr.coords[orig_x2_name].values[1] - #print('intermediary unnorm overlap value2:', unnorm_overlap_x1, unnorm_overlap_x2) - # Find the position of these indices within the DataArray - x1_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[orig_x1_name].values - unnorm_overlap_x1))/2))) - x2_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[orig_x2_name].values - unnorm_overlap_x2))/2))) - #print('intermediary unnorm overlap value2:', x1_overlap_index, x2_overlap_index, X_t_ds.coords[orig_x1_name].values, X_t_ds.coords[orig_x2_name].values) - x1x2_overlap = (x1_overlap_index, x2_overlap_index) - return x1x2_overlap + + # Find size of overlap for x1/x2 in pixels + if x1_ascend: + x1_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[orig_x1_name].values - unnorm_overlap_x1))/2))) + else: + x1_overlap_index = int(np.floor((X_t_ds.coords[orig_x1_name].values.size- int(np.ceil((np.argmin(np.abs(X_t_ds.coords[orig_x1_name].values- unnorm_overlap_x1))))))/2)) + if x2_ascend: + x2_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[orig_x2_name].values - unnorm_overlap_x2))/2))) + else: + x2_overlap_index = int(np.floor((X_t_ds.coords[orig_x2_name].values.size- int(np.ceil((np.argmin(np.abs(X_t_ds.coords[orig_x2_name].values- unnorm_overlap_x2))))))/2)) + + x1_x2_overlap = (x1_overlap_index, x2_overlap_index) + + return x1_x2_overlap def get_index(*args, x1 = True) -> Union[int, Tuple[List[int], List[int]]]: """ @@ -861,6 +868,12 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a patches_per_row: int Number of patchwise predictions in each row. + + x1_ascend : str + Boolean defining whether the x1 coords ascend (increase) from top to bottom, default = True. + + x2_ascend : str + Boolean defining whether the x2 coords ascend (increase) from left to right, default = True. Returns ------- @@ -879,10 +892,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a data_x2 = X_t.coords[orig_x2_name].max().values, X_t.coords[orig_x2_name].min().values data_x1_index, data_x2_index = get_index(data_x1, data_x2) - print('coords of X_t', data_x1[0].item(), data_x1[1].item(), data_x2[0].item(), data_x2[1].item()) - print('row and column values of X_t', data_x1_index, data_x2_index ) patches_clipped = {var_name: [] for var_name in patch_preds[0].keys()} - print('coords ascending', x1_ascend, x2_ascend) for i, patch_pred in enumerate(patch_preds): for var_name, data_array in patch_pred.items(): #previously patch @@ -898,13 +908,12 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a patch_x2 = data_array.coords[orig_x2_name].max().values, data_array.coords[orig_x2_name].min().values patch_x1_index, patch_x2_index = get_index(patch_x1, patch_x2) - #print('coords of patch', patch_x1[0].item(), patch_x1[1].item(), patch_x2[0].item(), patch_x2[1].item()) - print('row and column values of patch', patch_x1_index, patch_x2_index) b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0] b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1] - # Do not remove border for the patches along top and left of dataset - # and change overlap size for last patch in each row and column. + """ + Do not remove border for the patches along top and left of dataset and change overlap size for last patch in each row and column. + At end of row (when patch_x2_index = data_x2_index), to calculate the number of pixels to remove from left hand side of patch: If x2 is ascending, subtract previous patch x2 max value from current patch x2 min value to get bespoke overlap in column pixels. To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels @@ -946,24 +955,6 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a patch_clip_x2_min = int(b_x2_min) patch_clip_x2_max = int(data_array.sizes[orig_x2_name] - b_x2_max) - """ - if x1_ascend: - patch_clip_x1_max = int(data_array.sizes[orig_x1_name] - b_x1_max) - else: - patch_clip_x1_max = int(data_array.sizes[orig_x1_name] - b_x1_max) - patch_clip_x2_min = int(b_x2_min) - if x2_ascend: - patch_clip_x2_max = int(data_array.sizes[orig_x2_name] - b_x2_max) - else: - patch_clip_x2_max = int(b_x2_max - data_array.sizes[orig_x2_name]) - """ - #print('x1 and x2 sizes', data_array.sizes[orig_x1_name], data_array.sizes[orig_x2_name]) - # patch_clip = data_array.isel(y=slice(patch_clip_x1_min, patch_clip_x1_max), - # x=slice(patch_clip_x2_min, patch_clip_x2_max)) - patch_clip_x1_min = int(b_x1_min) - #print('final clip coord values', patch_clip_x1_min, patch_clip_x1_max, patch_clip_x2_min, patch_clip_x2_max) - print('row and column values of clipped patch', patch_x1_index[0]-int(b_x1_min) , patch_x1_index[1]+int(b_x1_max), - patch_x2_index[0]+int(b_x2_min) , patch_x2_index[1]-int(b_x2_max)) patch_clip = data_array.isel(**{orig_x1_name: slice(patch_clip_x1_min, patch_clip_x1_max), orig_x2_name: slice(patch_clip_x2_min, patch_clip_x2_max)}) patches_clipped[var_name].append(patch_clip) @@ -987,7 +978,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a unnorm_bbox_x2 = bbox_unnorm[orig_x2_name].values.min(), bbox_unnorm[orig_x2_name].values.max() # Determine X_t for patch, however, cannot assume min/max ordering of slice coordinates - # Check the order of coordinates in X_t, sometimes they are in increasing or decreasing order + # Check the order of coordinates in X_t, sometimes they are increasing or decreasing in order. x1_coords = X_t.coords[orig_x1_name].values x2_coords = X_t.coords[orig_x2_name].values @@ -1013,12 +1004,8 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a # Append patchwise DeepSensor prediction object to list preds.append(pred) - print('first pred', preds[0]) overlap_norm = tuple(patch - stride for patch, stride in zip(patch_size, stride_size)) - patch_overlap_unnorm = get_patch_overlap(overlap_norm, data_processor, X_t) - print('pred bbox coords', overlap_norm, patch_overlap_unnorm) - patch_overlap_unnorm = (10,10) - print('pred bbox coords', overlap_norm, patch_overlap_unnorm) + patch_overlap_unnorm = get_patch_overlap(overlap_norm, data_processor, X_t, x1_ascending, x2_ascending) patches_per_row = get_patches_per_row(preds) stitched_prediction = stitch_clipped_predictions(preds, patch_overlap_unnorm, patches_per_row, x1_ascending, x2_ascending) From 529e8c8951d5229d4e58c34b4a18aaac97adacc8 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Wed, 14 Aug 2024 15:16:09 +0100 Subject: [PATCH 52/69] use more informative error message for predict_patch --- deepsensor/model/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 91b04966..dad3caba 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -1028,7 +1028,9 @@ def stitch_clipped_predictions( if bbox is None: raise AttributeError( - "Tasks require non-None ``bbox`` for patchwise inference." + "For patchwise prediction, only tasks generated using a patch_strategy of 'sliding' are valid. \ + This task has a bbox value of None, indicating that it was generated with a patch_strategy of \ + 'random' or None." ) # Unnormalise coordinates of bounding box of patch From 0344c2a8faf40c3d6f1bba36129d1adf5604e51d Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Thu, 15 Aug 2024 07:40:24 +0100 Subject: [PATCH 53/69] fix use of stride_size --- deepsensor/model/model.py | 2 +- tests/test_model.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 9a3254c9..c1b627e7 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -983,7 +983,7 @@ def stitch_clipped_predictions( preds.append(pred) - overlap_norm = tuple(patch - stride for patch, stride in zip(patch_size, stride_size)) + overlap_norm = tuple(patch - stride for patch, stride in zip(patch_size, stride)) patch_overlap_unnorm = get_patch_overlap(overlap_norm, data_processor, X_t) patches_per_row = get_patches_per_row(preds) diff --git a/tests/test_model.py b/tests/test_model.py index 5e6c7f6a..871c6801 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -526,7 +526,7 @@ def test_patchwise_prediction(self): """Test that ``.predict_patch`` runs correctly.""" patch_size = (0.6, 0.6) - stride_size = (0.5, 0.5) + stride = (0.5, 0.5) tl = TaskLoader(context=self.da, target=self.da) @@ -536,7 +536,7 @@ def test_patchwise_prediction(self): target_sampling="all", patch_strategy="sliding", patch_size=patch_size, - stride=stride_size, + stride=stride, ) model = ConvNP(self.dp, tl) @@ -545,7 +545,7 @@ def test_patchwise_prediction(self): tasks=task, X_t=self.da, data_processor=self.dp, - stride=stride_size, + stride=stride, patch_size=patch_size, ) @@ -571,7 +571,7 @@ def test_patchwise_prediction(self): (-0.1, 0.6), # and below allowed range ] ) - def test_patchwise_prediction_parameter_handling(self, patch_size, stride_size): + def test_patchwise_prediction_parameter_handling(self, patch_size, stride): """Test that correct errors and warnings are raised by ``.predict_patch``.""" tl = TaskLoader(context=self.da, target=self.da) @@ -582,7 +582,7 @@ def test_patchwise_prediction_parameter_handling(self, patch_size, stride_size): target_sampling="all", patch_strategy="sliding", patch_size=patch_size, - stride=stride_size, + stride=stride, ) model = ConvNP(self.dp, tl) @@ -592,7 +592,7 @@ def test_patchwise_prediction_parameter_handling(self, patch_size, stride_size): tasks=task, X_t=self.da, data_processor=self.dp, - stride=stride_size, + stride=stride, patch_size=patch_size, ) From 840838d9181abf597e6efdca4f78f064bd08eb06 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Fri, 16 Aug 2024 10:03:54 +0100 Subject: [PATCH 54/69] move patchwise parameter test to test_task_loader --- tests/test_task_loader.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_task_loader.py b/tests/test_task_loader.py index d8a3d739..4ed8e8de 100644 --- a/tests/test_task_loader.py +++ b/tests/test_task_loader.py @@ -367,6 +367,29 @@ def test_sliding_window(self, patch_size, stride) -> None: stride=stride, ) + @parameterized.expand( + [ + ((0.5, 0.5), (0.6, 0.6)), # patch_size and stride as tuples + (0.5, 0.6), # as floats + (1.0, 1.2), # one argument above allowed range + (-0.1, 0.6), # and below allowed range + ] + ) + def test_patchwise_task_loader_parameter_handling(self, patch_size, stride): + """Test that correct errors and warnings are raised by ``.predict_patch``.""" + + tl = TaskLoader(context=self.da, target=self.da) + + with self.assertRaises(ValueError): + task = tl( + "2020-01-01", + context_sampling="all", + target_sampling="all", + patch_strategy="sliding", + patch_size=patch_size, + stride=stride, + ) + def test_saving_and_loading(self): """Test saving and loading TaskLoader""" with tempfile.TemporaryDirectory() as tmp_dir: From ceeb8ca77c347f2a51c7b6128de3bcf21a3e7979 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Fri, 16 Aug 2024 10:04:17 +0100 Subject: [PATCH 55/69] fix patch_size and stride for sliding window tests --- tests/test_task_loader.py | 2 +- tests/test_training.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_task_loader.py b/tests/test_task_loader.py index 4ed8e8de..61eb7673 100644 --- a/tests/test_task_loader.py +++ b/tests/test_task_loader.py @@ -326,7 +326,7 @@ def test_patch_size(self, patch_size) -> None: num_samples_per_date=2, ) - @parameterized.expand([[(0.2, 0.2), (1, 1)], [(0.3, 0.4), (1, 1)]]) + @parameterized.expand([[0.5, 0.1], [(0.3, 0.4), (0.1, 0.1)]]) def test_sliding_window(self, patch_size, stride) -> None: """Test sliding window sampling.""" # need to redefine the data generators because the patch size samplin diff --git a/tests/test_training.py b/tests/test_training.py index 3d4c0388..2157f047 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -164,8 +164,8 @@ def test_sliding_window_training(self): context_sampling="all", target_sampling="all", patch_strategy="sliding", - patch_size=(0.5, 0.5), - stride=(1, 1), + patch_size=(0.4, 0.4), + stride=(0.1, 0.1), ) # Train From 5fc1fe33b03038aa1fb4e3e1c003fd0bfcca90c1 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Fri, 16 Aug 2024 10:16:57 +0100 Subject: [PATCH 56/69] remove test as moved to test_task_loader --- tests/test_model.py | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 871c6801..38bafe77 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -563,38 +563,6 @@ def test_patchwise_prediction(self): assert self.da.x1.size == pred[var_ID].x1.size assert self.da.x2.size == pred[var_ID].x2.size - @parameterized.expand( - [ - ((0.5, 0.5), (0.6, 0.6)), # patch_size and stride as tuples - (0.5, 0.6), # as floats - (1.0, 1.2), # one argument above allowed range - (-0.1, 0.6), # and below allowed range - ] - ) - def test_patchwise_prediction_parameter_handling(self, patch_size, stride): - """Test that correct errors and warnings are raised by ``.predict_patch``.""" - - tl = TaskLoader(context=self.da, target=self.da) - - task = tl( - "2020-01-01", - context_sampling="all", - target_sampling="all", - patch_strategy="sliding", - patch_size=patch_size, - stride=stride, - ) - - model = ConvNP(self.dp, tl) - - with self.assertRaises(ValueError): - model.predict_patch( - tasks=task, - X_t=self.da, - data_processor=self.dp, - stride=stride, - patch_size=patch_size, - ) def test_saving_and_loading(self): """Test saving and loading of model""" From 1f434cc9a40d9b4c5b57f3c074fa599c3bc12a63 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Fri, 16 Aug 2024 10:26:57 +0100 Subject: [PATCH 57/69] check input parameters in task loader --- deepsensor/data/loader.py | 28 +++++++++++++++++++++++++++- deepsensor/model/model.py | 6 ------ tests/test_task_loader.py | 15 ++++++++------- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index f24b117e..bedbfdcc 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1571,6 +1571,14 @@ def __call__( patch_size is not None ), "Patch size must be specified for random patch sampling" + coord_bounds = [self.coord_bounds[0:2],self.coord_bounds[2:]] + for i,val in enumerate(patch_size): + if val < coord_bounds[i][0] or val > coord_bounds[i][1]: + raise ValueError( + f"Values of stride must be between the normalised coordinate bounds of: {self.coord_bounds}. \ + Got: patch_size: {patch_size}." + ) + if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)): for d in date: bboxes = [ @@ -1612,7 +1620,25 @@ def __call__( # sliding window sampling of patch assert ( patch_size is not None - ), "Patch size must be specified for sliding window sampling" + ), "patch_size must be specified for sliding window sampling" + + assert ( + stride is not None + ), "stride must be specified for sliding window sampling" + + if stride[0] > patch_size[0] or stride[1] > patch_size[1]: + raise ValueError( + f"stride must be smaller than patch_size in the corresponding dimensions. Got: patch_size: {patch_size}, stride: {stride}" + ) + + coord_bounds = [self.coord_bounds[0:2],self.coord_bounds[2:]] + for i in (0,1): + for val in (patch_size[i], stride[i]): + if val < coord_bounds[i][0] or val > coord_bounds[i][1]: + raise ValueError( + f"Values of stride and patch_size must be between the normalised coordinate bounds of: {self.coord_bounds}. \ + Got: patch_size: {patch_size}, stride: {stride}" + ) if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)): tasks = [] diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index c1b627e7..339000a1 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -934,12 +934,6 @@ def stitch_clipped_predictions( f"stride must be smaller than patch_size in the corresponding dimensions. Got: patch_size: {patch_size}, stride: {stride}" ) - for val in list(stride + patch_size): - if val > 1.0 or val < 0.0: - raise ValueError( - f"Values of stride and patch_size must be between 0 & 1. Got: patch_size: {patch_size}, stride: {stride}" - ) - # Get coordinate names of original unnormalised dataset. unnorm_coord_names = { "x1": self.data_processor.raw_spatial_coord_names[0], diff --git a/tests/test_task_loader.py b/tests/test_task_loader.py index 61eb7673..aec802cc 100644 --- a/tests/test_task_loader.py +++ b/tests/test_task_loader.py @@ -369,23 +369,24 @@ def test_sliding_window(self, patch_size, stride) -> None: @parameterized.expand( [ - ((0.5, 0.5), (0.6, 0.6)), # patch_size and stride as tuples - (0.5, 0.6), # as floats - (1.0, 1.2), # one argument above allowed range - (-0.1, 0.6), # and below allowed range + ("sliding", (0.5, 0.5), (0.6, 0.6)), # patch_size and stride as tuples + ("sliding", 0.5, 0.6), # as floats + ("sliding", 1.0, 1.2), # one argument above allowed range + ("sliding", -0.1, 0.6), # and below allowed range + ("random", 1.1, None) # for sliding window as well ] ) - def test_patchwise_task_loader_parameter_handling(self, patch_size, stride): + def test_patchwise_task_loader_parameter_handling(self, patch_strategy, patch_size, stride): """Test that correct errors and warnings are raised by ``.predict_patch``.""" tl = TaskLoader(context=self.da, target=self.da) with self.assertRaises(ValueError): - task = tl( + tl( "2020-01-01", context_sampling="all", target_sampling="all", - patch_strategy="sliding", + patch_strategy=patch_strategy, patch_size=patch_size, stride=stride, ) From 96edce8a258d02c06dc351b6bc2c4689b97bf7c5 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Fri, 16 Aug 2024 11:51:47 +0100 Subject: [PATCH 58/69] For patchwise prediction, get patch_size and stride directly from task --- deepsensor/data/loader.py | 12 ++++++++++++ deepsensor/model/model.py | 9 +-------- tests/test_model.py | 40 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 8 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 2bfcdf2a..19ab9f4b 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -959,6 +959,8 @@ def task_generation( ] = None, split_frac: float = 0.5, bbox: Sequence[float] = None, + patch_size: Union[float, tuple[float]] = None, + stride: Union[float, tuple[float]] = None, datewise_deterministic: bool = False, seed_override: Optional[int] = None, ) -> Task: @@ -995,6 +997,10 @@ def task_generation( bbox : Sequence[float], optional Bounding box to spatially slice the data, should be of the form [x1_min, x1_max, x2_min, x2_max]. Useful when considering the entire available region is computationally prohibitive for model forward pass. + patch_size : Union(Tuple|float), optional + Only used by patchwise inference. Height and width of patch in x1/x2 normalised coordinates. + stride: Union(Tuple|float), optional + Only used by patchwise inference. Length of stride between adjacent patches in x1/x2 normalised coordinates. datewise_deterministic : bool Whether random sampling is datewise_deterministic based on the date. Default is ``False``. @@ -1186,6 +1192,8 @@ def sample_variable(var, sampling_strat, seed): task["time"] = date task["ops"] = [] task["bbox"] = bbox + task["patch_size"] = patch_size # store patch_size and stride in task for use in stitching in prediction + task["stride"] = stride task["X_c"] = [] task["Y_c"] = [] if target_sampling is not None: @@ -1620,6 +1628,8 @@ def __call__( split_frac=split_frac, datewise_deterministic=datewise_deterministic, seed_override=seed_override, + patch_size=patch_size, + stride=stride ) for bbox in bboxes ] @@ -1635,6 +1645,8 @@ def __call__( split_frac=split_frac, datewise_deterministic=datewise_deterministic, seed_override=seed_override, + patch_size=patch_size, + stride=stride ) for bbox in bboxes ] diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 1ce5fa92..9af3443d 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -638,8 +638,6 @@ def predict_patch( pd.DataFrame, List[Union[xr.DataArray, xr.Dataset, pd.DataFrame]], ], - stride_size: Union[float, tuple[float]], - patch_size: Union[float, tuple[float]], X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None, X_t_is_normalised: bool = False, aux_at_targets_override: Union[xr.Dataset, xr.DataArray] = None, @@ -664,10 +662,6 @@ def predict_patch( List of tasks containing context data. data_processor (:class:`~.data.processor.DataProcessor`): Used for unnormalising the coordinates of the bounding boxes of patches. - stride_size (Union[float, tuple[float]]): - Length of stride between adjacent patches in x1/x2 normalised coordinates. - patch_size (Union[float, tuple[float]]): - Height and width of patch in x1/x2 normalised coordinates. X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`): Target locations to predict at. Can be an xarray object containingon-grid locations or a pandas object containing off-grid locations. @@ -938,8 +932,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row) -> d # Append patchwise DeepSensor prediction object to list preds.append(pred) - - overlap_norm = tuple(patch - stride for patch, stride in zip(patch_size, stride_size)) + overlap_norm = tuple(patch - stride for patch, stride in zip(task["patch_size"], task["stride"])) patch_overlap_unnorm = get_patch_overlap(overlap_norm, data_processor, X_t) patches_per_row = get_patches_per_row(preds) diff --git a/tests/test_model.py b/tests/test_model.py index 80519269..c8457d8e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -522,6 +522,46 @@ def test_highlevel_predict_with_invalid_pred_params(self): with self.assertRaises(AttributeError): model.predict(task, X_t=self.da, pred_params=["invalid_param"]) + def test_patchwise_prediction(self): + """Test that ``.predict_patch`` runs correctly.""" + + patch_size = (0.2, 0.2) + stride = (0.1, 0.1) + + tl = TaskLoader(context=self.da, target=self.da) + + tasks = tl( + "2020-01-01", + context_sampling="all", + target_sampling="all", + patch_strategy="sliding", + patch_size=patch_size, + stride=stride, + ) + + model = ConvNP(self.dp, tl) + + pred = model.predict_patch( + tasks=tasks, + X_t=self.da, + data_processor=self.dp, + ) + + # gridded predictions + assert [isinstance(ds, xr.Dataset) for ds in pred.values()] + for var_ID in pred: + assert_shape( + pred[var_ID]["mean"], + (1, self.da.x1.size, self.da.x2.size), + ) + assert_shape( + pred[var_ID]["std"], + (1, self.da.x1.size, self.da.x2.size), + ) + assert self.da.x1.size == pred[var_ID].x1.size + assert self.da.x2.size == pred[var_ID].x2.size + + def test_saving_and_loading(self): """Test saving and loading of model""" with tempfile.TemporaryDirectory() as folder: From 18f2e5a926fa7e29958a6fe66cf2c340b1bbbb69 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Thu, 22 Aug 2024 10:28:13 +0100 Subject: [PATCH 59/69] raise errors instead of assert --- deepsensor/data/loader.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index bedbfdcc..5658506c 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1531,10 +1531,11 @@ def __call__( Task object or list of task objects for each date containing the context and target data. """ - assert patch_strategy in [None, "random", "sliding"], ( - f"Invalid patch strategy {patch_strategy}. " - f"Must be one of [None, 'random', 'sliding']." - ) + if patch_strategy not in [None, "random", "sliding"]: + raise ValueError( + f"Invalid patch strategy {patch_strategy}. " + f"Must be one of [None, 'random', 'sliding']." + ) if isinstance(patch_size, float) and patch_size is not None: patch_size = (patch_size, patch_size) @@ -1567,10 +1568,9 @@ def __call__( elif patch_strategy == "random": - assert ( - patch_size is not None - ), "Patch size must be specified for random patch sampling" - + if patch_size is None: + raise ValueError("Patch size must be specified for random patch sampling") + coord_bounds = [self.coord_bounds[0:2],self.coord_bounds[2:]] for i,val in enumerate(patch_size): if val < coord_bounds[i][0] or val > coord_bounds[i][1]: @@ -1618,13 +1618,10 @@ def __call__( elif patch_strategy == "sliding": # sliding window sampling of patch - assert ( - patch_size is not None - ), "patch_size must be specified for sliding window sampling" - - assert ( - stride is not None - ), "stride must be specified for sliding window sampling" + + for val in (patch_size, stride): + if val is None: + raise ValueError(f"patch_size and stride must be specified for sliding window sampling, got patch_size: {patch_size} and stride: {stride}.") if stride[0] > patch_size[0] or stride[1] > patch_size[1]: raise ValueError( From 47d0998b50d07424ce8e0d073a5b6a1fcb15a502 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Thu, 22 Aug 2024 10:35:48 +0100 Subject: [PATCH 60/69] use warning for stride > patch size --- deepsensor/data/loader.py | 4 ++-- deepsensor/model/model.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 5658506c..5fab61cb 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1624,8 +1624,8 @@ def __call__( raise ValueError(f"patch_size and stride must be specified for sliding window sampling, got patch_size: {patch_size} and stride: {stride}.") if stride[0] > patch_size[0] or stride[1] > patch_size[1]: - raise ValueError( - f"stride must be smaller than patch_size in the corresponding dimensions. Got: patch_size: {patch_size}, stride: {stride}" + raise Warning( + f"stride should generally be smaller than patch_size in the corresponding dimensions. Got: patch_size: {patch_size}, stride: {stride}" ) coord_bounds = [self.coord_bounds[0:2],self.coord_bounds[2:]] diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 339000a1..7212a05a 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -930,8 +930,8 @@ def stitch_clipped_predictions( stride = (stride, stride) if stride[0] > patch_size[0] or stride[1] > patch_size[1]: - raise ValueError( - f"stride must be smaller than patch_size in the corresponding dimensions. Got: patch_size: {patch_size}, stride: {stride}" + raise Warning( + f"stride should generally be smaller than patch_size in the corresponding dimensions. Got: patch_size: {patch_size}, stride: {stride}" ) # Get coordinate names of original unnormalised dataset. From df0533b2129e945a8e566e025cb359ba55cf1a00 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Thu, 22 Aug 2024 10:37:01 +0100 Subject: [PATCH 61/69] remove comment --- deepsensor/model/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 7212a05a..01e56239 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -869,7 +869,7 @@ def stitch_clipped_predictions( patches_clipped = {var_name: [] for var_name in patch_preds[0].keys()} for i, patch_pred in enumerate(patch_preds): - for var_name, data_array in patch_pred.items(): # previously patch + for var_name, data_array in patch_pred.items(): if var_name in patch_pred: # Get row/col index values of each patch patch_x1 = data_array.coords[unnorm_coord_names['x1']].min().values, data_array.coords[unnorm_coord_names['x1']].max().values From f7d57e965857fb73a361d0d5e0ba6e1d1b14a32d Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Thu, 22 Aug 2024 10:40:10 +0100 Subject: [PATCH 62/69] raise error for stride > patch_size in prediction --- deepsensor/model/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 01e56239..79894a6c 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -930,8 +930,8 @@ def stitch_clipped_predictions( stride = (stride, stride) if stride[0] > patch_size[0] or stride[1] > patch_size[1]: - raise Warning( - f"stride should generally be smaller than patch_size in the corresponding dimensions. Got: patch_size: {patch_size}, stride: {stride}" + raise ValueError( + f"stride must be smaller than patch_size in the corresponding dimensions for patchwise prediction. Got: patch_size: {patch_size}, stride: {stride}" ) # Get coordinate names of original unnormalised dataset. From fed39407ce16affb4491611178dffcff411fee82 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Thu, 22 Aug 2024 10:45:10 +0100 Subject: [PATCH 63/69] alter paramaters for test --- tests/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index 38bafe77..6f5d8561 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -526,7 +526,7 @@ def test_patchwise_prediction(self): """Test that ``.predict_patch`` runs correctly.""" patch_size = (0.6, 0.6) - stride = (0.5, 0.5) + stride = (0.3, 0.3) tl = TaskLoader(context=self.da, target=self.da) From b3a6dab0728c18b47a827d84877f6483b0e3ae8c Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Thu, 22 Aug 2024 11:09:47 +0100 Subject: [PATCH 64/69] raise error for more than one date in predict_patch --- deepsensor/model/model.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 79894a6c..157be6a8 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -933,6 +933,15 @@ def stitch_clipped_predictions( raise ValueError( f"stride must be smaller than patch_size in the corresponding dimensions for patchwise prediction. Got: patch_size: {patch_size}, stride: {stride}" ) + + # patchwise prediction does not yet support more than a single date + num_task_dates = len(set([t["time"] for t in tasks])) + if num_task_dates > 1: + raise NotImplementedError( + f"Patchwise prediction does not yet support more than a single date at a time, got {num_task_dates}. \n\ + Contributions to the DeepSensor package are very welcome. \n\ + Please see the contributing guide at https://alan-turing-institute.github.io/deepsensor/community/contributing.html" + ) # Get coordinate names of original unnormalised dataset. unnorm_coord_names = { From c8a38f2a376481a94058a75f96c31df7219b317b Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Thu, 22 Aug 2024 11:10:18 +0100 Subject: [PATCH 65/69] black --- deepsensor/data/loader.py | 20 ++-- deepsensor/model/model.py | 187 ++++++++++++++++++++++++++++---------- 2 files changed, 151 insertions(+), 56 deletions(-) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 5fab61cb..5d5f5a45 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1569,10 +1569,12 @@ def __call__( elif patch_strategy == "random": if patch_size is None: - raise ValueError("Patch size must be specified for random patch sampling") - - coord_bounds = [self.coord_bounds[0:2],self.coord_bounds[2:]] - for i,val in enumerate(patch_size): + raise ValueError( + "Patch size must be specified for random patch sampling" + ) + + coord_bounds = [self.coord_bounds[0:2], self.coord_bounds[2:]] + for i, val in enumerate(patch_size): if val < coord_bounds[i][0] or val > coord_bounds[i][1]: raise ValueError( f"Values of stride must be between the normalised coordinate bounds of: {self.coord_bounds}. \ @@ -1618,18 +1620,20 @@ def __call__( elif patch_strategy == "sliding": # sliding window sampling of patch - + for val in (patch_size, stride): if val is None: - raise ValueError(f"patch_size and stride must be specified for sliding window sampling, got patch_size: {patch_size} and stride: {stride}.") + raise ValueError( + f"patch_size and stride must be specified for sliding window sampling, got patch_size: {patch_size} and stride: {stride}." + ) if stride[0] > patch_size[0] or stride[1] > patch_size[1]: raise Warning( f"stride should generally be smaller than patch_size in the corresponding dimensions. Got: patch_size: {patch_size}, stride: {stride}" ) - coord_bounds = [self.coord_bounds[0:2],self.coord_bounds[2:]] - for i in (0,1): + coord_bounds = [self.coord_bounds[0:2], self.coord_bounds[2:]] + for i in (0, 1): for val in (patch_size[i], stride[i]): if val < coord_bounds[i][0] or val > coord_bounds[i][1]: raise ValueError( diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 157be6a8..c401c108 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -733,13 +733,12 @@ def predict_patch( ValueError If ``append_indexes`` are not all the same length as ``X_t``. """ - + # Get coordinate names of original unnormalised dataset. unnorm_coord_names = { - "x1": self.data_processor.raw_spatial_coord_names[0], - "x2": self.data_processor.raw_spatial_coord_names[1], - } - + "x1": self.data_processor.raw_spatial_coord_names[0], + "x2": self.data_processor.raw_spatial_coord_names[1], + } def get_patches_per_row(preds) -> int: """ @@ -755,13 +754,13 @@ def get_patches_per_row(preds) -> int: """ patches_per_row = 0 vars = list(preds[0][0].data_vars) - - var = vars[0] - y_val = preds[0][0][var].coords[unnorm_coord_names['x1']].min() - + + var = vars[0] + y_val = preds[0][0][var].coords[unnorm_coord_names["x1"]].min() + for p in preds: - if p[0][var].coords[unnorm_coord_names['x1']].min() == y_val: - patches_per_row = patches_per_row + 1 + if p[0][var].coords[unnorm_coord_names["x1"]].min() == y_val: + patches_per_row = patches_per_row + 1 return patches_per_row @@ -793,12 +792,40 @@ def get_patch_overlap(overlap_norm, data_processor, X_t_ds) -> int: # Unnormalise coordinates of bounding boxes overlap_unnorm_xr = data_processor.unnormalise(overlap_norm_xr) - unnorm_overlap_x1 = overlap_unnorm_xr.coords[unnorm_coord_names['x1']].values[1] - unnorm_overlap_x2 = overlap_unnorm_xr.coords[unnorm_coord_names['x2']].values[1] + unnorm_overlap_x1 = overlap_unnorm_xr.coords[ + unnorm_coord_names["x1"] + ].values[1] + unnorm_overlap_x2 = overlap_unnorm_xr.coords[ + unnorm_coord_names["x2"] + ].values[1] # Find the position of these indices within the DataArray - x_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[unnorm_coord_names['x1']].values - unnorm_overlap_x1))/2))) - y_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[unnorm_coord_names['x2']].values - unnorm_overlap_x2))/2))) + x_overlap_index = int( + np.ceil( + ( + np.argmin( + np.abs( + X_t_ds.coords[unnorm_coord_names["x1"]].values + - unnorm_overlap_x1 + ) + ) + / 2 + ) + ) + ) + y_overlap_index = int( + np.ceil( + ( + np.argmin( + np.abs( + X_t_ds.coords[unnorm_coord_names["x2"]].values + - unnorm_overlap_x2 + ) + ) + / 2 + ) + ) + ) xy_overlap = (x_overlap_index, y_overlap_index) return xy_overlap @@ -827,15 +854,33 @@ def get_index(*args, x1=True) -> Union[int, Tuple[List[int], List[int]]]: if len(args) == 1: patch_coord = args if x1: - coord_index = np.argmin(np.abs(X_t.coords[unnorm_coord_names['x1']].values - patch_coord)) + coord_index = np.argmin( + np.abs( + X_t.coords[unnorm_coord_names["x1"]].values - patch_coord + ) + ) else: - coord_index = np.argmin(np.abs(X_t.coords[unnorm_coord_names['x2']].values - patch_coord)) + coord_index = np.argmin( + np.abs( + X_t.coords[unnorm_coord_names["x2"]].values - patch_coord + ) + ) return coord_index elif len(args) == 2: - patch_x1, patch_x2 = args - x1_index = [np.argmin(np.abs(X_t.coords[unnorm_coord_names['x1']].values - target_x1)) for target_x1 in patch_x1] - x2_index = [np.argmin(np.abs(X_t.coords[unnorm_coord_names['x2']].values - target_x2)) for target_x2 in patch_x2] + patch_x1, patch_x2 = args + x1_index = [ + np.argmin( + np.abs(X_t.coords[unnorm_coord_names["x1"]].values - target_x1) + ) + for target_x1 in patch_x1 + ] + x2_index = [ + np.argmin( + np.abs(X_t.coords[unnorm_coord_names["x2"]].values - target_x2) + ) + for target_x2 in patch_x2 + ] return (x1_index, x2_index) def stitch_clipped_predictions( @@ -861,10 +906,14 @@ def stitch_clipped_predictions( Dictionary object containing the stitched model predictions. """ - - - data_x1 = X_t.coords[unnorm_coord_names['x1']].min().values, X_t.coords[unnorm_coord_names['x1']].max().values - data_x2 = X_t.coords[unnorm_coord_names['x2']].min().values, X_t.coords[unnorm_coord_names['x2']].max().values + data_x1 = ( + X_t.coords[unnorm_coord_names["x1"]].min().values, + X_t.coords[unnorm_coord_names["x1"]].max().values, + ) + data_x2 = ( + X_t.coords[unnorm_coord_names["x2"]].min().values, + X_t.coords[unnorm_coord_names["x2"]].max().values, + ) data_x1_index, data_x2_index = get_index(data_x1, data_x2) patches_clipped = {var_name: [] for var_name in patch_preds[0].keys()} @@ -872,10 +921,16 @@ def stitch_clipped_predictions( for var_name, data_array in patch_pred.items(): if var_name in patch_pred: # Get row/col index values of each patch - patch_x1 = data_array.coords[unnorm_coord_names['x1']].min().values, data_array.coords[unnorm_coord_names['x1']].max().values - patch_x2 = data_array.coords[unnorm_coord_names['x2']].min().values, data_array.coords[unnorm_coord_names['x2']].max().values - patch_x1_index, patch_x2_index = get_index(patch_x1, patch_x2) - + patch_x1 = ( + data_array.coords[unnorm_coord_names["x1"]].min().values, + data_array.coords[unnorm_coord_names["x1"]].max().values, + ) + patch_x2 = ( + data_array.coords[unnorm_coord_names["x2"]].min().values, + data_array.coords[unnorm_coord_names["x2"]].max().values, + ) + patch_x1_index, patch_x2_index = get_index(patch_x1, patch_x2) + b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0] b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1] # Do not remove border for the patches along top and left of dataset @@ -884,22 +939,44 @@ def stitch_clipped_predictions( b_x2_min = 0 elif patch_x2_index[1] == data_x2_index[1]: b_x2_max = 0 - patch_row_prev = preds[i-1] - prev_patch_x2_max = get_index(int(patch_row_prev[var_name].coords[unnorm_coord_names['x2']].max()), x1 = False) - b_x2_min = (prev_patch_x2_max - patch_x2_index[0])-patch_overlap[1] + patch_row_prev = preds[i - 1] + prev_patch_x2_max = get_index( + int( + patch_row_prev[var_name] + .coords[unnorm_coord_names["x2"]] + .max() + ), + x1=False, + ) + b_x2_min = ( + prev_patch_x2_max - patch_x2_index[0] + ) - patch_overlap[1] if patch_x1_index[0] == data_x1_index[0]: b_x1_min = 0 elif abs(patch_x1_index[1] - data_x1_index[1]) < 2: b_x1_max = 0 - patch_prev = preds[i-patches_per_row] - prev_patch_x1_max = get_index(int(patch_prev[var_name].coords[unnorm_coord_names['x1']].max()), x1 = True) - b_x1_min = (prev_patch_x1_max - patch_x1_index[0])- patch_overlap[0] + patch_prev = preds[i - patches_per_row] + prev_patch_x1_max = get_index( + int( + patch_prev[var_name] + .coords[unnorm_coord_names["x1"]] + .max() + ), + x1=True, + ) + b_x1_min = ( + prev_patch_x1_max - patch_x1_index[0] + ) - patch_overlap[0] patch_clip_x1_min = int(b_x1_min) - patch_clip_x1_max = int(data_array.sizes[unnorm_coord_names['x1']] - b_x1_max) + patch_clip_x1_max = int( + data_array.sizes[unnorm_coord_names["x1"]] - b_x1_max + ) patch_clip_x2_min = int(b_x2_min) - patch_clip_x2_max = int(data_array.sizes[unnorm_coord_names['x2']] - b_x2_max) + patch_clip_x2_max = int( + data_array.sizes[unnorm_coord_names["x2"]] - b_x2_max + ) patch_clip = data_array[ { @@ -933,7 +1010,7 @@ def stitch_clipped_predictions( raise ValueError( f"stride must be smaller than patch_size in the corresponding dimensions for patchwise prediction. Got: patch_size: {patch_size}, stride: {stride}" ) - + # patchwise prediction does not yet support more than a single date num_task_dates = len(set([t["time"] for t in tasks])) if num_task_dates > 1: @@ -970,13 +1047,19 @@ def stitch_clipped_predictions( x2 = xr.DataArray([bbox[2], bbox[3]], dims="x2", name="x2") bbox_norm = xr.Dataset(coords={"x1": x1, "x2": x2}) bbox_unnorm = data_processor.unnormalise(bbox_norm) - unnorm_bbox_x1 = bbox_unnorm[unnorm_coord_names['x1']].values.min(), bbox_unnorm[unnorm_coord_names['x1']].values.max() - unnorm_bbox_x2 = bbox_unnorm[unnorm_coord_names['x2']].values.min(), bbox_unnorm[unnorm_coord_names['x2']].values.max() - + unnorm_bbox_x1 = ( + bbox_unnorm[unnorm_coord_names["x1"]].values.min(), + bbox_unnorm[unnorm_coord_names["x1"]].values.max(), + ) + unnorm_bbox_x2 = ( + bbox_unnorm[unnorm_coord_names["x2"]].values.min(), + bbox_unnorm[unnorm_coord_names["x2"]].values.max(), + ) + # Determine X_t for patch task_extent_dict = { - unnorm_coord_names['x1']: slice(unnorm_bbox_x1[0], unnorm_bbox_x1[1]), - unnorm_coord_names['x2']: slice(unnorm_bbox_x2[0], unnorm_bbox_x2[1]) + unnorm_coord_names["x1"]: slice(unnorm_bbox_x1[0], unnorm_bbox_x1[1]), + unnorm_coord_names["x2"]: slice(unnorm_bbox_x2[0], unnorm_bbox_x2[1]), } task_X_t = X_t.sel(**task_extent_dict) @@ -985,13 +1068,16 @@ def stitch_clipped_predictions( # Append patchwise DeepSensor prediction object to list preds.append(pred) - - overlap_norm = tuple(patch - stride for patch, stride in zip(patch_size, stride)) + overlap_norm = tuple( + patch - stride for patch, stride in zip(patch_size, stride) + ) patch_overlap_unnorm = get_patch_overlap(overlap_norm, data_processor, X_t) - + patches_per_row = get_patches_per_row(preds) - stitched_prediction = stitch_clipped_predictions(preds, patch_overlap_unnorm, patches_per_row) - + stitched_prediction = stitch_clipped_predictions( + preds, patch_overlap_unnorm, patches_per_row + ) + ## Cast prediction into DeepSensor.Prediction object. # TODO make this into seperate method. prediction = copy.deepcopy(preds[0]) @@ -1000,7 +1086,12 @@ def stitch_clipped_predictions( for var_name_copy, data_array_copy in prediction.items(): # set x and y coords - stitched_preds = xr.Dataset(coords={'x1': X_t[unnorm_coord_names['x1']], 'x2': X_t[unnorm_coord_names['x2']]}) + stitched_preds = xr.Dataset( + coords={ + "x1": X_t[unnorm_coord_names["x1"]], + "x2": X_t[unnorm_coord_names["x2"]], + } + ) # Set time to same as patched prediction stitched_preds["time"] = data_array_copy["time"] From e10d6454c33fdff7603aef8941dd6f831823844a Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Thu, 22 Aug 2024 15:53:59 +0100 Subject: [PATCH 66/69] fix getting and checking of patch_size and stride --- deepsensor/model/model.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 7d9e846a..52c30123 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -960,8 +960,11 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a return combined - # sanitise patch_size and stride arguments + # load patch_size and stride from task + patch_size = tasks[0]["patch_size"] + stride = tasks[0]["stride"] + # sanitise patch_size and stride arguments if isinstance(patch_size, float) and patch_size is not None: patch_size = (patch_size, patch_size) @@ -1033,7 +1036,7 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a # Append patchwise DeepSensor prediction object to list preds.append(pred) - overlap_norm = tuple(patch - stride for patch, stride in zip(task["patch_size"], task["stride"])) + overlap_norm = tuple(patch - stride for patch, stride in zip(patch_size, stride)) patch_overlap_unnorm = get_patch_overlap(overlap_norm, data_processor, X_t, x1_ascending, x2_ascending) patches_per_row = get_patches_per_row(preds) From f6f843df3e1c05c9c118aaa08526520e87e43881 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Thu, 22 Aug 2024 15:55:11 +0100 Subject: [PATCH 67/69] fix docstrings and defaults --- deepsensor/model/model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 52c30123..eea4a52d 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -771,11 +771,11 @@ def get_patch_overlap(overlap_norm, data_processor, X_t_ds, x1_ascend, x2_ascend X_t_ds (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`): Data array containing target locations to predict at. - x1_ascend : str: - Boolean defining whether the x1 coords ascend (increase) from top to bottom, default = True. + x1_ascend : bool: + Boolean defining whether the x1 coords ascend (increase) from top to bottom. - x2_ascend : str: - Boolean defining whether the x2 coords ascend (increase) from left to right, default = True. + x2_ascend : bool: + Boolean defining whether the x2 coords ascend (increase) from left to right. Returns ------- @@ -861,10 +861,10 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a patches_per_row: int Number of patchwise predictions in each row. - x1_ascend : str + x1_ascend : bool Boolean defining whether the x1 coords ascend (increase) from top to bottom, default = True. - x2_ascend : str + x2_ascend : bool Boolean defining whether the x2 coords ascend (increase) from left to right, default = True. Returns From 2e5c6a86a11942606de98424725a9276b8badbb1 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Thu, 22 Aug 2024 16:12:48 +0100 Subject: [PATCH 68/69] reinstate orig_name patch clip slicing --- deepsensor/model/model.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index eea4a52d..b5481e32 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -771,11 +771,11 @@ def get_patch_overlap(overlap_norm, data_processor, X_t_ds, x1_ascend, x2_ascend X_t_ds (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`): Data array containing target locations to predict at. - x1_ascend : bool: - Boolean defining whether the x1 coords ascend (increase) from top to bottom. + x1_ascend : str: + Boolean defining whether the x1 coords ascend (increase) from top to bottom, default = True. - x2_ascend : bool: - Boolean defining whether the x2 coords ascend (increase) from left to right. + x2_ascend : str: + Boolean defining whether the x2 coords ascend (increase) from left to right, default = True. Returns ------- @@ -861,10 +861,10 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a patches_per_row: int Number of patchwise predictions in each row. - x1_ascend : bool + x1_ascend : str Boolean defining whether the x1 coords ascend (increase) from top to bottom, default = True. - x2_ascend : bool + x2_ascend : str Boolean defining whether the x2 coords ascend (increase) from left to right, default = True. Returns @@ -947,8 +947,8 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a patch_clip_x2_min = int(b_x2_min) patch_clip_x2_max = int(data_array.sizes[orig_x2_name] - b_x2_max) - patch_clip = data_array.isel(y=slice(patch_clip_x1_min, patch_clip_x1_max), - x=slice(patch_clip_x2_min, patch_clip_x2_max)) + patch_clip = data_array.isel(**{orig_x1_name: slice(patch_clip_x1_min, patch_clip_x1_max), + orig_x2_name: slice(patch_clip_x2_min, patch_clip_x2_max)}) patches_clipped[var_name].append(patch_clip) From 51d8c050eb70a2b5c0fa4f85e3e68df7e0288b18 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Fri, 23 Aug 2024 08:34:19 +0100 Subject: [PATCH 69/69] use hypothesis to expand on patchwise predict testing --- .gitignore | 1 + requirements/requirements.dev.txt | 1 + tests/test_model.py | 8 +++++--- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 46ab9b74..3b461c76 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ dist/* _build *.png deepsensor.egg-info/ +.hypothesis/ \ No newline at end of file diff --git a/requirements/requirements.dev.txt b/requirements/requirements.dev.txt index 6240a200..2ae199c2 100644 --- a/requirements/requirements.dev.txt +++ b/requirements/requirements.dev.txt @@ -6,3 +6,4 @@ tox tox-gh-actions coveralls black +hypothesis \ No newline at end of file diff --git a/tests/test_model.py b/tests/test_model.py index 2a7d2022..17b1de26 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -3,6 +3,7 @@ import tempfile from parameterized import parameterized +from hypothesis import example, given, strategies as st import os import xarray as xr @@ -553,11 +554,12 @@ def test_highlevel_predict_with_invalid_pred_params(self): with self.assertRaises(AttributeError): model.predict(task, X_t=self.da, pred_params=["invalid_param"]) - def test_patchwise_prediction(self): + @given(st.data()) + def test_patchwise_prediction(self, data): """Test that ``.predict_patch`` runs correctly.""" - patch_size = (0.2, 0.2) - stride = (0.1, 0.1) + patch_size = data.draw(st.floats(min_value=0.1, max_value=1.0)) + stride = data.draw(st.floats(min_value=0.1, max_value=patch_size)) tl = TaskLoader(context=self.da, target=self.da)