Skip to content

Commit

Permalink
implement a lazy representation of a list of sliding windows
Browse files Browse the repository at this point in the history
allowing arbitrarily long lists of sliding windows (from arbitrarily large extents) to be represented
  • Loading branch information
AdeelH committed Nov 26, 2024
1 parent fb3b199 commit b9b2394
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 63 deletions.
231 changes: 169 additions & 62 deletions rastervision_core/rastervision/core/box.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import (TYPE_CHECKING, Literal)
from typing import (TYPE_CHECKING, Literal, Sequence, overload)
from collections.abc import Callable
from pydantic import NonNegativeInt as NonNegInt, PositiveInt as PosInt
import math
Expand Down Expand Up @@ -364,6 +364,31 @@ def pad(self, ymin: int, xmin: int, ymax: int, xmax: int) -> 'Self':
ymax=self.ymax + ymax,
xmax=self.xmax + xmax)

def pad_directional(
self,
padding: tuple[NonNegInt, NonNegInt] | NonNegInt,
pad_direction: Literal['both', 'start', 'end'] = 'end') -> 'Self':
"""Pad sides based on given padding and direction."""

padding: tuple[NonNegInt, NonNegInt] = ensure_tuple(padding)

if padding == (0, 0):
return self

if padding[0] < 0 or padding[1] < 0:
raise ValueError('padding must be non-negative.')

h_pad, w_pad = padding
if pad_direction == 'both':
return self.pad(ymin=h_pad, xmin=w_pad, ymax=h_pad, xmax=w_pad)
elif pad_direction == 'end':
return self.pad(ymin=0, xmin=0, ymax=h_pad, xmax=w_pad)
elif pad_direction == 'start':
return self.pad(ymin=h_pad, xmin=w_pad, ymax=0, xmax=0)

raise ValueError('pad_directions must be one of: '
'"both", "start", "end".')

def copy(self) -> 'Self':
return Box(*self)

Expand All @@ -373,78 +398,40 @@ def get_windows(
stride: PosInt | tuple[PosInt, PosInt],
padding: NonNegInt | tuple[NonNegInt, NonNegInt] | None = None,
pad_direction: Literal['both', 'start', 'end'] = 'end'
) -> list['Self']:
) -> 'SlidingWindows':
"""Return sliding windows for given size, stride, and padding.
Each of size, stride, and padding can be either a positive int or
a tuple ``(vertical-component, horizontal-component)`` of positive
ints.
Each of ``size``, ``stride``, and ``padding`` can be either a positive
int or a tuple ``(vertical-component, horizontal-component)`` of
positive ints.
If ``padding`` is not specified and ``stride <= size``, it will be
automatically calculated such that the windows cover the entire extent.
Args:
size: Size (h, w) of the windows.
stride: Step size between windows. Can be 2-tuple (h_step, w_step)
or positive int.
box: Outer box within which to generate sliding windows.
size: Size ``(h, w)`` of the windows.
stride: Step size between windows. Can be a ``(h_step, w_step)``
tuple or positive int.
padding: Optional padding to accommodate windows that overflow the
extent. Can be 2-tuple (h_pad, w_pad) or non-negative int.
If None, will be automatically calculated such that the windows
cover the entire extent. Defaults to ``None``.
pad_direction: If ``'end'``, only pad ymax and xmax (bottom and
right). If ``'start'``, only pad ymin and xmin (top and left).
If ``'both'``, pad all sides. If ``'both'`` pad all sides. Has
no effect if padding is zero. Defaults to ``'end'``.
extent. Can be a ``(h_pad, w_pad)`` tuple or a non-negative
int. If ``None``, will be automatically calculated such that
the windows cover the entire extent. Defaults to ``None``.
pad_direction: Directions to add padding to.
If ``'end'``, only add padding to bottom and right.
If ``'start'``, only add padding to top and left.
If ``'both'``, add padding to all sides.
Has no effect if padding is zero. Defaults to ``'end'``.
Returns:
List of windows.
Lazy list of windows.
"""
size: tuple[PosInt, PosInt] = ensure_tuple(size)
stride: tuple[PosInt, PosInt] = ensure_tuple(stride)

