Skip to content

Commit

Permalink
add test for time-varying point data
Browse files Browse the repository at this point in the history
  • Loading branch information
jvwilliams23 committed Nov 1, 2024
1 parent 509a3f8 commit 00b4f7c
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 6 deletions.
11 changes: 10 additions & 1 deletion src/meshio/exodus/_exodus.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,12 @@ def write(filename, mesh):
import netCDF4

with netCDF4.Dataset(filename, "w") as rootgrp:
# if time-dependent, pass in mesh and time_step as list
if type(mesh) is list:
mesh, time_step = mesh
else:
time_step = None

# set global data
now = datetime.datetime.now().isoformat()
rootgrp.title = f"Created by meshio v{__version__}, {now}"
Expand All @@ -308,7 +314,7 @@ def write(filename, mesh):
rootgrp.createDimension("len_string", 33)
rootgrp.createDimension("len_line", 81)
rootgrp.createDimension("four", 4)
rootgrp.createDimension("time_step", None)
rootgrp.createDimension("time_step", time_step)

# dummy time step
data = rootgrp.createVariable("time_whole", "f4", ("time_step",))
Expand Down Expand Up @@ -377,6 +383,9 @@ def write(filename, mesh):
fill_value=False,
)
node_data[0] = data
if time_step is not None:
for time_index in range(1, time_step):
node_data[time_index] = data

# node sets
num_point_sets = len(mesh.point_sets)
Expand Down
41 changes: 36 additions & 5 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,18 @@ def add_point_data(mesh, dim, num_tags=2, seed=0, dtype=float):
return mesh2


def add_timevarying_point_data(mesh, dim, time_step=2, num_tags=2, seed=0, dtype=float):
rng = np.random.default_rng(seed)

mesh2 = copy.deepcopy(mesh)

shape = (len(mesh.points),) if dim == 1 else (len(mesh.points), dim)
data = [(100 * rng.random(shape)).astype(dtype) for _ in range(num_tags)]

for k, d in enumerate(data):
mesh2.point_data[string.ascii_lowercase[k]] = d
return [mesh2, time_step]

def add_cell_data(mesh, specs: list[tuple[str, tuple[int, ...], type]]):
mesh2 = copy.deepcopy(mesh)

Expand Down Expand Up @@ -653,11 +665,20 @@ def add_cell_sets(mesh):

def write_read(tmp_path, writer, reader, input_mesh, atol, extension=".dat"):
"""Write and read a file, and make sure the data is the same as before."""
in_mesh = copy.deepcopy(input_mesh)
if type(input_mesh) is list:
in_mesh = copy.deepcopy(input_mesh[0])
else:
in_mesh = copy.deepcopy(input_mesh)

p = tmp_path / ("test" + extension)
print(input_mesh)
writer(p, input_mesh)
# when using time-varying data, we pass time_step in list with
# mesh to avoid excessive code changes
if type(input_mesh) is list:
input_mesh, time_step = input_mesh
else:
time_step = None
mesh = reader(p)

# Make sure the output is writeable
Expand Down Expand Up @@ -718,10 +739,20 @@ def cell_sorter(cell):
print("b", cells1.data)
assert np.array_equal(cells0.data, cells1.data)

for key in input_mesh.point_data.keys():
assert np.allclose(
input_mesh.point_data[key], mesh.point_data[key], atol=atol, rtol=0.0
)
if time_step is None:
for key in input_mesh.point_data.keys():
assert np.allclose(
input_mesh.point_data[key], mesh.point_data[key], atol=atol, rtol=0.0
)
else:
# we cannot write time-dependent data to input_mesh.point_data
# (as far as I can see). So for testing we set all times equal
for key in input_mesh.point_data.keys():
for time_index in range(time_step):
time_dep_key = f"{key}_time{time_index}"
assert np.allclose(
input_mesh.point_data[key], mesh.point_data[time_dep_key], atol=atol, rtol=0.0
)

print(input_mesh.cell_data)
print()
Expand Down
1 change: 1 addition & 0 deletions tests/test_exodus.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
helpers.add_point_data(helpers.tri_mesh, 2),
helpers.add_point_data(helpers.tri_mesh, 3),
helpers.add_cell_data(helpers.tri_mesh, [("a", (3,), np.float64)]),
helpers.add_timevarying_point_data(helpers.tri_mesh, 1, 2),
helpers.add_point_sets(helpers.tri_mesh),
helpers.add_point_sets(helpers.tet_mesh),
]
Expand Down

0 comments on commit 00b4f7c

Please sign in to comment.