diff --git a/pyphare/pyphare/pharesee/hierarchy/hierarchy.py b/pyphare/pyphare/pharesee/hierarchy/hierarchy.py index 2e3b62df5..cab07023c 100644 --- a/pyphare/pyphare/pharesee/hierarchy/hierarchy.py +++ b/pyphare/pyphare/pharesee/hierarchy/hierarchy.py @@ -1,12 +1,13 @@ +import numpy as np +import matplotlib.pyplot as plt + from .patch import Patch from .patchlevel import PatchLevel from ...core.box import Box from ...core import box as boxm -from ...core.phare_utilities import refinement_ratio from ...core.phare_utilities import listify - -import numpy as np -import matplotlib.pyplot as plt +from ...core.phare_utilities import deep_copy +from ...core.phare_utilities import refinement_ratio def format_timestamp(timestamp): @@ -68,6 +69,10 @@ def __init__( self.update() + def __deepcopy__(self, memo): + no_copy_keys = ["data_files"] # do not copy these things + return deep_copy(self, memo, no_copy_keys) + def __getitem__(self, qty): return self.__dict__[qty] diff --git a/pyphare/pyphare/pharesee/hierarchy/hierarchy_utils.py b/pyphare/pyphare/pharesee/hierarchy/hierarchy_utils.py index 369cd0f28..cccd9704a 100644 --- a/pyphare/pyphare/pharesee/hierarchy/hierarchy_utils.py +++ b/pyphare/pyphare/pharesee/hierarchy/hierarchy_utils.py @@ -1,11 +1,17 @@ -from .hierarchy import PatchHierarchy -from .patchdata import FieldData +from dataclasses import dataclass +from copy import deepcopy +import numpy as np + +from .hierarchy import PatchHierarchy, format_timestamp +from .patchdata import FieldData, ParticleData from .patchlevel import PatchLevel from .patch import Patch +from ...core.box import Box +from ...core.gridlayout import GridLayout from ...core.phare_utilities import listify from ...core.phare_utilities import refinement_ratio +from pyphare.pharesee import particles as mparticles -import numpy as np field_qties = { "EM_B_x": "Bx", @@ -552,9 +558,6 @@ def _compute_scalardiv(patch_datas, **kwargs): return tuple(pd_attrs) -from dataclasses import dataclass - - @dataclass class EqualityReport: ok: bool @@ -606,3 +609,46 @@ def hierarchy_compare(this, that): return EqualityReport(False, msg) return EqualityReport(True, "OK") + + +def single_patch_for_LO(hier, qties=None): + def _skip(qty): + return qties is not None and qty not in qties + + cier = deepcopy(hier) + sim = hier.sim + layout = GridLayout( + Box(sim.origin, sim.cells), sim.origin, sim.dl, interp_order=sim.interp_order + ) + p0 = Patch(patch_datas={}, patch_id="", layout=layout) + for t in cier.times(): + cier.time_hier[format_timestamp(t)] = {0: cier.level(0, t)} + cier.level(0, t).patches = [deepcopy(p0)] + l0_pds = cier.level(0, t).patches[0].patch_datas + for k, v in hier.level(0, t).patches[0].patch_datas.items(): + if _skip(k): + continue + if isinstance(v, FieldData): + l0_pds[k] = FieldData( + layout, v.field_name, None, centering=v.centerings + ) + l0_pds[k].dataset = np.zeros(l0_pds[k].size) + + elif isinstance(v, ParticleData): + l0_pds[k] = deepcopy(v) + else: + raise RuntimeError("unexpected state") + + for patch in hier.level(0, t).patches: + for k, v in patch.patch_datas.items(): + if _skip(k): + continue + if isinstance(v, FieldData): + l0_pds[k][patch.box] = v[patch.box] + elif isinstance(v, ParticleData): + l0_pds[k].dataset = mparticles.aggregate( + [l0_pds[k].dataset, v.dataset] + ) + else: + raise RuntimeError("unexpected state") + return cier diff --git a/pyphare/pyphare/pharesee/hierarchy/patchdata.py b/pyphare/pyphare/pharesee/hierarchy/patchdata.py index 4b606ccdb..47c6b3de6 100644 --- a/pyphare/pyphare/pharesee/hierarchy/patchdata.py +++ b/pyphare/pyphare/pharesee/hierarchy/patchdata.py @@ -118,6 +118,9 @@ def __getitem__(self, box_or_slice): return self.dataset[box_or_slice] return self.select(box_or_slice) + def __setitem__(self, box_or_slice, val): + self.__getitem__(box_or_slice)[:] = val + def __init__(self, layout, field_name, data, **kwargs): """ :param layout: A GridLayout representing the domain on which data is defined diff --git a/tests/simulator/test_init_from_restart.py b/tests/simulator/test_init_from_restart.py index f18da1433..82ab70ec8 100644 --- a/tests/simulator/test_init_from_restart.py +++ b/tests/simulator/test_init_from_restart.py @@ -9,22 +9,26 @@ from pyphare.core import phare_utilities as phut from pyphare.pharesee.hierarchy.fromh5 import get_all_available_quantities_from_h5 from pyphare.pharesee.particles import single_patch_per_level_per_pop_from -from pyphare.pharesee.hierarchy.hierarchy_utils import flat_finest_field +from pyphare.pharesee.hierarchy.hierarchy_utils import ( + flat_finest_field, + single_patch_for_LO, + hierarchy_compare, +) from tests.simulator import SimulatorTest, test_restarts from tests.diagnostic import dump_all_diags timestep = 0.001 time_step_nbr = 1 -first_mpi_size = 1 +first_mpi_size = 4 ppc = 100 cells = 200 first_out = "phare_outputs/reinit/first" secnd_out = "phare_outputs/reinit/secnd" -timestamps = [0] # np.array([timestep * 2, timestep * 4]) +timestamps = [0] restart_idx = Z = 0 simInitArgs = dict( - # largest_patch_size=100, + largest_patch_size=100, time_step_nbr=time_step_nbr, time_step=timestep, cells=cells, @@ -63,20 +67,12 @@ def test_reinit(self): Simulator(sim).run().reset() datahier0 = get_all_available_quantities_from_h5(first_out, timestamps[0]) datahier1 = get_all_available_quantities_from_h5(secnd_out, timestamps[0]) - - for k in "xyz": - a = flat_finest_field(datahier0, f"B{k}", timestamps[0], 0) - b = flat_finest_field(datahier1, f"B{k}", timestamps[0], 0) - phut.assert_fp_any_all_close(a, b) - - def get_merged(hier): - return single_patch_per_level_per_pop_from(hier) - - ds = [get_merged(datahier0), get_merged(datahier1)] - for key in ["alpha", "protons"]: - a, b = [d.level(0).patches[0].patch_datas[f"{key}_domain"] for d in ds] - self.assertGreater(a.size(), (cells - 1) * ppc) - self.assertEqual(a, b) + qties = ["protons_domain", "alpha_domain", "Bx", "By", "Bz"] + ds = [single_patch_for_LO(d, qties) for d in [datahier0, datahier1]] + eq = hierarchy_compare(*ds) + if not eq: + print(eq) + self.assertTrue(eq) def run_first_sim():