Skip to content

Commit

Permalink
Merge pull request #62 from AI4S2S/consistent_output_traintest
Browse files Browse the repository at this point in the history
Consistent output traintest
  • Loading branch information
Yang authored Sep 29, 2023
2 parents 628a6fc + 27a7812 commit ee8ff0a
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 53 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/).

## [Unreleased]

### Changed
- Consistent output type of train-test split as input ([#62](https://github.com/AI4S2S/lilio/pull/62)).

## 0.4.1 (2023-09-11)
### Added
- Python 3.11 support ([#60](https://github.com/AI4S2S/lilio/pull/60)).
Expand Down
27 changes: 2 additions & 25 deletions docs/notebooks/tutorial_traintest.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -83,36 +83,13 @@
"\n",
"kfold = KFold(n_splits=3)\n",
"cv = lilio.traintest.TrainTestSplit(kfold)\n",
"for (x1_train, x2_train), (x1_test, x2_test), y_train, y_test in cv.split(x1, x2, y=y):\n",
"for (x1_train, x2_train), (x1_test, x2_test), y_train, y_test in cv.split([x1, x2], y=y):\n",
" print(\"Train:\", x1_train.anchor_year.values)\n",
" print(\"Test:\", x1_test.anchor_year.values)\n",
"\n",
"print(x1_train)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"With an alternative notation we can make this more compact:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Alternative using shorthand notation\n",
"x = [x1, x2]\n",
"for x_train, x_test, y_train, y_test in cv.split(*x, y=y):\n",
" x1_train, x2_train = x_train\n",
" x1_test, x2_test = x_test\n",
" print(\"Train:\", x1_train.anchor_year.values)\n",
" print(\"Test:\", x1_test.anchor_year.values)"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down Expand Up @@ -143,7 +120,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.5"
"version": "3.10.12"
},
"vscode": {
"interpreter": {
Expand Down
113 changes: 86 additions & 27 deletions lilio/traintest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,16 @@
from collections.abc import Iterable
from typing import Optional
from typing import Union
from typing import overload
import numpy as np
import xarray as xr
from sklearn.model_selection._split import BaseCrossValidator
from sklearn.model_selection._split import BaseShuffleSplit


# Mypy type aliases
XType = Union[xr.DataArray, list[xr.DataArray]]
CVtype = Union[BaseCrossValidator, BaseShuffleSplit]

# For output types, variables are split in 2
XOnly = tuple[XType, XType]
XAndY = tuple[XType, XType, xr.DataArray, xr.DataArray]
XMaybeY = Iterable[Union[XOnly, XAndY]]


class CoordinateMismatchError(Exception):
"""Custom exception for unmatching coordinates."""
Expand Down Expand Up @@ -55,28 +50,106 @@ def __init__(self, splitter: type[CVtype]) -> None:
"""
self.splitter = splitter

@overload
def split(
self,
x_args: xr.DataArray,
y: Optional[xr.DataArray] = None,
dim: str = "anchor_year",
) -> Iterable[tuple[xr.DataArray, xr.DataArray, xr.DataArray, xr.DataArray]]:
...

@overload
def split(
self,
x_args: Iterable[xr.DataArray],
y: Optional[xr.DataArray] = None,
dim: str = "anchor_year",
) -> Iterable[
tuple[
Iterable[xr.DataArray], Iterable[xr.DataArray], xr.DataArray, xr.DataArray
]
]:
...

def split(
self,
*x_args: xr.DataArray,
x_args: Union[xr.DataArray, Iterable[xr.DataArray]],
y: Optional[xr.DataArray] = None,
dim: str = "anchor_year",
) -> XMaybeY:
) -> Iterable[
Union[
tuple[xr.DataArray, xr.DataArray],
tuple[xr.DataArray, xr.DataArray, xr.DataArray, xr.DataArray],
tuple[Iterable[xr.DataArray], Iterable[xr.DataArray]],
tuple[
Iterable[xr.DataArray],
Iterable[xr.DataArray],
xr.DataArray,
xr.DataArray,
],
]
]:
"""Iterate over splits.
Args:
x_args: one or multiple xr.DataArray's that share the same
coordinate along the given dimension
coordinate along the given dimension.
y: (optional) xr.DataArray that shares the same coordinate along the
given dimension
given dimension.
dim: name of the dimension along which to split the data.
Returns:
Iterator over the splits
"""
x_args_list, x = self._check_dimension_and_type(x_args, y, dim)

# Now we know that all inputs are equal.
for train_indices, test_indices in self.splitter.split(x[dim]):
x_train = [da.isel({dim: train_indices}) for da in x_args_list]
x_test = [da.isel({dim: test_indices}) for da in x_args_list]

if y is None:
if isinstance(x_args, xr.DataArray):
yield x_train.pop(), x_test.pop()
else:
yield x_train, x_test
else:
y_train = y.isel({dim: train_indices})
y_test = y.isel({dim: test_indices})
if isinstance(x_args, xr.DataArray):
yield x_train.pop(), x_test.pop(), y_train, y_test
else:
yield x_train, x_test, y_train, y_test

def _check_dimension_and_type(
self,
x_args: Union[xr.DataArray, Iterable[xr.DataArray]],
y: Optional[xr.DataArray] = None,
dim: str = "anchor_year",
) -> tuple[list[xr.DataArray], xr.DataArray]:
"""Check input dimensions and type and return input as list.
Args:
x_args: one or multiple xr.DataArray's that share the same
coordinate along the given dimension.
y: (optional) xr.DataArray that shares the same coordinate along the
given dimension.
dim: name of the dimension along which to split the data.
Returns:
List of input x and dataarray containing coordinate info
"""
# Check that all inputs share the same dim coordinate
coords = []
x: xr.DataArray # Initialize x to set scope outside loop
for x in x_args:

if isinstance(x_args, xr.DataArray):
x_args_list = [x_args]
else:
x_args_list = list(x_args)

for x in x_args_list:
try:
coords.append(x[dim])
except KeyError as err:
Expand All @@ -96,21 +169,7 @@ def split(

if x[dim].size <= 1:
raise ValueError(
f"Invalid input: need at least 2 values along dimension {dim}"
f"Invalid input: need at least 2 values along dimension {dim}."
)

# Now we know that all inputs are equal.
for train_indices, test_indices in self.splitter.split(x[dim]):
if len(x_args) == 1:
x_train: XType = x.isel({dim: train_indices})
x_test: XType = x.isel({dim: test_indices})
else:
x_train = [da.isel({dim: train_indices}) for da in x_args]
x_test = [da.isel({dim: test_indices}) for da in x_args]

if y is None:
yield x_train, x_test
else:
y_train = y.isel({dim: train_indices})
y_test = y.isel({dim: test_indices})
yield x_train, x_test, y_train, y_test
return x_args_list, x
29 changes: 28 additions & 1 deletion tests/test_traintest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ def test_kfold_x(dummy_data):
xr.testing.assert_equal(x_test, x1.sel(anchor_year=expected_test))


def test_kfold_x_list(dummy_data):
"""Correctly split x."""
x1, _, _ = dummy_data
cv = lilio.traintest.TrainTestSplit(KFold(n_splits=3))
x_train, x_test = next(cv.split([x1]))
expected_train = [2019, 2020, 2021, 2022]
expected_test = [2016, 2017, 2018]
assert isinstance(x_train, list)
assert np.array_equal(x_train[0].anchor_year, expected_train)
xr.testing.assert_equal(x_test[0], x1.sel(anchor_year=expected_test))


def test_kfold_xy(dummy_data):
"""Correctly split x and y."""
x1, _, y = dummy_data
Expand All @@ -57,10 +69,25 @@ def test_kfold_xxy(dummy_data):
"""Correctly split x1, x2, and y."""
x1, x2, y = dummy_data
cv = lilio.traintest.TrainTestSplit(KFold(n_splits=3))
x_train, x_test, y_train, y_test = next(cv.split(x1, x2, y=y))
x_train, x_test, y_train, y_test = next(cv.split([x1, x2], y=y))
expected_train = [2019, 2020, 2021, 2022]
expected_test = [2016, 2017, 2018]

assert np.array_equal(x_train[0].anchor_year, expected_train)
xr.testing.assert_equal(x_test[1], x2.sel(anchor_year=expected_test))
assert np.array_equal(y_train.anchor_year, expected_train)
xr.testing.assert_equal(y_test, y.sel(anchor_year=expected_test))


def test_kfold_xxy_tuple(dummy_data):
"""Correctly split x1, x2, and y."""
x1, x2, y = dummy_data
cv = lilio.traintest.TrainTestSplit(KFold(n_splits=3))
x_train, x_test, y_train, y_test = next(cv.split((x1, x2), y=y))
expected_train = [2019, 2020, 2021, 2022]
expected_test = [2016, 2017, 2018]

assert isinstance(x_train, list) # all iterable will be turned into list
assert np.array_equal(x_train[0].anchor_year, expected_train)
xr.testing.assert_equal(x_test[1], x2.sel(anchor_year=expected_test))
assert np.array_equal(y_train.anchor_year, expected_train)
Expand Down

0 comments on commit ee8ff0a

Please sign in to comment.