Skip to content

Commit

Permalink
fix: add tags to domains created for inner tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed May 6, 2024
1 parent 41e7d81 commit 2a43aab
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
19 changes: 15 additions & 4 deletions R/to_tune.R
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,11 @@ print.ObjectTuneToken = function(x, ...) {
#
# Makes liberal use to `pslike_to_ps` (converting Param, ParamSet, Domain to ParamSet)
# param is a data.table that is potentially modified by reference using data.table set() methods.
tunetoken_to_ps = function(tt, param) {
tunetoken_to_ps = function(tt, param, ...) {
UseMethod("tunetoken_to_ps")
}

tunetoken_to_ps.FullTuneToken = function(tt, param) {
tunetoken_to_ps.FullTuneToken = function(tt, param, ...) {
if (!domain_is_bounded(param)) {
stopf("%s must give a range for unbounded parameter %s.", tt$call, param$id)
}
Expand All @@ -258,7 +258,18 @@ tunetoken_to_ps.FullTuneToken = function(tt, param) {
}
}

tunetoken_to_ps.RangeTuneToken = function(tt, param) {
tunetoken_to_ps.InnerTuneToken = function(tt, param, ...) {
# Calling NextMethod with additional arguments behaves weirdly, as the InnerTuneToken only works with ranges right now
# we just call it directly
aggr = if (!is.null(tt$content$aggr)) tt$content$aggr else param$cargo[[1L]]$aggr
if (is.null(aggr)) {
stopf("%s must specify a aggregation function for parameter %s", tt$call, param$id)
}
tunetoken_to_ps.RangeTuneToken(tt = tt, param = param, in_tune_fn = param$cargo[[1L]]$in_tune_fn, tags = "inner_tuning",
aggr = aggr)
}

tunetoken_to_ps.RangeTuneToken = function(tt, param, args = list(), ...) {
if (!domain_is_number(param)) {
stopf("%s for non-numeric param must have zero or one argument.", tt$call)
}
Expand All @@ -280,7 +291,7 @@ tunetoken_to_ps.RangeTuneToken = function(tt, param) {
# create p_int / p_dbl object. Doesn't work if there is a numeric param class that we don't know about :-/
constructor = switch(param$cls, ParamInt = p_int, ParamDbl = p_dbl,
stopf("%s: logscale for parameter %s of class %s not supported", tt$call, param$id, param$class))
content = constructor(lower = bound_lower, upper = bound_upper, logscale = tt$content$logscale, aggr = tt$content$aggr)
content = constructor(lower = bound_lower, upper = bound_upper, logscale = tt$content$logscale, ...)
pslike_to_ps(content, tt$call, param)
}

Expand Down
13 changes: 13 additions & 0 deletions tests/testthat/test_ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -475,3 +475,16 @@ test_that("get_values works with inner_tune", {
param_set$set_values(a = to_tune())
expect_list(param_set$get_values(type = "with_inner"), len = 0L)
})

test_that("InnerTuneToken is translated to 'inner_tuning' tag when creating search space", {
param_set = ps(
a = p_int(0, Inf, tags = "inner_tuning", in_tune_fn = function(domain, param_set) domain$upper, aggr = function(x) round(mean(unlist(x))))
)

param_set$set_values(
a = to_tune(upper = 100, inner = TRUE)
)

ss = param_set$search_space()
expect_true("inner_tuning" %in% ss$tags$a)
})
2 changes: 1 addition & 1 deletion tests/testthat/test_to_tune.R
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ test_that("inner and aggr", {
param_set = ps(a = p_dbl(lower = 1, upper = 2, tags = "inner_tuning", in_tune_fn = function(domain, param_set) domain$upper))

# correct errors
expect_error(param_set$set_values(a = to_tune(inner = TRUE)), "but no aggregation function is available")
expect_error(param_set$set_values(a = to_tune(inner = TRUE)), "aggregation")
expect_error(param_set$set_values(a = to_tune(inner = FALSE, aggr = function(x) 1)))


Expand Down

0 comments on commit 2a43aab

Please sign in to comment.