Skip to content

Commit

Permalink
PhBaseWorkChain: add handler for ERROR_SCHEDULER_OUT_OF_WALLTIME (a…
Browse files Browse the repository at this point in the history
…iidateam#754)

Certain scheduler plugins can detect an out-of-walltime error in which
case the `ERROR_SCHEDULER_OUT_OF_WALLTIME` exit code will already have
been set on the node when the actual output parser is called. The
`PhParser` is updated to check for this exit code, and after having
parsed as much as possible from the output, the same exit code is kept
by not returning any other more specific exit code.

The `PhBaseWorkChain` adds a new handler for this exit code and will
perform a full restart by setting `recover = False`. It needs to be a
full restart because with an OOW error from the scheduler, the state of
the files on disk are almost certainly corrupt as the scheduler will
have killed the job when it was writing to disk.

Co-authored-by: Sebastiaan Huber <[email protected]>
  • Loading branch information
2 people authored and bastonero committed Dec 20, 2021
1 parent c72ea07 commit a88c67a
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 11 deletions.
4 changes: 4 additions & 0 deletions aiida_quantumespresso/parsers/ph.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def parse(self, **kwargs):
self.emit_logs(logs)
self.out('output_parameters', orm.Dict(dict=parsed_data))

# If the scheduler detected OOW, simply keep that exit code by not returning anything more specific.
if self.node.exit_status == PhCalculation.exit_codes.ERROR_SCHEDULER_OUT_OF_WALLTIME:
return

if 'ERROR_OUT_OF_WALLTIME' in logs['error']:
return self.exit_codes.ERROR_OUT_OF_WALLTIME

Expand Down
26 changes: 24 additions & 2 deletions aiida_quantumespresso/workflows/ph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,28 @@ def handle_unrecoverable_failure(self, node):
self.report_error_handled(node, 'unrecoverable error, aborting...')
return ProcessHandlerReport(True, self.exit_codes.ERROR_UNRECOVERABLE_FAILURE)

@process_handler(priority=610, exit_codes=PhCalculation.exit_codes.ERROR_SCHEDULER_OUT_OF_WALLTIME)
def handle_scheduler_out_of_walltime(self, node):
"""Handle `ERROR_SCHEDULER_OUT_OF_WALLTIME` exit code: decrease the max_secondes and restart from scratch."""

# Decrease `max_seconds` significantly in order to make sure that the calculation has the time to shut down
# neatly before reaching the scheduler wall time and one can restart from this calculation.
factor = 0.5
max_seconds = self.ctx.inputs.parameters.get('INPUTPH', {}).get('max_seconds', None)
if max_seconds is None:
max_seconds = self.ctx.inputs.metadata.options.get(
'max_wallclock_seconds', None
) * self.defaults.delta_factor_max_seconds
max_seconds_new = max_seconds * factor

self.ctx.restart_calc = node
self.ctx.inputs.parameters.setdefault('INPUTPH', {})['recover'] = False
self.ctx.inputs.parameters.setdefault('INPUTPH', {})['max_seconds'] = max_seconds_new

action = f'reduced max_seconds from {max_seconds} to {max_seconds_new} and restarting'
self.report_error_handled(node, action)
return ProcessHandlerReport(True)

@process_handler(priority=580, exit_codes=PhCalculation.exit_codes.ERROR_OUT_OF_WALLTIME)
def handle_out_of_walltime(self, node):
"""Handle `ERROR_OUT_OF_WALLTIME` exit code: calculation shut down neatly and we can simply restart."""
Expand All @@ -117,8 +139,8 @@ def handle_out_of_walltime(self, node):
return ProcessHandlerReport(True)

@process_handler(priority=410, exit_codes=PhCalculation.exit_codes.ERROR_CONVERGENCE_NOT_REACHED)
def handle_convergence_not_achieved(self, node):
"""Handle `ERROR_CONVERGENCE_NOT_REACHED` exit code: decrease the mixing beta and restart from scratch."""
def handle_convergence_not_reached(self, node):
"""Handle `ERROR_CONVERGENCE_NOT_REACHED` exit code: decrease the mixing beta and restart."""
factor = self.defaults.delta_factor_alpha_mix
alpha_mix = self.ctx.inputs.parameters.get('INPUTPH', {}).get('alpha_mix(1)', self.defaults.alpha_mix)
alpha_mix_new = alpha_mix * factor
Expand Down
4 changes: 2 additions & 2 deletions aiida_quantumespresso/workflows/pw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def validate_kpoints(self):
the case of the latter, the `KpointsData` will be constructed for the input `StructureData` using the
`create_kpoints_from_distance` calculation function.
"""
if all([key not in self.inputs for key in ['kpoints', 'kpoints_distance']]):
if all(key not in self.inputs for key in ['kpoints', 'kpoints_distance']):
return self.exit_codes.ERROR_INVALID_INPUT_KPOINTS

try:
Expand Down Expand Up @@ -637,7 +637,7 @@ def handle_relax_recoverable_electronic_convergence_error(self, calculation):
@process_handler(priority=410, exit_codes=[
PwCalculation.exit_codes.ERROR_ELECTRONIC_CONVERGENCE_NOT_REACHED,
])
def handle_electronic_convergence_not_achieved(self, calculation):
def handle_electronic_convergence_not_reached(self, calculation):
"""Handle `ERROR_ELECTRONIC_CONVERGENCE_NOT_REACHED` error.
Decrease the mixing beta and fully restart from the previous calculation.
Expand Down
29 changes: 26 additions & 3 deletions tests/workflows/ph/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,38 @@ def test_handle_out_of_walltime(generate_workchain_ph):
assert result.status == 0


def test_handle_convergence_not_achieved(generate_workchain_ph):
"""Test `PhBaseWorkChain.handle_convergence_not_achieved`."""
def test_handle_scheduler_out_of_walltime(generate_workchain_ph):
"""Test `PhBaseWorkChain.handle_scheduler_out_of_walltime`."""
inputs = generate_workchain_ph(return_inputs=True)
max_wallclock_seconds = inputs['ph']['metadata']['options']['max_wallclock_seconds']
max_seconds = max_wallclock_seconds * PhBaseWorkChain.defaults.delta_factor_max_seconds

process = generate_workchain_ph(exit_code=PhCalculation.exit_codes.ERROR_SCHEDULER_OUT_OF_WALLTIME)
process.setup()
process.validate_parameters()
process.prepare_process()

max_seconds_new = max_seconds * 0.5

result = process.handle_scheduler_out_of_walltime(process.ctx.children[-1])
assert isinstance(result, ProcessHandlerReport)
assert result.do_break
assert process.ctx.inputs.parameters['INPUTPH']['max_seconds'] == max_seconds_new
assert not process.ctx.inputs.parameters['INPUTPH']['recover']

result = process.inspect_process()
assert result.status == 0


def test_handle_convergence_not_reached(generate_workchain_ph):
"""Test `PhBaseWorkChain.handle_convergence_not_reached`."""
process = generate_workchain_ph(exit_code=PhCalculation.exit_codes.ERROR_CONVERGENCE_NOT_REACHED)
process.setup()
process.validate_parameters()

alpha_new = PhBaseWorkChain.defaults.alpha_mix * PhBaseWorkChain.defaults.delta_factor_alpha_mix

result = process.handle_convergence_not_achieved(process.ctx.children[-1])
result = process.handle_convergence_not_reached(process.ctx.children[-1])
assert isinstance(result, ProcessHandlerReport)
assert result.do_break
assert process.ctx.inputs.parameters['INPUTPH']['alpha_mix(1)'] == alpha_new
Expand Down
8 changes: 4 additions & 4 deletions tests/workflows/pw/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_handle_out_of_walltime(generate_workchain_pw, fixture_localhost, genera
)
process.setup()

result = process.handle_electronic_convergence_not_achieved(process.ctx.children[-1])
result = process.handle_electronic_convergence_not_reached(process.ctx.children[-1])
result = process.handle_out_of_walltime(process.ctx.children[-1])
assert isinstance(result, ProcessHandlerReport)
assert process.ctx.inputs.parameters['CONTROL']['restart_mode'] == 'restart'
Expand All @@ -74,8 +74,8 @@ def test_handle_out_of_walltime_structure_changed(generate_workchain_pw, generat
assert result.status == 0


def test_handle_electronic_convergence_not_achieved(generate_workchain_pw, fixture_localhost, generate_remote_data):
"""Test `PwBaseWorkChain.handle_electronic_convergence_not_achieved`."""
def test_handle_electronic_convergence_not_reached(generate_workchain_pw, fixture_localhost, generate_remote_data):
"""Test `PwBaseWorkChain.handle_electronic_convergence_not_reached`."""
remote_data = generate_remote_data(computer=fixture_localhost, remote_path='/path/to/remote')

process = generate_workchain_pw(
Expand All @@ -86,7 +86,7 @@ def test_handle_electronic_convergence_not_achieved(generate_workchain_pw, fixtu

process.ctx.inputs.parameters['ELECTRONS']['mixing_beta'] = 0.5

result = process.handle_electronic_convergence_not_achieved(process.ctx.children[-1])
result = process.handle_electronic_convergence_not_reached(process.ctx.children[-1])
assert isinstance(result, ProcessHandlerReport)
assert process.ctx.inputs.parameters['ELECTRONS']['mixing_beta'] == \
process.defaults.delta_factor_mixing_beta * 0.5
Expand Down

0 comments on commit a88c67a

Please sign in to comment.