From 09d5fe52a6df8f9acf629f7aa671e1f55f2d7059 Mon Sep 17 00:00:00 2001 From: Jeremy Nimmer Date: Fri, 1 Sep 2023 13:09:08 -0700 Subject: [PATCH] [py systems] Improve error message for missing witness functions (#20106) --- .../pydrake/systems/framework_py_systems.cc | 35 ++++++++++++++----- bindings/pydrake/systems/test/custom_test.py | 25 +++++++++++-- 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/bindings/pydrake/systems/framework_py_systems.cc b/bindings/pydrake/systems/framework_py_systems.cc index 70fc181fdd59..2d46f3f53351 100644 --- a/bindings/pydrake/systems/framework_py_systems.cc +++ b/bindings/pydrake/systems/framework_py_systems.cc @@ -166,15 +166,24 @@ struct Impl { // trampoline if this is needed outside of LeafSystem. void DoGetWitnessFunctions(const Context& context, std::vector*>* witnesses) const override { - auto wrapped = [&]() -> std::vector*> { - PYBIND11_OVERLOAD_INT(std::vector*>, + auto wrapped = + [&]() -> std::optional*>> { + PYBIND11_OVERLOAD_INT( + std::optional*>>, LeafSystem, "DoGetWitnessFunctions", &context); std::vector*> result; // If the macro did not return, use default functionality. Base::DoGetWitnessFunctions(context, &result); - return result; + return {result}; }; - *witnesses = wrapped(); + auto result = wrapped(); + if (!result.has_value()) { + // Give a good error message in case the user forgot to return anything. + throw py::type_error( + "Overrides of DoGetWitnessFunctions() must return " + "List[WitnessFunction], not NoneType."); + } + *witnesses = std::move(*result); } }; @@ -894,10 +903,20 @@ Note: The above is for the C++ documentation. For Python, use .def("MakeWitnessFunction", WrapCallbacks([](PyLeafSystem* self, const std::string& description, const WitnessFunctionDirection& direction_type, - std::function&)> calc) - -> std::unique_ptr> { - return self->MakeWitnessFunction( - description, direction_type, calc); + std::function(const Context&)> + calc) -> std::unique_ptr> { + return self->MakeWitnessFunction(description, direction_type, + [calc](const Context& context) -> T { + const std::optional result = calc(context); + if (!result.has_value()) { + // Give a good error message in case the user forgot to + // return anything. + throw py::type_error( + "The MakeWitnessFunction() calc callback must return " + "a floating point value, not NoneType."); + } + return *result; + }); }), py_rvp::reference_internal, py::arg("description"), py::arg("direction_type"), py::arg("calc"), diff --git a/bindings/pydrake/systems/test/custom_test.py b/bindings/pydrake/systems/test/custom_test.py index 56f50009f360..354e747d4c69 100644 --- a/bindings/pydrake/systems/test/custom_test.py +++ b/bindings/pydrake/systems/test/custom_test.py @@ -423,6 +423,12 @@ def __init__(self): "system reset", WitnessFunctionDirection.kCrossesZero, self._guard, UnrestrictedUpdateEvent( system_callback=self._system_reset)) + self.witness_result = 1.0 + self.getwitness_result = [ + self.witness, + self.reset_witness, + self.system_reset_witness, + ] def DoPublish(self, context, events): # Call base method to ensure we do not get recursion. @@ -450,8 +456,7 @@ def DoCalcDiscreteVariableUpdates( def DoGetWitnessFunctions(self, context): self.called_getwitness = True - return [self.witness, self.reset_witness, - self.system_reset_witness] + return self.getwitness_result def _on_initialize(self, context, event): test.assertIsInstance(context, Context) @@ -550,7 +555,7 @@ def _on_forced_unrestricted(self, context, state): def _witness(self, context): test.assertIsInstance(context, Context) self.called_witness = True - return 1.0 + return self.witness_result def _guard(self, context): test.assertIsInstance(context, Context) @@ -674,6 +679,20 @@ def _system_reset(self, system, context, event, state): self.assertFalse(system.called_reset) self.assertFalse(system.called_system_reset) + # Test witness function error messages. + system = TrivialSystem() + system.getwitness_result = None + simulator = Simulator(system) + with self.assertRaisesRegex(TypeError, "NoneType"): + simulator.AdvanceTo(0.1) + self.assertTrue(system.called_getwitness) + system = TrivialSystem() + system.witness_result = None + simulator = Simulator(system) + with self.assertRaisesRegex(TypeError, "NoneType"): + simulator.AdvanceTo(0.1) + self.assertTrue(system.called_witness) + def test_event_handler_returns_none(self): """Checks that a Python event handler callback function is allowed to (implicitly) return None, instead of an EventStatus. Because of all the