if size[0] <= 0 or size[1] <= 0 or stride[0] <= 0 or stride[1] <= 0:
raise ValueError('size and stride must be positive.')

if padding is None:
if size[0] < stride[0] or size[1] < stride[1]:
padding = (0, 0)
else:
padding = calculate_required_padding(self.size, size, stride,
pad_direction)

padding: tuple[NonNegInt, NonNegInt] = ensure_tuple(padding)

if padding[0] < 0 or padding[1] < 0:
raise ValueError('padding must be non-negative.')

if padding != (0, 0):
h_pad, w_pad = padding
if pad_direction == 'both':
padded_box = self.pad(
ymin=h_pad, xmin=w_pad, ymax=h_pad, xmax=w_pad)
elif pad_direction == 'end':
padded_box = self.pad(ymin=0, xmin=0, ymax=h_pad, xmax=w_pad)
elif pad_direction == 'start':
padded_box = self.pad(ymin=h_pad, xmin=w_pad, ymax=0, xmax=0)
else:
raise ValueError('pad_directions must be one of: '
'"both", "start", "end".')
return padded_box.get_windows(
size=size, stride=stride, padding=(0, 0))

# padding is necessarily (0, 0) at this point, so we ignore it
h, w = size
h_step, w_step = stride
# lb = lower bound, ub = upper bound
ymin_lb = self.ymin
xmin_lb = self.xmin
ymin_ub = self.ymax - h
xmin_ub = self.xmax - w

windows = []
for ymin in range(ymin_lb, ymin_ub + 1, h_step):
for xmin in range(xmin_lb, xmin_ub + 1, w_step):
windows.append(Box(ymin, xmin, ymin + h, xmin + w))
windows = SlidingWindows(
self,
size=size,
stride=stride,
padding=padding,
pad_direction=pad_direction)
return windows

def to_dict(self) -> dict[str, int]:
Expand Down Expand Up @@ -518,3 +505,123 @@ def __contains__(self, query: 'Self | tuple[int, int]') -> bool:
return self.xmin <= x <= self.xmax and self.ymin <= y <= self.ymax
else:
raise NotImplementedError()


class SlidingWindows(Sequence[Box]):
"""Lazy representation of a list of sliding windows.
Instead of storing a list of all windows in memory, this class dynamically
computes the coordinates of windows as they are retrieved. Supports
iteration and basic slicing.
"""

def __init__(
self,
box: Box,
*,
size: PosInt | tuple[PosInt, PosInt],
stride: PosInt | tuple[PosInt, PosInt],
padding: NonNegInt | tuple[NonNegInt, NonNegInt] | None = None,
pad_direction: Literal['both', 'start', 'end'] = 'end',
):
"""Constructor.
Each of ``size``, ``stride``, and ``padding`` can be either a positive
int or a tuple ``(vertical-component, horizontal-component)`` of
positive ints.
If ``padding`` is not specified and ``stride <= size``, it will be
automatically calculated such that the windows cover the entire extent.
Args:
box: Outer box within which to generate sliding windows.
size: Size ``(h, w)`` of the windows.
stride: Step size between windows. Can be a ``(h_step, w_step)``
tuple or positive int.
padding: Optional padding to accommodate windows that overflow the
extent. Can be a ``(h_pad, w_pad)`` tuple or a non-negative
int. If ``None``, will be automatically calculated such that
the windows cover the entire extent. Defaults to ``None``.
pad_direction: Directions to add padding to.
If ``'end'``, only add padding to bottom and right.
If ``'start'``, only add padding to top and left.
If ``'both'``, add padding to all sides.
Has no effect if padding is zero. Defaults to ``'end'``.
"""
size: tuple[PosInt, PosInt] = ensure_tuple(size)
stride: tuple[PosInt, PosInt] = ensure_tuple(stride)

