-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
340 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from collections import deque | ||
from itertools import islice | ||
|
||
from .base_pairwise import RollingPairwise | ||
|
||
|
||
class ApplyPairwise(RollingPairwise): | ||
""" | ||
Apply a binary function to windows over two iterables. | ||
Parameters | ||
---------- | ||
iterable : any iterable object | ||
window_size : integer, the size of the rolling | ||
window moving over the iterable | ||
function : callable, | ||
a binary callable to be applied to the current | ||
window of each iterable | ||
Complexity | ||
---------- | ||
Update time: function dependent | ||
Memory usage: O(k) | ||
where k is the size of the rolling window | ||
Example | ||
-------- | ||
>>> from rolling import ApplyPairwise | ||
>>> from statistics import correlation | ||
>>> seq_1 = [1, 2, 3, 4, 5] | ||
>>> seq_2 = [1, 2, 3, 2, 1] | ||
>>> r_corr = ApplyPairwise(seq_1, seq_2, window_size=3, function=correlation) | ||
>>> list(r_corr) | ||
[1.0, 0.0, -1.0] | ||
""" | ||
def __init__(self, iterable_1, iterable_2, window_size, function, window_type="fixed"): | ||
self._buffer_1 = deque(maxlen=window_size) | ||
self._buffer_2 = deque(maxlen=window_size) | ||
self._function = function | ||
super().__init__(iterable_1, iterable_2, window_size=window_size, window_type=window_type) | ||
|
||
def _init_fixed(self, **kwargs): | ||
pairs = zip(self._iterator_1, self._iterator_2) | ||
for item_1, item_2 in islice(pairs, self.window_size-1): | ||
self._buffer_1.append(item_1) | ||
self._buffer_2.append(item_2) | ||
|
||
def _init_variable(self, **kwargs): | ||
pass # no action required | ||
|
||
@property | ||
def current_value(self): | ||
return self._function(self._buffer_1, self._buffer_2) | ||
|
||
def _add_new(self, new_1, new_2): | ||
self._buffer_1.append(new_1) | ||
self._buffer_2.append(new_2) | ||
|
||
def _remove_old(self): | ||
self._buffer_1.popleft() | ||
self._buffer_2.popleft() | ||
|
||
def _update_window(self, new_1, new_2): | ||
self._buffer_1.append(new_1) | ||
self._buffer_2.append(new_2) | ||
|
||
@property | ||
def _obs(self): | ||
return len(self._buffer_1) | ||
|
||
def __repr__(self): | ||
return "RollingPairwise(operation='{}', window_size={}, window_type='{}')".format( | ||
self._function.__name__, self.window_size, self.window_type | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import abc | ||
from collections.abc import Iterator | ||
from itertools import chain | ||
|
||
|
||
class RollingPairwise(Iterator): | ||
""" | ||
Baseclass for rolling iterators over two iterables. | ||
""" | ||
def __init__(self, iterable_1, iterable_2, window_size, window_type="fixed", **kwargs): | ||
self.window_type = window_type | ||
self.window_size = _validate_window_size(window_size) | ||
self._iterator_1 = iter(iterable_1) | ||
self._iterator_2 = iter(iterable_2) | ||
self._filled = self.window_type == "fixed" | ||
|
||
if window_type == "fixed": | ||
self._init_fixed(**kwargs) | ||
|
||
elif window_type == "variable": | ||
self._init_variable(**kwargs) | ||
|
||
else: | ||
raise ValueError(f"Unknown window_type '{window_type}'") | ||
|
||
def __repr__(self): | ||
return "RollingPairwise(operation='{}', window_size={}, window_type='{}')".format( | ||
self.__class__.__name__, self.window_size, self.window_type | ||
) | ||
|
||
def _next_fixed(self): | ||
new_1 = next(self._iterator_1) | ||
new_2 = next(self._iterator_2) | ||
self._update_window(new_1, new_2) | ||
return self.current_value | ||
|
||
def _next_variable(self): | ||
# while the window size is not reached, add new values | ||
if not self._filled and self._obs < self.window_size: | ||
new_1 = next(self._iterator_1) | ||
new_2 = next(self._iterator_2) | ||
self._add_new(new_1, new_2) | ||
self._filled = self._obs == self.window_size | ||
return self.current_value | ||
|
||
# once the window size is reached, consider fixed until iterator ends | ||
try: | ||
return self._next_fixed() | ||
|
||
# if the iterator finishes, remove the oldest values one at a time | ||
except StopIteration: | ||
if self._obs == 1: | ||
raise | ||
else: | ||
self._remove_old() | ||
return self.current_value | ||
|
||
def __next__(self): | ||
|
||
if self.window_type == "fixed": | ||
return self._next_fixed() | ||
|
||
if self.window_type == "variable": | ||
return self._next_variable() | ||
|
||
raise NotImplementedError(f"next() not implemented for {self.window_type}") | ||
|
||
@property | ||
@abc.abstractmethod | ||
def current_value(self): | ||
""" | ||
Return the current value of the window | ||
""" | ||
pass | ||
|
||
@property | ||
@abc.abstractmethod | ||
def _obs(self): | ||
""" | ||
Return the number of observations in the window | ||
""" | ||
pass | ||
|
||
@abc.abstractmethod | ||
def _init_fixed(self, **kwargs): | ||
""" | ||
Intialise as a fixed-size window | ||
""" | ||
pass | ||
|
||
@abc.abstractmethod | ||
def _init_variable(self, **kwargs): | ||
""" | ||
Intialise as a variable-size window | ||
""" | ||
pass | ||
|
||
@abc.abstractmethod | ||
def _remove_old(self): | ||
""" | ||
Remove the oldest value from the window, decreasing window size by 1 | ||
""" | ||
pass | ||
|
||
@abc.abstractmethod | ||
def _add_new(self, new): | ||
""" | ||
Add a new value to the window, increasing window size by 1 | ||
""" | ||
pass | ||
|
||
@abc.abstractmethod | ||
def _update_window(self, new): | ||
""" | ||
Add a new value to the window and remove the oldest value from the window | ||
""" | ||
pass | ||
|
||
|
||
def _validate_window_size(k): | ||
""" | ||
Check if k is a positive integer | ||
""" | ||
if not isinstance(k, int): | ||
raise TypeError(f"window_size must be integer type, got {type(k).__name__}") | ||
if k <= 0: | ||
raise ValueError("window_size must be positive") | ||
return k |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import pytest | ||
|
||
from rolling.apply_pairwise import ApplyPairwise | ||
|
||
ARRAY_1 = [3, 6, 5, 8, 1] | ||
ARRAY_2 = [1, 2, 3, 4, 5] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"window_size,expected", | ||
[ | ||
(6, []), | ||
(5, [[(3, 1), (6, 2), (5, 3), (8, 4), (1, 5)]]), | ||
(4, [[(3, 1), (6, 2), (5, 3), (8, 4)], [(6, 2), (5, 3), (8, 4), (1, 5)]]), | ||
(3, [[(3, 1), (6, 2), (5, 3)], [(6, 2), (5, 3), (8, 4)], [(5, 3), (8, 4), (1, 5)]]), | ||
(2, [[(3, 1), (6, 2)], [(6, 2), (5, 3)], [(5, 3), (8, 4)], [(8, 4), (1, 5)]]), | ||
(1, [[(3, 1)], [(6, 2)], [(5, 3)], [(8, 4)], [(1, 5)]]), | ||
], | ||
) | ||
def test_rolling_apply_pairwise_fixed(window_size, expected): | ||
r = ApplyPairwise(ARRAY_1, ARRAY_2, window_size, function=lambda x, y: list(zip(x, y))) | ||
assert list(r) == expected | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"window_size,expected", | ||
[ | ||
( | ||
6, | ||
[ | ||
[(3, 1)], | ||
[(3, 1), (6, 2)], | ||
[(3, 1), (6, 2), (5, 3)], | ||
[(3, 1), (6, 2), (5, 3), (8, 4)], | ||
[(3, 1), (6, 2), (5, 3), (8, 4), (1, 5)], | ||
] | ||
), | ||
( | ||
5, | ||
[ | ||
[(3, 1)], | ||
[(3, 1), (6, 2)], | ||
[(3, 1), (6, 2), (5, 3)], | ||
[(3, 1), (6, 2), (5, 3), (8, 4)], | ||
[(3, 1), (6, 2), (5, 3), (8, 4), (1, 5)], | ||
[(6, 2), (5, 3), (8, 4), (1, 5)], | ||
[(5, 3), (8, 4), (1, 5)], | ||
[(8, 4), (1, 5)], | ||
[(1, 5)], | ||
], | ||
), | ||
( | ||
4, | ||
[ | ||
[(3, 1)], | ||
[(3, 1), (6, 2)], | ||
[(3, 1), (6, 2), (5, 3)], | ||
[(3, 1), (6, 2), (5, 3), (8, 4)], | ||
[(6, 2), (5, 3), (8, 4), (1, 5)], | ||
[(5, 3), (8, 4), (1, 5)], | ||
[(8, 4), (1, 5)], | ||
[(1, 5)], | ||
], | ||
), | ||
( | ||
3, | ||
[ | ||
[(3, 1)], | ||
[(3, 1), (6, 2)], | ||
[(3, 1), (6, 2), (5, 3)], | ||
[(6, 2), (5, 3), (8, 4)], | ||
[(5, 3), (8, 4), (1, 5)], | ||
[(8, 4), (1, 5)], | ||
[(1, 5)], | ||
], | ||
), | ||
( | ||
2, | ||
[ | ||
[(3, 1)], | ||
[(3, 1), (6, 2)], | ||
[(6, 2), (5, 3)], | ||
[(5, 3), (8, 4)], | ||
[(8, 4), (1, 5)], | ||
[(1, 5)], | ||
], | ||
), | ||
( | ||
1, | ||
[ | ||
[(3, 1)], | ||
[(6, 2)], | ||
[(5, 3)], | ||
[(8, 4)], | ||
[(1, 5)], | ||
] | ||
), | ||
], | ||
) | ||
def test_rolling_apply_variable(window_size, expected): | ||
r = ApplyPairwise( | ||
ARRAY_1, | ||
ARRAY_2, | ||
window_size, | ||
window_type="variable", | ||
function=lambda x, y: list(zip(x, y)), | ||
) | ||
assert list(r) == expected | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"array_1,array_2,window_type,expected", | ||
[ | ||
([], [], "fixed", []), | ||
([3], [], "fixed", []), | ||
([], [5], "fixed", []), | ||
([], [], "variable", []), | ||
([1], [1], "variable", [[(1, 1)]]), | ||
], | ||
) | ||
def test_rolling_apply_pairwise_over_short_iterables(array_1, array_2, window_type, expected): | ||
r = ApplyPairwise( | ||
array_1, | ||
array_2, | ||
10, | ||
window_type=window_type, | ||
function=lambda x, y: list(zip(x, y)), | ||
) | ||
assert list(r) == expected |