Skip to content

Commit

Permalink
Restructure simulation crash handling and add global condition for exit
Browse files Browse the repository at this point in the history
  • Loading branch information
tjwsch committed Apr 4, 2024
1 parent 69d79d6 commit e808ee6
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 24 deletions.
57 changes: 34 additions & 23 deletions micro_manager/micro_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def __init__(self, config_file: str) -> None:
self._config.get_config_file_name(),
self._rank,
self._size)

micro_file_name = self._config.get_micro_file_name()

self._macro_mesh_name = self._config.get_macro_mesh_name()

Expand All @@ -89,6 +91,8 @@ def __init__(self, config_file: str) -> None:
self._ranks_per_axis = self._config.get_ranks_per_axis()

self._is_micro_solve_time_required = self._config.write_micro_solve_time()

self._crash_threshold = 0.2

self._local_number_of_sims = 0
self._global_number_of_sims = 0
Expand Down Expand Up @@ -256,6 +260,19 @@ def solve(self) -> None:
else:
micro_sims_output = self._solve_micro_simulations(micro_sims_input)

# Check if more than a certain percentage of the micro simulations have crashed and terminate if threshold is exceeded
crashed_sims_on_all_ranks = np.zeros(self._size, dtype=np.int64)
self._comm.Allgather(np.sum(self._crashed_sims), crashed_sims_on_all_ranks)

if self._is_parallel:
crash_ratio = np.sum(crashed_sims_on_all_ranks) / self._global_number_of_sims
else:
crash_ratio = np.sum(self._crashed_sims) / len(self._crashed_sims)
if crash_ratio > self._crash_threshold:
self._logger.info("{:.1%} of the micro simulations have crashed exceeding the threshold of {:.1%}. "
"Exiting simulation.".format(crash_ratio, self._crash_threshold))
sys.exit()

self._write_data_to_precice(micro_sims_output)

self._participant.advance(self._dt)
Expand Down Expand Up @@ -333,11 +350,11 @@ def _initialize(self) -> None:

(
self._mesh_vertex_ids,
mesh_vertex_coords,
self._mesh_vertex_coords,
) = self._participant.get_mesh_vertex_ids_and_coordinates(self._macro_mesh_name)
assert mesh_vertex_coords.size != 0, "Macro mesh has no vertices."
assert self._mesh_vertex_coords.size != 0, "Macro mesh has no vertices."

self._local_number_of_sims, _ = mesh_vertex_coords.shape
self._local_number_of_sims, _ = self._mesh_vertex_coords.shape
self._logger.info(
"Number of local micro simulations = {}".format(self._local_number_of_sims)
)
Expand Down Expand Up @@ -551,34 +568,31 @@ def _solve_micro_simulations(self, micro_sims_input: list) -> list:
micro_sims_output = [None] * self._local_number_of_sims

for count, sim in enumerate(self._micro_sims):

# If micro simulation has not crashed in a previous iteration, attempt to solve it
if not self._crashed_sims[count]:
# Attempt to solve the micro simulation
try:
start_time = time.time()
micro_sims_output[count] = sim.solve(
micro_sims_input[count], self._dt)
end_time = time.time()
# If simulation crashes, log the error and keep the output constant at the previous iteration's output
except Exception as error_message:
_, mesh_vertex_coords = self._participant.get_mesh_vertex_ids_and_coordinates(
self._macro_mesh_name)
self._logger.error("Micro simulation at macro coordinates {} has experienced an error. "
"See next entry for error message. "
"Keeping values constant at results of previous iteration".format(
mesh_vertex_coords[count]))
self._mesh_vertex_coords[count]))
self._logger.error(error_message)
micro_sims_output[count] = self._old_micro_sims_output[count]
self._crashed_sims[count] = True
# If simulation has crashed in a previous iteration, keep the output constant
else:
micro_sims_output[count] = self._old_micro_sims_output[count]
# Write solve time of the macro simulation if required and the simulation has not crashed
if self._is_micro_solve_time_required and not self._crashed_sims[count]:
micro_sims_output[count]["micro_sim_time"] = end_time - start_time

crash_ratio = np.sum(self._crashed_sims) / len(self._crashed_sims)
if crash_ratio > 0.2:
self._logger.info("More than 20% of the micro simulations on rank {} have crashed. "
"Exiting simulation.".format(self._rank))
sys.exit()

# If a simulation crashes in the first iteration it is replaced with the output of the first simulation that ran
set_sims = np.where(micro_sims_output)
none_mask = np.array([item is None for item in micro_sims_output])
unset_sims = np.where(none_mask)[0]
Expand Down Expand Up @@ -642,25 +656,26 @@ def _solve_micro_simulations_with_adaptivity(

# Solve all active micro simulations
for active_id in active_sim_ids:

# If micro simulation has not crashed in a previous iteration, attempt to solve it
if not self._crashed_sims[active_id]:
# Attempt to solve the micro simulation
try:
start_time = time.time()
micro_sims_output[active_id] = self._micro_sims[active_id].solve(
micro_sims_input[active_id], self._dt
)
end_time = time.time()
# If simulation crashes, log the error and keep the output constant at the previous iteration's output
except Exception as error_message:
_, mesh_vertex_coords = self._participant.get_mesh_vertex_ids_and_coordinates(
self._macro_mesh_name)
self._logger.error("Micro simulation at macro coordinates {} has experienced an error. "
"See next entry for error message. "
"Keeping values constant at results of previous iteration".format(
mesh_vertex_coords[active_id]))
self._mesh_vertex_coords[active_id]))
self._logger.error(error_message)
# set the micro simulation value to old value and keep it constant if simulation crashes
micro_sims_output[active_id] = self._old_micro_sims_output[active_id]
self._crashed_sims[active_id] = True
# If simulation has crashed in a previous iteration, keep the output constant
else:
micro_sims_output[active_id] = self._old_micro_sims_output[active_id]

Expand All @@ -670,15 +685,11 @@ def _solve_micro_simulations_with_adaptivity(
"active_steps"
] = self._micro_sims_active_steps[active_id]

# Write solve time of the macro simulation if required and the simulation has not crashed
if self._is_micro_solve_time_required and not self._crashed_sims[active_id]:
micro_sims_output[active_id]["micro_sim_time"] = end_time - start_time

crash_ratio = np.sum(self._crashed_sims) / len(self._crashed_sims)
if crash_ratio > 0.2:
self._logger.info("More than 20% of the micro simulations on rank {} have crashed. "
"Exiting simulation.".format(self._rank))
sys.exit()

# If a simulation crashes in the first iteration it is replaced with the output of the first simulation that ran
set_sims = np.where(micro_sims_output)
unset_sims = []
for active_id in active_sim_ids:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_micro_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_read_write_data_from_precice(self):
self.assertListEqual(data["macro-vector-data"].tolist(),
fake_data["macro-vector-data"].tolist())

def test_solve_micro_sims(self):
def test_solve_mico_sims(self):
"""
Test if the internal function _solve_micro_simulations works as expected.
"""
Expand Down

0 comments on commit e808ee6

Please sign in to comment.