Skip to content

Commit

Permalink
faster ids() for only one tag (#405)
Browse files Browse the repository at this point in the history
* faster ids() for only one tag

* Update ParamSet.R

* Update ParamSet.R

* some optimizations and bugfixes

* with = FALSE

* possibly overdoing it now

* bugfix

* document

* small edits, trying better psc set_values

* better inharitance for ParamSet

* ParamSetCollection could be faster now

* document

* actually make use of new better inheritance

* experimental PSC improvements

* important fix

* another bugfix

* bugfix!

* version bunp
  • Loading branch information
mb706 authored Jun 30, 2024
1 parent 33c0b9e commit 7f2520c
Show file tree
Hide file tree
Showing 15 changed files with 345 additions and 151 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Type: Package
Package: paradox
Title: Define and Work with Parameter Spaces for Complex
Algorithms
Version: 1.0.0
Version: 1.0.0-9000
Authors@R:
c(person(given = "Michel",
family = "Lang",
Expand Down Expand Up @@ -61,7 +61,7 @@ Config/testthat/edition: 3
Config/testthat/parallel: false
NeedsCompilation: no
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
VignetteBuilder: knitr
Collate:
'Condition.R'
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# paradox 1.0.0-9000

* Performance improvements.

# paradox 1.0.0

* Removed `Param` objects. `ParamSet` now uses a `data.table` internally; individual parameters are more like `Domain` objects now. `ParamSets` should be constructed using the `ps()` shorthand and `Domain` objects. This entails the following major changes:
Expand Down
61 changes: 32 additions & 29 deletions R/Domain.R
Original file line number Diff line number Diff line change
Expand Up @@ -206,25 +206,41 @@ Domain = function(cls, grouping,
}
}

# repr: what to print
# This takes the call of the shortform-constructor (such as `p_dbl()`) and inserts all the
# given values.
constructorcall = match.call(sys.function(-1), sys.call(-1), envir = parent.frame(2))
trafoexpr = constructorcall$trafo
constructorcall$trafo = NULL
constructorcall$depends = NULL
reprargs = sapply(names(constructorcall)[-1], get, pos = parent.frame(1), simplify = FALSE)
reprargs$depends = depends_expr
reprargs$trafo = trafoexpr
if (isTRUE(reprargs$logscale)) reprargs$trafo = NULL
param_repr = as.call(c(constructorcall[[1]], reprargs))

# domain is a data.table with a few classes.
# setting `id` to something preliminary so that `domain_assert()` works.
# we construct this data.table as structure(list(...)), however, since this is *much* faster.
param = structure(list(
id = deparse1(param_repr, collapse = "\n", width.cutoff = 80),
cls = cls, grouping = grouping,
cargo = list(cargo),
lower = lower, upper = upper, tolerance = tolerance, levels = list(levels),
special_vals = list(special_vals),
default = list(default),
storage_type = storage_type,
.tags = list(tags),
.trafo = list(trafo),
.requirements = list(parse_depends(depends_expr, parent.frame(2))),

param = data.table(id = "domain being constructed", cls = cls, grouping = grouping,
cargo = list(cargo),
lower = lower, upper = upper, tolerance = tolerance, levels = list(levels),
special_vals = list(special_vals),
default = list(default),
storage_type = storage_type,
.tags = list(tags),
.trafo = list(trafo),
.requirements = list(parse_depends(depends_expr, parent.frame(2))),

.init_given = !missing(init),
.init = list(if (!missing(init)) init)
.init_given = !missing(init),
.init = list(if (!missing(init)) init)
),
class = c(cls, "Domain", "data.table", "data.frame"),
repr = param_repr
)

class(param) = c(cls, "Domain", class(param))

if (!is_nodefault(default)) {
domain_assert(param, list(default))
if ("required" %in% tags) stop("A 'required' parameter can not have a 'default'.\nWhen the method behaves the same as if the parameter value were 'X' whenever the parameter is missing, then 'X' should be a 'default', but the 'required' indicates that the parameter may not be missing.")
Expand All @@ -236,21 +252,8 @@ Domain = function(cls, grouping,
if (identical(init, default)) warning("Initial value and 'default' value seem to be the same, this is usually a mistake due to a misunderstanding of the meaning of 'default'.\nWhen the method behaves the same as if the parameter value were 'X' whenever the parameter is missing, then 'X' should be a 'default' (but then there is no point in setting it as initial value). 'default' should not be used to indicate the value with which values are initialized.")
}

# repr: what to print
# This takes the call of the shortform-constructor (such as `p_dbl()`) and inserts all the
# given values.
constructorcall = match.call(sys.function(-1), sys.call(-1), envir = parent.frame(2))
trafoexpr = constructorcall$trafo
constructorcall$trafo = NULL
constructorcall$depends = NULL
reprargs = sapply(names(constructorcall)[-1], get, pos = parent.frame(1), simplify = FALSE)
reprargs$depends = depends_expr
reprargs$trafo = trafoexpr
if (isTRUE(reprargs$logscale)) reprargs$trafo = NULL
attr(param, "repr") = as.call(c(constructorcall[[1]], reprargs))
set(param, , "id", repr(attr(param, "repr"))) # some ID for consistency with ParamSet$params, only for error messages.

assert_names(names(param), identical.to = domain_names) # If this is not true then there is either a bug in Domain(), or empty_domain was not updated.
# If this is not true then there is either a bug in Domain(), or empty_domain was not updated.
if (!identical(names(param), domain_names)) stop("Unexpected names in constructed Domain object; this is probably a bug in paradox.")

param
}
Expand Down
28 changes: 16 additions & 12 deletions R/Domain_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,26 @@
#'
#' @param param (`Domain`).
#' @param values (`any`).
#' @param internal (`logical(1)`)\cr
#' When set, function arguments are not checked for plausibility and `special_values` are not respected.
#' This is an optimization for internal purposes and should not be used.
#' @return If successful `TRUE`, if not a string with the error message.
#' @keywords internal
#' @export
domain_check = function(param, values) {
if (!test_list(values, len = nrow(param))) return("values must be a list")
domain_check = function(param, values, internal = FALSE) {
if (length(values) == 0) return(TRUE) # happens when there are no params + values to check
assert_string(unique(param$grouping))
special_vals_hit = pmap_lgl(list(param$special_vals, values), has_element)
if (any(special_vals_hit)) {
# don't annoy domain_check methods with the burdon of having to filter out
# values that match special_values
Recall(param[!special_vals_hit], values[!special_vals_hit])
} else {
UseMethod("domain_check")
if (!internal) {
if (!test_list(values, len = nrow(param))) return("values must be a list")
assert_string(unique(param$grouping))

special_vals_hit = pmap_lgl(list(param$special_vals, values), has_element)
if (any(special_vals_hit)) {
# don't annoy domain_check methods with the burdon of having to filter out
# values that match special_values
return(Recall(param[!special_vals_hit], values[!special_vals_hit], internal = TRUE))
}
}
UseMethod("domain_check")
}

#' @export
Expand Down Expand Up @@ -133,8 +138,7 @@ domain_qunif = function(param, x) {
#' @keywords internal
#' @export
domain_sanitize = function(param, values) {
if (!nrow(param)) return(values)
assert_string(unique(param$grouping))
if (!length(values)) return(values)
UseMethod("domain_sanitize")
}

Expand Down
2 changes: 1 addition & 1 deletion R/ParamDbl.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_
}

#' @export
domain_check.ParamDbl = function(param, values) {
domain_check.ParamDbl = function(param, values, internal = FALSE) {
lower = param$lower - param$tolerance * pmax(1, abs(param$lower))
upper = param$upper + param$tolerance * pmax(1, abs(param$upper))
if (qtestr(values, "N1")) {
Expand Down
2 changes: 1 addition & 1 deletion R/ParamFct.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = charact
}

#' @export
domain_check.ParamFct = function(param, values) {
domain_check.ParamFct = function(param, values, internal = FALSE) {
if (qtestr(values, "S1")) {
values_str = as.character(values)
if (all(values_str %in% param$levels[[1]])) return(TRUE) # this works because we have the grouping -- all 'levels' are the same here.
Expand Down
12 changes: 6 additions & 6 deletions R/ParamInt.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_
}

#' @export
domain_check.ParamInt = function(param, values) {
domain_check.ParamInt = function(param, values, internal = FALSE) {
if (!qtestr(values, "N1()")) {
return(check_domain_vectorize(param$id, values, check_int,
more_args = list(lower = param$lower - 0.5, upper = param$upper + 0.5, # be lenient with bounds, because they would refer to the rounded values
Expand All @@ -45,11 +45,11 @@ domain_check.ParamInt = function(param, values) {

values_num = as.numeric(values)

if (all(abs(trunc(values_num + 0.5) - 0.5) <= param$tolerance)) {
values_num = round(values_num)
if (all(values_num >= param$lower) && all(values_num <= param$upper)) {
return(TRUE)
}
rounded = round(values_num)
if (all(abs(values_num - rounded) <= param$tolerance &
rounded >= param$lower &
rounded <= param$upper)) {
return(TRUE)
}

check_domain_vectorize(param$id, values_num, check_int,
Expand Down
2 changes: 1 addition & 1 deletion R/ParamLgl.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), de
}

#' @export
domain_check.ParamLgl = function(param, values) {
domain_check.ParamLgl = function(param, values, internal = FALSE) {
if (qtestr(values, "B1")) {
return(TRUE)
}
Expand Down
Loading

0 comments on commit 7f2520c

Please sign in to comment.