Skip to content

Commit

Permalink
use unittest setUpClass instead of overriding __init__
Browse files Browse the repository at this point in the history
  • Loading branch information
davidwilby committed Jun 14, 2024
1 parent bead637 commit a0d61c2
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 51 deletions.
33 changes: 17 additions & 16 deletions tests/test_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,33 +30,34 @@


class TestActiveLearning(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@classmethod
def setUpClass(cls):

# It's safe to share data between tests because the TaskLoader does not modify data
ds_raw = xr.tutorial.open_dataset("air_temperature")["air"]
self.ds_raw = ds_raw
self.data_processor = DataProcessor(x1_name="lat", x2_name="lon")
self.ds = self.data_processor(ds_raw)
cls.ds_raw = ds_raw
cls.data_processor = DataProcessor(x1_name="lat", x2_name="lon")
cls.ds = cls.data_processor(ds_raw)
# Set up a model with two context sets and two target sets for generality
self.task_loader = TaskLoader(
context=[self.ds, self.ds], target=[self.ds, self.ds]
cls.task_loader = TaskLoader(
context=[cls.ds, cls.ds], target=[cls.ds, cls.ds]
)
self.model = ConvNP(
self.data_processor,
self.task_loader,
cls.model = ConvNP(
cls.data_processor,
cls.task_loader,
unet_channels=(5, 5, 5),
verbose=False,
)

# Set up model with aux-at-target data
aux_at_targets = self.ds.isel(time=0).drop_vars("time")
self.task_loader_with_aux = TaskLoader(
context=self.ds, target=self.ds, aux_at_targets=aux_at_targets
aux_at_targets = cls.ds.isel(time=0).drop_vars("time")
cls.task_loader_with_aux = TaskLoader(
context=cls.ds, target=cls.ds, aux_at_targets=aux_at_targets
)
self.model_with_aux = ConvNP(
self.data_processor,
self.task_loader_with_aux,
cls.model_with_aux = ConvNP(
cls.data_processor,
cls.task_loader_with_aux,
unet_channels=(5, 5, 5),
verbose=False,
)
Expand Down
13 changes: 7 additions & 6 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,15 @@ class TestModel(unittest.TestCase):
A test class for the ``ConvNP`` model.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@classmethod
def setUpClass(cls):
# super().__init__(*args, **kwargs)
# It's safe to share data between tests because the TaskLoader does not modify data
self.da = _gen_data_xr()
self.df = _gen_data_pandas()
cls.da = _gen_data_xr()
cls.df = _gen_data_pandas()

self.dp = DataProcessor()
_ = self.dp([self.da, self.df]) # Compute normalisation parameters
cls.dp = DataProcessor()
_ = cls.dp([cls.da, cls.df]) # Compute normalisation parameters

def _gen_task_loader_call_args(self, n_context, n_target):
"""Generate arguments for TaskLoader.__call__"""
Expand Down
21 changes: 11 additions & 10 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,23 @@


class TestPlotting(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@classmethod
def setUpClass(cls):
# It's safe to share data between tests because the TaskLoader does not modify data
ds_raw = xr.tutorial.open_dataset("air_temperature")
self.ds_raw = ds_raw
self.data_processor = DataProcessor(x1_name="lat", x2_name="lon")
ds = self.data_processor(ds_raw)
self.task_loader = TaskLoader(context=ds, target=ds)
self.model = ConvNP(
self.data_processor,
self.task_loader,
cls.ds_raw = ds_raw
cls.data_processor = DataProcessor(x1_name="lat", x2_name="lon")
ds = cls.data_processor(ds_raw)
cls.task_loader = TaskLoader(context=ds, target=ds)
cls.model = ConvNP(
cls.data_processor,
cls.task_loader,
unet_channels=(5, 5, 5),
verbose=False,
)
# Sample a task with 10 random context points
self.task = self.task_loader(
cls.task = cls.task_loader(
"2014-12-31", context_sampling=10, target_sampling="all"
)

Expand Down
19 changes: 10 additions & 9 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@


class TestConcatTasks(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@classmethod
def setUpClass(cls):
# It's safe to share data between tests because the TaskLoader does not modify data
ds_raw = xr.tutorial.open_dataset("air_temperature")
self.ds_raw = ds_raw
self.data_processor = DataProcessor(x1_name="lat", x2_name="lon")
ds = self.data_processor(ds_raw)
self.task_loader = TaskLoader(context=ds, target=ds)
self.model = ConvNP(
self.data_processor,
self.task_loader,
cls.ds_raw = ds_raw
cls.data_processor = DataProcessor(x1_name="lat", x2_name="lon")
ds = cls.data_processor(ds_raw)
cls.task_loader = TaskLoader(context=ds, target=ds)
cls.model = ConvNP(
cls.data_processor,
cls.task_loader,
unet_channels=(5, 5, 5),
verbose=False,
)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_task_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ class TestTaskLoader(unittest.TestCase):
- Task batching shape as expected
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@classmethod
def setUpClass(cls):
# It's safe to share data between tests because the TaskLoader does not modify data
self.da = _gen_data_xr()
self.aux_da = self.da.isel(time=0)
self.df = _gen_data_pandas()
cls.da = _gen_data_xr()
cls.aux_da = cls.da.isel(time=0)
cls.df = _gen_data_pandas()

def _gen_task_loader_call_args(self, n_context_sets, n_target_sets):
"""Generate arguments for ``TaskLoader.__call__``."""
Expand Down
11 changes: 6 additions & 5 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@


class TestTraining(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@classmethod
def setUpClass(cls):
# It's safe to share data between tests because the TaskLoader does not modify data
ds_raw = xr.tutorial.open_dataset("air_temperature")

self.ds_raw = ds_raw
self.data_processor = DataProcessor(x1_name="lat", x2_name="lon")
cls.ds_raw = ds_raw
cls.data_processor = DataProcessor(x1_name="lat", x2_name="lon")

self.da = self.data_processor(ds_raw)
cls.da = cls.data_processor(ds_raw)

def test_concat_tasks(self):
tl = TaskLoader(context=self.da, target=self.da)
Expand Down

0 comments on commit a0d61c2

Please sign in to comment.