Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a lazy representation of a list of sliding windows #2278

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 169 additions & 62 deletions rastervision_core/rastervision/core/box.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import (TYPE_CHECKING, Literal)
from typing import (TYPE_CHECKING, Literal, Sequence, overload)
from collections.abc import Callable
from pydantic import NonNegativeInt as NonNegInt, PositiveInt as PosInt
import math
Expand Down Expand Up @@ -364,6 +364,31 @@ def pad(self, ymin: int, xmin: int, ymax: int, xmax: int) -> 'Self':
ymax=self.ymax + ymax,
xmax=self.xmax + xmax)

def pad_directional(
self,
padding: tuple[NonNegInt, NonNegInt] | NonNegInt,
pad_direction: Literal['both', 'start', 'end'] = 'end') -> 'Self':
"""Pad sides based on given padding and direction."""

padding: tuple[NonNegInt, NonNegInt] = ensure_tuple(padding)

if padding == (0, 0):
return self

if padding[0] < 0 or padding[1] < 0:
raise ValueError('padding must be non-negative.')

h_pad, w_pad = padding
if pad_direction == 'both':
return self.pad(ymin=h_pad, xmin=w_pad, ymax=h_pad, xmax=w_pad)
elif pad_direction == 'end':
return self.pad(ymin=0, xmin=0, ymax=h_pad, xmax=w_pad)
elif pad_direction == 'start':
return self.pad(ymin=h_pad, xmin=w_pad, ymax=0, xmax=0)

raise ValueError('pad_directions must be one of: '
'"both", "start", "end".')

def copy(self) -> 'Self':
return Box(*self)

Expand All @@ -373,78 +398,40 @@ def get_windows(
stride: PosInt | tuple[PosInt, PosInt],
padding: NonNegInt | tuple[NonNegInt, NonNegInt] | None = None,
pad_direction: Literal['both', 'start', 'end'] = 'end'
) -> list['Self']:
) -> 'SlidingWindows':
"""Return sliding windows for given size, stride, and padding.

Each of size, stride, and padding can be either a positive int or
a tuple ``(vertical-component, horizontal-component)`` of positive
ints.
Each of ``size``, ``stride``, and ``padding`` can be either a positive
int or a tuple ``(vertical-component, horizontal-component)`` of
positive ints.

If ``padding`` is not specified and ``stride <= size``, it will be
automatically calculated such that the windows cover the entire extent.

Args:
size: Size (h, w) of the windows.
stride: Step size between windows. Can be 2-tuple (h_step, w_step)
or positive int.
box: Outer box within which to generate sliding windows.
size: Size ``(h, w)`` of the windows.
stride: Step size between windows. Can be a ``(h_step, w_step)``
tuple or positive int.
padding: Optional padding to accommodate windows that overflow the
extent. Can be 2-tuple (h_pad, w_pad) or non-negative int.
If None, will be automatically calculated such that the windows
cover the entire extent. Defaults to ``None``.
pad_direction: If ``'end'``, only pad ymax and xmax (bottom and
right). If ``'start'``, only pad ymin and xmin (top and left).
If ``'both'``, pad all sides. If ``'both'`` pad all sides. Has
no effect if padding is zero. Defaults to ``'end'``.
extent. Can be a ``(h_pad, w_pad)`` tuple or a non-negative
int. If ``None``, will be automatically calculated such that
the windows cover the entire extent. Defaults to ``None``.
pad_direction: Directions to add padding to.
If ``'end'``, only add padding to bottom and right.
If ``'start'``, only add padding to top and left.
If ``'both'``, add padding to all sides.
Has no effect if padding is zero. Defaults to ``'end'``.

Returns:
List of windows.
Lazy list of windows.
"""
size: tuple[PosInt, PosInt] = ensure_tuple(size)
stride: tuple[PosInt, PosInt] = ensure_tuple(stride)

if size[0] <= 0 or size[1] <= 0 or stride[0] <= 0 or stride[1] <= 0:
raise ValueError('size and stride must be positive.')

if padding is None:
if size[0] < stride[0] or size[1] < stride[1]:
padding = (0, 0)
else:
padding = calculate_required_padding(self.size, size, stride,
pad_direction)

padding: tuple[NonNegInt, NonNegInt] = ensure_tuple(padding)

if padding[0] < 0 or padding[1] < 0:
raise ValueError('padding must be non-negative.')

if padding != (0, 0):
h_pad, w_pad = padding
if pad_direction == 'both':
padded_box = self.pad(
ymin=h_pad, xmin=w_pad, ymax=h_pad, xmax=w_pad)
elif pad_direction == 'end':
padded_box = self.pad(ymin=0, xmin=0, ymax=h_pad, xmax=w_pad)
elif pad_direction == 'start':
padded_box = self.pad(ymin=h_pad, xmin=w_pad, ymax=0, xmax=0)
else:
raise ValueError('pad_directions must be one of: '
'"both", "start", "end".')
return padded_box.get_windows(
size=size, stride=stride, padding=(0, 0))

# padding is necessarily (0, 0) at this point, so we ignore it
h, w = size
h_step, w_step = stride
# lb = lower bound, ub = upper bound
ymin_lb = self.ymin
xmin_lb = self.xmin
ymin_ub = self.ymax - h
xmin_ub = self.xmax - w

windows = []
for ymin in range(ymin_lb, ymin_ub + 1, h_step):
for xmin in range(xmin_lb, xmin_ub + 1, w_step):
windows.append(Box(ymin, xmin, ymin + h, xmin + w))
windows = SlidingWindows(
self,
size=size,
stride=stride,
padding=padding,
pad_direction=pad_direction)
return windows

