Skip to content

Commit

Permalink
test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mb706 committed Aug 26, 2024
1 parent 5f92c05 commit ee1b088
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
6 changes: 3 additions & 3 deletions R/ParamSetCollection.R
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,12 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,

in_tune_fn = cargo$in_tune_fn

set_ids = info$ids
prefixed_set_ids = private$.add_name_prefix(prefix, info$ids)
cargo$in_tune_fn = crate(function(domain, param_vals) {
param_vals = param_vals[names(param_vals) %in% private$.add_name_prefix(prefix, set_ids)]
param_vals = param_vals[names(param_vals) %in% prefixed_set_ids]
names(param_vals) = gsub(sprintf("^\\Q%s.\\E", prefix), "", names(param_vals))
in_tune_fn(domain, param_vals)
}, in_tune_fn, prefix, set_ids)
}, in_tune_fn, prefix, prefixed_set_ids)

if (length(cargo$disable_in_tune)) {
cargo$disable_in_tune = set_names(
Expand Down
10 changes: 5 additions & 5 deletions tests/testthat/test_ParamSetCollection.R
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ test_that("PSC postfix", {
ps3 = ps(x.y = p_fct(c("a", "b")))
ps4 = ps(y.x = p_lgl())

psc = ParamSetCollection$new(list(y = ps1, z = ps2), postfix = TRUE)
psc = ParamSetCollection$new(list(y = ps1, z = ps2), postfix_names = TRUE)

expect_equal(psc$ids(), c("x.y", "x.z"))

Expand Down Expand Up @@ -432,14 +432,14 @@ test_that("PSC postfix", {
expect_equal(ps4$values, named_list())

# mixed with / without names
psc = ParamSetCollection$new(list(y = ps1, ps4), postfix = TRUE)
psc = ParamSetCollection$new(list(y = ps1, ps4), postfix_names = TRUE)
expect_equal(psc$ids(), c("x.y", "y.x"))
psc$values$x.y = 1
expect_equal(psc$values, list(x.y = 1))
psc$values$y.x = TRUE
expect_equal(psc$values, list(x.y = 1, y.x = TRUE))
expect_equal(ps1$values$x, 1)
expect_equal(ps4$values$y, TRUE)
expect_equal(ps4$values$y.x, TRUE)

ps4$extra_trafo = function(x, param_set) {
x$zzz = 888
Expand All @@ -449,10 +449,10 @@ test_that("PSC postfix", {
expect_equal(psc$trafo(list()), list(x.z.y = 999, zzz = 888))

# x.y generated twice here
expect_error(ParamSetCollection$new(list(y = ps1, ps3), postfix = TRUE), "would contain duplicated parameter.* x.y")
expect_error(ParamSetCollection$new(list(y = ps1, ps3), postfix_names = TRUE), "would contain duplicated parameter.* x.y")

# don't get confused when no names are given
psc = ParamSetCollection$new(list(ps3, ps4), postfix = TRUE)
psc = ParamSetCollection$new(list(ps3, ps4), postfix_names = TRUE)

expect_equal(psc$ids(), c("x.y", "y.x"))
psc$values$x.y = "a"
Expand Down

0 comments on commit ee1b088

Please sign in to comment.