From 432da9d80b17a5522d5025077cfb8145ea70165d Mon Sep 17 00:00:00 2001 From: Yang Date: Wed, 20 Sep 2023 13:28:33 +0200 Subject: [PATCH 01/10] consistent input output for traintest --- lilio/traintest.py | 34 +++++++++++++++++++++------------- tests/test_traintest.py | 2 +- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/lilio/traintest.py b/lilio/traintest.py index 61e60fb..41ded9f 100644 --- a/lilio/traintest.py +++ b/lilio/traintest.py @@ -57,7 +57,7 @@ def __init__(self, splitter: type[CVtype]) -> None: def split( self, - *x_args: xr.DataArray, + x_args: Union[xr.DataArray, Iterable[xr.DataArray]], y: Optional[xr.DataArray] = None, dim: str = "anchor_year", ) -> XMaybeY: @@ -65,9 +65,9 @@ def split( Args: x_args: one or multiple xr.DataArray's that share the same - coordinate along the given dimension + coordinate along the given dimension. y: (optional) xr.DataArray that shares the same coordinate along the - given dimension + given dimension. dim: name of the dimension along which to split the data. Returns: @@ -76,7 +76,13 @@ def split( # Check that all inputs share the same dim coordinate coords = [] x: xr.DataArray # Initialize x to set scope outside loop - for x in x_args: + + if isinstance(x_args, xr.DataArray): + x_args_list = [x_args] + else: + x_args_list = list(x_args) + + for x in x_args_list: try: coords.append(x[dim]) except KeyError as err: @@ -96,21 +102,23 @@ def split( if x[dim].size <= 1: raise ValueError( - f"Invalid input: need at least 2 values along dimension {dim}" + f"Invalid input: need at least 2 values along dimension {dim}." ) # Now we know that all inputs are equal. for train_indices, test_indices in self.splitter.split(x[dim]): - if len(x_args) == 1: - x_train: XType = x.isel({dim: train_indices}) - x_test: XType = x.isel({dim: test_indices}) - else: - x_train = [da.isel({dim: train_indices}) for da in x_args] - x_test = [da.isel({dim: test_indices}) for da in x_args] + x_train = [da.isel({dim: train_indices}) for da in x_args_list] + x_test = [da.isel({dim: test_indices}) for da in x_args_list] if y is None: - yield x_train, x_test + if isinstance(x_args, xr.DataArray): + yield x_train.pop(), x_test.pop() + else: + x_train, x_test else: y_train = y.isel({dim: train_indices}) y_test = y.isel({dim: test_indices}) - yield x_train, x_test, y_train, y_test + if isinstance(x_args, xr.DataArray): + yield x_train.pop(), x_test.pop(), y_train, y_test + else: + yield x_train, x_test, y_train, y_test diff --git a/tests/test_traintest.py b/tests/test_traintest.py index 0edcf5f..afc3336 100644 --- a/tests/test_traintest.py +++ b/tests/test_traintest.py @@ -57,7 +57,7 @@ def test_kfold_xxy(dummy_data): """Correctly split x1, x2, and y.""" x1, x2, y = dummy_data cv = lilio.traintest.TrainTestSplit(KFold(n_splits=3)) - x_train, x_test, y_train, y_test = next(cv.split(x1, x2, y=y)) + x_train, x_test, y_train, y_test = next(cv.split([x1, x2], y=y)) expected_train = [2019, 2020, 2021, 2022] expected_test = [2016, 2017, 2018] From 6ac822fdc38a3d954268e51d0cc1d98d6b0d09ef Mon Sep 17 00:00:00 2001 From: Yang Date: Wed, 20 Sep 2023 13:56:23 +0200 Subject: [PATCH 02/10] add more tests --- lilio/traintest.py | 2 +- tests/test_traintest.py | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/lilio/traintest.py b/lilio/traintest.py index 41ded9f..f2a56c3 100644 --- a/lilio/traintest.py +++ b/lilio/traintest.py @@ -114,7 +114,7 @@ def split( if isinstance(x_args, xr.DataArray): yield x_train.pop(), x_test.pop() else: - x_train, x_test + yield x_train, x_test else: y_train = y.isel({dim: train_indices}) y_test = y.isel({dim: test_indices}) diff --git a/tests/test_traintest.py b/tests/test_traintest.py index afc3336..d12e0e6 100644 --- a/tests/test_traintest.py +++ b/tests/test_traintest.py @@ -39,6 +39,18 @@ def test_kfold_x(dummy_data): xr.testing.assert_equal(x_test, x1.sel(anchor_year=expected_test)) +def test_kfold_x_list(dummy_data): + """Correctly split x.""" + x1, _, _ = dummy_data + cv = lilio.traintest.TrainTestSplit(KFold(n_splits=3)) + x_train, x_test = next(cv.split([x1])) + expected_train = [2019, 2020, 2021, 2022] + expected_test = [2016, 2017, 2018] + assert isinstance(x_train, list) + assert np.array_equal(x_train[0].anchor_year, expected_train) + xr.testing.assert_equal(x_test[0], x1.sel(anchor_year=expected_test)) + + def test_kfold_xy(dummy_data): """Correctly split x and y.""" x1, _, y = dummy_data @@ -67,6 +79,21 @@ def test_kfold_xxy(dummy_data): xr.testing.assert_equal(y_test, y.sel(anchor_year=expected_test)) +def test_kfold_xxy_tuple(dummy_data): + """Correctly split x1, x2, and y.""" + x1, x2, y = dummy_data + cv = lilio.traintest.TrainTestSplit(KFold(n_splits=3)) + x_train, x_test, y_train, y_test = next(cv.split((x1, x2), y=y)) + expected_train = [2019, 2020, 2021, 2022] + expected_test = [2016, 2017, 2018] + + assert isinstance(x_train, list) # all iterable will be turned into list + assert np.array_equal(x_train[0].anchor_year, expected_train) + xr.testing.assert_equal(x_test[1], x2.sel(anchor_year=expected_test)) + assert np.array_equal(y_train.anchor_year, expected_train) + xr.testing.assert_equal(y_test, y.sel(anchor_year=expected_test)) + + def test_kfold_too_short(dummy_data): "Fail if there is only a single anchor year: no splits can be made" x1, _, _ = dummy_data From f3a8a2d538ae9235a62059ea5fbb4ff97659d227 Mon Sep 17 00:00:00 2001 From: Yang Date: Wed, 20 Sep 2023 13:59:11 +0200 Subject: [PATCH 03/10] update notebook --- docs/notebooks/tutorial_traintest.ipynb | 96 +++++++++++++++++-------- 1 file changed, 65 insertions(+), 31 deletions(-) diff --git a/docs/notebooks/tutorial_traintest.ipynb b/docs/notebooks/tutorial_traintest.ipynb index 52d5229..1c1c145 100644 --- a/docs/notebooks/tutorial_traintest.ipynb +++ b/docs/notebooks/tutorial_traintest.ipynb @@ -16,9 +16,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "-0.3582 1.426 -0.9996 -1.108 3.111 0.4837 ... 1.37 -0.462 0.2982 0.2441 -1.63\n", + "Coordinates:\n", + " * time (time) datetime64[ns] 2015-10-20 2015-12-19 ... 2023-11-07\n" + ] + } + ], "source": [ "import numpy as np\n", "import pandas as pd\n", @@ -47,9 +58,29 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "0.6981 -0.9089 -0.3391 -1.239 -0.6261 ... 1.038 0.574 1.023 -0.09789 0.6119\n", + "Coordinates:\n", + " * anchor_year (anchor_year) int64 2016 2017 2018 2019 2020 2021 2022\n", + " * i_interval (i_interval) int64 -1 1\n", + " left_bound (anchor_year, i_interval) datetime64[ns] 2016-04-18 ... 2022...\n", + " right_bound (anchor_year, i_interval) datetime64[ns] 2016-10-15 ... 2023...\n", + " is_target (i_interval) bool False True\n", + "Attributes:\n", + " lilio_version: 0.4.1\n", + " lilio_calendar_anchor_date: 10-15\n", + " lilio_calendar_code: Calendar(\\n anchor='10-15',\\n allow_ov...\n", + " history: 2023-09-20 11:46:51 UTC - Resampled with a L...\n" + ] + } + ], "source": [ "calendar = lilio.daily_calendar(anchor=\"10-15\", length=\"180d\")\n", "calendar.map_to_data(x1)\n", @@ -73,9 +104,35 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train: [2019 2020 2021 2022]\n", + "Test: [2016 2017 2018]\n", + "Train: [2016 2017 2018 2021 2022]\n", + "Test: [2019 2020]\n", + "Train: [2016 2017 2018 2019 2020]\n", + "Test: [2021 2022]\n", + "\n", + "0.6981 -0.9089 -0.3391 -1.239 -0.6261 -0.07158 -0.2167 -0.05942 -0.07224 1.038\n", + "Coordinates:\n", + " * anchor_year (anchor_year) int64 2016 2017 2018 2019 2020\n", + " * i_interval (i_interval) int64 -1 1\n", + " left_bound (anchor_year, i_interval) datetime64[ns] 2016-04-18 ... 2020...\n", + " right_bound (anchor_year, i_interval) datetime64[ns] 2016-10-15 ... 2021...\n", + " is_target (i_interval) bool False True\n", + "Attributes:\n", + " lilio_version: 0.4.1\n", + " lilio_calendar_anchor_date: 10-15\n", + " lilio_calendar_code: Calendar(\\n anchor='10-15',\\n allow_ov...\n", + " history: 2023-09-20 11:46:51 UTC - Resampled with a L...\n" + ] + } + ], "source": [ "# Cross-validation\n", "from sklearn.model_selection import KFold\n", @@ -83,36 +140,13 @@ "\n", "kfold = KFold(n_splits=3)\n", "cv = lilio.traintest.TrainTestSplit(kfold)\n", - "for (x1_train, x2_train), (x1_test, x2_test), y_train, y_test in cv.split(x1, x2, y=y):\n", + "for (x1_train, x2_train), (x1_test, x2_test), y_train, y_test in cv.split([x1, x2], y=y):\n", " print(\"Train:\", x1_train.anchor_year.values)\n", " print(\"Test:\", x1_test.anchor_year.values)\n", "\n", "print(x1_train)" ] }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "With an alternative notation we can make this more compact:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Alternative using shorthand notation\n", - "x = [x1, x2]\n", - "for x_train, x_test, y_train, y_test in cv.split(*x, y=y):\n", - " x1_train, x2_train = x_train\n", - " x1_test, x2_test = x_test\n", - " print(\"Train:\", x1_train.anchor_year.values)\n", - " print(\"Test:\", x1_test.anchor_year.values)" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -143,7 +177,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.5" + "version": "3.10.12" }, "vscode": { "interpreter": { From 98803f56948e2602c69a6bdc2215cf80a643095a Mon Sep 17 00:00:00 2001 From: Yang Date: Wed, 20 Sep 2023 14:23:57 +0200 Subject: [PATCH 04/10] split checks and make linter happy --- lilio/traintest.py | 58 ++++++++++++++++++++++++++++------------- tests/test_traintest.py | 2 +- 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/lilio/traintest.py b/lilio/traintest.py index f2a56c3..bdb5ca7 100644 --- a/lilio/traintest.py +++ b/lilio/traintest.py @@ -73,6 +73,44 @@ def split( Returns: Iterator over the splits """ + x_args_list, x = self._check_dimension_and_type(x_args, y, dim) + + # Now we know that all inputs are equal. + for train_indices, test_indices in self.splitter.split(x[dim]): + x_train = [da.isel({dim: train_indices}) for da in x_args_list] + x_test = [da.isel({dim: test_indices}) for da in x_args_list] + + if y is None: + if isinstance(x_args, xr.DataArray): + yield x_train.pop(), x_test.pop() + else: + yield x_train, x_test + else: + y_train = y.isel({dim: train_indices}) + y_test = y.isel({dim: test_indices}) + if isinstance(x_args, xr.DataArray): + yield x_train.pop(), x_test.pop(), y_train, y_test + else: + yield x_train, x_test, y_train, y_test + + def _check_dimension_and_type( + self, + x_args: Union[xr.DataArray, Iterable[xr.DataArray]], + y: Optional[xr.DataArray] = None, + dim: str = "anchor_year", + ): + """Check input dimensions and type. + + Args: + x_args: one or multiple xr.DataArray's that share the same + coordinate along the given dimension. + y: (optional) xr.DataArray that shares the same coordinate along the + given dimension. + dim: name of the dimension along which to split the data. + + Returns: + List of input x and dataarray containing coordinate info + """ # Check that all inputs share the same dim coordinate coords = [] x: xr.DataArray # Initialize x to set scope outside loop @@ -81,7 +119,7 @@ def split( x_args_list = [x_args] else: x_args_list = list(x_args) - + for x in x_args_list: try: coords.append(x[dim]) @@ -105,20 +143,4 @@ def split( f"Invalid input: need at least 2 values along dimension {dim}." ) - # Now we know that all inputs are equal. - for train_indices, test_indices in self.splitter.split(x[dim]): - x_train = [da.isel({dim: train_indices}) for da in x_args_list] - x_test = [da.isel({dim: test_indices}) for da in x_args_list] - - if y is None: - if isinstance(x_args, xr.DataArray): - yield x_train.pop(), x_test.pop() - else: - yield x_train, x_test - else: - y_train = y.isel({dim: train_indices}) - y_test = y.isel({dim: test_indices}) - if isinstance(x_args, xr.DataArray): - yield x_train.pop(), x_test.pop(), y_train, y_test - else: - yield x_train, x_test, y_train, y_test + return x_args_list, x diff --git a/tests/test_traintest.py b/tests/test_traintest.py index d12e0e6..f8d6f0c 100644 --- a/tests/test_traintest.py +++ b/tests/test_traintest.py @@ -87,7 +87,7 @@ def test_kfold_xxy_tuple(dummy_data): expected_train = [2019, 2020, 2021, 2022] expected_test = [2016, 2017, 2018] - assert isinstance(x_train, list) # all iterable will be turned into list + assert isinstance(x_train, list) # all iterable will be turned into list assert np.array_equal(x_train[0].anchor_year, expected_train) xr.testing.assert_equal(x_test[1], x2.sel(anchor_year=expected_test)) assert np.array_equal(y_train.anchor_year, expected_train) From d3dba0c86d8d56bc71085e9eeee63be3bf9f7356 Mon Sep 17 00:00:00 2001 From: Yang Date: Wed, 20 Sep 2023 14:46:56 +0200 Subject: [PATCH 05/10] fix rendering of notebook for docs --- docs/notebooks/tutorial_traintest.ipynb | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/docs/notebooks/tutorial_traintest.ipynb b/docs/notebooks/tutorial_traintest.ipynb index 1c1c145..05eeaca 100644 --- a/docs/notebooks/tutorial_traintest.ipynb +++ b/docs/notebooks/tutorial_traintest.ipynb @@ -24,7 +24,7 @@ "output_type": "stream", "text": [ "\n", - "-0.3582 1.426 -0.9996 -1.108 3.111 0.4837 ... 1.37 -0.462 0.2982 0.2441 -1.63\n", + "0.0826 0.5005 0.1339 -0.2692 -0.3126 ... -2.013 -0.5182 -0.1175 0.1007 -0.7869\n", "Coordinates:\n", " * time (time) datetime64[ns] 2015-10-20 2015-12-19 ... 2023-11-07\n" ] @@ -66,7 +66,7 @@ "output_type": "stream", "text": [ "\n", - "0.6981 -0.9089 -0.3391 -1.239 -0.6261 ... 1.038 0.574 1.023 -0.09789 0.6119\n", + "-0.5155 0.06111 -0.09011 0.134 0.2794 ... -0.5539 -1.069 1.103 0.2899 -1.148\n", "Coordinates:\n", " * anchor_year (anchor_year) int64 2016 2017 2018 2019 2020 2021 2022\n", " * i_interval (i_interval) int64 -1 1\n", @@ -77,7 +77,7 @@ " lilio_version: 0.4.1\n", " lilio_calendar_anchor_date: 10-15\n", " lilio_calendar_code: Calendar(\\n anchor='10-15',\\n allow_ov...\n", - " history: 2023-09-20 11:46:51 UTC - Resampled with a L...\n" + " history: 2023-09-20 12:20:07 UTC - Resampled with a L...\n" ] } ], @@ -118,7 +118,7 @@ "Train: [2016 2017 2018 2019 2020]\n", "Test: [2021 2022]\n", "\n", - "0.6981 -0.9089 -0.3391 -1.239 -0.6261 -0.07158 -0.2167 -0.05942 -0.07224 1.038\n", + "-0.5155 0.06111 -0.09011 0.134 0.2794 -0.3576 -0.08088 -0.4669 0.04973 -0.5539\n", "Coordinates:\n", " * anchor_year (anchor_year) int64 2016 2017 2018 2019 2020\n", " * i_interval (i_interval) int64 -1 1\n", @@ -129,7 +129,7 @@ " lilio_version: 0.4.1\n", " lilio_calendar_anchor_date: 10-15\n", " lilio_calendar_code: Calendar(\\n anchor='10-15',\\n allow_ov...\n", - " history: 2023-09-20 11:46:51 UTC - Resampled with a L...\n" + " history: 2023-09-20 12:20:07 UTC - Resampled with a L...\n" ] } ], @@ -154,11 +154,6 @@ "source": [ "Now you are ready to train your models!" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] } ], "metadata": { From bf6402351343569fbfd8e7afdeaa2e0fb55868ad Mon Sep 17 00:00:00 2001 From: Yang Date: Wed, 20 Sep 2023 15:06:20 +0200 Subject: [PATCH 06/10] Revert "fix rendering of notebook for docs" This reverts commit d3dba0c86d8d56bc71085e9eeee63be3bf9f7356. --- docs/notebooks/tutorial_traintest.ipynb | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/docs/notebooks/tutorial_traintest.ipynb b/docs/notebooks/tutorial_traintest.ipynb index 05eeaca..1c1c145 100644 --- a/docs/notebooks/tutorial_traintest.ipynb +++ b/docs/notebooks/tutorial_traintest.ipynb @@ -24,7 +24,7 @@ "output_type": "stream", "text": [ "\n", - "0.0826 0.5005 0.1339 -0.2692 -0.3126 ... -2.013 -0.5182 -0.1175 0.1007 -0.7869\n", + "-0.3582 1.426 -0.9996 -1.108 3.111 0.4837 ... 1.37 -0.462 0.2982 0.2441 -1.63\n", "Coordinates:\n", " * time (time) datetime64[ns] 2015-10-20 2015-12-19 ... 2023-11-07\n" ] @@ -66,7 +66,7 @@ "output_type": "stream", "text": [ "\n", - "-0.5155 0.06111 -0.09011 0.134 0.2794 ... -0.5539 -1.069 1.103 0.2899 -1.148\n", + "0.6981 -0.9089 -0.3391 -1.239 -0.6261 ... 1.038 0.574 1.023 -0.09789 0.6119\n", "Coordinates:\n", " * anchor_year (anchor_year) int64 2016 2017 2018 2019 2020 2021 2022\n", " * i_interval (i_interval) int64 -1 1\n", @@ -77,7 +77,7 @@ " lilio_version: 0.4.1\n", " lilio_calendar_anchor_date: 10-15\n", " lilio_calendar_code: Calendar(\\n anchor='10-15',\\n allow_ov...\n", - " history: 2023-09-20 12:20:07 UTC - Resampled with a L...\n" + " history: 2023-09-20 11:46:51 UTC - Resampled with a L...\n" ] } ], @@ -118,7 +118,7 @@ "Train: [2016 2017 2018 2019 2020]\n", "Test: [2021 2022]\n", "\n", - "-0.5155 0.06111 -0.09011 0.134 0.2794 -0.3576 -0.08088 -0.4669 0.04973 -0.5539\n", + "0.6981 -0.9089 -0.3391 -1.239 -0.6261 -0.07158 -0.2167 -0.05942 -0.07224 1.038\n", "Coordinates:\n", " * anchor_year (anchor_year) int64 2016 2017 2018 2019 2020\n", " * i_interval (i_interval) int64 -1 1\n", @@ -129,7 +129,7 @@ " lilio_version: 0.4.1\n", " lilio_calendar_anchor_date: 10-15\n", " lilio_calendar_code: Calendar(\\n anchor='10-15',\\n allow_ov...\n", - " history: 2023-09-20 12:20:07 UTC - Resampled with a L...\n" + " history: 2023-09-20 11:46:51 UTC - Resampled with a L...\n" ] } ], @@ -154,6 +154,11 @@ "source": [ "Now you are ready to train your models!" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] } ], "metadata": { From fe310974f8b0eba6357f29db16b6dee965db5ecb Mon Sep 17 00:00:00 2001 From: Yang Date: Wed, 20 Sep 2023 15:51:12 +0200 Subject: [PATCH 07/10] empty the notebook --- docs/notebooks/tutorial_traintest.ipynb | 69 +++---------------------- 1 file changed, 6 insertions(+), 63 deletions(-) diff --git a/docs/notebooks/tutorial_traintest.ipynb b/docs/notebooks/tutorial_traintest.ipynb index 1c1c145..c2984bf 100644 --- a/docs/notebooks/tutorial_traintest.ipynb +++ b/docs/notebooks/tutorial_traintest.ipynb @@ -16,20 +16,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "-0.3582 1.426 -0.9996 -1.108 3.111 0.4837 ... 1.37 -0.462 0.2982 0.2441 -1.63\n", - "Coordinates:\n", - " * time (time) datetime64[ns] 2015-10-20 2015-12-19 ... 2023-11-07\n" - ] - } - ], + "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", @@ -58,29 +47,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "0.6981 -0.9089 -0.3391 -1.239 -0.6261 ... 1.038 0.574 1.023 -0.09789 0.6119\n", - "Coordinates:\n", - " * anchor_year (anchor_year) int64 2016 2017 2018 2019 2020 2021 2022\n", - " * i_interval (i_interval) int64 -1 1\n", - " left_bound (anchor_year, i_interval) datetime64[ns] 2016-04-18 ... 2022...\n", - " right_bound (anchor_year, i_interval) datetime64[ns] 2016-10-15 ... 2023...\n", - " is_target (i_interval) bool False True\n", - "Attributes:\n", - " lilio_version: 0.4.1\n", - " lilio_calendar_anchor_date: 10-15\n", - " lilio_calendar_code: Calendar(\\n anchor='10-15',\\n allow_ov...\n", - " history: 2023-09-20 11:46:51 UTC - Resampled with a L...\n" - ] - } - ], + "outputs": [], "source": [ "calendar = lilio.daily_calendar(anchor=\"10-15\", length=\"180d\")\n", "calendar.map_to_data(x1)\n", @@ -104,35 +73,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train: [2019 2020 2021 2022]\n", - "Test: [2016 2017 2018]\n", - "Train: [2016 2017 2018 2021 2022]\n", - "Test: [2019 2020]\n", - "Train: [2016 2017 2018 2019 2020]\n", - "Test: [2021 2022]\n", - "\n", - "0.6981 -0.9089 -0.3391 -1.239 -0.6261 -0.07158 -0.2167 -0.05942 -0.07224 1.038\n", - "Coordinates:\n", - " * anchor_year (anchor_year) int64 2016 2017 2018 2019 2020\n", - " * i_interval (i_interval) int64 -1 1\n", - " left_bound (anchor_year, i_interval) datetime64[ns] 2016-04-18 ... 2020...\n", - " right_bound (anchor_year, i_interval) datetime64[ns] 2016-10-15 ... 2021...\n", - " is_target (i_interval) bool False True\n", - "Attributes:\n", - " lilio_version: 0.4.1\n", - " lilio_calendar_anchor_date: 10-15\n", - " lilio_calendar_code: Calendar(\\n anchor='10-15',\\n allow_ov...\n", - " history: 2023-09-20 11:46:51 UTC - Resampled with a L...\n" - ] - } - ], + "outputs": [], "source": [ "# Cross-validation\n", "from sklearn.model_selection import KFold\n", From 65b82c03373ab510449486a6e3aa25899ef1a922 Mon Sep 17 00:00:00 2001 From: Yang Date: Thu, 28 Sep 2023 10:45:26 +0200 Subject: [PATCH 08/10] revise based on the review --- lilio/traintest.py | 41 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/lilio/traintest.py b/lilio/traintest.py index bdb5ca7..205ec7a 100644 --- a/lilio/traintest.py +++ b/lilio/traintest.py @@ -5,6 +5,7 @@ from collections.abc import Iterable from typing import Optional from typing import Union +from typing import overload import numpy as np import xarray as xr from sklearn.model_selection._split import BaseCrossValidator @@ -55,12 +56,46 @@ def __init__(self, splitter: type[CVtype]) -> None: """ self.splitter = splitter + @overload + def split( + self, + x_args: xr.DataArray, + y: Optional[xr.DataArray] = None, + dim: str = "anchor_year", + ) -> Iterable[tuple[xr.DataArray, xr.DataArray, xr.DataArray, xr.DataArray]]: + ... + + @overload + def split( + self, + x_args: Iterable[xr.DataArray], + y: Optional[xr.DataArray] = None, + dim: str = "anchor_year", + ) -> Iterable[ + tuple[ + Iterable[xr.DataArray], Iterable[xr.DataArray], xr.DataArray, xr.DataArray + ] + ]: + ... + def split( self, x_args: Union[xr.DataArray, Iterable[xr.DataArray]], y: Optional[xr.DataArray] = None, dim: str = "anchor_year", - ) -> XMaybeY: + ) -> Iterable[ + Union[ + tuple[xr.DataArray, xr.DataArray], + tuple[xr.DataArray, xr.DataArray, xr.DataArray, xr.DataArray], + tuple[Iterable[xr.DataArray], Iterable[xr.DataArray]], + tuple[ + Iterable[xr.DataArray], + Iterable[xr.DataArray], + xr.DataArray, + xr.DataArray, + ], + ] + ]: """Iterate over splits. Args: @@ -98,8 +133,8 @@ def _check_dimension_and_type( x_args: Union[xr.DataArray, Iterable[xr.DataArray]], y: Optional[xr.DataArray] = None, dim: str = "anchor_year", - ): - """Check input dimensions and type. + ) -> tuple[list[xr.DataArray], xr.DataArray]: + """Check input dimensions and type and return input as list. Args: x_args: one or multiple xr.DataArray's that share the same From 3d2d0ecabbf3c4f093b78eaa9a9fc0fdf7414680 Mon Sep 17 00:00:00 2001 From: Yang Date: Thu, 28 Sep 2023 11:02:42 +0200 Subject: [PATCH 09/10] remove unused type aliases --- lilio/traintest.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/lilio/traintest.py b/lilio/traintest.py index 205ec7a..f0ce2e2 100644 --- a/lilio/traintest.py +++ b/lilio/traintest.py @@ -13,14 +13,8 @@ # Mypy type aliases -XType = Union[xr.DataArray, list[xr.DataArray]] CVtype = Union[BaseCrossValidator, BaseShuffleSplit] -# For output types, variables are split in 2 -XOnly = tuple[XType, XType] -XAndY = tuple[XType, XType, xr.DataArray, xr.DataArray] -XMaybeY = Iterable[Union[XOnly, XAndY]] - class CoordinateMismatchError(Exception): """Custom exception for unmatching coordinates.""" From 27a7812df388103d3289a3aa6014053520c63782 Mon Sep 17 00:00:00 2001 From: Yang Date: Thu, 28 Sep 2023 11:14:31 +0200 Subject: [PATCH 10/10] update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 574d300..152f572 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/). ## [Unreleased] +### Changed +- Consistent output type of train-test split as input ([#62](https://github.com/AI4S2S/lilio/pull/62)). + ## 0.4.1 (2023-09-11) ### Added - Python 3.11 support ([#60](https://github.com/AI4S2S/lilio/pull/60)).