Skip to content

Commit

Permalink
detached extra_trafo (#411)
Browse files Browse the repository at this point in the history
* detached extra_trafo

* fixes and document()

* NEWS

* fix

* crating fix

* dumb fix
  • Loading branch information
mb706 authored Aug 15, 2024
1 parent 589858f commit 5a84ccb
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 44 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# paradox 1.0.1-9000

* `ParamSetCollection$flatten()` now detaches `$extra_trafo` completely from original ParamSetCollection.

# paradox 1.0.1

* Performance improvements.
Expand Down
1 change: 0 additions & 1 deletion R/ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,6 @@ ParamSet = R6Class("ParamSet",
result$assert_values = FALSE
result$deps = deps[ids, on = "id", nomatch = NULL]
if (keep_constraint) result$constraint = self$constraint
# TODO: ParamSetCollection trafo currently drags along the entire original paramset in its environment
result$extra_trafo = self$extra_trafo
# restrict to ids already in pvals
values = self$values
Expand Down
152 changes: 110 additions & 42 deletions R/ParamSetCollection.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,6 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
setindexv(paramtbl, c("id", "cls", "grouping"))
private$.params = paramtbl

private$.children_with_trafos = which(!map_lgl(map(sets, "extra_trafo"), is.null))
private$.children_with_constraints = which(!map_lgl(map(sets, "constraint"), is.null))

private$.sets = sets
},

Expand Down Expand Up @@ -188,21 +185,28 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
private$.params = rbind(private$.params, paramtbl)
setindexv(private$.params, c("id", "cls", "grouping"))

if (!is.null(p$extra_trafo)) {
entry = if (n == "") length(private$.children_with_trafos) + 1 else n
private$.children_with_trafos[[entry]] = new_index
}

if (!is.null(p$constraint)) {
entry = if (n == "") length(private$.children_with_constraints) + 1 else n
private$.children_with_constraints[[entry]] = new_index
}

entry = if (n == "") length(private$.sets) + 1 else n
private$.sets[[n]] = p
invisible(self)
},

#' @description
#' Create a new `ParamSet` restricted to the passed IDs.
#' @param ids (`character()`).
#' @param allow_dangling_dependencies (`logical(1)`)\cr
#' Whether to allow subsets that cut across parameter dependencies.
#' Dependencies that point to dropped parameters are kept (but will be "dangling", i.e. their `"on"` will not be present).
#' @param keep_constraint (`logical(1)`)\cr
#' Whether to keep the `$constraint` function.
#' @return `ParamSet`.
subset = function(ids, allow_dangling_dependencies = FALSE, keep_constraint = TRUE) {
# need to take care of extra_trafo and constraint.
result = super$subset(ids, allow_dangling_dependencies = allow_dangling_dependencies, keep_constraint = keep_constraint)
if (keep_constraint) result$constraint = private$.get_constraint_detached(ids)
result$extra_trafo = private$.get_extra_trafo_detached(ids)
result
},

