Skip to content

Commit

Permalink
feat: aggr function for inner tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Apr 16, 2024
1 parent 30addd7 commit 441d754
Show file tree
Hide file tree
Showing 23 changed files with 139 additions and 45 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Config/testthat/edition: 3
Config/testthat/parallel: false
NeedsCompilation: no
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
VignetteBuilder: knitr
Collate:
'Condition.R'
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ S3method(format,Condition)
S3method(print,Condition)
S3method(print,Domain)
S3method(print,FullTuneToken)
S3method(print,InnerTuneToken)
S3method(print,ObjectTuneToken)
S3method(print,RangeTuneToken)
S3method(rd_info,ParamSet)
Expand Down Expand Up @@ -85,6 +86,7 @@ export(generate_design_grid)
export(generate_design_lhs)
export(generate_design_random)
export(generate_design_sobol)
export(in_tune)
export(p_dbl)
export(p_fct)
export(p_int)
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# dev

* feat: added `aggr`(egation function) to `Domain` which can be used for inner
tuning.

# paradox 0.12.0
* Removed `Param` objects. `ParamSet` now uses a `data.table` internally; individual parameters are more like `Domain` objects now. `ParamSets` should be constructed using the `ps()` shorthand and `Domain` objects. This entails the following major changes:
* `ParamSet` now supports `extra_trafo` natively; it behaves like `.extra_trafo` of the `ps()` call.
Expand Down
18 changes: 14 additions & 4 deletions R/Domain.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@
#' @param init (`any`)\cr
#' Initial value. When this is given, then the corresponding entry in `ParamSet$values` is initialized with this
#' value upon construction.
#' @param aggr (`function`)\cr
#' Function with one argument, which is a list of parameter values.
#' The function specifies how this list of parameter values is aggregated to form one parameter value.
#' This is used in the context of inner tuning. The default is to aggregate the values.
#'
#' @return A `Domain` object.
#'
Expand Down Expand Up @@ -134,7 +138,8 @@ Domain = function(cls, grouping,
trafo = NULL,
depends_expr = NULL,
storage_type = "list",
init) {
init,
aggr = NULL) {

assert_string(cls)
assert_string(grouping)
Expand All @@ -146,7 +151,11 @@ Domain = function(cls, grouping,
if (length(special_vals) && !is.null(trafo)) stop("trafo and special_values can not both be given at the same time.")
assert_character(tags, any.missing = FALSE, unique = TRUE)
assert_function(trafo, null.ok = TRUE)
assert_function(aggr, null.ok = TRUE, nargs = 1L)

if (is.null(aggr) && "inner_tuning" %in% tags) {
aggr = default_aggr
}

# depends may be an expression, but may also be quote() or expression()
if (length(depends_expr) == 1) {
Expand All @@ -168,9 +177,9 @@ Domain = function(cls, grouping,
.tags = list(tags),
.trafo = list(trafo),
.requirements = list(parse_depends(depends_expr, parent.frame(2))),

.init_given = !missing(init),
.init = list(if (!missing(init)) init)
.init = list(if (!missing(init)) init),
.aggr = list(aggr)
)

class(param) = c(cls, "Domain", class(param))
Expand Down Expand Up @@ -215,7 +224,8 @@ empty_domain = data.table(id = character(0), cls = character(0), grouping = char
.trafo = list(),
.requirements = list(),
.init_given = logical(0),
.init = list()
.init = list(),
.aggr = list()
)

domain_names = names(empty_domain)
Expand Down
4 changes: 2 additions & 2 deletions R/ParamDbl.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' @rdname Domain
#' @export
p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init) {
p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL) {
assert_number(tolerance, lower = 0)
assert_number(lower)
assert_number(upper)
Expand All @@ -18,7 +18,7 @@ p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_
}

Domain(cls = "ParamDbl", grouping = "ParamDbl", lower = real_lower, upper = real_upper, special_vals = special_vals, default = default, tags = tags, tolerance = tolerance, trafo = trafo, storage_type = "numeric",
depends_expr = substitute(depends), init = init, cargo = if (logscale) "logscale")
depends_expr = substitute(depends), init = init, cargo = if (logscale) "logscale", aggr = aggr)
}

