Skip to content

Commit

Permalink
Merge pull request #39 from octoenergy/fix-grouped-pipeline
Browse files Browse the repository at this point in the history
Fix GroupedPipeline bug selecting from pd.Series
  • Loading branch information
Ali Teeney authored Aug 5, 2021
2 parents 33728b5 + 3ce7b8d commit 44f05ac
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
20 changes: 19 additions & 1 deletion tests/test_pipeline/test_grouped_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,28 @@ def test_all_groups_missing_raises(input_df):
)
def test_iter_groups_non_consecutive_index(index):
group = [1] * 2 + [2] * (len(index) - 2)
target = np.array(group)
value = np.random.random(len(index))
input_df = pd.DataFrame(
[group, value], index=["group", "value"], columns=index
).T
gp = GroupedPipeline(groupby=["group"], pipeline=None)
for key, sub_df, _ in gp._iter_groups(input_df):
for key, sub_df, sub_target in gp._iter_groups(input_df, target):
assert (sub_df["group"] == key).all()
assert (sub_target == key).all()


@pytest.mark.parametrize(
"index", [[1, 2, 3, 4], pd.date_range("2019-01-01", periods=4)]
)
def test_iter_groups_non_consecutive_index_target_series(index):
group = [1] * 2 + [2] * (len(index) - 2)
target = pd.Series(group, index=index)
value = np.random.random(len(index))
input_df = pd.DataFrame(
[group, value], index=["group", "value"], columns=index
).T
gp = GroupedPipeline(groupby=["group"], pipeline=None)
for key, sub_df, sub_target in gp._iter_groups(input_df, target):
assert (sub_df["group"] == key).all()
assert (sub_target == key).all()
6 changes: 5 additions & 1 deletion timeserio/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@ def _iter_groups(self, df, y=None):
groups = df.groupby(self.groupby).indices
for key, sub_idx in groups.items():
sub_df = df.iloc[sub_idx]
sub_y = y[sub_idx] if y is not None else None
if y is not None:
# y is either a numpy array or a pd.Series so index accordingly
sub_y = y.iloc[sub_idx] if type(y) is pd.Series else y[sub_idx]
else:
sub_y = None
yield key, sub_df, sub_y

def fit(self, df, y=None):
Expand Down

0 comments on commit 44f05ac

Please sign in to comment.