Skip to content

Commit

Permalink
Bind BranchingGUB to Python and test
Browse files Browse the repository at this point in the history
  • Loading branch information
AntoinePrv committed Jun 7, 2021
1 parent eddb682 commit 888bfe5
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 13 deletions.
17 changes: 17 additions & 0 deletions python/src/ecole/core/dynamics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <pybind11/stl.h>
#include <xtensor-python/pytensor.hpp>

#include "ecole/dynamics/branching-gub.hpp"
#include "ecole/dynamics/branching.hpp"
#include "ecole/dynamics/configuring.hpp"
#include "ecole/scip/model.hpp"
Expand Down Expand Up @@ -62,6 +63,22 @@ void bind_submodule(pybind11::module_ const& m) {
.def_set_dynamics_random_state()
.def(py::init<bool>(), py::arg("pseudo_candidates") = false);

using idx_t = typename BranchingGUBDynamics::Action::value_type;
using array_t = py::array_t<idx_t, py::array::c_style | py::array::forcecast>;
dynamics_class<BranchingGUBDynamics>{m, "BranchingGUBDynamics"}
.def_reset_dynamics()
.def_set_dynamics_random_state()
.def(
"step_dynamics",
[](BranchingGUBDynamics& self, scip::Model& model, array_t const& action) {
auto const vars = nonstd::span{action.data(), static_cast<std::size_t>(action.size())};
auto const release = py::gil_scoped_release{};
return self.step_dynamics(model, vars);
},
py::arg("model"),
py::arg("action"))
.def(py::init<>());

dynamics_class<ConfiguringDynamics>{m, "ConfiguringDynamics"}
.def_reset_dynamics()
.def_step_dynamics()
Expand Down
5 changes: 5 additions & 0 deletions python/src/ecole/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,5 +195,10 @@ class Branching(Environment):
__DefaultObservationFunction__ = ecole.observation.NodeBipartite


class BranchingGUB(Environment):
__Dynamics__ = ecole.dynamics.BranchingGUBDynamics
__DefaultObservationFunction__ = ecole.observation.NodeBipartite


class Configuring(Environment):
__Dynamics__ = ecole.dynamics.ConfiguringDynamics
55 changes: 42 additions & 13 deletions python/tests/test_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def test_full_trajectory(self, model):

def test_exception(self, model):
"""Bad action raise exceptions."""
with pytest.raises((ecole.Exception, ecole.scip.Exception)):
self.dynamics.reset_dynamics(model)
self.dynamics.step_dynamics(model, self.bad_action)
with pytest.raises((ecole.scip.Exception, ValueError)):
_, action_set = self.dynamics.reset_dynamics(model)
self.dynamics.step_dynamics(model, self.bad_policy(action_set))

def test_set_random_state(self, model):
"""Random engine is consumed."""
Expand All @@ -62,38 +62,67 @@ def assert_action_set(action_set):
assert action_set.size > 0
assert action_set.dtype == np.uint64

@staticmethod
def policy(action_set):
return action_set[0]

@staticmethod
def bad_policy(action_set):
return 1 << 31

def setup_method(self, method):
self.dynamics = ecole.dynamics.BranchingDynamics(False)
self.policy = lambda action_set: action_set[0]
self.bad_action = 1 << 31


class TestBranchingPseudocost(DynamicsUnitTests):
class TestBranching_Pseudocandidate(TestBranching):
def setup_method(self, method):
self.dynamics = ecole.dynamics.BranchingDynamics(True)


class TestBranchingGUB_List(DynamicsUnitTests):
@staticmethod
def assert_action_set(action_set):
assert isinstance(action_set, np.ndarray)
assert action_set.ndim == 1
assert action_set.size > 0
assert action_set.dtype == np.uint64

@staticmethod
def policy(action_set):
return [action_set[0]]

@staticmethod
def bad_policy(action_set):
return [1 << 31]

def setup_method(self, method):
self.dynamics = ecole.dynamics.BranchingDynamics(True)
self.policy = lambda action_set: action_set[0]
self.bad_action = 1 << 31
self.dynamics = ecole.dynamics.BranchingGUBDynamics()


class TestBranchingGUB_Numpy(TestBranchingGUB_List):
@staticmethod
def policy(action_set):
return np.array([action_set[0]])


class TestConfiguring(DynamicsUnitTests):
@staticmethod
def assert_action_set(action_set):
assert action_set is None

def setup_method(self, method):
self.dynamics = ecole.dynamics.ConfiguringDynamics()
self.policy = lambda _: {
@staticmethod
def policy(action_set):
return {
"branching/scorefunc": "s",
"branching/scorefac": 0.1,
"branching/divingpscost": False,
"conflict/lpiterations": 0,
"heuristics/undercover/fixingalts": "ln",
}
self.bad_action = {"not/a/parameter": 44}

@staticmethod
def bad_policy(action_set):
return {"not/a/parameter": 44}

def setup_method(self, method):
self.dynamics = ecole.dynamics.ConfiguringDynamics()

0 comments on commit 888bfe5

Please sign in to comment.