-
-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add convert oml dataset to mlr3 #444
base: master
Are you sure you want to change the base?
Changes from 3 commits
ec4c156
a63f22d
a239b3a
af4fa54
d9f29af
b66aced
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
#' @title Convert an OpenML data set to mlr3 task. | ||
#' | ||
#' @description | ||
#' Converts an \code{\link{OMLDataSet}} to a \code{\link[mlr3]{Task}}. | ||
#' | ||
#' @param obj [\code{\link{OMLDataSet}}]\cr | ||
#' The object that should be converted. | ||
#' @param mlr.task.id [\code{character(1)}]\cr | ||
#' Id string for \code{\link[mlr3]{Task}} object. | ||
#' The strings \code{<oml.data.name>}, \code{<oml.data.id>} and \code{<oml.data.version>} | ||
#' will be replaced by their respective values contained in the \code{\link{OMLDataSet}} object. | ||
#' Default is \code{<oml.data.name>}. | ||
#' @param task.type [\code{character(1)}]\cr | ||
#' As we only pass the data set, we need to define the task type manually. | ||
#' Possible are: \dQuote{Supervised Classification}, \dQuote{Supervised Regression}, | ||
#' \dQuote{Survival Analysis}. | ||
#' Default is \code{NULL} which means to guess it from the target column in the | ||
#' data set. If that is a factor or a logical, we choose classification. | ||
#' If it is numeric we choose regression. In all other cases an error is thrown. | ||
#' @param target [\code{character}]\cr | ||
#' The target for the classification/regression task. | ||
#' Default is the \code{default.target.attribute} of the \code{\link{OMLDataSetDescription}}. | ||
#' @param ignore.flagged.attributes [\code{logical(1)}]\cr | ||
#' Should those features that are listed in the data set description slot \dQuote{ignore.attribute} | ||
#' be removed? | ||
#' Default is \code{TRUE}. | ||
#' @param drop.levels [\code{logical(1)}]\cr | ||
#' Should empty factor levels be dropped in the data? | ||
#' Default is \code{TRUE}. | ||
#' @param fix.colnames [\code{logical(1)}]\cr | ||
#' Should colnames of the data be fixed using \code{\link[base]{make.names}}? | ||
#' Default is \code{TRUE}. | ||
#' @template arg_verbosity | ||
#' @return [\code{\link[mlr3]{Task}}]. | ||
#' @family data set-related functions | ||
#' @example /inst/examples/convertOMLDataSetToMlr3.R | ||
#' @export | ||
convertOMLDataSetToMlr3 = function( | ||
obj, | ||
mlr.task.id = "<oml.data.name>", | ||
task.type = NULL, | ||
target = obj$desc$default.target.attribute, | ||
ignore.flagged.attributes = TRUE, | ||
drop.levels = TRUE, | ||
fix.colnames = TRUE, | ||
verbosity = NULL) { | ||
|
||
assertClass(obj, "OMLDataSet") | ||
assertSubset(target, obj$colnames.new) | ||
assertFlag(ignore.flagged.attributes) | ||
assertFlag(drop.levels) | ||
|
||
data = obj$data | ||
desc = obj$desc | ||
|
||
# no task type? guess it by looking at target | ||
if (is.null(task.type)) | ||
task.type = guessTaskType(data[, target]) | ||
assertChoice(task.type, getValidTaskTypes()) | ||
|
||
# remove ignored attributes from data | ||
if (any(!is.na(desc$ignore.attribute)) & ignore.flagged.attributes) { | ||
keep.cols = obj$colnames.old %nin% desc$ignore.attribute | ||
data = data[, keep.cols, drop = FALSE] | ||
} | ||
|
||
# drop levels | ||
if (drop.levels) | ||
data = droplevels(data) | ||
|
||
# fix colnames using make.names | ||
if (fix.colnames) { | ||
colnames(data) = make.names(colnames(data), unique = TRUE) | ||
target = make.names(target, unique = TRUE) | ||
} | ||
|
||
# get fixup verbose setting for mlr | ||
if (is.null(verbosity)) | ||
verbosity = getOMLConfig()$verbosity | ||
fixup = ifelse(verbosity == 0L, "quiet", "warn") | ||
|
||
mlr.task = switch(task.type, | ||
"Supervised Classification" = mlr3::TaskClassif$new(id = desc$name, backend = data, target = target), | ||
"Supervised Regression" = mlr3::TaskRegr$new(id = desc$name, backend = data, target = target), | ||
"Survival Analysis" = mlr3survival::TaskSurv$new(id = desc$name, backend = data, target = target), | ||
stopf("Encountered currently unsupported task type: %s", task.type) | ||
) | ||
|
||
if (!is.null(mlr.task.id)) | ||
mlr.task$id = replaceOMLDataSetString(mlr.task.id, obj) | ||
|
||
return(mlr.task) | ||
} | ||
|
||
replaceOMLDataSetString = function(string, data.set) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this function already exists in |
||
string = stri_replace_all_fixed(string, "<oml.data.id>", data.set$desc$id) | ||
string = stri_replace_all_fixed(string, "<oml.data.name>", data.set$desc$name) | ||
stri_replace_all_fixed(string, "<oml.data.version>", data.set$desc$version) | ||
} | ||
|
||
# @title Helper to guess task type from target column format. | ||
# | ||
# @param target [vector] | ||
# Vector of target values. | ||
# @return [character(1)] | ||
guessTaskType = function(target) { | ||
if (inherits(target, "data.frame")) { | ||
assertDataFrame(target, types = "logical") | ||
return("Multilabel") | ||
} else { | ||
if (is.factor(target) | is.logical(target)) | ||
return("Supervised Classification") | ||
if (is.numeric(target)) | ||
return("Supervised Regression") | ||
} | ||
|
||
stopf("Cannot guess task.type from data!") | ||
} | ||
|
||
getValidTaskTypes = function() { | ||
c("Supervised Classification", "Supervised Regression", "Survival Analysis", "Multilabel") | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
convertOMLSplitsToMlr3 = function(estim.proc, mlr.task, predict = "both") { | ||
type = estim.proc$type | ||
n.repeats = estim.proc$parameters[["number_repeats"]] | ||
n.folds = estim.proc$parameters[["number_folds"]] | ||
percentage = as.numeric(estim.proc$parameters[["percentage"]]) | ||
data.splits = estim.proc$data.splits | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The data splits need to be stored |
||
stratified = estim.proc$parameters[["stratified_sampling"]] | ||
stratified = ifelse(is.null(stratified), FALSE, stratified == "true") | ||
|
||
if (type == "crossvalidation") { | ||
if (n.repeats == 1L) | ||
mlr.rdesc = mlr3::rsmp("cv", folds = n.folds, stratify = stratified) | ||
else | ||
mlr.rdesc = mlr3::rsmp("repeated_cv", reps = n.repeats, folds = n.folds, stratify = stratified) | ||
mlr.rin = mlr.rdesc$instantiate(mlr.task) | ||
} else if (type == "holdout") { | ||
mlr.rdesc = mlr3::rsmp("holdout") | ||
mlr.rin = mlr.rdesc$instantiate(task = mlr.task) | ||
n.folds = 1 | ||
} else { | ||
stopf("Unsupported estimation procedure type: %s", type) | ||
} | ||
return(mlr.rin) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# \dontrun{ | ||
# library("mlr3") | ||
# autosOML = getOMLDataSet(data.id = 9) | ||
# autosMlr3 = convertOMLDataSetToMlr3(autosOML) | ||
# } |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
context("convertOMLDataSetToMlr3") | ||
|
||
test_that("convertOMLDataSetToMlr3", { | ||
with_test_cache({ | ||
ds = getOMLDataSet(10) | ||
|
||
expect_is_mlr_task = function(mlr.task, ds) { | ||
expect_equal(mlr.task$task_type, "classif") | ||
expect_equal(mlr.task$nrow, nrow(ds$data)) | ||
expect_equal(ds$desc$default.target.attribute, mlr.task$target_names) | ||
} | ||
|
||
# now create the task | ||
mlr.task = convertOMLDataSetToMlr3(ds) | ||
expect_equal(mlr.task$task_type, "classif") | ||
|
||
# now modify dataset by hand (no more server calls) to check | ||
# ignore attributes stuff: | ||
# Define the first two attributes as ignored attributes | ||
ds$desc$ignore.attribute = colnames(ds$data[, 1:2]) | ||
|
||
mlr.task = convertOMLDataSetToMlr3(ds, ignore.flagged.attributes = TRUE) | ||
expect_is_mlr_task(mlr.task, ds) | ||
# we removed two attributes (and the target column is not considered here) | ||
#expect_equal(sum(mlr.task$task.desc$n.feat), ncol(ds$data) - 3L) | ||
expect_equal(mlr.task$ncol, ncol(ds$data) - 2L) | ||
|
||
# pass faulty parameters | ||
expect_error(convertOMLDataSetToMlr3(ds, task.type = "Nonexistent task type"), "element of") | ||
|
||
# check setting mlr task id | ||
expect_equal(convertOMLDataSetToMlr3(ds)$id, ds$desc$name) | ||
expect_equal(convertOMLDataSetToMlr3(ds, mlr.task.id = "<oml.data.name>.<oml.data.id>")$id, | ||
sprintf("%s.%s", ds$desc$name, ds$desc$id)) | ||
expect_equal(convertOMLDataSetToMlr3(ds, mlr.task.id = "test")$id, "test") | ||
expect_equal(convertOMLDataSetToMlr3(ds, mlr.task.id = "<oml.data.id>")$id, as.character(ds$desc$id)) | ||
expect_equal(convertOMLDataSetToMlr3(ds, mlr.task.id = "<oml.data.name>")$id, as.character(ds$desc$name)) | ||
expect_equal(convertOMLDataSetToMlr3(ds, mlr.task.id = "<oml.data.version>")$id, as.character(ds$desc$version)) | ||
expect_equal(convertOMLDataSetToMlr3(ds, mlr.task.id = "<oml.task.id>")$id, "<oml.task.id>") | ||
|
||
# check if conversion to regression task works | ||
ds$desc$target.features = ds$desc$default.target.attribute = "no_of_nodes_in" | ||
expect_equal(convertOMLDataSetToMlr3(ds)$task_type, "regr") | ||
}) | ||
}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
context("convertOMLSplitsToMlr3") | ||
|
||
test_that("convertOMLSplitsToMlr3", { | ||
with_test_cache({ | ||
task = getOMLTask(59) | ||
mlr.task = convertOMLTaskToMlr3(task)$mlr.task | ||
|
||
oml.types = c("crossvalidation", "holdout") | ||
mlr.types = c("cv", "holdout") | ||
|
||
for (i in seq_along(oml.types)) { | ||
task$input$estimation.procedure$type = oml.types[i] | ||
if (oml.types[i] == "holdout") { | ||
task$input$estimation.procedure$parameters$percentage = "50" | ||
} | ||
splits = convertOMLSplitsToMlr3(task$input$estimation.procedure, mlr.task) | ||
expect_is(splits, "Resampling") | ||
expect_equal(splits$id, mlr.types[i]) | ||
} | ||
|
||
# pass invalid estim.proc | ||
task$input$estimation.procedure$type = "blabla" | ||
expect_error(convertOMLSplitsToMlr3(task$input$estimation.procedure, mlr.task), "Unsupported estimation procedure type: blabla") | ||
}) | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some assertions missing (mlr.task.id, fix.colnames, verbosity)