diff --git a/DESCRIPTION b/DESCRIPTION index 946f9bc..3329e5d 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -29,7 +29,7 @@ Suggests: MASS (>= 7.3.47), testthat VignetteBuilder: knitr -RoxygenNote: 7.3.0 +RoxygenNote: 7.3.1 URL: https://github.com/ModelOriented/randomForestExplainer, https://modeloriented.github.io/randomForestExplainer/ Config/testthat/edition: 3 Config/Needs/website: ModelOriented/DrWhyTemplate diff --git a/R/min_depth_interactions.R b/R/min_depth_interactions.R index 6a2439c..7e706be 100644 --- a/R/min_depth_interactions.R +++ b/R/min_depth_interactions.R @@ -187,7 +187,8 @@ plot_min_depth_interactions <- function(interactions_frame, k = 30, #' @examples #' forest <- randomForest::randomForest(Species ~., data = iris) #' plot_predict_interaction(forest, iris, "Petal.Width", "Sepal.Width") -#' forest_ranger <- ranger::ranger(Species ~., data = iris) +#' +#' forest <- ranger::ranger(Species ~., data = iris, probability = TRUE) #' plot_predict_interaction(forest, iris, "Petal.Width", "Sepal.Width") #' #' @export @@ -214,12 +215,8 @@ plot_predict_interaction.randomForest <- function(forest, data, variable1, varia newdata <- expand.grid(seq(min(data[[variable1]]), max(data[[variable1]]), length.out = grid), seq(min(data[[variable2]]), max(data[[variable2]]), length.out = grid)) colnames(newdata) <- c(variable1, variable2) - if(as.character(forest$call$formula)[3] == "."){ - other_vars <- setdiff(names(data), as.character(forest$call$formula)[2]) - } else { - other_vars <- labels(terms(as.formula(forest$call$formula))) - } - other_vars <- setdiff(other_vars, c(variable1, variable2)) + + other_vars <- setdiff(get_feature_names(forest), c(variable1, variable2)) n <- nrow(data) for(i in other_vars){ newdata[[i]] <- data[[i]][sample(1:n, nrow(newdata), replace = TRUE)] @@ -263,12 +260,8 @@ plot_predict_interaction.ranger <- function(forest, data, variable1, variable2, newdata <- expand.grid(seq(min(data[[variable1]]), max(data[[variable1]]), length.out = grid), seq(min(data[[variable2]]), max(data[[variable2]]), length.out = grid)) colnames(newdata) <- c(variable1, variable2) - if(as.character(forest$call[[2]])[3] == "."){ - other_vars <- setdiff(names(data), as.character(forest$call[[2]])[2]) - } else { - other_vars <- labels(terms(as.formula(forest$call[[2]]))) - } - other_vars <- setdiff(other_vars, c(variable1, variable2)) + + other_vars <- setdiff(get_feature_names(forest), c(variable1, variable2)) n <- nrow(data) for(i in other_vars){ newdata[[i]] <- data[[i]][sample(1:n, nrow(newdata), replace = TRUE)] diff --git a/R/utils.R b/R/utils.R index 71f7ff9..57b55db 100644 --- a/R/utils.R +++ b/R/utils.R @@ -74,6 +74,17 @@ ntrees <- function(x) { if (inherits(x, "randomForest")) x$ntree else x$num.trees } +# Helper function that extracts feature names from fitted random forest +# Used in plot_predict_interaction() +get_feature_names <- function(x) { + stopifnot(inherits(x, c("randomForest", "ranger"))) + if (inherits(x, "randomForest")) { + rownames(x[["importance"]]) + } else { # ranger + x[[c("forest", "independent.variable.names")]] + } +} + # Applies tree2df() to each tree and stacks the results forest2df <- function(x) { rbindlist(lapply(seq_len(ntrees(x)), function(i) tree2df(x, i))) diff --git a/man/plot_predict_interaction.Rd b/man/plot_predict_interaction.Rd index 64bf789..870a452 100644 --- a/man/plot_predict_interaction.Rd +++ b/man/plot_predict_interaction.Rd @@ -41,7 +41,8 @@ Plot the prediction of the forest for a grid of values of two numerical variable \examples{ forest <- randomForest::randomForest(Species ~., data = iris) plot_predict_interaction(forest, iris, "Petal.Width", "Sepal.Width") -forest_ranger <- ranger::ranger(Species ~., data = iris) + +forest <- ranger::ranger(Species ~., data = iris, probability = TRUE) plot_predict_interaction(forest, iris, "Petal.Width", "Sepal.Width") } diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index a5bcddb..7270bb3 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -1,5 +1,32 @@ test_that("functions work as expected without warnings", { - expect_equal(min_na(NA), NA) expect_equal(max_na(NA), NA) }) + +test_that("get_feature_names() work with '.' features", { + fit_rf <- randomForest::randomForest(Sepal.Width ~ ., data = iris) + fit_ranger <- ranger::ranger(Sepal.Width ~ ., data = iris) + + expected <- setdiff(colnames(iris), "Sepal.Width") + + expect_equal(get_feature_names(fit_rf), expected) + expect_equal(get_feature_names(fit_ranger), expected) +}) + +test_that("get_feature_names() work with explicit features", { + form <- Sepal.Width ~ Sepal.Length + Species + fit_rf <- randomForest::randomForest(form, data = iris) + fit_ranger <- ranger::ranger(form, data = iris) + + expected <- c("Sepal.Length", "Species") + + expect_equal(get_feature_names(fit_rf), expected) + expect_equal(get_feature_names(fit_ranger), expected) +}) + +test_that("get_feature_names() work with xy interface of ranger", { + xvars <- setdiff(colnames(iris), "Sepal.Width") + fit_ranger <- ranger::ranger(y = iris$Sepal.Width, x = iris[xvars]) + + expect_equal(get_feature_names(fit_ranger), xvars) +})