Skip to content

Commit

Permalink
Add RollingPairwise, ApplyPairwise
Browse files Browse the repository at this point in the history
  • Loading branch information
ajcr committed Apr 1, 2023
1 parent b29151c commit 14d63b5
Show file tree
Hide file tree
Showing 4 changed files with 340 additions and 0 deletions.
3 changes: 3 additions & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).

## [0.5.0]
### Added
- New `rolling.ApplyPairwise` object

## [0.4.0]
### Added
Expand Down
79 changes: 79 additions & 0 deletions rolling/apply_pairwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from collections import deque
from itertools import islice

from .base_pairwise import RollingPairwise


class ApplyPairwise(RollingPairwise):
"""
Apply a binary function to windows over two iterables.
Parameters
----------
iterable : any iterable object
window_size : integer, the size of the rolling
window moving over the iterable
function : callable,
a binary callable to be applied to the current
window of each iterable
Complexity
----------
Update time: function dependent
Memory usage: O(k)
where k is the size of the rolling window
Example
--------
>>> from rolling import ApplyPairwise
>>> from statistics import correlation
>>> seq_1 = [1, 2, 3, 4, 5]
>>> seq_2 = [1, 2, 3, 2, 1]
>>> r_corr = ApplyPairwise(seq_1, seq_2, window_size=3, function=correlation)
>>> list(r_corr)
[1.0, 0.0, -1.0]
"""
def __init__(self, iterable_1, iterable_2, window_size, function, window_type="fixed"):
self._buffer_1 = deque(maxlen=window_size)
self._buffer_2 = deque(maxlen=window_size)
self._function = function
super().__init__(iterable_1, iterable_2, window_size=window_size, window_type=window_type)

def _init_fixed(self, **kwargs):
pairs = zip(self._iterator_1, self._iterator_2)
for item_1, item_2 in islice(pairs, self.window_size-1):
self._buffer_1.append(item_1)
self._buffer_2.append(item_2)

def _init_variable(self, **kwargs):
pass # no action required

@property
def current_value(self):
return self._function(self._buffer_1, self._buffer_2)

def _add_new(self, new_1, new_2):
self._buffer_1.append(new_1)
self._buffer_2.append(new_2)

def _remove_old(self):
self._buffer_1.popleft()
self._buffer_2.popleft()

def _update_window(self, new_1, new_2):
self._buffer_1.append(new_1)
self._buffer_2.append(new_2)

@property
def _obs(self):
return len(self._buffer_1)

def __repr__(self):
return "RollingPairwise(operation='{}', window_size={}, window_type='{}')".format(
self._function.__name__, self.window_size, self.window_type
)
129 changes: 129 additions & 0 deletions rolling/base_pairwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import abc
from collections.abc import Iterator
from itertools import chain


class RollingPairwise(Iterator):
"""
Baseclass for rolling iterators over two iterables.
"""
def __init__(self, iterable_1, iterable_2, window_size, window_type="fixed", **kwargs):
self.window_type = window_type
self.window_size = _validate_window_size(window_size)
self._iterator_1 = iter(iterable_1)
self._iterator_2 = iter(iterable_2)
self._filled = self.window_type == "fixed"

if window_type == "fixed":
self._init_fixed(**kwargs)

elif window_type == "variable":
self._init_variable(**kwargs)

else:
raise ValueError(f"Unknown window_type '{window_type}'")

def __repr__(self):
return "RollingPairwise(operation='{}', window_size={}, window_type='{}')".format(
self.__class__.__name__, self.window_size, self.window_type
)

def _next_fixed(self):
new_1 = next(self._iterator_1)
new_2 = next(self._iterator_2)
self._update_window(new_1, new_2)
return self.current_value

def _next_variable(self):
# while the window size is not reached, add new values
if not self._filled and self._obs < self.window_size:
new_1 = next(self._iterator_1)
new_2 = next(self._iterator_2)
self._add_new(new_1, new_2)
self._filled = self._obs == self.window_size
return self.current_value

# once the window size is reached, consider fixed until iterator ends
try:
return self._next_fixed()

# if the iterator finishes, remove the oldest values one at a time
except StopIteration:
if self._obs == 1:
raise
else:
self._remove_old()
return self.current_value

def __next__(self):

if self.window_type == "fixed":
return self._next_fixed()

if self.window_type == "variable":
return self._next_variable()

raise NotImplementedError(f"next() not implemented for {self.window_type}")

@property
@abc.abstractmethod
def current_value(self):
"""
Return the current value of the window
"""
pass

@property
@abc.abstractmethod
def _obs(self):
"""
Return the number of observations in the window
"""
pass

@abc.abstractmethod
def _init_fixed(self, **kwargs):
"""
Intialise as a fixed-size window
"""
pass

