diff --git a/tests/test_pipeline/test_grouped_pipeline.py b/tests/test_pipeline/test_grouped_pipeline.py index e49293d..c7c8dda 100644 --- a/tests/test_pipeline/test_grouped_pipeline.py +++ b/tests/test_pipeline/test_grouped_pipeline.py @@ -191,10 +191,28 @@ def test_all_groups_missing_raises(input_df): ) def test_iter_groups_non_consecutive_index(index): group = [1] * 2 + [2] * (len(index) - 2) + target = np.array(group) value = np.random.random(len(index)) input_df = pd.DataFrame( [group, value], index=["group", "value"], columns=index ).T gp = GroupedPipeline(groupby=["group"], pipeline=None) - for key, sub_df, _ in gp._iter_groups(input_df): + for key, sub_df, sub_target in gp._iter_groups(input_df, target): assert (sub_df["group"] == key).all() + assert (sub_target == key).all() + + +@pytest.mark.parametrize( + "index", [[1, 2, 3, 4], pd.date_range("2019-01-01", periods=4)] +) +def test_iter_groups_non_consecutive_index_target_series(index): + group = [1] * 2 + [2] * (len(index) - 2) + target = pd.Series(group, index=index) + value = np.random.random(len(index)) + input_df = pd.DataFrame( + [group, value], index=["group", "value"], columns=index + ).T + gp = GroupedPipeline(groupby=["group"], pipeline=None) + for key, sub_df, sub_target in gp._iter_groups(input_df, target): + assert (sub_df["group"] == key).all() + assert (sub_target == key).all() diff --git a/timeserio/pipeline/pipeline.py b/timeserio/pipeline/pipeline.py index 24e592f..4dbbb57 100644 --- a/timeserio/pipeline/pipeline.py +++ b/timeserio/pipeline/pipeline.py @@ -127,7 +127,11 @@ def _iter_groups(self, df, y=None): groups = df.groupby(self.groupby).indices for key, sub_idx in groups.items(): sub_df = df.iloc[sub_idx] - sub_y = y[sub_idx] if y is not None else None + if y is not None: + # y is either a numpy array or a pd.Series so index accordingly + sub_y = y.iloc[sub_idx] if type(y) is pd.Series else y[sub_idx] + else: + sub_y = None yield key, sub_df, sub_y def fit(self, df, y=None):