diff --git a/rastervision_core/rastervision/core/box.py b/rastervision_core/rastervision/core/box.py index 3521a5716..4fa177297 100644 --- a/rastervision_core/rastervision/core/box.py +++ b/rastervision_core/rastervision/core/box.py @@ -1,4 +1,4 @@ -from typing import (TYPE_CHECKING, Literal) +from typing import (TYPE_CHECKING, Literal, Sequence, overload) from collections.abc import Callable from pydantic import NonNegativeInt as NonNegInt, PositiveInt as PosInt import math @@ -364,6 +364,31 @@ def pad(self, ymin: int, xmin: int, ymax: int, xmax: int) -> 'Self': ymax=self.ymax + ymax, xmax=self.xmax + xmax) + def pad_directional( + self, + padding: tuple[NonNegInt, NonNegInt] | NonNegInt, + pad_direction: Literal['both', 'start', 'end'] = 'end') -> 'Self': + """Pad sides based on given padding and direction.""" + + padding: tuple[NonNegInt, NonNegInt] = ensure_tuple(padding) + + if padding == (0, 0): + return self + + if padding[0] < 0 or padding[1] < 0: + raise ValueError('padding must be non-negative.') + + h_pad, w_pad = padding + if pad_direction == 'both': + return self.pad(ymin=h_pad, xmin=w_pad, ymax=h_pad, xmax=w_pad) + elif pad_direction == 'end': + return self.pad(ymin=0, xmin=0, ymax=h_pad, xmax=w_pad) + elif pad_direction == 'start': + return self.pad(ymin=h_pad, xmin=w_pad, ymax=0, xmax=0) + + raise ValueError('pad_directions must be one of: ' + '"both", "start", "end".') + def copy(self) -> 'Self': return Box(*self) @@ -373,78 +398,40 @@ def get_windows( stride: PosInt | tuple[PosInt, PosInt], padding: NonNegInt | tuple[NonNegInt, NonNegInt] | None = None, pad_direction: Literal['both', 'start', 'end'] = 'end' - ) -> list['Self']: + ) -> 'SlidingWindows': """Return sliding windows for given size, stride, and padding. - Each of size, stride, and padding can be either a positive int or - a tuple ``(vertical-component, horizontal-component)`` of positive - ints. + Each of ``size``, ``stride``, and ``padding`` can be either a positive + int or a tuple ``(vertical-component, horizontal-component)`` of + positive ints. If ``padding`` is not specified and ``stride <= size``, it will be automatically calculated such that the windows cover the entire extent. Args: - size: Size (h, w) of the windows. - stride: Step size between windows. Can be 2-tuple (h_step, w_step) - or positive int. + box: Outer box within which to generate sliding windows. + size: Size ``(h, w)`` of the windows. + stride: Step size between windows. Can be a ``(h_step, w_step)`` + tuple or positive int. padding: Optional padding to accommodate windows that overflow the - extent. Can be 2-tuple (h_pad, w_pad) or non-negative int. - If None, will be automatically calculated such that the windows - cover the entire extent. Defaults to ``None``. - pad_direction: If ``'end'``, only pad ymax and xmax (bottom and - right). If ``'start'``, only pad ymin and xmin (top and left). - If ``'both'``, pad all sides. If ``'both'`` pad all sides. Has - no effect if padding is zero. Defaults to ``'end'``. + extent. Can be a ``(h_pad, w_pad)`` tuple or a non-negative + int. If ``None``, will be automatically calculated such that + the windows cover the entire extent. Defaults to ``None``. + pad_direction: Directions to add padding to. + If ``'end'``, only add padding to bottom and right. + If ``'start'``, only add padding to top and left. + If ``'both'``, add padding to all sides. + Has no effect if padding is zero. Defaults to ``'end'``. Returns: - List of windows. + Lazy list of windows. """ - size: tuple[PosInt, PosInt] = ensure_tuple(size) - stride: tuple[PosInt, PosInt] = ensure_tuple(stride) - - if size[0] <= 0 or size[1] <= 0 or stride[0] <= 0 or stride[1] <= 0: - raise ValueError('size and stride must be positive.') - - if padding is None: - if size[0] < stride[0] or size[1] < stride[1]: - padding = (0, 0) - else: - padding = calculate_required_padding(self.size, size, stride, - pad_direction) - - padding: tuple[NonNegInt, NonNegInt] = ensure_tuple(padding) - - if padding[0] < 0 or padding[1] < 0: - raise ValueError('padding must be non-negative.') - - if padding != (0, 0): - h_pad, w_pad = padding - if pad_direction == 'both': - padded_box = self.pad( - ymin=h_pad, xmin=w_pad, ymax=h_pad, xmax=w_pad) - elif pad_direction == 'end': - padded_box = self.pad(ymin=0, xmin=0, ymax=h_pad, xmax=w_pad) - elif pad_direction == 'start': - padded_box = self.pad(ymin=h_pad, xmin=w_pad, ymax=0, xmax=0) - else: - raise ValueError('pad_directions must be one of: ' - '"both", "start", "end".') - return padded_box.get_windows( - size=size, stride=stride, padding=(0, 0)) - - # padding is necessarily (0, 0) at this point, so we ignore it - h, w = size - h_step, w_step = stride - # lb = lower bound, ub = upper bound - ymin_lb = self.ymin - xmin_lb = self.xmin - ymin_ub = self.ymax - h - xmin_ub = self.xmax - w - - windows = [] - for ymin in range(ymin_lb, ymin_ub + 1, h_step): - for xmin in range(xmin_lb, xmin_ub + 1, w_step): - windows.append(Box(ymin, xmin, ymin + h, xmin + w)) + windows = SlidingWindows( + self, + size=size, + stride=stride, + padding=padding, + pad_direction=pad_direction) return windows def to_dict(self) -> dict[str, int]: @@ -518,3 +505,123 @@ def __contains__(self, query: 'Self | tuple[int, int]') -> bool: return self.xmin <= x <= self.xmax and self.ymin <= y <= self.ymax else: raise NotImplementedError() + + +class SlidingWindows(Sequence[Box]): + """Lazy representation of a list of sliding windows. + + Instead of storing a list of all windows in memory, this class dynamically + computes the coordinates of windows as they are retrieved. Supports + iteration and basic slicing. + """ + + def __init__( + self, + box: Box, + *, + size: PosInt | tuple[PosInt, PosInt], + stride: PosInt | tuple[PosInt, PosInt], + padding: NonNegInt | tuple[NonNegInt, NonNegInt] | None = None, + pad_direction: Literal['both', 'start', 'end'] = 'end', + ): + """Constructor. + + Each of ``size``, ``stride``, and ``padding`` can be either a positive + int or a tuple ``(vertical-component, horizontal-component)`` of + positive ints. + + If ``padding`` is not specified and ``stride <= size``, it will be + automatically calculated such that the windows cover the entire extent. + + Args: + box: Outer box within which to generate sliding windows. + size: Size ``(h, w)`` of the windows. + stride: Step size between windows. Can be a ``(h_step, w_step)`` + tuple or positive int. + padding: Optional padding to accommodate windows that overflow the + extent. Can be a ``(h_pad, w_pad)`` tuple or a non-negative + int. If ``None``, will be automatically calculated such that + the windows cover the entire extent. Defaults to ``None``. + pad_direction: Directions to add padding to. + If ``'end'``, only add padding to bottom and right. + If ``'start'``, only add padding to top and left. + If ``'both'``, add padding to all sides. + Has no effect if padding is zero. Defaults to ``'end'``. + """ + size: tuple[PosInt, PosInt] = ensure_tuple(size) + stride: tuple[PosInt, PosInt] = ensure_tuple(stride) + + if size[0] <= 0 or size[1] <= 0 or stride[0] <= 0 or stride[1] <= 0: + raise ValueError('size and stride must be positive.') + + if padding is None: + if size[0] < stride[0] or size[1] < stride[1]: + padding = (0, 0) + else: + padding = calculate_required_padding(box.size, size, stride, + pad_direction) + self.box = box + self.size = size + self.stride = stride + self.padding: tuple[NonNegInt, NonNegInt] = ensure_tuple(padding) + self.pad_direction = pad_direction + self.padded_box = box.pad_directional(self.padding, self.pad_direction) + + self.h, self.w = size + self.y_step, self.x_step = stride + self.y_start = self.padded_box.ymin + self.x_start = self.padded_box.xmin + self.y_end = self.padded_box.ymax - self.h + self.x_end = self.padded_box.xmax - self.w + self.nrows = int((self.y_end - self.y_start) // self.y_step + 1) + self.ncols = int((self.x_end - self.x_start) // self.x_step + 1) + self.total = self.nrows * self.ncols + + @overload + def __getitem__(self, i: int | np.integer) -> Box: + ... + + @overload + def __getitem__(self, s: slice) -> list[Box]: + ... + + @overload + def __getitem__(self, inds: Sequence[int]) -> list[Box]: + ... + + def __getitem__(self, key: int | slice | Sequence[int]) -> Box | list[Box]: + if isinstance(key, int | np.integer): + row, col = self.index_to_rowcol(key) + return self.get_by_rowcol(row, col) + if isinstance(key, slice): + start = 0 if key.start is None else key.start + stop = len(self) if key.stop is None else key.stop + step = 1 if key.step is None else key.step + if not all(isinstance(v, int) for v in (start, stop, step)): + raise TypeError('Slice indices must be integers.') + windows = [self[i] for i in range(start, stop, step)] + return windows + windows = [self[i] for i in key] + return windows + + def get_by_rowcol(self, row: int, col: int) -> Box: + """Get window at given row and column indices.""" + if row >= self.nrows or col >= self.ncols: + raise IndexError() + ymin = self.y_start + self.y_step * row + xmin = self.x_start + self.x_step * col + window = Box(ymin, xmin, ymin + self.h, xmin + self.w) + return window + + def index_to_rowcol(self, i: int) -> tuple[int, int]: + """Get row and column indices of the i-th window.""" + if i >= len(self): + raise IndexError() + if i < 0: + i += len(self) + row = i // self.ncols + col = i % self.ncols + return row, col + + def __len__(self) -> int: + return self.total diff --git a/tests/core/test_box.py b/tests/core/test_box.py index e13c24207..7751cb8a1 100644 --- a/tests/core/test_box.py +++ b/tests/core/test_box.py @@ -3,7 +3,7 @@ import numpy as np from shapely.geometry import box as ShapelyBox -from rastervision.core.box import Box, BoxSizeError, RioWindow +from rastervision.core.box import Box, BoxSizeError, RioWindow, SlidingWindows np.random.seed(1) @@ -449,5 +449,18 @@ def test_error_on_nonfinite_inputs(self): self.assertRaises(ValueError, lambda: Box(np.nan, 0, 0, 0)) +class TestSlidingWindows(unittest.TestCase): + def test_getitem(self): + ws = SlidingWindows(Box(0, 0, 2, 4), size=2, stride=1, padding=0) + self.assertEqual(ws[0], Box(0, 0, 2, 2)) + self.assertEqual(ws[-1], Box(0, 2, 2, 4)) + self.assertListEqual(ws[1:], [Box(0, 1, 2, 3), Box(0, 2, 2, 4)]) + self.assertListEqual(ws[::2], [Box(0, 0, 2, 2), Box(0, 2, 2, 4)]) + self.assertListEqual( + ws[np.array([1, 0])], + [Box(0, 1, 2, 3), Box(0, 0, 2, 2)], + ) + + if __name__ == '__main__': unittest.main()