#' @export
Expand Down
4 changes: 2 additions & 2 deletions R/ParamFct.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' @rdname Domain
#' @export
p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init) {
p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL) {
constargs = as.list(match.call()[-1])
levels = eval.parent(constargs$levels)
if (!is.character(levels)) {
Expand All @@ -22,7 +22,7 @@ p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = charact
# group p_fct by levels, so the group can be checked in a vectorized fashion.
# We escape '"' and '\' to '\"' and '\\', respectively.
grouping = str_collapse(gsub("([\\\\\"])", "\\\\\\1", sort(real_levels)), quote = '"', sep = ",")
Domain(cls = "ParamFct", grouping = grouping, levels = real_levels, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "character", depends_expr = substitute(depends), init = init)
Domain(cls = "ParamFct", grouping = grouping, levels = real_levels, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "character", depends_expr = substitute(depends), init = init, aggr = aggr)
}

#' @export
Expand Down
4 changes: 2 additions & 2 deletions R/ParamInt.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

#' @rdname Domain
#' @export
p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init) {
p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL) {
assert_number(tolerance, lower = 0, upper = 0.5)
# assert_int will stop for `Inf` values, which we explicitly allow as lower / upper bound
if (!isTRUE(is.infinite(lower))) assert_int(lower, tol = 1e-300) else assert_number(lower)
Expand All @@ -25,7 +25,7 @@ p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_

Domain(cls = cls, grouping = cls, lower = real_lower, upper = real_upper, special_vals = special_vals, default = default, tags = tags, tolerance = tolerance, trafo = trafo,
storage_type = storage_type,
depends_expr = substitute(depends), init = init, cargo = if (logscale) "logscale")
depends_expr = substitute(depends), init = init, cargo = if (logscale) "logscale", aggr = aggr)
}

#' @export
Expand Down
4 changes: 2 additions & 2 deletions R/ParamLgl.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#' @rdname Domain
#' @export
p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init) {
p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL) {
Domain(cls = "ParamLgl", grouping = "ParamLgl", levels = c(TRUE, FALSE), special_vals = special_vals, default = default,
tags = tags, trafo = trafo, storage_type = "logical", depends_expr = substitute(depends), init = init)
tags = tags, trafo = trafo, storage_type = "logical", depends_expr = substitute(depends), init = init, aggr = aggr)
}

#' @export
Expand Down
7 changes: 6 additions & 1 deletion R/ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ ParamSet = R6Class("ParamSet",
private$.trafos = setkeyv(paramtbl[!map_lgl(.trafo, is.null), .(id, trafo = .trafo)], "id")
}

if (".aggr" %in% names(paramtbl)) {
private$.aggrs = setkeyv(paramtbl[!map_lgl(.aggr, is.null), .(id, aggr = .aggr)], "id")
}

