diff --git a/flight/learning/torch/utils.py b/flight/learning/torch/utils.py index 2816250..91c2b66 100644 --- a/flight/learning/torch/utils.py +++ b/flight/learning/torch/utils.py @@ -20,6 +20,10 @@ from flight.learning.types import Data, FloatDouble, FloatTriple +IID: t.Final[float] = 1e5 +NON_IID: t.Final[float] = 1e-5 + + class FederatedDataModule(TorchDataModule): """ This class defines a DataModule that is split across worker nodes in a federation's @@ -177,8 +181,8 @@ def federated_split( topo: Topology, data: Data, num_labels: int, - label_alpha: float, - sample_alpha: float, + label_alpha: float = IID, + sample_alpha: float = IID, train_test_valid_split: FloatTriple | FloatDouble | None = None, ensure_at_least_one_sample: bool = True, rng: np.random.Generator | int | None = None, diff --git a/tests/test_commons.py b/tests/test_commons.py index ad9dc4a..1f3b777 100644 --- a/tests/test_commons.py +++ b/tests/test_commons.py @@ -3,24 +3,26 @@ from flight.commons import proportion_split -class TestProportionSplit: - def test_valid_proportions(self): - lst = list(range(10)) - assert proportion_split(lst, (0.5, 0.5)) == ([0, 1, 2, 3, 4], [5, 6, 7, 8, 9]) - assert proportion_split(lst, (0.5, 0.2, 0.3)) == ( - [0, 1, 2, 3, 4], - [5, 6], - [7, 8, 9], - ) - - def test_proportions_sum_not_one(self): - with pytest.raises(ValueError): - proportion_split([1, 2, 3], (0.5, 0.6)) - - def test_negative_proportions(self): - with pytest.raises(ValueError): - proportion_split([1, 2, 3], (-0.5, 1.5)) - - def test_more_proportions_than_elements(self): - with pytest.raises(ValueError): - proportion_split([1, 2], (0.5, 0.5, 0.5)) +def test_valid_proportions(): + lst = list(range(10)) + assert proportion_split(lst, (0.5, 0.5)) == ([0, 1, 2, 3, 4], [5, 6, 7, 8, 9]) + assert proportion_split(lst, (0.5, 0.2, 0.3)) == ( + [0, 1, 2, 3, 4], + [5, 6], + [7, 8, 9], + ) + + +def test_proportions_sum_not_one(): + with pytest.raises(ValueError): + proportion_split([1, 2, 3], (0.5, 0.6)) + + +def test_negative_proportions(): + with pytest.raises(ValueError): + proportion_split([1, 2, 3], (-0.5, 1.5)) + + +def test_more_proportions_than_elements(): + with pytest.raises(ValueError): + proportion_split([1, 2], (0.5, 0.5, 0.5))