Skip to content

Commit

Permalink
[pydrake] Add bindings for SharedPointerSystem (#17264)
Browse files Browse the repository at this point in the history
* [pydrake] Add bindings for SharedPointerSystem
  • Loading branch information
jwnimmer-tri authored May 31, 2022
1 parent f8d36f4 commit 5a76870
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
30 changes: 30 additions & 0 deletions bindings/pydrake/systems/primitives_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "drake/systems/primitives/pass_through.h"
#include "drake/systems/primitives/random_source.h"
#include "drake/systems/primitives/saturation.h"
#include "drake/systems/primitives/shared_pointer_system.h"
#include "drake/systems/primitives/sine.h"
#include "drake/systems/primitives/symbolic_vector_system.h"
#include "drake/systems/primitives/trajectory_affine_system.h"
Expand Down Expand Up @@ -390,6 +391,35 @@ PYBIND11_MODULE(primitives, m) {
doc.StateInterpolatorWithDiscreteDerivative.set_initial_position
.doc_2args_state_position);

DefineTemplateClassWithDefault<SharedPointerSystem<T>, LeafSystem<T>>(
m, "SharedPointerSystem", GetPyParam<T>(), doc.SharedPointerSystem.doc)
.def(py::init([](py::object value_to_hold) {
auto wrapped = std::make_unique<py::object>(std::move(value_to_hold));
return std::make_unique<SharedPointerSystem<T>>(std::move(wrapped));
}),
py::arg("value_to_hold"), doc.SharedPointerSystem.ctor.doc)
.def_static(
"AddToBuilder",
[](DiagramBuilder<T>* builder, py::object value_to_hold) {
auto wrapped =
std::make_unique<py::object>(std::move(value_to_hold));
return SharedPointerSystem<T>::AddToBuilder(
builder, std::move(wrapped));
},
py::arg("builder"), py::arg("value_to_hold"),
doc.SharedPointerSystem.AddToBuilder.doc)
.def(
"get",
[](const SharedPointerSystem<T>& self) {
py::object result = py::none();
py::object* held = self.template get<py::object>();
if (held != nullptr) {
result = std::move(*held);
}
return result;
},
doc.SharedPointerSystem.get.doc);

DefineTemplateClassWithDefault<SymbolicVectorSystem<T>, LeafSystem<T>>(m,
"SymbolicVectorSystem", GetPyParam<T>(), doc.SymbolicVectorSystem.doc)
.def(py::init<std::optional<Variable>, VectorX<Variable>,
Expand Down
22 changes: 22 additions & 0 deletions bindings/pydrake/systems/test/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
PerceptronActivationType,
RandomSource,
Saturation, Saturation_,
SharedPointerSystem, SharedPointerSystem_,
Sine, Sine_,
StateInterpolatorWithDiscreteDerivative,
StateInterpolatorWithDiscreteDerivative_,
Expand Down Expand Up @@ -95,6 +96,7 @@ def test_instantiations(self):
self._check_instantiations(MultilayerPerceptron_)
self._check_instantiations(PassThrough_)
self._check_instantiations(Saturation_)
self._check_instantiations(SharedPointerSystem_)
self._check_instantiations(Sine_)
self._check_instantiations(StateInterpolatorWithDiscreteDerivative_)
self._check_instantiations(SymbolicVectorSystem_)
Expand Down Expand Up @@ -624,6 +626,26 @@ def test_ctor_api(self):
period_sec=0.1,
abstract_model_value=AbstractValue.Make("Hello world"))

def test_shared_pointer_system_ctor(self):
dut = SharedPointerSystem(value_to_hold=[1, 2, 3])
readback = dut.get()
self.assertListEqual(readback, [1, 2, 3])
del dut
self.assertListEqual(readback, [1, 2, 3])

def test_shared_pointer_system_builder(self):
builder = DiagramBuilder()
self.assertListEqual(
SharedPointerSystem.AddToBuilder(
builder=builder, value_to_hold=[1, 2, 3]),
[1, 2, 3])
diagram = builder.Build()
del builder
readback = diagram.GetSystems()[0].get()
self.assertListEqual(readback, [1, 2, 3])
del diagram
self.assertListEqual(readback, [1, 2, 3])

def test_sine(self):
# Test scalar output.
sine_source = Sine(amplitude=1, frequency=2, phase=3,
Expand Down

0 comments on commit 5a76870

Please sign in to comment.