From 00b4f7c0ac41740a909bd0eaf3e0be87737e8dfa Mon Sep 17 00:00:00 2001 From: jvwilliams23 Date: Fri, 1 Nov 2024 17:53:07 +0000 Subject: [PATCH] add test for time-varying point data --- src/meshio/exodus/_exodus.py | 11 +++++++++- tests/helpers.py | 41 +++++++++++++++++++++++++++++++----- tests/test_exodus.py | 1 + 3 files changed, 47 insertions(+), 6 deletions(-) diff --git a/src/meshio/exodus/_exodus.py b/src/meshio/exodus/_exodus.py index 3a2335fb5..11c71a99c 100644 --- a/src/meshio/exodus/_exodus.py +++ b/src/meshio/exodus/_exodus.py @@ -291,6 +291,12 @@ def write(filename, mesh): import netCDF4 with netCDF4.Dataset(filename, "w") as rootgrp: + # if time-dependent, pass in mesh and time_step as list + if type(mesh) is list: + mesh, time_step = mesh + else: + time_step = None + # set global data now = datetime.datetime.now().isoformat() rootgrp.title = f"Created by meshio v{__version__}, {now}" @@ -308,7 +314,7 @@ def write(filename, mesh): rootgrp.createDimension("len_string", 33) rootgrp.createDimension("len_line", 81) rootgrp.createDimension("four", 4) - rootgrp.createDimension("time_step", None) + rootgrp.createDimension("time_step", time_step) # dummy time step data = rootgrp.createVariable("time_whole", "f4", ("time_step",)) @@ -377,6 +383,9 @@ def write(filename, mesh): fill_value=False, ) node_data[0] = data + if time_step is not None: + for time_index in range(1, time_step): + node_data[time_index] = data # node sets num_point_sets = len(mesh.point_sets) diff --git a/tests/helpers.py b/tests/helpers.py index da261cf0f..db7946c5a 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -605,6 +605,18 @@ def add_point_data(mesh, dim, num_tags=2, seed=0, dtype=float): return mesh2 +def add_timevarying_point_data(mesh, dim, time_step=2, num_tags=2, seed=0, dtype=float): + rng = np.random.default_rng(seed) + + mesh2 = copy.deepcopy(mesh) + + shape = (len(mesh.points),) if dim == 1 else (len(mesh.points), dim) + data = [(100 * rng.random(shape)).astype(dtype) for _ in range(num_tags)] + + for k, d in enumerate(data): + mesh2.point_data[string.ascii_lowercase[k]] = d + return [mesh2, time_step] + def add_cell_data(mesh, specs: list[tuple[str, tuple[int, ...], type]]): mesh2 = copy.deepcopy(mesh) @@ -653,11 +665,20 @@ def add_cell_sets(mesh): def write_read(tmp_path, writer, reader, input_mesh, atol, extension=".dat"): """Write and read a file, and make sure the data is the same as before.""" - in_mesh = copy.deepcopy(input_mesh) + if type(input_mesh) is list: + in_mesh = copy.deepcopy(input_mesh[0]) + else: + in_mesh = copy.deepcopy(input_mesh) p = tmp_path / ("test" + extension) print(input_mesh) writer(p, input_mesh) + # when using time-varying data, we pass time_step in list with + # mesh to avoid excessive code changes + if type(input_mesh) is list: + input_mesh, time_step = input_mesh + else: + time_step = None mesh = reader(p) # Make sure the output is writeable @@ -718,10 +739,20 @@ def cell_sorter(cell): print("b", cells1.data) assert np.array_equal(cells0.data, cells1.data) - for key in input_mesh.point_data.keys(): - assert np.allclose( - input_mesh.point_data[key], mesh.point_data[key], atol=atol, rtol=0.0 - ) + if time_step is None: + for key in input_mesh.point_data.keys(): + assert np.allclose( + input_mesh.point_data[key], mesh.point_data[key], atol=atol, rtol=0.0 + ) + else: + # we cannot write time-dependent data to input_mesh.point_data + # (as far as I can see). So for testing we set all times equal + for key in input_mesh.point_data.keys(): + for time_index in range(time_step): + time_dep_key = f"{key}_time{time_index}" + assert np.allclose( + input_mesh.point_data[key], mesh.point_data[time_dep_key], atol=atol, rtol=0.0 + ) print(input_mesh.cell_data) print() diff --git a/tests/test_exodus.py b/tests/test_exodus.py index 67e3a1328..f68930ecb 100644 --- a/tests/test_exodus.py +++ b/tests/test_exodus.py @@ -23,6 +23,7 @@ helpers.add_point_data(helpers.tri_mesh, 2), helpers.add_point_data(helpers.tri_mesh, 3), helpers.add_cell_data(helpers.tri_mesh, [("a", (3,), np.float64)]), + helpers.add_timevarying_point_data(helpers.tri_mesh, 1, 2), helpers.add_point_sets(helpers.tri_mesh), helpers.add_point_sets(helpers.tet_mesh), ]