def to_dict(self) -> dict[str, int]:
Expand Down Expand Up @@ -518,3 +505,123 @@ def __contains__(self, query: 'Self | tuple[int, int]') -> bool:
return self.xmin <= x <= self.xmax and self.ymin <= y <= self.ymax
else:
raise NotImplementedError()


class SlidingWindows(Sequence[Box]):
"""Lazy representation of a list of sliding windows.

Instead of storing a list of all windows in memory, this class dynamically
computes the coordinates of windows as they are retrieved. Supports
iteration and basic slicing.
"""

def __init__(
self,
box: Box,
*,
size: PosInt | tuple[PosInt, PosInt],
stride: PosInt | tuple[PosInt, PosInt],
padding: NonNegInt | tuple[NonNegInt, NonNegInt] | None = None,
pad_direction: Literal['both', 'start', 'end'] = 'end',
):
"""Constructor.

Each of ``size``, ``stride``, and ``padding`` can be either a positive
int or a tuple ``(vertical-component, horizontal-component)`` of
positive ints.

If ``padding`` is not specified and ``stride <= size``, it will be
automatically calculated such that the windows cover the entire extent.

Args:
box: Outer box within which to generate sliding windows.
size: Size ``(h, w)`` of the windows.
stride: Step size between windows. Can be a ``(h_step, w_step)``
tuple or positive int.
padding: Optional padding to accommodate windows that overflow the
extent. Can be a ``(h_pad, w_pad)`` tuple or a non-negative
int. If ``None``, will be automatically calculated such that
the windows cover the entire extent. Defaults to ``None``.
pad_direction: Directions to add padding to.
If ``'end'``, only add padding to bottom and right.
If ``'start'``, only add padding to top and left.
If ``'both'``, add padding to all sides.
Has no effect if padding is zero. Defaults to ``'end'``.
"""
size: tuple[PosInt, PosInt] = ensure_tuple(size)
stride: tuple[PosInt, PosInt] = ensure_tuple(stride)

if size[0] <= 0 or size[1] <= 0 or stride[0] <= 0 or stride[1] <= 0:
raise ValueError('size and stride must be positive.')

if padding is None:
if size[0] < stride[0] or size[1] < stride[1]:
padding = (0, 0)
else:
padding = calculate_required_padding(box.size, size, stride,
pad_direction)
self.box = box
self.size = size
self.stride = stride
self.padding: tuple[NonNegInt, NonNegInt] = ensure_tuple(padding)
self.pad_direction = pad_direction
self.padded_box = box.pad_directional(self.padding, self.pad_direction)

self.h, self.w = size
self.y_step, self.x_step = stride
self.y_start = self.padded_box.ymin
self.x_start = self.padded_box.xmin
self.y_end = self.padded_box.ymax - self.h
self.x_end = self.padded_box.xmax - self.w
self.nrows = int((self.y_end - self.y_start) // self.y_step + 1)
self.ncols = int((self.x_end - self.x_start) // self.x_step + 1)
self.total = self.nrows * self.ncols

@overload
def __getitem__(self, i: int | np.integer) -> Box:
...

@overload
def __getitem__(self, s: slice) -> list[Box]:
...

@overload
def __getitem__(self, inds: Sequence[int]) -> list[Box]:
...

def __getitem__(self, key: int | slice | Sequence[int]) -> Box | list[Box]:
if isinstance(key, int | np.integer):
row, col = self.index_to_rowcol(key)
return self.get_by_rowcol(row, col)
if isinstance(key, slice):
start = 0 if key.start is None else key.start
stop = len(self) if key.stop is None else key.stop
step = 1 if key.step is None else key.step
if not all(isinstance(v, int) for v in (start, stop, step)):
raise TypeError('Slice indices must be integers.')
windows = [self[i] for i in range(start, stop, step)]
return windows
windows = [self[i] for i in key]
return windows

def get_by_rowcol(self, row: int, col: int) -> Box:
"""Get window at given row and column indices."""
if row >= self.nrows or col >= self.ncols:
raise IndexError()
ymin = self.y_start + self.y_step * row
xmin = self.x_start + self.x_step * col
window = Box(ymin, xmin, ymin + self.h, xmin + self.w)
return window

def index_to_rowcol(self, i: int) -> tuple[int, int]:
"""Get row and column indices of the i-th window."""
if i >= len(self):
raise IndexError()
if i < 0:
i += len(self)
row = i // self.ncols
col = i % self.ncols
return row, col

def __len__(self) -> int:
return self.total
15 changes: 14 additions & 1 deletion tests/core/test_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from shapely.geometry import box as ShapelyBox

from rastervision.core.box import Box, BoxSizeError, RioWindow
from rastervision.core.box import Box, BoxSizeError, RioWindow, SlidingWindows

np.random.seed(1)

Expand Down Expand Up @@ -449,5 +449,18 @@ def test_error_on_nonfinite_inputs(self):
self.assertRaises(ValueError, lambda: Box(np.nan, 0, 0, 0))


class TestSlidingWindows(unittest.TestCase):
def test_getitem(self):
ws = SlidingWindows(Box(0, 0, 2, 4), size=2, stride=1, padding=0)
self.assertEqual(ws[0], Box(0, 0, 2, 2))
self.assertEqual(ws[-1], Box(0, 2, 2, 4))
self.assertListEqual(ws[1:], [Box(0, 1, 2, 3), Box(0, 2, 2, 4)])
self.assertListEqual(ws[::2], [Box(0, 0, 2, 2), Box(0, 2, 2, 4)])
self.assertListEqual(
ws[np.array([1, 0])],
[Box(0, 1, 2, 3), Box(0, 0, 2, 2)],
)


if __name__ == '__main__':
unittest.main()
Loading