Skip to content

Commit

Permalink
make random_state sklearn compliant
Browse files Browse the repository at this point in the history
  • Loading branch information
rpreen committed Nov 3, 2023
1 parent ab835aa commit 3cd548d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 38 deletions.
2 changes: 1 addition & 1 deletion cfg/default.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"omp_num_threads": 8,
"random_state": 0,
"random_state": -1,
"pop_init": true,
"max_trials": 100000,
"perf_trials": 1000,
Expand Down
8 changes: 4 additions & 4 deletions xcsf/param.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ param_init(struct XCSF *xcsf, const int x_dim, const int y_dim,
param_set_x_dim(xcsf, x_dim);
param_set_y_dim(xcsf, y_dim);
param_set_omp_num_threads(xcsf, 8);
param_set_random_state(xcsf, 0);
param_set_random_state(xcsf, -1);
param_set_pop_init(xcsf, true);
param_set_max_trials(xcsf, 100000);
param_set_perf_trials(xcsf, 1000);
Expand Down Expand Up @@ -608,10 +608,10 @@ const char *
param_set_random_state(struct XCSF *xcsf, const int a)
{
xcsf->RANDOM_STATE = a;
if (a > 0) {
rand_init_seed(a);
} else {
if (a < 0) {
rand_init();
} else {
rand_init_seed(a);
}
return NULL;
}
Expand Down
52 changes: 19 additions & 33 deletions xcsf/pybind_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,36 +327,6 @@ class XCS

/* Supervised learning */

/**
* @brief Sets XCSF input and output dimensions.
* @param [in] n_x_dim Number of input dimensions.
* @param [in] x1 Size of second input dimension.
* @param [in] n_y_dim Number of output dimensions.
* @param [in] y1 Size of second output dimension.
*/
void
set_dims(int n_x_dim, int x1, int n_y_dim, int y1)
{
py::dict kwargs;
kwargs["n_actions"] = 1;
if (n_x_dim > 1) {
kwargs["x_dim"] = x1;
} else {
kwargs["x_dim"] = 1;
}
if (n_y_dim > 1) {
kwargs["y_dim"] = y1;
} else {
kwargs["y_dim"] = 1;
}
// update external params dict
for (const auto &item : kwargs) {
params[item.first] = item.second;
}
// flush param update to make sure neural nets resize
set_params(params);
}

/**
* @brief Loads an input data structure for fitting.
* @param [in,out] data Input data structure used to point to the data.
Expand Down Expand Up @@ -853,6 +823,13 @@ class XCS
py::object parsed_json = json.attr("loads")(json_str);
py::dict result(parsed_json);
params = result;
// map None types
if (params.contains("random_state")) {
py::object rs = params["random_state"];
if (py::isinstance<py::int_>(rs) && py::int_(rs) < 0) {
params["random_state"] = py::none();
}
}
free(json_str);
}

Expand All @@ -877,14 +854,23 @@ class XCS
set_params(py::kwargs kwargs)
{
py::dict kwargs_dict(kwargs);
// update external params dict
for (const auto &item : kwargs_dict) {
params[item.first] = item.second;
}
// map None types
if (kwargs_dict.contains("random_state")) {
py::object rs = kwargs["random_state"];
if (rs.is_none()) {
kwargs_dict["random_state"] = -1;
}
}
// convert dict to JSON and parse parameters
py::module json_module = py::module::import("json");
py::object json_dumps = json_module.attr("dumps")(kwargs_dict);
std::string json_str = json_dumps.cast<std::string>();
const char *json_params = json_str.c_str();
param_json_import(&xcs, json_params);
for (const auto &item : kwargs_dict) {
params[item.first] = item.second;
}
return *this;
}

Expand Down

0 comments on commit 3cd548d

Please sign in to comment.