Skip to content

Commit

Permalink
keep diff small
Browse files Browse the repository at this point in the history
  • Loading branch information
mb706 committed Jun 10, 2024
1 parent 949bbe8 commit 18e4088
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 36 deletions.
1 change: 0 additions & 1 deletion R/Design.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ Design = R6Class("Design",
# set fixed param vals to their constant values
# FIXME: this might also be problematic for LHS
# do we still create an LHS like this?

imap(param_set$values, function(v, n) set(data, j = n, value = v))
self$data = data
if (param_set$has_deps) {
Expand Down
2 changes: 2 additions & 0 deletions R/Domain.R
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ Domain = function(cls, grouping,
assert_character(tags, any.missing = FALSE, unique = TRUE)
assert_function(trafo, null.ok = TRUE)


# depends may be an expression, but may also be quote() or expression()
if (length(depends_expr) == 1) {
depends_expr = eval(depends_expr, envir = parent.frame(2))
Expand All @@ -217,6 +218,7 @@ Domain = function(cls, grouping,
.tags = list(tags),
.trafo = list(trafo),
.requirements = list(parse_depends(depends_expr, parent.frame(2))),

.init_given = !missing(init),
.init = list(if (!missing(init)) init)
)
Expand Down
57 changes: 22 additions & 35 deletions R/ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ ParamSet = R6Class("ParamSet",

if (".requirements" %in% names(paramtbl)) {
requirements = paramtbl$.requirements
private$.params = paramtbl # self$add_dep needs this
private$.params = paramtbl # self$add_dep needs this
for (row in seq_len(nrow(paramtbl))) {
for (req in requirements[[row]]) {
invoke(self$add_dep, id = paramtbl$id[[row]], allow_dangling_dependencies = allow_dangling_dependencies,
Expand All @@ -107,7 +107,7 @@ ParamSet = R6Class("ParamSet",

setindexv(paramtbl, c("id", "cls", "grouping"))

private$.params = paramtbl # I am 99% sure this is not necessary, but maybe set() creates a copy when deleting too many cols?
private$.params = paramtbl # I am 99% sure this is not necessary, but maybe set() creates a copy when deleting too many cols?

if (!is.null(initvalues)) self$values = initvalues
},
Expand Down Expand Up @@ -379,9 +379,7 @@ ParamSet = R6Class("ParamSet",
private$get_tune_ps(xs)
TRUE
}, error = function(e) paste("tune token invalid:", conditionMessage(e)))
if (!isTRUE(tunecheck)) {
return(tunecheck)
}
if (!isTRUE(tunecheck)) return(tunecheck)
}

xs_internaltune = keep(xs, is, "InternalTuneToken")
Expand Down Expand Up @@ -413,9 +411,7 @@ ParamSet = R6Class("ParamSet",
## if (length(required) > 0L) {
## return(sprintf("Missing required parameters: %s", str_collapse(required)))
## }
if (!self$test_constraint(xs, assert_value = FALSE)) {
return(sprintf("Constraint not fulfilled."))
}
if (!self$test_constraint(xs, assert_value = FALSE)) return(sprintf("Constraint not fulfilled."))
return(self$check_dependencies(xs))
}

Expand All @@ -430,37 +426,29 @@ ParamSet = R6Class("ParamSet",
#' @return If successful `TRUE`, if not a string with an error message.
check_dependencies = function(xs) {
deps = self$deps
if (!nrow(deps)) {
return(TRUE)
}
if (!nrow(deps)) return(TRUE)
params = private$.params
ns = names(xs)
errors = pmap(deps[id %in% ns], function(id, on, cond) {
onval = xs[[on]]
if (inherits(xs[[id]], "TuneToken") || inherits(onval, "TuneToken")) {
return(NULL)
}
if (inherits(xs[[id]], "TuneToken") || inherits(onval, "TuneToken")) return(NULL)

# we are ONLY ok if:
# - if 'id' is there, then 'on' must be there, and cond must be true
# - if 'id' is not there. but that is skipped (deps[id %in% ns] filter)
if (on %in% ns && condition_test(cond, onval)) {
return(NULL)
}
if (on %in% ns && condition_test(cond, onval)) return(NULL)
msg = sprintf("%s: can only be set if the following condition is met '%s'.",
id, condition_as_string(cond, on))
if (is.null(onval)) {
msg = sprintf(paste("%s Instead the parameter value for '%s' is not set at all.",
"Try setting '%s' to a value that satisfies the condition"), msg, on, on)
"Try setting '%s' to a value that satisfies the condition"), msg, on, on)
} else {
msg = sprintf("%s Instead the current parameter value is: %s == %s", msg, on, as_short_string(onval))
}
msg
})
errors = unlist(errors)
if (!length(errors)) {
return(TRUE)
}
if (!length(errors)) return(TRUE)
str_collapse(errors, sep = "\n")
},

Expand Down Expand Up @@ -491,7 +479,7 @@ ParamSet = R6Class("ParamSet",
#' Name of the checked object to print in error messages.\cr
#' Defaults to the heuristic implemented in [vname][checkmate::vname].
#' @return If successful `xs` invisibly, if not an error message.
assert = function(xs, check_strict = TRUE, .var.name = vname(xs)) makeAssertion(xs, self$check(xs, check_strict = check_strict), .var.name, NULL), # nolint
assert = function(xs, check_strict = TRUE, .var.name = vname(xs)) makeAssertion(xs, self$check(xs, check_strict = check_strict), .var.name, NULL), # nolint

#' @description
#' \pkg{checkmate}-like check-function. Takes a [data.table::data.table]
Expand Down Expand Up @@ -579,9 +567,10 @@ ParamSet = R6Class("ParamSet",
paramrow[, `:=`(
.tags = list(private$.tags[id, tag, nomatch = 0]),
.trafo = private$.trafos[id, trafo],
.requirements = list(if (nrow(depstbl)) transpose_list(depstbl)), # NULL if no deps
.requirements = list(if (nrow(depstbl)) transpose_list(depstbl)), # NULL if no deps
.init_given = id %in% names(vals),
.init = unname(vals[id]))]
.init = unname(vals[id]))
]

set_class(paramrow, c(paramrow$cls, "Domain", class(paramrow)))
},
Expand All @@ -606,7 +595,7 @@ ParamSet = R6Class("ParamSet",
pids_not_there = setdiff(parents, ids)
if (length(pids_not_there) > 0L) {
stopf(paste0("Subsetting so that dependencies on params exist which would be gone: %s.",
"\nIf you still want to subset, set allow_dangling_dependencies to TRUE."), str_collapse(pids_not_there))
"\nIf you still want to subset, set allow_dangling_dependencies to TRUE."), str_collapse(pids_not_there))
}
}
result = ParamSet$new()
Expand Down Expand Up @@ -662,7 +651,7 @@ ParamSet = R6Class("ParamSet",
assert_list(values)
assert_names(names(values), subset.of = self$ids())
pars = private$get_tune_ps(values)
on = NULL # pacify static code check
on = NULL # pacify static code check
dangling_deps = pars$deps[!pars$ids(), on = "on"]
if (nrow(dangling_deps)) {
stopf("Dangling dependencies not allowed: Dependencies on %s dangling.", str_collapse(dangling_deps$on))
Expand All @@ -687,7 +676,7 @@ ParamSet = R6Class("ParamSet",
stopf("A param cannot depend on itself!")
}

if (on %in% ids) { # not necessarily true when allow_dangling_dependencies
if (on %in% ids) { # not necessarily true when allow_dangling_dependencies
feasible_on_values = map_lgl(cond$rhs, function(x) domain_test(self$get_domain(on), list(x)))
if (any(!feasible_on_values)) {
stopf("Condition has infeasible values for %s: %s", on, str_collapse(cond$rhs[!feasible_on_values]))
Expand Down Expand Up @@ -847,7 +836,7 @@ ParamSet = R6Class("ParamSet",
assert_character(v$on, any.missing = FALSE)
assert_list(v$cond, types = "Condition", any.missing = FALSE)
} else {
v = data.table(id = character(0), on = character(0), cond = list()) # make sure we have the right columns
v = data.table(id = character(0), on = character(0), cond = list()) # make sure we have the right columns
}
private$.deps = v
}
Expand Down Expand Up @@ -954,9 +943,7 @@ ParamSet = R6Class("ParamSet",

get_tune_ps = function(values) {
values = keep(values, inherits, "TuneToken")
if (!length(values)) {
return(ParamSet$new())
}
if (!length(values)) return(ParamSet$new())
params = map(names(values), function(pn) {
domain = private$.params[pn, on = "id"]
set_class(domain, c(domain$cls, "Domain", class(domain)))
Expand All @@ -965,14 +952,13 @@ ParamSet = R6Class("ParamSet",

# package-internal S3 fails if we don't call the function indirectly here
partsets = pmap(list(values, params), function(...) tunetoken_to_ps(...))

pars = ps_union(partsets) # partsets does not have names here, wihch is what we want.
pars = ps_union(partsets) # partsets does not have names here, wihch is what we want.

names(partsets) = names(values)
idmapping = map(partsets, function(x) x$ids())

# only add the dependencies that are also in the tuning PS
on = id = NULL # pacify static code check
on = id = NULL # pacify static code check
pmap(self$deps[id %in% names(idmapping) & on %in% names(partsets), c("on", "id", "cond")], function(on, id, cond) {
onpar = partsets[[on]]
if (onpar$has_trafo || !identical(onpar$ids(), on)) {
Expand Down Expand Up @@ -1042,7 +1028,7 @@ rd_info.ParamSet = function(obj, descriptions = character(), ...) { # nolint
is_default = map_lgl(params$default, inherits, "NoDefault")
is_uty = params$storage_type == "list"
set(params, i = which(is_uty & !is_default), j = "default",
value = map(cargo[!is_default & is_uty], function(x) x$repr))
value = map(cargo[!is_default & is_uty], function(x) x$repr))
set(params, i = which(is_uty), j = "storage_type", value = list("untyped"))
set(params, i = which(is_default), j = "default", value = list("-"))

Expand All @@ -1060,3 +1046,4 @@ rd_info.ParamSet = function(obj, descriptions = character(), ...) { # nolint
x = c("", knitr::kable(params, col.names = capitalize(names(params))))
paste(x, collapse = "\n")
}

0 comments on commit 18e4088

Please sign in to comment.