From 888bfe51b2cbf87d5e695e1d2880ecddf4cf2e5e Mon Sep 17 00:00:00 2001 From: AntoinePrv Date: Mon, 7 Jun 2021 09:33:57 -0400 Subject: [PATCH] Bind BranchingGUB to Python and test --- python/src/ecole/core/dynamics.cpp | 17 +++++++++ python/src/ecole/environment.py | 5 +++ python/tests/test_dynamics.py | 55 +++++++++++++++++++++++------- 3 files changed, 64 insertions(+), 13 deletions(-) diff --git a/python/src/ecole/core/dynamics.cpp b/python/src/ecole/core/dynamics.cpp index ec434665a..d05329f5e 100644 --- a/python/src/ecole/core/dynamics.cpp +++ b/python/src/ecole/core/dynamics.cpp @@ -4,6 +4,7 @@ #include #include +#include "ecole/dynamics/branching-gub.hpp" #include "ecole/dynamics/branching.hpp" #include "ecole/dynamics/configuring.hpp" #include "ecole/scip/model.hpp" @@ -62,6 +63,22 @@ void bind_submodule(pybind11::module_ const& m) { .def_set_dynamics_random_state() .def(py::init(), py::arg("pseudo_candidates") = false); + using idx_t = typename BranchingGUBDynamics::Action::value_type; + using array_t = py::array_t; + dynamics_class{m, "BranchingGUBDynamics"} + .def_reset_dynamics() + .def_set_dynamics_random_state() + .def( + "step_dynamics", + [](BranchingGUBDynamics& self, scip::Model& model, array_t const& action) { + auto const vars = nonstd::span{action.data(), static_cast(action.size())}; + auto const release = py::gil_scoped_release{}; + return self.step_dynamics(model, vars); + }, + py::arg("model"), + py::arg("action")) + .def(py::init<>()); + dynamics_class{m, "ConfiguringDynamics"} .def_reset_dynamics() .def_step_dynamics() diff --git a/python/src/ecole/environment.py b/python/src/ecole/environment.py index 63bc41e8c..05bbe942f 100644 --- a/python/src/ecole/environment.py +++ b/python/src/ecole/environment.py @@ -195,5 +195,10 @@ class Branching(Environment): __DefaultObservationFunction__ = ecole.observation.NodeBipartite +class BranchingGUB(Environment): + __Dynamics__ = ecole.dynamics.BranchingGUBDynamics + __DefaultObservationFunction__ = ecole.observation.NodeBipartite + + class Configuring(Environment): __Dynamics__ = ecole.dynamics.ConfiguringDynamics diff --git a/python/tests/test_dynamics.py b/python/tests/test_dynamics.py index 48c8ef924..6a9c5054a 100644 --- a/python/tests/test_dynamics.py +++ b/python/tests/test_dynamics.py @@ -43,9 +43,9 @@ def test_full_trajectory(self, model): def test_exception(self, model): """Bad action raise exceptions.""" - with pytest.raises((ecole.Exception, ecole.scip.Exception)): - self.dynamics.reset_dynamics(model) - self.dynamics.step_dynamics(model, self.bad_action) + with pytest.raises((ecole.scip.Exception, ValueError)): + _, action_set = self.dynamics.reset_dynamics(model) + self.dynamics.step_dynamics(model, self.bad_policy(action_set)) def test_set_random_state(self, model): """Random engine is consumed.""" @@ -62,13 +62,24 @@ def assert_action_set(action_set): assert action_set.size > 0 assert action_set.dtype == np.uint64 + @staticmethod + def policy(action_set): + return action_set[0] + + @staticmethod + def bad_policy(action_set): + return 1 << 31 + def setup_method(self, method): self.dynamics = ecole.dynamics.BranchingDynamics(False) - self.policy = lambda action_set: action_set[0] - self.bad_action = 1 << 31 -class TestBranchingPseudocost(DynamicsUnitTests): +class TestBranching_Pseudocandidate(TestBranching): + def setup_method(self, method): + self.dynamics = ecole.dynamics.BranchingDynamics(True) + + +class TestBranchingGUB_List(DynamicsUnitTests): @staticmethod def assert_action_set(action_set): assert isinstance(action_set, np.ndarray) @@ -76,10 +87,22 @@ def assert_action_set(action_set): assert action_set.size > 0 assert action_set.dtype == np.uint64 + @staticmethod + def policy(action_set): + return [action_set[0]] + + @staticmethod + def bad_policy(action_set): + return [1 << 31] + def setup_method(self, method): - self.dynamics = ecole.dynamics.BranchingDynamics(True) - self.policy = lambda action_set: action_set[0] - self.bad_action = 1 << 31 + self.dynamics = ecole.dynamics.BranchingGUBDynamics() + + +class TestBranchingGUB_Numpy(TestBranchingGUB_List): + @staticmethod + def policy(action_set): + return np.array([action_set[0]]) class TestConfiguring(DynamicsUnitTests): @@ -87,13 +110,19 @@ class TestConfiguring(DynamicsUnitTests): def assert_action_set(action_set): assert action_set is None - def setup_method(self, method): - self.dynamics = ecole.dynamics.ConfiguringDynamics() - self.policy = lambda _: { + @staticmethod + def policy(action_set): + return { "branching/scorefunc": "s", "branching/scorefac": 0.1, "branching/divingpscost": False, "conflict/lpiterations": 0, "heuristics/undercover/fixingalts": "ln", } - self.bad_action = {"not/a/parameter": 44} + + @staticmethod + def bad_policy(action_set): + return {"not/a/parameter": 44} + + def setup_method(self, method): + self.dynamics = ecole.dynamics.ConfiguringDynamics()