Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Dec 20, 2024
1 parent ddefaee commit 26c0785
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 11 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# mlr3tuning (development version)

* fix: The `$predict_type` was written to the model even when the `AutoTuner` was not trained.

# mlr3tuning 1.3.0

* feat: Save `ArchiveAsyncTuning` to a `data.table` with `ArchiveAsyncTuningFrozen`.
Expand Down
10 changes: 7 additions & 3 deletions R/AutoTuner.R
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ AutoTuner = R6Class("AutoTuner",
#' @field predict_type (`character(1)`)\cr
#' Stores the currently active predict type, e.g. `"response"`.
#' Must be an element of `$predict_types`.
#' A few learners already use the predict type during training.
#' So there is no guarantee that changing the predict type after tuning and training will have any effect or does not lead to errors.
predict_type = function(rhs) {
if (missing(rhs)) {
return(private$.predict_type)
Expand All @@ -322,10 +324,12 @@ AutoTuner = R6Class("AutoTuner",
stopf("Learner '%s' does not support predict type '%s'", self$id, rhs)
}

# Catches 'Error: Field/Binding is read-only' bug
tryCatch({
self$instance_args$learner$predict_type = rhs


if (!is.null(self$model)) {
self$model$learner$predict_type = rhs
}, error = function(cond){})
}

private$.predict_type = rhs
},
Expand Down
4 changes: 3 additions & 1 deletion man/AutoTuner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 29 additions & 7 deletions tests/testthat/test_AutoTuner.R
Original file line number Diff line number Diff line change
Expand Up @@ -231,22 +231,44 @@ test_that("store_tuning_instance, store_benchmark_result and store_models flags
})

test_that("predict_type works", {
te = trm("evals", n_evals = 4)
task = tsk("iris")
ps = TEST_MAKE_PS1(n_dim = 1)
ms = msr("classif.ce")
tuner = tnr("grid_search", resolution = 3)
task = tsk("pima")

at = AutoTuner$new(lrn("classif.rpart"), rsmp("holdout"), ms, te,
tuner = tuner, ps)
# response predict type
at = auto_tuner(
tuner = tnr("random_search"),
learner = lrn("classif.rpart"),
resampling = rsmp("holdout"),
measure = msr("classif.ce"),
terminator = trm("evals", n_evals = 4))

expect_equal(at$predict_type, "response")

at$train(task)
expect_equal(at$predict_type, "response")
expect_equal(at$model$learner$predict_type, "response")

# change predict type after training
at$predict_type = "prob"
expect_equal(at$predict_type, "prob")
expect_equal(at$model$learner$predict_type, "prob")

# prob predict type
at = auto_tuner(
tuner = tnr("random_search"),
learner = lrn("classif.rpart", predict_type = "prob"),
resampling = rsmp("holdout"),
measure = msr("classif.ce"),
terminator = trm("evals", n_evals = 4))

expect_equal(at$predict_type, "prob")

at$train(task)

expect_equal(at$predict_type, "prob")
expect_equal(at$model$learner$predict_type, "prob")

pred = at$predict(task)
expect_numeric(pred$score(msr("classif.auc")))
})

test_that("search space from TuneToken works", {
Expand Down

0 comments on commit 26c0785

Please sign in to comment.