if size[0] <= 0 or size[1] <= 0 or stride[0] <= 0 or stride[1] <= 0:
raise ValueError('size and stride must be positive.')

if padding is None:
if size[0] < stride[0] or size[1] < stride[1]:
padding = (0, 0)
else:
padding = calculate_required_padding(box.size, size, stride,
pad_direction)
self.box = box
self.size = size
self.stride = stride
self.padding: tuple[NonNegInt, NonNegInt] = ensure_tuple(padding)
self.pad_direction = pad_direction
self.padded_box = box.pad_directional(self.padding, self.pad_direction)

self.h, self.w = size
self.y_step, self.x_step = stride
self.y_start = self.padded_box.ymin
self.x_start = self.padded_box.xmin
self.y_end = self.padded_box.ymax - self.h
self.x_end = self.padded_box.xmax - self.w
self.nrows = int((self.y_end - self.y_start) // self.y_step + 1)
self.ncols = int((self.x_end - self.x_start) // self.x_step + 1)
self.total = self.nrows * self.ncols

@overload
def __getitem__(self, i: int | np.integer) -> Box:
...

@overload
def __getitem__(self, s: slice) -> list[Box]:
...

@overload
def __getitem__(self, inds: Sequence[int]) -> list[Box]:
...

def __getitem__(self, key: int | slice | Sequence[int]) -> Box | list[Box]:
if isinstance(key, int | np.integer):
row, col = self.index_to_rowcol(key)
return self.get_by_rowcol(row, col)
if isinstance(key, slice):
start = 0 if key.start is None else key.start
stop = len(self) if key.stop is None else key.stop
step = 1 if key.step is None else key.step
if not all(isinstance(v, int) for v in (start, stop, step)):
raise TypeError('Slice indices must be integers.')
windows = [self[i] for i in range(start, stop, step)]
return windows
windows = [self[i] for i in key]
return windows

def get_by_rowcol(self, row: int, col: int) -> Box:
"""Get window at given row and column indices."""
if row >= self.nrows or col >= self.ncols:
raise IndexError()
ymin = self.y_start + self.y_step * row
xmin = self.x_start + self.x_step * col
window = Box(ymin, xmin, ymin + self.h, xmin + self.w)
return window

def index_to_rowcol(self, i: int) -> tuple[int, int]:
"""Get row and column indices of the i-th window."""
if i >= len(self):
raise IndexError()
if i < 0:
i += len(self)
row = i // self.ncols
col = i % self.ncols
return row, col

def __len__(self) -> int:
return self.total
15 changes: 14 additions & 1 deletion tests/core/test_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from shapely.geometry import box as ShapelyBox

from rastervision.core.box import Box, BoxSizeError, RioWindow
from rastervision.core.box import Box, BoxSizeError, RioWindow, SlidingWindows

np.random.seed(1)

Expand Down Expand Up @@ -449,5 +449,18 @@ def test_error_on_nonfinite_inputs(self):
self.assertRaises(ValueError, lambda: Box(np.nan, 0, 0, 0))


class TestSlidingWindows(unittest.TestCase):
def test_getitem(self):
ws = SlidingWindows(Box(0, 0, 2, 4), size=2, stride=1, padding=0)
self.assertEqual(ws[0], Box(0, 0, 2, 2))
self.assertEqual(ws[-1], Box(0, 2, 2, 4))
self.assertListEqual(ws[1:], [Box(0, 1, 2, 3), Box(0, 2, 2, 4)])
self.assertListEqual(ws[::2], [Box(0, 0, 2, 2), Box(0, 2, 2, 4)])
self.assertListEqual(
ws[np.array([1, 0])],
[Box(0, 1, 2, 3), Box(0, 0, 2, 2)],
)


if __name__ == '__main__':
unittest.main()

0 comments on commit b9b2394

Please sign in to comment.