-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add Rolling Indexed * Add rolling nunique for indexed rolling * Small changes for PR #31 * Update rolling/arithmetic/nunique_indexed.py Co-authored-by: Alex Riley <[email protected]> --------- Co-authored-by: Alex Riley <[email protected]>
- Loading branch information
Showing
5 changed files
with
302 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from collections import deque | ||
|
||
from .base_indexed import RollingIndexed | ||
|
||
class ApplyIndexed(RollingIndexed): | ||
""" | ||
Apply a function to windows over an indexed array | ||
Parameters | ||
---------- | ||
index : the object that will serve as index | ||
iterable : any iterable object | ||
window_size : same type as the index, the maximum size (difference between indices) | ||
of the rolling window moving over the iterable | ||
function : callable, the function to be applied to the current window of each | ||
iterable | ||
""" | ||
|
||
def __init__(self, index, iterable, window_size, function, window_type="variable"): | ||
self._idx_buffer = deque() | ||
self._val_buffer = deque() | ||
self._function = function | ||
super().__init__(index, iterable, window_size, window_type) | ||
|
||
@property | ||
def current_value(self): | ||
return self._function(self._val_buffer) | ||
|
||
def _init_variable(self, **kwargs): | ||
pass | ||
|
||
def _insert(self, idx, val): | ||
if self._idx_buffer and self._idx_buffer[0] > idx: | ||
raise ValueError("Indices should be monotonic") | ||
|
||
self._idx_buffer.append(idx) | ||
self._val_buffer.append(val) | ||
|
||
assert len(self._idx_buffer) == len(self._val_buffer), \ | ||
"Both buffers should have same length" | ||
|
||
def _evict(self, idx): | ||
""" Removes all values whose index is lower or equal than idx | ||
""" | ||
# Keep advancing both iterators until smallest is greater than idx | ||
while self._idx_buffer and self._idx_buffer[0] <= idx: | ||
self._idx_buffer.popleft() | ||
self._val_buffer.popleft() | ||
|
||
assert len(self._idx_buffer) == len(self._val_buffer), \ | ||
"Both buffers should have same length" | ||
|
||
@property | ||
def _obs(self): | ||
return self._idx_buffer[-1] - self._idx_buffer[0] | ||
|
||
def __repr__(self): | ||
return "ApplyIndexed(operation='{}', window_size={}, window_type='{}')".format( | ||
self._function.__name__, self.window_size, self.window_type | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from collections import Counter, deque | ||
|
||
from rolling.base_indexed import RollingIndexed | ||
|
||
class NuniqueIndexed(RollingIndexed): | ||
""" | ||
Iterator object that counts the number of unique values in a rolling | ||
window with an index array | ||
Parameters | ||
---------- | ||
index : the object that will serve as index | ||
iterable : any iterable object | ||
window_size : same type as the index, the maximum size (difference between indices) | ||
of the rolling window moving over the iterable | ||
Complexity | ||
---------- | ||
Update time: O(1) | ||
Memory usage: O(k) | ||
where k is the size of the rolling window (which can potentially be n) | ||
""" | ||
|
||
def __init__(self, index, iterable, window_size, window_type="variable"): | ||
self._idx_buffer = deque() | ||
self._val_buffer = deque() | ||
self._counter = Counter() | ||
self._nunique = 0 | ||
super().__init__(index, iterable, window_size, window_type) | ||
|
||
def _init_variable(self, **kwargs): | ||
pass | ||
|
||
def _insert(self, idx, val): | ||
if self._idx_buffer and self._idx_buffer[0] > idx: | ||
raise ValueError("Indices should be monotonic") | ||
|
||
self._idx_buffer.append(idx) | ||
self._val_buffer.append(val) | ||
if self._counter[val] == 0: | ||
self._nunique += 1 | ||
self._counter[val] += 1 | ||
|
||
assert len(self._idx_buffer) == len(self._val_buffer), \ | ||
"Both buffers should have same length" | ||
|
||
def _evict(self, idx): | ||
""" Removes all values whose index is lower or equal than idx | ||
""" | ||
# Keep advancing both iterators until smallest is greater than idx | ||
while self._idx_buffer and self._idx_buffer[0] <= idx: | ||
self._idx_buffer.popleft() | ||
val = self._val_buffer.popleft() | ||
|
||
self._counter[val] -= 1 | ||
if self._counter[val] == 0: | ||
self._nunique -= 1 | ||
|
||
assert len(self._idx_buffer) == len(self._val_buffer), \ | ||
"Both buffers should have same length" | ||
|
||
@property | ||
def current_value(self): | ||
return self._nunique | ||
|
||
@property | ||
def _obs(self): | ||
return self._idx_buffer[-1] - self._idx_buffer[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import abc | ||
from collections.abc import Iterator | ||
|
||
class RollingIndexed(Iterator): | ||
""" | ||
Baseclass for rolling iterators over _indexed_ or sparse data | ||
""" | ||
|
||
def __init__(self, index, iterable, window_size, window_type="variable", **kwargs): | ||
"""Initialize a base rolling indexed class | ||
Args: | ||
index: Must be a monotonic array of the same type that `window_size`, | ||
it must support the (-) operator. | ||
iterable: Any iterable, must be same length as index | ||
window_size: Max difference between the first and last index of the | ||
stored elements | ||
window_type (str, optional): Defaults to "variable". | ||
Raises: | ||
ValueError: _description_ | ||
""" | ||
self.window_type = window_type | ||
|
||
assert len(index) == len(iterable), \ | ||
"Index and values should have same size" | ||
|
||
self._iterator_index = iter(index) | ||
self._iterator_values = iter(iterable) | ||
self.window_size = window_size | ||
|
||
if window_type == "variable": | ||
self._init_variable(**kwargs) | ||
|
||
else: | ||
raise ValueError(f"Unknown window_type '{window_type}'") | ||
|
||
def _next_variable(self): | ||
newidx = next(self._iterator_index) | ||
newval = next(self._iterator_values) | ||
|
||
self._insert(newidx, newval) | ||
self._evict(newidx - self.window_size) | ||
|
||
return self.current_value | ||
|
||
def __next__(self): | ||
if self.window_type == "variable": | ||
return self._next_variable() | ||
|
||
raise NotImplementedError(f"next() not implemented for {self.window_type}") | ||
|
||
@abc.abstractmethod | ||
def _insert(self, idx, val): | ||
""" Inserts value into the window with index idx. idx is greater that | ||
all indexes received | ||
""" | ||
|
||
@abc.abstractmethod | ||
def _evict(self, idx): | ||
""" Removes all values whose index is lower or equal than idx | ||
""" | ||
|
||
@property | ||
@abc.abstractmethod | ||
def current_value(self): | ||
""" | ||
Return the current value of the window | ||
""" | ||
pass | ||
|
||
@abc.abstractmethod | ||
def _init_variable(self, **kwargs): | ||
""" | ||
Intialise as a variable-size window | ||
""" | ||
pass | ||
|
||
@property | ||
@abc.abstractmethod | ||
def _obs(self): | ||
""" | ||
Return the window size | ||
""" | ||
pass | ||
|
||
def __repr__(self): | ||
return "RollingIndexed(operation='{}', window_size={}, window_type='{}')".format( | ||
self.__class__.__name__, self.window_size, self.window_type | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import datetime as dt | ||
|
||
import pytest | ||
|
||
from rolling.apply_indexed import ApplyIndexed | ||
|
||
ARRAY_1 = [1, 2, 3, 4, 5] | ||
IDX_1 = [0,1,2,4,6] | ||
|
||
@pytest.mark.parametrize( | ||
"window_size,expected", | ||
[ | ||
(1, ARRAY_1), | ||
(2, [1,3,5,4,5]), | ||
(3, [1,3,6,7,9]), | ||
(4, [1,3,6,9,9]), | ||
(5, [1,3,6,10,12]), | ||
(6, [1,3,6,10,14]), | ||
(7, [1,3,6,10,15]), | ||
], | ||
) | ||
def test_rolling_apply_indexed(window_size, expected): | ||
r = ApplyIndexed(IDX_1, ARRAY_1, window_size, function=sum) | ||
assert list(r) == expected | ||
|
||
# Now with datetime | ||
idx_datetime = [dt.datetime(2023, 5, x+1) for x in IDX_1] | ||
ws_timedelta = dt.timedelta(days=window_size) | ||
r = ApplyIndexed(idx_datetime, ARRAY_1, ws_timedelta, function=sum) | ||
assert list(r) == expected | ||
|
||
ARRAY_2 = [1,2,3,4,5,6] | ||
IDX_2 = [0,0,1,2,4,6] | ||
|
||
@pytest.mark.parametrize( | ||
"window_size,expected", | ||
[ | ||
(1, [1,3,3,4,5,6]), | ||
(2, [1,3,6,7,5,6]), | ||
] | ||
) | ||
def test_rolling_apply_repeated(window_size, expected): | ||
r = ApplyIndexed(IDX_2, ARRAY_2, window_size, function=sum) | ||
assert list(r) == expected |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import datetime as dt | ||
|
||
import pytest | ||
|
||
from rolling.apply_indexed import ApplyIndexed | ||
from rolling.arithmetic.nunique_indexed import NuniqueIndexed | ||
|
||
|
||
@pytest.mark.parametrize("word", ["aabbc", "xooxyzzziiismsdd", "jjjjjj", ""]) | ||
@pytest.mark.parametrize("window_size", [1, 2, 3, 4, 5]) | ||
def test_rolling_nunique(word, window_size): | ||
idx = range(len(word)) | ||
got = NuniqueIndexed(idx, word, window_size) | ||
expected = ApplyIndexed( | ||
idx, word, window_size, function=lambda x: len(set(x)) | ||
) | ||
assert list(got) == list(expected) | ||
|
||
|
||
@pytest.mark.parametrize("window_size", [1, 2, 3, 4, 5, 6]) | ||
def test_index_date(window_size): | ||
idx, val = zip(*[ | ||
(dt.datetime(2023,5,1), 'Cat1'), | ||
(dt.datetime(2023,5,2), 'Cat1'), | ||
(dt.datetime(2023,5,2), 'Cat2'), | ||
(dt.datetime(2023,5,3), 'Cat3'), | ||
(dt.datetime(2023,5,6), 'Cat1'), | ||
]) | ||
|
||
ws = dt.timedelta(days=window_size) | ||
|
||
got = NuniqueIndexed(idx, val, ws) | ||
expected = ApplyIndexed(idx, val, ws, function=lambda x: len(set(x))) | ||
|
||
assert list(got) == list(expected) |