diff --git a/DESCRIPTION b/DESCRIPTION index 621fe625..ad897c6b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -2,7 +2,7 @@ Type: Package Package: paradox Title: Define and Work with Parameter Spaces for Complex Algorithms -Version: 1.0.0 +Version: 1.0.0-9000 Authors@R: c(person(given = "Michel", family = "Lang", @@ -61,7 +61,7 @@ Config/testthat/edition: 3 Config/testthat/parallel: false NeedsCompilation: no Roxygen: list(markdown = TRUE, r6 = TRUE) -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 VignetteBuilder: knitr Collate: 'Condition.R' diff --git a/NEWS.md b/NEWS.md index 6ce734e4..16bad2ce 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,7 @@ +# paradox 1.0.0-9000 + +* Performance improvements. + # paradox 1.0.0 * Removed `Param` objects. `ParamSet` now uses a `data.table` internally; individual parameters are more like `Domain` objects now. `ParamSets` should be constructed using the `ps()` shorthand and `Domain` objects. This entails the following major changes: diff --git a/R/Domain.R b/R/Domain.R index b7dee1e8..675a393d 100644 --- a/R/Domain.R +++ b/R/Domain.R @@ -206,25 +206,41 @@ Domain = function(cls, grouping, } } + # repr: what to print + # This takes the call of the shortform-constructor (such as `p_dbl()`) and inserts all the + # given values. + constructorcall = match.call(sys.function(-1), sys.call(-1), envir = parent.frame(2)) + trafoexpr = constructorcall$trafo + constructorcall$trafo = NULL + constructorcall$depends = NULL + reprargs = sapply(names(constructorcall)[-1], get, pos = parent.frame(1), simplify = FALSE) + reprargs$depends = depends_expr + reprargs$trafo = trafoexpr + if (isTRUE(reprargs$logscale)) reprargs$trafo = NULL + param_repr = as.call(c(constructorcall[[1]], reprargs)) + # domain is a data.table with a few classes. # setting `id` to something preliminary so that `domain_assert()` works. + # we construct this data.table as structure(list(...)), however, since this is *much* faster. + param = structure(list( + id = deparse1(param_repr, collapse = "\n", width.cutoff = 80), + cls = cls, grouping = grouping, + cargo = list(cargo), + lower = lower, upper = upper, tolerance = tolerance, levels = list(levels), + special_vals = list(special_vals), + default = list(default), + storage_type = storage_type, + .tags = list(tags), + .trafo = list(trafo), + .requirements = list(parse_depends(depends_expr, parent.frame(2))), - param = data.table(id = "domain being constructed", cls = cls, grouping = grouping, - cargo = list(cargo), - lower = lower, upper = upper, tolerance = tolerance, levels = list(levels), - special_vals = list(special_vals), - default = list(default), - storage_type = storage_type, - .tags = list(tags), - .trafo = list(trafo), - .requirements = list(parse_depends(depends_expr, parent.frame(2))), - - .init_given = !missing(init), - .init = list(if (!missing(init)) init) + .init_given = !missing(init), + .init = list(if (!missing(init)) init) + ), + class = c(cls, "Domain", "data.table", "data.frame"), + repr = param_repr ) - class(param) = c(cls, "Domain", class(param)) - if (!is_nodefault(default)) { domain_assert(param, list(default)) if ("required" %in% tags) stop("A 'required' parameter can not have a 'default'.\nWhen the method behaves the same as if the parameter value were 'X' whenever the parameter is missing, then 'X' should be a 'default', but the 'required' indicates that the parameter may not be missing.") @@ -236,21 +252,8 @@ Domain = function(cls, grouping, if (identical(init, default)) warning("Initial value and 'default' value seem to be the same, this is usually a mistake due to a misunderstanding of the meaning of 'default'.\nWhen the method behaves the same as if the parameter value were 'X' whenever the parameter is missing, then 'X' should be a 'default' (but then there is no point in setting it as initial value). 'default' should not be used to indicate the value with which values are initialized.") } - # repr: what to print - # This takes the call of the shortform-constructor (such as `p_dbl()`) and inserts all the - # given values. - constructorcall = match.call(sys.function(-1), sys.call(-1), envir = parent.frame(2)) - trafoexpr = constructorcall$trafo - constructorcall$trafo = NULL - constructorcall$depends = NULL - reprargs = sapply(names(constructorcall)[-1], get, pos = parent.frame(1), simplify = FALSE) - reprargs$depends = depends_expr - reprargs$trafo = trafoexpr - if (isTRUE(reprargs$logscale)) reprargs$trafo = NULL - attr(param, "repr") = as.call(c(constructorcall[[1]], reprargs)) - set(param, , "id", repr(attr(param, "repr"))) # some ID for consistency with ParamSet$params, only for error messages. - - assert_names(names(param), identical.to = domain_names) # If this is not true then there is either a bug in Domain(), or empty_domain was not updated. + # If this is not true then there is either a bug in Domain(), or empty_domain was not updated. + if (!identical(names(param), domain_names)) stop("Unexpected names in constructed Domain object; this is probably a bug in paradox.") param } diff --git a/R/Domain_methods.R b/R/Domain_methods.R index 83d3dab6..2e149196 100644 --- a/R/Domain_methods.R +++ b/R/Domain_methods.R @@ -14,21 +14,26 @@ #' #' @param param (`Domain`). #' @param values (`any`). +#' @param internal (`logical(1)`)\cr +#' When set, function arguments are not checked for plausibility and `special_values` are not respected. +#' This is an optimization for internal purposes and should not be used. #' @return If successful `TRUE`, if not a string with the error message. #' @keywords internal #' @export -domain_check = function(param, values) { - if (!test_list(values, len = nrow(param))) return("values must be a list") +domain_check = function(param, values, internal = FALSE) { if (length(values) == 0) return(TRUE) # happens when there are no params + values to check - assert_string(unique(param$grouping)) - special_vals_hit = pmap_lgl(list(param$special_vals, values), has_element) - if (any(special_vals_hit)) { - # don't annoy domain_check methods with the burdon of having to filter out - # values that match special_values - Recall(param[!special_vals_hit], values[!special_vals_hit]) - } else { - UseMethod("domain_check") + if (!internal) { + if (!test_list(values, len = nrow(param))) return("values must be a list") + assert_string(unique(param$grouping)) + + special_vals_hit = pmap_lgl(list(param$special_vals, values), has_element) + if (any(special_vals_hit)) { + # don't annoy domain_check methods with the burdon of having to filter out + # values that match special_values + return(Recall(param[!special_vals_hit], values[!special_vals_hit], internal = TRUE)) + } } + UseMethod("domain_check") } #' @export @@ -133,8 +138,7 @@ domain_qunif = function(param, x) { #' @keywords internal #' @export domain_sanitize = function(param, values) { - if (!nrow(param)) return(values) - assert_string(unique(param$grouping)) + if (!length(values)) return(values) UseMethod("domain_sanitize") } diff --git a/R/ParamDbl.R b/R/ParamDbl.R index 6b355bf5..c878ad1e 100644 --- a/R/ParamDbl.R +++ b/R/ParamDbl.R @@ -27,7 +27,7 @@ p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_ } #' @export -domain_check.ParamDbl = function(param, values) { +domain_check.ParamDbl = function(param, values, internal = FALSE) { lower = param$lower - param$tolerance * pmax(1, abs(param$lower)) upper = param$upper + param$tolerance * pmax(1, abs(param$upper)) if (qtestr(values, "N1")) { diff --git a/R/ParamFct.R b/R/ParamFct.R index 3adc69c7..6d21e736 100644 --- a/R/ParamFct.R +++ b/R/ParamFct.R @@ -33,7 +33,7 @@ p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = charact } #' @export -domain_check.ParamFct = function(param, values) { +domain_check.ParamFct = function(param, values, internal = FALSE) { if (qtestr(values, "S1")) { values_str = as.character(values) if (all(values_str %in% param$levels[[1]])) return(TRUE) # this works because we have the grouping -- all 'levels' are the same here. diff --git a/R/ParamInt.R b/R/ParamInt.R index c98a5ea3..b103dc65 100644 --- a/R/ParamInt.R +++ b/R/ParamInt.R @@ -35,7 +35,7 @@ p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_ } #' @export -domain_check.ParamInt = function(param, values) { +domain_check.ParamInt = function(param, values, internal = FALSE) { if (!qtestr(values, "N1()")) { return(check_domain_vectorize(param$id, values, check_int, more_args = list(lower = param$lower - 0.5, upper = param$upper + 0.5, # be lenient with bounds, because they would refer to the rounded values @@ -45,11 +45,11 @@ domain_check.ParamInt = function(param, values) { values_num = as.numeric(values) - if (all(abs(trunc(values_num + 0.5) - 0.5) <= param$tolerance)) { - values_num = round(values_num) - if (all(values_num >= param$lower) && all(values_num <= param$upper)) { - return(TRUE) - } + rounded = round(values_num) + if (all(abs(values_num - rounded) <= param$tolerance & + rounded >= param$lower & + rounded <= param$upper)) { + return(TRUE) } check_domain_vectorize(param$id, values_num, check_int, diff --git a/R/ParamLgl.R b/R/ParamLgl.R index 016d4a7d..ca3475a7 100644 --- a/R/ParamLgl.R +++ b/R/ParamLgl.R @@ -10,7 +10,7 @@ p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), de } #' @export -domain_check.ParamLgl = function(param, values) { +domain_check.ParamLgl = function(param, values, internal = FALSE) { if (qtestr(values, "B1")) { return(TRUE) } diff --git a/R/ParamSet.R b/R/ParamSet.R index 529338dd..b680fdba 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -76,20 +76,46 @@ ParamSet = R6Class("ParamSet", } else { paramtbl = rbindlist(params) set(paramtbl, , "id", names(params)) - if (".tags" %in% colnames(paramtbl)) { - private$.tags = paramtbl[, .(tag = unlist(.tags)), keyby = "id"] - setindexv(private$.tags, "tag") - } } + if (".tags" %in% colnames(paramtbl)) { + # fastest way to init a data.table + private$.tags = structure(list( + id = rep(paramtbl$id, lengths(paramtbl$.tags)), + tag = unlist(paramtbl$.tags) + ), class = c("data.table", "data.frame") + ) + } else { + private$.tags = structure(list( + id = character(0), tag = character(0) + ), class = c("data.table", "data.frame") + ) + } + setkeyv(private$.tags, "id") + setindexv(private$.tags, "tag") + # get initvalues here, so we can delete the relevant column. # we only assign it later, so checks can run normally. .init_given = .init = NULL # pacify checks - initvalues = if (".init" %in% names(paramtbl)) with(paramtbl[(.init_given), .(.init, id)], set_names(.init, id)) + initvalues = if (".init" %in% names(paramtbl)) structure( + paramtbl$.init[paramtbl$.init_given], + names = paramtbl$id[paramtbl$.init_given] + ) if (".trafo" %in% names(paramtbl)) { - private$.trafos = setkeyv(paramtbl[!map_lgl(.trafo, is.null), .(id, trafo = .trafo)], "id") + trafo_given = lengths(paramtbl$.trafo) != 0 + private$.trafos = structure(list( + id = paramtbl$id[trafo_given], + trafo = paramtbl$.trafo[trafo_given] + ), class = c("data.table", "data.frame") + ) + } else { + private$.trafos = structure(list( + id = character(0), trafo = list() + ), class = c("data.table", "data.frame") + ) } + setkeyv(private$.trafos, "id") if (".requirements" %in% names(paramtbl)) { requirements = paramtbl$.requirements @@ -135,6 +161,12 @@ ParamSet = R6Class("ParamSet", if (is.null(class) && is.null(tags) && is.null(any_tags)) { return(private$.params$id) } + if (length(tags) == 1 && is.null(any_tags) && is.null(class)) { + # very typical case: only 'tags' is given. + rv = private$.tags$id[private$.tags$tag == tags] + # keep original order + return(rv[match(private$.params$id, rv, nomatch = 0)]) + } ptbl = if (is.null(class)) private$.params else private$.params[cls %in% class, .(id)] if (is.null(tags) && is.null(any_tags)) { return(ptbl$id) @@ -359,14 +391,29 @@ ParamSet = R6Class("ParamSet", #' @param xs (named `list()`). #' @param check_strict (`logical(1)`)\cr #' Whether to check that constraints and dependencies are satisfied. + #' @param sanitize (`logical(1)`)\cr + #' Whether to move values that are slightly outside bounds to valid values. + #' These values are accepted independent of `sanitize` (depending on the + #' `tolerance` arguments of `p_dbl()` and `p_int()`) . If `sanitize` + #' is `TRUE`, the additional effect is that, should checks pass, the + #' sanitized values of `xs` are added to the result as attribute `"sanitized"`. #' @return If successful `TRUE`, if not a string with an error message. - check = function(xs, check_strict = TRUE) { + check = function(xs, check_strict = TRUE, sanitize = FALSE) { assert_flag(check_strict) ok = check_list(xs, names = "unique") if (!isTRUE(ok)) { return(ok) } + trueret = TRUE + if (sanitize) { + attr(trueret, "sanitized") = xs + } + + # return early, this makes the following code easier since we don't need to consider edgecases with empty vectors. + if (!length(xs)) return(trueret) + + params = private$.params ns = names(xs) ids = private$.params$id @@ -376,12 +423,20 @@ ParamSet = R6Class("ParamSet", return(sprintf("Parameter '%s' not available.%s", ns[extra], did_you_mean(extra, ids))) } - if (length(xs) && test_list(xs, types = "TuneToken")) { + if (some(xs, inherits, "TuneToken")) { tunecheck = tryCatch({ private$get_tune_ps(xs) TRUE }, error = function(e) paste("tune token invalid:", conditionMessage(e))) if (!isTRUE(tunecheck)) return(tunecheck) + xs_nontune = discard(xs, inherits, "TuneToken") + + # only had TuneTokens, nothing else to check here. + if (!length(xs_nontune)) { + return(trueret) + } + } else { + xs_nontune = xs } xs_internaltune = keep(xs, is, "InternalTuneToken") @@ -393,16 +448,54 @@ ParamSet = R6Class("ParamSet", # check each parameter group's feasibility - xs_nontune = discard(xs, inherits, "TuneToken") + pidx = match(names(xs_nontune), params$id) + nonspecial = !pmap_lgl(list(params$special_vals[pidx], xs_nontune), has_element) + pidx = pidx[nonspecial] + + if (sanitize) { + bylevels = paste0(params$cls[pidx], params$grouping[pidx]) + if (length(unique(bylevels)) <= 7) { + # if we do few splits, it is faster to do the subsetting of `params` manually instead of using data.table `by`. + checkresults = list() + sanitized_list = list() + for (spl in split(pidx, bylevels)) { + values = xs[params$id[spl]] + spltbl = params[spl] + spltbl = recover_domain(spltbl) + cr = domain_check(spltbl, values, internal = TRUE) + if (isTRUE(cr)) { + sanitized_list[[length(sanitized_list) + 1]] = structure(domain_sanitize(spltbl, values), names = names(values)) + } + checkresults[[length(checkresults) + 1]] = cr + } + } else { - # need to make sure we index w/ empty character instead of NULL - params = params[names(xs_nontune) %??% character(0), on = "id"] + params = params[pidx] + set(params, , "values", list(xs_nontune[nonspecial])) + + checks = params[, { + domain = recover_domain(.SD) + cr = domain_check(domain, values, internal = TRUE) + if (isTRUE(cr)) { + values = domain_sanitize(domain, values) + } + list(list(cr), list(structure(values, names = id))) + }, by = c("cls", "grouping"), + .SDcols = colnames(params)] + checkresults = checks[[3]] + sanitized_list = checks[[4]] + } + sanitized = unlist(sanitized_list, recursive = FALSE) + sanitized_all = xs + sanitized_all[names(sanitized)] = sanitized + attr(trueret, "sanitized") = sanitized_all + } else { + params = params[pidx] + set(params, , "values", list(xs_nontune[nonspecial])) - set(params, , "values", list(xs_nontune)) - pgroups = split(params, by = c("cls", "grouping")) - checkresults = map(pgroups, function(x) { - domain_check(set_class(x, c(x$cls[[1]], "Domain", class(x))), x$values) - }) + checkresults = params[, list(list(domain_check(recover_domain(.SD), values))), by = c("cls", "grouping"), + .SDcols = colnames(params)][[3]] # first two cols are 'cls' and 'grouping' + } checkresults = discard(checkresults, isTRUE) if (length(checkresults)) { return(str_collapse(checkresults, sep = "\n")) @@ -414,10 +507,10 @@ ParamSet = R6Class("ParamSet", ## return(sprintf("Missing required parameters: %s", str_collapse(required))) ## } if (!self$test_constraint(xs, assert_value = FALSE)) return(sprintf("Constraint not fulfilled.")) - return(self$check_dependencies(xs)) + cd = self$check_dependencies(xs) + if (!isTRUE(cd)) return(cd) } - - TRUE # we passed all checks + trueret # we passed all checks }, #' @description @@ -480,8 +573,16 @@ ParamSet = R6Class("ParamSet", #' @param .var.name (`character(1)`)\cr #' Name of the checked object to print in error messages.\cr #' Defaults to the heuristic implemented in [vname][checkmate::vname]. + #' @param sanitize (`logical(1)`)\cr + #' Whether to move values that are slightly outside bounds to valid values. + #' These values are accepted independent of `sanitize` (depending on the + #' `tolerance` arguments of `p_dbl()` and `p_int()`) . If `sanitize` + #' is `TRUE`, the additional effect is that `xs` is converted to within bounds. #' @return If successful `xs` invisibly, if not an error message. - assert = function(xs, check_strict = TRUE, .var.name = vname(xs)) makeAssertion(xs, self$check(xs, check_strict = check_strict), .var.name, NULL), # nolint + assert = function(xs, check_strict = TRUE, .var.name = vname(xs), sanitize = FALSE) { + checkresult = self$check(xs, check_strict = check_strict, sanitize = sanitize) + makeAssertion(if (sanitize) attr(checkresult, "sanitized") else xs, checkresult, .var.name, NULL) # nolint + }, #' @description #' \pkg{checkmate}-like check-function. Takes a [data.table::data.table] @@ -549,8 +650,9 @@ ParamSet = R6Class("ParamSet", params = private$.params[rownames(x), on = "id"] params$result = list() result = NULL # static checks - params[, result := list(as.list(as.data.frame(t(matrix(domain_qunif(recover_domain(.SD, .BY), x[id, ]), nrow = .N))))), - by = c("cls", "grouping")] + params[, result := list(as.list(as.data.frame(t(matrix(domain_qunif(recover_domain(.SD), x[id, ]), nrow = .N))))), + by = c("cls", "grouping"), + .SDcols = colnames(private$.params)] as.data.table(set_names(params$result, params$id)) }, @@ -741,25 +843,17 @@ ParamSet = R6Class("ParamSet", #' @template field_values values = function(xs) { if (missing(xs)) { - return(private$.values) - } - if (self$assert_values) { - self$assert(xs) + return(private$.get_values()) } if (length(xs) == 0L) { xs = named_list() - } else if (self$assert_values) { # this only makes sense when we have asserts on + } else if (self$assert_values) { + # this only makes sense when we have asserts on # convert all integer params really to storage type int, move doubles to within bounds etc. # solves issue #293, #317 - nontt = discard(xs, inherits, "TuneToken") - values = special_vals = NULL # static checks - sanitized = set(private$.params[names(nontt), on = "id"], , "values", list(nontt))[ - !pmap_lgl(list(special_vals, values), has_element), - .(id, values = domain_sanitize(recover_domain(.SD, .BY), values)), by = c("cls", "grouping")] - xs = insert_named(xs, with(sanitized, set_names(values, id))) + xs = self$assert(xs, sanitize = TRUE) } - # store with param ordering, return value with original ordering - private$.values = xs[match(private$.params$id, names(xs), nomatch = 0)] + private$.store_values(xs) xs }, @@ -812,7 +906,9 @@ ParamSet = R6Class("ParamSet", if (missing(f)) { private$.extra_trafo } else { - assert(check_function(f, args = c("x", "param_set"), null.ok = TRUE), check_function(f, args = "x", null.ok = TRUE)) + if (!is.null(f)) { # for speed, since asserts below are slow apparently + assert(check_function(f, args = c("x", "param_set"), null.ok = TRUE), check_function(f, args = "x", null.ok = TRUE)) + } private$.extra_trafo = f } }, @@ -903,8 +999,9 @@ ParamSet = R6Class("ParamSet", #' Named with param IDs. nlevels = function() { tmp = private$.params[, - list(id, nlevels = domain_nlevels(recover_domain(.SD, .BY))), - by = c("cls", "grouping") + list(id, nlevels = domain_nlevels(recover_domain(.SD))), + by = c("cls", "grouping"), + .SDcols = colnames(private$.params) ] with(tmp[private$.params$id, on = "id"], set_names(nlevels, id)) }, @@ -912,8 +1009,9 @@ ParamSet = R6Class("ParamSet", #' @field is_number (named `logical()`)\cr Whether parameter is [`p_dbl()`] or [`p_int()`]. Named with parameter IDs. is_number = function() { tmp = private$.params[, - list(id, is_number = rep(domain_is_number(recover_domain(.SD, .BY)), .N)), - by = c("cls", "grouping") + list(id, is_number = rep(domain_is_number(recover_domain(.SD)), .N)), + by = c("cls", "grouping"), + .SDcols = colnames(private$.params) ] with(tmp[private$.params$id, on = "id"], set_names(is_number, id)) }, @@ -921,8 +1019,9 @@ ParamSet = R6Class("ParamSet", #' @field is_categ (named `logical()`)\cr Whether parameter is [`p_fct()`] or [`p_lgl()`]. Named with parameter IDs. is_categ = function() { tmp = private$.params[, - list(id, is_categ = rep(domain_is_categ(recover_domain(.SD, .BY)), .N)), - by = c("cls", "grouping") + list(id, is_categ = rep(domain_is_categ(recover_domain(.SD)), .N)), + by = c("cls", "grouping"), + .SDcols = colnames(private$.params) ] with(tmp[private$.params$id, on = "id"], set_names(is_categ, id)) }, @@ -930,14 +1029,20 @@ ParamSet = R6Class("ParamSet", #' @field is_bounded (named `logical()`)\cr Whether parameters have finite bounds. Named with parameter IDs. is_bounded = function() { tmp = private$.params[, - list(id, is_bounded = domain_is_bounded(recover_domain(.SD, .BY))), - by = c("cls", "grouping") + list(id, is_bounded = domain_is_bounded(recover_domain(.SD))), + by = c("cls", "grouping"), + .SDcols = colnames(private$.params) ] with(tmp[private$.params$id, on = "id"], set_names(is_bounded, id)) } ), private = list( + .store_values = function(xs) { + # store with param ordering + private$.values = xs[match(private$.params$id, names(xs), nomatch = 0)] + }, + .get_values = function() private$.values, .extra_trafo = NULL, .constraint = NULL, .params = NULL, @@ -1004,10 +1109,9 @@ ParamSet = R6Class("ParamSet", ) ) -recover_domain = function(sd, by) { - domain = as.data.table(c(by, sd)) - class(domain) = c(domain$cls, "Domain", class(domain)) - domain +recover_domain = function(sd) { + class(sd) = c(sd$cls[1], "Domain", class(sd)) + sd } #' @export diff --git a/R/ParamSetCollection.R b/R/ParamSetCollection.R index 14af5099..df342b0c 100644 --- a/R/ParamSetCollection.R +++ b/R/ParamSetCollection.R @@ -43,13 +43,24 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, assert_names(names(sets)[names(sets) != ""], type = "strict") - paramtbl = rbindlist(map(seq_along(sets), function(i) { + paramtbl = rbindlist_proto(map(seq_along(sets), function(i) { s = sets[[i]] n = names(sets)[[i]] - params_child = s$params[, `:=`(original_id = id, owner_ps_index = i, owner_name = n)] - if (n != "") set(params_child, , "id", sprintf("%s.%s", n, params_child$id)) - params_child - })) + params_child = s$.__enclos_env__$private$.params + if (nrow(params_child)) { + params_child = copy(params_child) + set(params_child, , "original_id", params_child$id) + set(params_child, , "owner_ps_index", i) + set(params_child, , "owner_name", n) + if (n != "") set(params_child, , "id", sprintf("%s.%s", n, params_child$id)) + params_child + } + }), prototype = { + paramtbl = copy(empty_domain) + set(paramtbl, , "original_id", character(0)) + set(paramtbl, , "owner_ps_index", integer(0)) + set(paramtbl, , "owner_name", character(0)) + }) dups = duplicated(paramtbl$id) if (any(dups)) { @@ -57,18 +68,59 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, str_collapse(unique(paramtbl$id[dups]))) } - if (!nrow(paramtbl)) { - # when paramtbl is empty, use special setup to make sure information about the `.tags` column is present. - paramtbl = copy(empty_domain)[, `:=`(original_id = character(0), owner_ps_index = integer(0), owner_name = character(0))] - } - owner_name = NULL # static check - if (tag_sets) paramtbl[owner_name != "", .tags := pmap(list(.tags, owner_name), function(x, n) c(x, sprintf("set_%s", n)))] - if (tag_params) paramtbl[, .tags := pmap(list(.tags, original_id), function(x, n) c(x, sprintf("param_%s", n)))] - private$.tags = paramtbl[, .(tag = unique(unlist(.tags))), keyby = "id"] + alltagstables = map(seq_along(sets), function(i) { + s = sets[[i]] + n = names(sets)[[i]] + if (tag_sets || tag_params) { + ids = s$.__enclos_env__$private$.params$id + newids = ids + if (n != "") newids = sprintf("%s.%s", n, ids) + } + tags_child = s$.__enclos_env__$private$.tags + list( + if (nrow(tags_child)) { + tags_child = copy(tags_child) + if (n != "") set(tags_child, , "id", sprintf("%s.%s", n, tags_child$id)) + tags_child + }, + if (tag_sets && n != "" && length(newids)) { + structure(list(id = newids, tag = rep_len(sprintf("set_%s", n), length(ids))), + class = c("data.table", "data.frame")) + }, + if (tag_params && length(newids)) { + structure(list(id = newids, tag = sprintf("param_%s", ids)), + class = c("data.table", "data.frame")) + } + ) + }) - private$.trafos = setkeyv(paramtbl[!map_lgl(.trafo, is.null), .(id, trafo = .trafo)], "id") + # this may introduce duplicate tags, if param_... or set_... is present before... + private$.tags = rbindlist_proto(unlist(alltagstables, recursive = FALSE, use.names = FALSE), + prototype = structure(list( + id = character(0), tag = character(0) + ), class = c("data.table", "data.frame") + ) + ) + setkeyv(private$.tags, "id") - private$.translation = paramtbl[, c("id", "original_id", "owner_ps_index", "owner_name"), with = FALSE] + private$.trafos = rbindlist_proto(map(seq_along(sets), function(i) { + s = sets[[i]] + n = names(sets)[[i]] + trafos_child = s$.__enclos_env__$private$.trafos + if (nrow(trafos_child)) { + trafos_child = copy(trafos_child) + if (n != "" && nrow(trafos_child)) set(trafos_child, , "id", sprintf("%s.%s", n, trafos_child$id)) + trafos_child + } + }), prototype = structure(list( + id = character(0), trafo = list() + ), class = c("data.table", "data.frame") + ) + ) + setkeyv(private$.trafos, "id") + + private$.translation = structure(unclass(copy(paramtbl))[c("id", "original_id", "owner_ps_index", "owner_name")], + class = c("data.table", "data.frame")) setkeyv(private$.translation, "id") setindexv(private$.translation, "original_id") @@ -279,31 +331,6 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, rbindlist(c(d_all, list(private$.deps)), use.names = TRUE) }, - #' @template field_values - values = function(xs) { - sets = private$.sets - if (!missing(xs)) { - assert_list(xs) - # make sure everything is valid and feasible. - # We do this here because we don't want the loop to be aborted early and have half an update. - self$assert(xs) - - # %??% character(0) in case xs is an empty unnamed list - translate = private$.translation[names(xs) %??% character(0), list(original_id, owner_ps_index), on = "id"] - set(translate, , j = "values", list(xs)) - for (xtl in split(translate, by = "owner_ps_index")) { - sets[[xtl$owner_ps_index[[1]]]]$values = set_names(xtl$values, xtl$original_id) - } - # clear the values of all sets that are not touched by xs - for (clearing in setdiff(seq_along(sets), translate$owner_ps_index)) { - sets[[clearing]]$values = named_list() - } - } - vals = unlist(map(sets, "values"), recursive = FALSE) - if (!length(vals)) return(named_list()) - vals - }, - #' @template field_extra_trafo extra_trafo = function(f) { if (!missing(f)) stop("extra_trafo is read-only in ParamSetCollection.") @@ -329,6 +356,24 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, ), private = list( + .get_values = function() { + vals = unlist(map(private$.sets, "values"), recursive = FALSE) + if (length(vals)) vals else named_list() + }, + .store_values = function(xs) { + sets = private$.sets + # %??% character(0) in case xs is an empty unnamed list + idx = match(names(xs) %??% character(0), private$.translation$id) + translate = private$.translation[idx, c("original_id", "owner_ps_index"), with = FALSE] + set(translate, , j = "values", list(xs)) + for (xtl in split(translate, f = translate$owner_ps_index)) { + sets[[xtl$owner_ps_index[[1]]]]$.__enclos_env__$private$.store_values(set_names(xtl$values, xtl$original_id)) + } + # clear the values of all sets that are not touched by xs + for (clearing in setdiff(seq_along(sets), translate$owner_ps_index)) { + sets[[clearing]]$.__enclos_env__$private$.store_values(named_list()) + } + }, .sets = NULL, .translation = data.table(id = character(0), original_id = character(0), owner_ps_index = integer(0), owner_name = character(0), key = "id"), .children_with_trafos = NULL, diff --git a/R/ParamUty.R b/R/ParamUty.R index 5acd8fad..14bf0d72 100644 --- a/R/ParamUty.R +++ b/R/ParamUty.R @@ -21,7 +21,7 @@ p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, t } #' @export -domain_check.ParamUty = function(param, values) { +domain_check.ParamUty = function(param, values, internal = FALSE) { cargo = map(param$cargo, "custom_check") subset = !map_lgl(cargo, is.null) if (!any(subset)) return(TRUE) diff --git a/R/helper.R b/R/helper.R index 53ca3f44..84934327 100644 --- a/R/helper.R +++ b/R/helper.R @@ -53,3 +53,15 @@ col_to_nl = function(dt, col = 1, idcol = 2) { names(data) = dt[[idcol]] data } + + +# rbindlist, but +# (1) some optimization if only one table given and we know it is a data.table +# (2) if no table is given, return a prototype table. +# Input is potentially not copied, so input should be copied itself if by-reference-modification could be an issue! +rbindlist_proto = function(l, prototype) { + tbls_given = which(lengths(l) != 0) + if (length(tbls_given) == 0) return(prototype) + if (length(tbls_given) == 1) return(l[[tbls_given]]) + rbindlist(l, use.names = TRUE) +} diff --git a/man/ParamSet.Rd b/man/ParamSet.Rd index a4ad34e7..d83d9e75 100644 --- a/man/ParamSet.Rd +++ b/man/ParamSet.Rd @@ -473,7 +473,7 @@ all individual param constraints are satisfied and all dependencies are satisfie Params for which dependencies are not satisfied should not be part of \code{x}. Constraints and dependencies are not checked when \code{check_strict} is \code{FALSE}. \subsection{Usage}{ -\if{html}{\out{