if (".requirements" %in% names(paramtbl)) {
requirements = paramtbl$.requirements
private$.params = paramtbl # self$add_dep needs this
Expand Down Expand Up @@ -645,7 +649,7 @@ ParamSet = R6Class("ParamSet",
if (nrow(deps)) { # add a nice extra charvec-col to the tab, which lists all parents-ids
on = NULL
dd = deps[, list(parents = list(unlist(on))), by = "id"]
d = merge(d, dd, on = "id", all.x = TRUE)
d = merge(d, dd, by = "id", all.x = TRUE)
}
v = named_list(d$id) # add values to last col of print-dt as list col
v = insert_named(v, self$values)
Expand Down Expand Up @@ -872,6 +876,7 @@ ParamSet = R6Class("ParamSet",
.tags = data.table(id = character(0L), tag = character(0), key = "id"),
.deps = data.table(id = character(0L), on = character(0L), cond = list()),
.trafos = data.table(id = character(0L), trafo = list(), key = "id"),
.aggrs = data.table(id = character(0L), aggr = list(), key = "id"),

get_tune_ps = function(values) {
values = keep(values, inherits, "TuneToken")
Expand Down
4 changes: 2 additions & 2 deletions R/ParamUty.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

#' @rdname Domain
#' @export
p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, repr = substitute(default), init) {
p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, repr = substitute(default), init, aggr = NULL) {
assert_function(custom_check, null.ok = TRUE)
if (!is.null(custom_check)) {
custom_check_result = custom_check(1)
Expand All @@ -12,7 +12,7 @@ p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, t
} else {
"NoDefault"
}
Domain(cls = "ParamUty", grouping = "ParamUty", cargo = list(custom_check = custom_check, repr = repr), special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "list", depends_expr = substitute(depends), init = init)
Domain(cls = "ParamUty", grouping = "ParamUty", cargo = list(custom_check = custom_check, repr = repr), special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "list", depends_expr = substitute(depends), init = init, aggr = aggr)
}

#' @export
Expand Down
8 changes: 8 additions & 0 deletions R/helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,11 @@ col_to_nl = function(dt, col = 1, idcol = 2) {
names(data) = dt[[idcol]]
data
}

default_aggr = function(x) {
if (!test_numeric(x[[1]], len = 1L)) {
stopf("Provide a custom aggregator for non-numeric and non-scalar parameters.")
}
ceiling(mean(unlist(x)))
}

35 changes: 35 additions & 0 deletions R/to_tune.R
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,27 @@ to_tune = function(...) {
set_class(list(content = content, call = deparse1(call)), c(type, "TuneToken"))
}

#' @title Create an Inner Tuning Token
#' @description
#' Works just like [`to_tune()`], but marks the parameter for inner tuning.
#' See [`mlr3::Learner`] for more information.
#' @inheritParams to_tune
#' @param aggr (`function`)\cr
#' The aggregator function that determines how to aggregate a list of parameter values into one value.
#' a single parameter value. The default is to average them.
#' @export
in_tune = function(..., aggr = NULL) {
if (is.null(aggr)) {
aggr = default_aggr
} else {
test_function(aggr, nargs = 1L)
}
tt = to_tune(...)
tt$aggr = aggr
tt = set_class(tt, classes = c("InnerTuneToken", class(tt)))
return(tt)
}

#' @export
print.FullTuneToken = function(x, ...) {
catf("Tuning over:\n<entire parameter range%s>\n",
Expand All @@ -201,6 +222,12 @@ print.ObjectTuneToken = function(x, ...) {
print(x$content)
}

#' @export
print.InnerTuneToken = function(x, ...) {
cat("Inner ")
NextMethod()
}

# tunetoken_to_ps: Convert a `TuneToken` to a `ParamSet` that tunes over this.
# Needs the corresponding `Domain` to which the `TuneToken` refers, both to
# get the range (e.g. if `to_tune()` was used) and to verify that the `TuneToken`
Expand All @@ -212,6 +239,13 @@ tunetoken_to_ps = function(tt, param) {
UseMethod("tunetoken_to_ps")
}

tunetoken_to_ps.InnerTuneToken = function(tt, params) {
ps = NextMethod()
browser()
ps$tags = map(ps$tags, function(tags) union(tags, "inner_tune"))
return(ps)
}

tunetoken_to_ps.FullTuneToken = function(tt, param) {
if (!domain_is_bounded(param)) {
stopf("%s must give a range for unbounded parameter %s.", tt$call, param$id)
Expand All @@ -224,6 +258,7 @@ tunetoken_to_ps.FullTuneToken = function(tt, param) {
}
}


tunetoken_to_ps.RangeTuneToken = function(tt, param) {
if (!domain_is_number(param)) {
stopf("%s for non-numeric param must have zero or one argument.", tt$call)
Expand Down
20 changes: 15 additions & 5 deletions man/Domain.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/Sampler.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/Sampler1D.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/Sampler1DCateg.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 441d754

Please sign in to comment.