From 5a84ccbb3ccd0080d65ec879c9e2a46c43ca9a43 Mon Sep 17 00:00:00 2001 From: mb706 Date: Thu, 15 Aug 2024 15:51:40 +0200 Subject: [PATCH] detached extra_trafo (#411) * detached extra_trafo * fixes and document() * NEWS * fix * crating fix * dumb fix --- NEWS.md | 2 + R/ParamSet.R | 1 - R/ParamSetCollection.R | 152 +++++++++++++++++++++++++++----------- man/ParamSetCollection.Rd | 33 ++++++++- 4 files changed, 144 insertions(+), 44 deletions(-) diff --git a/NEWS.md b/NEWS.md index 54c6bc4b..c1a3b058 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # paradox 1.0.1-9000 +* `ParamSetCollection$flatten()` now detaches `$extra_trafo` completely from original ParamSetCollection. + # paradox 1.0.1 * Performance improvements. diff --git a/R/ParamSet.R b/R/ParamSet.R index 7fd9dfdc..5103b965 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -712,7 +712,6 @@ ParamSet = R6Class("ParamSet", result$assert_values = FALSE result$deps = deps[ids, on = "id", nomatch = NULL] if (keep_constraint) result$constraint = self$constraint - # TODO: ParamSetCollection trafo currently drags along the entire original paramset in its environment result$extra_trafo = self$extra_trafo # restrict to ids already in pvals values = self$values diff --git a/R/ParamSetCollection.R b/R/ParamSetCollection.R index 5ea93599..d8d2f3f6 100644 --- a/R/ParamSetCollection.R +++ b/R/ParamSetCollection.R @@ -129,9 +129,6 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, setindexv(paramtbl, c("id", "cls", "grouping")) private$.params = paramtbl - private$.children_with_trafos = which(!map_lgl(map(sets, "extra_trafo"), is.null)) - private$.children_with_constraints = which(!map_lgl(map(sets, "constraint"), is.null)) - private$.sets = sets }, @@ -188,21 +185,28 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, private$.params = rbind(private$.params, paramtbl) setindexv(private$.params, c("id", "cls", "grouping")) - if (!is.null(p$extra_trafo)) { - entry = if (n == "") length(private$.children_with_trafos) + 1 else n - private$.children_with_trafos[[entry]] = new_index - } - - if (!is.null(p$constraint)) { - entry = if (n == "") length(private$.children_with_constraints) + 1 else n - private$.children_with_constraints[[entry]] = new_index - } - entry = if (n == "") length(private$.sets) + 1 else n private$.sets[[n]] = p invisible(self) }, + #' @description + #' Create a new `ParamSet` restricted to the passed IDs. + #' @param ids (`character()`). + #' @param allow_dangling_dependencies (`logical(1)`)\cr + #' Whether to allow subsets that cut across parameter dependencies. + #' Dependencies that point to dropped parameters are kept (but will be "dangling", i.e. their `"on"` will not be present). + #' @param keep_constraint (`logical(1)`)\cr + #' Whether to keep the `$constraint` function. + #' @return `ParamSet`. + subset = function(ids, allow_dangling_dependencies = FALSE, keep_constraint = TRUE) { + # need to take care of extra_trafo and constraint. + result = super$subset(ids, allow_dangling_dependencies = allow_dangling_dependencies, keep_constraint = keep_constraint) + if (keep_constraint) result$constraint = private$.get_constraint_detached(ids) + result$extra_trafo = private$.get_extra_trafo_detached(ids) + result + }, + #' @description #' #' Set the parameter values so that internal tuning for the selected parameters is disabled. @@ -261,6 +265,8 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, flatten = function() { flatps = super$flatten() + # This function is a mistake. It should not have been written. Sorry for allowing it to be merged. + recurse_prefix = function(id_, param_set, prefix = "") { info = get_private(param_set)$.translation[list(id_), c("owner_name", "owner_ps_index"), on = "id"] prefix = if (info$owner_name == "") { @@ -334,15 +340,16 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, #' @template field_extra_trafo extra_trafo = function(f) { if (!missing(f)) stop("extra_trafo is read-only in ParamSetCollection.") - if (!length(private$.children_with_trafos)) return(NULL) - private$.extra_trafo_explicit + if (!length(private$.children_with_trafos())) return(NULL) + # The reason why we don't crate a function here is that the extra_trafo of private$.sets could change. + private$.extra_trafo_explicit }, #' @template field_constraint constraint = function(f) { if (!missing(f)) stop("constraint is read-only in ParamSetCollection.") - if (!length(private$.children_with_constraints)) return(NULL) + if (!length(private$.children_with_constraints())) return(NULL) private$.constraint_explicit }, @@ -376,36 +383,47 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, }, .sets = NULL, .translation = data.table(id = character(0), original_id = character(0), owner_ps_index = integer(0), owner_name = character(0), key = "id"), - .children_with_trafos = NULL, - .children_with_constraints = NULL, + .children_with_trafos = function() { + which(!map_lgl(map(private$.sets, "extra_trafo"), is.null)) + }, + .children_with_constraints = function() { + which(!map_lgl(map(private$.sets, "constraint"), is.null)) + }, .extra_trafo_explicit = function(x) { - changed = unlist(lapply(private$.children_with_trafos, function(set_index) { - changing_ids = private$.translation[J(set_index), id, on = "owner_ps_index"] - trafo = private$.sets[[set_index]]$extra_trafo - changing_values_in = x[names(x) %in% changing_ids] - names(changing_values_in) = private$.translation[names(changing_values_in), original_id] - # input of trafo() must not be changed after the call; otherwise the trafo would have to `force()` it in - # some circumstances. - changing_values = trafo(changing_values_in) - prefix = names(private$.sets)[[set_index]] - if (prefix != "") { - names(changing_values) = sprintf("%s.%s", prefix, names(changing_values)) - } - changing_values - }), recursive = FALSE) - unchanged_ids = private$.translation[!J(private$.children_with_trafos), id, on = "owner_ps_index"] - unchanged = x[names(x) %in% unchanged_ids] - c(unchanged, changed) + children_with_trafos = private$.children_with_trafos() + sets_with_trafos = private$.sets[children_with_trafos] + translation = private$.translation + psc_extra_trafo(x, children_with_trafos, sets_with_trafos, translation) + }, + # get an extra_trafo function that does not have any references to the PSC object or any of its contained sets. + # This is used for flattening. + # `ids`: subset of params to consider + .get_extra_trafo_detached = function(ids = NULL) { + translation = if (is.null(ids)) copy(private$.translation) else private$.translation[id %in% ids] + children_with_trafos = private$.children_with_trafos() # just an integer vector, no need to worry here + if (!is.null(ids)) { + children_with_trafos = intersect(children_with_trafos, translation$owner_ps_index) + } + if (!length(children_with_trafos)) return(NULL) + sets_with_trafos = lapply(private$.sets[children_with_trafos], function(x) x$clone(deep = TRUE)) # get new objects that are detached from PSC + crate(function(x) psc_extra_trafo(x, children_with_trafos, sets_with_trafos, translation), children_with_trafos, sets_with_trafos, translation, psc_extra_trafo) }, .constraint_explicit = function(x) { - for (set_index in private$.children_with_constraints) { - constraining_ids = private$.translation[J(set_index), id, on = "owner_ps_index"] - constraint = private$.sets[[set_index]]$constraint - constraining_values = x[names(x) %in% constraining_ids] - names(constraining_values) = private$.translation[names(constraining_values), original_id] - if (!constraint(x)) return(FALSE) + children_with_constraints = private$.children_with_constraints() + sets_with_constraints = private$.sets[children_with_constraints] + translation = private$.translation + psc_constraint(x, children_with_constraints, sets_with_constraints, translation) + }, + # same as with extra_trafo above + .get_constraint_detached = function(ids = NULL) { + translation = if (is.null(ids)) copy(private$.translation) else private$.translation[id %in% ids] + children_with_constraints = private$.children_with_constraints() + if (!is.null(ids)) { + children_with_constraints = intersect(children_with_constraints, translation$owner_ps_index) } - TRUE + if (!length(children_with_constraints)) return(NULL) + sets_with_constraints = lapply(private$.sets[children_with_constraints], function(x) x$clone(deep = TRUE)) + crate(function(x) psc_constraint(x, children_with_constraints, sets_with_constraints, translation), children_with_constraints, sets_with_constraints, translation, psc_constraint) }, deep_clone = function(name, value) { switch(name, @@ -418,3 +436,53 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, } ) ) + +# extra_trafo function for ParamSetCollection +# This function is used as extra_trafo for ParamSetCollection, in the case that any of its children has an extra_trafo. +# Arguments: +# - children_with_trafos: set-indices (i.e. index inside PSC's private$.sets, and inside `translation`) of children with extra_trafo +# - sets_with_trafos: subset of PSC's private$.sets of children with extra_trafo +# - translation: PSC's private$.translation +# +# We have this functoin outside of the ParamSetCollection class, because we anticipate that PSC can be "flattened", i.e. turned into +# a normal ParamSet. In that case, the resulting ParamSet's extra_trafo should be a function that can stand on its own, without +# referring to private$. +psc_extra_trafo = function(x, children_with_trafos, sets_with_trafos, translation) { + changed = unlist(lapply(seq_along(children_with_trafos), function(i) { + set_index = children_with_trafos[[i]] + changing_ids = translation[J(set_index), id, on = "owner_ps_index"] + trafo = sets_with_trafos[[i]]$extra_trafo + changing_values_in = x[names(x) %in% changing_ids] + names(changing_values_in) = translation[names(changing_values_in), original_id] + # input of trafo() must not be changed after the call; otherwise the trafo would have to `force()` it in + # some circumstances. + if (test_function(trafo, args = c("x", "param_set"))) { + changing_values = trafo(x = changing_values_in, param_set = sets_with_trafos[[i]]) + } else { + changing_values = trafo(changing_values_in) + } + changing_values = trafo(changing_values_in) + prefix = names(sets_with_trafos)[[i]] + if (prefix != "") { + names(changing_values) = sprintf("%s.%s", prefix, names(changing_values)) + } + changing_values + }), recursive = FALSE) + unchanged_ids = translation[!J(children_with_trafos), id, on = "owner_ps_index"] + unchanged = x[names(x) %in% unchanged_ids] + c(unchanged, changed) +} + +psc_constraint = function(x, children_with_constraints, sets_with_constraints, translation) { + for (i in seq_along(children_with_constraints)) { + set_index = children_with_constraints[[i]] + constraining_ids = translation[J(set_index), id, on = "owner_ps_index"] + constraint = sets_with_constraints[[i]]$constraint + constraining_values = x[names(x) %in% constraining_ids] + names(constraining_values) = translation[names(constraining_values), original_id] + if (!constraint(x)) return(FALSE) + } + TRUE +} + + diff --git a/man/ParamSetCollection.Rd b/man/ParamSetCollection.Rd index 1ce82386..6bf63db8 100644 --- a/man/ParamSetCollection.Rd +++ b/man/ParamSetCollection.Rd @@ -66,6 +66,7 @@ This field provides direct references to the \code{\link{ParamSet}} objects.} \itemize{ \item \href{#method-ParamSetCollection-new}{\code{ParamSetCollection$new()}} \item \href{#method-ParamSetCollection-add}{\code{ParamSetCollection$add()}} +\item \href{#method-ParamSetCollection-subset}{\code{ParamSetCollection$subset()}} \item \href{#method-ParamSetCollection-disable_internal_tuning}{\code{ParamSetCollection$disable_internal_tuning()}} \item \href{#method-ParamSetCollection-convert_internal_search_space}{\code{ParamSetCollection$convert_internal_search_space()}} \item \href{#method-ParamSetCollection-flatten}{\code{ParamSetCollection$flatten()}} @@ -90,7 +91,6 @@ This field provides direct references to the \code{\link{ParamSet}} objects.}
  • paradox::ParamSet$qunif()
  • paradox::ParamSet$search_space()
  • paradox::ParamSet$set_values()
  • -
  • paradox::ParamSet$subset()
  • paradox::ParamSet$subspaces()
  • paradox::ParamSet$test()
  • paradox::ParamSet$test_constraint()
  • @@ -152,6 +152,37 @@ Whether to add tags of the form \code{"param_"} to each parameter with } } \if{html}{\out{
    }} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ParamSetCollection-subset}{}}} +\subsection{Method \code{subset()}}{ +Create a new \code{ParamSet} restricted to the passed IDs. +\subsection{Usage}{ +\if{html}{\out{
    }}\preformatted{ParamSetCollection$subset( + ids, + allow_dangling_dependencies = FALSE, + keep_constraint = TRUE +)}\if{html}{\out{
    }} +} + +\subsection{Arguments}{ +\if{html}{\out{
    }} +\describe{ +\item{\code{ids}}{(\code{character()}).} + +\item{\code{allow_dangling_dependencies}}{(\code{logical(1)})\cr +Whether to allow subsets that cut across parameter dependencies. +Dependencies that point to dropped parameters are kept (but will be "dangling", i.e. their \code{"on"} will not be present).} + +\item{\code{keep_constraint}}{(\code{logical(1)})\cr +Whether to keep the \verb{$constraint} function.} +} +\if{html}{\out{
    }} +} +\subsection{Returns}{ +\code{ParamSet}. +} +} +\if{html}{\out{
    }} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ParamSetCollection-disable_internal_tuning}{}}} \subsection{Method \code{disable_internal_tuning()}}{