@abc.abstractmethod
def _init_variable(self, **kwargs):
"""
Intialise as a variable-size window
"""
pass

@abc.abstractmethod
def _remove_old(self):
"""
Remove the oldest value from the window, decreasing window size by 1
"""
pass

@abc.abstractmethod
def _add_new(self, new):
"""
Add a new value to the window, increasing window size by 1
"""
pass

@abc.abstractmethod
def _update_window(self, new):
"""
Add a new value to the window and remove the oldest value from the window
"""
pass


def _validate_window_size(k):
"""
Check if k is a positive integer
"""
if not isinstance(k, int):
raise TypeError(f"window_size must be integer type, got {type(k).__name__}")
if k <= 0:
raise ValueError("window_size must be positive")
return k
129 changes: 129 additions & 0 deletions tests/test_apply_pairwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import pytest

from rolling.apply_pairwise import ApplyPairwise

ARRAY_1 = [3, 6, 5, 8, 1]
ARRAY_2 = [1, 2, 3, 4, 5]


@pytest.mark.parametrize(
"window_size,expected",
[
(6, []),
(5, [[(3, 1), (6, 2), (5, 3), (8, 4), (1, 5)]]),
(4, [[(3, 1), (6, 2), (5, 3), (8, 4)], [(6, 2), (5, 3), (8, 4), (1, 5)]]),
(3, [[(3, 1), (6, 2), (5, 3)], [(6, 2), (5, 3), (8, 4)], [(5, 3), (8, 4), (1, 5)]]),
(2, [[(3, 1), (6, 2)], [(6, 2), (5, 3)], [(5, 3), (8, 4)], [(8, 4), (1, 5)]]),
(1, [[(3, 1)], [(6, 2)], [(5, 3)], [(8, 4)], [(1, 5)]]),
],
)
def test_rolling_apply_pairwise_fixed(window_size, expected):
r = ApplyPairwise(ARRAY_1, ARRAY_2, window_size, function=lambda x, y: list(zip(x, y)))
assert list(r) == expected


@pytest.mark.parametrize(
"window_size,expected",
[
(
6,
[
[(3, 1)],
[(3, 1), (6, 2)],
[(3, 1), (6, 2), (5, 3)],
[(3, 1), (6, 2), (5, 3), (8, 4)],
[(3, 1), (6, 2), (5, 3), (8, 4), (1, 5)],
]
),
(
5,
[
[(3, 1)],
[(3, 1), (6, 2)],
[(3, 1), (6, 2), (5, 3)],
[(3, 1), (6, 2), (5, 3), (8, 4)],
[(3, 1), (6, 2), (5, 3), (8, 4), (1, 5)],
[(6, 2), (5, 3), (8, 4), (1, 5)],
[(5, 3), (8, 4), (1, 5)],
[(8, 4), (1, 5)],
[(1, 5)],
],
),
(
4,
[
[(3, 1)],
[(3, 1), (6, 2)],
[(3, 1), (6, 2), (5, 3)],
[(3, 1), (6, 2), (5, 3), (8, 4)],
[(6, 2), (5, 3), (8, 4), (1, 5)],
[(5, 3), (8, 4), (1, 5)],
[(8, 4), (1, 5)],
[(1, 5)],
],
),
(
3,
[
[(3, 1)],
[(3, 1), (6, 2)],
[(3, 1), (6, 2), (5, 3)],
[(6, 2), (5, 3), (8, 4)],
[(5, 3), (8, 4), (1, 5)],
[(8, 4), (1, 5)],
[(1, 5)],
],
),
(
2,
[
[(3, 1)],
[(3, 1), (6, 2)],
[(6, 2), (5, 3)],
[(5, 3), (8, 4)],
[(8, 4), (1, 5)],
[(1, 5)],
],
),
(
1,
[
[(3, 1)],
[(6, 2)],
[(5, 3)],
[(8, 4)],
[(1, 5)],
]
),
],
)
def test_rolling_apply_variable(window_size, expected):
r = ApplyPairwise(
ARRAY_1,
ARRAY_2,
window_size,
window_type="variable",
function=lambda x, y: list(zip(x, y)),
)
assert list(r) == expected


@pytest.mark.parametrize(
"array_1,array_2,window_type,expected",
[
([], [], "fixed", []),
([3], [], "fixed", []),
([], [5], "fixed", []),
([], [], "variable", []),
([1], [1], "variable", [[(1, 1)]]),
],
)
def test_rolling_apply_pairwise_over_short_iterables(array_1, array_2, window_type, expected):
r = ApplyPairwise(
array_1,
array_2,
10,
window_type=window_type,
function=lambda x, y: list(zip(x, y)),
)
assert list(r) == expected

0 comments on commit 14d63b5

Please sign in to comment.