From 14d63b5642e35d7f2e5e1894fb510763e9a5d3e8 Mon Sep 17 00:00:00 2001 From: Alex Riley Date: Sat, 11 Feb 2023 14:19:02 +0000 Subject: [PATCH] Add RollingPairwise, ApplyPairwise --- doc/changelog.md | 3 + rolling/apply_pairwise.py | 79 +++++++++++++++++++++ rolling/base_pairwise.py | 129 +++++++++++++++++++++++++++++++++++ tests/test_apply_pairwise.py | 129 +++++++++++++++++++++++++++++++++++ 4 files changed, 340 insertions(+) create mode 100644 rolling/apply_pairwise.py create mode 100644 rolling/base_pairwise.py create mode 100644 tests/test_apply_pairwise.py diff --git a/doc/changelog.md b/doc/changelog.md index 1ad60d1..57ef199 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). +## [0.5.0] +### Added +- New `rolling.ApplyPairwise` object ## [0.4.0] ### Added diff --git a/rolling/apply_pairwise.py b/rolling/apply_pairwise.py new file mode 100644 index 0000000..5d00a3f --- /dev/null +++ b/rolling/apply_pairwise.py @@ -0,0 +1,79 @@ +from collections import deque +from itertools import islice + +from .base_pairwise import RollingPairwise + + +class ApplyPairwise(RollingPairwise): + """ + Apply a binary function to windows over two iterables. + + Parameters + ---------- + + iterable : any iterable object + window_size : integer, the size of the rolling + window moving over the iterable + function : callable, + a binary callable to be applied to the current + window of each iterable + + Complexity + ---------- + + Update time: function dependent + Memory usage: O(k) + + where k is the size of the rolling window + + Example + -------- + + >>> from rolling import ApplyPairwise + >>> from statistics import correlation + >>> seq_1 = [1, 2, 3, 4, 5] + >>> seq_2 = [1, 2, 3, 2, 1] + >>> r_corr = ApplyPairwise(seq_1, seq_2, window_size=3, function=correlation) + >>> list(r_corr) + [1.0, 0.0, -1.0] + + """ + def __init__(self, iterable_1, iterable_2, window_size, function, window_type="fixed"): + self._buffer_1 = deque(maxlen=window_size) + self._buffer_2 = deque(maxlen=window_size) + self._function = function + super().__init__(iterable_1, iterable_2, window_size=window_size, window_type=window_type) + + def _init_fixed(self, **kwargs): + pairs = zip(self._iterator_1, self._iterator_2) + for item_1, item_2 in islice(pairs, self.window_size-1): + self._buffer_1.append(item_1) + self._buffer_2.append(item_2) + + def _init_variable(self, **kwargs): + pass # no action required + + @property + def current_value(self): + return self._function(self._buffer_1, self._buffer_2) + + def _add_new(self, new_1, new_2): + self._buffer_1.append(new_1) + self._buffer_2.append(new_2) + + def _remove_old(self): + self._buffer_1.popleft() + self._buffer_2.popleft() + + def _update_window(self, new_1, new_2): + self._buffer_1.append(new_1) + self._buffer_2.append(new_2) + + @property + def _obs(self): + return len(self._buffer_1) + + def __repr__(self): + return "RollingPairwise(operation='{}', window_size={}, window_type='{}')".format( + self._function.__name__, self.window_size, self.window_type + ) diff --git a/rolling/base_pairwise.py b/rolling/base_pairwise.py new file mode 100644 index 0000000..adda361 --- /dev/null +++ b/rolling/base_pairwise.py @@ -0,0 +1,129 @@ +import abc +from collections.abc import Iterator +from itertools import chain + + +class RollingPairwise(Iterator): + """ + Baseclass for rolling iterators over two iterables. + + """ + def __init__(self, iterable_1, iterable_2, window_size, window_type="fixed", **kwargs): + self.window_type = window_type + self.window_size = _validate_window_size(window_size) + self._iterator_1 = iter(iterable_1) + self._iterator_2 = iter(iterable_2) + self._filled = self.window_type == "fixed" + + if window_type == "fixed": + self._init_fixed(**kwargs) + + elif window_type == "variable": + self._init_variable(**kwargs) + + else: + raise ValueError(f"Unknown window_type '{window_type}'") + + def __repr__(self): + return "RollingPairwise(operation='{}', window_size={}, window_type='{}')".format( + self.__class__.__name__, self.window_size, self.window_type + ) + + def _next_fixed(self): + new_1 = next(self._iterator_1) + new_2 = next(self._iterator_2) + self._update_window(new_1, new_2) + return self.current_value + + def _next_variable(self): + # while the window size is not reached, add new values + if not self._filled and self._obs < self.window_size: + new_1 = next(self._iterator_1) + new_2 = next(self._iterator_2) + self._add_new(new_1, new_2) + self._filled = self._obs == self.window_size + return self.current_value + + # once the window size is reached, consider fixed until iterator ends + try: + return self._next_fixed() + + # if the iterator finishes, remove the oldest values one at a time + except StopIteration: + if self._obs == 1: + raise + else: + self._remove_old() + return self.current_value + + def __next__(self): + + if self.window_type == "fixed": + return self._next_fixed() + + if self.window_type == "variable": + return self._next_variable() + + raise NotImplementedError(f"next() not implemented for {self.window_type}") + + @property + @abc.abstractmethod + def current_value(self): + """ + Return the current value of the window + """ + pass + + @property + @abc.abstractmethod + def _obs(self): + """ + Return the number of observations in the window + """ + pass + + @abc.abstractmethod + def _init_fixed(self, **kwargs): + """ + Intialise as a fixed-size window + """ + pass + + @abc.abstractmethod + def _init_variable(self, **kwargs): + """ + Intialise as a variable-size window + """ + pass + + @abc.abstractmethod + def _remove_old(self): + """ + Remove the oldest value from the window, decreasing window size by 1 + """ + pass + + @abc.abstractmethod + def _add_new(self, new): + """ + Add a new value to the window, increasing window size by 1 + """ + pass + + @abc.abstractmethod + def _update_window(self, new): + """ + Add a new value to the window and remove the oldest value from the window + """ + pass + + +def _validate_window_size(k): + """ + Check if k is a positive integer + """ + if not isinstance(k, int): + raise TypeError(f"window_size must be integer type, got {type(k).__name__}") + if k <= 0: + raise ValueError("window_size must be positive") + return k diff --git a/tests/test_apply_pairwise.py b/tests/test_apply_pairwise.py new file mode 100644 index 0000000..39e61a1 --- /dev/null +++ b/tests/test_apply_pairwise.py @@ -0,0 +1,129 @@ +import pytest + +from rolling.apply_pairwise import ApplyPairwise + +ARRAY_1 = [3, 6, 5, 8, 1] +ARRAY_2 = [1, 2, 3, 4, 5] + + +@pytest.mark.parametrize( + "window_size,expected", + [ + (6, []), + (5, [[(3, 1), (6, 2), (5, 3), (8, 4), (1, 5)]]), + (4, [[(3, 1), (6, 2), (5, 3), (8, 4)], [(6, 2), (5, 3), (8, 4), (1, 5)]]), + (3, [[(3, 1), (6, 2), (5, 3)], [(6, 2), (5, 3), (8, 4)], [(5, 3), (8, 4), (1, 5)]]), + (2, [[(3, 1), (6, 2)], [(6, 2), (5, 3)], [(5, 3), (8, 4)], [(8, 4), (1, 5)]]), + (1, [[(3, 1)], [(6, 2)], [(5, 3)], [(8, 4)], [(1, 5)]]), + ], +) +def test_rolling_apply_pairwise_fixed(window_size, expected): + r = ApplyPairwise(ARRAY_1, ARRAY_2, window_size, function=lambda x, y: list(zip(x, y))) + assert list(r) == expected + + +@pytest.mark.parametrize( + "window_size,expected", + [ + ( + 6, + [ + [(3, 1)], + [(3, 1), (6, 2)], + [(3, 1), (6, 2), (5, 3)], + [(3, 1), (6, 2), (5, 3), (8, 4)], + [(3, 1), (6, 2), (5, 3), (8, 4), (1, 5)], + ] + ), + ( + 5, + [ + [(3, 1)], + [(3, 1), (6, 2)], + [(3, 1), (6, 2), (5, 3)], + [(3, 1), (6, 2), (5, 3), (8, 4)], + [(3, 1), (6, 2), (5, 3), (8, 4), (1, 5)], + [(6, 2), (5, 3), (8, 4), (1, 5)], + [(5, 3), (8, 4), (1, 5)], + [(8, 4), (1, 5)], + [(1, 5)], + ], + ), + ( + 4, + [ + [(3, 1)], + [(3, 1), (6, 2)], + [(3, 1), (6, 2), (5, 3)], + [(3, 1), (6, 2), (5, 3), (8, 4)], + [(6, 2), (5, 3), (8, 4), (1, 5)], + [(5, 3), (8, 4), (1, 5)], + [(8, 4), (1, 5)], + [(1, 5)], + ], + ), + ( + 3, + [ + [(3, 1)], + [(3, 1), (6, 2)], + [(3, 1), (6, 2), (5, 3)], + [(6, 2), (5, 3), (8, 4)], + [(5, 3), (8, 4), (1, 5)], + [(8, 4), (1, 5)], + [(1, 5)], + ], + ), + ( + 2, + [ + [(3, 1)], + [(3, 1), (6, 2)], + [(6, 2), (5, 3)], + [(5, 3), (8, 4)], + [(8, 4), (1, 5)], + [(1, 5)], + ], + ), + ( + 1, + [ + [(3, 1)], + [(6, 2)], + [(5, 3)], + [(8, 4)], + [(1, 5)], + ] + ), + ], +) +def test_rolling_apply_variable(window_size, expected): + r = ApplyPairwise( + ARRAY_1, + ARRAY_2, + window_size, + window_type="variable", + function=lambda x, y: list(zip(x, y)), + ) + assert list(r) == expected + + +@pytest.mark.parametrize( + "array_1,array_2,window_type,expected", + [ + ([], [], "fixed", []), + ([3], [], "fixed", []), + ([], [5], "fixed", []), + ([], [], "variable", []), + ([1], [1], "variable", [[(1, 1)]]), + ], +) +def test_rolling_apply_pairwise_over_short_iterables(array_1, array_2, window_type, expected): + r = ApplyPairwise( + array_1, + array_2, + 10, + window_type=window_type, + function=lambda x, y: list(zip(x, y)), + ) + assert list(r) == expected