Skip to content

Commit

Permalink
[py systems] Improve error message for missing witness functions (#20106
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jwnimmer-tri authored Sep 1, 2023
1 parent 09752cc commit 09d5fe5
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 11 deletions.
35 changes: 27 additions & 8 deletions bindings/pydrake/systems/framework_py_systems.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,24 @@ struct Impl {
// trampoline if this is needed outside of LeafSystem.
void DoGetWitnessFunctions(const Context<T>& context,
std::vector<const WitnessFunction<T>*>* witnesses) const override {
auto wrapped = [&]() -> std::vector<const WitnessFunction<T>*> {
PYBIND11_OVERLOAD_INT(std::vector<const WitnessFunction<T>*>,
auto wrapped =
[&]() -> std::optional<std::vector<const WitnessFunction<T>*>> {
PYBIND11_OVERLOAD_INT(
std::optional<std::vector<const WitnessFunction<T>*>>,
LeafSystem<T>, "DoGetWitnessFunctions", &context);
std::vector<const WitnessFunction<T>*> result;
// If the macro did not return, use default functionality.
Base::DoGetWitnessFunctions(context, &result);
return result;
return {result};
};
*witnesses = wrapped();
auto result = wrapped();
if (!result.has_value()) {
// Give a good error message in case the user forgot to return anything.
throw py::type_error(
"Overrides of DoGetWitnessFunctions() must return "
"List[WitnessFunction], not NoneType.");
}
*witnesses = std::move(*result);
}
};

Expand Down Expand Up @@ -894,10 +903,20 @@ Note: The above is for the C++ documentation. For Python, use
.def("MakeWitnessFunction",
WrapCallbacks([](PyLeafSystem* self, const std::string& description,
const WitnessFunctionDirection& direction_type,
std::function<T(const Context<T>&)> calc)
-> std::unique_ptr<WitnessFunction<T>> {
return self->MakeWitnessFunction(
description, direction_type, calc);
std::function<std::optional<T>(const Context<T>&)>
calc) -> std::unique_ptr<WitnessFunction<T>> {
return self->MakeWitnessFunction(description, direction_type,
[calc](const Context<T>& context) -> T {
const std::optional<T> result = calc(context);
if (!result.has_value()) {
// Give a good error message in case the user forgot to
// return anything.
throw py::type_error(
"The MakeWitnessFunction() calc callback must return "
"a floating point value, not NoneType.");
}
return *result;
});
}),
py_rvp::reference_internal, py::arg("description"),
py::arg("direction_type"), py::arg("calc"),
Expand Down
25 changes: 22 additions & 3 deletions bindings/pydrake/systems/test/custom_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,12 @@ def __init__(self):
"system reset", WitnessFunctionDirection.kCrossesZero,
self._guard, UnrestrictedUpdateEvent(
system_callback=self._system_reset))
self.witness_result = 1.0
self.getwitness_result = [
self.witness,
self.reset_witness,
self.system_reset_witness,
]

def DoPublish(self, context, events):
# Call base method to ensure we do not get recursion.
Expand Down Expand Up @@ -450,8 +456,7 @@ def DoCalcDiscreteVariableUpdates(

def DoGetWitnessFunctions(self, context):
self.called_getwitness = True
return [self.witness, self.reset_witness,
self.system_reset_witness]
return self.getwitness_result

def _on_initialize(self, context, event):
test.assertIsInstance(context, Context)
Expand Down Expand Up @@ -550,7 +555,7 @@ def _on_forced_unrestricted(self, context, state):
def _witness(self, context):
test.assertIsInstance(context, Context)
self.called_witness = True
return 1.0
return self.witness_result

def _guard(self, context):
test.assertIsInstance(context, Context)
Expand Down Expand Up @@ -674,6 +679,20 @@ def _system_reset(self, system, context, event, state):
self.assertFalse(system.called_reset)
self.assertFalse(system.called_system_reset)

# Test witness function error messages.
system = TrivialSystem()
system.getwitness_result = None
simulator = Simulator(system)
with self.assertRaisesRegex(TypeError, "NoneType"):
simulator.AdvanceTo(0.1)
self.assertTrue(system.called_getwitness)
system = TrivialSystem()
system.witness_result = None
simulator = Simulator(system)
with self.assertRaisesRegex(TypeError, "NoneType"):
simulator.AdvanceTo(0.1)
self.assertTrue(system.called_witness)

def test_event_handler_returns_none(self):
"""Checks that a Python event handler callback function is allowed to
(implicitly) return None, instead of an EventStatus. Because of all the
Expand Down

0 comments on commit 09d5fe5

Please sign in to comment.