#' @description
#'
#' Set the parameter values so that internal tuning for the selected parameters is disabled.
Expand Down Expand Up @@ -261,6 +265,8 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
flatten = function() {
flatps = super$flatten()

# This function is a mistake. It should not have been written. Sorry for allowing it to be merged.

recurse_prefix = function(id_, param_set, prefix = "") {
info = get_private(param_set)$.translation[list(id_), c("owner_name", "owner_ps_index"), on = "id"]
prefix = if (info$owner_name == "") {
Expand Down Expand Up @@ -334,15 +340,16 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
#' @template field_extra_trafo
extra_trafo = function(f) {
if (!missing(f)) stop("extra_trafo is read-only in ParamSetCollection.")
if (!length(private$.children_with_trafos)) return(NULL)
private$.extra_trafo_explicit
if (!length(private$.children_with_trafos())) return(NULL)

# The reason why we don't crate a function here is that the extra_trafo of private$.sets could change.
private$.extra_trafo_explicit
},

#' @template field_constraint
constraint = function(f) {
if (!missing(f)) stop("constraint is read-only in ParamSetCollection.")
if (!length(private$.children_with_constraints)) return(NULL)
if (!length(private$.children_with_constraints())) return(NULL)
private$.constraint_explicit
},

Expand Down Expand Up @@ -376,36 +383,47 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
},
.sets = NULL,
.translation = data.table(id = character(0), original_id = character(0), owner_ps_index = integer(0), owner_name = character(0), key = "id"),
.children_with_trafos = NULL,
.children_with_constraints = NULL,
.children_with_trafos = function() {
which(!map_lgl(map(private$.sets, "extra_trafo"), is.null))
},
.children_with_constraints = function() {
which(!map_lgl(map(private$.sets, "constraint"), is.null))
},
.extra_trafo_explicit = function(x) {
changed = unlist(lapply(private$.children_with_trafos, function(set_index) {
changing_ids = private$.translation[J(set_index), id, on = "owner_ps_index"]
trafo = private$.sets[[set_index]]$extra_trafo
changing_values_in = x[names(x) %in% changing_ids]
names(changing_values_in) = private$.translation[names(changing_values_in), original_id]
# input of trafo() must not be changed after the call; otherwise the trafo would have to `force()` it in
# some circumstances.
changing_values = trafo(changing_values_in)
prefix = names(private$.sets)[[set_index]]
if (prefix != "") {
names(changing_values) = sprintf("%s.%s", prefix, names(changing_values))
}
changing_values
}), recursive = FALSE)
unchanged_ids = private$.translation[!J(private$.children_with_trafos), id, on = "owner_ps_index"]
unchanged = x[names(x) %in% unchanged_ids]
c(unchanged, changed)
children_with_trafos = private$.children_with_trafos()
sets_with_trafos = private$.sets[children_with_trafos]
translation = private$.translation
psc_extra_trafo(x, children_with_trafos, sets_with_trafos, translation)
},
# get an extra_trafo function that does not have any references to the PSC object or any of its contained sets.
# This is used for flattening.
# `ids`: subset of params to consider
.get_extra_trafo_detached = function(ids = NULL) {
translation = if (is.null(ids)) copy(private$.translation) else private$.translation[id %in% ids]
children_with_trafos = private$.children_with_trafos() # just an integer vector, no need to worry here
if (!is.null(ids)) {
children_with_trafos = intersect(children_with_trafos, translation$owner_ps_index)
}
if (!length(children_with_trafos)) return(NULL)
sets_with_trafos = lapply(private$.sets[children_with_trafos], function(x) x$clone(deep = TRUE)) # get new objects that are detached from PSC
crate(function(x) psc_extra_trafo(x, children_with_trafos, sets_with_trafos, translation), children_with_trafos, sets_with_trafos, translation, psc_extra_trafo)
},
.constraint_explicit = function(x) {
for (set_index in private$.children_with_constraints) {
constraining_ids = private$.translation[J(set_index), id, on = "owner_ps_index"]
constraint = private$.sets[[set_index]]$constraint
constraining_values = x[names(x) %in% constraining_ids]
names(constraining_values) = private$.translation[names(constraining_values), original_id]
if (!constraint(x)) return(FALSE)
children_with_constraints = private$.children_with_constraints()
sets_with_constraints = private$.sets[children_with_constraints]
translation = private$.translation
psc_constraint(x, children_with_constraints, sets_with_constraints, translation)
},
# same as with extra_trafo above
.get_constraint_detached = function(ids = NULL) {
translation = if (is.null(ids)) copy(private$.translation) else private$.translation[id %in% ids]
children_with_constraints = private$.children_with_constraints()
if (!is.null(ids)) {
children_with_constraints = intersect(children_with_constraints, translation$owner_ps_index)
}
TRUE
if (!length(children_with_constraints)) return(NULL)
sets_with_constraints = lapply(private$.sets[children_with_constraints], function(x) x$clone(deep = TRUE))
crate(function(x) psc_constraint(x, children_with_constraints, sets_with_constraints, translation), children_with_constraints, sets_with_constraints, translation, psc_constraint)
},
deep_clone = function(name, value) {
switch(name,
Expand All @@ -418,3 +436,53 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
}
)
)

# extra_trafo function for ParamSetCollection
# This function is used as extra_trafo for ParamSetCollection, in the case that any of its children has an extra_trafo.
# Arguments:
# - children_with_trafos: set-indices (i.e. index inside PSC's private$.sets, and inside `translation`) of children with extra_trafo
# - sets_with_trafos: subset of PSC's private$.sets of children with extra_trafo
# - translation: PSC's private$.translation
#
# We have this functoin outside of the ParamSetCollection class, because we anticipate that PSC can be "flattened", i.e. turned into
# a normal ParamSet. In that case, the resulting ParamSet's extra_trafo should be a function that can stand on its own, without
# referring to private$<anything>.
psc_extra_trafo = function(x, children_with_trafos, sets_with_trafos, translation) {
changed = unlist(lapply(seq_along(children_with_trafos), function(i) {
set_index = children_with_trafos[[i]]
changing_ids = translation[J(set_index), id, on = "owner_ps_index"]
trafo = sets_with_trafos[[i]]$extra_trafo
changing_values_in = x[names(x) %in% changing_ids]
names(changing_values_in) = translation[names(changing_values_in), original_id]
# input of trafo() must not be changed after the call; otherwise the trafo would have to `force()` it in
# some circumstances.
if (test_function(trafo, args = c("x", "param_set"))) {
changing_values = trafo(x = changing_values_in, param_set = sets_with_trafos[[i]])
} else {
changing_values = trafo(changing_values_in)
}
changing_values = trafo(changing_values_in)
prefix = names(sets_with_trafos)[[i]]
if (prefix != "") {
names(changing_values) = sprintf("%s.%s", prefix, names(changing_values))
}
changing_values
}), recursive = FALSE)
unchanged_ids = translation[!J(children_with_trafos), id, on = "owner_ps_index"]
unchanged = x[names(x) %in% unchanged_ids]
c(unchanged, changed)
}

psc_constraint = function(x, children_with_constraints, sets_with_constraints, translation) {
for (i in seq_along(children_with_constraints)) {
set_index = children_with_constraints[[i]]
constraining_ids = translation[J(set_index), id, on = "owner_ps_index"]
constraint = sets_with_constraints[[i]]$constraint
constraining_values = x[names(x) %in% constraining_ids]
names(constraining_values) = translation[names(constraining_values), original_id]
if (!constraint(x)) return(FALSE)
}
TRUE
}


33 changes: 32 additions & 1 deletion man/ParamSetCollection.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 5a84ccb

Please sign in to comment.