Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
nathaniel-hudson committed Dec 19, 2024
1 parent 990c9fc commit b891740
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 23 deletions.
8 changes: 6 additions & 2 deletions flight/learning/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
from flight.learning.types import Data, FloatDouble, FloatTriple


IID: t.Final[float] = 1e5
NON_IID: t.Final[float] = 1e-5


class FederatedDataModule(TorchDataModule):
"""
This class defines a DataModule that is split across worker nodes in a federation's
Expand Down Expand Up @@ -177,8 +181,8 @@ def federated_split(
topo: Topology,
data: Data,
num_labels: int,
label_alpha: float,
sample_alpha: float,
label_alpha: float = IID,
sample_alpha: float = IID,
train_test_valid_split: FloatTriple | FloatDouble | None = None,
ensure_at_least_one_sample: bool = True,
rng: np.random.Generator | int | None = None,
Expand Down
44 changes: 23 additions & 21 deletions tests/test_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,26 @@
from flight.commons import proportion_split


class TestProportionSplit:
def test_valid_proportions(self):
lst = list(range(10))
assert proportion_split(lst, (0.5, 0.5)) == ([0, 1, 2, 3, 4], [5, 6, 7, 8, 9])
assert proportion_split(lst, (0.5, 0.2, 0.3)) == (
[0, 1, 2, 3, 4],
[5, 6],
[7, 8, 9],
)

def test_proportions_sum_not_one(self):
with pytest.raises(ValueError):
proportion_split([1, 2, 3], (0.5, 0.6))

def test_negative_proportions(self):
with pytest.raises(ValueError):
proportion_split([1, 2, 3], (-0.5, 1.5))

def test_more_proportions_than_elements(self):
with pytest.raises(ValueError):
proportion_split([1, 2], (0.5, 0.5, 0.5))
def test_valid_proportions():
lst = list(range(10))
assert proportion_split(lst, (0.5, 0.5)) == ([0, 1, 2, 3, 4], [5, 6, 7, 8, 9])
assert proportion_split(lst, (0.5, 0.2, 0.3)) == (
[0, 1, 2, 3, 4],
[5, 6],
[7, 8, 9],
)


def test_proportions_sum_not_one():
with pytest.raises(ValueError):
proportion_split([1, 2, 3], (0.5, 0.6))


def test_negative_proportions():
with pytest.raises(ValueError):
proportion_split([1, 2, 3], (-0.5, 1.5))


def test_more_proportions_than_elements():
with pytest.raises(ValueError):
proportion_split([1, 2], (0.5, 0.5, 0.5))

0 comments on commit b891740

Please sign in to comment.