Skip to content

Commit

Permalink
Merge pull request #16 from davidwilby/refactor_sample_sliding
Browse files Browse the repository at this point in the history
Refactor `sample_sliding_window`
  • Loading branch information
davidwilby authored Dec 3, 2024
2 parents 4e028ab + 8b9a8ac commit d620e88
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 75 deletions.
125 changes: 52 additions & 73 deletions deepsensor/data/loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import itertools
import json
import operator
import os
import random
from typing import List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -1440,96 +1441,74 @@ def sample_sliding_window(
patch_size : Tuple[float]
Tuple of window extent
Stride : Tuple[float]
stride : Tuple[float]
Tuple of step size between each patch along x1 and x2 axis.
Returns:
-------
bbox: List[float]
List[float]
Sequence of patch spatial extent as [x1_min, x1_max, x2_min, x2_max].
"""
# define patch size in x1/x2
x1_extend, x2_extend = patch_size
size = {}
size["x1"], size["x2"] = patch_size

# define stride length in x1/x2 or set to patch_size if undefined
if stride is None:
stride = patch_size

dy, dx = stride
# Calculate the global bounds of context and target set.
x1_min, x1_max, x2_min, x2_max = self.coord_bounds
## start with first patch top left hand corner at x1_min, x2_min
patch_list = []
step = {}
step["x1"], step["x2"] = stride

# Todo: simplify these elif statements
if self.coord_directions["x1"] == False and self.coord_directions["x2"] == True:
for y in np.arange(x1_max, x1_min, -dy):
for x in np.arange(x2_min, x2_max, dx):
if y - x1_extend < x1_min:
y0 = x1_min + x1_extend
else:
y0 = y
if x + x2_extend > x2_max:
x0 = x2_max - x2_extend
else:
x0 = x

# bbox of x1_min, x1_max, x2_min, x2_max per patch
bbox = [y0 - x1_extend, y0, x0, x0 + x2_extend]
patch_list.append(bbox)

elif (
self.coord_directions["x1"] == False
and self.coord_directions["x2"] == False
):
for y in np.arange(x1_max, x1_min, -dy):
for x in np.arange(x2_max, x2_min, -dx):
if y - x1_extend < x1_min:
y0 = x1_min + x1_extend
else:
y0 = y
if x - x2_extend < x2_min:
x0 = x2_min + x2_extend
else:
x0 = x
# Calculate the global bounds of context and target set.
coord_min = {}
coord_max = {}
coord_min["x1"], coord_max["x1"], coord_min["x2"], coord_max["x2"] = (
self.coord_bounds
)

# bbox of x1_min, x1_max, x2_min, x2_max per patch
bbox = [y0 - x1_extend, y0, x0 - x2_extend, x0]
patch_list.append(bbox)
## start with first patch top left hand corner at coord_min["x1"], coord_min["x2"]
patch_list = []

elif (
self.coord_directions["x1"] == True and self.coord_directions["x2"] == False
# define some lambda functions for use below
# round to 12 figures to avoid floating point error but reduce likelihood of unintentional rounding
r = lambda x: round(x, 12)
bbox_coords_ascend = lambda a, b: [r(a), r(a + b)]
bbox_coords_descend = lambda a, b: bbox_coords_ascend(a, b)[::-1]

compare = {}
bbox_coords = {}
# for each coordinate direction specify the correct operations for patching
for c in ("x1", "x2"):
if self.coord_directions[c]:
compare[c] = operator.gt
bbox_coords[c] = bbox_coords_ascend
else:
step[c] = -step[c]
coord_min[c], coord_max[c] = coord_max[c], coord_min[c]
size[c] = -size[c]
compare[c] = operator.lt
bbox_coords[c] = bbox_coords_descend

# Define the bounding boxes for all patches, starting in top left corner of dataArray
for y, x in itertools.product(
np.arange(coord_min["x1"], coord_max["x1"], step["x1"]),
np.arange(coord_min["x2"], coord_max["x2"], step["x2"]),
):
for y in np.arange(x1_min, x1_max, dy):
for x in np.arange(x2_max, x2_min, -dx):
if y + x1_extend > x1_max:
y0 = x1_max - x1_extend
else:
y0 = y
if x - x2_extend < x2_min:
x0 = x2_min + x2_extend
else:
x0 = x

# bbox of x1_min, x1_max, x2_min, x2_max per patch
bbox = [y0, y0 + x1_extend, x0 - x2_extend, x0]
patch_list.append(bbox)
else:
for y in np.arange(x1_min, x1_max, dy):
for x in np.arange(x2_min, x2_max, dx):
if y + x1_extend > x1_max:
y0 = x1_max - x1_extend
else:
y0 = y
if x + x2_extend > x2_max:
x0 = x2_max - x2_extend
else:
x0 = x

# bbox of x1_min, x1_max, x2_min, x2_max per patch
bbox = [y0, y0 + x1_extend, x0, x0 + x2_extend]
y0 = (
coord_max["x1"] - size["x1"]
if compare["x1"](y + size["x1"], coord_max["x1"])
else y
)
x0 = (
coord_max["x2"] - size["x2"]
if compare["x2"](x + size["x2"], coord_max["x2"])
else x
)

patch_list.append(bbox)
# bbox of x1_min, x1_max, x2_min, x2_max per patch
bbox = bbox_coords["x1"](y0, size["x1"]) + bbox_coords["x2"](x0, size["x2"])
patch_list.append(bbox)

# Remove duplicate patches while preserving order
seen = set()
Expand Down
13 changes: 11 additions & 2 deletions tests/test_task_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import itertools
import math
import os
import shutil
import tempfile
Expand Down Expand Up @@ -343,7 +344,7 @@ def test_patch_size(self, patch_size) -> None:
num_samples_per_date=2,
)

@parameterized.expand([[0.5, 0.1], [(0.3, 0.4), (0.1, 0.1)]])
@parameterized.expand([[0.5, 0.45], [(0.3, 0.4), (0.3, 0.35)]])
def test_sliding_window(self, patch_size, stride) -> None:
"""Test sliding window sampling."""
# need to redefine the data generators because the patch size samplin
Expand Down Expand Up @@ -371,7 +372,7 @@ def test_sliding_window(self, patch_size, stride) -> None:
context = [da_data_0_1, da_data_smaller, da_data_larger]
tl = TaskLoader(
context=context, # gridded xarray and off-grid pandas contexts
target=self.df, # off-grid pandas targets
target=self.df, # off-grid pandas targets
)

# test date range
Expand All @@ -384,6 +385,14 @@ def test_sliding_window(self, patch_size, stride) -> None:
stride=stride,
)

# test patch sizes are correct
for task in tasks:
assert math.isclose(task['bbox'][1] - task['bbox'][0], task['patch_size'][0])
assert math.isclose(task['bbox'][3] - task['bbox'][2], task['patch_size'][1])

# test stride sizes are correct
assert math.isclose(abs(tasks[0]['bbox'][2] - tasks[1]['bbox'][2]), tasks[0]['stride'][1])

@parameterized.expand(
[
("sliding", (0.5, 0.5), (0.6, 0.6), Warning), # patch_size and stride as tuples
Expand Down

0 comments on commit d620e88

Please sign in to comment.